Fix apparent bug...
[satcheck.git] / clang / src / add_mc2_annotations.cpp
1 // -*-  indent-tabs-mode:nil; c-basic-offset:4; -*-
2 //------------------------------------------------------------------------------
3 // Add MC2 annotations to C code.
4 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
5 //
6 // Permission is hereby granted, free of charge, to any person
7 // obtaining a copy of this software and associated documentation
8 // files (the "Software"), to deal with the Software without
9 // restriction, including without limitation the rights to use, copy,
10 // modify, merge, publish, distribute, sublicense, and/or sell copies
11 // of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
13 //
14 // Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimers.
16 //
17 // Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimers in
19 // the documentation and/or other materials provided with the
20 // distribution.
21 //
22 // Neither the names of the University of Waterloo, nor the names of
23 // its contributors may be used to endorse or promote products derived
24 // from this Software without specific prior written permission.
25 //
26 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
27 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
28 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
29 // NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT
30 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
31 // WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
33 // DEALINGS WITH THE SOFTWARE.
34 //
35 // Patrick Lam (prof.lam@gmail.com)
36 //
37 // Base code:
38 // Eli Bendersky (eliben@gmail.com)
39 //
40 //------------------------------------------------------------------------------
41 #include <sstream>
42 #include <string>
43 #include <map>
44 #include <stdbool.h>
45
46 #include "clang/AST/AST.h"
47 #include "clang/AST/ASTContext.h"
48 #include "clang/AST/ASTConsumer.h"
49 #include "clang/AST/RecursiveASTVisitor.h"
50 #include "clang/ASTMatchers/ASTMatchers.h"
51 #include "clang/ASTMatchers/ASTMatchFinder.h"
52 #include "clang/Frontend/ASTConsumers.h"
53 #include "clang/Frontend/FrontendActions.h"
54 #include "clang/Frontend/CompilerInstance.h"
55 #include "clang/Lex/Lexer.h"
56 #include "clang/Tooling/CommonOptionsParser.h"
57 #include "clang/Tooling/Tooling.h"
58 #include "clang/Rewrite/Core/Rewriter.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/ADT/STLExtras.h"
61
62 using namespace clang;
63 using namespace clang::ast_matchers;
64 using namespace clang::driver;
65 using namespace clang::tooling;
66 using namespace llvm;
67
68 static LangOptions LangOpts;
69 static llvm::cl::OptionCategory AddMC2AnnotationsCategory("Add MC2 Annotations");
70
71 static std::string encode(std::string varName) {
72     std::stringstream nn;
73     nn << "_m" << varName;
74     return nn.str();
75 }
76
77 static int fnCount;
78 static std::string encodeFn(int num) {
79     std::stringstream nn;
80     nn << "_fn" << num;
81     return nn.str();
82 };
83
84 static int ptrCount;
85 static std::string encodePtr(int num) {
86     std::stringstream nn;
87     nn << "_p" << num;
88     return nn.str();
89 };
90
91 static int rmwCount;
92 static std::string encodeRMW(int num) {
93     std::stringstream nn;
94     nn << "_rmw" << num;
95     return nn.str();
96 };
97
98 static int branchCount;
99 static std::string encodeBranch(int num) {
100     std::stringstream nn;
101     nn << "_br" << num;
102     return nn.str();
103 };
104
105 static int condCount;
106 static std::string encodeCond(int num) {
107     std::stringstream nn;
108     nn << "_cond" << num;
109     return nn.str();
110 };
111
112 static int rvCount;
113 static std::string encodeRV(int num) {
114     std::stringstream nn;
115     nn << "_rv" << num;
116     return nn.str();
117 };
118
119 static int funcCount;
120
121 struct ProvisionalName {
122     int index, length;
123     const DeclRefExpr * pname;
124     bool enabled;
125
126     ProvisionalName(int index, const DeclRefExpr * pname) : index(index), pname(pname), length(encode(pname->getNameInfo().getName().getAsString()).length()), enabled(true) {}
127     ProvisionalName(int index, const DeclRefExpr * pname, int length) : index(index), pname(pname), length(length), enabled(true) {}
128 };
129
130 struct Update {
131     SourceLocation loc;
132     std::string update;
133     std::vector<ProvisionalName *> * pnames;
134
135     Update(SourceLocation loc, std::string update, std::vector<ProvisionalName *> * pnames) : 
136         loc(loc), update(update), pnames(pnames) {}
137
138     ~Update() { 
139         for (auto pname : *pnames) delete pname;
140         delete pnames; 
141     }
142 };
143
144 void updateProvisionalName(std::vector<Update *> &DeferredUpdates, const ValueDecl * now_known, std::string mcVar) {
145     for (Update * u : DeferredUpdates) {
146         for (int i = 0; i < u->pnames->size(); i++) {
147             ProvisionalName * v = (*(u->pnames))[i];
148             if (!v->enabled) continue;
149             if (now_known == v->pname->getDecl()) {
150                 v->enabled = false;
151                 std::string oldName = encode(v->pname->getNameInfo().getName().getAsString());
152
153                 u->update.replace(v->index, v->length, mcVar);
154                 for (int j = i+1; j < u->pnames->size(); j++) {
155                     ProvisionalName * vv = (*(u->pnames))[j];
156                     if (vv->index > v->index)
157                         vv->index -= v->length - mcVar.length();
158                 }
159             }
160         }
161     }
162 }
163
164 static const VarDecl * retrieveSingleDecl(const DeclStmt * s) {
165     // XXX iterate through all decls defined in s, not just the first one
166     assert(s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl()) && "unsupported form of decl");
167     if (s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl())) {
168         return cast<VarDecl>(s->getSingleDecl());
169     } else return NULL;
170 }
171
172 class FindCallArgVisitor : public RecursiveASTVisitor<FindCallArgVisitor> {
173 public:
174     FindCallArgVisitor() : DE(NULL), UnaryOp(NULL) {}
175
176     bool VisitStmt(Stmt * s) {
177         if (!UnaryOp) {
178             if (UnaryOperator * uo = dyn_cast<UnaryOperator>(s)) {
179                 if (uo->getOpcode() == UnaryOperatorKind::UO_AddrOf ||
180                     uo->getOpcode() == UnaryOperatorKind::UO_Deref)
181                     UnaryOp = uo;
182             }
183         }
184
185         if (!DE && (DE = dyn_cast<DeclRefExpr>(s)))
186             ;
187         return true;
188     }
189
190     void Clear() {
191         UnaryOp = NULL; DE = NULL;
192     }
193
194     const UnaryOperator * RetrieveUnaryOp() {
195         if (UnaryOp) {
196             bool found = false;
197             const Stmt * s = UnaryOp;
198             while (s != NULL) {
199                 if (s == DE) {
200                     found = true; break;
201                 }
202
203                 if (const UnaryOperator * op = dyn_cast<UnaryOperator>(s))
204                     s = op->getSubExpr();
205                 else if (const CastExpr * op = dyn_cast<CastExpr>(s))
206                     s = op->getSubExpr();
207                 else if (const MemberExpr * op = dyn_cast<MemberExpr>(s))
208                     s = op->getBase();
209                 else
210                     s = NULL;
211             }
212             if (found)
213                 return UnaryOp;
214         }
215         return NULL;
216     }
217
218     const DeclRefExpr * RetrieveDeclRefExpr() {
219         return DE;
220     }
221
222 private:
223     const UnaryOperator * UnaryOp;
224     const DeclRefExpr * DE;
225 };
226
227 class FindLocalsVisitor : public RecursiveASTVisitor<FindLocalsVisitor> {
228 public:
229     FindLocalsVisitor() : Vars() {}
230
231     bool VisitDeclRefExpr(DeclRefExpr * de) {
232         Vars.push_back(de->getDecl());
233         return true;
234     }
235
236     void Clear() {
237         Vars.clear();
238     }
239
240     const TinyPtrVector<const NamedDecl *> RetrieveVars() {
241         return Vars;
242     }
243
244 private:
245     TinyPtrVector<const NamedDecl *> Vars;
246 };
247
248
249 class MallocHandler : public MatchFinder::MatchCallback {
250 public:
251     MallocHandler(std::set<const Expr *> & MallocExprs) :
252         MallocExprs(MallocExprs) {}
253
254     virtual void run(const MatchFinder::MatchResult &Result) {
255         const CallExpr * ce = Result.Nodes.getNodeAs<CallExpr>("callExpr");
256
257         MallocExprs.insert(ce);
258     }
259
260     private:
261     std::set<const Expr *> &MallocExprs;
262 };
263
264 static void generateMC2Function(Rewriter & Rewrite,
265                                 const Expr * e, 
266                                 SourceLocation loc,
267                                 std::string tmpname, 
268                                 std::string tmpFn, 
269                                 const DeclRefExpr * lhs, 
270                                 std::string lhsName,
271                                 std::vector<ProvisionalName *> * vars1,
272                                 std::vector<Update *> & DeferredUpdates) {
273     // prettyprint the LHS (&newnode->value)
274     // e.g. int * _tmp0 = &newnode->value;
275     std::string SStr;
276     llvm::raw_string_ostream S(SStr);
277     e->printPretty(S, nullptr, Rewrite.getLangOpts());
278     const std::string &Str = S.str();
279
280     std::stringstream prel;
281     prel << "\nvoid * " << tmpname << " = " << Str << ";\n";
282
283     // MCID _p0 = MC2_function(1, MC2_PTR_LENGTH, _tmp0, _fn0);
284     prel << "MCID " << tmpFn << " = MC2_function_id(" << ++funcCount << ", 1, MC2_PTR_LENGTH, " << tmpname << ", ";
285     if (lhs) {
286         // XXX generate casts when they'd eliminate warnings
287         ProvisionalName * v = new ProvisionalName(prel.tellp(), lhs);
288         vars1->push_back(v);
289     }
290     prel << encode(lhsName) << "); ";
291
292     Update * u = new Update(loc, prel.str(), vars1);
293     DeferredUpdates.push_back(u);
294 }
295
296 class LoadHandler : public MatchFinder::MatchCallback {
297 public:
298     LoadHandler(Rewriter &Rewrite,
299                 std::set<const NamedDecl *> & DeclsRead,
300                 std::set<const VarDecl *> & DeclsNeedingMC, 
301                 std::map<const NamedDecl *, std::string> &DeclToMCVar,
302                 std::map<const Expr *, std::string> &ExprToMCVar,
303                 std::set<const Stmt *> & StmtsHandled,
304                 std::map<const Expr *, SourceLocation> &Redirector,
305                 std::vector<Update *> & DeferredUpdates) : 
306         Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeclToMCVar(DeclToMCVar),
307         ExprToMCVar(ExprToMCVar),
308         StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
309
310     virtual void run(const MatchFinder::MatchResult &Result) {
311         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
312         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
313         const VarDecl * d = Result.Nodes.getNodeAs<VarDecl>("decl");
314         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
315         const Expr * lhs = NULL;
316         if (s && isa<BinaryOperator>(s)) lhs = cast<BinaryOperator>(s)->getLHS();
317         if (!s) s = ce;
318
319         const DeclRefExpr * rhs = NULL;
320         MemberExpr * ml = NULL;
321         bool isAddrOfR = false, isAddrMemberR = false;
322
323         StmtsHandled.insert(s);
324         
325         std::string n, n_decl;
326         if (lhs) {
327             FindCallArgVisitor fcaVisitor;
328             fcaVisitor.Clear();
329             fcaVisitor.TraverseStmt(ce->getArg(0));
330             rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
331             const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
332             isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
333             isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
334             if (isAddrMemberR)
335                 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
336
337             FindLocalsVisitor flv;
338             flv.Clear();
339             flv.TraverseStmt(const_cast<Stmt*>(cast<Stmt>(lhs)));
340             for (auto & d : flv.RetrieveVars()) {
341                 const VarDecl * dd = cast<VarDecl>(d);
342                 n = dd->getName();
343                 // XXX todo rhs for non-decl stmts
344                 if (!isa<ParmVarDecl>(dd))
345                     DeclsNeedingMC.insert(dd);
346                 DeclsRead.insert(d);
347                 DeclToMCVar[dd] = encode(n);
348             }
349         } else {
350             FindCallArgVisitor fcaVisitor;
351             fcaVisitor.Clear();
352             fcaVisitor.TraverseStmt(ce->getArg(0));
353             rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
354             const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
355             isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
356             isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
357             if (isAddrMemberR)
358                 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
359
360             if (d) {
361                 n = d->getName();
362                 DeclsNeedingMC.insert(d);
363                 DeclsRead.insert(d);
364                 DeclToMCVar[d] = encode(n);
365             } else {
366                 n = ExprToMCVar[ce];
367                 fcaVisitor.Clear();
368                 fcaVisitor.TraverseStmt(ce);
369                 const DeclRefExpr * dd = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
370                 updateProvisionalName(DeferredUpdates, dd->getDecl(), encode(n));
371                 DeclToMCVar[dd->getDecl()] = encode(n);
372                 n_decl = "MCID ";
373             }
374         }
375         
376         std::stringstream nol;
377
378         if (lhs && isa<DeclRefExpr>(lhs)) {
379             const DeclRefExpr * ll = cast<DeclRefExpr>(lhs);
380             ProvisionalName * v = new ProvisionalName(nol.tellp(), ll);
381             vars->push_back(v);
382         }
383
384         if (rhs) {
385             if (isAddrMemberR) {
386                 if (!n.empty()) 
387                     nol << n_decl << encode(n) << "=";
388                 nol << "MC2_nextOpLoadOffset(";
389
390                 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
391                 vars->push_back(v);
392                 nol << encode(rhs->getNameInfo().getName().getAsString());
393
394                 nol << ", MC2_OFFSET(";
395                 nol << ml->getBase()->getType().getAsString();
396                 nol << ", ";
397                 nol << ml->getMemberDecl()->getName().str();
398                 nol << ")";
399             } else if (!isAddrOfR) {
400                 if (!n.empty()) 
401                     nol << n_decl << encode(n) << "=";
402                 nol << "MC2_nextOpLoad(";
403                 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
404                 vars->push_back(v);
405                 nol << encode(rhs->getNameInfo().getName().getAsString());
406             } else {
407                 if (!n.empty()) 
408                     nol << n_decl << encode(n) << "=";
409                 nol << "MC2_nextOpLoad(";
410                 nol << "MCID_NODEP";
411             }
412         } else {
413             if (!n.empty()) 
414                 nol << n_decl << encode(n) << "=";
415             nol << "MC2_nextOpLoad(";
416             nol << "MCID_NODEP";
417         }
418
419         if (lhs)
420             nol << "), ";
421         else
422             nol << "); ";
423         SourceLocation ss = s->getLocStart();
424         // eek gross hack:
425         // if the load appears as its own stmt and is the 1st stmt, containingStmt may be the containing CompoundStmt;
426         // move over 1 so that we get the right location.
427         if (isa<CompoundStmt>(s)) ss = ss.getLocWithOffset(1);
428         const Expr * e = dyn_cast<Expr>(s);
429         if (e && Redirector.count(e) > 0)
430             ss = Redirector[e];
431         Update * u = new Update(ss, nol.str(), vars);
432         DeferredUpdates.insert(DeferredUpdates.begin(), u);
433     }
434
435     private:
436     Rewriter &Rewrite;
437     std::set<const NamedDecl *> & DeclsRead;
438     std::set<const VarDecl *> & DeclsNeedingMC;
439     std::map<const Expr *, std::string> &ExprToMCVar;
440     std::map<const NamedDecl *, std::string> &DeclToMCVar;
441     std::set<const Stmt *> &StmtsHandled;
442     std::vector<Update *> &DeferredUpdates;
443     std::map<const Expr *, SourceLocation> &Redirector;
444 };
445
446 class StoreHandler : public MatchFinder::MatchCallback {
447 public:
448     StoreHandler(Rewriter &Rewrite,
449                  std::set<const NamedDecl *> & DeclsRead,
450                  std::set<const VarDecl *> &DeclsNeedingMC,
451                  std::vector<Update *> & DeferredUpdates) : 
452         Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeferredUpdates(DeferredUpdates) {}
453
454     virtual void run(const MatchFinder::MatchResult &Result) {
455         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
456         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
457
458         fcaVisitor.Clear();
459         fcaVisitor.TraverseStmt(ce->getArg(0));
460         const DeclRefExpr * lhs = fcaVisitor.RetrieveDeclRefExpr();
461         const UnaryOperator * luop = fcaVisitor.RetrieveUnaryOp();
462     
463         std::stringstream nol;
464
465         bool isAddrMemberL;
466         bool isAddrOfL;
467
468         if (luop && luop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
469             isAddrMemberL = isa<MemberExpr>(luop->getSubExpr());
470             isAddrOfL = !isa<MemberExpr>(luop->getSubExpr());
471         }
472
473         if (lhs) {
474             if (isAddrOfL) {
475                 nol << "MC2_nextOpStore(";
476                 nol << "MCID_NODEP";
477             } else {
478                 if (isAddrMemberL) {
479                     MemberExpr * ml = cast<MemberExpr>(luop->getSubExpr());
480
481                     nol << "MC2_nextOpStoreOffset(";
482
483                     ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
484                     vars->push_back(v);
485
486                     nol << encode(lhs->getNameInfo().getName().getAsString());
487                     if (!isa<ParmVarDecl>(lhs->getDecl()))
488                         DeclsNeedingMC.insert(cast<VarDecl>(lhs->getDecl()));
489
490                     nol << ", MC2_OFFSET(";
491                     nol << ml->getBase()->getType().getAsString();
492                     nol << ", ";
493                     nol << ml->getMemberDecl()->getName().str();
494                     nol << ")";
495                 } else {
496                     nol << "MC2_nextOpStore(";
497                     ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
498                     vars->push_back(v);
499
500                     nol << encode(lhs->getNameInfo().getName().getAsString());
501                 }
502             }
503         }
504         else {
505             nol << "MC2_nextOpStore(";
506             nol << "MCID_NODEP";
507         }
508         
509         nol << ", ";
510
511         fcaVisitor.Clear();
512         fcaVisitor.TraverseStmt(ce->getArg(1));
513         const DeclRefExpr * rhs = fcaVisitor.RetrieveDeclRefExpr();
514         const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
515
516         bool isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
517         bool isDerefR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_Deref;
518
519         if (rhs && !isAddrOfR) {
520             assert (!isDerefR && "Must use atomic load for dereferences!");
521             ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
522             vars->push_back(v);
523
524             nol << encode(rhs->getNameInfo().getName().getAsString());
525             DeclsRead.insert(rhs->getDecl());
526         }
527         else
528             nol << "MCID_NODEP";
529         
530         nol << ");\n";
531         Update * u = new Update(ce->getLocStart(), nol.str(), vars);
532         DeferredUpdates.push_back(u);
533     }
534
535     private:
536     Rewriter &Rewrite;
537     FindCallArgVisitor fcaVisitor;
538     std::set<const NamedDecl *> & DeclsRead;
539     std::set<const VarDecl *> & DeclsNeedingMC;
540     std::vector<Update *> &DeferredUpdates;
541 };
542
543 class RMWHandler : public MatchFinder::MatchCallback {
544 public:
545     RMWHandler(Rewriter &rewrite,
546                  std::set<const NamedDecl *> & DeclsRead,
547                  std::set<const NamedDecl *> & DeclsInCond,
548                  std::map<const NamedDecl *, std::string> &DeclToMCVar,
549                  std::map<const Expr *, std::string> &ExprToMCVar,
550                  std::set<const Stmt *> & StmtsHandled,
551                  std::map<const Expr *, SourceLocation> &Redirector,
552                  std::vector<Update *> & DeferredUpdates) : 
553         rewrite(rewrite), DeclsRead(DeclsRead), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar), 
554         ExprToMCVar(ExprToMCVar),
555         StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
556
557     virtual void run(const MatchFinder::MatchResult &Result) {
558         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
559         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
560         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
561
562         std::stringstream nol;
563         
564         std::string rmwMCVar;
565         rmwMCVar = encodeRMW(rmwCount++);
566
567         const VarDecl * rmw_lhs;
568         if (s) {
569             StmtsHandled.insert(s);
570             assert (isa<DeclStmt>(s) || isa<BinaryOperator>(s) && "unknown RMW format: not declrefexpr, not binaryoperator");
571             const DeclStmt * ds;
572             if ((ds = dyn_cast<DeclStmt>(s))) {
573                 rmw_lhs = retrieveSingleDecl(ds);
574             } else {
575                 const Expr * e = cast<BinaryOperator>(s)->getLHS();
576                 assert (isa<DeclRefExpr>(e));
577                 rmw_lhs = cast<VarDecl>(cast<DeclRefExpr>(e)->getDecl());
578             }
579             DeclToMCVar[rmw_lhs] = rmwMCVar;
580         }
581
582         // retrieve effective LHS of the RMW
583         fcaVisitor.Clear();
584         fcaVisitor.TraverseStmt(ce->getArg(1));
585         const DeclRefExpr * elhs = fcaVisitor.RetrieveDeclRefExpr();
586         const UnaryOperator * eluop = fcaVisitor.RetrieveUnaryOp();
587         bool isAddrMemberL = false;
588
589         if (eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
590             isAddrMemberL = isa<MemberExpr>(eluop->getSubExpr());
591         }
592
593         nol << "MCID " << rmwMCVar;
594         if (isAddrMemberL) {
595             MemberExpr * ml = cast<MemberExpr>(eluop->getSubExpr());
596
597             nol << " = MC2_nextRMWOffset(";
598
599             ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
600             vars->push_back(v);
601
602             nol << encode(elhs->getNameInfo().getName().getAsString());
603
604             nol << ", MC2_OFFSET(";
605             nol << ml->getBase()->getType().getAsString();
606             nol << ", ";
607             nol << ml->getMemberDecl()->getName().str();
608             nol << ")";
609         } else {
610             nol << " = MC2_nextRMW(";
611             bool isAddrOfL = eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
612
613             if (elhs) {
614                 if (isAddrOfL)
615                     nol << "MCID_NODEP";
616                 else {
617                     ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
618                     vars->push_back(v);
619
620                     std::string elhsName = encode(elhs->getNameInfo().getName().getAsString());
621                     nol << elhsName;
622                 }
623             }
624             else
625                 nol << "MCID_NODEP";
626         }
627         nol << ", ";
628
629         // handle both RHS ops
630         int outputted = 0;
631         for (int arg = 2; arg < 4; arg++) {
632             fcaVisitor.Clear();
633             fcaVisitor.TraverseStmt(ce->getArg(arg));
634             const DeclRefExpr * a = fcaVisitor.RetrieveDeclRefExpr();
635             const UnaryOperator * op = fcaVisitor.RetrieveUnaryOp();
636             
637             bool isAddrOfR = op && op->getOpcode() == UnaryOperatorKind::UO_AddrOf;
638             bool isDerefR = op && op->getOpcode() == UnaryOperatorKind::UO_Deref;
639
640             if (a && !isAddrOfR) {
641                 assert (!isDerefR && "Must use atomic load for dereferences!");
642
643                 DeclsInCond.insert(a->getDecl());
644
645                 if (outputted > 0) nol << ", ";
646                 outputted++;
647
648                 bool alreadyMCVar = false;
649                 if (DeclToMCVar.find(a->getDecl()) != DeclToMCVar.end()) {
650                     alreadyMCVar = true;
651                     nol << DeclToMCVar[a->getDecl()];
652                 }
653                 else {
654                     std::string an = "MCID_NODEP";
655                     ProvisionalName * v = new ProvisionalName(nol.tellp(), a, an.length());
656                     nol << an;
657                     vars->push_back(v);
658                 }
659
660                 DeclsRead.insert(a->getDecl());
661             }
662             else {
663                 if (outputted > 0) nol << ", ";
664                 outputted++;
665
666                 nol << "MCID_NODEP";
667             }
668         }
669         nol << ");\n";
670
671         SourceLocation place = s ? s->getLocStart() : ce->getLocStart();
672         const Expr * e = s ? dyn_cast<Expr>(s) : ce;
673         if (e && Redirector.count(e) > 0)
674             place = Redirector[e];
675         Update * u = new Update(place, nol.str(), vars);
676         DeferredUpdates.insert(DeferredUpdates.begin(), u);
677     }
678     
679     private:
680     Rewriter &rewrite;
681     FindCallArgVisitor fcaVisitor;
682     std::set<const NamedDecl *> &DeclsRead;
683     std::set<const NamedDecl *> &DeclsInCond;
684     std::map<const NamedDecl *, std::string> &DeclToMCVar;
685     std::map<const Expr *, std::string> &ExprToMCVar;
686     std::set<const Stmt *> &StmtsHandled;
687     std::vector<Update *> &DeferredUpdates;
688     std::map<const Expr *, SourceLocation> &Redirector;
689 };
690
691 class FindReturnsBreaksVisitor : public RecursiveASTVisitor<FindReturnsBreaksVisitor> {
692 public:
693     FindReturnsBreaksVisitor() : Returns(), Breaks() {}
694
695     bool VisitStmt(Stmt * s) {
696         if (isa<ReturnStmt>(s))
697             Returns.push_back(cast<ReturnStmt>(s));
698
699         if (isa<BreakStmt>(s))
700             Breaks.push_back(cast<BreakStmt>(s));
701         return true;
702     }
703
704     void Clear() {
705         Returns.clear(); Breaks.clear();
706     }
707
708     const std::vector<const ReturnStmt *> RetrieveReturns() {
709         return Returns;
710     }
711
712     const std::vector<const BreakStmt *> RetrieveBreaks() {
713         return Breaks;
714     }
715
716 private:
717     std::vector<const ReturnStmt *> Returns;
718     std::vector<const BreakStmt *> Breaks;
719 };
720
721 class LoopHandler : public MatchFinder::MatchCallback {
722 public:
723     LoopHandler(Rewriter &rewrite) : rewrite(rewrite) {}
724
725     virtual void run(const MatchFinder::MatchResult &Result) {
726         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
727
728         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
729                            "MC2_enterLoop();\n", true, true);
730
731         // annotate all returns with MC2_exitLoop()
732         // annotate all breaks that aren't further nested with MC2_exitLoop().
733         FindReturnsBreaksVisitor frbv;
734         if (isa<ForStmt>(s))
735             frbv.TraverseStmt(const_cast<Stmt *>(cast<ForStmt>(s)->getBody()));
736         if (isa<WhileStmt>(s))
737             frbv.TraverseStmt(const_cast<Stmt *>(cast<WhileStmt>(s)->getBody()));
738         if (isa<DoStmt>(s))
739             frbv.TraverseStmt(const_cast<Stmt *>(cast<DoStmt>(s)->getBody()));
740
741         for (auto & r : frbv.RetrieveReturns()) {
742             rewrite.InsertText(r->getLocStart(), "MC2_exitLoop();\n", true, true);
743         }
744         
745         // need to find all breaks and returns embedded inside the loop
746
747         rewrite.InsertTextAfterToken(rewrite.getSourceMgr().getExpansionLoc(s->getLocEnd().getLocWithOffset(1)),
748                                      "\nMC2_exitLoop();\n");
749     }
750
751 private:
752     Rewriter &rewrite;
753 };
754
755 /* Inserts MC2_function for any variables which are subsequently used by the model checker, as long as they depend on MC-visible [currently: read] state. */
756 class AssignHandler : public MatchFinder::MatchCallback {
757 public:
758     AssignHandler(Rewriter &rewrite, std::set<const NamedDecl *> &DeclsRead,
759                   std::set<const NamedDecl *> &DeclsInCond,
760                   std::set<const VarDecl *> &DeclsNeedingMC,
761                   std::map<const NamedDecl *, std::string> &DeclToMCVar,
762                   std::set<const Stmt *> &StmtsHandled,
763                   std::set<const Expr *> &MallocExprs,
764                   std::vector<Update *> &DeferredUpdates) :
765         rewrite(rewrite),
766         DeclsRead(DeclsRead),
767         DeclsInCond(DeclsInCond),
768         DeclsNeedingMC(DeclsNeedingMC),
769         DeclToMCVar(DeclToMCVar),
770         StmtsHandled(StmtsHandled),
771         MallocExprs(MallocExprs),
772         DeferredUpdates(DeferredUpdates) {}
773
774     virtual void run(const MatchFinder::MatchResult &Result) {
775         BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
776         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
777         FindLocalsVisitor locals, locals_rhs;
778
779         const VarDecl * lhs = NULL;
780         const Expr * rhs = NULL;
781         const DeclStmt * ds;
782
783         if (s && (ds = dyn_cast<DeclStmt>(s))) {
784             // long term goal: refactor the run() method to deal with one assignment at a time
785             // for now, if there is only declarations and no rhs's, we'll ignore this stmt
786             if (!ds->isSingleDecl()) {
787                 for (auto & d : ds->decls()) {
788                     VarDecl * vd = dyn_cast<VarDecl>(d);
789                     if (!d || vd->hasInit())
790                         assert(0 && "unsupported form of decl");
791                 }
792                 return;
793             }
794
795             lhs = retrieveSingleDecl(ds);
796         }
797
798         if (StmtsHandled.find(ds) != StmtsHandled.end() || StmtsHandled.find(op) != StmtsHandled.end())
799             return;
800
801         if (lhs) {
802             if (lhs->hasInit()) {
803                 rhs = lhs->getInit();
804                 if (rhs) {
805                     rhs = rhs->IgnoreCasts();
806                 }
807             }
808             else
809                 return;
810         }
811         std::set<std::string> mcState;
812
813         bool lhsUsedInCond;
814         bool rhsRead = false;
815
816         bool lhsTooComplicated = false;
817         if (op) {
818             DeclRefExpr * vd;
819             if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
820                 lhs = dyn_cast<VarDecl>(vd->getDecl());
821             else {
822                 // kick the can along...
823                 lhsTooComplicated = true;
824             }
825
826             rhs = op->getRHS();
827             if (rhs) 
828                 rhs = rhs->IgnoreCasts();
829         }
830
831         // rhs must be MC-active state, i.e. in declsread
832         // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
833
834         if (rhs) {
835             locals_rhs.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
836             for (auto & nd : locals_rhs.RetrieveVars()) {
837                 if (DeclsRead.find(nd) != DeclsRead.end())
838                     rhsRead = true;
839             }
840         }
841
842         locals.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
843
844         lhsUsedInCond = DeclsInCond.find(lhs) != DeclsInCond.end();
845         if (lhsUsedInCond) {
846             for (auto & d : locals.RetrieveVars()) {
847                 if (DeclToMCVar.count(d) > 0)
848                     mcState.insert(DeclToMCVar[d]);
849                 else if (DeclsRead.find(d) != DeclsRead.end())
850                     mcState.insert(encode(d->getName().str()));
851             }
852         }
853         if (rhsRead) {
854             for (auto & d : locals_rhs.RetrieveVars()) {
855                 if (DeclToMCVar.count(d) > 0)
856                     mcState.insert(DeclToMCVar[d]);
857                 else if (DeclsRead.find(d) != DeclsRead.end())
858                     mcState.insert(encode(d->getName().str()));
859             }
860         }
861         if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
862             if (lhsTooComplicated)
863                 assert(0 && "couldn't find LHS of = operator");
864
865             std::stringstream nol;
866             std::string _lhsStr, lhsStr;
867             std::string mcVar = encodeFn(fnCount++);
868             if (lhs) {
869                 lhsStr = lhs->getName().str();
870                 _lhsStr = encode(lhsStr);
871                 DeclToMCVar[lhs] = mcVar;
872                 DeclsNeedingMC.insert(cast<VarDecl>(lhs));
873             }
874             int function_id = 0;
875             if (!(MallocExprs.find(rhs) != MallocExprs.end()))
876                 function_id = ++funcCount;
877             nol << "\n" << mcVar << " = MC2_function_id(" << function_id << ", " << mcState.size();
878             if (lhs)
879                 nol << ", sizeof (" << lhsStr << "), (uint64_t)" << lhsStr;
880             else 
881                 nol << ", MC2_PTR_LENGTH";
882             for (auto & d : mcState) {
883                 nol <<  ", ";
884                 if (_lhsStr == d)
885                     nol << mcVar;
886                 else
887                     nol << d;
888             }
889             nol << "); ";
890             SourceLocation place;
891             if (op) {
892                 place = Lexer::getLocForEndOfToken(op->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts()).getLocWithOffset(1);
893             } else
894                 place = s->getLocEnd();
895             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
896                                nol.str(), true, true);
897
898             updateProvisionalName(DeferredUpdates, lhs, mcVar);
899         }
900     }
901
902     private:
903     Rewriter &rewrite;
904     std::set<const NamedDecl *> &DeclsRead, &DeclsInCond;
905     std::set<const VarDecl *> &DeclsNeedingMC;
906     std::map<const NamedDecl *, std::string> &DeclToMCVar;
907     std::set<const Stmt *> &StmtsHandled;
908     std::set<const Expr *> &MallocExprs;
909     std::vector<Update *> &DeferredUpdates;
910 };
911
912 // record vars used in conditions
913 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
914 public:
915     BranchConditionRefactoringHandler(Rewriter &rewrite,
916                                       std::set<const NamedDecl *> & DeclsInCond,
917                                       std::map<const NamedDecl *, std::string> &DeclToMCVar,
918                                       std::map<const Expr *, std::string> &ExprToMCVar,
919                                       std::map<const Expr *, SourceLocation> &Redirector,
920                                       std::vector<Update *> &DeferredUpdates) :
921         rewrite(rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
922         ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
923
924     virtual void run(const MatchFinder::MatchResult &Result) {
925         IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
926         Expr * cond = is->getCond();
927
928         // refactor out complicated conditions
929         FindCallArgVisitor flv;
930         flv.TraverseStmt(cond);
931         std::string mcVar;
932
933         BinaryOperator * bc = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("bc"));
934         if (bc) {
935             std::string condVar = encodeCond(condCount++);
936             std::stringstream condVarEncoded;
937             condVarEncoded << condVar << "_m";
938
939             // prettyprint the binary op
940             // e.g. int _cond0 = x == y;
941             std::string SStr;
942             llvm::raw_string_ostream S(SStr);
943             bc->printPretty(S, nullptr, rewrite.getLangOpts());
944             const std::string &Str = S.str();
945             
946             std::stringstream prel;
947
948             bool is_equality = false;
949             // handle equality tests
950             if (bc->getOpcode() == BO_EQ) {
951                 Expr * lhs = bc->getLHS()->IgnoreCasts(), * rhs = bc->getRHS()->IgnoreCasts();
952                 if (isa<DeclRefExpr>(lhs) && isa<DeclRefExpr>(rhs)) {
953                     DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
954                     is_equality = true;
955                     prel << "\nMCID " << condVarEncoded.str() << ";\n";
956                     std::string ld, rd;
957                     if (DeclToMCVar.find(l->getDecl()) != DeclToMCVar.end())
958                         ld = DeclToMCVar.find(l->getDecl())->second;
959                     else
960                         ld = encode(l->getDecl()->getName());
961                     if (DeclToMCVar.find(r->getDecl()) != DeclToMCVar.end())
962                         rd = DeclToMCVar.find(r->getDecl())->second;
963                     else
964                         rd = encode(r->getDecl()->getName());
965
966                     prel << "\nint " << condVar << " = MC2_equals(" <<
967                         ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
968                         rd << ", (uint64_t)" << r->getNameInfo().getName().getAsString() << ", " <<
969                         "&" << condVarEncoded.str() << ");\n";
970                 }
971             }
972
973             if (!is_equality) {
974                 prel << "\nint " << condVar << " = " << Str << ";";
975                 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
976                 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
977                 if (DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
978                     prel << ", " << DeclToMCVar[d->getDecl()];
979                 }
980                 prel << ");\n";
981             }
982
983             ExprToMCVar[cond] = condVarEncoded.str();
984             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
985                                prel.str(), false, true);
986
987             // rewrite the binary op with the newly-inserted var
988             Expr * RO = bc->getRHS(); // used for location only
989
990             int cl = Lexer::MeasureTokenLength(RO->getLocStart(), rewrite.getSourceMgr(), rewrite.getLangOpts());
991             SourceRange SR(cond->getLocStart(), rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
992             rewrite.ReplaceText(SR, condVar);
993         } else {
994             std::string condVar = encodeCond(condCount++);
995             std::stringstream condVarEncoded;
996             condVarEncoded << condVar << "_m";
997
998             std::string SStr;
999             llvm::raw_string_ostream S(SStr);
1000             cond->printPretty(S, nullptr, rewrite.getLangOpts());
1001             const std::string &Str = S.str();
1002
1003             std::stringstream prel;
1004             prel << "\nint " << condVar << " = " << Str << ";";
1005             prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
1006             std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
1007             const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
1008             if (isa<VarDecl>(d->getDecl()) && DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
1009                 prel << ", " << DeclToMCVar[d->getDecl()];
1010             } else {
1011                 prel << ", ";
1012                 ProvisionalName * v = new ProvisionalName(prel.tellp(), d, 0);
1013                 vars->push_back(v);
1014             }
1015             prel << ");\n";
1016
1017             ExprToMCVar[cond] = condVarEncoded.str();
1018             // gross hack; should look for any callexprs in cond
1019             // but right now, if it's a unaryop, just manually traverse
1020             if (isa<UnaryOperator>(cond)) {
1021                 Expr * e = dyn_cast<UnaryOperator>(cond)->getSubExpr();
1022                 ExprToMCVar[e] = condVarEncoded.str();
1023             }
1024             Update * u = new Update(is->getLocStart(), prel.str(), vars);
1025             DeferredUpdates.push_back(u);
1026
1027             // rewrite the call op with the newly-inserted var
1028             SourceRange SR(cond->getLocStart(), cond->getLocEnd());
1029             Redirector[cond] = is->getLocStart();
1030             rewrite.ReplaceText(SR, condVar);
1031         }
1032
1033         std::deque<const Decl *> q;
1034         const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1035         q.push_back(d);
1036         while (!q.empty()) {
1037             const Decl * d = q.back();
1038             q.pop_back();
1039             if (isa<NamedDecl>(d))
1040                 DeclsInCond.insert(cast<NamedDecl>(d));
1041
1042             const VarDecl * vd;
1043             if ((vd = dyn_cast<VarDecl>(d))) {
1044                 if (vd->hasInit()) {
1045                     const Expr * e = vd->getInit();
1046                     flv.Clear();
1047                     flv.TraverseStmt(const_cast<Expr *>(e));
1048                     const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1049                     q.push_back(d);
1050                 }
1051             }
1052         }
1053     }
1054
1055 private:
1056     Rewriter &rewrite;
1057     std::set<const NamedDecl *> & DeclsInCond;
1058     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1059     std::map<const Expr *, std::string> &ExprToMCVar;
1060     std::map<const Expr *, SourceLocation> &Redirector;
1061     std::vector<Update *> &DeferredUpdates;
1062 };
1063
1064 class BranchAnnotationHandler : public MatchFinder::MatchCallback {
1065 public:
1066     BranchAnnotationHandler(Rewriter &rewrite,
1067                             std::map<const NamedDecl *, std::string> & DeclToMCVar,
1068                             std::map<const Expr *, std::string> & ExprToMCVar)
1069         : rewrite(rewrite),
1070           DeclToMCVar(DeclToMCVar),
1071           ExprToMCVar(ExprToMCVar){}
1072     virtual void run(const MatchFinder::MatchResult &Result) {
1073         IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
1074
1075         // if the branch condition is interesting:
1076         // (but right now, not too interesting)
1077         Expr * cond = is->getCond()->IgnoreCasts();
1078
1079         FindLocalsVisitor flv;
1080         flv.TraverseStmt(cond);
1081         if (flv.RetrieveVars().size() == 0) return;
1082
1083         const NamedDecl * condVar = flv.RetrieveVars()[0];
1084
1085         std::string mCondVar;
1086         if (ExprToMCVar.count(cond) > 0)
1087             mCondVar = ExprToMCVar[cond];
1088         else if (DeclToMCVar.count(condVar) > 0) 
1089             mCondVar = DeclToMCVar[condVar];
1090         else
1091             mCondVar = encode(condVar->getName());
1092         std::string brVar = encodeBranch(branchCount++);
1093
1094         std::stringstream brline;
1095         brline << "MCID " << brVar << ";\n";
1096         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
1097                            brline.str(), false, true);
1098
1099         Stmt * ts = is->getThen(), * es = is->getElse();
1100         bool tHasChild = hasChild(ts);
1101         SourceLocation tfl;
1102         if (tHasChild) {
1103             if (isa<CompoundStmt>(ts))
1104                 tfl = getFirstChild(ts)->getLocStart();
1105             else
1106                 tfl = ts->getLocStart();
1107         } else
1108             tfl = ts->getLocStart().getLocWithOffset(1);
1109         SourceLocation tsl = ts->getLocEnd().getLocWithOffset(-1);
1110
1111         std::stringstream tlineStart, mergeStmt, eline;
1112
1113         UnaryOperator * uop = dyn_cast<UnaryOperator>(cond);
1114         tlineStart << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "1" << ", 2, true);\n";
1115         eline << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "0" << ", 2, true);";
1116
1117         mergeStmt << "\tMC2_merge(" << brVar << ");\n";
1118
1119         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tfl), tlineStart.str(), false, true);
1120
1121         Stmt * tls = NULL;
1122         int extra_else_offset = 0;
1123
1124         if (tHasChild) { tls = getLastChild(ts); }
1125         if (tls) extra_else_offset = 2; else extra_else_offset = 1;
1126
1127         if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
1128             extra_else_offset = 0;
1129             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tsl.getLocWithOffset(1)),
1130                                mergeStmt.str(), true, true);
1131         }
1132         if (tHasChild && !isa<CompoundStmt>(ts)) {
1133             rewrite.InsertText(rewrite.getSourceMgr().getFileLoc(tls->getLocStart()), "{", false, true);
1134             SourceLocation tend = Lexer::findLocationAfterToken(tls->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1135             rewrite.InsertText(tend, "}", true, true);
1136         }
1137         if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
1138
1139         if (es) {
1140             SourceLocation esl = es->getLocEnd().getLocWithOffset(-1);
1141             bool eHasChild = hasChild(es); 
1142             Stmt * els = NULL;
1143             if (eHasChild) els = getLastChild(es); else els = es;
1144             
1145             eline << "\n";
1146
1147             SourceLocation el;
1148             if (eHasChild) {
1149                 if (isa<CompoundStmt>(es))
1150                     el = getFirstChild(es)->getLocStart();
1151                 else {
1152                     el = es->getLocStart();
1153                 }
1154             } else
1155                 el = es->getLocStart().getLocWithOffset(1);
1156             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), eline.str(), false, true);
1157
1158             if (eHasChild && !isa<CompoundStmt>(es)) {
1159                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), "{", false, true);
1160                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(es->getLocEnd().getLocWithOffset(1)), "}", true, true);
1161             }
1162
1163             if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
1164                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(esl.getLocWithOffset(1)), mergeStmt.str(), true, true);
1165         }
1166         else {
1167             std::stringstream eCompoundLine;
1168             eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
1169             SourceLocation tend = Lexer::findLocationAfterToken(ts->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1170             if (!tend.isValid())
1171                 tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts());
1172             rewrite.InsertText(tend.getLocWithOffset(1), eCompoundLine.str(), false, true);
1173         }
1174     }
1175 private:
1176
1177     bool hasChild(Stmt * s) {
1178         if (!isa<CompoundStmt>(s)) return true;
1179         return (!cast<CompoundStmt>(s)->body_empty());
1180     }
1181
1182     Stmt * getFirstChild(Stmt * s) {
1183         assert(isa<CompoundStmt>(s) && "haven't yet added code to rewrite then/elsestmt to CompoundStmt");
1184         assert(!cast<CompoundStmt>(s)->body_empty());
1185         return *(cast<CompoundStmt>(s)->body_begin());
1186     }
1187
1188     Stmt * getLastChild(Stmt * s) {
1189         CompoundStmt * cs;
1190         if ((cs = dyn_cast<CompoundStmt>(s))) {
1191             assert (!cs->body_empty());
1192             return cs->body_back();
1193         }
1194         return s;
1195     }
1196
1197     Rewriter &rewrite;
1198     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1199     std::map<const Expr *, std::string> &ExprToMCVar;
1200 };
1201
1202 class FunctionCallHandler : public MatchFinder::MatchCallback {
1203 public:
1204     FunctionCallHandler(Rewriter &rewrite,
1205                         std::map<const NamedDecl *, std::string> &DeclToMCVar,
1206                         std::set<const FunctionDecl *> &ThreadMains)
1207         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1208
1209     virtual void run(const MatchFinder::MatchResult &Result) {
1210         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
1211         Decl * d = ce->getCalleeDecl();
1212         NamedDecl * nd = dyn_cast<NamedDecl>(d);
1213         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
1214         ASTContext *Context = Result.Context;
1215
1216         if (nd->getName() == "thrd_create") {
1217             Expr * callee0 = ce->getArg(1)->IgnoreCasts();
1218             UnaryOperator * callee1;
1219             if ((callee1 = dyn_cast<UnaryOperator>(callee0))) {
1220                 if (callee1->getOpcode() == UnaryOperatorKind::UO_AddrOf)
1221                     callee0 = callee1->getSubExpr();
1222             }
1223             DeclRefExpr * callee = dyn_cast<DeclRefExpr>(callee0);
1224             if (!callee) return;
1225             FunctionDecl * fd = dyn_cast<FunctionDecl>(callee->getDecl());
1226             ThreadMains.insert(fd);
1227             return;
1228         }
1229
1230         if (!d->hasBody())
1231             return;
1232
1233         if (s && !ce->getCallReturnType(*Context)->isVoidType()) {
1234             // TODO check that the type is mc-visible also?
1235             const DeclStmt * ds;
1236             const VarDecl * lhs = NULL;
1237             std::string mc_rv = encodeRV(rvCount++);
1238
1239             std::stringstream brline;
1240             brline << "MCID " << mc_rv << ";\n";
1241             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
1242                                brline.str(), false, true);
1243
1244             std::stringstream nol;
1245             if (ce->getNumArgs() > 0) nol << ", ";
1246             nol << "&" << mc_rv;
1247             rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(ce->getRParenLoc()),
1248                                      nol.str());
1249
1250             if (s && (ds = dyn_cast<DeclStmt>(s))) {
1251                 if (!ds->isSingleDecl()) {
1252                     for (auto & d : ds->decls()) {
1253                         VarDecl * vd = dyn_cast<VarDecl>(d);
1254                         if (!d || vd->hasInit())
1255                             assert(0 && "unsupported form of decl");
1256                     }
1257                     return;
1258                 }
1259
1260                 lhs = retrieveSingleDecl(ds);
1261             }
1262
1263             DeclToMCVar[lhs] = mc_rv;
1264         }
1265
1266         for (const auto & a : ce->arguments()) {
1267             std::stringstream nol;
1268
1269             std::string aa = "MCID_NODEP";
1270
1271             Expr * e = a->IgnoreCasts();
1272             DeclRefExpr * dr = dyn_cast<DeclRefExpr>(e);
1273             if (dr) { 
1274                 NamedDecl * d = dr->getDecl();
1275                 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1276                     aa = DeclToMCVar[d];
1277             }
1278
1279             nol << aa << ", ";
1280             
1281             if (a->getLocEnd().isValid())
1282                 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(a->getLocStart()),
1283                                          nol.str());
1284         }
1285     }
1286
1287 private:
1288     Rewriter &rewrite;
1289     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1290     std::set<const FunctionDecl *> &ThreadMains;
1291 };
1292
1293 class ReturnHandler : public MatchFinder::MatchCallback {
1294 public:
1295     ReturnHandler(Rewriter &rewrite,
1296                   std::map<const NamedDecl *, std::string> &DeclToMCVar,
1297                   std::set<const FunctionDecl *> &ThreadMains)
1298         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1299
1300     virtual void run(const MatchFinder::MatchResult &Result) {
1301         const FunctionDecl * fd = Result.Nodes.getNodeAs<FunctionDecl>("containingFunction");
1302         ReturnStmt * rs = const_cast<ReturnStmt *>(Result.Nodes.getNodeAs<ReturnStmt>("returnStmt"));
1303         Expr * rv = const_cast<Expr *>(rs->getRetValue());
1304
1305         if (!rv) return;        
1306         if (ThreadMains.find(fd) != ThreadMains.end()) return;
1307         // not sure why this is explicitly needed, but crashes without it
1308         if (!fd->getIdentifier() || fd->getName() == "user_main") return;
1309
1310         FindLocalsVisitor flv;
1311         flv.TraverseStmt(rv);
1312         std::string mrv = "MCID_NODEP";
1313
1314         if (flv.RetrieveVars().size() > 0) {
1315             const NamedDecl * returnVar = flv.RetrieveVars()[0];
1316             if (DeclToMCVar.find(returnVar) != DeclToMCVar.end()) {
1317                 mrv = DeclToMCVar[returnVar];
1318             }
1319         }
1320         std::stringstream nol;
1321         nol << "*retval = " << mrv << ";\n";
1322         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(rs->getLocStart()),
1323                            nol.str(), false, true);
1324     }
1325
1326 private:
1327     Rewriter &rewrite;
1328     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1329     std::set<const FunctionDecl *> &ThreadMains;
1330 };
1331
1332 class VarDeclHandler : public MatchFinder::MatchCallback {
1333 public:
1334     VarDeclHandler(Rewriter &rewrite,
1335                    std::map<const NamedDecl *, std::string> &DeclToMCVar,
1336                    std::set<const VarDecl *> &DeclsNeedingMC)
1337         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), DeclsNeedingMC(DeclsNeedingMC) {}
1338
1339     virtual void run(const MatchFinder::MatchResult &Result) {
1340         VarDecl * d = const_cast<VarDecl *>(Result.Nodes.getNodeAs<VarDecl>("d"));
1341         std::stringstream nol;
1342
1343         if (DeclsNeedingMC.find(d) == DeclsNeedingMC.end()) return;
1344
1345         std::string dn;
1346         if (DeclToMCVar.find(d) != DeclToMCVar.end())
1347             dn = DeclToMCVar[d];
1348         else
1349             dn = encode(d->getName().str());
1350
1351         nol << "MCID " << dn << "; ";
1352
1353         if (d->getLocStart().isValid())
1354             rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(d->getLocStart()),
1355                                      nol.str());
1356     }
1357
1358 private:
1359     Rewriter &rewrite;
1360     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1361     std::set<const VarDecl *> &DeclsNeedingMC;
1362 };
1363
1364 class FunctionDeclHandler : public MatchFinder::MatchCallback {
1365 public:
1366     FunctionDeclHandler(Rewriter &rewrite,
1367                         std::set<const FunctionDecl *> &ThreadMains)
1368         : rewrite(rewrite), ThreadMains(ThreadMains) {}
1369
1370     virtual void run(const MatchFinder::MatchResult &Result) {
1371         FunctionDecl * fd = const_cast<FunctionDecl *>(Result.Nodes.getNodeAs<FunctionDecl>("fd"));
1372
1373         if (!fd->getIdentifier()) return;
1374
1375         if (fd->getName() == "user_main") { ThreadMains.insert(fd); return; }
1376
1377         if (ThreadMains.find(fd) != ThreadMains.end()) return;
1378
1379         SourceLocation LastParam = fd->getNameInfo().getLocStart().getLocWithOffset(fd->getName().size()).getLocWithOffset(1);
1380         for (auto & p : fd->params()) {
1381             std::stringstream nol;
1382             nol << "MCID " << encode(p->getName()) << ", ";
1383             if (p->getLocStart().isValid())
1384                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(p->getLocStart()),
1385                                    nol.str(), false);
1386             if (p->getLocEnd().isValid())
1387                 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
1388         }
1389
1390         if (!fd->getReturnType()->isVoidType()) {
1391             std::stringstream nol;
1392             if (fd->param_size() > 0) nol << ", ";
1393             nol << "MCID * retval";
1394             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(LastParam),
1395                                nol.str(), false);
1396         }
1397     }
1398
1399 private:
1400     Rewriter &rewrite;
1401     std::set<const FunctionDecl *> &ThreadMains;
1402 };
1403
1404 class BailHandler : public MatchFinder::MatchCallback {
1405 public:
1406     BailHandler() {}
1407     virtual void run(const MatchFinder::MatchResult &Result) {
1408         assert(0 && "we don't handle goto statements");
1409     }
1410 };
1411
1412 class MyASTConsumer : public ASTConsumer {
1413 public:
1414     MyASTConsumer(Rewriter &R) : R(R),
1415                                  DeclsRead(),
1416                                  DeclsInCond(),
1417                                  DeclToMCVar(),
1418                                  HandlerMalloc(MallocExprs),
1419                                  HandlerLoad(R, DeclsRead, DeclsNeedingMC, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1420                                  HandlerStore(R, DeclsRead, DeclsNeedingMC, DeferredUpdates),
1421                                  HandlerRMW(R, DeclsRead, DeclsInCond, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1422                                  HandlerLoop(R),
1423                                  HandlerBranchConditionRefactoring(R, DeclsInCond, DeclToMCVar, ExprToMCVar, Redirector, DeferredUpdates),
1424                                  HandlerAssign(R, DeclsRead, DeclsInCond, DeclsNeedingMC, DeclToMCVar, StmtsHandled, MallocExprs, DeferredUpdates),
1425                                  HandlerAnnotateBranch(R, DeclToMCVar, ExprToMCVar),
1426                                  HandlerFunctionDecl(R, ThreadMains),
1427                                  HandlerFunctionCall(R, DeclToMCVar, ThreadMains),
1428                                  HandlerReturn(R, DeclToMCVar, ThreadMains),
1429                                  HandlerVarDecl(R, DeclToMCVar, DeclsNeedingMC),
1430                                  HandlerBail() {
1431         MatcherFunctionCall.addMatcher(callExpr(anyOf(hasParent(compoundStmt()),
1432                                                       hasAncestor(varDecl(hasParent(stmt().bind("containingStmt")))),
1433                                                       hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")))).bind("callExpr"),
1434                                        &HandlerFunctionCall);
1435         MatcherLoadStore.addMatcher
1436             (callExpr(callee(functionDecl(anyOf(hasName("malloc"), hasName("calloc"))))).bind("callExpr"),
1437              &HandlerMalloc);
1438
1439         MatcherLoadStore.addMatcher
1440             (callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64")))),
1441                       anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1442                             hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")),
1443                             hasParent(stmt().bind("containingStmt"))))
1444              .bind("callExpr"),
1445              &HandlerLoad);
1446
1447         MatcherLoadStore.addMatcher(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr"),
1448                                     &HandlerStore);
1449
1450         MatcherLoadStore.addMatcher
1451             (callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64")))),
1452                       anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1453                             hasAncestor(binaryOperator(hasOperatorName("="),
1454                                                        hasLHS(declRefExpr().bind("lhs"))).bind("containingStmt")),
1455                             anything()))
1456              .bind("callExpr"),
1457              &HandlerRMW);
1458
1459         MatcherLoadStore.addMatcher(ifStmt(hasCondition
1460                                            (anyOf(binaryOperator().bind("bc"),
1461                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64"))))).bind("callExpr")),
1462                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr")),
1463                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64"))))).bind("callExpr")),
1464                                                   anything()))).bind("if"),
1465                                     &HandlerBranchConditionRefactoring);
1466
1467         MatcherLoadStore.addMatcher(forStmt().bind("s"),
1468                                     &HandlerLoop);
1469         MatcherLoadStore.addMatcher(whileStmt().bind("s"),
1470                                     &HandlerLoop);
1471         MatcherLoadStore.addMatcher(doStmt().bind("s"),
1472                                     &HandlerLoop);
1473
1474         MatcherFunction.addMatcher(binaryOperator(anyOf(hasAncestor(declStmt().bind("containingStmt")),
1475                                                         hasParent(compoundStmt())),
1476                                                         hasOperatorName("=")).bind("op"),
1477                                    &HandlerAssign);
1478         MatcherFunction.addMatcher(declStmt().bind("containingStmt"), &HandlerAssign);
1479
1480         MatcherFunction.addMatcher(ifStmt().bind("if"),
1481                                    &HandlerAnnotateBranch);
1482
1483         MatcherFunctionDecl.addMatcher(functionDecl().bind("fd"),
1484                                        &HandlerFunctionDecl);
1485         MatcherFunctionDecl.addMatcher(varDecl().bind("d"), &HandlerVarDecl);
1486         MatcherFunctionDecl.addMatcher(returnStmt(hasAncestor(functionDecl().bind("containingFunction"))).bind("returnStmt"),
1487                                    &HandlerReturn);
1488
1489         MatcherSanity.addMatcher(gotoStmt(), &HandlerBail);
1490     }
1491
1492     // Override the method that gets called for each parsed top-level
1493     // declaration.
1494     void HandleTranslationUnit(ASTContext &Context) override {
1495         LangOpts = Context.getLangOpts();
1496
1497         MatcherFunctionCall.matchAST(Context);
1498         MatcherLoadStore.matchAST(Context);
1499         MatcherFunction.matchAST(Context);
1500         MatcherFunctionDecl.matchAST(Context);
1501         MatcherSanity.matchAST(Context);
1502
1503         for (auto & u : DeferredUpdates) {
1504             R.InsertText(R.getSourceMgr().getExpansionLoc(u->loc), u->update, true, true);
1505             delete u;
1506         }
1507         DeferredUpdates.clear();
1508     }
1509
1510 private:
1511     /* DeclsRead contains all local variables 'x' which:
1512     * 1) appear in 'x = load_32(...);
1513     * 2) appear in 'y = store_32(x); */
1514     std::set<const NamedDecl *> DeclsRead;
1515     /* DeclsInCond contains all local variables 'x' used in a branch condition or rmw parameter */
1516     std::set<const NamedDecl *> DeclsInCond;
1517     std::map<const NamedDecl *, std::string> DeclToMCVar;
1518     std::map<const Expr *, std::string> ExprToMCVar;
1519     std::set<const VarDecl *> DeclsNeedingMC;
1520     std::set<const FunctionDecl *> ThreadMains;
1521     std::set<const Stmt *> StmtsHandled;
1522     std::set<const Expr *> MallocExprs;
1523     std::map<const Expr *, SourceLocation> Redirector;
1524     std::vector<Update *> DeferredUpdates;
1525
1526     Rewriter &R;
1527
1528     MallocHandler HandlerMalloc;
1529     LoadHandler HandlerLoad;
1530     StoreHandler HandlerStore;
1531     RMWHandler HandlerRMW;
1532     LoopHandler HandlerLoop;
1533     BranchConditionRefactoringHandler HandlerBranchConditionRefactoring;
1534     BranchAnnotationHandler HandlerAnnotateBranch;
1535     AssignHandler HandlerAssign;
1536     FunctionDeclHandler HandlerFunctionDecl;
1537     FunctionCallHandler HandlerFunctionCall;
1538     ReturnHandler HandlerReturn;
1539     VarDeclHandler HandlerVarDecl;
1540     BailHandler HandlerBail;
1541     MatchFinder MatcherLoadStore, MatcherFunction, MatcherFunctionDecl, MatcherFunctionCall, MatcherSanity;
1542 };
1543
1544 // For each source file provided to the tool, a new FrontendAction is created.
1545 class MyFrontendAction : public ASTFrontendAction {
1546 public:
1547     MyFrontendAction() {}
1548     void EndSourceFileAction() override {
1549         SourceManager &SM = TheRewriter.getSourceMgr();
1550         llvm::errs() << "** EndSourceFileAction for: "
1551                      << SM.getFileEntryForID(SM.getMainFileID())->getName() << "\n";
1552
1553         // Now emit the rewritten buffer.
1554         TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
1555     }
1556
1557     std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
1558                                                    StringRef file) override {
1559         llvm::errs() << "** Creating AST consumer for: " << file << "\n";
1560         TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
1561         return llvm::make_unique<MyASTConsumer>(TheRewriter);
1562     }
1563
1564 private:
1565     Rewriter TheRewriter;
1566 };
1567
1568 int main(int argc, const char **argv) {
1569     CommonOptionsParser op(argc, argv, AddMC2AnnotationsCategory);
1570     ClangTool Tool(op.getCompilations(), op.getSourcePathList());
1571     
1572     return Tool.run(newFrontendActionFactory<MyFrontendAction>().get());
1573 }