7088f364568d862a96acf814fcaa770e5a48bd75
[oota-llvm.git] / lib / CodeGen / PBQP / HeuristicSolver.h
1 #ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
2 #define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
3
4 #include "Solver.h"
5 #include "AnnotatedGraph.h"
6
7 #include <limits>
8 #include <iostream>
9
10 namespace PBQP {
11
12 /// \brief Important types for the HeuristicSolverImpl.
13 /// 
14 /// Declared seperately to allow access to heuristic classes before the solver
15 /// is fully constructed.
16 template <typename HeuristicNodeData, typename HeuristicEdgeData>
17 class HSITypes {
18 public:
19
20   class NodeData;
21   class EdgeData;
22
23   typedef AnnotatedGraph<NodeData, EdgeData> SolverGraph;
24   typedef typename SolverGraph::NodeIterator GraphNodeIterator;
25   typedef typename SolverGraph::EdgeIterator GraphEdgeIterator;
26   typedef typename SolverGraph::AdjEdgeIterator GraphAdjEdgeIterator;
27
28   typedef std::list<GraphNodeIterator> NodeList;
29   typedef typename NodeList::iterator NodeListIterator;
30
31   typedef std::vector<GraphNodeIterator> NodeStack;
32   typedef typename NodeStack::iterator NodeStackIterator;
33
34   class NodeData {
35     friend class EdgeData;
36
37   private:
38
39     typedef std::list<GraphEdgeIterator> LinksList;
40
41     unsigned numLinks;
42     LinksList links, solvedLinks;
43     NodeListIterator bucketItr;
44     HeuristicNodeData heuristicData;
45
46   public:
47
48     typedef typename LinksList::iterator AdjLinkIterator;
49
50   private:
51
52     AdjLinkIterator addLink(const GraphEdgeIterator &edgeItr) {
53       ++numLinks;
54       return links.insert(links.end(), edgeItr);
55     }
56
57     void delLink(const AdjLinkIterator &adjLinkItr) {
58       --numLinks;
59       links.erase(adjLinkItr);
60     }
61
62   public:
63
64     NodeData() : numLinks(0) {}
65
66     unsigned getLinkDegree() const { return numLinks; }
67
68     HeuristicNodeData& getHeuristicData() { return heuristicData; }
69     const HeuristicNodeData& getHeuristicData() const {
70       return heuristicData;
71     }
72
73     void setBucketItr(const NodeListIterator &bucketItr) {
74       this->bucketItr = bucketItr;
75     }
76
77     const NodeListIterator& getBucketItr() const {
78       return bucketItr;
79     }
80
81     AdjLinkIterator adjLinksBegin() {
82       return links.begin();
83     }
84
85     AdjLinkIterator adjLinksEnd() {
86       return links.end();
87     }
88
89     void addSolvedLink(const GraphEdgeIterator &solvedLinkItr) {
90       solvedLinks.push_back(solvedLinkItr);
91     }
92
93     AdjLinkIterator solvedLinksBegin() {
94       return solvedLinks.begin();
95     }
96
97     AdjLinkIterator solvedLinksEnd() {
98       return solvedLinks.end();
99     }
100
101   };
102
103   class EdgeData {
104   private:
105
106     SolverGraph &g;
107     GraphNodeIterator node1Itr, node2Itr;
108     HeuristicEdgeData heuristicData;
109     typename NodeData::AdjLinkIterator node1ThisEdgeItr, node2ThisEdgeItr;
110
111   public:
112
113     EdgeData(SolverGraph &g) : g(g) {}
114
115     HeuristicEdgeData& getHeuristicData() { return heuristicData; }
116     const HeuristicEdgeData& getHeuristicData() const {
117       return heuristicData;
118     }
119
120     void setup(const GraphEdgeIterator &thisEdgeItr) {
121       node1Itr = g.getEdgeNode1Itr(thisEdgeItr);
122       node2Itr = g.getEdgeNode2Itr(thisEdgeItr);
123
124       node1ThisEdgeItr = g.getNodeData(node1Itr).addLink(thisEdgeItr);
125       node2ThisEdgeItr = g.getNodeData(node2Itr).addLink(thisEdgeItr);
126     }
127
128     void unlink() {
129       g.getNodeData(node1Itr).delLink(node1ThisEdgeItr);
130       g.getNodeData(node2Itr).delLink(node2ThisEdgeItr);
131     }
132
133   };
134
135 };
136
137 template <typename Heuristic>
138 class HeuristicSolverImpl {
139 public:
140   // Typedefs to make life easier:
141   typedef HSITypes<typename Heuristic::NodeData,
142                    typename Heuristic::EdgeData> HSIT;
143   typedef typename HSIT::SolverGraph SolverGraph;
144   typedef typename HSIT::NodeData NodeData;
145   typedef typename HSIT::EdgeData EdgeData;
146   typedef typename HSIT::GraphNodeIterator GraphNodeIterator;
147   typedef typename HSIT::GraphEdgeIterator GraphEdgeIterator;
148   typedef typename HSIT::GraphAdjEdgeIterator GraphAdjEdgeIterator;
149
150   typedef typename HSIT::NodeList NodeList;
151   typedef typename HSIT::NodeListIterator NodeListIterator;
152
153   typedef std::vector<GraphNodeIterator> NodeStack;
154   typedef typename NodeStack::iterator NodeStackIterator;
155
156   /*!
157    * \brief Constructor, which performs all the actual solver work.
158    */
159   HeuristicSolverImpl(const SimpleGraph &orig) :
160     solution(orig.getNumNodes(), true)
161   {
162     copyGraph(orig);
163     simplify();
164     setup();
165     computeSolution();
166     computeSolutionCost(orig);
167   }
168
169   /*!
170    * \brief Returns the graph for this solver.
171    */
172   SolverGraph& getGraph() { return g; }
173
174   /*!
175    * \brief Return the solution found by this solver.
176    */
177   const Solution& getSolution() const { return solution; }
178
179 private:
180
181   /*!
182    * \brief Add the given node to the appropriate bucket for its link
183    * degree.
184    */
185   void addToBucket(const GraphNodeIterator &nodeItr) {
186     NodeData &nodeData = g.getNodeData(nodeItr);
187
188     switch (nodeData.getLinkDegree()) {
189       case 0: nodeData.setBucketItr(
190                 r0Bucket.insert(r0Bucket.end(), nodeItr));
191               break;                                            
192       case 1: nodeData.setBucketItr(
193                 r1Bucket.insert(r1Bucket.end(), nodeItr));
194               break;
195       case 2: nodeData.setBucketItr(
196                 r2Bucket.insert(r2Bucket.end(), nodeItr));
197               break;
198       default: heuristic.addToRNBucket(nodeItr);
199                break;
200     }
201   }
202
203   /*!
204    * \brief Remove the given node from the appropriate bucket for its link
205    * degree.
206    */
207   void removeFromBucket(const GraphNodeIterator &nodeItr) {
208     NodeData &nodeData = g.getNodeData(nodeItr);
209
210     switch (nodeData.getLinkDegree()) {
211       case 0: r0Bucket.erase(nodeData.getBucketItr()); break;
212       case 1: r1Bucket.erase(nodeData.getBucketItr()); break;
213       case 2: r2Bucket.erase(nodeData.getBucketItr()); break;
214       default: heuristic.removeFromRNBucket(nodeItr); break;
215     }
216   }
217
218 public:
219
220   /*!
221    * \brief Add a link.
222    */
223   void addLink(const GraphEdgeIterator &edgeItr) {
224     g.getEdgeData(edgeItr).setup(edgeItr);
225
226     if ((g.getNodeData(g.getEdgeNode1Itr(edgeItr)).getLinkDegree() > 2) ||
227         (g.getNodeData(g.getEdgeNode2Itr(edgeItr)).getLinkDegree() > 2)) {
228       heuristic.handleAddLink(edgeItr);
229     }
230   }
231
232   /*!
233    * \brief Remove link, update info for node.
234    *
235    * Only updates information for the given node, since usually the other
236    * is about to be removed.
237    */
238   void removeLink(const GraphEdgeIterator &edgeItr,
239                   const GraphNodeIterator &nodeItr) {
240
241     if (g.getNodeData(nodeItr).getLinkDegree() > 2) {
242       heuristic.handleRemoveLink(edgeItr, nodeItr);
243     }
244     g.getEdgeData(edgeItr).unlink();
245   }
246
247   /*!
248    * \brief Remove link, update info for both nodes. Useful for R2 only.
249    */
250   void removeLinkR2(const GraphEdgeIterator &edgeItr) {
251     GraphNodeIterator node1Itr = g.getEdgeNode1Itr(edgeItr);
252
253     if (g.getNodeData(node1Itr).getLinkDegree() > 2) {
254       heuristic.handleRemoveLink(edgeItr, node1Itr);
255     }
256     removeLink(edgeItr, g.getEdgeNode2Itr(edgeItr));
257   }
258
259   /*!
260    * \brief Removes all links connected to the given node.
261    */
262   void unlinkNode(const GraphNodeIterator &nodeItr) {
263     NodeData &nodeData = g.getNodeData(nodeItr);
264
265     typedef std::vector<GraphEdgeIterator> TempEdgeList;
266
267     TempEdgeList edgesToUnlink;
268     edgesToUnlink.reserve(nodeData.getLinkDegree());
269
270     // Copy adj edges into a temp vector. We want to destroy them during
271     // the unlink, and we can't do that while we're iterating over them.
272     std::copy(nodeData.adjLinksBegin(), nodeData.adjLinksEnd(),
273               std::back_inserter(edgesToUnlink));
274
275     for (typename TempEdgeList::iterator
276          edgeItr = edgesToUnlink.begin(), edgeEnd = edgesToUnlink.end();
277          edgeItr != edgeEnd; ++edgeItr) {
278
279       GraphNodeIterator otherNode = g.getEdgeOtherNode(*edgeItr, nodeItr);
280
281       removeFromBucket(otherNode);
282       removeLink(*edgeItr, otherNode);
283       addToBucket(otherNode);
284     }
285   }
286
287   /*!
288    * \brief Push the given node onto the stack to be solved with
289    * backpropagation.
290    */
291   void pushStack(const GraphNodeIterator &nodeItr) {
292     stack.push_back(nodeItr);
293   }
294
295   /*!
296    * \brief Set the solution of the given node.
297    */
298   void setSolution(const GraphNodeIterator &nodeItr, unsigned solIndex) {
299     solution.setSelection(g.getNodeID(nodeItr), solIndex);
300
301     for (GraphAdjEdgeIterator adjEdgeItr = g.adjEdgesBegin(nodeItr),
302          adjEdgeEnd = g.adjEdgesEnd(nodeItr);
303          adjEdgeItr != adjEdgeEnd; ++adjEdgeItr) {
304       GraphEdgeIterator edgeItr(*adjEdgeItr);
305       GraphNodeIterator adjNodeItr(g.getEdgeOtherNode(edgeItr, nodeItr));
306       g.getNodeData(adjNodeItr).addSolvedLink(edgeItr);
307     }
308   }
309
310 private:
311
312   SolverGraph g;
313   Heuristic heuristic;
314   Solution solution;
315
316   NodeList r0Bucket,
317            r1Bucket,
318            r2Bucket;
319
320   NodeStack stack;
321
322   // Copy the SimpleGraph into an annotated graph which we can use for reduction.
323   void copyGraph(const SimpleGraph &orig) {
324
325     assert((g.getNumEdges() == 0) && (g.getNumNodes() == 0) &&
326            "Graph should be empty prior to solver setup.");
327
328     assert(orig.areNodeIDsValid() &&
329            "Cannot copy from a graph with invalid node IDs.");
330
331     std::vector<GraphNodeIterator> newNodeItrs;
332
333     for (unsigned nodeID = 0; nodeID < orig.getNumNodes(); ++nodeID) {
334       newNodeItrs.push_back(
335         g.addNode(orig.getNodeCosts(orig.getNodeItr(nodeID)), NodeData()));
336     }
337
338     for (SimpleGraph::ConstEdgeIterator
339          origEdgeItr = orig.edgesBegin(), origEdgeEnd = orig.edgesEnd();
340          origEdgeItr != origEdgeEnd; ++origEdgeItr) {
341
342       unsigned id1 = orig.getNodeID(orig.getEdgeNode1Itr(origEdgeItr)),
343                id2 = orig.getNodeID(orig.getEdgeNode2Itr(origEdgeItr));
344
345       g.addEdge(newNodeItrs[id1], newNodeItrs[id2],
346                 orig.getEdgeCosts(origEdgeItr), EdgeData(g));
347     }
348
349     // Assign IDs to the new nodes using the ordering from the old graph,
350     // this will lead to nodes in the new graph getting the same ID as the
351     // corresponding node in the old graph.
352     g.assignNodeIDs(newNodeItrs);
353   }
354
355   // Simplify the annotated graph by eliminating independent edges and trivial
356   // nodes. 
357   void simplify() {
358     disconnectTrivialNodes();
359     eliminateIndependentEdges();
360   }
361
362   // Eliminate trivial nodes.
363   void disconnectTrivialNodes() {
364     for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
365          nodeItr != nodeEnd; ++nodeItr) {
366
367       if (g.getNodeCosts(nodeItr).getLength() == 1) {
368
369         std::vector<GraphEdgeIterator> edgesToRemove;
370
371         for (GraphAdjEdgeIterator adjEdgeItr = g.adjEdgesBegin(nodeItr),
372              adjEdgeEnd = g.adjEdgesEnd(nodeItr);
373              adjEdgeItr != adjEdgeEnd; ++adjEdgeItr) {
374
375           GraphEdgeIterator edgeItr = *adjEdgeItr;
376
377           if (g.getEdgeNode1Itr(edgeItr) == nodeItr) {
378             GraphNodeIterator otherNodeItr = g.getEdgeNode2Itr(edgeItr);
379             g.getNodeCosts(otherNodeItr) +=
380               g.getEdgeCosts(edgeItr).getRowAsVector(0);
381           }
382           else {
383             GraphNodeIterator otherNodeItr = g.getEdgeNode1Itr(edgeItr);
384             g.getNodeCosts(otherNodeItr) +=
385               g.getEdgeCosts(edgeItr).getColAsVector(0);
386           }
387
388           edgesToRemove.push_back(edgeItr);
389         }
390
391         while (!edgesToRemove.empty()) {
392           g.removeEdge(edgesToRemove.back());
393           edgesToRemove.pop_back();
394         }
395       }
396     }
397   }
398
399   void eliminateIndependentEdges() {
400     std::vector<GraphEdgeIterator> edgesToProcess;
401
402     for (GraphEdgeIterator edgeItr = g.edgesBegin(), edgeEnd = g.edgesEnd();
403          edgeItr != edgeEnd; ++edgeItr) {
404       edgesToProcess.push_back(edgeItr);
405     }
406
407     while (!edgesToProcess.empty()) {
408       tryToEliminateEdge(edgesToProcess.back());
409       edgesToProcess.pop_back();
410     }
411   }
412
413   void tryToEliminateEdge(const GraphEdgeIterator &edgeItr) {
414     if (tryNormaliseEdgeMatrix(edgeItr)) {
415       g.removeEdge(edgeItr); 
416     }
417   }
418
419   bool tryNormaliseEdgeMatrix(const GraphEdgeIterator &edgeItr) {
420
421     Matrix &edgeCosts = g.getEdgeCosts(edgeItr);
422     Vector &uCosts = g.getNodeCosts(g.getEdgeNode1Itr(edgeItr)),
423                &vCosts = g.getNodeCosts(g.getEdgeNode2Itr(edgeItr));
424
425     for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
426       PBQPNum rowMin = edgeCosts.getRowMin(r);
427       uCosts[r] += rowMin;
428       if (rowMin != std::numeric_limits<PBQPNum>::infinity()) {
429         edgeCosts.subFromRow(r, rowMin);
430       }
431       else {
432         edgeCosts.setRow(r, 0);
433       }
434     }
435
436     for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
437       PBQPNum colMin = edgeCosts.getColMin(c);
438       vCosts[c] += colMin;
439       if (colMin != std::numeric_limits<PBQPNum>::infinity()) {
440         edgeCosts.subFromCol(c, colMin);
441       }
442       else {
443         edgeCosts.setCol(c, 0);
444       }
445     }
446
447     return edgeCosts.isZero();
448   }
449
450   void setup() {
451     setupLinks();
452     heuristic.initialise(*this);
453     setupBuckets();
454   }
455
456   void setupLinks() {
457     for (GraphEdgeIterator edgeItr = g.edgesBegin(), edgeEnd = g.edgesEnd();
458          edgeItr != edgeEnd; ++edgeItr) {
459       g.getEdgeData(edgeItr).setup(edgeItr);
460     }
461   }
462
463   void setupBuckets() {
464     for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
465          nodeItr != nodeEnd; ++nodeItr) {
466       addToBucket(nodeItr);
467     }
468   }
469
470   void computeSolution() {
471     assert(g.areNodeIDsValid() &&
472            "Nodes cannot be added/removed during reduction.");
473
474     reduce();
475     computeTrivialSolutions();
476     backpropagate();
477   }
478
479   void printNode(const GraphNodeIterator &nodeItr) {
480
481     std::cerr << "Node " << g.getNodeID(nodeItr) << " (" << &*nodeItr << "):\n"
482               << "  costs = " << g.getNodeCosts(nodeItr) << "\n"
483               << "  link degree = " << g.getNodeData(nodeItr).getLinkDegree() << "\n"
484               << "  links = [ ";
485
486     for (typename HSIT::NodeData::AdjLinkIterator 
487          aeItr = g.getNodeData(nodeItr).adjLinksBegin(),
488          aeEnd = g.getNodeData(nodeItr).adjLinksEnd();
489          aeItr != aeEnd; ++aeItr) {
490       std::cerr << "(" << g.getNodeID(g.getEdgeNode1Itr(*aeItr))
491                 << ", " << g.getNodeID(g.getEdgeNode2Itr(*aeItr))
492                 << ") ";
493     }
494     std::cout << "]\n";
495   }
496
497   void dumpState() {
498
499     std::cerr << "\n";
500
501     for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
502          nodeItr != nodeEnd; ++nodeItr) {
503       printNode(nodeItr);
504     }
505
506     NodeList* buckets[] = { &r0Bucket, &r1Bucket, &r2Bucket };
507
508     for (unsigned b = 0; b < 3; ++b) {
509       NodeList &bucket = *buckets[b];
510
511       std::cerr << "Bucket " << b << ": [ ";
512
513       for (NodeListIterator nItr = bucket.begin(), nEnd = bucket.end();
514            nItr != nEnd; ++nItr) {
515         std::cerr << g.getNodeID(*nItr) << " ";
516       }
517
518       std::cerr << "]\n";
519     }
520
521     std::cerr << "Stack: [ ";
522     for (NodeStackIterator nsItr = stack.begin(), nsEnd = stack.end();
523          nsItr != nsEnd; ++nsItr) {
524       std::cerr << g.getNodeID(*nsItr) << " ";
525     }
526     std::cerr << "]\n";
527   }
528
529   void reduce() {
530     bool reductionFinished = r1Bucket.empty() && r2Bucket.empty() &&
531       heuristic.rNBucketEmpty();
532
533     while (!reductionFinished) {
534
535       if (!r1Bucket.empty()) {
536         processR1();
537       }
538       else if (!r2Bucket.empty()) {
539         processR2();
540       }
541       else if (!heuristic.rNBucketEmpty()) {
542         solution.setProvedOptimal(false);
543         solution.incRNReductions();
544         heuristic.processRN();
545       } 
546       else reductionFinished = true;
547     }
548       
549   };
550
551   void processR1() {
552
553     // Remove the first node in the R0 bucket:
554     GraphNodeIterator xNodeItr = r1Bucket.front();
555     r1Bucket.pop_front();
556
557     solution.incR1Reductions();
558
559     //std::cerr << "Applying R1 to " << g.getNodeID(xNodeItr) << "\n";
560
561     assert((g.getNodeData(xNodeItr).getLinkDegree() == 1) &&
562            "Node in R1 bucket has degree != 1");
563
564     GraphEdgeIterator edgeItr = *g.getNodeData(xNodeItr).adjLinksBegin();
565
566     const Matrix &edgeCosts = g.getEdgeCosts(edgeItr);
567
568     const Vector &xCosts = g.getNodeCosts(xNodeItr);
569     unsigned xLen = xCosts.getLength();
570
571     // Duplicate a little code to avoid transposing matrices:
572     if (xNodeItr == g.getEdgeNode1Itr(edgeItr)) {
573       GraphNodeIterator yNodeItr = g.getEdgeNode2Itr(edgeItr);
574       Vector &yCosts = g.getNodeCosts(yNodeItr);
575       unsigned yLen = yCosts.getLength();
576
577       for (unsigned j = 0; j < yLen; ++j) {
578         PBQPNum min = edgeCosts[0][j] + xCosts[0];
579         for (unsigned i = 1; i < xLen; ++i) {
580           PBQPNum c = edgeCosts[i][j] + xCosts[i];
581           if (c < min)
582             min = c;
583         }
584         yCosts[j] += min;
585       }
586     }
587     else {
588       GraphNodeIterator yNodeItr = g.getEdgeNode1Itr(edgeItr);
589       Vector &yCosts = g.getNodeCosts(yNodeItr);
590       unsigned yLen = yCosts.getLength();
591
592       for (unsigned i = 0; i < yLen; ++i) {
593         PBQPNum min = edgeCosts[i][0] + xCosts[0];
594
595         for (unsigned j = 1; j < xLen; ++j) {
596           PBQPNum c = edgeCosts[i][j] + xCosts[j];
597           if (c < min)
598             min = c;
599         }
600         yCosts[i] += min;
601       }
602     }
603
604     unlinkNode(xNodeItr);
605     pushStack(xNodeItr);
606   }
607
608   void processR2() {
609
610     GraphNodeIterator xNodeItr = r2Bucket.front();
611     r2Bucket.pop_front();
612
613     solution.incR2Reductions();
614
615     // Unlink is unsafe here. At some point it may optimistically more a node
616     // to a lower-degree list when its degree will later rise, or vice versa,
617     // violating the assumption that node degrees monotonically decrease
618     // during the reduction phase. Instead we'll bucket shuffle manually.
619     pushStack(xNodeItr);
620
621     assert((g.getNodeData(xNodeItr).getLinkDegree() == 2) &&
622            "Node in R2 bucket has degree != 2");
623
624     const Vector &xCosts = g.getNodeCosts(xNodeItr);
625
626     typename NodeData::AdjLinkIterator tempItr =
627       g.getNodeData(xNodeItr).adjLinksBegin();
628
629     GraphEdgeIterator yxEdgeItr = *tempItr,
630                       zxEdgeItr = *(++tempItr);
631
632     GraphNodeIterator yNodeItr = g.getEdgeOtherNode(yxEdgeItr, xNodeItr),
633                       zNodeItr = g.getEdgeOtherNode(zxEdgeItr, xNodeItr);
634
635     removeFromBucket(yNodeItr);
636     removeFromBucket(zNodeItr);
637
638     removeLink(yxEdgeItr, yNodeItr);
639     removeLink(zxEdgeItr, zNodeItr);
640
641     // Graph some of the costs:
642     bool flipEdge1 = (g.getEdgeNode1Itr(yxEdgeItr) == xNodeItr),
643          flipEdge2 = (g.getEdgeNode1Itr(zxEdgeItr) == xNodeItr);
644
645     const Matrix *yxCosts = flipEdge1 ?
646       new Matrix(g.getEdgeCosts(yxEdgeItr).transpose()) :
647       &g.getEdgeCosts(yxEdgeItr),
648                      *zxCosts = flipEdge2 ?
649       new Matrix(g.getEdgeCosts(zxEdgeItr).transpose()) :
650         &g.getEdgeCosts(zxEdgeItr);
651
652     unsigned xLen = xCosts.getLength(),
653              yLen = yxCosts->getRows(),
654              zLen = zxCosts->getRows();
655
656     // Compute delta:
657     Matrix delta(yLen, zLen);
658
659     for (unsigned i = 0; i < yLen; ++i) {
660       for (unsigned j = 0; j < zLen; ++j) {
661         PBQPNum min = (*yxCosts)[i][0] + (*zxCosts)[j][0] + xCosts[0];
662         for (unsigned k = 1; k < xLen; ++k) {
663           PBQPNum c = (*yxCosts)[i][k] + (*zxCosts)[j][k] + xCosts[k];
664           if (c < min) {
665             min = c;
666           }
667         }
668         delta[i][j] = min;
669       }
670     }
671
672     if (flipEdge1)
673       delete yxCosts;
674
675     if (flipEdge2)
676       delete zxCosts;
677
678     // Deal with the potentially induced yz edge.
679     GraphEdgeIterator yzEdgeItr = g.findEdge(yNodeItr, zNodeItr);
680     if (yzEdgeItr == g.edgesEnd()) {
681       yzEdgeItr = g.addEdge(yNodeItr, zNodeItr, delta, EdgeData(g));
682     }
683     else {
684       // There was an edge, but we're going to screw with it. Delete the old
685       // link, update the costs. We'll re-link it later.
686       removeLinkR2(yzEdgeItr);
687       g.getEdgeCosts(yzEdgeItr) +=
688         (yNodeItr == g.getEdgeNode1Itr(yzEdgeItr)) ?
689         delta : delta.transpose();
690     }
691
692     bool nullCostEdge = tryNormaliseEdgeMatrix(yzEdgeItr);
693
694     // Nulled the edge, remove it entirely.
695     if (nullCostEdge) {
696       g.removeEdge(yzEdgeItr);
697     }
698     else {
699       // Edge remains - re-link it.
700       addLink(yzEdgeItr);
701     }
702
703     addToBucket(yNodeItr);
704     addToBucket(zNodeItr);
705     }
706
707   void computeTrivialSolutions() {
708
709     for (NodeListIterator r0Itr = r0Bucket.begin(), r0End = r0Bucket.end();
710          r0Itr != r0End; ++r0Itr) {
711       GraphNodeIterator nodeItr = *r0Itr;
712
713       solution.incR0Reductions();
714       setSolution(nodeItr, g.getNodeCosts(nodeItr).minIndex());
715     }
716
717   }
718
719   void backpropagate() {
720     while (!stack.empty()) {
721       computeSolution(stack.back());
722       stack.pop_back();
723     }
724   }
725
726   void computeSolution(const GraphNodeIterator &nodeItr) {
727
728     NodeData &nodeData = g.getNodeData(nodeItr);
729
730     Vector v(g.getNodeCosts(nodeItr));
731
732     // Solve based on existing links.
733     for (typename NodeData::AdjLinkIterator
734          solvedLinkItr = nodeData.solvedLinksBegin(),
735          solvedLinkEnd = nodeData.solvedLinksEnd();
736          solvedLinkItr != solvedLinkEnd; ++solvedLinkItr) {
737
738       GraphEdgeIterator solvedEdgeItr(*solvedLinkItr);
739       Matrix &edgeCosts = g.getEdgeCosts(solvedEdgeItr);
740
741       if (nodeItr == g.getEdgeNode1Itr(solvedEdgeItr)) {
742         GraphNodeIterator adjNode(g.getEdgeNode2Itr(solvedEdgeItr));
743         unsigned adjSolution =
744           solution.getSelection(g.getNodeID(adjNode));
745         v += edgeCosts.getColAsVector(adjSolution);
746       }
747       else {
748         GraphNodeIterator adjNode(g.getEdgeNode1Itr(solvedEdgeItr));
749         unsigned adjSolution =
750           solution.getSelection(g.getNodeID(adjNode));
751         v += edgeCosts.getRowAsVector(adjSolution);
752       }
753
754     }
755
756     setSolution(nodeItr, v.minIndex());
757   }
758
759   void computeSolutionCost(const SimpleGraph &orig) {
760     PBQPNum cost = 0.0;
761
762     for (SimpleGraph::ConstNodeIterator
763          nodeItr = orig.nodesBegin(), nodeEnd = orig.nodesEnd();
764          nodeItr != nodeEnd; ++nodeItr) {
765
766       unsigned nodeId = orig.getNodeID(nodeItr);
767
768       cost += orig.getNodeCosts(nodeItr)[solution.getSelection(nodeId)];
769     }
770
771     for (SimpleGraph::ConstEdgeIterator
772          edgeItr = orig.edgesBegin(), edgeEnd = orig.edgesEnd();
773          edgeItr != edgeEnd; ++edgeItr) {
774
775       SimpleGraph::ConstNodeIterator n1 = orig.getEdgeNode1Itr(edgeItr),
776                                      n2 = orig.getEdgeNode2Itr(edgeItr);
777       unsigned sol1 = solution.getSelection(orig.getNodeID(n1)),
778                sol2 = solution.getSelection(orig.getNodeID(n2));
779
780       cost += orig.getEdgeCosts(edgeItr)[sol1][sol2];
781     }
782
783     solution.setSolutionCost(cost);
784   }
785
786 };
787
788 template <typename Heuristic>
789 class HeuristicSolver : public Solver {
790 public:
791   Solution solve(const SimpleGraph &g) const {
792     HeuristicSolverImpl<Heuristic> solverImpl(g);
793     return solverImpl.getSolution();
794   }
795 };
796
797 }
798
799 #endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H