1 // -*- indent-tabs-mode:nil; -*-
2 //------------------------------------------------------------------------------
3 // Add MC2 annotations to C code.
4 // Copyright 2015 Patrick Lam <prof.lam@gmail.com>
6 // Permission is hereby granted, free of charge, to any person
7 // obtaining a copy of this software and associated documentation
8 // files (the "Software"), to deal with the Software without
9 // restriction, including without limitation the rights to use, copy,
10 // modify, merge, publish, distribute, sublicense, and/or sell copies
11 // of the Software, and to permit persons to whom the Software is
12 // furnished to do so, subject to the following conditions:
14 // Redistributions of source code must retain the above copyright
15 // notice, this list of conditions and the following disclaimers.
17 // Redistributions in binary form must reproduce the above copyright
18 // notice, this list of conditions and the following disclaimers in
19 // the documentation and/or other materials provided with the
22 // Neither the names of the University of Waterloo, nor the names of
23 // its contributors may be used to endorse or promote products derived
24 // from this Software without specific prior written permission.
26 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
27 // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
28 // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
29 // NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT
30 // HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
31 // WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
33 // DEALINGS WITH THE SOFTWARE.
35 // Patrick Lam (prof.lam@gmail.com)
38 // Eli Bendersky (eliben@gmail.com)
40 //------------------------------------------------------------------------------
46 #include "clang/AST/AST.h"
47 #include "clang/AST/ASTContext.h"
48 #include "clang/AST/ASTConsumer.h"
49 #include "clang/AST/RecursiveASTVisitor.h"
50 #include "clang/ASTMatchers/ASTMatchers.h"
51 #include "clang/ASTMatchers/ASTMatchFinder.h"
52 #include "clang/Frontend/ASTConsumers.h"
53 #include "clang/Frontend/FrontendActions.h"
54 #include "clang/Frontend/CompilerInstance.h"
55 #include "clang/Lex/Lexer.h"
56 #include "clang/Tooling/CommonOptionsParser.h"
57 #include "clang/Tooling/Tooling.h"
58 #include "clang/Rewrite/Core/Rewriter.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/ADT/STLExtras.h"
62 using namespace clang;
63 using namespace clang::ast_matchers;
64 using namespace clang::driver;
65 using namespace clang::tooling;
68 static LangOptions LangOpts;
69 static llvm::cl::OptionCategory AddMC2AnnotationsCategory("Add MC2 Annotations");
71 static std::string encode(std::string varName) {
73 nn << "_m" << varName;
78 static std::string encodeFn(int num) {
85 static std::string encodePtr(int num) {
92 static std::string encodeRMW(int num) {
98 static int branchCount;
99 static std::string encodeBranch(int num) {
100 std::stringstream nn;
105 static int condCount;
106 static std::string encodeCond(int num) {
107 std::stringstream nn;
108 nn << "_cond" << num;
113 static std::string encodeRV(int num) {
114 std::stringstream nn;
119 static int funcCount;
121 struct ProvisionalName {
123 const DeclRefExpr * pname;
126 ProvisionalName(int index, const DeclRefExpr * pname) : index(index), pname(pname), length(encode(pname->getNameInfo().getName().getAsString()).length()), enabled(true) {}
127 ProvisionalName(int index, const DeclRefExpr * pname, int length) : index(index), pname(pname), length(length), enabled(true) {}
133 std::vector<ProvisionalName *> * pnames;
135 Update(SourceLocation loc, std::string update, std::vector<ProvisionalName *> * pnames) :
136 loc(loc), update(update), pnames(pnames) {}
139 for (auto pname : *pnames) delete pname;
144 void updateProvisionalName(std::vector<Update *> &DeferredUpdates, const ValueDecl * now_known, std::string mcVar) {
145 for (Update * u : DeferredUpdates) {
146 for (int i = 0; i < u->pnames->size(); i++) {
147 ProvisionalName * v = (*(u->pnames))[i];
148 if (!v->enabled) continue;
149 if (now_known == v->pname->getDecl()) {
151 std::string oldName = encode(v->pname->getNameInfo().getName().getAsString());
153 u->update.replace(v->index, v->length, mcVar);
154 for (int j = i+1; j < u->pnames->size(); j++) {
155 ProvisionalName * vv = (*(u->pnames))[j];
156 if (vv->index > v->index)
157 vv->index -= v->length - mcVar.length();
164 static const VarDecl * retrieveSingleDecl(const DeclStmt * s) {
165 // XXX iterate through all decls defined in s, not just the first one
166 assert(s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl()) && "unsupported form of decl");
167 if (s->isSingleDecl() && isa<VarDecl>(s->getSingleDecl())) {
168 return cast<VarDecl>(s->getSingleDecl());
172 class FindCallArgVisitor : public RecursiveASTVisitor<FindCallArgVisitor> {
174 FindCallArgVisitor() : DE(NULL), UnaryOp(NULL) {}
176 bool VisitStmt(Stmt * s) {
178 if (UnaryOperator * uo = dyn_cast<UnaryOperator>(s)) {
179 if (uo->getOpcode() == UnaryOperatorKind::UO_AddrOf ||
180 uo->getOpcode() == UnaryOperatorKind::UO_Deref)
185 if (!DE && (DE = dyn_cast<DeclRefExpr>(s)))
191 UnaryOp = NULL; DE = NULL;
194 const UnaryOperator * RetrieveUnaryOp() {
197 const Stmt * s = UnaryOp;
203 if (const UnaryOperator * op = dyn_cast<UnaryOperator>(s))
204 s = op->getSubExpr();
205 else if (const CastExpr * op = dyn_cast<CastExpr>(s))
206 s = op->getSubExpr();
207 else if (const MemberExpr * op = dyn_cast<MemberExpr>(s))
218 const DeclRefExpr * RetrieveDeclRefExpr() {
223 const UnaryOperator * UnaryOp;
224 const DeclRefExpr * DE;
227 class FindLocalsVisitor : public RecursiveASTVisitor<FindLocalsVisitor> {
229 FindLocalsVisitor() : Vars() {}
231 bool VisitDeclRefExpr(DeclRefExpr * de) {
232 Vars.push_back(de->getDecl());
240 const TinyPtrVector<const NamedDecl *> RetrieveVars() {
245 TinyPtrVector<const NamedDecl *> Vars;
249 class MallocHandler : public MatchFinder::MatchCallback {
251 MallocHandler(std::set<const Expr *> & MallocExprs) :
252 MallocExprs(MallocExprs) {}
254 virtual void run(const MatchFinder::MatchResult &Result) {
255 const CallExpr * ce = Result.Nodes.getNodeAs<CallExpr>("callExpr");
257 MallocExprs.insert(ce);
261 std::set<const Expr *> &MallocExprs;
264 static void generateMC2Function(Rewriter & Rewrite,
269 const DeclRefExpr * lhs,
271 std::vector<ProvisionalName *> * vars1,
272 std::vector<Update *> & DeferredUpdates) {
273 // prettyprint the LHS (&newnode->value)
274 // e.g. int * _tmp0 = &newnode->value;
276 llvm::raw_string_ostream S(SStr);
277 e->printPretty(S, nullptr, Rewrite.getLangOpts());
278 const std::string &Str = S.str();
280 std::stringstream prel;
281 prel << "\nvoid * " << tmpname << " = " << Str << ";\n";
283 // MCID _p0 = MC2_function(1, MC2_PTR_LENGTH, _tmp0, _fn0);
284 prel << "MCID " << tmpFn << " = MC2_function_id(" << ++funcCount << ", 1, MC2_PTR_LENGTH, " << tmpname << ", ";
286 // XXX generate casts when they'd eliminate warnings
287 ProvisionalName * v = new ProvisionalName(prel.tellp(), lhs);
290 prel << encode(lhsName) << "); ";
292 Update * u = new Update(loc, prel.str(), vars1);
293 DeferredUpdates.push_back(u);
296 class LoadHandler : public MatchFinder::MatchCallback {
298 LoadHandler(Rewriter &Rewrite,
299 std::set<const NamedDecl *> & DeclsRead,
300 std::set<const VarDecl *> & DeclsNeedingMC,
301 std::map<const NamedDecl *, std::string> &DeclToMCVar,
302 std::map<const Expr *, std::string> &ExprToMCVar,
303 std::set<const Stmt *> & StmtsHandled,
304 std::map<const Expr *, SourceLocation> &Redirector,
305 std::vector<Update *> & DeferredUpdates) :
306 Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeclToMCVar(DeclToMCVar),
307 ExprToMCVar(ExprToMCVar),
308 StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
310 virtual void run(const MatchFinder::MatchResult &Result) {
311 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
312 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
313 const VarDecl * d = Result.Nodes.getNodeAs<VarDecl>("decl");
314 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
315 const Expr * lhs = NULL;
316 if (s && isa<BinaryOperator>(s)) lhs = cast<BinaryOperator>(s)->getLHS();
319 const DeclRefExpr * rhs = NULL;
320 MemberExpr * ml = NULL;
321 bool isAddrOfR = false, isAddrMemberR = false;
323 StmtsHandled.insert(s);
325 std::string n, n_decl;
327 FindCallArgVisitor fcaVisitor;
329 fcaVisitor.TraverseStmt(ce->getArg(0));
330 rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
331 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
332 isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
333 isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
335 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
337 FindLocalsVisitor flv;
339 flv.TraverseStmt(const_cast<Stmt*>(cast<Stmt>(lhs)));
340 for (auto & d : flv.RetrieveVars()) {
341 const VarDecl * dd = cast<VarDecl>(d);
343 // XXX todo rhs for non-decl stmts
344 if (!isa<ParmVarDecl>(dd))
345 DeclsNeedingMC.insert(dd);
347 DeclToMCVar[dd] = encode(n);
350 FindCallArgVisitor fcaVisitor;
352 fcaVisitor.TraverseStmt(ce->getArg(0));
353 rhs = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
354 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
355 isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
356 isAddrMemberR = ruop && isa<MemberExpr>(ruop->getSubExpr());
358 ml = dyn_cast<MemberExpr>(ruop->getSubExpr());
362 DeclsNeedingMC.insert(d);
364 DeclToMCVar[d] = encode(n);
368 fcaVisitor.TraverseStmt(ce);
369 const DeclRefExpr * dd = cast<DeclRefExpr>(fcaVisitor.RetrieveDeclRefExpr()->IgnoreParens());
370 updateProvisionalName(DeferredUpdates, dd->getDecl(), encode(n));
371 DeclToMCVar[dd->getDecl()] = encode(n);
376 std::stringstream nol;
378 if (lhs && isa<DeclRefExpr>(lhs)) {
379 const DeclRefExpr * ll = cast<DeclRefExpr>(lhs);
380 ProvisionalName * v = new ProvisionalName(nol.tellp(), ll);
387 nol << n_decl << encode(n) << "=";
388 nol << "MC2_nextOpLoadOffset(";
390 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
392 nol << encode(rhs->getNameInfo().getName().getAsString());
394 nol << ", MC2_OFFSET(";
395 nol << ml->getBase()->getType().getAsString();
397 nol << ml->getMemberDecl()->getName().str();
399 } else if (!isAddrOfR) {
401 nol << n_decl << encode(n) << "=";
402 nol << "MC2_nextOpLoad(";
403 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
405 nol << encode(rhs->getNameInfo().getName().getAsString());
408 nol << n_decl << encode(n) << "=";
409 nol << "MC2_nextOpLoad(";
414 nol << n_decl << encode(n) << "=";
415 nol << "MC2_nextOpLoad(";
423 SourceLocation ss = s->getLocStart();
425 // if the load appears as its own stmt and is the 1st stmt, containingStmt may be the containing CompoundStmt;
426 // move over 1 so that we get the right location.
427 if (isa<CompoundStmt>(s)) ss = ss.getLocWithOffset(1);
428 const Expr * e = dyn_cast<Expr>(s);
429 if (e && Redirector.count(e) > 0)
431 Update * u = new Update(ss, nol.str(), vars);
432 DeferredUpdates.insert(DeferredUpdates.begin(), u);
437 std::set<const NamedDecl *> & DeclsRead;
438 std::set<const VarDecl *> & DeclsNeedingMC;
439 std::map<const Expr *, std::string> &ExprToMCVar;
440 std::map<const NamedDecl *, std::string> &DeclToMCVar;
441 std::set<const Stmt *> &StmtsHandled;
442 std::vector<Update *> &DeferredUpdates;
443 std::map<const Expr *, SourceLocation> &Redirector;
446 class StoreHandler : public MatchFinder::MatchCallback {
448 StoreHandler(Rewriter &Rewrite,
449 std::set<const NamedDecl *> & DeclsRead,
450 std::set<const VarDecl *> &DeclsNeedingMC,
451 std::vector<Update *> & DeferredUpdates) :
452 Rewrite(Rewrite), DeclsRead(DeclsRead), DeclsNeedingMC(DeclsNeedingMC), DeferredUpdates(DeferredUpdates) {}
454 virtual void run(const MatchFinder::MatchResult &Result) {
455 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
456 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
459 fcaVisitor.TraverseStmt(ce->getArg(0));
460 const DeclRefExpr * lhs = fcaVisitor.RetrieveDeclRefExpr();
461 const UnaryOperator * luop = fcaVisitor.RetrieveUnaryOp();
463 std::stringstream nol;
468 if (luop && luop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
469 isAddrMemberL = isa<MemberExpr>(luop->getSubExpr());
470 isAddrOfL = !isa<MemberExpr>(luop->getSubExpr());
475 nol << "MC2_nextOpStore(";
479 MemberExpr * ml = cast<MemberExpr>(luop->getSubExpr());
481 nol << "MC2_nextOpStoreOffset(";
483 ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
486 nol << encode(lhs->getNameInfo().getName().getAsString());
487 if (!isa<ParmVarDecl>(lhs->getDecl()))
488 DeclsNeedingMC.insert(cast<VarDecl>(lhs->getDecl()));
490 nol << ", MC2_OFFSET(";
491 nol << ml->getBase()->getType().getAsString();
493 nol << ml->getMemberDecl()->getName().str();
496 nol << "MC2_nextOpStore(";
497 ProvisionalName * v = new ProvisionalName(nol.tellp(), lhs);
500 nol << encode(lhs->getNameInfo().getName().getAsString());
505 nol << "MC2_nextOpStore(";
512 fcaVisitor.TraverseStmt(ce->getArg(1));
513 const DeclRefExpr * rhs = fcaVisitor.RetrieveDeclRefExpr();
514 const UnaryOperator * ruop = fcaVisitor.RetrieveUnaryOp();
516 bool isAddrOfR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
517 bool isDerefR = ruop && ruop->getOpcode() == UnaryOperatorKind::UO_Deref;
519 if (rhs && !isAddrOfR) {
520 assert (!isDerefR && "Must use atomic load for dereferences!");
521 ProvisionalName * v = new ProvisionalName(nol.tellp(), rhs);
524 nol << encode(rhs->getNameInfo().getName().getAsString());
525 DeclsRead.insert(rhs->getDecl());
531 Update * u = new Update(ce->getLocStart(), nol.str(), vars);
532 DeferredUpdates.push_back(u);
537 FindCallArgVisitor fcaVisitor;
538 std::set<const NamedDecl *> & DeclsRead;
539 std::set<const VarDecl *> & DeclsNeedingMC;
540 std::vector<Update *> &DeferredUpdates;
543 class RMWHandler : public MatchFinder::MatchCallback {
545 RMWHandler(Rewriter &rewrite,
546 std::set<const NamedDecl *> & DeclsRead,
547 std::set<const NamedDecl *> & DeclsInCond,
548 std::map<const NamedDecl *, std::string> &DeclToMCVar,
549 std::map<const Expr *, std::string> &ExprToMCVar,
550 std::set<const Stmt *> & StmtsHandled,
551 std::map<const Expr *, SourceLocation> &Redirector,
552 std::vector<Update *> & DeferredUpdates) :
553 rewrite(rewrite), DeclsRead(DeclsRead), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
554 ExprToMCVar(ExprToMCVar),
555 StmtsHandled(StmtsHandled), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
557 virtual void run(const MatchFinder::MatchResult &Result) {
558 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
559 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
560 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
562 std::stringstream nol;
564 std::string rmwMCVar;
565 rmwMCVar = encodeRMW(rmwCount++);
567 const VarDecl * rmw_lhs;
569 StmtsHandled.insert(s);
570 assert (isa<DeclStmt>(s) || isa<BinaryOperator>(s) && "unknown RMW format: not declrefexpr, not binaryoperator");
572 if ((ds = dyn_cast<DeclStmt>(s))) {
573 rmw_lhs = retrieveSingleDecl(ds);
575 const Expr * e = cast<BinaryOperator>(s)->getLHS();
576 assert (isa<DeclRefExpr>(e));
577 rmw_lhs = cast<VarDecl>(cast<DeclRefExpr>(e)->getDecl());
579 DeclToMCVar[rmw_lhs] = rmwMCVar;
582 // retrieve effective LHS of the RMW
584 fcaVisitor.TraverseStmt(ce->getArg(1));
585 const DeclRefExpr * elhs = fcaVisitor.RetrieveDeclRefExpr();
586 const UnaryOperator * eluop = fcaVisitor.RetrieveUnaryOp();
587 bool isAddrMemberL = false;
589 if (eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
590 isAddrMemberL = isa<MemberExpr>(eluop->getSubExpr());
593 nol << "MCID " << rmwMCVar;
595 MemberExpr * ml = cast<MemberExpr>(eluop->getSubExpr());
597 nol << " = MC2_nextRMWOffset(";
599 ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
602 nol << encode(elhs->getNameInfo().getName().getAsString());
604 nol << ", MC2_OFFSET(";
605 nol << ml->getBase()->getType().getAsString();
607 nol << ml->getMemberDecl()->getName().str();
610 nol << " = MC2_nextRMW(";
611 bool isAddrOfL = eluop && eluop->getOpcode() == UnaryOperatorKind::UO_AddrOf;
617 ProvisionalName * v = new ProvisionalName(nol.tellp(), elhs);
620 std::string elhsName = encode(elhs->getNameInfo().getName().getAsString());
629 // handle both RHS ops
631 for (int arg = 2; arg < 4; arg++) {
633 fcaVisitor.TraverseStmt(ce->getArg(arg));
634 const DeclRefExpr * a = fcaVisitor.RetrieveDeclRefExpr();
635 const UnaryOperator * op = fcaVisitor.RetrieveUnaryOp();
637 bool isAddrOfR = op && op->getOpcode() == UnaryOperatorKind::UO_AddrOf;
638 bool isDerefR = op && op->getOpcode() == UnaryOperatorKind::UO_Deref;
640 if (a && !isAddrOfR) {
641 assert (!isDerefR && "Must use atomic load for dereferences!");
643 DeclsInCond.insert(a->getDecl());
645 if (outputted > 0) nol << ", ";
648 bool alreadyMCVar = false;
649 if (DeclToMCVar.find(a->getDecl()) != DeclToMCVar.end()) {
651 nol << DeclToMCVar[a->getDecl()];
654 std::string an = "MCID_NODEP";
655 ProvisionalName * v = new ProvisionalName(nol.tellp(), a, an.length());
660 DeclsRead.insert(a->getDecl());
663 if (outputted > 0) nol << ", ";
671 SourceLocation place = s ? s->getLocStart() : ce->getLocStart();
672 const Expr * e = s ? dyn_cast<Expr>(s) : ce;
673 if (e && Redirector.count(e) > 0)
674 place = Redirector[e];
675 Update * u = new Update(place, nol.str(), vars);
676 DeferredUpdates.insert(DeferredUpdates.begin(), u);
681 FindCallArgVisitor fcaVisitor;
682 std::set<const NamedDecl *> &DeclsRead;
683 std::set<const NamedDecl *> &DeclsInCond;
684 std::map<const NamedDecl *, std::string> &DeclToMCVar;
685 std::map<const Expr *, std::string> &ExprToMCVar;
686 std::set<const Stmt *> &StmtsHandled;
687 std::vector<Update *> &DeferredUpdates;
688 std::map<const Expr *, SourceLocation> &Redirector;
691 class FindReturnsBreaksVisitor : public RecursiveASTVisitor<FindReturnsBreaksVisitor> {
693 FindReturnsBreaksVisitor() : Returns(), Breaks() {}
695 bool VisitStmt(Stmt * s) {
696 if (isa<ReturnStmt>(s))
697 Returns.push_back(cast<ReturnStmt>(s));
699 if (isa<BreakStmt>(s))
700 Breaks.push_back(cast<BreakStmt>(s));
705 Returns.clear(); Breaks.clear();
708 const std::vector<const ReturnStmt *> RetrieveReturns() {
712 const std::vector<const BreakStmt *> RetrieveBreaks() {
717 std::vector<const ReturnStmt *> Returns;
718 std::vector<const BreakStmt *> Breaks;
721 class LoopHandler : public MatchFinder::MatchCallback {
723 LoopHandler(Rewriter &rewrite) : rewrite(rewrite) {}
725 virtual void run(const MatchFinder::MatchResult &Result) {
726 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("s");
728 rewrite.InsertText(s->getLocStart(), "MC2_enterLoop();\n", true, true);
730 // annotate all returns with MC2_exitLoop()
731 // annotate all breaks that aren't further nested with MC2_exitLoop().
732 FindReturnsBreaksVisitor frbv;
734 frbv.TraverseStmt(const_cast<Stmt *>(cast<ForStmt>(s)->getBody()));
735 if (isa<WhileStmt>(s))
736 frbv.TraverseStmt(const_cast<Stmt *>(cast<WhileStmt>(s)->getBody()));
738 frbv.TraverseStmt(const_cast<Stmt *>(cast<DoStmt>(s)->getBody()));
740 for (auto & r : frbv.RetrieveReturns()) {
741 rewrite.InsertText(r->getLocStart(), "MC2_exitLoop();\n", true, true);
744 // need to find all breaks and returns embedded inside the loop
746 rewrite.InsertTextAfterToken(s->getLocEnd().getLocWithOffset(1), "\nMC2_exitLoop();\n");
753 /* Inserts MC2_function for any variables which are subsequently used by the model checker, as long as they depend on MC-visible [currently: read] state. */
754 class AssignHandler : public MatchFinder::MatchCallback {
756 AssignHandler(Rewriter &rewrite, std::set<const NamedDecl *> &DeclsRead,
757 std::set<const NamedDecl *> &DeclsInCond,
758 std::set<const VarDecl *> &DeclsNeedingMC,
759 std::map<const NamedDecl *, std::string> &DeclToMCVar,
760 std::set<const Stmt *> &StmtsHandled,
761 std::set<const Expr *> &MallocExprs,
762 std::vector<Update *> &DeferredUpdates) :
764 DeclsRead(DeclsRead),
765 DeclsInCond(DeclsInCond),
766 DeclsNeedingMC(DeclsNeedingMC),
767 DeclToMCVar(DeclToMCVar),
768 StmtsHandled(StmtsHandled),
769 MallocExprs(MallocExprs),
770 DeferredUpdates(DeferredUpdates) {}
772 virtual void run(const MatchFinder::MatchResult &Result) {
773 BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
774 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
775 FindLocalsVisitor flv;
777 const VarDecl * lhs = NULL;
778 const Expr * rhs = NULL;
781 if (s && (ds = dyn_cast<DeclStmt>(s))) {
782 // long term goal: refactor the run() method to deal with one assignment at a time
783 // for now, if there is only declarations and no rhs's, we'll ignore this stmt
784 if (!ds->isSingleDecl()) {
785 for (auto & d : ds->decls()) {
786 VarDecl * vd = dyn_cast<VarDecl>(d);
787 if (!d || vd->hasInit())
788 assert(0 && "unsupported form of decl");
793 lhs = retrieveSingleDecl(ds);
796 if (StmtsHandled.find(ds) != StmtsHandled.end() || StmtsHandled.find(op) != StmtsHandled.end())
800 if (lhs->hasInit()) {
801 rhs = lhs->getInit();
803 rhs = rhs->IgnoreCasts();
809 std::set<std::string> mcState;
811 bool lhsTooComplicated = false;
813 flv.TraverseStmt(op);
816 if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
817 lhs = dyn_cast<VarDecl>(vd->getDecl());
819 // kick the can along...
820 lhsTooComplicated = true;
825 rhs = rhs->IgnoreCasts();
828 // rhs must be MC-active state, i.e. in declsread
829 // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
830 flv.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
833 if (DeclsInCond.find(lhs) != DeclsInCond.end()) {
834 for (auto & d : flv.RetrieveVars()) {
835 if (DeclToMCVar.count(d) > 0)
836 mcState.insert(DeclToMCVar[d]);
837 else if (DeclsRead.find(d) != DeclsRead.end())
838 mcState.insert(encode(d->getName().str()));
842 if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
843 if (lhsTooComplicated)
844 assert(0 && "couldn't find LHS of = operator");
846 std::stringstream nol;
847 std::string _lhsStr, lhsStr;
848 std::string mcVar = encodeFn(fnCount++);
850 lhsStr = lhs->getName().str();
851 _lhsStr = encode(lhsStr);
852 DeclToMCVar[lhs] = mcVar;
853 DeclsNeedingMC.insert(cast<VarDecl>(lhs));
856 if (!(MallocExprs.find(rhs) != MallocExprs.end()))
857 function_id = ++funcCount;
858 nol << "\n" << mcVar << " = MC2_function_id(" << function_id << ", " << mcState.size();
860 nol << ", sizeof (" << lhsStr << "), (uint64_t)" << lhsStr;
862 nol << ", MC2_PTR_LENGTH";
863 for (auto & d : mcState) {
871 SourceLocation place;
873 place = op->getLocEnd().getLocWithOffset(1);
875 place = s->getLocEnd();
876 rewrite.InsertText(place.getLocWithOffset(1), nol.str(), true, true);
878 updateProvisionalName(DeferredUpdates, lhs, mcVar);
884 std::set<const NamedDecl *> &DeclsRead, &DeclsInCond;
885 std::set<const VarDecl *> &DeclsNeedingMC;
886 std::map<const NamedDecl *, std::string> &DeclToMCVar;
887 std::set<const Stmt *> &StmtsHandled;
888 std::set<const Expr *> &MallocExprs;
889 std::vector<Update *> &DeferredUpdates;
892 // record vars used in conditions
893 class BranchConditionRefactoringHandler : public MatchFinder::MatchCallback {
895 BranchConditionRefactoringHandler(Rewriter &Rewrite,
896 std::set<const NamedDecl *> & DeclsInCond,
897 std::map<const NamedDecl *, std::string> &DeclToMCVar,
898 std::map<const Expr *, std::string> &ExprToMCVar,
899 std::map<const Expr *, SourceLocation> &Redirector,
900 std::vector<Update *> &DeferredUpdates) :
901 Rewrite(Rewrite), DeclsInCond(DeclsInCond), DeclToMCVar(DeclToMCVar),
902 ExprToMCVar(ExprToMCVar), Redirector(Redirector), DeferredUpdates(DeferredUpdates) {}
904 virtual void run(const MatchFinder::MatchResult &Result) {
905 IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
906 Expr * cond = is->getCond();
908 // refactor out complicated conditions
909 FindCallArgVisitor flv;
910 flv.TraverseStmt(cond);
913 BinaryOperator * bc = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("bc"));
915 std::string condVar = encodeCond(condCount++);
916 std::stringstream condVarEncoded;
917 condVarEncoded << condVar << "_m";
919 // prettyprint the binary op
920 // e.g. int _cond0 = x == y;
922 llvm::raw_string_ostream S(SStr);
923 bc->printPretty(S, nullptr, Rewrite.getLangOpts());
924 const std::string &Str = S.str();
926 std::stringstream prel;
928 bool is_equality = false;
929 // handle equality tests
930 if (bc->getOpcode() == BO_EQ) {
931 Expr * lhs = bc->getLHS()->IgnoreCasts(), * rhs = bc->getRHS()->IgnoreCasts();
932 if (isa<DeclRefExpr>(lhs) && isa<DeclRefExpr>(rhs)) {
933 DeclRefExpr * l = dyn_cast<DeclRefExpr>(lhs), *r = dyn_cast<DeclRefExpr>(rhs);
935 prel << "\nMCID " << condVarEncoded.str() << ";\n";
936 std::string ld = DeclToMCVar.find(l->getDecl())->second,
937 rd = DeclToMCVar.find(r->getDecl())->second;
939 prel << "\nint " << condVar << " = MC2_equals(" <<
940 ld << ", (uint64_t)" << l->getNameInfo().getName().getAsString() << ", " <<
941 rd << ", (uint64_t)" << r->getNameInfo().getName().getAsString() << ", " <<
942 "&" << condVarEncoded.str() << ");\n";
947 prel << "\nint " << condVar << " = " << Str << ";";
948 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
949 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
950 if (DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
951 prel << ", " << DeclToMCVar[d->getDecl()];
956 ExprToMCVar[cond] = condVarEncoded.str();
957 Rewrite.InsertText(is->getLocStart(), prel.str(), false, true);
959 // rewrite the binary op with the newly-inserted var
960 Expr * RO = bc->getRHS(); // used for location only
962 int cl = Lexer::MeasureTokenLength(RO->getLocStart(), Rewrite.getSourceMgr(), Rewrite.getLangOpts());
963 SourceRange SR(cond->getLocStart(), Rewrite.getSourceMgr().getExpansionLoc(RO->getLocStart()).getLocWithOffset(cl-1));
964 Rewrite.ReplaceText(SR, condVar);
966 std::string condVar = encodeCond(condCount++);
967 std::stringstream condVarEncoded;
968 condVarEncoded << condVar << "_m";
971 llvm::raw_string_ostream S(SStr);
972 cond->printPretty(S, nullptr, Rewrite.getLangOpts());
973 const std::string &Str = S.str();
975 std::stringstream prel;
976 prel << "\nint " << condVar << " = " << Str << ";";
977 prel << "\nMCID " << condVarEncoded.str() << " = MC2_function_id(" << ++funcCount << ", 1, sizeof(" << condVar << "), " << condVar;
978 std::vector<ProvisionalName *> * vars = new std::vector<ProvisionalName *>();
979 const DeclRefExpr * d = flv.RetrieveDeclRefExpr();
980 if (isa<VarDecl>(d->getDecl()) && DeclToMCVar.find(d->getDecl()) != DeclToMCVar.end()) {
981 prel << ", " << DeclToMCVar[d->getDecl()];
984 ProvisionalName * v = new ProvisionalName(prel.tellp(), d, 0);
989 ExprToMCVar[cond] = condVarEncoded.str();
990 // gross hack; should look for any callexprs in cond
991 // but right now, if it's a unaryop, just manually traverse
992 if (isa<UnaryOperator>(cond)) {
993 Expr * e = dyn_cast<UnaryOperator>(cond)->getSubExpr();
994 ExprToMCVar[e] = condVarEncoded.str();
996 Update * u = new Update(is->getLocStart(), prel.str(), vars);
997 DeferredUpdates.push_back(u);
999 // rewrite the call op with the newly-inserted var
1000 SourceRange SR(cond->getLocStart(), cond->getLocEnd());
1001 Redirector[cond] = is->getLocStart();
1002 Rewrite.ReplaceText(SR, condVar);
1005 std::deque<const Decl *> q;
1006 const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1008 while (!q.empty()) {
1009 const Decl * d = q.back();
1011 if (isa<NamedDecl>(d))
1012 DeclsInCond.insert(cast<NamedDecl>(d));
1015 if ((vd = dyn_cast<VarDecl>(d))) {
1016 if (vd->hasInit()) {
1017 const Expr * e = vd->getInit();
1019 flv.TraverseStmt(const_cast<Expr *>(e));
1020 const NamedDecl * d = flv.RetrieveDeclRefExpr()->getDecl();
1029 std::set<const NamedDecl *> & DeclsInCond;
1030 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1031 std::map<const Expr *, std::string> &ExprToMCVar;
1032 std::map<const Expr *, SourceLocation> &Redirector;
1033 std::vector<Update *> &DeferredUpdates;
1036 class BranchAnnotationHandler : public MatchFinder::MatchCallback {
1038 BranchAnnotationHandler(Rewriter &rewrite,
1039 std::map<const NamedDecl *, std::string> & DeclToMCVar,
1040 std::map<const Expr *, std::string> & ExprToMCVar)
1042 DeclToMCVar(DeclToMCVar),
1043 ExprToMCVar(ExprToMCVar){}
1044 virtual void run(const MatchFinder::MatchResult &Result) {
1045 IfStmt * is = const_cast<IfStmt *>(Result.Nodes.getNodeAs<IfStmt>("if"));
1047 // if the branch condition is interesting:
1048 // (but right now, not too interesting)
1049 Expr * cond = is->getCond()->IgnoreCasts();
1051 FindLocalsVisitor flv;
1052 flv.TraverseStmt(cond);
1053 if (flv.RetrieveVars().size() == 0) return;
1055 const NamedDecl * condVar = flv.RetrieveVars()[0];
1057 std::string mCondVar;
1058 if (ExprToMCVar.count(cond) > 0)
1059 mCondVar = ExprToMCVar[cond];
1060 else if (DeclToMCVar.count(condVar) > 0)
1061 mCondVar = DeclToMCVar[condVar];
1063 mCondVar = encode(condVar->getName());
1064 std::string brVar = encodeBranch(branchCount++);
1066 std::stringstream brline;
1067 brline << "MCID " << brVar << ";\n";
1068 Rewrite.InsertText(is->getLocStart(), brline.str(), false, true);
1070 Stmt * ts = is->getThen(), * es = is->getElse();
1071 bool tHasChild = hasChild(ts);
1074 if (isa<CompoundStmt>(ts))
1075 tfl = getFirstChild(ts)->getLocStart();
1077 tfl = ts->getLocStart();
1079 tfl = ts->getLocStart().getLocWithOffset(1);
1080 SourceLocation tsl = ts->getLocEnd().getLocWithOffset(-1);
1082 std::stringstream tlineStart, mergeStmt, eline;
1084 UnaryOperator * uop = dyn_cast<UnaryOperator>(cond);
1085 tlineStart << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "1" << ", 2, true);\n";
1086 eline << brVar << " = MC2_branchUsesID(" << mCondVar << ", " << "0" << ", 2, true);";
1088 mergeStmt << "\tMC2_merge(" << brVar << ");\n";
1090 Rewrite.InsertText(tfl, tlineStart.str(), false, true);
1093 int extra_else_offset = 0;
1095 if (tHasChild) { tls = getLastChild(ts); }
1096 if (tls) extra_else_offset = 2; else extra_else_offset = 1;
1098 if (!tHasChild || (!isa<ReturnStmt>(tls) && !isa<BreakStmt>(tls))) {
1099 extra_else_offset = 0;
1100 Rewrite.InsertText(tsl.getLocWithOffset(1), mergeStmt.str(), true, true);
1102 if (tHasChild && !isa<CompoundStmt>(ts)) {
1103 Rewrite.InsertText(tls->getLocStart(), "{", false, true);
1104 SourceLocation tend = Lexer::getLocForEndOfToken(tls->getLocStart(), 0, Rewrite.getSourceMgr(), Rewrite.getLangOpts());
1105 Rewrite.InsertText(tend.getLocWithOffset(2), "}", true, true);
1106 extra_else_offset++;
1108 if (tHasChild && isa<CompoundStmt>(ts)) extra_else_offset++;
1111 SourceLocation esl = es->getLocEnd().getLocWithOffset(-1);
1112 bool eHasChild = hasChild(es);
1114 if (eHasChild) els = getLastChild(es); else els = es;
1120 if (isa<CompoundStmt>(es))
1121 el = getFirstChild(es)->getLocStart();
1123 el = es->getLocStart();
1126 el = es->getLocStart().getLocWithOffset(1);
1127 Rewrite.InsertText(el, eline.str(), false, true);
1129 if (eHasChild && !isa<CompoundStmt>(es)) {
1130 Rewrite.InsertText(el, "{", false, true);
1131 Rewrite.InsertText(es->getLocEnd().getLocWithOffset(1), "}", true, true);
1134 if (!eHasChild || (!isa<ReturnStmt>(els) && !isa<BreakStmt>(els)))
1135 Rewrite.InsertText(esl.getLocWithOffset(1), mergeStmt.str(), true, true);
1138 std::stringstream eCompoundLine;
1139 eCompoundLine << " else { " << eline.str() << mergeStmt.str() << " }";
1140 SourceLocation tend = Lexer::getLocForEndOfToken(ts->getLocEnd(), 0, Rewrite.getSourceMgr(), Rewrite.getLangOpts());
1141 Rewrite.InsertText(tend.getLocWithOffset(extra_else_offset),
1142 eCompoundLine.str(), false, true);
1147 bool hasChild(Stmt * s) {
1148 if (!isa<CompoundStmt>(s)) return true;
1149 return (!cast<CompoundStmt>(s)->body_empty());
1152 Stmt * getFirstChild(Stmt * s) {
1153 assert(isa<CompoundStmt>(s) && "haven't yet added code to rewrite then/elsestmt to CompoundStmt");
1154 assert(!cast<CompoundStmt>(s)->body_empty());
1155 return *(cast<CompoundStmt>(s)->body_begin());
1158 Stmt * getLastChild(Stmt * s) {
1160 if ((cs = dyn_cast<CompoundStmt>(s))) {
1161 assert (!cs->body_empty());
1162 return cs->body_back();
1168 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1169 std::map<const Expr *, std::string> &ExprToMCVar;
1172 class FunctionCallHandler : public MatchFinder::MatchCallback {
1174 FunctionCallHandler(Rewriter &rewrite,
1175 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1176 std::set<const FunctionDecl *> &ThreadMains)
1177 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1179 virtual void run(const MatchFinder::MatchResult &Result) {
1180 CallExpr * ce = const_cast<CallExpr *>(Result.Nodes.getNodeAs<CallExpr>("callExpr"));
1181 Decl * d = ce->getCalleeDecl();
1182 NamedDecl * nd = dyn_cast<NamedDecl>(d);
1183 const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
1184 ASTContext *Context = Result.Context;
1186 if (nd->getName() == "thrd_create") {
1187 Expr * callee0 = ce->getArg(1)->IgnoreCasts();
1188 UnaryOperator * callee1;
1189 if ((callee1 = dyn_cast<UnaryOperator>(callee0))) {
1190 if (callee1->getOpcode() == UnaryOperatorKind::UO_AddrOf)
1191 callee0 = callee1->getSubExpr();
1193 DeclRefExpr * callee = dyn_cast<DeclRefExpr>(callee0);
1194 if (!callee) return;
1195 FunctionDecl * fd = dyn_cast<FunctionDecl>(callee->getDecl());
1196 ThreadMains.insert(fd);
1203 if (s && !ce->getCallReturnType(*Context)->isVoidType()) {
1204 // TODO check that the type is mc-visible also?
1205 const DeclStmt * ds;
1206 const VarDecl * lhs = NULL;
1207 std::string mc_rv = encodeRV(rvCount++);
1209 std::stringstream brline;
1210 brline << "MCID " << mc_rv << ";\n";
1211 rewrite.InsertText(s->getLocStart(), brline.str(), false, true);
1213 std::stringstream nol;
1214 if (ce->getNumArgs() > 0) nol << ", ";
1215 nol << "&" << mc_rv;
1216 rewrite.InsertTextBefore(ce->getRParenLoc(), nol.str());
1218 if (s && (ds = dyn_cast<DeclStmt>(s))) {
1219 if (!ds->isSingleDecl()) {
1220 for (auto & d : ds->decls()) {
1221 VarDecl * vd = dyn_cast<VarDecl>(d);
1222 if (!d || vd->hasInit())
1223 assert(0 && "unsupported form of decl");
1228 lhs = retrieveSingleDecl(ds);
1231 DeclToMCVar[lhs] = mc_rv;
1234 for (const auto & a : ce->arguments()) {
1235 std::stringstream nol;
1237 std::string aa = "MCID_NODEP";
1239 Expr * e = a->IgnoreCasts();
1240 DeclRefExpr * dr = dyn_cast<DeclRefExpr>(e);
1242 NamedDecl * d = dr->getDecl();
1243 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1244 aa = DeclToMCVar[d];
1249 if (a->getLocEnd().isValid())
1250 rewrite.InsertTextBefore(a->getLocStart(), nol.str());
1256 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1257 std::set<const FunctionDecl *> &ThreadMains;
1260 class ReturnHandler : public MatchFinder::MatchCallback {
1262 ReturnHandler(Rewriter &rewrite,
1263 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1264 std::set<const FunctionDecl *> &ThreadMains)
1265 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), ThreadMains(ThreadMains) {}
1267 virtual void run(const MatchFinder::MatchResult &Result) {
1268 const FunctionDecl * fd = Result.Nodes.getNodeAs<FunctionDecl>("containingFunction");
1269 ReturnStmt * rs = const_cast<ReturnStmt *>(Result.Nodes.getNodeAs<ReturnStmt>("returnStmt"));
1270 Expr * rv = const_cast<Expr *>(rs->getRetValue());
1273 if (ThreadMains.find(fd) != ThreadMains.end()) return;
1274 // not sure why this is explicitly needed, but crashes without it
1275 if (!fd->getIdentifier() || fd->getName() == "user_main") return;
1277 FindLocalsVisitor flv;
1278 flv.TraverseStmt(rv);
1279 std::string mrv = "MCID_NODEP";
1281 if (flv.RetrieveVars().size() > 0) {
1282 const NamedDecl * returnVar = flv.RetrieveVars()[0];
1283 if (DeclToMCVar.find(returnVar) != DeclToMCVar.end()) {
1284 mrv = DeclToMCVar[returnVar];
1287 std::stringstream nol;
1288 nol << "*retval = " << mrv << ";\n";
1289 rewrite.InsertText(rs->getLocStart(), nol.str(), false, true);
1294 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1295 std::set<const FunctionDecl *> &ThreadMains;
1298 class VarDeclHandler : public MatchFinder::MatchCallback {
1300 VarDeclHandler(Rewriter &rewrite,
1301 std::map<const NamedDecl *, std::string> &DeclToMCVar,
1302 std::set<const VarDecl *> &DeclsNeedingMC)
1303 : rewrite(rewrite), DeclToMCVar(DeclToMCVar), DeclsNeedingMC(DeclsNeedingMC) {}
1305 virtual void run(const MatchFinder::MatchResult &Result) {
1306 VarDecl * d = const_cast<VarDecl *>(Result.Nodes.getNodeAs<VarDecl>("d"));
1307 std::stringstream nol;
1309 if (DeclsNeedingMC.find(d) == DeclsNeedingMC.end()) return;
1312 if (DeclToMCVar.find(d) != DeclToMCVar.end())
1313 dn = DeclToMCVar[d];
1315 dn = encode(d->getName().str());
1317 nol << "MCID " << dn << "; ";
1319 if (d->getLocStart().isValid())
1320 rewrite.InsertTextBefore(d->getLocStart(), nol.str());
1325 std::map<const NamedDecl *, std::string> &DeclToMCVar;
1326 std::set<const VarDecl *> &DeclsNeedingMC;
1329 class FunctionDeclHandler : public MatchFinder::MatchCallback {
1331 FunctionDeclHandler(Rewriter &rewrite,
1332 std::set<const FunctionDecl *> &ThreadMains)
1333 : rewrite(rewrite), ThreadMains(ThreadMains) {}
1335 virtual void run(const MatchFinder::MatchResult &Result) {
1336 FunctionDecl * fd = const_cast<FunctionDecl *>(Result.Nodes.getNodeAs<FunctionDecl>("fd"));
1338 if (!fd->getIdentifier()) return;
1340 if (fd->getName() == "user_main") { ThreadMains.insert(fd); return; }
1342 if (ThreadMains.find(fd) != ThreadMains.end()) return;
1344 SourceLocation LastParam = fd->getNameInfo().getLocStart().getLocWithOffset(fd->getName().size()).getLocWithOffset(1);
1345 for (auto & p : fd->params()) {
1346 std::stringstream nol;
1347 nol << "MCID " << encode(p->getName()) << ", ";
1348 if (p->getLocStart().isValid())
1349 rewrite.InsertText(p->getLocStart(), nol.str(), false);
1350 if (p->getLocEnd().isValid())
1351 LastParam = p->getLocEnd().getLocWithOffset(p->getName().size());
1354 if (!fd->getReturnType()->isVoidType()) {
1355 std::stringstream nol;
1356 if (fd->param_size() > 0) nol << ", ";
1357 nol << "MCID * retval";
1358 rewrite.InsertText(LastParam, nol.str(), false);
1364 std::set<const FunctionDecl *> &ThreadMains;
1367 class BailHandler : public MatchFinder::MatchCallback {
1370 virtual void run(const MatchFinder::MatchResult &Result) {
1371 assert(0 && "we don't handle goto statements");
1375 class MyASTConsumer : public ASTConsumer {
1377 MyASTConsumer(Rewriter &R) : R(R),
1381 HandlerMalloc(MallocExprs),
1382 HandlerLoad(R, DeclsRead, DeclsNeedingMC, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1383 HandlerStore(R, DeclsRead, DeclsNeedingMC, DeferredUpdates),
1384 HandlerRMW(R, DeclsRead, DeclsInCond, DeclToMCVar, ExprToMCVar, StmtsHandled, Redirector, DeferredUpdates),
1386 HandlerBranchConditionRefactoring(R, DeclsInCond, DeclToMCVar, ExprToMCVar, Redirector, DeferredUpdates),
1387 HandlerAssign(R, DeclsRead, DeclsInCond, DeclsNeedingMC, DeclToMCVar, StmtsHandled, MallocExprs, DeferredUpdates),
1388 HandlerAnnotateBranch(R, DeclToMCVar, ExprToMCVar),
1389 HandlerFunctionDecl(R, ThreadMains),
1390 HandlerFunctionCall(R, DeclToMCVar, ThreadMains),
1391 HandlerReturn(R, DeclToMCVar, ThreadMains),
1392 HandlerVarDecl(R, DeclToMCVar, DeclsNeedingMC),
1394 MatcherFunctionCall.addMatcher(callExpr(anyOf(hasParent(compoundStmt()),
1395 hasAncestor(varDecl(hasParent(stmt().bind("containingStmt")))),
1396 hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")))).bind("callExpr"),
1397 &HandlerFunctionCall);
1398 MatcherLoadStore.addMatcher
1399 (callExpr(callee(functionDecl(anyOf(hasName("malloc"), hasName("calloc"))))).bind("callExpr"),
1402 MatcherLoadStore.addMatcher
1403 (callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64")))),
1404 anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1405 hasAncestor(binaryOperator(hasOperatorName("=")).bind("containingStmt")),
1406 hasParent(stmt().bind("containingStmt"))))
1410 MatcherLoadStore.addMatcher(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr"),
1413 MatcherLoadStore.addMatcher
1414 (callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64")))),
1415 anyOf(hasAncestor(varDecl(hasParent(stmt().bind("containingStmt"))).bind("decl")),
1416 hasAncestor(binaryOperator(hasOperatorName("="),
1417 hasLHS(declRefExpr().bind("lhs"))).bind("containingStmt")),
1422 MatcherLoadStore.addMatcher(ifStmt(hasCondition
1423 (anyOf(binaryOperator().bind("bc"),
1424 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("load_32"), hasName("load_64"))))).bind("callExpr")),
1425 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("store_32"), hasName("store_64"))))).bind("callExpr")),
1426 hasDescendant(callExpr(callee(functionDecl(anyOf(hasName("rmw_32"), hasName("rmw_64"))))).bind("callExpr")),
1427 anything()))).bind("if"),
1428 &HandlerBranchConditionRefactoring);
1430 MatcherLoadStore.addMatcher(forStmt().bind("s"),
1432 MatcherLoadStore.addMatcher(whileStmt().bind("s"),
1434 MatcherLoadStore.addMatcher(doStmt().bind("s"),
1437 MatcherFunction.addMatcher(binaryOperator(anyOf(hasAncestor(declStmt().bind("containingStmt")),
1438 hasParent(compoundStmt())),
1439 hasOperatorName("=")).bind("op"),
1441 MatcherFunction.addMatcher(declStmt().bind("containingStmt"), &HandlerAssign);
1443 MatcherFunction.addMatcher(ifStmt().bind("if"),
1444 &HandlerAnnotateBranch);
1446 MatcherFunctionDecl.addMatcher(functionDecl().bind("fd"),
1447 &HandlerFunctionDecl);
1448 MatcherFunctionDecl.addMatcher(varDecl().bind("d"), &HandlerVarDecl);
1449 MatcherFunctionDecl.addMatcher(returnStmt(hasAncestor(functionDecl().bind("containingFunction"))).bind("returnStmt"),
1452 MatcherSanity.addMatcher(gotoStmt(), &HandlerBail);
1455 // Override the method that gets called for each parsed top-level
1457 void HandleTranslationUnit(ASTContext &Context) override {
1458 LangOpts = Context.getLangOpts();
1460 MatcherFunctionCall.matchAST(Context);
1461 MatcherLoadStore.matchAST(Context);
1462 MatcherFunction.matchAST(Context);
1463 MatcherFunctionDecl.matchAST(Context);
1464 MatcherSanity.matchAST(Context);
1466 for (auto & u : DeferredUpdates) {
1467 R.InsertText(u->loc, u->update, true, true);
1470 DeferredUpdates.clear();
1474 /* DeclsRead contains all local variables 'x' which:
1475 * 1) appear in 'x = load_32(...);
1476 * 2) appear in 'y = store_32(x); */
1477 std::set<const NamedDecl *> DeclsRead, DeclsInCond;
1478 std::map<const NamedDecl *, std::string> DeclToMCVar;
1479 std::map<const Expr *, std::string> ExprToMCVar;
1480 std::set<const VarDecl *> DeclsNeedingMC;
1481 std::set<const FunctionDecl *> ThreadMains;
1482 std::set<const Stmt *> StmtsHandled;
1483 std::set<const Expr *> MallocExprs;
1484 std::map<const Expr *, SourceLocation> Redirector;
1485 std::vector<Update *> DeferredUpdates;
1489 MallocHandler HandlerMalloc;
1490 LoadHandler HandlerLoad;
1491 StoreHandler HandlerStore;
1492 RMWHandler HandlerRMW;
1493 LoopHandler HandlerLoop;
1494 BranchConditionRefactoringHandler HandlerBranchConditionRefactoring;
1495 BranchAnnotationHandler HandlerAnnotateBranch;
1496 AssignHandler HandlerAssign;
1497 FunctionDeclHandler HandlerFunctionDecl;
1498 FunctionCallHandler HandlerFunctionCall;
1499 ReturnHandler HandlerReturn;
1500 VarDeclHandler HandlerVarDecl;
1501 BailHandler HandlerBail;
1502 MatchFinder MatcherLoadStore, MatcherFunction, MatcherFunctionDecl, MatcherFunctionCall, MatcherSanity;
1505 // For each source file provided to the tool, a new FrontendAction is created.
1506 class MyFrontendAction : public ASTFrontendAction {
1508 MyFrontendAction() {}
1509 void EndSourceFileAction() override {
1510 SourceManager &SM = TheRewriter.getSourceMgr();
1511 llvm::errs() << "** EndSourceFileAction for: "
1512 << SM.getFileEntryForID(SM.getMainFileID())->getName() << "\n";
1514 // Now emit the rewritten buffer.
1515 TheRewriter.getEditBuffer(SM.getMainFileID()).write(llvm::outs());
1518 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
1519 StringRef file) override {
1520 llvm::errs() << "** Creating AST consumer for: " << file << "\n";
1521 TheRewriter.setSourceMgr(CI.getSourceManager(), CI.getLangOpts());
1522 return llvm::make_unique<MyASTConsumer>(TheRewriter);
1526 Rewriter TheRewriter;
1529 int main(int argc, const char **argv) {
1530 CommonOptionsParser op(argc, argv, AddMC2AnnotationsCategory);
1531 ClangTool Tool(op.getCompilations(), op.getSourcePathList());
1533 return Tool.run(newFrontendActionFactory<MyFrontendAction>().get());