Fixed malformed -*- lines in PBQP headers.
[oota-llvm.git] / lib / CodeGen / PBQP / Heuristics / Briggs.h
1 //===-- Briggs.h --- Briggs Heuristic for PBQP ------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This class implements the Briggs test for "allocability" of nodes in a
11 // PBQP graph representing a register allocation problem. Nodes which can be
12 // proven allocable (by a safe and relatively accurate test) are removed from
13 // the PBQP graph first. If no provably allocable node is present in the graph
14 // then the node with the minimal spill-cost to degree ratio is removed.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #ifndef LLVM_CODEGEN_PBQP_HEURISTICS_BRIGGS_H
19 #define LLVM_CODEGEN_PBQP_HEURISTICS_BRIGGS_H
20
21 #include "../HeuristicSolver.h"
22
23 #include <set>
24
25 namespace PBQP {
26 namespace Heuristics {
27
28 class Briggs {
29   public:
30
31     class NodeData;
32     class EdgeData;
33
34   private:
35
36     typedef HeuristicSolverImpl<Briggs> Solver;
37     typedef HSITypes<NodeData, EdgeData> HSIT;
38     typedef HSIT::SolverGraph SolverGraph;
39     typedef HSIT::GraphNodeIterator GraphNodeIterator;
40     typedef HSIT::GraphEdgeIterator GraphEdgeIterator;
41
42     class LinkDegreeComparator {
43       public:
44         LinkDegreeComparator() : g(0) {}
45         LinkDegreeComparator(SolverGraph *g) : g(g) {}
46
47         bool operator()(const GraphNodeIterator &node1Itr,
48                         const GraphNodeIterator &node2Itr) const {
49           assert((g != 0) && "Graph object not set, cannot access node data.");
50           unsigned n1Degree = g->getNodeData(node1Itr).getLinkDegree(),
51                    n2Degree = g->getNodeData(node2Itr).getLinkDegree();
52           if (n1Degree > n2Degree) {
53             return true;
54           }
55           else if (n1Degree < n2Degree) {
56             return false;
57           }
58           // else they're "equal" by degree, differentiate based on ID.
59           return g->getNodeID(node1Itr) < g->getNodeID(node2Itr);
60         }
61
62       private:
63         SolverGraph *g;
64     };
65
66     class SpillPriorityComparator {
67       public:
68         SpillPriorityComparator() : g(0) {}
69         SpillPriorityComparator(SolverGraph *g) : g(g) {}
70
71         bool operator()(const GraphNodeIterator &node1Itr,
72                         const GraphNodeIterator &node2Itr) const {
73           assert((g != 0) && "Graph object not set, cannot access node data.");
74           PBQPNum cost1 =
75             g->getNodeCosts(node1Itr)[0] /
76             g->getNodeData(node1Itr).getLinkDegree(),
77             cost2 =
78               g->getNodeCosts(node2Itr)[0] /
79               g->getNodeData(node2Itr).getLinkDegree();
80
81           if (cost1 < cost2) {
82             return true;
83           }
84           else if (cost1 > cost2) {
85             return false;
86           }
87           // else they'er "equal" again, differentiate based on address again.
88           return g->getNodeID(node1Itr) < g->getNodeID(node2Itr);
89         }
90
91       private:
92         SolverGraph *g;
93     };
94
95     typedef std::set<GraphNodeIterator, LinkDegreeComparator>
96       RNAllocableNodeList;
97     typedef RNAllocableNodeList::iterator RNAllocableNodeListIterator;
98
99     typedef std::set<GraphNodeIterator, SpillPriorityComparator>
100       RNUnallocableNodeList;
101     typedef RNUnallocableNodeList::iterator RNUnallocableNodeListIterator;
102
103   public:
104
105     class NodeData {
106       private:
107         RNAllocableNodeListIterator rNAllocableNodeListItr;
108         RNUnallocableNodeListIterator rNUnallocableNodeListItr;
109         unsigned numRegOptions, numDenied, numSafe;
110         std::vector<unsigned> unsafeDegrees;
111         bool allocable;
112
113         void addRemoveLink(SolverGraph &g, const GraphNodeIterator &nodeItr,
114             const GraphEdgeIterator &edgeItr, bool add) {
115
116           //assume we're adding...
117           unsigned udTarget = 0, dir = 1;
118
119           if (!add) {
120             udTarget = 1;
121             dir = ~0;
122           }
123
124           EdgeData &linkEdgeData = g.getEdgeData(edgeItr).getHeuristicData();
125
126           EdgeData::ConstUnsafeIterator edgeUnsafeBegin, edgeUnsafeEnd;
127
128           if (nodeItr == g.getEdgeNode1Itr(edgeItr)) {
129             numDenied += (dir * linkEdgeData.getWorstDegree());
130             edgeUnsafeBegin = linkEdgeData.unsafeBegin();
131             edgeUnsafeEnd = linkEdgeData.unsafeEnd();
132           }
133           else {
134             numDenied += (dir * linkEdgeData.getReverseWorstDegree());
135             edgeUnsafeBegin = linkEdgeData.reverseUnsafeBegin();
136             edgeUnsafeEnd = linkEdgeData.reverseUnsafeEnd();
137           }
138
139           assert((unsafeDegrees.size() ==
140                 static_cast<unsigned>(
141                   std::distance(edgeUnsafeBegin, edgeUnsafeEnd)))
142               && "Unsafe array size mismatch.");
143
144           std::vector<unsigned>::iterator unsafeDegreesItr =
145             unsafeDegrees.begin();
146
147           for (EdgeData::ConstUnsafeIterator edgeUnsafeItr = edgeUnsafeBegin;
148               edgeUnsafeItr != edgeUnsafeEnd;
149               ++edgeUnsafeItr, ++unsafeDegreesItr) {
150
151             if ((*edgeUnsafeItr == 1) && (*unsafeDegreesItr == udTarget))  {
152               numSafe -= dir;
153             }
154             *unsafeDegreesItr += (dir * (*edgeUnsafeItr));
155           }
156
157           allocable = (numDenied < numRegOptions) || (numSafe > 0);
158         }
159
160       public:
161
162         void setup(SolverGraph &g, const GraphNodeIterator &nodeItr) {
163
164           numRegOptions = g.getNodeCosts(nodeItr).getLength() - 1;
165
166           numSafe = numRegOptions; // Optimistic, correct below.
167           numDenied = 0; // Also optimistic.
168           unsafeDegrees.resize(numRegOptions, 0);
169
170           HSIT::NodeData &nodeData = g.getNodeData(nodeItr);
171
172           for (HSIT::NodeData::AdjLinkIterator
173               adjLinkItr = nodeData.adjLinksBegin(),
174               adjLinkEnd = nodeData.adjLinksEnd();
175               adjLinkItr != adjLinkEnd; ++adjLinkItr) {
176
177             addRemoveLink(g, nodeItr, *adjLinkItr, true);
178           }
179         }
180
181         bool isAllocable() const { return allocable; }
182
183         void handleAddLink(SolverGraph &g, const GraphNodeIterator &nodeItr,
184             const GraphEdgeIterator &adjEdge) {
185           addRemoveLink(g, nodeItr, adjEdge, true);
186         }
187
188         void handleRemoveLink(SolverGraph &g, const GraphNodeIterator &nodeItr,
189             const GraphEdgeIterator &adjEdge) {
190           addRemoveLink(g, nodeItr, adjEdge, false);
191         }
192
193         void setRNAllocableNodeListItr(
194             const RNAllocableNodeListIterator &rNAllocableNodeListItr) {
195
196           this->rNAllocableNodeListItr = rNAllocableNodeListItr;
197         }
198
199         RNAllocableNodeListIterator getRNAllocableNodeListItr() const {
200           return rNAllocableNodeListItr;
201         }
202
203         void setRNUnallocableNodeListItr(
204             const RNUnallocableNodeListIterator &rNUnallocableNodeListItr) {
205
206           this->rNUnallocableNodeListItr = rNUnallocableNodeListItr;
207         }
208
209         RNUnallocableNodeListIterator getRNUnallocableNodeListItr() const {
210           return rNUnallocableNodeListItr;
211         }
212
213
214     };
215
216     class EdgeData {
217       private:
218
219         typedef std::vector<unsigned> UnsafeArray;
220
221         unsigned worstDegree,
222                  reverseWorstDegree;
223         UnsafeArray unsafe, reverseUnsafe;
224
225       public:
226
227         EdgeData() : worstDegree(0), reverseWorstDegree(0) {}
228
229         typedef UnsafeArray::const_iterator ConstUnsafeIterator;
230
231         void setup(SolverGraph &g, const GraphEdgeIterator &edgeItr) {
232           const Matrix &edgeCosts = g.getEdgeCosts(edgeItr);
233           unsigned numRegs = edgeCosts.getRows() - 1,
234                    numReverseRegs = edgeCosts.getCols() - 1;
235
236           unsafe.resize(numRegs, 0);
237           reverseUnsafe.resize(numReverseRegs, 0);
238
239           std::vector<unsigned> rowInfCounts(numRegs, 0),
240                                 colInfCounts(numReverseRegs, 0);
241
242           for (unsigned i = 0; i < numRegs; ++i) {
243             for (unsigned j = 0; j < numReverseRegs; ++j) {
244               if (edgeCosts[i + 1][j + 1] ==
245                   std::numeric_limits<PBQPNum>::infinity()) {
246                 unsafe[i] = 1;
247                 reverseUnsafe[j] = 1;
248                 ++rowInfCounts[i];
249                 ++colInfCounts[j];
250
251                 if (colInfCounts[j] > worstDegree) {
252                   worstDegree = colInfCounts[j];
253                 }
254
255                 if (rowInfCounts[i] > reverseWorstDegree) {
256                   reverseWorstDegree = rowInfCounts[i];
257                 }
258               }
259             }
260           }
261         }
262
263         unsigned getWorstDegree() const { return worstDegree; }
264         unsigned getReverseWorstDegree() const { return reverseWorstDegree; }
265         ConstUnsafeIterator unsafeBegin() const { return unsafe.begin(); }
266         ConstUnsafeIterator unsafeEnd() const { return unsafe.end(); }
267         ConstUnsafeIterator reverseUnsafeBegin() const {
268           return reverseUnsafe.begin();
269         }
270         ConstUnsafeIterator reverseUnsafeEnd() const {
271           return reverseUnsafe.end();
272         }
273     };
274
275   void initialise(Solver &solver) {
276     this->s = &solver;
277     g = &s->getGraph();
278     rNAllocableBucket = RNAllocableNodeList(LinkDegreeComparator(g));
279     rNUnallocableBucket =
280       RNUnallocableNodeList(SpillPriorityComparator(g));
281     
282     for (GraphEdgeIterator
283          edgeItr = g->edgesBegin(), edgeEnd = g->edgesEnd();
284          edgeItr != edgeEnd; ++edgeItr) {
285
286       g->getEdgeData(edgeItr).getHeuristicData().setup(*g, edgeItr);
287     }
288
289     for (GraphNodeIterator
290          nodeItr = g->nodesBegin(), nodeEnd = g->nodesEnd();
291          nodeItr != nodeEnd; ++nodeItr) {
292
293       g->getNodeData(nodeItr).getHeuristicData().setup(*g, nodeItr);
294     }
295   }
296
297   void addToRNBucket(const GraphNodeIterator &nodeItr) {
298     NodeData &nodeData = g->getNodeData(nodeItr).getHeuristicData();
299
300     if (nodeData.isAllocable()) {
301       nodeData.setRNAllocableNodeListItr(
302         rNAllocableBucket.insert(rNAllocableBucket.begin(), nodeItr));
303     }
304     else {
305       nodeData.setRNUnallocableNodeListItr(
306         rNUnallocableBucket.insert(rNUnallocableBucket.begin(), nodeItr));
307     }
308   }
309
310   void removeFromRNBucket(const GraphNodeIterator &nodeItr) {
311     NodeData &nodeData = g->getNodeData(nodeItr).getHeuristicData();
312
313     if (nodeData.isAllocable()) {
314       rNAllocableBucket.erase(nodeData.getRNAllocableNodeListItr());
315     }
316     else {
317       rNUnallocableBucket.erase(nodeData.getRNUnallocableNodeListItr());
318     }
319   }
320
321   void handleAddLink(const GraphEdgeIterator &edgeItr) {
322     // We assume that if we got here this edge is attached to at least
323     // one high degree node.
324     g->getEdgeData(edgeItr).getHeuristicData().setup(*g, edgeItr);
325
326     GraphNodeIterator n1Itr = g->getEdgeNode1Itr(edgeItr),
327                       n2Itr = g->getEdgeNode2Itr(edgeItr);
328    
329     HSIT::NodeData &n1Data = g->getNodeData(n1Itr),
330                    &n2Data = g->getNodeData(n2Itr);
331
332     if (n1Data.getLinkDegree() > 2) {
333       n1Data.getHeuristicData().handleAddLink(*g, n1Itr, edgeItr);
334     }
335     if (n2Data.getLinkDegree() > 2) {
336       n2Data.getHeuristicData().handleAddLink(*g, n2Itr, edgeItr);
337     }
338   }
339
340   void handleRemoveLink(const GraphEdgeIterator &edgeItr,
341                         const GraphNodeIterator &nodeItr) {
342     NodeData &nodeData = g->getNodeData(nodeItr).getHeuristicData();
343     nodeData.handleRemoveLink(*g, nodeItr, edgeItr);
344   }
345
346   void processRN() {
347     
348     if (!rNAllocableBucket.empty()) {
349       GraphNodeIterator selectedNodeItr = *rNAllocableBucket.begin();
350       //std::cerr << "RN safely pushing " << g->getNodeID(selectedNodeItr) << "\n";
351       rNAllocableBucket.erase(rNAllocableBucket.begin());
352       s->pushStack(selectedNodeItr);
353       s->unlinkNode(selectedNodeItr);
354     }
355     else {
356       GraphNodeIterator selectedNodeItr = *rNUnallocableBucket.begin();
357       //std::cerr << "RN optimistically pushing " << g->getNodeID(selectedNodeItr) << "\n";
358       rNUnallocableBucket.erase(rNUnallocableBucket.begin());
359       s->pushStack(selectedNodeItr);
360       s->unlinkNode(selectedNodeItr);
361     }
362  
363   }
364
365   bool rNBucketEmpty() const {
366     return (rNAllocableBucket.empty() && rNUnallocableBucket.empty());
367   }
368
369 private:
370
371   Solver *s;
372   SolverGraph *g;
373   RNAllocableNodeList rNAllocableBucket;
374   RNUnallocableNodeList rNUnallocableBucket;
375 };
376
377
378
379 }
380 }
381
382
383 #endif // LLVM_CODEGEN_PBQP_HEURISTICS_BRIGGS_H