Fix "the the" and similar typos.
[oota-llvm.git] / lib / CodeGen / PBQP / HeuristicSolver.h
1 //===-- HeuristicSolver.h - Heuristic PBQP Solver --------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Heuristic PBQP solver. This solver is able to perform optimal reductions for
11 // nodes of degree 0, 1 or 2. For nodes of degree >2 a plugable heuristic is
12 // used to select a node for reduction. 
13 //
14 //===----------------------------------------------------------------------===//
15
16 #ifndef LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
17 #define LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H
18
19 #include "Graph.h"
20 #include "Solution.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include <vector>
23 #include <limits>
24
25 namespace PBQP {
26
27   /// \brief Heuristic PBQP solver implementation.
28   ///
29   /// This class should usually be created (and destroyed) indirectly via a call
30   /// to HeuristicSolver<HImpl>::solve(Graph&).
31   /// See the comments for HeuristicSolver.
32   ///
33   /// HeuristicSolverImpl provides the R0, R1 and R2 reduction rules,
34   /// backpropagation phase, and maintains the internal copy of the graph on
35   /// which the reduction is carried out (the original being kept to facilitate
36   /// backpropagation).
37   template <typename HImpl>
38   class HeuristicSolverImpl {
39   private:
40
41     typedef typename HImpl::NodeData HeuristicNodeData;
42     typedef typename HImpl::EdgeData HeuristicEdgeData;
43
44     typedef std::list<Graph::EdgeItr> SolverEdges;
45
46   public:
47   
48     /// \brief Iterator type for edges in the solver graph.
49     typedef SolverEdges::iterator SolverEdgeItr;
50
51   private:
52
53     class NodeData {
54     public:
55       NodeData() : solverDegree(0) {}
56
57       HeuristicNodeData& getHeuristicData() { return hData; }
58
59       SolverEdgeItr addSolverEdge(Graph::EdgeItr eItr) {
60         ++solverDegree;
61         return solverEdges.insert(solverEdges.end(), eItr);
62       }
63
64       void removeSolverEdge(SolverEdgeItr seItr) {
65         --solverDegree;
66         solverEdges.erase(seItr);
67       }
68
69       SolverEdgeItr solverEdgesBegin() { return solverEdges.begin(); }
70       SolverEdgeItr solverEdgesEnd() { return solverEdges.end(); }
71       unsigned getSolverDegree() const { return solverDegree; }
72       void clearSolverEdges() {
73         solverDegree = 0;
74         solverEdges.clear(); 
75       }
76       
77     private:
78       HeuristicNodeData hData;
79       unsigned solverDegree;
80       SolverEdges solverEdges;
81     };
82  
83     class EdgeData {
84     public:
85       HeuristicEdgeData& getHeuristicData() { return hData; }
86
87       void setN1SolverEdgeItr(SolverEdgeItr n1SolverEdgeItr) {
88         this->n1SolverEdgeItr = n1SolverEdgeItr;
89       }
90
91       SolverEdgeItr getN1SolverEdgeItr() { return n1SolverEdgeItr; }
92
93       void setN2SolverEdgeItr(SolverEdgeItr n2SolverEdgeItr){
94         this->n2SolverEdgeItr = n2SolverEdgeItr;
95       }
96
97       SolverEdgeItr getN2SolverEdgeItr() { return n2SolverEdgeItr; }
98
99     private:
100
101       HeuristicEdgeData hData;
102       SolverEdgeItr n1SolverEdgeItr, n2SolverEdgeItr;
103     };
104
105     Graph &g;
106     HImpl h;
107     Solution s;
108     std::vector<Graph::NodeItr> stack;
109
110     typedef std::list<NodeData> NodeDataList;
111     NodeDataList nodeDataList;
112
113     typedef std::list<EdgeData> EdgeDataList;
114     EdgeDataList edgeDataList;
115
116   public:
117
118     /// \brief Construct a heuristic solver implementation to solve the given
119     ///        graph.
120     /// @param g The graph representing the problem instance to be solved.
121     HeuristicSolverImpl(Graph &g) : g(g), h(*this) {}  
122
123     /// \brief Get the graph being solved by this solver.
124     /// @return The graph representing the problem instance being solved by this
125     ///         solver.
126     Graph& getGraph() { return g; }
127
128     /// \brief Get the heuristic data attached to the given node.
129     /// @param nItr Node iterator.
130     /// @return The heuristic data attached to the given node.
131     HeuristicNodeData& getHeuristicNodeData(Graph::NodeItr nItr) {
132       return getSolverNodeData(nItr).getHeuristicData();
133     }
134
135     /// \brief Get the heuristic data attached to the given edge.
136     /// @param eItr Edge iterator.
137     /// @return The heuristic data attached to the given node.
138     HeuristicEdgeData& getHeuristicEdgeData(Graph::EdgeItr eItr) {
139       return getSolverEdgeData(eItr).getHeuristicData();
140     }
141
142     /// \brief Begin iterator for the set of edges adjacent to the given node in
143     ///        the solver graph.
144     /// @param nItr Node iterator.
145     /// @return Begin iterator for the set of edges adjacent to the given node
146     ///         in the solver graph. 
147     SolverEdgeItr solverEdgesBegin(Graph::NodeItr nItr) {
148       return getSolverNodeData(nItr).solverEdgesBegin();
149     }
150
151     /// \brief End iterator for the set of edges adjacent to the given node in
152     ///        the solver graph.
153     /// @param nItr Node iterator.
154     /// @return End iterator for the set of edges adjacent to the given node in
155     ///         the solver graph. 
156     SolverEdgeItr solverEdgesEnd(Graph::NodeItr nItr) {
157       return getSolverNodeData(nItr).solverEdgesEnd();
158     }
159
160     /// \brief Remove a node from the solver graph.
161     /// @param eItr Edge iterator for edge to be removed.
162     ///
163     /// Does <i>not</i> notify the heuristic of the removal. That should be
164     /// done manually if necessary.
165     void removeSolverEdge(Graph::EdgeItr eItr) {
166       EdgeData &eData = getSolverEdgeData(eItr);
167       NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eItr)),
168                &n2Data = getSolverNodeData(g.getEdgeNode2(eItr));
169
170       n1Data.removeSolverEdge(eData.getN1SolverEdgeItr());
171       n2Data.removeSolverEdge(eData.getN2SolverEdgeItr());
172     }
173
174     /// \brief Compute a solution to the PBQP problem instance with which this
175     ///        heuristic solver was constructed.
176     /// @return A solution to the PBQP problem.
177     ///
178     /// Performs the full PBQP heuristic solver algorithm, including setup,
179     /// calls to the heuristic (which will call back to the reduction rules in
180     /// this class), and cleanup.
181     Solution computeSolution() {
182       setup();
183       h.setup();
184       h.reduce();
185       backpropagate();
186       h.cleanup();
187       cleanup();
188       return s;
189     }
190
191     /// \brief Add to the end of the stack.
192     /// @param nItr Node iterator to add to the reduction stack.
193     void pushToStack(Graph::NodeItr nItr) {
194       getSolverNodeData(nItr).clearSolverEdges();
195       stack.push_back(nItr);
196     }
197
198     /// \brief Returns the solver degree of the given node.
199     /// @param nItr Node iterator for which degree is requested.
200     /// @return Node degree in the <i>solver</i> graph (not the original graph).
201     unsigned getSolverDegree(Graph::NodeItr nItr) {
202       return  getSolverNodeData(nItr).getSolverDegree();
203     }
204
205     /// \brief Set the solution of the given node.
206     /// @param nItr Node iterator to set solution for.
207     /// @param selection Selection for node.
208     void setSolution(const Graph::NodeItr &nItr, unsigned selection) {
209       s.setSelection(nItr, selection);
210
211       for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nItr),
212                              aeEnd = g.adjEdgesEnd(nItr);
213            aeItr != aeEnd; ++aeItr) {
214         Graph::EdgeItr eItr(*aeItr);
215         Graph::NodeItr anItr(g.getEdgeOtherNode(eItr, nItr));
216         getSolverNodeData(anItr).addSolverEdge(eItr);
217       }
218     }
219
220     /// \brief Apply rule R0.
221     /// @param nItr Node iterator for node to apply R0 to.
222     ///
223     /// Node will be automatically pushed to the solver stack.
224     void applyR0(Graph::NodeItr nItr) {
225       assert(getSolverNodeData(nItr).getSolverDegree() == 0 &&
226              "R0 applied to node with degree != 0.");
227
228       // Nothing to do. Just push the node onto the reduction stack.
229       pushToStack(nItr);
230     }
231
232     /// \brief Apply rule R1.
233     /// @param nItr Node iterator for node to apply R1 to.
234     ///
235     /// Node will be automatically pushed to the solver stack.
236     void applyR1(Graph::NodeItr xnItr) {
237       NodeData &nd = getSolverNodeData(xnItr);
238       assert(nd.getSolverDegree() == 1 &&
239              "R1 applied to node with degree != 1.");
240
241       Graph::EdgeItr eItr = *nd.solverEdgesBegin();
242
243       const Matrix &eCosts = g.getEdgeCosts(eItr);
244       const Vector &xCosts = g.getNodeCosts(xnItr);
245       
246       // Duplicate a little to avoid transposing matrices.
247       if (xnItr == g.getEdgeNode1(eItr)) {
248         Graph::NodeItr ynItr = g.getEdgeNode2(eItr);
249         Vector &yCosts = g.getNodeCosts(ynItr);
250         for (unsigned j = 0; j < yCosts.getLength(); ++j) {
251           PBQPNum min = eCosts[0][j] + xCosts[0];
252           for (unsigned i = 1; i < xCosts.getLength(); ++i) {
253             PBQPNum c = eCosts[i][j] + xCosts[i];
254             if (c < min)
255               min = c;
256           }
257           yCosts[j] += min;
258         }
259         h.handleRemoveEdge(eItr, ynItr);
260      } else {
261         Graph::NodeItr ynItr = g.getEdgeNode1(eItr);
262         Vector &yCosts = g.getNodeCosts(ynItr);
263         for (unsigned i = 0; i < yCosts.getLength(); ++i) {
264           PBQPNum min = eCosts[i][0] + xCosts[0];
265           for (unsigned j = 1; j < xCosts.getLength(); ++j) {
266             PBQPNum c = eCosts[i][j] + xCosts[j];
267             if (c < min)
268               min = c;
269           }
270           yCosts[i] += min;
271         }
272         h.handleRemoveEdge(eItr, ynItr);
273       }
274       removeSolverEdge(eItr);
275       assert(nd.getSolverDegree() == 0 &&
276              "Degree 1 with edge removed should be 0.");
277       pushToStack(xnItr);
278     }
279
280     /// \brief Apply rule R2.
281     /// @param nItr Node iterator for node to apply R2 to.
282     ///
283     /// Node will be automatically pushed to the solver stack.
284     void applyR2(Graph::NodeItr xnItr) {
285       assert(getSolverNodeData(xnItr).getSolverDegree() == 2 &&
286              "R2 applied to node with degree != 2.");
287
288       NodeData &nd = getSolverNodeData(xnItr);
289       const Vector &xCosts = g.getNodeCosts(xnItr);
290
291       SolverEdgeItr aeItr = nd.solverEdgesBegin();
292       Graph::EdgeItr yxeItr = *aeItr,
293                      zxeItr = *(++aeItr);
294
295       Graph::NodeItr ynItr = g.getEdgeOtherNode(yxeItr, xnItr),
296                      znItr = g.getEdgeOtherNode(zxeItr, xnItr);
297
298       bool flipEdge1 = (g.getEdgeNode1(yxeItr) == xnItr),
299            flipEdge2 = (g.getEdgeNode1(zxeItr) == xnItr);
300
301       const Matrix *yxeCosts = flipEdge1 ?
302         new Matrix(g.getEdgeCosts(yxeItr).transpose()) :
303         &g.getEdgeCosts(yxeItr);
304
305       const Matrix *zxeCosts = flipEdge2 ?
306         new Matrix(g.getEdgeCosts(zxeItr).transpose()) :
307         &g.getEdgeCosts(zxeItr);
308
309       unsigned xLen = xCosts.getLength(),
310                yLen = yxeCosts->getRows(),
311                zLen = zxeCosts->getRows();
312                
313       Matrix delta(yLen, zLen);
314
315       for (unsigned i = 0; i < yLen; ++i) {
316         for (unsigned j = 0; j < zLen; ++j) {
317           PBQPNum min = (*yxeCosts)[i][0] + (*zxeCosts)[j][0] + xCosts[0];
318           for (unsigned k = 1; k < xLen; ++k) {
319             PBQPNum c = (*yxeCosts)[i][k] + (*zxeCosts)[j][k] + xCosts[k];
320             if (c < min) {
321               min = c;
322             }
323           }
324           delta[i][j] = min;
325         }
326       }
327
328       if (flipEdge1)
329         delete yxeCosts;
330
331       if (flipEdge2)
332         delete zxeCosts;
333
334       Graph::EdgeItr yzeItr = g.findEdge(ynItr, znItr);
335       bool addedEdge = false;
336
337       if (yzeItr == g.edgesEnd()) {
338         yzeItr = g.addEdge(ynItr, znItr, delta);
339         addedEdge = true;
340       } else {
341         Matrix &yzeCosts = g.getEdgeCosts(yzeItr);
342         h.preUpdateEdgeCosts(yzeItr);
343         if (ynItr == g.getEdgeNode1(yzeItr)) {
344           yzeCosts += delta;
345         } else {
346           yzeCosts += delta.transpose();
347         }
348       }
349
350       bool nullCostEdge = tryNormaliseEdgeMatrix(yzeItr);
351
352       if (!addedEdge) {
353         // If we modified the edge costs let the heuristic know.
354         h.postUpdateEdgeCosts(yzeItr);
355       }
356  
357       if (nullCostEdge) {
358         // If this edge ended up null remove it.
359         if (!addedEdge) {
360           // We didn't just add it, so we need to notify the heuristic
361           // and remove it from the solver.
362           h.handleRemoveEdge(yzeItr, ynItr);
363           h.handleRemoveEdge(yzeItr, znItr);
364           removeSolverEdge(yzeItr);
365         }
366         g.removeEdge(yzeItr);
367       } else if (addedEdge) {
368         // If the edge was added, and non-null, finish setting it up, add it to
369         // the solver & notify heuristic.
370         edgeDataList.push_back(EdgeData());
371         g.setEdgeData(yzeItr, &edgeDataList.back());
372         addSolverEdge(yzeItr);
373         h.handleAddEdge(yzeItr);
374       }
375
376       h.handleRemoveEdge(yxeItr, ynItr);
377       removeSolverEdge(yxeItr);
378       h.handleRemoveEdge(zxeItr, znItr);
379       removeSolverEdge(zxeItr);
380
381       pushToStack(xnItr);
382     }
383
384   private:
385
386     NodeData& getSolverNodeData(Graph::NodeItr nItr) {
387       return *static_cast<NodeData*>(g.getNodeData(nItr));
388     }
389
390     EdgeData& getSolverEdgeData(Graph::EdgeItr eItr) {
391       return *static_cast<EdgeData*>(g.getEdgeData(eItr));
392     }
393
394     void addSolverEdge(Graph::EdgeItr eItr) {
395       EdgeData &eData = getSolverEdgeData(eItr);
396       NodeData &n1Data = getSolverNodeData(g.getEdgeNode1(eItr)),
397                &n2Data = getSolverNodeData(g.getEdgeNode2(eItr));
398
399       eData.setN1SolverEdgeItr(n1Data.addSolverEdge(eItr));
400       eData.setN2SolverEdgeItr(n2Data.addSolverEdge(eItr));
401     }
402
403     void setup() {
404       if (h.solverRunSimplify()) {
405         simplify();
406       }
407
408       // Create node data objects.
409       for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
410                nItr != nEnd; ++nItr) {
411         nodeDataList.push_back(NodeData());
412         g.setNodeData(nItr, &nodeDataList.back());
413       }
414
415       // Create edge data objects.
416       for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
417            eItr != eEnd; ++eItr) {
418         edgeDataList.push_back(EdgeData());
419         g.setEdgeData(eItr, &edgeDataList.back());
420         addSolverEdge(eItr);
421       }
422     }
423
424     void simplify() {
425       disconnectTrivialNodes();
426       eliminateIndependentEdges();
427     }
428
429     // Eliminate trivial nodes.
430     void disconnectTrivialNodes() {
431       unsigned numDisconnected = 0;
432
433       for (Graph::NodeItr nItr = g.nodesBegin(), nEnd = g.nodesEnd();
434            nItr != nEnd; ++nItr) {
435
436         if (g.getNodeCosts(nItr).getLength() == 1) {
437
438           std::vector<Graph::EdgeItr> edgesToRemove;
439
440           for (Graph::AdjEdgeItr aeItr = g.adjEdgesBegin(nItr),
441                                  aeEnd = g.adjEdgesEnd(nItr);
442                aeItr != aeEnd; ++aeItr) {
443
444             Graph::EdgeItr eItr = *aeItr;
445
446             if (g.getEdgeNode1(eItr) == nItr) {
447               Graph::NodeItr otherNodeItr = g.getEdgeNode2(eItr);
448               g.getNodeCosts(otherNodeItr) +=
449                 g.getEdgeCosts(eItr).getRowAsVector(0);
450             }
451             else {
452               Graph::NodeItr otherNodeItr = g.getEdgeNode1(eItr);
453               g.getNodeCosts(otherNodeItr) +=
454                 g.getEdgeCosts(eItr).getColAsVector(0);
455             }
456
457             edgesToRemove.push_back(eItr);
458           }
459
460           if (!edgesToRemove.empty())
461             ++numDisconnected;
462
463           while (!edgesToRemove.empty()) {
464             g.removeEdge(edgesToRemove.back());
465             edgesToRemove.pop_back();
466           }
467         }
468       }
469     }
470
471     void eliminateIndependentEdges() {
472       std::vector<Graph::EdgeItr> edgesToProcess;
473       unsigned numEliminated = 0;
474
475       for (Graph::EdgeItr eItr = g.edgesBegin(), eEnd = g.edgesEnd();
476            eItr != eEnd; ++eItr) {
477         edgesToProcess.push_back(eItr);
478       }
479
480       while (!edgesToProcess.empty()) {
481         if (tryToEliminateEdge(edgesToProcess.back()))
482           ++numEliminated;
483         edgesToProcess.pop_back();
484       }
485     }
486
487     bool tryToEliminateEdge(Graph::EdgeItr eItr) {
488       if (tryNormaliseEdgeMatrix(eItr)) {
489         g.removeEdge(eItr);
490         return true; 
491       }
492       return false;
493     }
494
495     bool tryNormaliseEdgeMatrix(Graph::EdgeItr &eItr) {
496
497       Matrix &edgeCosts = g.getEdgeCosts(eItr);
498       Vector &uCosts = g.getNodeCosts(g.getEdgeNode1(eItr)),
499              &vCosts = g.getNodeCosts(g.getEdgeNode2(eItr));
500
501       for (unsigned r = 0; r < edgeCosts.getRows(); ++r) {
502         PBQPNum rowMin = edgeCosts.getRowMin(r);
503         uCosts[r] += rowMin;
504         if (rowMin != std::numeric_limits<PBQPNum>::infinity()) {
505           edgeCosts.subFromRow(r, rowMin);
506         }
507         else {
508           edgeCosts.setRow(r, 0);
509         }
510       }
511
512       for (unsigned c = 0; c < edgeCosts.getCols(); ++c) {
513         PBQPNum colMin = edgeCosts.getColMin(c);
514         vCosts[c] += colMin;
515         if (colMin != std::numeric_limits<PBQPNum>::infinity()) {
516           edgeCosts.subFromCol(c, colMin);
517         }
518         else {
519           edgeCosts.setCol(c, 0);
520         }
521       }
522
523       return edgeCosts.isZero();
524     }
525
526     void backpropagate() {
527       while (!stack.empty()) {
528         computeSolution(stack.back());
529         stack.pop_back();
530       }
531     }
532
533     void computeSolution(Graph::NodeItr nItr) {
534
535       NodeData &nodeData = getSolverNodeData(nItr);
536
537       Vector v(g.getNodeCosts(nItr));
538
539       // Solve based on existing solved edges.
540       for (SolverEdgeItr solvedEdgeItr = nodeData.solverEdgesBegin(),
541                          solvedEdgeEnd = nodeData.solverEdgesEnd();
542            solvedEdgeItr != solvedEdgeEnd; ++solvedEdgeItr) {
543
544         Graph::EdgeItr eItr(*solvedEdgeItr);
545         Matrix &edgeCosts = g.getEdgeCosts(eItr);
546
547         if (nItr == g.getEdgeNode1(eItr)) {
548           Graph::NodeItr adjNode(g.getEdgeNode2(eItr));
549           unsigned adjSolution = s.getSelection(adjNode);
550           v += edgeCosts.getColAsVector(adjSolution);
551         }
552         else {
553           Graph::NodeItr adjNode(g.getEdgeNode1(eItr));
554           unsigned adjSolution = s.getSelection(adjNode);
555           v += edgeCosts.getRowAsVector(adjSolution);
556         }
557
558       }
559
560       setSolution(nItr, v.minIndex());
561     }
562
563     void cleanup() {
564       h.cleanup();
565       nodeDataList.clear();
566       edgeDataList.clear();
567     }
568   };
569
570   /// \brief PBQP heuristic solver class.
571   ///
572   /// Given a PBQP Graph g representing a PBQP problem, you can find a solution
573   /// by calling
574   /// <tt>Solution s = HeuristicSolver<H>::solve(g);</tt>
575   ///
576   /// The choice of heuristic for the H parameter will affect both the solver
577   /// speed and solution quality. The heuristic should be chosen based on the
578   /// nature of the problem being solved.
579   /// Currently the only solver included with LLVM is the Briggs heuristic for
580   /// register allocation.
581   template <typename HImpl>
582   class HeuristicSolver {
583   public:
584     static Solution solve(Graph &g) {
585       HeuristicSolverImpl<HImpl> hs(g);
586       return hs.computeSolution();
587     }
588   };
589
590 }
591
592 #endif // LLVM_CODEGEN_PBQP_HEURISTICSOLVER_H