Fixed malformed -*- lines in PBQP headers.
[oota-llvm.git] / lib / CodeGen / PBQP / HeuristicSolver.h
1 //===-- HeuristicSolver.h - Heuristic PBQP Solver ---------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Heuristic PBQP solver. This solver is able to perform optimal reductions for
11 // nodes of degree 0, 1 or 2. For nodes of degree >2 a plugable heuristic is
12 // used to to select a node for reduction. 
13 //
14 //===----------------------------------------------------------------------===//
15
16 #ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
17 #define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
18
19 #include "Solver.h"
20 #include "AnnotatedGraph.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include <limits>
23
24 namespace PBQP {
25
26 /// \brief Important types for the HeuristicSolverImpl.
27 /// 
28 /// Declared seperately to allow access to heuristic classes before the solver
29 /// is fully constructed.
30 template <typename HeuristicNodeData, typename HeuristicEdgeData>
31 class HSITypes {
32 public:
33
34   class NodeData;
35   class EdgeData;
36
37   typedef AnnotatedGraph<NodeData, EdgeData> SolverGraph;
38   typedef typename SolverGraph::NodeIterator GraphNodeIterator;
39   typedef typename SolverGraph::EdgeIterator GraphEdgeIterator;
40   typedef typename SolverGraph::AdjEdgeIterator GraphAdjEdgeIterator;
41
42   typedef std::list<GraphNodeIterator> NodeList;
43   typedef typename NodeList::iterator NodeListIterator;
44
45   typedef std::vector<GraphNodeIterator> NodeStack;
46   typedef typename NodeStack::iterator NodeStackIterator;
47
48   class NodeData {
49     friend class EdgeData;
50
51   private:
52
53     typedef std::list<GraphEdgeIterator> LinksList;
54
55     unsigned numLinks;
56     LinksList links, solvedLinks;
57     NodeListIterator bucketItr;
58     HeuristicNodeData heuristicData;
59
60   public:
61
62     typedef typename LinksList::iterator AdjLinkIterator;
63
64   private:
65
66     AdjLinkIterator addLink(const GraphEdgeIterator &edgeItr) {
67       ++numLinks;
68       return links.insert(links.end(), edgeItr);
69     }
70
71     void delLink(const AdjLinkIterator &adjLinkItr) {
72       --numLinks;
73       links.erase(adjLinkItr);
74     }
75
76   public:
77
78     NodeData() : numLinks(0) {}
79
80     unsigned getLinkDegree() const { return numLinks; }
81
82     HeuristicNodeData& getHeuristicData() { return heuristicData; }
83     const HeuristicNodeData& getHeuristicData() const {
84       return heuristicData;
85     }
86
87     void setBucketItr(const NodeListIterator &bucketItr) {
88       this->bucketItr = bucketItr;
89     }
90
91     const NodeListIterator& getBucketItr() const {
92       return bucketItr;
93     }
94
95     AdjLinkIterator adjLinksBegin() {
96       return links.begin();
97     }
98
99     AdjLinkIterator adjLinksEnd() {
100       return links.end();
101     }
102
103     void addSolvedLink(const GraphEdgeIterator &solvedLinkItr) {
104       solvedLinks.push_back(solvedLinkItr);
105     }
106
107     AdjLinkIterator solvedLinksBegin() {
108       return solvedLinks.begin();
109     }
110
111     AdjLinkIterator solvedLinksEnd() {
112       return solvedLinks.end();
113     }
114
115   };
116
117   class EdgeData {
118   private:
119
120     SolverGraph &g;
121     GraphNodeIterator node1Itr, node2Itr;
122     HeuristicEdgeData heuristicData;
123     typename NodeData::AdjLinkIterator node1ThisEdgeItr, node2ThisEdgeItr;
124
125   public:
126
127     EdgeData(SolverGraph &g) : g(g) {}
128
129     HeuristicEdgeData& getHeuristicData() { return heuristicData; }
130     const HeuristicEdgeData& getHeuristicData() const {
131       return heuristicData;
132     }
133
134     void setup(const GraphEdgeIterator &thisEdgeItr) {
135       node1Itr = g.getEdgeNode1Itr(thisEdgeItr);
136       node2Itr = g.getEdgeNode2Itr(thisEdgeItr);
137
138       node1ThisEdgeItr = g.getNodeData(node1Itr).addLink(thisEdgeItr);
139       node2ThisEdgeItr = g.getNodeData(node2Itr).addLink(thisEdgeItr);
140     }
141
142     void unlink() {
143       g.getNodeData(node1Itr).delLink(node1ThisEdgeItr);
144       g.getNodeData(node2Itr).delLink(node2ThisEdgeItr);
145     }
146
147   };
148
149 };
150
151 template <typename Heuristic>
152 class HeuristicSolverImpl {
153 public:
154   // Typedefs to make life easier:
155   typedef HSITypes<typename Heuristic::NodeData,
156                    typename Heuristic::EdgeData> HSIT;
157   typedef typename HSIT::SolverGraph SolverGraph;
158   typedef typename HSIT::NodeData NodeData;
159   typedef typename HSIT::EdgeData EdgeData;
160   typedef typename HSIT::GraphNodeIterator GraphNodeIterator;
161   typedef typename HSIT::GraphEdgeIterator GraphEdgeIterator;
162   typedef typename HSIT::GraphAdjEdgeIterator GraphAdjEdgeIterator;
163
164   typedef typename HSIT::NodeList NodeList;
165   typedef typename HSIT::NodeListIterator NodeListIterator;
166
167   typedef std::vector<GraphNodeIterator> NodeStack;
168   typedef typename NodeStack::iterator NodeStackIterator;
169
170   /// \brief Constructor, which performs all the actual solver work.
171   HeuristicSolverImpl(const SimpleGraph &orig) :
172     solution(orig.getNumNodes(), true)
173   {
174     copyGraph(orig);
175     simplify();
176     setup();
177     computeSolution();
178     computeSolutionCost(orig);
179   }
180
181   /// \brief Returns the graph for this solver.
182   SolverGraph& getGraph() { return g; }
183
184   /// \brief Return the solution found by this solver.
185   const Solution& getSolution() const { return solution; }
186
187 private:
188
189   /// \brief Add the given node to the appropriate bucket for its link
190   /// degree.
191   void addToBucket(const GraphNodeIterator &nodeItr) {
192     NodeData &nodeData = g.getNodeData(nodeItr);
193
194     switch (nodeData.getLinkDegree()) {
195       case 0: nodeData.setBucketItr(
196                 r0Bucket.insert(r0Bucket.end(), nodeItr));
197               break;                                            
198       case 1: nodeData.setBucketItr(
199                 r1Bucket.insert(r1Bucket.end(), nodeItr));
200               break;
201       case 2: nodeData.setBucketItr(
202                 r2Bucket.insert(r2Bucket.end(), nodeItr));
203               break;
204       default: heuristic.addToRNBucket(nodeItr);
205                break;
206     }
207   }
208
209   /// \brief Remove the given node from the appropriate bucket for its link
210   /// degree.
211   void removeFromBucket(const GraphNodeIterator &nodeItr) {
212     NodeData &nodeData = g.getNodeData(nodeItr);
213
214     switch (nodeData.getLinkDegree()) {
215       case 0: r0Bucket.erase(nodeData.getBucketItr()); break;
216       case 1: r1Bucket.erase(nodeData.getBucketItr()); break;
217       case 2: r2Bucket.erase(nodeData.getBucketItr()); break;
218       default: heuristic.removeFromRNBucket(nodeItr); break;
219     }
220   }
221
222 public:
223
224   /// \brief Add a link.
225   void addLink(const GraphEdgeIterator &edgeItr) {
226     g.getEdgeData(edgeItr).setup(edgeItr);
227
228     if ((g.getNodeData(g.getEdgeNode1Itr(edgeItr)).getLinkDegree() > 2) ||
229         (g.getNodeData(g.getEdgeNode2Itr(edgeItr)).getLinkDegree() > 2)) {
230       heuristic.handleAddLink(edgeItr);
231     }
232   }
233
234   /// \brief Remove link, update info for node.
235   ///
236   /// Only updates information for the given node, since usually the other
237   /// is about to be removed.
238   void removeLink(const GraphEdgeIterator &edgeItr,
239                   const GraphNodeIterator &nodeItr) {
240
241     if (g.getNodeData(nodeItr).getLinkDegree() > 2) {
242       heuristic.handleRemoveLink(edgeItr, nodeItr);
243     }
244     g.getEdgeData(edgeItr).unlink();
245   }
246
247   /// \brief Remove link, update info for both nodes. Useful for R2 only.
248   void removeLinkR2(const GraphEdgeIterator &edgeItr) {
249     GraphNodeIterator node1Itr = g.getEdgeNode1Itr(edgeItr);
250
251     if (g.getNodeData(node1Itr).getLinkDegree() > 2) {
252       heuristic.handleRemoveLink(edgeItr, node1Itr);
253     }
254     removeLink(edgeItr, g.getEdgeNode2Itr(edgeItr));
255   }
256
257   /// \brief Removes all links connected to the given node.
258   void unlinkNode(const GraphNodeIterator &nodeItr) {
259     NodeData &nodeData = g.getNodeData(nodeItr);
260
261     typedef std::vector<GraphEdgeIterator> TempEdgeList;
262
263     TempEdgeList edgesToUnlink;
264     edgesToUnlink.reserve(nodeData.getLinkDegree());
265
266     // Copy adj edges into a temp vector. We want to destroy them during
267     // the unlink, and we can't do that while we're iterating over them.
268     std::copy(nodeData.adjLinksBegin(), nodeData.adjLinksEnd(),
269               std::back_inserter(edgesToUnlink));
270
271     for (typename TempEdgeList::iterator
272          edgeItr = edgesToUnlink.begin(), edgeEnd = edgesToUnlink.end();
273          edgeItr != edgeEnd; ++edgeItr) {
274
275       GraphNodeIterator otherNode = g.getEdgeOtherNode(*edgeItr, nodeItr);
276
277       removeFromBucket(otherNode);
278       removeLink(*edgeItr, otherNode);
279       addToBucket(otherNode);
280     }
281   }
282
283   /// \brief Push the given node onto the stack to be solved with
284   /// backpropagation.
285   void pushStack(const GraphNodeIterator &nodeItr) {
286     stack.push_back(nodeItr);
287   }
288
289   /// \brief Set the solution of the given node.
290   void setSolution(const GraphNodeIterator &nodeItr, unsigned solIndex) {
291     solution.setSelection(g.getNodeID(nodeItr), solIndex);
292
293     for (GraphAdjEdgeIterator adjEdgeItr = g.adjEdgesBegin(nodeItr),
294          adjEdgeEnd = g.adjEdgesEnd(nodeItr);
295          adjEdgeItr != adjEdgeEnd; ++adjEdgeItr) {
296       GraphEdgeIterator edgeItr(*adjEdgeItr);
297       GraphNodeIterator adjNodeItr(g.getEdgeOtherNode(edgeItr, nodeItr));
298       g.getNodeData(adjNodeItr).addSolvedLink(edgeItr);
299     }
300   }
301
302 private:
303
304   SolverGraph g;
305   Heuristic heuristic;
306   Solution solution;
307
308   NodeList r0Bucket,
309            r1Bucket,
310            r2Bucket;
311
312   NodeStack stack;
313
314   // Copy the SimpleGraph into an annotated graph which we can use for reduction.
315   void copyGraph(const SimpleGraph &orig) {
316
317     assert((g.getNumEdges() == 0) && (g.getNumNodes() == 0) &&
318            "Graph should be empty prior to solver setup.");
319
320     assert(orig.areNodeIDsValid() &&
321            "Cannot copy from a graph with invalid node IDs.");
322
323     std::vector<GraphNodeIterator> newNodeItrs;
324
325     for (unsigned nodeID = 0; nodeID < orig.getNumNodes(); ++nodeID) {
326       newNodeItrs.push_back(
327         g.addNode(orig.getNodeCosts(orig.getNodeItr(nodeID)), NodeData()));
328     }
329
330     for (SimpleGraph::ConstEdgeIterator
331          origEdgeItr = orig.edgesBegin(), origEdgeEnd = orig.edgesEnd();
332          origEdgeItr != origEdgeEnd; ++origEdgeItr) {
333
334       unsigned id1 = orig.getNodeID(orig.getEdgeNode1Itr(origEdgeItr)),
335                id2 = orig.getNodeID(orig.getEdgeNode2Itr(origEdgeItr));
336
337       g.addEdge(newNodeItrs[id1], newNodeItrs[id2],
338                 orig.getEdgeCosts(origEdgeItr), EdgeData(g));
339     }
340
341     // Assign IDs to the new nodes using the ordering from the old graph,
342     // this will lead to nodes in the new graph getting the same ID as the
343     // corresponding node in the old graph.
344     g.assignNodeIDs(newNodeItrs);
345   }
346
347   // Simplify the annotated graph by eliminating independent edges and trivial
348   // nodes. 
349   void simplify() {
350     disconnectTrivialNodes();
351     eliminateIndependentEdges();
352   }
353
354   // Eliminate trivial nodes.
355   void disconnectTrivialNodes() {
356     for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
357          nodeItr != nodeEnd; ++nodeItr) {
358
359       if (g.getNodeCosts(nodeItr).getLength() == 1) {
360
361         std::vector<GraphEdgeIterator> edgesToRemove;
362
363         for (GraphAdjEdgeIterator adjEdgeItr = g.adjEdgesBegin(nodeItr),
364              adjEdgeEnd = g.adjEdgesEnd(nodeItr);
365              adjEdgeItr != adjEdgeEnd; ++adjEdgeItr) {
366
367           GraphEdgeIterator edgeItr = *adjEdgeItr;
368
369           if (g.getEdgeNode1Itr(edgeItr) == nodeItr) {
370             GraphNodeIterator otherNodeItr = g.getEdgeNode2Itr(edgeItr);
371             g.getNodeCosts(otherNodeItr) +=
372               g.getEdgeCosts(edgeItr).getRowAsVector(0);
373           }
374           else {
375             GraphNodeIterator otherNodeItr = g.getEdgeNode1Itr(edgeItr);
376             g.getNodeCosts(otherNodeItr) +=
377               g.getEdgeCosts(edgeItr).getColAsVector(0);
378           }
379
380           edgesToRemove.push_back(edgeItr);
381         }
382
383         while (!edgesToRemove.empty()) {
384           g.removeEdge(edgesToRemove.back());
385           edgesToRemove.pop_back();
386         }
387       }
388     }
389   }
390
391   void eliminateIndependentEdges() {
392     std::vector<GraphEdgeIterator> edgesToProcess;
393
394     for (GraphEdgeIterator edgeItr = g.edgesBegin(), edgeEnd = g.edgesEnd();
395          edgeItr != edgeEnd; ++edgeItr) {
396       edgesToProcess.push_back(edgeItr);
397     }
398
399     while (!edgesToProcess.empty()) {
400       tryToEliminateEdge(edgesToProcess.back());
401       edgesToProcess.pop_back();
402     }
403   }
404
405   void tryToEliminateEdge(const GraphEdgeIterator &edgeItr) {
406     if (tryNormaliseEdgeMatrix(edgeItr)) {
407       g.removeEdge(edgeItr); 
408     }
409   }
410
411   bool tryNormaliseEdgeMatrix(const GraphEdgeIterator &edgeItr) {
412
413     Matrix &edgeCosts = g.getEdgeCosts(edgeItr);
414     Vector &uCosts = g.getNodeCosts(g.getEdgeNode1Itr(edgeItr)),
415                &vCosts = g.getNodeCosts(g.getEdgeNode2Itr(edgeItr));
416
417     for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
418       PBQPNum rowMin = edgeCosts.getRowMin(r);
419       uCosts[r] += rowMin;
420       if (rowMin != std::numeric_limits<PBQPNum>::infinity()) {
421         edgeCosts.subFromRow(r, rowMin);
422       }
423       else {
424         edgeCosts.setRow(r, 0);
425       }
426     }
427
428     for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
429       PBQPNum colMin = edgeCosts.getColMin(c);
430       vCosts[c] += colMin;
431       if (colMin != std::numeric_limits<PBQPNum>::infinity()) {
432         edgeCosts.subFromCol(c, colMin);
433       }
434       else {
435         edgeCosts.setCol(c, 0);
436       }
437     }
438
439     return edgeCosts.isZero();
440   }
441
442   void setup() {
443     setupLinks();
444     heuristic.initialise(*this);
445     setupBuckets();
446   }
447
448   void setupLinks() {
449     for (GraphEdgeIterator edgeItr = g.edgesBegin(), edgeEnd = g.edgesEnd();
450          edgeItr != edgeEnd; ++edgeItr) {
451       g.getEdgeData(edgeItr).setup(edgeItr);
452     }
453   }
454
455   void setupBuckets() {
456     for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
457          nodeItr != nodeEnd; ++nodeItr) {
458       addToBucket(nodeItr);
459     }
460   }
461
462   void computeSolution() {
463     assert(g.areNodeIDsValid() &&
464            "Nodes cannot be added/removed during reduction.");
465
466     reduce();
467     computeTrivialSolutions();
468     backpropagate();
469   }
470
471   void printNode(const GraphNodeIterator &nodeItr) {
472     llvm::errs() << "Node " << g.getNodeID(nodeItr) << " (" << &*nodeItr << "):\n"
473                  << "  costs = " << g.getNodeCosts(nodeItr) << "\n"
474                  << "  link degree = " << g.getNodeData(nodeItr).getLinkDegree() << "\n"
475                  << "  links = [ ";
476
477     for (typename HSIT::NodeData::AdjLinkIterator 
478          aeItr = g.getNodeData(nodeItr).adjLinksBegin(),
479          aeEnd = g.getNodeData(nodeItr).adjLinksEnd();
480          aeItr != aeEnd; ++aeItr) {
481       llvm::errs() << "(" << g.getNodeID(g.getEdgeNode1Itr(*aeItr))
482                    << ", " << g.getNodeID(g.getEdgeNode2Itr(*aeItr))
483                    << ") ";
484     }
485     llvm::errs() << "]\n";
486   }
487
488   void dumpState() {
489     llvm::errs() << "\n";
490
491     for (GraphNodeIterator nodeItr = g.nodesBegin(), nodeEnd = g.nodesEnd();
492          nodeItr != nodeEnd; ++nodeItr) {
493       printNode(nodeItr);
494     }
495
496     NodeList* buckets[] = { &r0Bucket, &r1Bucket, &r2Bucket };
497
498     for (unsigned b = 0; b < 3; ++b) {
499       NodeList &bucket = *buckets[b];
500
501       llvm::errs() << "Bucket " << b << ": [ ";
502
503       for (NodeListIterator nItr = bucket.begin(), nEnd = bucket.end();
504            nItr != nEnd; ++nItr) {
505         llvm::errs() << g.getNodeID(*nItr) << " ";
506       }
507
508       llvm::errs() << "]\n";
509     }
510
511     llvm::errs() << "Stack: [ ";
512     for (NodeStackIterator nsItr = stack.begin(), nsEnd = stack.end();
513          nsItr != nsEnd; ++nsItr) {
514       llvm::errs() << g.getNodeID(*nsItr) << " ";
515     }
516     llvm::errs() << "]\n";
517   }
518
519   void reduce() {
520     bool reductionFinished = r1Bucket.empty() && r2Bucket.empty() &&
521       heuristic.rNBucketEmpty();
522
523     while (!reductionFinished) {
524
525       if (!r1Bucket.empty()) {
526         processR1();
527       }
528       else if (!r2Bucket.empty()) {
529         processR2();
530       }
531       else if (!heuristic.rNBucketEmpty()) {
532         solution.setProvedOptimal(false);
533         solution.incRNReductions();
534         heuristic.processRN();
535       } 
536       else reductionFinished = true;
537     }
538       
539   }
540
541   void processR1() {
542
543     // Remove the first node in the R0 bucket:
544     GraphNodeIterator xNodeItr = r1Bucket.front();
545     r1Bucket.pop_front();
546
547     solution.incR1Reductions();
548
549     //llvm::errs() << "Applying R1 to " << g.getNodeID(xNodeItr) << "\n";
550
551     assert((g.getNodeData(xNodeItr).getLinkDegree() == 1) &&
552            "Node in R1 bucket has degree != 1");
553
554     GraphEdgeIterator edgeItr = *g.getNodeData(xNodeItr).adjLinksBegin();
555
556     const Matrix &edgeCosts = g.getEdgeCosts(edgeItr);
557
558     const Vector &xCosts = g.getNodeCosts(xNodeItr);
559     unsigned xLen = xCosts.getLength();
560
561     // Duplicate a little code to avoid transposing matrices:
562     if (xNodeItr == g.getEdgeNode1Itr(edgeItr)) {
563       GraphNodeIterator yNodeItr = g.getEdgeNode2Itr(edgeItr);
564       Vector &yCosts = g.getNodeCosts(yNodeItr);
565       unsigned yLen = yCosts.getLength();
566
567       for (unsigned j = 0; j < yLen; ++j) {
568         PBQPNum min = edgeCosts[0][j] + xCosts[0];
569         for (unsigned i = 1; i < xLen; ++i) {
570           PBQPNum c = edgeCosts[i][j] + xCosts[i];
571           if (c < min)
572             min = c;
573         }
574         yCosts[j] += min;
575       }
576     }
577     else {
578       GraphNodeIterator yNodeItr = g.getEdgeNode1Itr(edgeItr);
579       Vector &yCosts = g.getNodeCosts(yNodeItr);
580       unsigned yLen = yCosts.getLength();
581
582       for (unsigned i = 0; i < yLen; ++i) {
583         PBQPNum min = edgeCosts[i][0] + xCosts[0];
584
585         for (unsigned j = 1; j < xLen; ++j) {
586           PBQPNum c = edgeCosts[i][j] + xCosts[j];
587           if (c < min)
588             min = c;
589         }
590         yCosts[i] += min;
591       }
592     }
593
594     unlinkNode(xNodeItr);
595     pushStack(xNodeItr);
596   }
597
598   void processR2() {
599
600     GraphNodeIterator xNodeItr = r2Bucket.front();
601     r2Bucket.pop_front();
602
603     solution.incR2Reductions();
604
605     // Unlink is unsafe here. At some point it may optimistically more a node
606     // to a lower-degree list when its degree will later rise, or vice versa,
607     // violating the assumption that node degrees monotonically decrease
608     // during the reduction phase. Instead we'll bucket shuffle manually.
609     pushStack(xNodeItr);
610
611     assert((g.getNodeData(xNodeItr).getLinkDegree() == 2) &&
612            "Node in R2 bucket has degree != 2");
613
614     const Vector &xCosts = g.getNodeCosts(xNodeItr);
615
616     typename NodeData::AdjLinkIterator tempItr =
617       g.getNodeData(xNodeItr).adjLinksBegin();
618
619     GraphEdgeIterator yxEdgeItr = *tempItr,
620                       zxEdgeItr = *(++tempItr);
621
622     GraphNodeIterator yNodeItr = g.getEdgeOtherNode(yxEdgeItr, xNodeItr),
623                       zNodeItr = g.getEdgeOtherNode(zxEdgeItr, xNodeItr);
624
625     removeFromBucket(yNodeItr);
626     removeFromBucket(zNodeItr);
627
628     removeLink(yxEdgeItr, yNodeItr);
629     removeLink(zxEdgeItr, zNodeItr);
630
631     // Graph some of the costs:
632     bool flipEdge1 = (g.getEdgeNode1Itr(yxEdgeItr) == xNodeItr),
633          flipEdge2 = (g.getEdgeNode1Itr(zxEdgeItr) == xNodeItr);
634
635     const Matrix *yxCosts = flipEdge1 ?
636       new Matrix(g.getEdgeCosts(yxEdgeItr).transpose()) :
637       &g.getEdgeCosts(yxEdgeItr),
638                      *zxCosts = flipEdge2 ?
639       new Matrix(g.getEdgeCosts(zxEdgeItr).transpose()) :
640         &g.getEdgeCosts(zxEdgeItr);
641
642     unsigned xLen = xCosts.getLength(),
643              yLen = yxCosts->getRows(),
644              zLen = zxCosts->getRows();
645
646     // Compute delta:
647     Matrix delta(yLen, zLen);
648
649     for (unsigned i = 0; i < yLen; ++i) {
650       for (unsigned j = 0; j < zLen; ++j) {
651         PBQPNum min = (*yxCosts)[i][0] + (*zxCosts)[j][0] + xCosts[0];
652         for (unsigned k = 1; k < xLen; ++k) {
653           PBQPNum c = (*yxCosts)[i][k] + (*zxCosts)[j][k] + xCosts[k];
654           if (c < min) {
655             min = c;
656           }
657         }
658         delta[i][j] = min;
659       }
660     }
661
662     if (flipEdge1)
663       delete yxCosts;
664
665     if (flipEdge2)
666       delete zxCosts;
667
668     // Deal with the potentially induced yz edge.
669     GraphEdgeIterator yzEdgeItr = g.findEdge(yNodeItr, zNodeItr);
670     if (yzEdgeItr == g.edgesEnd()) {
671       yzEdgeItr = g.addEdge(yNodeItr, zNodeItr, delta, EdgeData(g));
672     }
673     else {
674       // There was an edge, but we're going to screw with it. Delete the old
675       // link, update the costs. We'll re-link it later.
676       removeLinkR2(yzEdgeItr);
677       g.getEdgeCosts(yzEdgeItr) +=
678         (yNodeItr == g.getEdgeNode1Itr(yzEdgeItr)) ?
679         delta : delta.transpose();
680     }
681
682     bool nullCostEdge = tryNormaliseEdgeMatrix(yzEdgeItr);
683
684     // Nulled the edge, remove it entirely.
685     if (nullCostEdge) {
686       g.removeEdge(yzEdgeItr);
687     }
688     else {
689       // Edge remains - re-link it.
690       addLink(yzEdgeItr);
691     }
692
693     addToBucket(yNodeItr);
694     addToBucket(zNodeItr);
695     }
696
697   void computeTrivialSolutions() {
698
699     for (NodeListIterator r0Itr = r0Bucket.begin(), r0End = r0Bucket.end();
700          r0Itr != r0End; ++r0Itr) {
701       GraphNodeIterator nodeItr = *r0Itr;
702
703       solution.incR0Reductions();
704       setSolution(nodeItr, g.getNodeCosts(nodeItr).minIndex());
705     }
706
707   }
708
709   void backpropagate() {
710     while (!stack.empty()) {
711       computeSolution(stack.back());
712       stack.pop_back();
713     }
714   }
715
716   void computeSolution(const GraphNodeIterator &nodeItr) {
717
718     NodeData &nodeData = g.getNodeData(nodeItr);
719
720     Vector v(g.getNodeCosts(nodeItr));
721
722     // Solve based on existing links.
723     for (typename NodeData::AdjLinkIterator
724          solvedLinkItr = nodeData.solvedLinksBegin(),
725          solvedLinkEnd = nodeData.solvedLinksEnd();
726          solvedLinkItr != solvedLinkEnd; ++solvedLinkItr) {
727
728       GraphEdgeIterator solvedEdgeItr(*solvedLinkItr);
729       Matrix &edgeCosts = g.getEdgeCosts(solvedEdgeItr);
730
731       if (nodeItr == g.getEdgeNode1Itr(solvedEdgeItr)) {
732         GraphNodeIterator adjNode(g.getEdgeNode2Itr(solvedEdgeItr));
733         unsigned adjSolution =
734           solution.getSelection(g.getNodeID(adjNode));
735         v += edgeCosts.getColAsVector(adjSolution);
736       }
737       else {
738         GraphNodeIterator adjNode(g.getEdgeNode1Itr(solvedEdgeItr));
739         unsigned adjSolution =
740           solution.getSelection(g.getNodeID(adjNode));
741         v += edgeCosts.getRowAsVector(adjSolution);
742       }
743
744     }
745
746     setSolution(nodeItr, v.minIndex());
747   }
748
749   void computeSolutionCost(const SimpleGraph &orig) {
750     PBQPNum cost = 0.0;
751
752     for (SimpleGraph::ConstNodeIterator
753          nodeItr = orig.nodesBegin(), nodeEnd = orig.nodesEnd();
754          nodeItr != nodeEnd; ++nodeItr) {
755
756       unsigned nodeId = orig.getNodeID(nodeItr);
757
758       cost += orig.getNodeCosts(nodeItr)[solution.getSelection(nodeId)];
759     }
760
761     for (SimpleGraph::ConstEdgeIterator
762          edgeItr = orig.edgesBegin(), edgeEnd = orig.edgesEnd();
763          edgeItr != edgeEnd; ++edgeItr) {
764
765       SimpleGraph::ConstNodeIterator n1 = orig.getEdgeNode1Itr(edgeItr),
766                                      n2 = orig.getEdgeNode2Itr(edgeItr);
767       unsigned sol1 = solution.getSelection(orig.getNodeID(n1)),
768                sol2 = solution.getSelection(orig.getNodeID(n2));
769
770       cost += orig.getEdgeCosts(edgeItr)[sol1][sol2];
771     }
772
773     solution.setSolutionCost(cost);
774   }
775
776 };
777
778 template <typename Heuristic>
779 class HeuristicSolver : public Solver {
780 public:
781   Solution solve(const SimpleGraph &g) const {
782     HeuristicSolverImpl<Heuristic> solverImpl(g);
783     return solverImpl.getSolution();
784   }
785 };
786
787 }
788
789 #endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H