[PBQP] Replace PBQPBuilder with composable constraints (PBQPRAConstraint).
[oota-llvm.git] / include / llvm / CodeGen / PBQP / RegAllocSolver.h
1 //===-- RegAllocSolver.h - Heuristic PBQP Solver for reg alloc --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Heuristic PBQP solver for register allocation problems. This solver uses a
11 // graph reduction approach. Nodes of degree 0, 1 and 2 are eliminated with
12 // optimality-preserving rules (see ReductionRules.h). When no low-degree (<3)
13 // nodes are present, a heuristic derived from Brigg's graph coloring approach
14 // is used.
15 //
16 //===----------------------------------------------------------------------===//
17
18 #ifndef LLVM_CODEGEN_PBQP_REGALLOCSOLVER_H
19 #define LLVM_CODEGEN_PBQP_REGALLOCSOLVER_H
20
21 #include "CostAllocator.h"
22 #include "Graph.h"
23 #include "ReductionRules.h"
24 #include "Solution.h"
25 #include "llvm/Support/ErrorHandling.h"
26 #include <limits>
27 #include <vector>
28
29 namespace llvm{
30 namespace PBQP {
31   namespace RegAlloc {
32
33     /// @brief Spill option index.
34     inline unsigned getSpillOptionIdx() { return 0; }
35
36     /// \brief Metadata to speed allocatability test.
37     ///
38     /// Keeps track of the number of infinities in each row and column.
39     class MatrixMetadata {
40     private:
41       MatrixMetadata(const MatrixMetadata&);
42       void operator=(const MatrixMetadata&);
43     public:
44       MatrixMetadata(const PBQP::Matrix& M)
45         : WorstRow(0), WorstCol(0),
46           UnsafeRows(new bool[M.getRows() - 1]()),
47           UnsafeCols(new bool[M.getCols() - 1]()) {
48
49         unsigned* ColCounts = new unsigned[M.getCols() - 1]();
50
51         for (unsigned i = 1; i < M.getRows(); ++i) {
52           unsigned RowCount = 0;
53           for (unsigned j = 1; j < M.getCols(); ++j) {
54             if (M[i][j] == std::numeric_limits<PBQP::PBQPNum>::infinity()) {
55               ++RowCount;
56               ++ColCounts[j - 1];
57               UnsafeRows[i - 1] = true;
58               UnsafeCols[j - 1] = true;
59             }
60           }
61           WorstRow = std::max(WorstRow, RowCount);
62         }
63         unsigned WorstColCountForCurRow =
64           *std::max_element(ColCounts, ColCounts + M.getCols() - 1);
65         WorstCol = std::max(WorstCol, WorstColCountForCurRow);
66         delete[] ColCounts;
67       }
68
69       ~MatrixMetadata() {
70         delete[] UnsafeRows;
71         delete[] UnsafeCols;
72       }
73
74       unsigned getWorstRow() const { return WorstRow; }
75       unsigned getWorstCol() const { return WorstCol; }
76       const bool* getUnsafeRows() const { return UnsafeRows; }
77       const bool* getUnsafeCols() const { return UnsafeCols; }
78
79     private:
80       unsigned WorstRow, WorstCol;
81       bool* UnsafeRows;
82       bool* UnsafeCols;
83     };
84
85     class NodeMetadata {
86     public:
87       typedef std::vector<unsigned> OptionToRegMap;
88
89       typedef enum { Unprocessed,
90                      OptimallyReducible,
91                      ConservativelyAllocatable,
92                      NotProvablyAllocatable } ReductionState;
93
94       NodeMetadata() : RS(Unprocessed), DeniedOpts(0), OptUnsafeEdges(nullptr){}
95       ~NodeMetadata() { delete[] OptUnsafeEdges; }
96
97       void setVReg(unsigned VReg) { this->VReg = VReg; }
98       unsigned getVReg() const { return VReg; }
99
100       void setOptionRegs(OptionToRegMap OptionRegs) {
101         this->OptionRegs = std::move(OptionRegs);
102       }
103       const OptionToRegMap& getOptionRegs() const { return OptionRegs; }
104
105       void setup(const Vector& Costs) {
106         NumOpts = Costs.getLength() - 1;
107         OptUnsafeEdges = new unsigned[NumOpts]();
108       }
109
110       ReductionState getReductionState() const { return RS; }
111       void setReductionState(ReductionState RS) { this->RS = RS; }
112
113       void handleAddEdge(const MatrixMetadata& MD, bool Transpose) {
114         DeniedOpts += Transpose ? MD.getWorstCol() : MD.getWorstRow();
115         const bool* UnsafeOpts =
116           Transpose ? MD.getUnsafeCols() : MD.getUnsafeRows();
117         for (unsigned i = 0; i < NumOpts; ++i)
118           OptUnsafeEdges[i] += UnsafeOpts[i];
119       }
120
121       void handleRemoveEdge(const MatrixMetadata& MD, bool Transpose) {
122         DeniedOpts -= Transpose ? MD.getWorstCol() : MD.getWorstRow();
123         const bool* UnsafeOpts =
124           Transpose ? MD.getUnsafeCols() : MD.getUnsafeRows();
125         for (unsigned i = 0; i < NumOpts; ++i)
126           OptUnsafeEdges[i] -= UnsafeOpts[i];
127       }
128
129       bool isConservativelyAllocatable() const {
130         return (DeniedOpts < NumOpts) ||
131                (std::find(OptUnsafeEdges, OptUnsafeEdges + NumOpts, 0) !=
132                   OptUnsafeEdges + NumOpts);
133       }
134
135     private:
136       ReductionState RS;
137       unsigned NumOpts;
138       unsigned DeniedOpts;
139       unsigned* OptUnsafeEdges;
140       unsigned VReg;
141       OptionToRegMap OptionRegs;
142     };
143
144     class RegAllocSolverImpl {
145     private:
146       typedef PBQP::MDMatrix<MatrixMetadata> RAMatrix;
147     public:
148       typedef PBQP::Vector RawVector;
149       typedef PBQP::Matrix RawMatrix;
150       typedef PBQP::Vector Vector;
151       typedef RAMatrix     Matrix;
152       typedef PBQP::PoolCostAllocator<
153                 Vector, PBQP::VectorComparator,
154                 Matrix, PBQP::MatrixComparator> CostAllocator;
155
156       typedef PBQP::GraphBase::NodeId NodeId;
157       typedef PBQP::GraphBase::EdgeId EdgeId;
158
159       typedef RegAlloc::NodeMetadata NodeMetadata;
160
161       struct EdgeMetadata { };
162
163       class GraphMetadata {
164       public:
165         GraphMetadata(MachineFunction &MF,
166                       LiveIntervals &LIS,
167                       MachineBlockFrequencyInfo &MBFI)
168           : MF(MF), LIS(LIS), MBFI(MBFI) {}
169
170         MachineFunction &MF;
171         LiveIntervals &LIS;
172         MachineBlockFrequencyInfo &MBFI;
173
174         void setNodeIdForVReg(unsigned VReg, GraphBase::NodeId NId) {
175           VRegToNodeId[VReg] = NId;
176         }
177
178         GraphBase::NodeId getNodeIdForVReg(unsigned VReg) const {
179           auto VRegItr = VRegToNodeId.find(VReg);
180           if (VRegItr == VRegToNodeId.end())
181             return GraphBase::invalidNodeId();
182           return VRegItr->second;
183         }
184
185         void eraseNodeIdForVReg(unsigned VReg) {
186           VRegToNodeId.erase(VReg);
187         }
188
189       private:
190         DenseMap<unsigned, NodeId> VRegToNodeId;
191       };
192
193       typedef PBQP::Graph<RegAllocSolverImpl> Graph;
194
195       RegAllocSolverImpl(Graph &G) : G(G) {}
196
197       Solution solve() {
198         G.setSolver(*this);
199         Solution S;
200         setup();
201         S = backpropagate(G, reduce());
202         G.unsetSolver();
203         return S;
204       }
205
206       void handleAddNode(NodeId NId) {
207         G.getNodeMetadata(NId).setup(G.getNodeCosts(NId));
208       }
209       void handleRemoveNode(NodeId NId) {}
210       void handleSetNodeCosts(NodeId NId, const Vector& newCosts) {}
211
212       void handleAddEdge(EdgeId EId) {
213         handleReconnectEdge(EId, G.getEdgeNode1Id(EId));
214         handleReconnectEdge(EId, G.getEdgeNode2Id(EId));
215       }
216
217       void handleRemoveEdge(EdgeId EId) {
218         handleDisconnectEdge(EId, G.getEdgeNode1Id(EId));
219         handleDisconnectEdge(EId, G.getEdgeNode2Id(EId));
220       }
221
222       void handleDisconnectEdge(EdgeId EId, NodeId NId) {
223         NodeMetadata& NMd = G.getNodeMetadata(NId);
224         const MatrixMetadata& MMd = G.getEdgeCosts(EId).getMetadata();
225         NMd.handleRemoveEdge(MMd, NId == G.getEdgeNode2Id(EId));
226         if (G.getNodeDegree(NId) == 3) {
227           // This node is becoming optimally reducible.
228           moveToOptimallyReducibleNodes(NId);
229         } else if (NMd.getReductionState() ==
230                      NodeMetadata::NotProvablyAllocatable &&
231                    NMd.isConservativelyAllocatable()) {
232           // This node just became conservatively allocatable.
233           moveToConservativelyAllocatableNodes(NId);
234         }
235       }
236
237       void handleReconnectEdge(EdgeId EId, NodeId NId) {
238         NodeMetadata& NMd = G.getNodeMetadata(NId);
239         const MatrixMetadata& MMd = G.getEdgeCosts(EId).getMetadata();
240         NMd.handleAddEdge(MMd, NId == G.getEdgeNode2Id(EId));
241       }
242
243       void handleSetEdgeCosts(EdgeId EId, const Matrix& NewCosts) {
244         handleRemoveEdge(EId);
245
246         NodeId N1Id = G.getEdgeNode1Id(EId);
247         NodeId N2Id = G.getEdgeNode2Id(EId);
248         NodeMetadata& N1Md = G.getNodeMetadata(N1Id);
249         NodeMetadata& N2Md = G.getNodeMetadata(N2Id);
250         const MatrixMetadata& MMd = NewCosts.getMetadata();
251         N1Md.handleAddEdge(MMd, N1Id != G.getEdgeNode1Id(EId));
252         N2Md.handleAddEdge(MMd, N2Id != G.getEdgeNode1Id(EId));
253       }
254
255     private:
256
257       void removeFromCurrentSet(NodeId NId) {
258         switch (G.getNodeMetadata(NId).getReductionState()) {
259           case NodeMetadata::Unprocessed: break;
260           case NodeMetadata::OptimallyReducible:
261             assert(OptimallyReducibleNodes.find(NId) !=
262                      OptimallyReducibleNodes.end() &&
263                    "Node not in optimally reducible set.");
264             OptimallyReducibleNodes.erase(NId);
265             break;
266           case NodeMetadata::ConservativelyAllocatable:
267             assert(ConservativelyAllocatableNodes.find(NId) !=
268                      ConservativelyAllocatableNodes.end() &&
269                    "Node not in conservatively allocatable set.");
270             ConservativelyAllocatableNodes.erase(NId);
271             break;
272           case NodeMetadata::NotProvablyAllocatable:
273             assert(NotProvablyAllocatableNodes.find(NId) !=
274                      NotProvablyAllocatableNodes.end() &&
275                    "Node not in not-provably-allocatable set.");
276             NotProvablyAllocatableNodes.erase(NId);
277             break;
278         }
279       }
280
281       void moveToOptimallyReducibleNodes(NodeId NId) {
282         removeFromCurrentSet(NId);
283         OptimallyReducibleNodes.insert(NId);
284         G.getNodeMetadata(NId).setReductionState(
285           NodeMetadata::OptimallyReducible);
286       }
287
288       void moveToConservativelyAllocatableNodes(NodeId NId) {
289         removeFromCurrentSet(NId);
290         ConservativelyAllocatableNodes.insert(NId);
291         G.getNodeMetadata(NId).setReductionState(
292           NodeMetadata::ConservativelyAllocatable);
293       }
294
295       void moveToNotProvablyAllocatableNodes(NodeId NId) {
296         removeFromCurrentSet(NId);
297         NotProvablyAllocatableNodes.insert(NId);
298         G.getNodeMetadata(NId).setReductionState(
299           NodeMetadata::NotProvablyAllocatable);
300       }
301
302       void setup() {
303         // Set up worklists.
304         for (auto NId : G.nodeIds()) {
305           if (G.getNodeDegree(NId) < 3)
306             moveToOptimallyReducibleNodes(NId);
307           else if (G.getNodeMetadata(NId).isConservativelyAllocatable())
308             moveToConservativelyAllocatableNodes(NId);
309           else
310             moveToNotProvablyAllocatableNodes(NId);
311         }
312       }
313
314       // Compute a reduction order for the graph by iteratively applying PBQP
315       // reduction rules. Locally optimal rules are applied whenever possible (R0,
316       // R1, R2). If no locally-optimal rules apply then any conservatively
317       // allocatable node is reduced. Finally, if no conservatively allocatable
318       // node exists then the node with the lowest spill-cost:degree ratio is
319       // selected.
320       std::vector<GraphBase::NodeId> reduce() {
321         assert(!G.empty() && "Cannot reduce empty graph.");
322
323         typedef GraphBase::NodeId NodeId;
324         std::vector<NodeId> NodeStack;
325
326         // Consume worklists.
327         while (true) {
328           if (!OptimallyReducibleNodes.empty()) {
329             NodeSet::iterator NItr = OptimallyReducibleNodes.begin();
330             NodeId NId = *NItr;
331             OptimallyReducibleNodes.erase(NItr);
332             NodeStack.push_back(NId);
333             switch (G.getNodeDegree(NId)) {
334               case 0:
335                 break;
336               case 1:
337                 applyR1(G, NId);
338                 break;
339               case 2:
340                 applyR2(G, NId);
341                 break;
342               default: llvm_unreachable("Not an optimally reducible node.");
343             }
344           } else if (!ConservativelyAllocatableNodes.empty()) {
345             // Conservatively allocatable nodes will never spill. For now just
346             // take the first node in the set and push it on the stack. When we
347             // start optimizing more heavily for register preferencing, it may
348             // would be better to push nodes with lower 'expected' or worst-case
349             // register costs first (since early nodes are the most
350             // constrained).
351             NodeSet::iterator NItr = ConservativelyAllocatableNodes.begin();
352             NodeId NId = *NItr;
353             ConservativelyAllocatableNodes.erase(NItr);
354             NodeStack.push_back(NId);
355             G.disconnectAllNeighborsFromNode(NId);
356
357           } else if (!NotProvablyAllocatableNodes.empty()) {
358             NodeSet::iterator NItr =
359               std::min_element(NotProvablyAllocatableNodes.begin(),
360                                NotProvablyAllocatableNodes.end(),
361                                SpillCostComparator(G));
362             NodeId NId = *NItr;
363             NotProvablyAllocatableNodes.erase(NItr);
364             NodeStack.push_back(NId);
365             G.disconnectAllNeighborsFromNode(NId);
366           } else
367             break;
368         }
369
370         return NodeStack;
371       }
372
373       class SpillCostComparator {
374       public:
375         SpillCostComparator(const Graph& G) : G(G) {}
376         bool operator()(NodeId N1Id, NodeId N2Id) {
377           PBQPNum N1SC = G.getNodeCosts(N1Id)[0] / G.getNodeDegree(N1Id);
378           PBQPNum N2SC = G.getNodeCosts(N2Id)[0] / G.getNodeDegree(N2Id);
379           return N1SC < N2SC;
380         }
381       private:
382         const Graph& G;
383       };
384
385       Graph& G;
386       typedef std::set<NodeId> NodeSet;
387       NodeSet OptimallyReducibleNodes;
388       NodeSet ConservativelyAllocatableNodes;
389       NodeSet NotProvablyAllocatableNodes;
390     };
391
392     class PBQPRAGraph : public PBQP::Graph<RegAllocSolverImpl> {
393     private:
394       typedef PBQP::Graph<RegAllocSolverImpl> BaseT;
395     public:
396       PBQPRAGraph(GraphMetadata Metadata) : BaseT(Metadata) {}
397     };
398
399     inline Solution solve(PBQPRAGraph& G) {
400       if (G.empty())
401         return Solution();
402       RegAllocSolverImpl RegAllocSolver(G);
403       return RegAllocSolver.solve();
404     }
405   } // namespace RegAlloc
406 } // namespace PBQP
407 } // namespace llvm
408
409 #endif // LLVM_CODEGEN_PBQP_REGALLOCSOLVER_H