X-Git-Url: http://plrg.eecs.uci.edu/git/?p=satcheck.git;a=blobdiff_plain;f=clang%2Fsrc%2Fadd_mc2_annotations.cpp;h=07246f9cc99f3beaaef2ede15b25649d4f8e2f0a;hp=a2539e545c3a98d1e9884429a49e9c00edd2cf44;hb=c0828349d8d79e469f450fb1e6b8dd717637c5f0;hpb=52b59882a22019190a4f6cb35c65ce1221b4a96e diff --git a/clang/src/add_mc2_annotations.cpp b/clang/src/add_mc2_annotations.cpp index a2539e5..07246f9 100644 --- a/clang/src/add_mc2_annotations.cpp +++ b/clang/src/add_mc2_annotations.cpp @@ -774,7 +774,7 @@ public: virtual void run(const MatchFinder::MatchResult &Result) { BinaryOperator * op = const_cast(Result.Nodes.getNodeAs("op")); const Stmt * s = Result.Nodes.getNodeAs("containingStmt"); - FindLocalsVisitor flv; + FindLocalsVisitor locals, locals_rhs; const VarDecl * lhs = NULL; const Expr * rhs = NULL; @@ -810,10 +810,11 @@ public: } std::set mcState; + bool lhsUsedInCond; + bool rhsRead = false; + bool lhsTooComplicated = false; if (op) { - flv.TraverseStmt(op); - DeclRefExpr * vd; if ((vd = dyn_cast(op->getLHS()))) lhs = dyn_cast(vd->getDecl()); @@ -826,21 +827,37 @@ public: if (rhs) rhs = rhs->IgnoreCasts(); } - else if (lhs) { - // rhs must be MC-active state, i.e. in declsread - // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions - flv.TraverseStmt(const_cast(cast(rhs))); + + // rhs must be MC-active state, i.e. in declsread + // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions + + if (rhs) { + locals_rhs.TraverseStmt(const_cast(cast(rhs))); + for (auto & nd : locals_rhs.RetrieveVars()) { + if (DeclsRead.find(nd) != DeclsRead.end()) + rhsRead = true; + } } - if (DeclsInCond.find(lhs) != DeclsInCond.end()) { - for (auto & d : flv.RetrieveVars()) { + locals.TraverseStmt(const_cast(cast(rhs))); + + lhsUsedInCond = DeclsInCond.find(lhs) != DeclsInCond.end(); + if (lhsUsedInCond) { + for (auto & d : locals.RetrieveVars()) { + if (DeclToMCVar.count(d) > 0) + mcState.insert(DeclToMCVar[d]); + else if (DeclsRead.find(d) != DeclsRead.end()) + mcState.insert(encode(d->getName().str())); + } + } + if (rhsRead) { + for (auto & d : locals_rhs.RetrieveVars()) { if (DeclToMCVar.count(d) > 0) mcState.insert(DeclToMCVar[d]); else if (DeclsRead.find(d) != DeclsRead.end()) mcState.insert(encode(d->getName().str())); } } - if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) { if (lhsTooComplicated) assert(0 && "couldn't find LHS of = operator"); @@ -871,9 +888,9 @@ public: } nol << "); "; SourceLocation place; - if (op) - place = op->getLocEnd().getLocWithOffset(1); - else + if (op) { + place = Lexer::getLocForEndOfToken(op->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts()).getLocWithOffset(1); + } else place = s->getLocEnd(); rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)), nol.str(), true, true); @@ -1487,7 +1504,9 @@ private: /* DeclsRead contains all local variables 'x' which: * 1) appear in 'x = load_32(...); * 2) appear in 'y = store_32(x); */ - std::set DeclsRead, DeclsInCond; + std::set DeclsRead; + /* DeclsInCond contains all local variables 'x' used in a branch condition or rmw parameter */ + std::set DeclsInCond; std::map DeclToMCVar; std::map ExprToMCVar; std::set DeclsNeedingMC;