fix replacement around macro expansion; add new benchmarks from Stavros Aronis
[satcheck.git] / clang / src / add_mc2_annotations.cpp
1 // -*-  indent-tabs-mode:nil; c-basic-offset:4; -*-
2 //------------------------------------------------------------------------------
3 // Add MC2 annotations to C code.
4 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
5 //
6 // Permission is hereby granted, free of charge, to any person
7 // obtaining a copy of this software and associated documentation
8 // files (the "Software"), to deal with the Software without
9 // restriction, including without limitation the rights to use, copy,
10 // modify, merge, publish, distribute, sublicense, and/or sell copies
11 // of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
13 //
14 // Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimers.
16 //
17 // Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimers in
19 // the documentation and/or other materials provided with the
20 // distribution.
21 //
22 // Neither the names of the University of Waterloo, nor the names of
23 // its contributors may be used to endorse or promote products derived
24 // from this Software without specific prior written permission.
25 //
26 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
27 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
28 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
29 // NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT
30 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
31 // WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
33 // DEALINGS WITH THE SOFTWARE.
34 //
35 // Patrick Lam (prof.lam@gmail.com)
36 //
37 // Base code:
38 // Eli Bendersky (eliben@gmail.com)
39 //
40 //------------------------------------------------------------------------------
41 #include <sstream>
42 #include <string>
43 #include <map>
44 #include <stdbool.h>
45
46 #include "clang/AST/AST.h"
47 #include "clang/AST/ASTContext.h"
48 #include "clang/AST/ASTConsumer.h"
49 #include "clang/AST/RecursiveASTVisitor.h"
50 #include "clang/ASTMatchers/ASTMatchers.h"
51 #include "clang/ASTMatchers/ASTMatchFinder.h"
52 #include "clang/Frontend/ASTConsumers.h"
53 #include "clang/Frontend/FrontendActions.h"
54 #include "clang/Frontend/CompilerInstance.h"
55 #include "clang/Lex/Lexer.h"
56 #include "clang/Tooling/CommonOptionsParser.h"
57 #include "clang/Tooling/Tooling.h"
58 #include "clang/Rewrite/Core/Rewriter.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/ADT/STLExtras.h"
61
62 using namespace clang;
63 using namespace clang::ast_matchers;
64 using namespace clang::driver;
65 using namespace clang::tooling;
66 using namespace llvm;
67
68 static LangOptions LangOpts;
69 static llvm::cl::OptionCategory AddMC2AnnotationsCategory("Add MC2 Annotations");
70
71 static std::string encode(std::string varName) {
72     std::stringstream nn;
73     nn << "_m" << varName;
74     return nn.str();
75 }
76
77 static int fnCount;
78 static std::string encodeFn(int num) {
79     std::stringstream nn;
80     nn << "_fn" << num;
81     return nn.str();
82 };
83
84 static int ptrCount;
85 static std::string encodePtr(int num) {
86     std::stringstream nn;
87     nn << "_p" << num;
88     return nn.str();
89 };
90
91 static int rmwCount;
92 static std::string encodeRMW(int num) {
93     std::stringstream nn;
94     nn << "_rmw" << num;
95     return nn.str();
96 };
97
98 static int branchCount;
99 static std::string encodeBranch(int num) {
100     std::stringstream nn;
101     nn << "_br" << num;
102     return nn.str();
103 };
104
105 static int condCount;
106 static std::string encodeCond(int num) {
107     std::stringstream nn;
108     nn << "_cond" << num;
109     return nn.str();
110 };
111
112 static int rvCount;
113 static std::string encodeRV(int num) {
114     std::stringstream nn;
115     nn << "_rv" << num;
116     return nn.str();
117 };
118
119 static int funcCount;
120
121 struct ProvisionalName {
122     int index, length;
123     const DeclRefExpr * pname;
124     bool enabled;
125
126     ProvisionalName(int index, const DeclRefExpr * pname) : index(index), pname(pname), length(encode(pname->getNameInfo().getName().getAsString()).length()), enabled(true) {}
127     ProvisionalName(int index, const DeclRefExpr * pname, int length) : index(index), pname(pname), length(length), enabled(true) {}
128 };
129
130 struct Update {
131     SourceLocation loc;
132     std::string update;
133     std::vector<ProvisionalName *> * pnames;
134
135     Update(SourceLocation loc, std::string update, std::vector<ProvisionalName *> * pnames) : 
136         loc(loc), update(update), pnames(pnames) {}
137
138     ~Update() { 
139         for (auto pname : *pnames) delete pname;
140         delete pnames; 
141     }
142 };
143
144 void updateProvisionalName(std::vector<Update *> &DeferredUpdates, const ValueDecl * now_known, std::string mcVar) {
145     for (Update * u : DeferredUpdates) {
146         for (int i = 0; i < u->pnames->size(); i++) {
147             ProvisionalName * v = (*(u->pnames))[i];
148             if (!v->enabled) continue;
149             if (now_known == v->pname->getDecl()) {
150                 v->enabled = false;
151                 std::string oldName = encode(v->pname->getNameInfo().getName().getAsString());
152
153                 u->update.replace(v->index, v->length, mcVar);
154                 for (int j = i+1; j < u->pnames->size(); j++) {
155                     ProvisionalName * vv = (*(u->pnames))[j];
156                     if (vv->index > v->index)
157                         vv->index -= v->length - mcVar.length();
158                 }
159             }
160         }
161     }
162 }
163
164 static const VarDecl * retrieveSingleDecl(const DeclStmt * s) {
165     // XXX iterate through all decls defined in s, not just the first one
166     assert(s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl()) && "unsupported form of decl");
167     if (s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl())) {
168         return cast<VarDecl>(s->getSingleDecl());
169     } else return NULL;
170 }
171
172 class FindCallArgVisitor : public RecursiveASTVisitor<FindCallArgVisitor> {
173 public:
174     FindCallArgVisitor() : DE(NULL), UnaryOp(NULL) {}
175
176     bool VisitStmt(Stmt * s) {
177         if (!UnaryOp) {
178             if (UnaryOperator * uo = dyn_cast<UnaryOperator>(s)) {
179                 if (uo->getOpcode() == UnaryOperatorKind::UO_AddrOf ||
180                     uo->getOpcode() == UnaryOperatorKind::UO_Deref)
181                     UnaryOp = uo;
182             }
183         }
184
185         if (!DE && (DE = dyn_cast<DeclRefExpr>(s)))
186             ;
187         return true;
188     }
189
190     void Clear() {
191         UnaryOp = NULL; DE = NULL;
192     }
193
194     const UnaryOperator * RetrieveUnaryOp() {
195         if (UnaryOp) {
196             bool found = false;
197             const Stmt * s = UnaryOp;
198             while (s != NULL) {
199                 if (s == DE) {
200                     found = true; break;
201                 }
202
203                 if (const UnaryOperator * op = dyn_cast<UnaryOperator>(s))
204                     s = op->getSubExpr();
205                 else if (const CastExpr * op = dyn_cast<CastExpr>(s))
206                     s = op->getSubExpr();
207                 else if (const MemberExpr * op = dyn_cast<MemberExpr>(s))
208                     s = op->getBase();
209                 else
210                     s = NULL;
211             }
212             if (found)
213                 return UnaryOp;
214         }
215         return NULL;
216     }
217
218     const DeclRefExpr * RetrieveDeclRefExpr() {
219         return DE;
220     }
221
222 private:
223     const UnaryOperator * UnaryOp;
224     const DeclRefExpr * DE;
225 };
226
227 class FindLocalsVisitor : public RecursiveASTVisitor<FindLocalsVisitor> {
228 public:
229     FindLocalsVisitor() : Vars() {}
230
231     bool VisitDeclRefExpr(DeclRefExpr * de) {
232         Vars.push_back(de->getDecl());
233         return true;
234     }
235
236     void Clear() {
237         Vars.clear();
238     }
239
240     const TinyPtrVector<const NamedDecl *> RetrieveVars() {
241         return Vars;
242     }
243
244 private:
245     TinyPtrVector<const NamedDecl *> Vars;
246 };
247
248
249 class MallocHandler : public MatchFinder::MatchCallback {
250 public:
251     MallocHandler(std::set<const Expr *> & MallocExprs) :
252         MallocExprs(MallocExprs) {}
253
254     virtual void run(const MatchFinder::MatchResult &Result) {
255         const CallExpr * ce = Result.Nodes.getNodeAs<CallExpr>("callExpr");
256
257         MallocExprs.insert(ce);
258     }
259
260     private:
261     std::set<const Expr *> &MallocExprs;
262 };
263
264 static void generateMC2Function(Rewriter & Rewrite,
265                                 const Expr * e, 
266                                 SourceLocation loc,
267                                 std::string tmpname, 
268                                 std::string tmpFn, 
269                                 const DeclRefExpr * lhs, 
270                                 std::string lhsName,
271                                 std::vector<ProvisionalName *> * vars1,
272                                 std::vector<Update *> & DeferredUpdates) {
273     // prettyprint the LHS (&newnode->value)
274     // e.g. int * _tmp0 = &newnode->value;
275     std::string SStr;
276     llvm::raw_string_ostream S(SStr);
277     e->printPretty(S, nullptr, Rewrite.getLangOpts());
278     const std::string &Str = S.str();
279
280     std::stringstream prel;
281     prel << "\nvoid * " << tmpname << " = " << Str << ";\n";
282
283     // MCID _p0 = MC2_function(1, MC2_PTR_LENGTH, _tmp0, _fn0);
284     prel << "MCID " << tmpFn << " = MC2_function_id(" << ++funcCount << ", 1, MC2_PTR_LENGTH, " << tmpname << ", ";
285     if (lhs) {
286         // XXX generate casts when they'd eliminate warnings
287         ProvisionalName * v = new ProvisionalName(prel.tellp(), lhs);
288         vars1->push_back(v);
289     }
290     prel << encode(lhsName) << "); ";
291
292     Update * u = new Update(loc, prel.str(), vars1);
293     DeferredUpdates.push_back(u);
294 }
295
296 class LoadHandler : public MatchFinder::MatchCallback {
297 public:
298     LoadHandler(Rewriter &Rewrite,
299                 std::set<const NamedDecl *> & DeclsRead,
300                 std::set<const VarDecl *> & DeclsNeedingMC, 
301                 std::map<const NamedDecl *, std::string> &DeclToMCVar,
302                 std::map<const Expr *, std::string> &ExprToMCVar,
303                 std::set<const Stmt *> & StmtsHandled,
304                 std::map<const Expr *, SourceLocation> &Redirector,
305                 std::vector<Update *> & DeferredUpdates) : 
306         Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeclToMCVar(DeclToMCVar),
307         ExprToMCVar(ExprToMCVar),
308         StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
309
310     virtual void run(const MatchFinder::MatchResult &Result) {
311         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
312         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
313         const VarDecl * d = Result.Nodes.getNodeAs<VarDecl>("decl");
314         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
315         const Expr * lhs = NULL;
316         if (s && isa<BinaryOperator>(s)) lhs = cast<BinaryOperator>(s)->getLHS();
317         if (!s) s = ce;
318
319         const DeclRefExpr * rhs = NULL;
320         MemberExpr * ml = NULL;
321         bool isAddrOfR = false, isAddrMemberR = false;
322
323         StmtsHandled.insert(s);
324         
325         std::string n, n_decl;
326         if (lhs) {
327             FindCallArgVisitor fcaVisitor;
328             fcaVisitor.Clear();
329             fcaVisitor.TraverseStmt(ce->getArg(0));
330             rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
331             const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
332             isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
333             isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
334             if (isAddrMemberR)
335                 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
336
337             FindLocalsVisitor flv;
338             flv.Clear();
339             flv.TraverseStmt(const_cast<Stmt*>(cast<Stmt>(lhs)));
340             for (auto & d : flv.RetrieveVars()) {
341                 const VarDecl * dd = cast<VarDecl>(d);
342                 n = dd->getName();
343                 // XXX todo rhs for non-decl stmts
344                 if (!isa<ParmVarDecl>(dd))
345                     DeclsNeedingMC.insert(dd);
346                 DeclsRead.insert(d);
347                 DeclToMCVar[dd] = encode(n);
348             }
349         } else {
350             FindCallArgVisitor fcaVisitor;
351             fcaVisitor.Clear();
352             fcaVisitor.TraverseStmt(ce->getArg(0));
353             rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
354             const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
355             isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
356             isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
357             if (isAddrMemberR)
358                 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
359
360             if (d) {
361                 n = d->getName();
362                 DeclsNeedingMC.insert(d);
363                 DeclsRead.insert(d);
364                 DeclToMCVar[d] = encode(n);
365             } else {
366                 n = ExprToMCVar[ce];
367                 fcaVisitor.Clear();
368                 fcaVisitor.TraverseStmt(ce);
369                 const DeclRefExpr * dd = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
370                 updateProvisionalName(DeferredUpdates, dd->getDecl(), encode(n));
371                 DeclToMCVar[dd->getDecl()] = encode(n);
372                 n_decl = "MCID ";
373             }
374         }
375         
376         std::stringstream nol;
377
378         if (lhs && isa<DeclRefExpr>(lhs)) {
379             const DeclRefExpr * ll = cast<DeclRefExpr>(lhs);
380             ProvisionalName * v = new ProvisionalName(nol.tellp(), ll);
381             vars->push_back(v);
382         }
383
384         if (rhs) {
385             if (isAddrMemberR) {
386                 if (!n.empty()) 
387                     nol << n_decl << encode(n) << "=";
388                 nol << "MC2_nextOpLoadOffset(";
389
390                 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
391                 vars->push_back(v);
392                 nol << encode(rhs->getNameInfo().getName().getAsString());
393
394                 nol << ", MC2_OFFSET(";
395                 nol << ml->getBase()->getType().getAsString();
396                 nol << ", ";
397                 nol << ml->getMemberDecl()->getName().str();
398                 nol << ")";
399             } else if (!isAddrOfR) {
400                 if (!n.empty()) 
401                     nol << n_decl << encode(n) << "=";
402                 nol << "MC2_nextOpLoad(";
403                 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
404                 vars->push_back(v);
405                 nol << encode(rhs->getNameInfo().getName().getAsString());
406             } else {
407                 if (!n.empty()) 
408                     nol << n_decl << encode(n) << "=";
409                 nol << "MC2_nextOpLoad(";
410                 nol << "MCID_NODEP";
411             }
412         } else {
413             if (!n.empty()) 
414                 nol << n_decl << encode(n) << "=";
415             nol << "MC2_nextOpLoad(";
416             nol << "MCID_NODEP";
417         }
418
419         if (lhs)
420             nol << "), ";
421         else
422             nol << "); ";
423         SourceLocation ss = s->getLocStart();
424         // eek gross hack:
425         // if the load appears as its own stmt and is the 1st stmt, containingStmt may be the containing CompoundStmt;
426         // move over 1 so that we get the right location.
427         if (isa<CompoundStmt>(s)) ss = ss.getLocWithOffset(1);
428         const Expr * e = dyn_cast<Expr>(s);
429         if (e && Redirector.count(e) > 0)
430             ss = Redirector[e];
431         Update * u = new Update(ss, nol.str(), vars);
432         DeferredUpdates.insert(DeferredUpdates.begin(), u);
433     }
434
435     private:
436     Rewriter &Rewrite;
437     std::set<const NamedDecl *> & DeclsRead;
438     std::set<const VarDecl *> & DeclsNeedingMC;
439     std::map<const Expr *, std::string> &ExprToMCVar;
440     std::map<const NamedDecl *, std::string> &DeclToMCVar;
441     std::set<const Stmt *> &StmtsHandled;
442     std::vector<Update *> &DeferredUpdates;
443     std::map<const Expr *, SourceLocation> &Redirector;
444 };
445
446 class StoreHandler : public MatchFinder::MatchCallback {
447 public:
448     StoreHandler(Rewriter &Rewrite,
449                  std::set<const NamedDecl *> & DeclsRead,
450                  std::set<const VarDecl *> &DeclsNeedingMC,
451                  std::vector<Update *> & DeferredUpdates) : 
452         Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeferredUpdates(DeferredUpdates) {}
453
454     virtual void run(const MatchFinder::MatchResult &Result) {
455         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
456         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
457
458         fcaVisitor.Clear();
459         fcaVisitor.TraverseStmt(ce->getArg(0));
460         const DeclRefExpr * lhs = fcaVisitor.RetrieveDeclRefExpr();
461         const UnaryOperator * luop = fcaVisitor.RetrieveUnaryOp();
462     
463         std::stringstream nol;
464
465         bool isAddrMemberL;
466         bool isAddrOfL;
467
468         if (luop && luop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
469             isAddrMemberL = isa<MemberExpr>(luop->getSubExpr());
470             isAddrOfL = !isa<MemberExpr>(luop->getSubExpr());
471         }
472
473         if (lhs) {
474             if (isAddrOfL) {
475                 nol << "MC2_nextOpStore(";
476                 nol << "MCID_NODEP";
477             } else {
478                 if (isAddrMemberL) {
479                     MemberExpr * ml = cast<MemberExpr>(luop->getSubExpr());
480
481                     nol << "MC2_nextOpStoreOffset(";
482
483                     ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
484                     vars->push_back(v);
485
486                     nol << encode(lhs->getNameInfo().getName().getAsString());
487                     if (!isa<ParmVarDecl>(lhs->getDecl()))
488                         DeclsNeedingMC.insert(cast<VarDecl>(lhs->getDecl()));
489
490                     nol << ", MC2_OFFSET(";
491                     nol << ml->getBase()->getType().getAsString();
492                     nol << ", ";
493                     nol << ml->getMemberDecl()->getName().str();
494                     nol << ")";
495                 } else {
496                     nol << "MC2_nextOpStore(";
497                     ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
498                     vars->push_back(v);
499
500                     nol << encode(lhs->getNameInfo().getName().getAsString());
501                 }
502             }
503         }
504         else {
505             nol << "MC2_nextOpStore(";
506             nol << "MCID_NODEP";
507         }
508         
509         nol << ", ";
510
511         fcaVisitor.Clear();
512         fcaVisitor.TraverseStmt(ce->getArg(1));
513         const DeclRefExpr * rhs = fcaVisitor.RetrieveDeclRefExpr();
514         const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
515
516         bool isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
517         bool isDerefR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_Deref;
518
519         if (rhs && !isAddrOfR) {
520             assert (!isDerefR && "Must use atomic load for dereferences!");
521             ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
522             vars->push_back(v);
523
524             nol << encode(rhs->getNameInfo().getName().getAsString());
525             DeclsRead.insert(rhs->getDecl());
526         }
527         else
528             nol << "MCID_NODEP";
529         
530         nol << ");\n";
531         Update * u = new Update(ce->getLocStart(), nol.str(), vars);
532         DeferredUpdates.push_back(u);
533     }
534
535     private:
536     Rewriter &Rewrite;
537     FindCallArgVisitor fcaVisitor;
538     std::set<const NamedDecl *> & DeclsRead;
539     std::set<const VarDecl *> & DeclsNeedingMC;
540     std::vector<Update *> &DeferredUpdates;
541 };
542
543 class RMWHandler : public MatchFinder::MatchCallback {
544 public:
545     RMWHandler(Rewriter &rewrite,
546                  std::set<const NamedDecl *> & DeclsRead,
547                  std::set<const NamedDecl *> & DeclsInCond,
548                  std::map<const NamedDecl *, std::string> &DeclToMCVar,
549                  std::map<const Expr *, std::string> &ExprToMCVar,
550                  std::set<const Stmt *> & StmtsHandled,
551                  std::map<const Expr *, SourceLocation> &Redirector,
552                  std::vector<Update *> & DeferredUpdates) : 
553         rewrite(rewrite), DeclsRead(DeclsRead), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar), 
554         ExprToMCVar(ExprToMCVar),
555         StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
556
557     virtual void run(const MatchFinder::MatchResult &Result) {
558         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
559         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
560         std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
561
562         std::stringstream nol;
563         
564         std::string rmwMCVar;
565         rmwMCVar = encodeRMW(rmwCount++);
566
567         const VarDecl * rmw_lhs;
568         if (s) {
569             StmtsHandled.insert(s);
570             assert (isa<DeclStmt>(s) || isa<BinaryOperator>(s) && "unknown RMW format: not declrefexpr, not binaryoperator");
571             const DeclStmt * ds;
572             if ((ds = dyn_cast<DeclStmt>(s))) {
573                 rmw_lhs = retrieveSingleDecl(ds);
574             } else {
575                 const Expr * e = cast<BinaryOperator>(s)->getLHS();
576                 assert (isa<DeclRefExpr>(e));
577                 rmw_lhs = cast<VarDecl>(cast<DeclRefExpr>(e)->getDecl());
578             }
579             DeclToMCVar[rmw_lhs] = rmwMCVar;
580         }
581
582         // retrieve effective LHS of the RMW
583         fcaVisitor.Clear();
584         fcaVisitor.TraverseStmt(ce->getArg(1));
585         const DeclRefExpr * elhs = fcaVisitor.RetrieveDeclRefExpr();
586         const UnaryOperator * eluop = fcaVisitor.RetrieveUnaryOp();
587         bool isAddrMemberL = false;
588
589         if (eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
590             isAddrMemberL = isa<MemberExpr>(eluop->getSubExpr());
591         }
592
593         nol << "MCID " << rmwMCVar;
594         if (isAddrMemberL) {
595             MemberExpr * ml = cast<MemberExpr>(eluop->getSubExpr());
596
597             nol << " = MC2_nextRMWOffset(";
598
599             ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
600             vars->push_back(v);
601
602             nol << encode(elhs->getNameInfo().getName().getAsString());
603
604             nol << ", MC2_OFFSET(";
605             nol << ml->getBase()->getType().getAsString();
606             nol << ", ";
607             nol << ml->getMemberDecl()->getName().str();
608             nol << ")";
609         } else {
610             nol << " = MC2_nextRMW(";
611             bool isAddrOfL = eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
612
613             if (elhs) {
614                 if (isAddrOfL)
615                     nol << "MCID_NODEP";
616                 else {
617                     ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
618                     vars->push_back(v);
619
620                     std::string elhsName = encode(elhs->getNameInfo().getName().getAsString());
621                     nol << elhsName;
622                 }
623             }
624             else
625                 nol << "MCID_NODEP";
626         }
627         nol << ", ";
628
629         // handle both RHS ops
630         int outputted = 0;
631         for (int arg = 2; arg < 4; arg++) {
632             fcaVisitor.Clear();
633             fcaVisitor.TraverseStmt(ce->getArg(arg));
634             const DeclRefExpr * a = fcaVisitor.RetrieveDeclRefExpr();
635             const UnaryOperator * op = fcaVisitor.RetrieveUnaryOp();
636             
637             bool isAddrOfR = op && op->getOpcode() == UnaryOperatorKind::UO_AddrOf;
638             bool isDerefR = op && op->getOpcode() == UnaryOperatorKind::UO_Deref;
639
640             if (a && !isAddrOfR) {
641                 assert (!isDerefR && "Must use atomic load for dereferences!");
642
643                 DeclsInCond.insert(a->getDecl());
644
645                 if (outputted > 0) nol << ", ";
646                 outputted++;
647
648                 bool alreadyMCVar = false;
649                 if (DeclToMCVar.find(a->getDecl()) != DeclToMCVar.end()) {
650                     alreadyMCVar = true;
651                     nol << DeclToMCVar[a->getDecl()];
652                 }
653                 else {
654                     std::string an = "MCID_NODEP";
655                     ProvisionalName * v = new ProvisionalName(nol.tellp(), a, an.length());
656                     nol << an;
657                     vars->push_back(v);
658                 }
659
660                 DeclsRead.insert(a->getDecl());
661             }
662             else {
663                 if (outputted > 0) nol << ", ";
664                 outputted++;
665
666                 nol << "MCID_NODEP";
667             }
668         }
669         nol << ");\n";
670
671         SourceLocation place = s ? s->getLocStart() : ce->getLocStart();
672         const Expr * e = s ? dyn_cast<Expr>(s) : ce;
673         if (e && Redirector.count(e) > 0)
674             place = Redirector[e];
675         Update * u = new Update(place, nol.str(), vars);
676         DeferredUpdates.insert(DeferredUpdates.begin(), u);
677     }
678     
679     private:
680     Rewriter &rewrite;
681     FindCallArgVisitor fcaVisitor;
682     std::set<const NamedDecl *> &DeclsRead;
683     std::set<const NamedDecl *> &DeclsInCond;
684     std::map<const NamedDecl *, std::string> &DeclToMCVar;
685     std::map<const Expr *, std::string> &ExprToMCVar;
686     std::set<const Stmt *> &StmtsHandled;
687     std::vector<Update *> &DeferredUpdates;
688     std::map<const Expr *, SourceLocation> &Redirector;
689 };
690
691 class FindReturnsBreaksVisitor : public RecursiveASTVisitor<FindReturnsBreaksVisitor> {
692 public:
693     FindReturnsBreaksVisitor() : Returns(), Breaks() {}
694
695     bool VisitStmt(Stmt * s) {
696         if (isa<ReturnStmt>(s))
697             Returns.push_back(cast<ReturnStmt>(s));
698
699         if (isa<BreakStmt>(s))
700             Breaks.push_back(cast<BreakStmt>(s));
701         return true;
702     }
703
704     void Clear() {
705         Returns.clear(); Breaks.clear();
706     }
707
708     const std::vector<const ReturnStmt *> RetrieveReturns() {
709         return Returns;
710     }
711
712     const std::vector<const BreakStmt *> RetrieveBreaks() {
713         return Breaks;
714     }
715
716 private:
717     std::vector<const ReturnStmt *> Returns;
718     std::vector<const BreakStmt *> Breaks;
719 };
720
721 class LoopHandler : public MatchFinder::MatchCallback {
722 public:
723     LoopHandler(Rewriter &rewrite) : rewrite(rewrite) {}
724
725     virtual void run(const MatchFinder::MatchResult &Result) {
726         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
727
728         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
729                            "MC2_enterLoop();\n", true, true);
730
731         // annotate all returns with MC2_exitLoop()
732         // annotate all breaks that aren't further nested with MC2_exitLoop().
733         FindReturnsBreaksVisitor frbv;
734         if (isa<ForStmt>(s))
735             frbv.TraverseStmt(const_cast<Stmt *>(cast<ForStmt>(s)->getBody()));
736         if (isa<WhileStmt>(s))
737             frbv.TraverseStmt(const_cast<Stmt *>(cast<WhileStmt>(s)->getBody()));
738         if (isa<DoStmt>(s))
739             frbv.TraverseStmt(const_cast<Stmt *>(cast<DoStmt>(s)->getBody()));
740
741         for (auto & r : frbv.RetrieveReturns()) {
742             rewrite.InsertText(r->getLocStart(), "MC2_exitLoop();\n", true, true);
743         }
744         
745         // need to find all breaks and returns embedded inside the loop
746
747         rewrite.InsertTextAfterToken(rewrite.getSourceMgr().getExpansionLoc(s->getLocEnd().getLocWithOffset(1)),
748                                      "\nMC2_exitLoop();\n");
749     }
750
751 private:
752     Rewriter &rewrite;
753 };
754
755 /* Inserts MC2_function for any variables which are subsequently used by the model checker, as long as they depend on MC-visible [currently: read] state. */
756 class AssignHandler : public MatchFinder::MatchCallback {
757 public:
758     AssignHandler(Rewriter &rewrite, std::set<const NamedDecl *> &DeclsRead,
759                   std::set<const NamedDecl *> &DeclsInCond,
760                   std::set<const VarDecl *> &DeclsNeedingMC,
761                   std::map<const NamedDecl *, std::string> &DeclToMCVar,
762                   std::set<const Stmt *> &StmtsHandled,
763                   std::set<const Expr *> &MallocExprs,
764                   std::vector<Update *> &DeferredUpdates) :
765         rewrite(rewrite),
766         DeclsRead(DeclsRead),
767         DeclsInCond(DeclsInCond),
768         DeclsNeedingMC(DeclsNeedingMC),
769         DeclToMCVar(DeclToMCVar),
770         StmtsHandled(StmtsHandled),
771         MallocExprs(MallocExprs),
772         DeferredUpdates(DeferredUpdates) {}
773
774     virtual void run(const MatchFinder::MatchResult &Result) {
775         BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
776         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
777         FindLocalsVisitor flv;
778
779         const VarDecl * lhs = NULL;
780         const Expr * rhs = NULL;
781         const DeclStmt * ds;
782
783         if (s && (ds = dyn_cast<DeclStmt>(s))) {
784             // long term goal: refactor the run() method to deal with one assignment at a time
785             // for now, if there is only declarations and no rhs's, we'll ignore this stmt
786             if (!ds->isSingleDecl()) {
787                 for (auto & d : ds->decls()) {
788                     VarDecl * vd = dyn_cast<VarDecl>(d);
789                     if (!d || vd->hasInit())
790                         assert(0 && "unsupported form of decl");
791                 }
792                 return;
793             }
794
795             lhs = retrieveSingleDecl(ds);
796         }
797
798         if (StmtsHandled.find(ds) != StmtsHandled.end() || StmtsHandled.find(op) != StmtsHandled.end())
799             return;
800
801         if (lhs) {
802             if (lhs->hasInit()) {
803                 rhs = lhs->getInit();
804                 if (rhs) {
805                     rhs = rhs->IgnoreCasts();
806                 }
807             }
808             else
809                 return;
810         }
811         std::set<std::string> mcState;
812
813         bool lhsTooComplicated = false;
814         if (op) {
815             flv.TraverseStmt(op);
816
817             DeclRefExpr * vd;
818             if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
819                 lhs = dyn_cast<VarDecl>(vd->getDecl());
820             else {
821                 // kick the can along...
822                 lhsTooComplicated = true;
823             }
824
825             rhs = op->getRHS();
826             if (rhs) 
827                 rhs = rhs->IgnoreCasts();
828         }
829         else if (lhs) {
830             // rhs must be MC-active state, i.e. in declsread
831             // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
832             flv.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
833         }
834
835         if (DeclsInCond.find(lhs) != DeclsInCond.end()) {
836             for (auto & d : flv.RetrieveVars()) {
837                 if (DeclToMCVar.count(d) > 0)
838                     mcState.insert(DeclToMCVar[d]);
839                 else if (DeclsRead.find(d) != DeclsRead.end())
840                     mcState.insert(encode(d->getName().str()));
841             }
842         }
843
844         if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
845             if (lhsTooComplicated)
846                 assert(0 && "couldn't find LHS of = operator");
847
848             std::stringstream nol;
849             std::string _lhsStr, lhsStr;
850             std::string mcVar = encodeFn(fnCount++);
851             if (lhs) {
852                 lhsStr = lhs->getName().str();
853                 _lhsStr = encode(lhsStr);
854                 DeclToMCVar[lhs] = mcVar;
855                 DeclsNeedingMC.insert(cast<VarDecl>(lhs));
856             }
857             int function_id = 0;
858             if (!(MallocExprs.find(rhs) != MallocExprs.end()))
859                 function_id = ++funcCount;
860             nol << "\n" << mcVar << " = MC2_function_id(" << function_id << ", " << mcState.size();
861             if (lhs)
862                 nol << ", sizeof (" << lhsStr << "), (uint64_t)" << lhsStr;
863             else 
864                 nol << ", MC2_PTR_LENGTH";
865             for (auto & d : mcState) {
866                 nol <<  ", ";
867                 if (_lhsStr == d)
868                     nol << mcVar;
869                 else
870                     nol << d;
871             }
872             nol << "); ";
873             SourceLocation place;
874             if (op)
875                 place = op->getLocEnd().getLocWithOffset(1);
876             else
877                 place = s->getLocEnd();
878             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
879                                nol.str(), true, true);
880
881             updateProvisionalName(DeferredUpdates, lhs, mcVar);
882         }
883     }
884
885     private:
886     Rewriter &rewrite;
887     std::set<const NamedDecl *> &DeclsRead, &DeclsInCond;
888     std::set<const VarDecl *> &DeclsNeedingMC;
889     std::map<const NamedDecl *, std::string> &DeclToMCVar;
890     std::set<const Stmt *> &StmtsHandled;
891     std::set<const Expr *> &MallocExprs;
892     std::vector<Update *> &DeferredUpdates;
893 };
894
895 // record vars used in conditions
896 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
897 public:
898     BranchConditionRefactoringHandler(Rewriter &rewrite,
899                                       std::set<const NamedDecl *> & DeclsInCond,
900                                       std::map<const NamedDecl *, std::string> &DeclToMCVar,
901                                       std::map<const Expr *, std::string> &ExprToMCVar,
902                                       std::map<const Expr *, SourceLocation> &Redirector,
903                                       std::vector<Update *> &DeferredUpdates) :
904         rewrite(rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
905         ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
906
907     virtual void run(const MatchFinder::MatchResult &Result) {
908         IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
909         Expr * cond = is->getCond();
910
911         // refactor out complicated conditions
912         FindCallArgVisitor flv;
913         flv.TraverseStmt(cond);
914         std::string mcVar;
915
916         BinaryOperator * bc = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("bc"));
917         if (bc) {
918             std::string condVar = encodeCond(condCount++);
919             std::stringstream condVarEncoded;
920             condVarEncoded << condVar << "_m";
921
922             // prettyprint the binary op
923             // e.g. int _cond0 = x == y;
924             std::string SStr;
925             llvm::raw_string_ostream S(SStr);
926             bc->printPretty(S, nullptr, rewrite.getLangOpts());
927             const std::string &Str = S.str();
928             
929             std::stringstream prel;
930
931             bool is_equality = false;
932             // handle equality tests
933             if (bc->getOpcode() == BO_EQ) {
934                 Expr * lhs = bc->getLHS()->IgnoreCasts(), * rhs = bc->getRHS()->IgnoreCasts();
935                 if (isa<DeclRefExpr>(lhs) && isa<DeclRefExpr>(rhs)) {
936                     DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
937                     is_equality = true;
938                     prel << "\nMCID " << condVarEncoded.str() << ";\n";
939                     std::string ld = DeclToMCVar.find(l->getDecl())->second,
940                         rd = DeclToMCVar.find(r->getDecl())->second;
941
942                     prel << "\nint " << condVar << " = MC2_equals(" <<
943                         ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
944                         rd << ", (uint64_t)" << r->getNameInfo().getName().getAsString() << ", " <<
945                         "&" << condVarEncoded.str() << ");\n";
946                 }
947             }
948
949             if (!is_equality) {
950                 prel << "\nint " << condVar << " = " << Str << ";";
951                 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
952                 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
953                 if (DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
954                     prel << ", " << DeclToMCVar[d->getDecl()];
955                 }
956                 prel << ");\n";
957             }
958
959             ExprToMCVar[cond] = condVarEncoded.str();
960             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
961                                prel.str(), false, true);
962
963             // rewrite the binary op with the newly-inserted var
964             Expr * RO = bc->getRHS(); // used for location only
965
966             int cl = Lexer::MeasureTokenLength(RO->getLocStart(), rewrite.getSourceMgr(), rewrite.getLangOpts());
967             SourceRange SR(cond->getLocStart(), rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
968             rewrite.ReplaceText(SR, condVar);
969         } else {
970             std::string condVar = encodeCond(condCount++);
971             std::stringstream condVarEncoded;
972             condVarEncoded << condVar << "_m";
973
974             std::string SStr;
975             llvm::raw_string_ostream S(SStr);
976             cond->printPretty(S, nullptr, rewrite.getLangOpts());
977             const std::string &Str = S.str();
978
979             std::stringstream prel;
980             prel << "\nint " << condVar << " = " << Str << ";";
981             prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
982             std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
983             const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
984             if (isa<VarDecl>(d->getDecl()) && DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
985                 prel << ", " << DeclToMCVar[d->getDecl()];
986             } else {
987                 prel << ", ";
988                 ProvisionalName * v = new ProvisionalName(prel.tellp(), d, 0);
989                 vars->push_back(v);
990             }
991             prel << ");\n";
992
993             ExprToMCVar[cond] = condVarEncoded.str();
994             // gross hack; should look for any callexprs in cond
995             // but right now, if it's a unaryop, just manually traverse
996             if (isa<UnaryOperator>(cond)) {
997                 Expr * e = dyn_cast<UnaryOperator>(cond)->getSubExpr();
998                 ExprToMCVar[e] = condVarEncoded.str();
999             }
1000             Update * u = new Update(is->getLocStart(), prel.str(), vars);
1001             DeferredUpdates.push_back(u);
1002
1003             // rewrite the call op with the newly-inserted var
1004             SourceRange SR(cond->getLocStart(), cond->getLocEnd());
1005             Redirector[cond] = is->getLocStart();
1006             rewrite.ReplaceText(SR, condVar);
1007         }
1008
1009         std::deque<const Decl *> q;
1010         const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1011         q.push_back(d);
1012         while (!q.empty()) {
1013             const Decl * d = q.back();
1014             q.pop_back();
1015             if (isa<NamedDecl>(d))
1016                 DeclsInCond.insert(cast<NamedDecl>(d));
1017
1018             const VarDecl * vd;
1019             if ((vd = dyn_cast<VarDecl>(d))) {
1020                 if (vd->hasInit()) {
1021                     const Expr * e = vd->getInit();
1022                     flv.Clear();
1023                     flv.TraverseStmt(const_cast<Expr *>(e));
1024                     const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1025                     q.push_back(d);
1026                 }
1027             }
1028         }
1029     }
1030
1031 private:
1032     Rewriter &rewrite;
1033     std::set<const NamedDecl *> & DeclsInCond;
1034     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1035     std::map<const Expr *, std::string> &ExprToMCVar;
1036     std::map<const Expr *, SourceLocation> &Redirector;
1037     std::vector<Update *> &DeferredUpdates;
1038 };
1039
1040 class BranchAnnotationHandler : public MatchFinder::MatchCallback {
1041 public:
1042     BranchAnnotationHandler(Rewriter &rewrite,
1043                             std::map<const NamedDecl *, std::string> & DeclToMCVar,
1044                             std::map<const Expr *, std::string> & ExprToMCVar)
1045         : rewrite(rewrite),
1046           DeclToMCVar(DeclToMCVar),
1047           ExprToMCVar(ExprToMCVar){}
1048     virtual void run(const MatchFinder::MatchResult &Result) {
1049         IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
1050
1051         // if the branch condition is interesting:
1052         // (but right now, not too interesting)
1053         Expr * cond = is->getCond()->IgnoreCasts();
1054
1055         FindLocalsVisitor flv;
1056         flv.TraverseStmt(cond);
1057         if (flv.RetrieveVars().size() == 0) return;
1058
1059         const NamedDecl * condVar = flv.RetrieveVars()[0];
1060
1061         std::string mCondVar;
1062         if (ExprToMCVar.count(cond) > 0)
1063             mCondVar = ExprToMCVar[cond];
1064         else if (DeclToMCVar.count(condVar) > 0) 
1065             mCondVar = DeclToMCVar[condVar];
1066         else
1067             mCondVar = encode(condVar->getName());
1068         std::string brVar = encodeBranch(branchCount++);
1069
1070         std::stringstream brline;
1071         brline << "MCID " << brVar << ";\n";
1072         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(is->getLocStart()),
1073                            brline.str(), false, true);
1074
1075         Stmt * ts = is->getThen(), * es = is->getElse();
1076         bool tHasChild = hasChild(ts);
1077         SourceLocation tfl;
1078         if (tHasChild) {
1079             if (isa<CompoundStmt>(ts))
1080                 tfl = getFirstChild(ts)->getLocStart();
1081             else
1082                 tfl = ts->getLocStart();
1083         } else
1084             tfl = ts->getLocStart().getLocWithOffset(1);
1085         SourceLocation tsl = ts->getLocEnd().getLocWithOffset(-1);
1086
1087         std::stringstream tlineStart, mergeStmt, eline;
1088
1089         UnaryOperator * uop = dyn_cast<UnaryOperator>(cond);
1090         tlineStart << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "1" << ", 2, true);\n";
1091         eline << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "0" << ", 2, true);";
1092
1093         mergeStmt << "\tMC2_merge(" << brVar << ");\n";
1094
1095         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tfl), tlineStart.str(), false, true);
1096
1097         Stmt * tls = NULL;
1098         int extra_else_offset = 0;
1099
1100         if (tHasChild) { tls = getLastChild(ts); }
1101         if (tls) extra_else_offset = 2; else extra_else_offset = 1;
1102
1103         if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
1104             extra_else_offset = 0;
1105             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(tsl.getLocWithOffset(1)),
1106                                mergeStmt.str(), true, true);
1107         }
1108         if (tHasChild && !isa<CompoundStmt>(ts)) {
1109             rewrite.InsertText(rewrite.getSourceMgr().getFileLoc(tls->getLocStart()), "{", false, true);
1110             SourceLocation tend = Lexer::findLocationAfterToken(tls->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1111             rewrite.InsertText(tend, "}", true, true);
1112         }
1113         if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
1114
1115         if (es) {
1116             SourceLocation esl = es->getLocEnd().getLocWithOffset(-1);
1117             bool eHasChild = hasChild(es); 
1118             Stmt * els = NULL;
1119             if (eHasChild) els = getLastChild(es); else els = es;
1120             
1121             eline << "\n";
1122
1123             SourceLocation el;
1124             if (eHasChild) {
1125                 if (isa<CompoundStmt>(es))
1126                     el = getFirstChild(es)->getLocStart();
1127                 else {
1128                     el = es->getLocStart();
1129                 }
1130             } else
1131                 el = es->getLocStart().getLocWithOffset(1);
1132             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), eline.str(), false, true);
1133
1134             if (eHasChild && !isa<CompoundStmt>(es)) {
1135                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(el), "{", false, true);
1136                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(es->getLocEnd().getLocWithOffset(1)), "}", true, true);
1137             }
1138
1139             if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
1140                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(esl.getLocWithOffset(1)), mergeStmt.str(), true, true);
1141         }
1142         else {
1143             std::stringstream eCompoundLine;
1144             eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
1145             SourceLocation tend = Lexer::findLocationAfterToken(ts->getLocEnd(), tok::semi, rewrite.getSourceMgr(), rewrite.getLangOpts(), false);
1146             if (!tend.isValid())
1147                 tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts());
1148             rewrite.InsertText(tend.getLocWithOffset(1), eCompoundLine.str(), false, true);
1149         }
1150     }
1151 private:
1152
1153     bool hasChild(Stmt * s) {
1154         if (!isa<CompoundStmt>(s)) return true;
1155         return (!cast<CompoundStmt>(s)->body_empty());
1156     }
1157
1158     Stmt * getFirstChild(Stmt * s) {
1159         assert(isa<CompoundStmt>(s) && "haven't yet added code to rewrite then/elsestmt to CompoundStmt");
1160         assert(!cast<CompoundStmt>(s)->body_empty());
1161         return *(cast<CompoundStmt>(s)->body_begin());
1162     }
1163
1164     Stmt * getLastChild(Stmt * s) {
1165         CompoundStmt * cs;
1166         if ((cs = dyn_cast<CompoundStmt>(s))) {
1167             assert (!cs->body_empty());
1168             return cs->body_back();
1169         }
1170         return s;
1171     }
1172
1173     Rewriter &rewrite;
1174     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1175     std::map<const Expr *, std::string> &ExprToMCVar;
1176 };
1177
1178 class FunctionCallHandler : public MatchFinder::MatchCallback {
1179 public:
1180     FunctionCallHandler(Rewriter &rewrite,
1181                         std::map<const NamedDecl *, std::string> &DeclToMCVar,
1182                         std::set<const FunctionDecl *> &ThreadMains)
1183         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1184
1185     virtual void run(const MatchFinder::MatchResult &Result) {
1186         CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
1187         Decl * d = ce->getCalleeDecl();
1188         NamedDecl * nd = dyn_cast<NamedDecl>(d);
1189         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
1190         ASTContext *Context = Result.Context;
1191
1192         if (nd->getName() == "thrd_create") {
1193             Expr * callee0 = ce->getArg(1)->IgnoreCasts();
1194             UnaryOperator * callee1;
1195             if ((callee1 = dyn_cast<UnaryOperator>(callee0))) {
1196                 if (callee1->getOpcode() == UnaryOperatorKind::UO_AddrOf)
1197                     callee0 = callee1->getSubExpr();
1198             }
1199             DeclRefExpr * callee = dyn_cast<DeclRefExpr>(callee0);
1200             if (!callee) return;
1201             FunctionDecl * fd = dyn_cast<FunctionDecl>(callee->getDecl());
1202             ThreadMains.insert(fd);
1203             return;
1204         }
1205
1206         if (!d->hasBody())
1207             return;
1208
1209         if (s && !ce->getCallReturnType(*Context)->isVoidType()) {
1210             // TODO check that the type is mc-visible also?
1211             const DeclStmt * ds;
1212             const VarDecl * lhs = NULL;
1213             std::string mc_rv = encodeRV(rvCount++);
1214
1215             std::stringstream brline;
1216             brline << "MCID " << mc_rv << ";\n";
1217             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(s->getLocStart()),
1218                                brline.str(), false, true);
1219
1220             std::stringstream nol;
1221             if (ce->getNumArgs() > 0) nol << ", ";
1222             nol << "&" << mc_rv;
1223             rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(ce->getRParenLoc()),
1224                                      nol.str());
1225
1226             if (s && (ds = dyn_cast<DeclStmt>(s))) {
1227                 if (!ds->isSingleDecl()) {
1228                     for (auto & d : ds->decls()) {
1229                         VarDecl * vd = dyn_cast<VarDecl>(d);
1230                         if (!d || vd->hasInit())
1231                             assert(0 && "unsupported form of decl");
1232                     }
1233                     return;
1234                 }
1235
1236                 lhs = retrieveSingleDecl(ds);
1237             }
1238
1239             DeclToMCVar[lhs] = mc_rv;
1240         }
1241
1242         for (const auto & a : ce->arguments()) {
1243             std::stringstream nol;
1244
1245             std::string aa = "MCID_NODEP";
1246
1247             Expr * e = a->IgnoreCasts();
1248             DeclRefExpr * dr = dyn_cast<DeclRefExpr>(e);
1249             if (dr) { 
1250                 NamedDecl * d = dr->getDecl();
1251                 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1252                     aa = DeclToMCVar[d];
1253             }
1254
1255             nol << aa << ", ";
1256             
1257             if (a->getLocEnd().isValid())
1258                 rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(a->getLocStart()),
1259                                          nol.str());
1260         }
1261     }
1262
1263 private:
1264     Rewriter &rewrite;
1265     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1266     std::set<const FunctionDecl *> &ThreadMains;
1267 };
1268
1269 class ReturnHandler : public MatchFinder::MatchCallback {
1270 public:
1271     ReturnHandler(Rewriter &rewrite,
1272                   std::map<const NamedDecl *, std::string> &DeclToMCVar,
1273                   std::set<const FunctionDecl *> &ThreadMains)
1274         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1275
1276     virtual void run(const MatchFinder::MatchResult &Result) {
1277         const FunctionDecl * fd = Result.Nodes.getNodeAs<FunctionDecl>("containingFunction");
1278         ReturnStmt * rs = const_cast<ReturnStmt *>(Result.Nodes.getNodeAs<ReturnStmt>("returnStmt"));
1279         Expr * rv = const_cast<Expr *>(rs->getRetValue());
1280
1281         if (!rv) return;        
1282         if (ThreadMains.find(fd) != ThreadMains.end()) return;
1283         // not sure why this is explicitly needed, but crashes without it
1284         if (!fd->getIdentifier() || fd->getName() == "user_main") return;
1285
1286         FindLocalsVisitor flv;
1287         flv.TraverseStmt(rv);
1288         std::string mrv = "MCID_NODEP";
1289
1290         if (flv.RetrieveVars().size() > 0) {
1291             const NamedDecl * returnVar = flv.RetrieveVars()[0];
1292             if (DeclToMCVar.find(returnVar) != DeclToMCVar.end()) {
1293                 mrv = DeclToMCVar[returnVar];
1294             }
1295         }
1296         std::stringstream nol;
1297         nol << "*retval = " << mrv << ";\n";
1298         rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(rs->getLocStart()),
1299                            nol.str(), false, true);
1300     }
1301
1302 private:
1303     Rewriter &rewrite;
1304     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1305     std::set<const FunctionDecl *> &ThreadMains;
1306 };
1307
1308 class VarDeclHandler : public MatchFinder::MatchCallback {
1309 public:
1310     VarDeclHandler(Rewriter &rewrite,
1311                    std::map<const NamedDecl *, std::string> &DeclToMCVar,
1312                    std::set<const VarDecl *> &DeclsNeedingMC)
1313         : rewrite(rewrite), DeclToMCVar(DeclToMCVar), DeclsNeedingMC(DeclsNeedingMC) {}
1314
1315     virtual void run(const MatchFinder::MatchResult &Result) {
1316         VarDecl * d = const_cast<VarDecl *>(Result.Nodes.getNodeAs<VarDecl>("d"));
1317         std::stringstream nol;
1318
1319         if (DeclsNeedingMC.find(d) == DeclsNeedingMC.end()) return;
1320
1321         std::string dn;
1322         if (DeclToMCVar.find(d) != DeclToMCVar.end())
1323             dn = DeclToMCVar[d];
1324         else
1325             dn = encode(d->getName().str());
1326
1327         nol << "MCID " << dn << "; ";
1328
1329         if (d->getLocStart().isValid())
1330             rewrite.InsertTextBefore(rewrite.getSourceMgr().getExpansionLoc(d->getLocStart()),
1331                                      nol.str());
1332     }
1333
1334 private:
1335     Rewriter &rewrite;
1336     std::map<const NamedDecl *, std::string> &DeclToMCVar;
1337     std::set<const VarDecl *> &DeclsNeedingMC;
1338 };
1339
1340 class FunctionDeclHandler : public MatchFinder::MatchCallback {
1341 public:
1342     FunctionDeclHandler(Rewriter &rewrite,
1343                         std::set<const FunctionDecl *> &ThreadMains)
1344         : rewrite(rewrite), ThreadMains(ThreadMains) {}
1345
1346     virtual void run(const MatchFinder::MatchResult &Result) {
1347         FunctionDecl * fd = const_cast<FunctionDecl *>(Result.Nodes.getNodeAs<FunctionDecl>("fd"));
1348
1349         if (!fd->getIdentifier()) return;
1350
1351         if (fd->getName() == "user_main") { ThreadMains.insert(fd); return; }
1352
1353         if (ThreadMains.find(fd) != ThreadMains.end()) return;
1354
1355         SourceLocation LastParam = fd->getNameInfo().getLocStart().getLocWithOffset(fd->getName().size()).getLocWithOffset(1);
1356         for (auto & p : fd->params()) {
1357             std::stringstream nol;
1358             nol << "MCID " << encode(p->getName()) << ", ";
1359             if (p->getLocStart().isValid())
1360                 rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(p->getLocStart()),
1361                                    nol.str(), false);
1362             if (p->getLocEnd().isValid())
1363                 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
1364         }
1365
1366         if (!fd->getReturnType()->isVoidType()) {
1367             std::stringstream nol;
1368             if (fd->param_size() > 0) nol << ", ";
1369             nol << "MCID * retval";
1370             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(LastParam),
1371                                nol.str(), false);
1372         }
1373     }
1374
1375 private:
1376     Rewriter &rewrite;
1377     std::set<const FunctionDecl *> &ThreadMains;
1378 };
1379
1380 class BailHandler : public MatchFinder::MatchCallback {
1381 public:
1382     BailHandler() {}
1383     virtual void run(const MatchFinder::MatchResult &Result) {
1384         assert(0 && "we don't handle goto statements");
1385     }
1386 };
1387
1388 class MyASTConsumer : public ASTConsumer {
1389 public:
1390     MyASTConsumer(Rewriter &R) : R(R),
1391                                  DeclsRead(),
1392                                  DeclsInCond(),
1393                                  DeclToMCVar(),
1394                                  HandlerMalloc(MallocExprs),
1395                                  HandlerLoad(R, DeclsRead, DeclsNeedingMC, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1396                                  HandlerStore(R, DeclsRead, DeclsNeedingMC, DeferredUpdates),
1397                                  HandlerRMW(R, DeclsRead, DeclsInCond, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1398                                  HandlerLoop(R),
1399                                  HandlerBranchConditionRefactoring(R, DeclsInCond, DeclToMCVar, ExprToMCVar, Redirector, DeferredUpdates),
1400                                  HandlerAssign(R, DeclsRead, DeclsInCond, DeclsNeedingMC, DeclToMCVar, StmtsHandled, MallocExprs, DeferredUpdates),
1401                                  HandlerAnnotateBranch(R, DeclToMCVar, ExprToMCVar),
1402                                  HandlerFunctionDecl(R, ThreadMains),
1403                                  HandlerFunctionCall(R, DeclToMCVar, ThreadMains),
1404                                  HandlerReturn(R, DeclToMCVar, ThreadMains),
1405                                  HandlerVarDecl(R, DeclToMCVar, DeclsNeedingMC),
1406                                  HandlerBail() {
1407         MatcherFunctionCall.addMatcher(callExpr(anyOf(hasParent(compoundStmt()),
1408                                                       hasAncestor(varDecl(hasParent(stmt().bind("containingStmt")))),
1409                                                       hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")))).bind("callExpr"),
1410                                        &HandlerFunctionCall);
1411         MatcherLoadStore.addMatcher
1412             (callExpr(callee(functionDecl(anyOf(hasName("malloc"), hasName("calloc"))))).bind("callExpr"),
1413              &HandlerMalloc);
1414
1415         MatcherLoadStore.addMatcher
1416             (callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64")))),
1417                       anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1418                             hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")),
1419                             hasParent(stmt().bind("containingStmt"))))
1420              .bind("callExpr"),
1421              &HandlerLoad);
1422
1423         MatcherLoadStore.addMatcher(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr"),
1424                                     &HandlerStore);
1425
1426         MatcherLoadStore.addMatcher
1427             (callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64")))),
1428                       anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1429                             hasAncestor(binaryOperator(hasOperatorName("="),
1430                                                        hasLHS(declRefExpr().bind("lhs"))).bind("containingStmt")),
1431                             anything()))
1432              .bind("callExpr"),
1433              &HandlerRMW);
1434
1435         MatcherLoadStore.addMatcher(ifStmt(hasCondition
1436                                            (anyOf(binaryOperator().bind("bc"),
1437                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64"))))).bind("callExpr")),
1438                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr")),
1439                                                   hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64"))))).bind("callExpr")),
1440                                                   anything()))).bind("if"),
1441                                     &HandlerBranchConditionRefactoring);
1442
1443         MatcherLoadStore.addMatcher(forStmt().bind("s"),
1444                                     &HandlerLoop);
1445         MatcherLoadStore.addMatcher(whileStmt().bind("s"),
1446                                     &HandlerLoop);
1447         MatcherLoadStore.addMatcher(doStmt().bind("s"),
1448                                     &HandlerLoop);
1449
1450         MatcherFunction.addMatcher(binaryOperator(anyOf(hasAncestor(declStmt().bind("containingStmt")),
1451                                                         hasParent(compoundStmt())),
1452                                                         hasOperatorName("=")).bind("op"),
1453                                    &HandlerAssign);
1454         MatcherFunction.addMatcher(declStmt().bind("containingStmt"), &HandlerAssign);
1455
1456         MatcherFunction.addMatcher(ifStmt().bind("if"),
1457                                    &HandlerAnnotateBranch);
1458
1459         MatcherFunctionDecl.addMatcher(functionDecl().bind("fd"),
1460                                        &HandlerFunctionDecl);
1461         MatcherFunctionDecl.addMatcher(varDecl().bind("d"), &HandlerVarDecl);
1462         MatcherFunctionDecl.addMatcher(returnStmt(hasAncestor(functionDecl().bind("containingFunction"))).bind("returnStmt"),
1463                                    &HandlerReturn);
1464
1465         MatcherSanity.addMatcher(gotoStmt(), &HandlerBail);
1466     }
1467
1468     // Override the method that gets called for each parsed top-level
1469     // declaration.
1470     void HandleTranslationUnit(ASTContext &Context) override {
1471         LangOpts = Context.getLangOpts();
1472
1473         MatcherFunctionCall.matchAST(Context);
1474         MatcherLoadStore.matchAST(Context);
1475         MatcherFunction.matchAST(Context);
1476         MatcherFunctionDecl.matchAST(Context);
1477         MatcherSanity.matchAST(Context);
1478
1479         for (auto & u : DeferredUpdates) {
1480             R.InsertText(R.getSourceMgr().getExpansionLoc(u->loc), u->update, true, true);
1481             delete u;
1482         }
1483         DeferredUpdates.clear();
1484     }
1485
1486 private:
1487     /* DeclsRead contains all local variables 'x' which:
1488     * 1) appear in 'x = load_32(...);
1489     * 2) appear in 'y = store_32(x); */
1490     std::set<const NamedDecl *> DeclsRead, DeclsInCond;
1491     std::map<const NamedDecl *, std::string> DeclToMCVar;
1492     std::map<const Expr *, std::string> ExprToMCVar;
1493     std::set<const VarDecl *> DeclsNeedingMC;
1494     std::set<const FunctionDecl *> ThreadMains;
1495     std::set<const Stmt *> StmtsHandled;
1496     std::set<const Expr *> MallocExprs;
1497     std::map<const Expr *, SourceLocation> Redirector;
1498     std::vector<Update *> DeferredUpdates;
1499
1500     Rewriter &R;
1501
1502     MallocHandler HandlerMalloc;
1503     LoadHandler HandlerLoad;
1504     StoreHandler HandlerStore;
1505     RMWHandler HandlerRMW;
1506     LoopHandler HandlerLoop;
1507     BranchConditionRefactoringHandler HandlerBranchConditionRefactoring;
1508     BranchAnnotationHandler HandlerAnnotateBranch;
1509     AssignHandler HandlerAssign;
1510     FunctionDeclHandler HandlerFunctionDecl;
1511     FunctionCallHandler HandlerFunctionCall;
1512     ReturnHandler HandlerReturn;
1513     VarDeclHandler HandlerVarDecl;
1514     BailHandler HandlerBail;
1515     MatchFinder MatcherLoadStore, MatcherFunction, MatcherFunctionDecl, MatcherFunctionCall, MatcherSanity;
1516 };
1517
1518 // For each source file provided to the tool, a new FrontendAction is created.
1519 class MyFrontendAction : public ASTFrontendAction {
1520 public:
1521     MyFrontendAction() {}
1522     void EndSourceFileAction() override {
1523         SourceManager &SM = TheRewriter.getSourceMgr();
1524         llvm::errs() << "** EndSourceFileAction for: "
1525                      << SM.getFileEntryForID(SM.getMainFileID())->getName() << "\n";
1526
1527         // Now emit the rewritten buffer.
1528         TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
1529     }
1530
1531     std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
1532                                                    StringRef file) override {
1533         llvm::errs() << "** Creating AST consumer for: " << file << "\n";
1534         TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
1535         return llvm::make_unique<MyASTConsumer>(TheRewriter);
1536     }
1537
1538 private:
1539     Rewriter TheRewriter;
1540 };
1541
1542 int main(int argc, const char **argv) {
1543     CommonOptionsParser op(argc, argv, AddMC2AnnotationsCategory);
1544     ClangTool Tool(op.getCompilations(), op.getSourcePathList());
1545     
1546     return Tool.run(newFrontendActionFactory<MyFrontendAction>().get());
1547 }