Convert some tab stops into spaces.
[oota-llvm.git] / lib / CodeGen / PBQP / HeuristicSolver.h
1 //===-- HeuristicSolver.h - Heuristic PBQP Solver --------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Heuristic PBQP solver. This solver is able to perform optimal reductions for
11 // nodes of degree 0, 1 or 2. For nodes of degree >2 a plugable heuristic is
12 // used to select a node for reduction. 
13 //
14 //===----------------------------------------------------------------------===//
15
16 #ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
17 #define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
18
19 #include "Graph.h"
20 #include "Solution.h"
21 #include <vector>
22 #include <limits>
23
24 namespace PBQP {
25
26   /// \brief Heuristic PBQP solver implementation.
27   ///
28   /// This class should usually be created (and destroyed) indirectly via a call
29   /// to HeuristicSolver<HImpl>::solve(Graph&).
30   /// See the comments for HeuristicSolver.
31   ///
32   /// HeuristicSolverImpl provides the R0, R1 and R2 reduction rules,
33   /// backpropagation phase, and maintains the internal copy of the graph on
34   /// which the reduction is carried out (the original being kept to facilitate
35   /// backpropagation).
36   template <typename HImpl>
37   class HeuristicSolverImpl {
38   private:
39
40     typedef typename HImpl::NodeData HeuristicNodeData;
41     typedef typename HImpl::EdgeData HeuristicEdgeData;
42
43     typedef std::list<Graph::EdgeItr> SolverEdges;
44
45   public:
46   
47     /// \brief Iterator type for edges in the solver graph.
48     typedef SolverEdges::iterator SolverEdgeItr;
49
50   private:
51
52     class NodeData {
53     public:
54       NodeData() : solverDegree(0) {}
55
56       HeuristicNodeData& getHeuristicData() { return hData; }
57
58       SolverEdgeItr addSolverEdge(Graph::EdgeItr eItr) {
59         ++solverDegree;
60         return solverEdges.insert(solverEdges.end(), eItr);
61       }
62
63       void removeSolverEdge(SolverEdgeItr seItr) {
64         --solverDegree;
65         solverEdges.erase(seItr);
66       }
67
68       SolverEdgeItr solverEdgesBegin() { return solverEdges.begin(); }
69       SolverEdgeItr solverEdgesEnd() { return solverEdges.end(); }
70       unsigned getSolverDegree() const { return solverDegree; }
71       void clearSolverEdges() {
72         solverDegree = 0;
73         solverEdges.clear(); 
74       }
75       
76     private:
77       HeuristicNodeData hData;
78       unsigned solverDegree;
79       SolverEdges solverEdges;
80     };
81  
82     class EdgeData {
83     public:
84       HeuristicEdgeData& getHeuristicData() { return hData; }
85
86       void setN1SolverEdgeItr(SolverEdgeItr n1SolverEdgeItr) {
87         this->n1SolverEdgeItr = n1SolverEdgeItr;
88       }
89
90       SolverEdgeItr getN1SolverEdgeItr() { return n1SolverEdgeItr; }
91
92       void setN2SolverEdgeItr(SolverEdgeItr n2SolverEdgeItr){
93         this->n2SolverEdgeItr = n2SolverEdgeItr;
94       }
95
96       SolverEdgeItr getN2SolverEdgeItr() { return n2SolverEdgeItr; }
97
98     private:
99
100       HeuristicEdgeData hData;
101       SolverEdgeItr n1SolverEdgeItr, n2SolverEdgeItr;
102     };
103
104     Graph &g;
105     HImpl h;
106     Solution s;
107     std::vector<Graph::NodeItr> stack;
108
109     typedef std::list<NodeData> NodeDataList;
110     NodeDataList nodeDataList;
111
112     typedef std::list<EdgeData> EdgeDataList;
113     EdgeDataList edgeDataList;
114
115   public:
116
117     /// \brief Construct a heuristic solver implementation to solve the given
118     ///        graph.
119     /// @param g The graph representing the problem instance to be solved.
120     HeuristicSolverImpl(Graph &g) : g(g), h(*this) {}  
121
122     /// \brief Get the graph being solved by this solver.
123     /// @return The graph representing the problem instance being solved by this
124     ///         solver.
125     Graph& getGraph() { return g; }
126
127     /// \brief Get the heuristic data attached to the given node.
128     /// @param nItr Node iterator.
129     /// @return The heuristic data attached to the given node.
130     HeuristicNodeData& getHeuristicNodeData(Graph::NodeItr nItr) {
131       return getSolverNodeData(nItr).getHeuristicData();
132     }
133
134     /// \brief Get the heuristic data attached to the given edge.
135     /// @param eItr Edge iterator.
136     /// @return The heuristic data attached to the given node.
137     HeuristicEdgeData& getHeuristicEdgeData(Graph::EdgeItr eItr) {
138       return getSolverEdgeData(eItr).getHeuristicData();
139     }
140
141     /// \brief Begin iterator for the set of edges adjacent to the given node in
142     ///        the solver graph.
143     /// @param nItr Node iterator.
144     /// @return Begin iterator for the set of edges adjacent to the given node
145     ///         in the solver graph. 
146     SolverEdgeItr solverEdgesBegin(Graph::NodeItr nItr) {
147       return getSolverNodeData(nItr).solverEdgesBegin();
148     }
149
150     /// \brief End iterator for the set of edges adjacent to the given node in
151     ///        the solver graph.
152     /// @param nItr Node iterator.
153     /// @return End iterator for the set of edges adjacent to the given node in
154     ///         the solver graph. 
155     SolverEdgeItr solverEdgesEnd(Graph::NodeItr nItr) {
156       return getSolverNodeData(nItr).solverEdgesEnd();
157     }
158
159     /// \brief Remove a node from the solver graph.
160     /// @param eItr Edge iterator for edge to be removed.
161     ///
162     /// Does <i>not</i> notify the heuristic of the removal. That should be
163     /// done manually if necessary.
164     void removeSolverEdge(Graph::EdgeItr eItr) {
165       EdgeData &eData = getSolverEdgeData(eItr);
166       NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eItr)),
167                &n2Data = getSolverNodeData(g.getEdgeNode2(eItr));
168
169       n1Data.removeSolverEdge(eData.getN1SolverEdgeItr());
170       n2Data.removeSolverEdge(eData.getN2SolverEdgeItr());
171     }
172
173     /// \brief Compute a solution to the PBQP problem instance with which this
174     ///        heuristic solver was constructed.
175     /// @return A solution to the PBQP problem.
176     ///
177     /// Performs the full PBQP heuristic solver algorithm, including setup,
178     /// calls to the heuristic (which will call back to the reduction rules in
179     /// this class), and cleanup.
180     Solution computeSolution() {
181       setup();
182       h.setup();
183       h.reduce();
184       backpropagate();
185       h.cleanup();
186       cleanup();
187       return s;
188     }
189
190     /// \brief Add to the end of the stack.
191     /// @param nItr Node iterator to add to the reduction stack.
192     void pushToStack(Graph::NodeItr nItr) {
193       getSolverNodeData(nItr).clearSolverEdges();
194       stack.push_back(nItr);
195     }
196
197     /// \brief Returns the solver degree of the given node.
198     /// @param nItr Node iterator for which degree is requested.
199     /// @return Node degree in the <i>solver</i> graph (not the original graph).
200     unsigned getSolverDegree(Graph::NodeItr nItr) {
201       return  getSolverNodeData(nItr).getSolverDegree();
202     }
203
204     /// \brief Set the solution of the given node.
205     /// @param nItr Node iterator to set solution for.
206     /// @param selection Selection for node.
207     void setSolution(const Graph::NodeItr &nItr, unsigned selection) {
208       s.setSelection(nItr, selection);
209
210       for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nItr),
211                              aeEnd = g.adjEdgesEnd(nItr);
212            aeItr != aeEnd; ++aeItr) {
213         Graph::EdgeItr eItr(*aeItr);
214         Graph::NodeItr anItr(g.getEdgeOtherNode(eItr, nItr));
215         getSolverNodeData(anItr).addSolverEdge(eItr);
216       }
217     }
218
219     /// \brief Apply rule R0.
220     /// @param nItr Node iterator for node to apply R0 to.
221     ///
222     /// Node will be automatically pushed to the solver stack.
223     void applyR0(Graph::NodeItr nItr) {
224       assert(getSolverNodeData(nItr).getSolverDegree() == 0 &&
225              "R0 applied to node with degree != 0.");
226
227       // Nothing to do. Just push the node onto the reduction stack.
228       pushToStack(nItr);
229     }
230
231     /// \brief Apply rule R1.
232     /// @param xnItr Node iterator for node to apply R1 to.
233     ///
234     /// Node will be automatically pushed to the solver stack.
235     void applyR1(Graph::NodeItr xnItr) {
236       NodeData &nd = getSolverNodeData(xnItr);
237       assert(nd.getSolverDegree() == 1 &&
238              "R1 applied to node with degree != 1.");
239
240       Graph::EdgeItr eItr = *nd.solverEdgesBegin();
241
242       const Matrix &eCosts = g.getEdgeCosts(eItr);
243       const Vector &xCosts = g.getNodeCosts(xnItr);
244       
245       // Duplicate a little to avoid transposing matrices.
246       if (xnItr == g.getEdgeNode1(eItr)) {
247         Graph::NodeItr ynItr = g.getEdgeNode2(eItr);
248         Vector &yCosts = g.getNodeCosts(ynItr);
249         for (unsigned j = 0; j < yCosts.getLength(); ++j) {
250           PBQPNum min = eCosts[0][j] + xCosts[0];
251           for (unsigned i = 1; i < xCosts.getLength(); ++i) {
252             PBQPNum c = eCosts[i][j] + xCosts[i];
253             if (c < min)
254               min = c;
255           }
256           yCosts[j] += min;
257         }
258         h.handleRemoveEdge(eItr, ynItr);
259      } else {
260         Graph::NodeItr ynItr = g.getEdgeNode1(eItr);
261         Vector &yCosts = g.getNodeCosts(ynItr);
262         for (unsigned i = 0; i < yCosts.getLength(); ++i) {
263           PBQPNum min = eCosts[i][0] + xCosts[0];
264           for (unsigned j = 1; j < xCosts.getLength(); ++j) {
265             PBQPNum c = eCosts[i][j] + xCosts[j];
266             if (c < min)
267               min = c;
268           }
269           yCosts[i] += min;
270         }
271         h.handleRemoveEdge(eItr, ynItr);
272       }
273       removeSolverEdge(eItr);
274       assert(nd.getSolverDegree() == 0 &&
275              "Degree 1 with edge removed should be 0.");
276       pushToStack(xnItr);
277     }
278
279     /// \brief Apply rule R2.
280     /// @param xnItr Node iterator for node to apply R2 to.
281     ///
282     /// Node will be automatically pushed to the solver stack.
283     void applyR2(Graph::NodeItr xnItr) {
284       assert(getSolverNodeData(xnItr).getSolverDegree() == 2 &&
285              "R2 applied to node with degree != 2.");
286
287       NodeData &nd = getSolverNodeData(xnItr);
288       const Vector &xCosts = g.getNodeCosts(xnItr);
289
290       SolverEdgeItr aeItr = nd.solverEdgesBegin();
291       Graph::EdgeItr yxeItr = *aeItr,
292                      zxeItr = *(++aeItr);
293
294       Graph::NodeItr ynItr = g.getEdgeOtherNode(yxeItr, xnItr),
295                      znItr = g.getEdgeOtherNode(zxeItr, xnItr);
296
297       bool flipEdge1 = (g.getEdgeNode1(yxeItr) == xnItr),
298            flipEdge2 = (g.getEdgeNode1(zxeItr) == xnItr);
299
300       const Matrix *yxeCosts = flipEdge1 ?
301         new Matrix(g.getEdgeCosts(yxeItr).transpose()) :
302         &g.getEdgeCosts(yxeItr);
303
304       const Matrix *zxeCosts = flipEdge2 ?
305         new Matrix(g.getEdgeCosts(zxeItr).transpose()) :
306         &g.getEdgeCosts(zxeItr);
307
308       unsigned xLen = xCosts.getLength(),
309                yLen = yxeCosts->getRows(),
310                zLen = zxeCosts->getRows();
311                
312       Matrix delta(yLen, zLen);
313
314       for (unsigned i = 0; i < yLen; ++i) {
315         for (unsigned j = 0; j < zLen; ++j) {
316           PBQPNum min = (*yxeCosts)[i][0] + (*zxeCosts)[j][0] + xCosts[0];
317           for (unsigned k = 1; k < xLen; ++k) {
318             PBQPNum c = (*yxeCosts)[i][k] + (*zxeCosts)[j][k] + xCosts[k];
319             if (c < min) {
320               min = c;
321             }
322           }
323           delta[i][j] = min;
324         }
325       }
326
327       if (flipEdge1)
328         delete yxeCosts;
329
330       if (flipEdge2)
331         delete zxeCosts;
332
333       Graph::EdgeItr yzeItr = g.findEdge(ynItr, znItr);
334       bool addedEdge = false;
335
336       if (yzeItr == g.edgesEnd()) {
337         yzeItr = g.addEdge(ynItr, znItr, delta);
338         addedEdge = true;
339       } else {
340         Matrix &yzeCosts = g.getEdgeCosts(yzeItr);
341         h.preUpdateEdgeCosts(yzeItr);
342         if (ynItr == g.getEdgeNode1(yzeItr)) {
343           yzeCosts += delta;
344         } else {
345           yzeCosts += delta.transpose();
346         }
347       }
348
349       bool nullCostEdge = tryNormaliseEdgeMatrix(yzeItr);
350
351       if (!addedEdge) {
352         // If we modified the edge costs let the heuristic know.
353         h.postUpdateEdgeCosts(yzeItr);
354       }
355  
356       if (nullCostEdge) {
357         // If this edge ended up null remove it.
358         if (!addedEdge) {
359           // We didn't just add it, so we need to notify the heuristic
360           // and remove it from the solver.
361           h.handleRemoveEdge(yzeItr, ynItr);
362           h.handleRemoveEdge(yzeItr, znItr);
363           removeSolverEdge(yzeItr);
364         }
365         g.removeEdge(yzeItr);
366       } else if (addedEdge) {
367         // If the edge was added, and non-null, finish setting it up, add it to
368         // the solver & notify heuristic.
369         edgeDataList.push_back(EdgeData());
370         g.setEdgeData(yzeItr, &edgeDataList.back());
371         addSolverEdge(yzeItr);
372         h.handleAddEdge(yzeItr);
373       }
374
375       h.handleRemoveEdge(yxeItr, ynItr);
376       removeSolverEdge(yxeItr);
377       h.handleRemoveEdge(zxeItr, znItr);
378       removeSolverEdge(zxeItr);
379
380       pushToStack(xnItr);
381     }
382
383   private:
384
385     NodeData& getSolverNodeData(Graph::NodeItr nItr) {
386       return *static_cast<NodeData*>(g.getNodeData(nItr));
387     }
388
389     EdgeData& getSolverEdgeData(Graph::EdgeItr eItr) {
390       return *static_cast<EdgeData*>(g.getEdgeData(eItr));
391     }
392
393     void addSolverEdge(Graph::EdgeItr eItr) {
394       EdgeData &eData = getSolverEdgeData(eItr);
395       NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eItr)),
396                &n2Data = getSolverNodeData(g.getEdgeNode2(eItr));
397
398       eData.setN1SolverEdgeItr(n1Data.addSolverEdge(eItr));
399       eData.setN2SolverEdgeItr(n2Data.addSolverEdge(eItr));
400     }
401
402     void setup() {
403       if (h.solverRunSimplify()) {
404         simplify();
405       }
406
407       // Create node data objects.
408       for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
409            nItr != nEnd; ++nItr) {
410         nodeDataList.push_back(NodeData());
411         g.setNodeData(nItr, &nodeDataList.back());
412       }
413
414       // Create edge data objects.
415       for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
416            eItr != eEnd; ++eItr) {
417         edgeDataList.push_back(EdgeData());
418         g.setEdgeData(eItr, &edgeDataList.back());
419         addSolverEdge(eItr);
420       }
421     }
422
423     void simplify() {
424       disconnectTrivialNodes();
425       eliminateIndependentEdges();
426     }
427
428     // Eliminate trivial nodes.
429     void disconnectTrivialNodes() {
430       unsigned numDisconnected = 0;
431
432       for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
433            nItr != nEnd; ++nItr) {
434
435         if (g.getNodeCosts(nItr).getLength() == 1) {
436
437           std::vector<Graph::EdgeItr> edgesToRemove;
438
439           for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nItr),
440                                  aeEnd = g.adjEdgesEnd(nItr);
441                aeItr != aeEnd; ++aeItr) {
442
443             Graph::EdgeItr eItr = *aeItr;
444
445             if (g.getEdgeNode1(eItr) == nItr) {
446               Graph::NodeItr otherNodeItr = g.getEdgeNode2(eItr);
447               g.getNodeCosts(otherNodeItr) +=
448                 g.getEdgeCosts(eItr).getRowAsVector(0);
449             }
450             else {
451               Graph::NodeItr otherNodeItr = g.getEdgeNode1(eItr);
452               g.getNodeCosts(otherNodeItr) +=
453                 g.getEdgeCosts(eItr).getColAsVector(0);
454             }
455
456             edgesToRemove.push_back(eItr);
457           }
458
459           if (!edgesToRemove.empty())
460             ++numDisconnected;
461
462           while (!edgesToRemove.empty()) {
463             g.removeEdge(edgesToRemove.back());
464             edgesToRemove.pop_back();
465           }
466         }
467       }
468     }
469
470     void eliminateIndependentEdges() {
471       std::vector<Graph::EdgeItr> edgesToProcess;
472       unsigned numEliminated = 0;
473
474       for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
475            eItr != eEnd; ++eItr) {
476         edgesToProcess.push_back(eItr);
477       }
478
479       while (!edgesToProcess.empty()) {
480         if (tryToEliminateEdge(edgesToProcess.back()))
481           ++numEliminated;
482         edgesToProcess.pop_back();
483       }
484     }
485
486     bool tryToEliminateEdge(Graph::EdgeItr eItr) {
487       if (tryNormaliseEdgeMatrix(eItr)) {
488         g.removeEdge(eItr);
489         return true; 
490       }
491       return false;
492     }
493
494     bool tryNormaliseEdgeMatrix(Graph::EdgeItr &eItr) {
495
496       const PBQPNum infinity = std::numeric_limits<PBQPNum>::infinity();
497
498       Matrix &edgeCosts = g.getEdgeCosts(eItr);
499       Vector &uCosts = g.getNodeCosts(g.getEdgeNode1(eItr)),
500              &vCosts = g.getNodeCosts(g.getEdgeNode2(eItr));
501
502       for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
503         PBQPNum rowMin = infinity;
504
505         for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
506           if (vCosts[c] != infinity && edgeCosts[r][c] < rowMin)
507             rowMin = edgeCosts[r][c];
508         }
509
510         uCosts[r] += rowMin;
511
512         if (rowMin != infinity) {
513           edgeCosts.subFromRow(r, rowMin);
514         }
515         else {
516           edgeCosts.setRow(r, 0);
517         }
518       }
519
520       for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
521         PBQPNum colMin = infinity;
522
523         for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
524           if (uCosts[r] != infinity && edgeCosts[r][c] < colMin)
525             colMin = edgeCosts[r][c];
526         }
527
528         vCosts[c] += colMin;
529
530         if (colMin != infinity) {
531           edgeCosts.subFromCol(c, colMin);
532         }
533         else {
534           edgeCosts.setCol(c, 0);
535         }
536       }
537
538       return edgeCosts.isZero();
539     }
540
541     void backpropagate() {
542       while (!stack.empty()) {
543         computeSolution(stack.back());
544         stack.pop_back();
545       }
546     }
547
548     void computeSolution(Graph::NodeItr nItr) {
549
550       NodeData &nodeData = getSolverNodeData(nItr);
551
552       Vector v(g.getNodeCosts(nItr));
553
554       // Solve based on existing solved edges.
555       for (SolverEdgeItr solvedEdgeItr = nodeData.solverEdgesBegin(),
556                          solvedEdgeEnd = nodeData.solverEdgesEnd();
557            solvedEdgeItr != solvedEdgeEnd; ++solvedEdgeItr) {
558
559         Graph::EdgeItr eItr(*solvedEdgeItr);
560         Matrix &edgeCosts = g.getEdgeCosts(eItr);
561
562         if (nItr == g.getEdgeNode1(eItr)) {
563           Graph::NodeItr adjNode(g.getEdgeNode2(eItr));
564           unsigned adjSolution = s.getSelection(adjNode);
565           v += edgeCosts.getColAsVector(adjSolution);
566         }
567         else {
568           Graph::NodeItr adjNode(g.getEdgeNode1(eItr));
569           unsigned adjSolution = s.getSelection(adjNode);
570           v += edgeCosts.getRowAsVector(adjSolution);
571         }
572
573       }
574
575       setSolution(nItr, v.minIndex());
576     }
577
578     void cleanup() {
579       h.cleanup();
580       nodeDataList.clear();
581       edgeDataList.clear();
582     }
583   };
584
585   /// \brief PBQP heuristic solver class.
586   ///
587   /// Given a PBQP Graph g representing a PBQP problem, you can find a solution
588   /// by calling
589   /// <tt>Solution s = HeuristicSolver<H>::solve(g);</tt>
590   ///
591   /// The choice of heuristic for the H parameter will affect both the solver
592   /// speed and solution quality. The heuristic should be chosen based on the
593   /// nature of the problem being solved.
594   /// Currently the only solver included with LLVM is the Briggs heuristic for
595   /// register allocation.
596   template <typename HImpl>
597   class HeuristicSolver {
598   public:
599     static Solution solve(Graph &g) {
600       HeuristicSolverImpl<HImpl> hs(g);
601       return hs.computeSolution();
602     }
603   };
604
605 }
606
607 #endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H