Commit state of repository at time of OOPSLA 2015 submission.
[satcheck.git] / clang / src / add_mc2_annotations.cpp
1 // -*-  indent-tabs-mode:nil;  -*-
2 //------------------------------------------------------------------------------
3 // Add MC2 annotations to C code.
4 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
5 //
6 // Permission is hereby granted, free of charge, to any person
7 // obtaining a copy of this software and associated documentation
8 // files (the "Software"), to deal with the Software without
9 // restriction, including without limitation the rights to use, copy,
10 // modify, merge, publish, distribute, sublicense, and/or sell copies
11 // of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
13 //
14 // Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimers.
16 //
17 // Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimers in
19 // the documentation and/or other materials provided with the
20 // distribution.
21 //
22 // Neither the names of the University of Waterloo, nor the names of
23 // its contributors may be used to endorse or promote products derived
24 // from this Software without specific prior written permission.
25 //
26 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
27 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
28 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
29 // NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT
30 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
31 // WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
33 // DEALINGS WITH THE SOFTWARE.
34 //
35 // Patrick Lam (prof.lam@gmail.com)
36 //
37 // Base code:
38 // Eli Bendersky (eliben@gmail.com)
39 //
40 //------------------------------------------------------------------------------
41 #include <sstream>
42 #include <string>
43 #include <map>
44 #include <stdbool.h>
45
46 #include "clang/AST/AST.h"
47 #include "clang/AST/ASTContext.h"
48 #include "clang/AST/ASTConsumer.h"
49 #include "clang/AST/RecursiveASTVisitor.h"
50 #include "clang/ASTMatchers/ASTMatchers.h"
51 #include "clang/ASTMatchers/ASTMatchFinder.h"
52 #include "clang/Frontend/ASTConsumers.h"
53 #include "clang/Frontend/FrontendActions.h"
54 #include "clang/Frontend/CompilerInstance.h"
55 #include "clang/Lex/Lexer.h"
56 #include "clang/Tooling/CommonOptionsParser.h"
57 #include "clang/Tooling/Tooling.h"
58 #include "clang/Rewrite/Core/Rewriter.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/ADT/STLExtras.h"
61
62 using namespace clang;
63 using namespace clang::ast_matchers;
64 using namespace clang::driver;
65 using namespace clang::tooling;
66 using namespace llvm;
67
68 static LangOptions LangOpts;
69 static llvm::cl::OptionCategory AddMC2AnnotationsCategory("Add MC2 Annotations");
70
71 static std::string encode(std::string varName) {
72     std::stringstream nn;
73     nn << "_m" << varName;
74     return nn.str();
75 }
76
77 static int fnCount;
78 static std::string encodeFn(int num) {
79     std::stringstream nn;
80     nn << "_fn" << num;
81     return nn.str();
82 };
83
84 static int ptrCount;
85 static std::string encodePtr(int num) {
86     std::stringstream nn;
87     nn << "_p" << num;
88     return nn.str();
89 };
90
91 static int rmwCount;
92 static std::string encodeRMW(int num) {
93     std::stringstream nn;
94     nn << "_rmw" << num;
95     return nn.str();
96 };
97
98 static int branchCount;
99 static std::string encodeBranch(int num) {
100     std::stringstream nn;
101     nn << "_br" << num;
102     return nn.str();
103 };
104
105 static int condCount;
106 static std::string encodeCond(int num) {
107     std::stringstream nn;
108     nn << "_cond" << num;
109     return nn.str();
110 };
111
112 static int rvCount;
113 static std::string encodeRV(int num) {
114     std::stringstream nn;
115     nn << "_rv" << num;
116     return nn.str();
117 };
118
119 static int funcCount;
120
121 struct ProvisionalName {
122     int index, length;
123     const DeclRefExpr * pname;
124     bool enabled;
125
126     ProvisionalName(int index, const DeclRefExpr * pname) : index(index), pname(pname), length(encode(pname->getNameInfo().getName().getAsString()).length()), enabled(true) {}
127     ProvisionalName(int index, const DeclRefExpr * pname, int length) : index(index), pname(pname), length(length), enabled(true) {}
128 };
129
130 struct Update {
131     SourceLocation loc;
132     std::string update;
133     std::vector<ProvisionalName *> * pnames;
134
135     Update(SourceLocation loc, std::string update, std::vector<ProvisionalName *> * pnames) : 
136         loc(loc), update(update), pnames(pnames) {}
137
138     ~Update() { 
139         for (auto pname : *pnames) delete pname;
140         delete pnames; 
141     }
142 };
143
144 void updateProvisionalName(std::vector<Update *> &DeferredUpdates, const ValueDecl * now_known, std::string mcVar) {
145     for (Update * u : DeferredUpdates) {
146         for (int i = 0; i < u->pnames->size(); i++) {
147             ProvisionalName * v = (*(u->pnames))[i];
148             if (!v->enabled) continue;
149             if (now_known == v->pname->getDecl()) {
150                 v->enabled = false;
151                 std::string oldName = encode(v->pname->getNameInfo().getName().getAsString());
152
153                 u->update.replace(v->index, v->length, mcVar);
154                 for (int j = i+1; j < u->pnames->size(); j++) {
155                     ProvisionalName * vv = (*(u->pnames))[j];
156                     if (vv->index > v->index)
157                         vv->index -= v->length - mcVar.length();
158                 }
159             }
160         }
161     }
162 }
163
164 static const VarDecl * retrieveSingleDecl(const DeclStmt * s) {
165     // XXX iterate through all decls defined in s, not just the first one
166     assert(s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl()) && "unsupported form of decl");
167     if (s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl())) {
168         return cast<VarDecl>(s->getSingleDecl());
169     } else return NULL;
170 }
171
172 class FindCallArgVisitor : public RecursiveASTVisitor<FindCallArgVisitor> {
173 public:
174     FindCallArgVisitor() : DE(NULL), UnaryOp(NULL) {}
175
176     bool VisitStmt(Stmt * s) {
177         if (!UnaryOp) {
178             if (UnaryOperator * uo = dyn_cast<UnaryOperator>(s)) {
179                 if (uo->getOpcode() == UnaryOperatorKind::UO_AddrOf ||
180                     uo->getOpcode() == UnaryOperatorKind::UO_Deref)
181                     UnaryOp = uo;
182             }
183         }
184
185         if (!DE && (DE = dyn_cast<DeclRefExpr>(s)))
186             ;
187         return true;
188     }
189
190     void Clear() {
191         UnaryOp = NULL; DE = NULL;
192     }
193
194     const UnaryOperator * RetrieveUnaryOp() {
195         if (UnaryOp) {
196             bool found = false;
197             const Stmt * s = UnaryOp;
198             while (s != NULL) {
199                 if (s == DE) {
200                     found = true; break;
201                 }
202
203                 if (const UnaryOperator * op = dyn_cast<UnaryOperator>(s))
204                     s = op->getSubExpr();
205                 else if (const CastExpr * op = dyn_cast<CastExpr>(s))
206                     s = op->getSubExpr();
207                 else if (const MemberExpr * op = dyn_cast<MemberExpr>(s))
208                     s = op->getBase();
209                 else
210                     s = NULL;
211             }
212             if (found)
213                 return UnaryOp;
214         }
215         return NULL;
216     }
217
218     const DeclRefExpr * RetrieveDeclRefExpr() {
219         return DE;
220     }
221
222 private:
223     const UnaryOperator * UnaryOp;
224     const DeclRefExpr * DE;
225 };
226
227 class FindLocalsVisitor : public RecursiveASTVisitor<FindLocalsVisitor> {
228 public:
229     FindLocalsVisitor() : Vars() {}
230
231     bool VisitDeclRefExpr(DeclRefExpr * de) {
232         Vars.push_back(de->getDecl());
233         return true;
234     }
235
236     void Clear() {
237         Vars.clear();
238     }
239
240     const TinyPtrVector<const NamedDecl *> RetrieveVars() {
241         return Vars;
242     }
243
244 private:
245     TinyPtrVector<const NamedDecl *> Vars;
246 };
247
248
249 class MallocHandler : public MatchFinder::MatchCallback {
250 public:
251     MallocHandler(std::set<const Expr *> & MallocExprs) :
252         MallocExprs(MallocExprs) {}
253
254     virtual void run(const MatchFinder::MatchResult &Result) {
255         const CallExpr * ce = Result.Nodes.getNodeAs<CallExpr>("callExpr");
256
257         MallocExprs.insert(ce);
258     }
259
260     private:
261     std::set<const Expr *> &MallocExprs;
262 };
263
264 static void generateMC2Function(Rewriter & Rewrite,
265                                 const Expr * e, 
266                                 SourceLocation loc,
267                                 std::string tmpname, 
268                                 std::string tmpFn, 
269                                 const DeclRefExpr * lhs, 
270                                 std::string lhsName,
271                                 std::vector<ProvisionalName *> * vars1,
272                                 std::vector<Update *> & DeferredUpdates) {
273     // prettyprint the LHS (&newnode->value)
274     // e.g. int * _tmp0 = &newnode->value;
275     std::string SStr;
276     llvm::raw_string_ostream S(SStr);
277     e->printPretty(S, nullptr, Rewrite.getLangOpts());
278     const std::string &Str = S.str();
279
280     std::stringstream prel;
281     prel << "\nvoid * " << tmpname << " = " << Str << ";\n";
282
283     // MCID _p0 = MC2_function(1, MC2_PTR_LENGTH, _tmp0, _fn0);
284     prel << "MCID " << tmpFn << " = MC2_function_id(" << ++funcCount << ", 1, MC2_PTR_LENGTH, " << tmpname << ", ";
285     if (lhs) {
286         // XXX generate casts when they'd eliminate warnings
287         ProvisionalName * v = new ProvisionalName(prel.tellp(), lhs);
288         vars1->push_back(v);
289     }
290     prel << encode(lhsName) << "); ";
291
292     Update * u = new Update(loc, prel.str(), vars1);
293     DeferredUpdates.push_back(u);
294 }
295
296 class LoadHandler : public MatchFinder::MatchCallback {
297 public:
298     LoadHandler(Rewriter &Rewrite,
299                 std::set<const NamedDecl *> & DeclsRead,
300                 std::set<const VarDecl *> & DeclsNeedingMC, 
301                 std::map<const NamedDecl *, std::string> &DeclToMCVar,
302                 std::map<const Expr *, std::string> &ExprToMCVar,
303                 std::set<const Stmt *> & StmtsHandled,
304                 std::map<const Expr *, SourceLocation> &Redirector,
305                 std::vector<Update *> & DeferredUpdates) : 
306         Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeclToMCVar(DeclToMCVar),
307         ExprToMCVar(ExprToMCVar),
308         StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
309
310     virtual void run(const MatchFinder::MatchResult &Result) {
311         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
312         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
313         const VarDecl * d = Result.Nodes.getNodeAs<VarDecl>("decl");
314         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
315         const Expr * lhs = NULL;
316         if (s && isa<BinaryOperator>(s)) lhs = cast<BinaryOperator>(s)->getLHS();
317         if (!s) s = ce;
318
319         const DeclRefExpr * rhs = NULL;
320         MemberExpr * ml = NULL;
321         bool isAddrOfR = false, isAddrMemberR = false;
322
323         StmtsHandled.insert(s);
324         
325         std::string n, n_decl;
326         if (lhs) {
327             FindCallArgVisitor fcaVisitor;
328             fcaVisitor.Clear();
329             fcaVisitor.TraverseStmt(ce->getArg(0));
330             rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
331             const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
332             isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
333             isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
334             if (isAddrMemberR)
335                 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
336
337             FindLocalsVisitor flv;
338             flv.Clear();
339             flv.TraverseStmt(const_cast<Stmt*>(cast<Stmt>(lhs)));
340             for (auto & d : flv.RetrieveVars()) {
341                 const VarDecl * dd = cast<VarDecl>(d);
342                 n = dd->getName();
343                 // XXX todo rhs for non-decl stmts
344                 if (!isa<ParmVarDecl>(dd))
345                     DeclsNeedingMC.insert(dd);
346                 DeclsRead.insert(d);
347                 DeclToMCVar[dd] = encode(n);
348             }
349         } else {
350             FindCallArgVisitor fcaVisitor;
351             fcaVisitor.Clear();
352             fcaVisitor.TraverseStmt(ce->getArg(0));
353             rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
354             const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
355             isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
356             isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
357             if (isAddrMemberR)
358                 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
359
360             if (d) {
361                 n = d->getName();
362                 DeclsNeedingMC.insert(d);
363                 DeclsRead.insert(d);
364                 DeclToMCVar[d] = encode(n);
365             } else {
366                 n = ExprToMCVar[ce];
367                 fcaVisitor.Clear();
368                 fcaVisitor.TraverseStmt(ce);
369                 const DeclRefExpr * dd = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
370                 updateProvisionalName(DeferredUpdates, dd->getDecl(), encode(n));
371                 DeclToMCVar[dd->getDecl()] = encode(n);
372                 n_decl = "MCID ";
373             }
374         }
375         
376         std::stringstream nol;
377
378         if (lhs && isa<DeclRefExpr>(lhs)) {
379             const DeclRefExpr * ll = cast<DeclRefExpr>(lhs);
380             ProvisionalName * v = new ProvisionalName(nol.tellp(), ll);
381             vars->push_back(v);
382         }
383
384         if (rhs) {
385             if (isAddrMemberR) {
386                 if (!n.empty()) 
387                     nol << n_decl << encode(n) << "=";
388                 nol << "MC2_nextOpLoadOffset(";
389
390                 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
391                 vars->push_back(v);
392                 nol << encode(rhs->getNameInfo().getName().getAsString());
393
394                 nol << ", MC2_OFFSET(";
395                 nol << ml->getBase()->getType().getAsString();
396                 nol << ", ";
397                 nol << ml->getMemberDecl()->getName().str();
398                 nol << ")";
399             } else if (!isAddrOfR) {
400                 if (!n.empty()) 
401                     nol << n_decl << encode(n) << "=";
402                 nol << "MC2_nextOpLoad(";
403                 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
404                 vars->push_back(v);
405                 nol << encode(rhs->getNameInfo().getName().getAsString());
406             } else {
407                 if (!n.empty()) 
408                     nol << n_decl << encode(n) << "=";
409                 nol << "MC2_nextOpLoad(";
410                 nol << "MCID_NODEP";
411             }
412         } else {
413             if (!n.empty()) 
414                 nol << n_decl << encode(n) << "=";
415             nol << "MC2_nextOpLoad(";
416             nol << "MCID_NODEP";
417         }
418
419         if (lhs)
420             nol << "), ";
421         else
422             nol << "); ";
423         SourceLocation ss = s->getLocStart();
424         // eek gross hack:
425         // if the load appears as its own stmt and is the 1st stmt, containingStmt may be the containing CompoundStmt;
426         // move over 1 so that we get the right location.
427         if (isa<CompoundStmt>(s)) ss = ss.getLocWithOffset(1);
428         const Expr * e = dyn_cast<Expr>(s);
429         if (e && Redirector.count(e) > 0)
430             ss = Redirector[e];
431         Update * u = new Update(ss, nol.str(), vars);
432         DeferredUpdates.insert(DeferredUpdates.begin(), u);
433     }
434
435     private:
436     Rewriter &Rewrite;
437     std::set<const NamedDecl *> & DeclsRead;
438     std::set<const VarDecl *> & DeclsNeedingMC;
439     std::map<const Expr *, std::string> &ExprToMCVar;
440     std::map<const NamedDecl *, std::string> &DeclToMCVar;
441     std::set<const Stmt *> &StmtsHandled;
442     std::vector<Update *> &DeferredUpdates;
443     std::map<const Expr *, SourceLocation> &Redirector;
444 };
445
446 class StoreHandler : public MatchFinder::MatchCallback {
447 public:
448     StoreHandler(Rewriter &Rewrite,
449                  std::set<const NamedDecl *> & DeclsRead,
450                  std::set<const VarDecl *> &DeclsNeedingMC,
451                  std::vector<Update *> & DeferredUpdates) : 
452         Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeferredUpdates(DeferredUpdates) {}
453
454     virtual void run(const MatchFinder::MatchResult &Result) {
455         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
456         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
457
458         fcaVisitor.Clear();
459         fcaVisitor.TraverseStmt(ce->getArg(0));
460         const DeclRefExpr * lhs = fcaVisitor.RetrieveDeclRefExpr();
461         const UnaryOperator * luop = fcaVisitor.RetrieveUnaryOp();
462     
463         std::stringstream nol;
464
465         bool isAddrMemberL;
466         bool isAddrOfL;
467
468         if (luop && luop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
469             isAddrMemberL = isa<MemberExpr>(luop->getSubExpr());
470             isAddrOfL = !isa<MemberExpr>(luop->getSubExpr());
471         }
472
473         if (lhs) {
474             if (isAddrOfL) {
475                 nol << "MC2_nextOpStore(";
476                 nol << "MCID_NODEP";
477             } else {
478                 if (isAddrMemberL) {
479                     MemberExpr * ml = cast<MemberExpr>(luop->getSubExpr());
480
481                     nol << "MC2_nextOpStoreOffset(";
482
483                     ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
484                     vars->push_back(v);
485
486                     nol << encode(lhs->getNameInfo().getName().getAsString());
487                     if (!isa<ParmVarDecl>(lhs->getDecl()))
488                         DeclsNeedingMC.insert(cast<VarDecl>(lhs->getDecl()));
489
490                     nol << ", MC2_OFFSET(";
491                     nol << ml->getBase()->getType().getAsString();
492                     nol << ", ";
493                     nol << ml->getMemberDecl()->getName().str();
494                     nol << ")";
495                 } else {
496                     nol << "MC2_nextOpStore(";
497                     ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
498                     vars->push_back(v);
499
500                     nol << encode(lhs->getNameInfo().getName().getAsString());
501                 }
502             }
503         }
504         else {
505             nol << "MC2_nextOpStore(";
506             nol << "MCID_NODEP";
507         }
508         
509         nol << ", ";
510
511         fcaVisitor.Clear();
512         fcaVisitor.TraverseStmt(ce->getArg(1));
513         const DeclRefExpr * rhs = fcaVisitor.RetrieveDeclRefExpr();
514         const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
515
516         bool isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
517         bool isDerefR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_Deref;
518
519         if (rhs && !isAddrOfR) {
520             assert (!isDerefR && "Must use atomic load for dereferences!");
521             ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
522             vars->push_back(v);
523
524             nol << encode(rhs->getNameInfo().getName().getAsString());
525             DeclsRead.insert(rhs->getDecl());
526         }
527         else
528             nol << "MCID_NODEP";
529         
530         nol << ");\n";
531         Update * u = new Update(ce->getLocStart(), nol.str(), vars);
532         DeferredUpdates.push_back(u);
533     }
534
535     private:
536     Rewriter &Rewrite;
537     FindCallArgVisitor fcaVisitor;
538     std::set<const NamedDecl *> & DeclsRead;
539     std::set<const VarDecl *> & DeclsNeedingMC;
540     std::vector<Update *> &DeferredUpdates;
541 };
542
543 class RMWHandler : public MatchFinder::MatchCallback {
544 public:
545     RMWHandler(Rewriter &rewrite,
546                  std::set<const NamedDecl *> & DeclsRead,
547                  std::set<const NamedDecl *> & DeclsInCond,
548                  std::map<const NamedDecl *, std::string> &DeclToMCVar,
549                  std::map<const Expr *, std::string> &ExprToMCVar,
550                  std::set<const Stmt *> & StmtsHandled,
551                  std::map<const Expr *, SourceLocation> &Redirector,
552                  std::vector<Update *> & DeferredUpdates) : 
553         rewrite(rewrite), DeclsRead(DeclsRead), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar), 
554         ExprToMCVar(ExprToMCVar),
555         StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
556
557     virtual void run(const MatchFinder::MatchResult &Result) {
558         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
559         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
560         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
561
562         std::stringstream nol;
563         
564         std::string rmwMCVar;
565         rmwMCVar = encodeRMW(rmwCount++);
566
567         const VarDecl * rmw_lhs;
568         if (s) {
569             StmtsHandled.insert(s);
570             assert (isa<DeclStmt>(s) || isa<BinaryOperator>(s) && "unknown RMW format: not declrefexpr, not binaryoperator");
571             const DeclStmt * ds;
572             if ((ds = dyn_cast<DeclStmt>(s))) {
573                 rmw_lhs = retrieveSingleDecl(ds);
574             } else {
575                 const Expr * e = cast<BinaryOperator>(s)->getLHS();
576                 assert (isa<DeclRefExpr>(e));
577                 rmw_lhs = cast<VarDecl>(cast<DeclRefExpr>(e)->getDecl());
578             }
579             DeclToMCVar[rmw_lhs] = rmwMCVar;
580         }
581
582         // retrieve effective LHS of the RMW
583         fcaVisitor.Clear();
584         fcaVisitor.TraverseStmt(ce->getArg(1));
585         const DeclRefExpr * elhs = fcaVisitor.RetrieveDeclRefExpr();
586         const UnaryOperator * eluop = fcaVisitor.RetrieveUnaryOp();
587         bool isAddrMemberL = false;
588
589         if (eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
590             isAddrMemberL = isa<MemberExpr>(eluop->getSubExpr());
591         }
592
593         nol << "MCID " << rmwMCVar;
594         if (isAddrMemberL) {
595             MemberExpr * ml = cast<MemberExpr>(eluop->getSubExpr());
596
597             nol << " = MC2_nextRMWOffset(";
598
599             ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
600             vars->push_back(v);
601
602             nol << encode(elhs->getNameInfo().getName().getAsString());
603
604             nol << ", MC2_OFFSET(";
605             nol << ml->getBase()->getType().getAsString();
606             nol << ", ";
607             nol << ml->getMemberDecl()->getName().str();
608             nol << ")";
609         } else {
610             nol << " = MC2_nextRMW(";
611             bool isAddrOfL = eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
612
613             if (elhs) {
614                 if (isAddrOfL)
615                     nol << "MCID_NODEP";
616                 else {
617                     ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
618                     vars->push_back(v);
619
620                     std::string elhsName = encode(elhs->getNameInfo().getName().getAsString());
621                     nol << elhsName;
622                 }
623             }
624             else
625                 nol << "MCID_NODEP";
626         }
627         nol << ", ";
628
629         // handle both RHS ops
630         int outputted = 0;
631         for (int arg = 2; arg < 4; arg++) {
632             fcaVisitor.Clear();
633             fcaVisitor.TraverseStmt(ce->getArg(arg));
634             const DeclRefExpr * a = fcaVisitor.RetrieveDeclRefExpr();
635             const UnaryOperator * op = fcaVisitor.RetrieveUnaryOp();
636             
637             bool isAddrOfR = op && op->getOpcode() == UnaryOperatorKind::UO_AddrOf;
638             bool isDerefR = op && op->getOpcode() == UnaryOperatorKind::UO_Deref;
639
640             if (a && !isAddrOfR) {
641                 assert (!isDerefR && "Must use atomic load for dereferences!");
642
643                 DeclsInCond.insert(a->getDecl());
644
645                 if (outputted > 0) nol << ", ";
646                 outputted++;
647
648                 bool alreadyMCVar = false;
649                 if (DeclToMCVar.find(a->getDecl()) != DeclToMCVar.end()) {
650                     alreadyMCVar = true;
651                     nol << DeclToMCVar[a->getDecl()];
652                 }
653                 else {
654                     std::string an = "MCID_NODEP";
655                     ProvisionalName * v = new ProvisionalName(nol.tellp(), a, an.length());
656                     nol << an;
657                     vars->push_back(v);
658                 }
659
660                 DeclsRead.insert(a->getDecl());
661             }
662             else {
663                 if (outputted > 0) nol << ", ";
664                 outputted++;
665
666                 nol << "MCID_NODEP";
667             }
668         }
669         nol << ");\n";
670
671         SourceLocation place = s ? s->getLocStart() : ce->getLocStart();
672         const Expr * e = s ? dyn_cast<Expr>(s) : ce;
673         if (e && Redirector.count(e) > 0)
674             place = Redirector[e];
675         Update * u = new Update(place, nol.str(), vars);
676         DeferredUpdates.insert(DeferredUpdates.begin(), u);
677     }
678     
679     private:
680     Rewriter &rewrite;
681     FindCallArgVisitor fcaVisitor;
682     std::set<const NamedDecl *> &DeclsRead;
683     std::set<const NamedDecl *> &DeclsInCond;
684     std::map<const NamedDecl *, std::string> &DeclToMCVar;
685     std::map<const Expr *, std::string> &ExprToMCVar;
686     std::set<const Stmt *> &StmtsHandled;
687     std::vector<Update *> &DeferredUpdates;
688     std::map<const Expr *, SourceLocation> &Redirector;
689 };
690
691 class FindReturnsBreaksVisitor : public RecursiveASTVisitor<FindReturnsBreaksVisitor> {
692 public:
693     FindReturnsBreaksVisitor() : Returns(), Breaks() {}
694
695     bool VisitStmt(Stmt * s) {
696         if (isa<ReturnStmt>(s))
697             Returns.push_back(cast<ReturnStmt>(s));
698
699         if (isa<BreakStmt>(s))
700             Breaks.push_back(cast<BreakStmt>(s));
701         return true;
702     }
703
704     void Clear() {
705         Returns.clear(); Breaks.clear();
706     }
707
708     const std::vector<const ReturnStmt *> RetrieveReturns() {
709         return Returns;
710     }
711
712     const std::vector<const BreakStmt *> RetrieveBreaks() {
713         return Breaks;
714     }
715
716 private:
717     std::vector<const ReturnStmt *> Returns;
718     std::vector<const BreakStmt *> Breaks;
719 };
720
721 class LoopHandler : public MatchFinder::MatchCallback {
722 public:
723     LoopHandler(Rewriter &rewrite) : rewrite(rewrite) {}
724
725     virtual void run(const MatchFinder::MatchResult &Result) {
726         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
727
728         rewrite.InsertText(s->getLocStart(), "MC2_enterLoop();\n", true, true);
729
730         // annotate all returns with MC2_exitLoop()
731         // annotate all breaks that aren't further nested with MC2_exitLoop().
732         FindReturnsBreaksVisitor frbv;
733         if (isa<ForStmt>(s))
734             frbv.TraverseStmt(const_cast<Stmt *>(cast<ForStmt>(s)->getBody()));
735         if (isa<WhileStmt>(s))
736             frbv.TraverseStmt(const_cast<Stmt *>(cast<WhileStmt>(s)->getBody()));
737         if (isa<DoStmt>(s))
738             frbv.TraverseStmt(const_cast<Stmt *>(cast<DoStmt>(s)->getBody()));
739
740         for (auto & r : frbv.RetrieveReturns()) {
741             rewrite.InsertText(r->getLocStart(), "MC2_exitLoop();\n", true, true);
742         }
743         
744         // need to find all breaks and returns embedded inside the loop
745
746         rewrite.InsertTextAfterToken(s->getLocEnd().getLocWithOffset(1), "\nMC2_exitLoop();\n");
747     }
748
749 private:
750     Rewriter &rewrite;
751 };
752
753 /* Inserts MC2_function for any variables which are subsequently used by the model checker, as long as they depend on MC-visible [currently: read] state. */
754 class AssignHandler : public MatchFinder::MatchCallback {
755 public:
756     AssignHandler(Rewriter &rewrite, std::set<const NamedDecl *> &DeclsRead,
757                   std::set<const NamedDecl *> &DeclsInCond,
758                   std::set<const VarDecl *> &DeclsNeedingMC,
759                   std::map<const NamedDecl *, std::string> &DeclToMCVar,
760                   std::set<const Stmt *> &StmtsHandled,
761                   std::set<const Expr *> &MallocExprs,
762                   std::vector<Update *> &DeferredUpdates) :
763         rewrite(rewrite),
764         DeclsRead(DeclsRead),
765         DeclsInCond(DeclsInCond),
766         DeclsNeedingMC(DeclsNeedingMC),
767         DeclToMCVar(DeclToMCVar),
768         StmtsHandled(StmtsHandled),
769         MallocExprs(MallocExprs),
770         DeferredUpdates(DeferredUpdates) {}
771
772     virtual void run(const MatchFinder::MatchResult &Result) {
773         BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
774         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
775         FindLocalsVisitor flv;
776
777         const VarDecl * lhs = NULL;
778         const Expr * rhs = NULL;
779         const DeclStmt * ds;
780
781         if (s && (ds = dyn_cast<DeclStmt>(s))) {
782             // long term goal: refactor the run() method to deal with one assignment at a time
783             // for now, if there is only declarations and no rhs's, we'll ignore this stmt
784             if (!ds->isSingleDecl()) {
785                 for (auto & d : ds->decls()) {
786                     VarDecl * vd = dyn_cast<VarDecl>(d);
787                     if (!d || vd->hasInit())
788                         assert(0 && "unsupported form of decl");
789                 }
790                 return;
791             }
792
793             lhs = retrieveSingleDecl(ds);
794         }
795
796         if (StmtsHandled.find(ds) != StmtsHandled.end() || StmtsHandled.find(op) != StmtsHandled.end())
797             return;
798
799         if (lhs) {
800             if (lhs->hasInit()) {
801                 rhs = lhs->getInit();
802                 if (rhs) {
803                     rhs = rhs->IgnoreCasts();
804                 }
805             }
806             else
807                 return;
808         }
809         std::set<std::string> mcState;
810
811         bool lhsTooComplicated = false;
812         if (op) {
813             flv.TraverseStmt(op);
814
815             DeclRefExpr * vd;
816             if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
817                 lhs = dyn_cast<VarDecl>(vd->getDecl());
818             else {
819                 // kick the can along...
820                 lhsTooComplicated = true;
821             }
822
823             rhs = op->getRHS();
824             if (rhs) 
825                 rhs = rhs->IgnoreCasts();
826         }
827         else if (lhs) {
828             // rhs must be MC-active state, i.e. in declsread
829             // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
830             flv.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
831         }
832
833         if (DeclsInCond.find(lhs) != DeclsInCond.end()) {
834             for (auto & d : flv.RetrieveVars()) {
835                 if (DeclToMCVar.count(d) > 0)
836                     mcState.insert(DeclToMCVar[d]);
837                 else if (DeclsRead.find(d) != DeclsRead.end())
838                     mcState.insert(encode(d->getName().str()));
839             }
840         }
841
842         if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
843             if (lhsTooComplicated)
844                 assert(0 && "couldn't find LHS of = operator");
845
846             std::stringstream nol;
847             std::string _lhsStr, lhsStr;
848             std::string mcVar = encodeFn(fnCount++);
849             if (lhs) {
850                 lhsStr = lhs->getName().str();
851                 _lhsStr = encode(lhsStr);
852                 DeclToMCVar[lhs] = mcVar;
853                 DeclsNeedingMC.insert(cast<VarDecl>(lhs));
854             }
855             int function_id = 0;
856             if (!(MallocExprs.find(rhs) != MallocExprs.end()))
857                 function_id = ++funcCount;
858             nol << "\n" << mcVar << " = MC2_function_id(" << function_id << ", " << mcState.size();
859             if (lhs)
860                 nol << ", sizeof (" << lhsStr << "), (uint64_t)" << lhsStr;
861             else 
862                 nol << ", MC2_PTR_LENGTH";
863             for (auto & d : mcState) {
864                 nol <<  ", ";
865                 if (_lhsStr == d)
866                     nol << mcVar;
867                 else
868                     nol << d;
869             }
870             nol << "); ";
871             SourceLocation place;
872             if (op)
873                 place = op->getLocEnd().getLocWithOffset(1);
874             else
875                 place = s->getLocEnd();
876             rewrite.InsertText(place.getLocWithOffset(1), nol.str(), true, true);
877
878             updateProvisionalName(DeferredUpdates, lhs, mcVar);
879         }
880     }
881
882     private:
883     Rewriter &rewrite;
884     std::set<const NamedDecl *> &DeclsRead, &DeclsInCond;
885     std::set<const VarDecl *> &DeclsNeedingMC;
886     std::map<const NamedDecl *, std::string> &DeclToMCVar;
887     std::set<const Stmt *> &StmtsHandled;
888     std::set<const Expr *> &MallocExprs;
889     std::vector<Update *> &DeferredUpdates;
890 };
891
892 // record vars used in conditions
893 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
894 public:
895     BranchConditionRefactoringHandler(Rewriter &Rewrite,
896                                       std::set<const NamedDecl *> & DeclsInCond,
897                                       std::map<const NamedDecl *, std::string> &DeclToMCVar,
898                                       std::map<const Expr *, std::string> &ExprToMCVar,
899                                       std::map<const Expr *, SourceLocation> &Redirector,
900                                       std::vector<Update *> &DeferredUpdates) :
901         Rewrite(Rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar), 
902         ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
903
904     virtual void run(const MatchFinder::MatchResult &Result) {
905         IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
906         Expr * cond = is->getCond();
907
908         // refactor out complicated conditions
909         FindCallArgVisitor flv;
910         flv.TraverseStmt(cond);
911         std::string mcVar;
912
913         BinaryOperator * bc = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("bc"));
914         if (bc) {
915             std::string condVar = encodeCond(condCount++);
916             std::stringstream condVarEncoded;
917             condVarEncoded << condVar << "_m";
918
919             // prettyprint the binary op
920             // e.g. int _cond0 = x == y;
921             std::string SStr;
922             llvm::raw_string_ostream S(SStr);
923             bc->printPretty(S, nullptr, Rewrite.getLangOpts());
924             const std::string &Str = S.str();
925             
926             std::stringstream prel;
927
928             bool is_equality = false;
929             // handle equality tests
930             if (bc->getOpcode() == BO_EQ) {
931                 Expr * lhs = bc->getLHS()->IgnoreCasts(), * rhs = bc->getRHS()->IgnoreCasts();
932                 if (isa<DeclRefExpr>(lhs) && isa<DeclRefExpr>(rhs)) {
933                     DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
934                     is_equality = true;
935                     prel << "\nMCID " << condVarEncoded.str() << ";\n";
936                     std::string ld = DeclToMCVar.find(l->getDecl())->second,
937                         rd = DeclToMCVar.find(r->getDecl())->second;
938
939                     prel << "\nint " << condVar << " = MC2_equals(" <<
940                         ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
941                         rd << ", (uint64_t)" << r->getNameInfo().getName().getAsString() << ", " <<
942                         "&" << condVarEncoded.str() << ");\n";
943                 }
944             }
945
946             if (!is_equality) {
947                 prel << "\nint " << condVar << " = " << Str << ";";
948                 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
949                 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
950                 if (DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
951                     prel << ", " << DeclToMCVar[d->getDecl()];
952                 }
953                 prel << ");\n";
954             }
955
956             ExprToMCVar[cond] = condVarEncoded.str();
957             Rewrite.InsertText(is->getLocStart(), prel.str(), false, true);
958
959             // rewrite the binary op with the newly-inserted var
960             Expr * RO = bc->getRHS(); // used for location only
961
962             int cl = Lexer::MeasureTokenLength(RO->getLocStart(), Rewrite.getSourceMgr(), Rewrite.getLangOpts());
963             SourceRange SR(cond->getLocStart(), Rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
964             Rewrite.ReplaceText(SR, condVar);
965         } else {
966             std::string condVar = encodeCond(condCount++);
967             std::stringstream condVarEncoded;
968             condVarEncoded << condVar << "_m";
969
970             std::string SStr;
971             llvm::raw_string_ostream S(SStr);
972             cond->printPretty(S, nullptr, Rewrite.getLangOpts());
973             const std::string &Str = S.str();
974
975             std::stringstream prel;
976             prel << "\nint " << condVar << " = " << Str << ";";
977             prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
978             std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
979             const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
980             if (isa<VarDecl>(d->getDecl()) && DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
981                 prel << ", " << DeclToMCVar[d->getDecl()];
982             } else {
983                 prel << ", ";
984                 ProvisionalName * v = new ProvisionalName(prel.tellp(), d, 0);
985                 vars->push_back(v);
986             }
987             prel << ");\n";
988
989             ExprToMCVar[cond] = condVarEncoded.str();
990             // gross hack; should look for any callexprs in cond
991             // but right now, if it's a unaryop, just manually traverse
992             if (isa<UnaryOperator>(cond)) {
993                 Expr * e = dyn_cast<UnaryOperator>(cond)->getSubExpr();
994                 ExprToMCVar[e] = condVarEncoded.str();
995             }
996             Update * u = new Update(is->getLocStart(), prel.str(), vars);
997             DeferredUpdates.push_back(u);
998
999             // rewrite the call op with the newly-inserted var
1000             SourceRange SR(cond->getLocStart(), cond->getLocEnd());
1001             Redirector[cond] = is->getLocStart();
1002             Rewrite.ReplaceText(SR, condVar);
1003         }
1004
1005         std::deque<const Decl *> q;
1006         const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1007         q.push_back(d);
1008         while (!q.empty()) {
1009             const Decl * d = q.back();
1010             q.pop_back();
1011             if (isa<NamedDecl>(d))
1012                 DeclsInCond.insert(cast<NamedDecl>(d));
1013
1014             const VarDecl * vd;
1015             if ((vd = dyn_cast<VarDecl>(d))) {
1016                 if (vd->hasInit()) {
1017                     const Expr * e = vd->getInit();
1018                     flv.Clear();
1019                     flv.TraverseStmt(const_cast<Expr *>(e));
1020                     const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1021                     q.push_back(d);
1022                 }
1023             }
1024         }
1025     }
1026
1027 private:
1028     Rewriter &Rewrite;
1029     std::set<const NamedDecl *> & DeclsInCond;
1030     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1031     std::map<const Expr *, std::string> &ExprToMCVar;
1032     std::map<const Expr *, SourceLocation> &Redirector;
1033     std::vector<Update *> &DeferredUpdates;
1034 };
1035
1036 class BranchAnnotationHandler : public MatchFinder::MatchCallback {
1037 public:
1038     BranchAnnotationHandler(Rewriter &rewrite,
1039                             std::map<const NamedDecl *, std::string> & DeclToMCVar,
1040                             std::map<const Expr *, std::string> & ExprToMCVar)
1041         : Rewrite(rewrite),
1042           DeclToMCVar(DeclToMCVar),
1043           ExprToMCVar(ExprToMCVar){}
1044     virtual void run(const MatchFinder::MatchResult &Result) {
1045         IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
1046
1047         // if the branch condition is interesting:
1048         // (but right now, not too interesting)
1049         Expr * cond = is->getCond()->IgnoreCasts();
1050
1051         FindLocalsVisitor flv;
1052         flv.TraverseStmt(cond);
1053         if (flv.RetrieveVars().size() == 0) return;
1054
1055         const NamedDecl * condVar = flv.RetrieveVars()[0];
1056
1057         std::string mCondVar;
1058         if (ExprToMCVar.count(cond) > 0)
1059             mCondVar = ExprToMCVar[cond];
1060         else if (DeclToMCVar.count(condVar) > 0) 
1061             mCondVar = DeclToMCVar[condVar];
1062         else
1063             mCondVar = encode(condVar->getName());
1064         std::string brVar = encodeBranch(branchCount++);
1065
1066         std::stringstream brline;
1067         brline << "MCID " << brVar << ";\n";
1068         Rewrite.InsertText(is->getLocStart(), brline.str(), false, true);
1069
1070         Stmt * ts = is->getThen(), * es = is->getElse();
1071         bool tHasChild = hasChild(ts);
1072         SourceLocation tfl;
1073         if (tHasChild) {
1074             if (isa<CompoundStmt>(ts))
1075                 tfl = getFirstChild(ts)->getLocStart();
1076             else
1077                 tfl = ts->getLocStart();
1078         } else
1079             tfl = ts->getLocStart().getLocWithOffset(1);
1080         SourceLocation tsl = ts->getLocEnd().getLocWithOffset(-1);
1081
1082         std::stringstream tlineStart, mergeStmt, eline;
1083
1084         UnaryOperator * uop = dyn_cast<UnaryOperator>(cond);
1085         tlineStart << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "1" << ", 2, true);\n";
1086         eline << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "0" << ", 2, true);";
1087
1088         mergeStmt << "\tMC2_merge(" << brVar << ");\n";
1089
1090         Rewrite.InsertText(tfl, tlineStart.str(), false, true);
1091
1092         Stmt * tls = NULL;
1093         int extra_else_offset = 0;
1094
1095         if (tHasChild) { tls = getLastChild(ts); }
1096         if (tls) extra_else_offset = 2; else extra_else_offset = 1;
1097
1098         if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
1099             extra_else_offset = 0;
1100             Rewrite.InsertText(tsl.getLocWithOffset(1), mergeStmt.str(), true, true);
1101         }
1102         if (tHasChild && !isa<CompoundStmt>(ts)) {
1103             Rewrite.InsertText(tls->getLocStart(), "{", false, true);
1104             SourceLocation tend = Lexer::getLocForEndOfToken(tls->getLocStart(), 0, Rewrite.getSourceMgr(), Rewrite.getLangOpts());
1105             Rewrite.InsertText(tend.getLocWithOffset(2), "}", true, true);
1106             extra_else_offset++;
1107         }
1108         if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
1109
1110         if (es) {
1111             SourceLocation esl = es->getLocEnd().getLocWithOffset(-1);
1112             bool eHasChild = hasChild(es); 
1113             Stmt * els = NULL;
1114             if (eHasChild) els = getLastChild(es); else els = es;
1115             
1116             eline << "\n";
1117
1118             SourceLocation el;
1119             if (eHasChild) {
1120                 if (isa<CompoundStmt>(es))
1121                     el = getFirstChild(es)->getLocStart();
1122                 else {
1123                     el = es->getLocStart();
1124                 }
1125             } else
1126                 el = es->getLocStart().getLocWithOffset(1);
1127             Rewrite.InsertText(el, eline.str(), false, true);
1128
1129             if (eHasChild && !isa<CompoundStmt>(es)) {
1130                 Rewrite.InsertText(el, "{", false, true);
1131                 Rewrite.InsertText(es->getLocEnd().getLocWithOffset(1), "}", true, true);
1132             }
1133
1134             if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
1135                 Rewrite.InsertText(esl.getLocWithOffset(1), mergeStmt.str(), true, true);
1136         }
1137         else {
1138             std::stringstream eCompoundLine;
1139             eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
1140             SourceLocation tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, Rewrite.getSourceMgr(), Rewrite.getLangOpts());
1141             Rewrite.InsertText(tend.getLocWithOffset(extra_else_offset),
1142               eCompoundLine.str(), false, true);
1143         }
1144     }
1145 private:
1146
1147     bool hasChild(Stmt * s) {
1148         if (!isa<CompoundStmt>(s)) return true;
1149         return (!cast<CompoundStmt>(s)->body_empty());
1150     }
1151
1152     Stmt * getFirstChild(Stmt * s) {
1153         assert(isa<CompoundStmt>(s) && "haven't yet added code to rewrite then/elsestmt to CompoundStmt");
1154         assert(!cast<CompoundStmt>(s)->body_empty());
1155         return *(cast<CompoundStmt>(s)->body_begin());
1156     }
1157
1158     Stmt * getLastChild(Stmt * s) {
1159         CompoundStmt * cs;
1160         if ((cs = dyn_cast<CompoundStmt>(s))) {
1161             assert (!cs->body_empty());
1162             return cs->body_back();
1163         }
1164         return s;
1165     }
1166
1167     Rewriter &Rewrite;
1168     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1169     std::map<const Expr *, std::string> &ExprToMCVar;
1170 };
1171
1172 class FunctionCallHandler : public MatchFinder::MatchCallback {
1173 public:
1174     FunctionCallHandler(Rewriter &rewrite,
1175                         std::map<const NamedDecl *, std::string> &DeclToMCVar,
1176                         std::set<const FunctionDecl *> &ThreadMains)
1177         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1178
1179     virtual void run(const MatchFinder::MatchResult &Result) {
1180         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
1181         Decl * d = ce->getCalleeDecl();
1182         NamedDecl * nd = dyn_cast<NamedDecl>(d);
1183         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
1184         ASTContext *Context = Result.Context;
1185
1186         if (nd->getName() == "thrd_create") {
1187             Expr * callee0 = ce->getArg(1)->IgnoreCasts();
1188             UnaryOperator * callee1;
1189             if ((callee1 = dyn_cast<UnaryOperator>(callee0))) {
1190                 if (callee1->getOpcode() == UnaryOperatorKind::UO_AddrOf)
1191                     callee0 = callee1->getSubExpr();
1192             }
1193             DeclRefExpr * callee = dyn_cast<DeclRefExpr>(callee0);
1194             if (!callee) return;
1195             FunctionDecl * fd = dyn_cast<FunctionDecl>(callee->getDecl());
1196             ThreadMains.insert(fd);
1197             return;
1198         }
1199
1200         if (!d->hasBody())
1201             return;
1202
1203         if (s && !ce->getCallReturnType(*Context)->isVoidType()) {
1204             // TODO check that the type is mc-visible also?
1205             const DeclStmt * ds;
1206             const VarDecl * lhs = NULL;
1207             std::string mc_rv = encodeRV(rvCount++);
1208
1209             std::stringstream brline;
1210             brline << "MCID " << mc_rv << ";\n";
1211             rewrite.InsertText(s->getLocStart(), brline.str(), false, true);
1212
1213             std::stringstream nol;
1214             if (ce->getNumArgs() > 0) nol << ", ";
1215             nol << "&" << mc_rv;
1216             rewrite.InsertTextBefore(ce->getRParenLoc(), nol.str());
1217
1218             if (s && (ds = dyn_cast<DeclStmt>(s))) {
1219                 if (!ds->isSingleDecl()) {
1220                     for (auto & d : ds->decls()) {
1221                         VarDecl * vd = dyn_cast<VarDecl>(d);
1222                         if (!d || vd->hasInit())
1223                             assert(0 && "unsupported form of decl");
1224                     }
1225                     return;
1226                 }
1227
1228                 lhs = retrieveSingleDecl(ds);
1229             }
1230
1231             DeclToMCVar[lhs] = mc_rv;
1232         }
1233
1234         for (const auto & a : ce->arguments()) {
1235             std::stringstream nol;
1236
1237             std::string aa = "MCID_NODEP";
1238
1239             Expr * e = a->IgnoreCasts();
1240             DeclRefExpr * dr = dyn_cast<DeclRefExpr>(e);
1241             if (dr) { 
1242                 NamedDecl * d = dr->getDecl();
1243                 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1244                     aa = DeclToMCVar[d];
1245             }
1246
1247             nol << aa << ", ";
1248             
1249             if (a->getLocEnd().isValid())
1250                 rewrite.InsertTextBefore(a->getLocStart(), nol.str());
1251         }
1252     }
1253
1254 private:
1255     Rewriter &rewrite;
1256     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1257     std::set<const FunctionDecl *> &ThreadMains;
1258 };
1259
1260 class ReturnHandler : public MatchFinder::MatchCallback {
1261 public:
1262     ReturnHandler(Rewriter &rewrite,
1263                   std::map<const NamedDecl *, std::string> &DeclToMCVar,
1264                   std::set<const FunctionDecl *> &ThreadMains)
1265         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1266
1267     virtual void run(const MatchFinder::MatchResult &Result) {
1268         const FunctionDecl * fd = Result.Nodes.getNodeAs<FunctionDecl>("containingFunction");
1269         ReturnStmt * rs = const_cast<ReturnStmt *>(Result.Nodes.getNodeAs<ReturnStmt>("returnStmt"));
1270         Expr * rv = const_cast<Expr *>(rs->getRetValue());
1271
1272         if (!rv) return;        
1273         if (ThreadMains.find(fd) != ThreadMains.end()) return;
1274         // not sure why this is explicitly needed, but crashes without it
1275         if (!fd->getIdentifier() || fd->getName() == "user_main") return;
1276
1277         FindLocalsVisitor flv;
1278         flv.TraverseStmt(rv);
1279         std::string mrv = "MCID_NODEP";
1280
1281         if (flv.RetrieveVars().size() > 0) {
1282             const NamedDecl * returnVar = flv.RetrieveVars()[0];
1283             if (DeclToMCVar.find(returnVar) != DeclToMCVar.end()) {
1284                 mrv = DeclToMCVar[returnVar];
1285             }
1286         }
1287         std::stringstream nol;
1288         nol << "*retval = " << mrv << ";\n";
1289         rewrite.InsertText(rs->getLocStart(), nol.str(), false, true);
1290     }
1291
1292 private:
1293     Rewriter &rewrite;
1294     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1295     std::set<const FunctionDecl *> &ThreadMains;
1296 };
1297
1298 class VarDeclHandler : public MatchFinder::MatchCallback {
1299 public:
1300     VarDeclHandler(Rewriter &rewrite,
1301                    std::map<const NamedDecl *, std::string> &DeclToMCVar,
1302                    std::set<const VarDecl *> &DeclsNeedingMC)
1303         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), DeclsNeedingMC(DeclsNeedingMC) {}
1304
1305     virtual void run(const MatchFinder::MatchResult &Result) {
1306         VarDecl * d = const_cast<VarDecl *>(Result.Nodes.getNodeAs<VarDecl>("d"));
1307         std::stringstream nol;
1308
1309         if (DeclsNeedingMC.find(d) == DeclsNeedingMC.end()) return;
1310
1311         std::string dn;
1312         if (DeclToMCVar.find(d) != DeclToMCVar.end())
1313             dn = DeclToMCVar[d];
1314         else
1315             dn = encode(d->getName().str());
1316
1317         nol << "MCID " << dn << "; ";
1318
1319         if (d->getLocStart().isValid())
1320             rewrite.InsertTextBefore(d->getLocStart(), nol.str());
1321     }
1322
1323 private:
1324     Rewriter &rewrite;
1325     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1326     std::set<const VarDecl *> &DeclsNeedingMC;
1327 };
1328
1329 class FunctionDeclHandler : public MatchFinder::MatchCallback {
1330 public:
1331     FunctionDeclHandler(Rewriter &rewrite,
1332                         std::set<const FunctionDecl *> &ThreadMains)
1333         : rewrite(rewrite), ThreadMains(ThreadMains) {}
1334
1335     virtual void run(const MatchFinder::MatchResult &Result) {
1336         FunctionDecl * fd = const_cast<FunctionDecl *>(Result.Nodes.getNodeAs<FunctionDecl>("fd"));
1337
1338         if (!fd->getIdentifier()) return;
1339
1340         if (fd->getName() == "user_main") { ThreadMains.insert(fd); return; }
1341
1342         if (ThreadMains.find(fd) != ThreadMains.end()) return;
1343
1344         SourceLocation LastParam = fd->getNameInfo().getLocStart().getLocWithOffset(fd->getName().size()).getLocWithOffset(1);
1345         for (auto & p : fd->params()) {
1346             std::stringstream nol;
1347             nol << "MCID " << encode(p->getName()) << ", ";
1348             if (p->getLocStart().isValid())
1349                 rewrite.InsertText(p->getLocStart(), nol.str(), false);
1350             if (p->getLocEnd().isValid())
1351                 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
1352         }
1353
1354         if (!fd->getReturnType()->isVoidType()) {
1355             std::stringstream nol;
1356             if (fd->param_size() > 0) nol << ", ";
1357             nol << "MCID * retval";
1358             rewrite.InsertText(LastParam, nol.str(), false);
1359         }
1360     }
1361
1362 private:
1363     Rewriter &rewrite;
1364     std::set<const FunctionDecl *> &ThreadMains;
1365 };
1366
1367 class BailHandler : public MatchFinder::MatchCallback {
1368 public:
1369     BailHandler() {}
1370     virtual void run(const MatchFinder::MatchResult &Result) {
1371         assert(0 && "we don't handle goto statements");
1372     }
1373 };
1374
1375 class MyASTConsumer : public ASTConsumer {
1376 public:
1377     MyASTConsumer(Rewriter &R) : R(R),
1378                                  DeclsRead(),
1379                                  DeclsInCond(),
1380                                  DeclToMCVar(),
1381                                  HandlerMalloc(MallocExprs),
1382                                  HandlerLoad(R, DeclsRead, DeclsNeedingMC, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1383                                  HandlerStore(R, DeclsRead, DeclsNeedingMC, DeferredUpdates),
1384                                  HandlerRMW(R, DeclsRead, DeclsInCond, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1385                                  HandlerLoop(R),
1386                                  HandlerBranchConditionRefactoring(R, DeclsInCond, DeclToMCVar, ExprToMCVar, Redirector, DeferredUpdates),
1387                                  HandlerAssign(R, DeclsRead, DeclsInCond, DeclsNeedingMC, DeclToMCVar, StmtsHandled, MallocExprs, DeferredUpdates),
1388                                  HandlerAnnotateBranch(R, DeclToMCVar, ExprToMCVar),
1389                                  HandlerFunctionDecl(R, ThreadMains),
1390                                  HandlerFunctionCall(R, DeclToMCVar, ThreadMains),
1391                                  HandlerReturn(R, DeclToMCVar, ThreadMains),
1392                                  HandlerVarDecl(R, DeclToMCVar, DeclsNeedingMC),
1393                                  HandlerBail() {
1394         MatcherFunctionCall.addMatcher(callExpr(anyOf(hasParent(compoundStmt()),
1395                                                       hasAncestor(varDecl(hasParent(stmt().bind("containingStmt")))),
1396                                                       hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")))).bind("callExpr"),
1397                                        &HandlerFunctionCall);
1398         MatcherLoadStore.addMatcher
1399             (callExpr(callee(functionDecl(anyOf(hasName("malloc"), hasName("calloc"))))).bind("callExpr"),
1400              &HandlerMalloc);
1401
1402         MatcherLoadStore.addMatcher
1403             (callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64")))),
1404                       anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1405                             hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")),
1406                             hasParent(stmt().bind("containingStmt"))))
1407              .bind("callExpr"),
1408              &HandlerLoad);
1409
1410         MatcherLoadStore.addMatcher(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr"),
1411                                     &HandlerStore);
1412
1413         MatcherLoadStore.addMatcher
1414             (callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64")))),
1415                       anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1416                             hasAncestor(binaryOperator(hasOperatorName("="),
1417                                                        hasLHS(declRefExpr().bind("lhs"))).bind("containingStmt")),
1418                             anything()))
1419              .bind("callExpr"),
1420              &HandlerRMW);
1421
1422         MatcherLoadStore.addMatcher(ifStmt(hasCondition
1423                                            (anyOf(binaryOperator().bind("bc"),
1424                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64"))))).bind("callExpr")),
1425                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr")),
1426                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64"))))).bind("callExpr")),
1427                                                   anything()))).bind("if"),
1428                                     &HandlerBranchConditionRefactoring);
1429
1430         MatcherLoadStore.addMatcher(forStmt().bind("s"),
1431                                     &HandlerLoop);
1432         MatcherLoadStore.addMatcher(whileStmt().bind("s"),
1433                                     &HandlerLoop);
1434         MatcherLoadStore.addMatcher(doStmt().bind("s"),
1435                                     &HandlerLoop);
1436
1437         MatcherFunction.addMatcher(binaryOperator(anyOf(hasAncestor(declStmt().bind("containingStmt")),
1438                                                         hasParent(compoundStmt())),
1439                                                         hasOperatorName("=")).bind("op"),
1440                                    &HandlerAssign);
1441         MatcherFunction.addMatcher(declStmt().bind("containingStmt"), &HandlerAssign);
1442
1443         MatcherFunction.addMatcher(ifStmt().bind("if"),
1444                                    &HandlerAnnotateBranch);
1445
1446         MatcherFunctionDecl.addMatcher(functionDecl().bind("fd"),
1447                                        &HandlerFunctionDecl);
1448         MatcherFunctionDecl.addMatcher(varDecl().bind("d"), &HandlerVarDecl);
1449         MatcherFunctionDecl.addMatcher(returnStmt(hasAncestor(functionDecl().bind("containingFunction"))).bind("returnStmt"),
1450                                    &HandlerReturn);
1451
1452         MatcherSanity.addMatcher(gotoStmt(), &HandlerBail);
1453     }
1454
1455     // Override the method that gets called for each parsed top-level
1456     // declaration.
1457     void HandleTranslationUnit(ASTContext &Context) override {
1458         LangOpts = Context.getLangOpts();
1459
1460         MatcherFunctionCall.matchAST(Context);
1461         MatcherLoadStore.matchAST(Context);
1462         MatcherFunction.matchAST(Context);
1463         MatcherFunctionDecl.matchAST(Context);
1464         MatcherSanity.matchAST(Context);
1465
1466         for (auto & u : DeferredUpdates) {
1467             R.InsertText(u->loc, u->update, true, true);
1468             delete u;
1469         }
1470         DeferredUpdates.clear();
1471     }
1472
1473 private:
1474     /* DeclsRead contains all local variables 'x' which:
1475     * 1) appear in 'x = load_32(...);
1476     * 2) appear in 'y = store_32(x); */
1477     std::set<const NamedDecl *> DeclsRead, DeclsInCond;
1478     std::map<const NamedDecl *, std::string> DeclToMCVar;
1479     std::map<const Expr *, std::string> ExprToMCVar;
1480     std::set<const VarDecl *> DeclsNeedingMC;
1481     std::set<const FunctionDecl *> ThreadMains;
1482     std::set<const Stmt *> StmtsHandled;
1483     std::set<const Expr *> MallocExprs;
1484     std::map<const Expr *, SourceLocation> Redirector;
1485     std::vector<Update *> DeferredUpdates;
1486
1487     Rewriter &R;
1488
1489     MallocHandler HandlerMalloc;
1490     LoadHandler HandlerLoad;
1491     StoreHandler HandlerStore;
1492     RMWHandler HandlerRMW;
1493     LoopHandler HandlerLoop;
1494     BranchConditionRefactoringHandler HandlerBranchConditionRefactoring;
1495     BranchAnnotationHandler HandlerAnnotateBranch;
1496     AssignHandler HandlerAssign;
1497     FunctionDeclHandler HandlerFunctionDecl;
1498     FunctionCallHandler HandlerFunctionCall;
1499     ReturnHandler HandlerReturn;
1500     VarDeclHandler HandlerVarDecl;
1501     BailHandler HandlerBail;
1502     MatchFinder MatcherLoadStore, MatcherFunction, MatcherFunctionDecl, MatcherFunctionCall, MatcherSanity;
1503 };
1504
1505 // For each source file provided to the tool, a new FrontendAction is created.
1506 class MyFrontendAction : public ASTFrontendAction {
1507 public:
1508     MyFrontendAction() {}
1509     void EndSourceFileAction() override {
1510         SourceManager &SM = TheRewriter.getSourceMgr();
1511         llvm::errs() << "** EndSourceFileAction for: "
1512                      << SM.getFileEntryForID(SM.getMainFileID())->getName() << "\n";
1513
1514         // Now emit the rewritten buffer.
1515         TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
1516     }
1517
1518     std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
1519                                                    StringRef file) override {
1520         llvm::errs() << "** Creating AST consumer for: " << file << "\n";
1521         TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
1522         return llvm::make_unique<MyASTConsumer>(TheRewriter);
1523     }
1524
1525 private:
1526     Rewriter TheRewriter;
1527 };
1528
1529 int main(int argc, const char **argv) {
1530     CommonOptionsParser op(argc, argv, AddMC2AnnotationsCategory);
1531     ClangTool Tool(op.getCompilations(), op.getSourcePathList());
1532     
1533     return Tool.run(newFrontendActionFactory<MyFrontendAction>().get());
1534 }