Fix snapshot code
[model-checker.git] / cyclegraph.cc
index 20623347213d0958bfdff79d142595420461b015..def51f9671e346251a4f108abe1c78c8d09a8bfe 100644 (file)
@@ -2,13 +2,12 @@
 #include "action.h"
 #include "common.h"
 #include "promise.h"
-#include "model.h"
 #include "threads-model.h"
 
 /** Initializes a CycleGraph object. */
 CycleGraph::CycleGraph() :
        discovered(new HashTable<const CycleNode *, const CycleNode *, uintptr_t, 4, model_malloc, model_calloc, model_free>(16)),
-       queue(new model_vector< const CycleNode * >()),
+       queue(new ModelVector<const CycleNode *>()),
        hasCycles(false),
        oldCycles(false)
 {
@@ -198,29 +197,33 @@ bool CycleGraph::mergeNodes(CycleNode *w_node, CycleNode *p_node)
  */
 bool CycleGraph::addNodeEdge(CycleNode *fromnode, CycleNode *tonode)
 {
-       bool added;
-
-       if ((added = fromnode->addEdge(tonode))) {
+       if (fromnode->addEdge(tonode)) {
                rollbackvector.push_back(fromnode);
                if (!hasCycles)
                        hasCycles = checkReachable(tonode, fromnode);
-       }
+       } else
+               return false; /* No new edge */
 
        /*
-        * If the fromnode has a rmwnode that is not the tonode, we should add
-        * an edge between its rmwnode and the tonode
+        * If the fromnode has a rmwnode that is not the tonode, we should
+        * follow its RMW chain to add an edge at the end, unless we encounter
+        * tonode along the way
         */
        CycleNode *rmwnode = fromnode->getRMW();
-       if (rmwnode && rmwnode != tonode) {
-               if (rmwnode->addEdge(tonode)) {
-                       if (!hasCycles)
-                               hasCycles = checkReachable(tonode, rmwnode);
+       if (rmwnode) {
+               while (rmwnode != tonode && rmwnode->getRMW())
+                       rmwnode = rmwnode->getRMW();
+
+               if (rmwnode != tonode) {
+                       if (rmwnode->addEdge(tonode)) {
+                               if (!hasCycles)
+                                       hasCycles = checkReachable(tonode, rmwnode);
 
-                       rollbackvector.push_back(rmwnode);
-                       added = true;
+                               rollbackvector.push_back(rmwnode);
+                       }
                }
        }
-       return added;
+       return true;
 }
 
 /**
@@ -228,8 +231,8 @@ bool CycleGraph::addNodeEdge(CycleNode *fromnode, CycleNode *tonode)
  *
  * Handles special case of a RMW action, where the ModelAction rmw reads from
  * the ModelAction/Promise from. The key differences are:
- * (1) no write can occur in between the rmw and the from action.
- * (2) Only one RMW action can read from a given write.
+ *  -# No write can occur in between the @a rmw and @a from actions.
+ *  -# Only one RMW action can read from a given write.
  *
  * @param from The edge comes from this ModelAction/Promise
  * @param rmw The edge points to this ModelAction; this action must read from
@@ -244,6 +247,9 @@ void CycleGraph::addRMWEdge(const T *from, const ModelAction *rmw)
        CycleNode *fromnode = getNode(from);
        CycleNode *rmwnode = getNode(rmw);
 
+       /* We assume that this RMW has no RMW reading from it yet */
+       ASSERT(!rmwnode->getRMW());
+
        /* Two RMW actions cannot read from the same write. */
        if (fromnode->setRMW(rmwnode))
                hasCycles = true;
@@ -307,12 +313,12 @@ static void print_node(FILE *file, const CycleNode *node, int label)
 {
        if (node->is_promise()) {
                const Promise *promise = node->getPromise();
-               int idx = model->get_promise_number(promise);
+               int idx = promise->get_index();
                fprintf(file, "P%u", idx);
                if (label) {
                        int first = 1;
                        fprintf(file, " [label=\"P%d, T", idx);
-                       for (unsigned int i = 0 ; i < model->get_num_threads(); i++)
+                       for (unsigned int i = 0 ; i < promise->max_available_thread_idx(); i++)
                                if (promise->thread_is_available(int_to_id(i))) {
                                        fprintf(file, "%s%u", first ? "": ",", i);
                                        first = 0;
@@ -450,9 +456,8 @@ bool CycleGraph::checkPromise(const ModelAction *fromact, Promise *promise) cons
 
                if (node->getPromise() == promise)
                        return true;
-               if (!node->is_promise() &&
-                               promise->eliminate_thread(node->getAction()->get_tid()))
-                       return true;
+               if (!node->is_promise())
+                       promise->eliminate_thread(node->getAction()->get_tid());
 
                for (unsigned int i = 0; i < node->getNumEdges(); i++) {
                        CycleNode *next = node->getEdge(i);
@@ -465,6 +470,7 @@ bool CycleGraph::checkPromise(const ModelAction *fromact, Promise *promise) cons
        return false;
 }
 
+/** @brief Begin a new sequence of graph additions which can be rolled back */
 void CycleGraph::startChanges()
 {
        ASSERT(rollbackvector.empty());
@@ -524,7 +530,7 @@ CycleNode::CycleNode(const Promise *promise) :
 
 /**
  * @param i The index of the edge to return
- * @returns The CycleNode edge indexed by i
+ * @returns The CycleNode edge indexed by i
  */
 CycleNode * CycleNode::getEdge(unsigned int i) const
 {
@@ -537,11 +543,16 @@ unsigned int CycleNode::getNumEdges() const
        return edges.size();
 }
 
+/**
+ * @param i The index of the back edge to return
+ * @returns The CycleNode back-edge indexed by i
+ */
 CycleNode * CycleNode::getBackEdge(unsigned int i) const
 {
        return back_edges[i];
 }
 
+/** @returns The number of edges entering this CycleNode */
 unsigned int CycleNode::getNumBackEdges() const
 {
        return back_edges.size();
@@ -554,7 +565,7 @@ unsigned int CycleNode::getNumBackEdges() const
  * @return True if the element was found; false otherwise
  */
 template <typename T>
-static bool vector_remove_node(std::vector<T, SnapshotAlloc<T> >& v, const T n)
+static bool vector_remove_node(SnapVector<T>& v, const T n)
 {
        for (unsigned int i = 0; i < v.size(); i++) {
                if (v[i] == n) {