add MC2_function call for assignments where RHS computed from loads; tweak tests
[satcheck.git] / clang / src / add_mc2_annotations.cpp
index a2539e545c3a98d1e9884429a49e9c00edd2cf44..07246f9cc99f3beaaef2ede15b25649d4f8e2f0a 100644 (file)
@@ -774,7 +774,7 @@ public:
     virtual void run(const MatchFinder::MatchResult &Result) {
         BinaryOperator * op = const_cast<BinaryOperator *>(Result.Nodes.getNodeAs<BinaryOperator>("op"));
         const Stmt * s = Result.Nodes.getNodeAs<Stmt>("containingStmt");
-        FindLocalsVisitor flv;
+        FindLocalsVisitor locals, locals_rhs;
 
         const VarDecl * lhs = NULL;
         const Expr * rhs = NULL;
@@ -810,10 +810,11 @@ public:
         }
         std::set<std::string> mcState;
 
+        bool lhsUsedInCond;
+        bool rhsRead = false;
+
         bool lhsTooComplicated = false;
         if (op) {
-            flv.TraverseStmt(op);
-
             DeclRefExpr * vd;
             if ((vd = dyn_cast<DeclRefExpr>(op->getLHS())))
                 lhs = dyn_cast<VarDecl>(vd->getDecl());
@@ -826,21 +827,37 @@ public:
             if (rhs) 
                 rhs = rhs->IgnoreCasts();
         }
-        else if (lhs) {
-            // rhs must be MC-active state, i.e. in declsread
-            // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
-            flv.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+
+        // rhs must be MC-active state, i.e. in declsread
+        // lhs must be subsequently used in (1) store/load or (2) branch condition or (3) other functions and (3a) uses values from other functions or (3b) uses values from loads, stores, or phi functions
+
+        if (rhs) {
+            locals_rhs.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+            for (auto & nd : locals_rhs.RetrieveVars()) {
+                if (DeclsRead.find(nd) != DeclsRead.end())
+                    rhsRead = true;
+            }
         }
 
-        if (DeclsInCond.find(lhs) != DeclsInCond.end()) {
-            for (auto & d : flv.RetrieveVars()) {
+        locals.TraverseStmt(const_cast<Stmt *>(cast<Stmt>(rhs)));
+
+        lhsUsedInCond = DeclsInCond.find(lhs) != DeclsInCond.end();
+        if (lhsUsedInCond) {
+            for (auto & d : locals.RetrieveVars()) {
+                if (DeclToMCVar.count(d) > 0)
+                    mcState.insert(DeclToMCVar[d]);
+                else if (DeclsRead.find(d) != DeclsRead.end())
+                    mcState.insert(encode(d->getName().str()));
+            }
+        }
+        if (rhsRead) {
+            for (auto & d : locals_rhs.RetrieveVars()) {
                 if (DeclToMCVar.count(d) > 0)
                     mcState.insert(DeclToMCVar[d]);
                 else if (DeclsRead.find(d) != DeclsRead.end())
                     mcState.insert(encode(d->getName().str()));
             }
         }
-
         if (mcState.size() > 0 || MallocExprs.find(rhs) != MallocExprs.end()) {
             if (lhsTooComplicated)
                 assert(0 && "couldn't find LHS of = operator");
@@ -871,9 +888,9 @@ public:
             }
             nol << "); ";
             SourceLocation place;
-            if (op)
-                place = op->getLocEnd().getLocWithOffset(1);
-            else
+            if (op) {
+                place = Lexer::getLocForEndOfToken(op->getLocEnd(), 0, rewrite.getSourceMgr(), rewrite.getLangOpts()).getLocWithOffset(1);
+            else
                 place = s->getLocEnd();
             rewrite.InsertText(rewrite.getSourceMgr().getExpansionLoc(place.getLocWithOffset(1)),
                                nol.str(), true, true);
@@ -1487,7 +1504,9 @@ private:
     /* DeclsRead contains all local variables 'x' which:
     * 1) appear in 'x = load_32(...);
     * 2) appear in 'y = store_32(x); */
-    std::set<const NamedDecl *> DeclsRead, DeclsInCond;
+    std::set<const NamedDecl *> DeclsRead;
+    /* DeclsInCond contains all local variables 'x' used in a branch condition or rmw parameter */
+    std::set<const NamedDecl *> DeclsInCond;
     std::map<const NamedDecl *, std::string> DeclToMCVar;
     std::map<const Expr *, std::string> ExprToMCVar;
     std::set<const VarDecl *> DeclsNeedingMC;