f0105b0a2622e4bc2b28531672e969b8cff9bc1f
[oota-llvm.git] / lib / Target / AArch64 / AArch64PBQPRegAlloc.cpp
1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 // This file contains the AArch64 / Cortex-A57 specific register allocation
10 // constraints for use by the PBQP register allocator.
11 //
12 // It is essentially a transcription of what is contained in
13 // AArch64A57FPLoadBalancing, which tries to use a balanced
14 // mix of odd and even D-registers when performing a critical sequence of
15 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
16 //===----------------------------------------------------------------------===//
17
18 #define DEBUG_TYPE "aarch64-pbqp"
19
20 #include "AArch64.h"
21 #include "AArch64PBQPRegAlloc.h"
22 #include "AArch64RegisterInfo.h"
23 #include "llvm/CodeGen/LiveIntervalAnalysis.h"
24 #include "llvm/CodeGen/MachineBasicBlock.h"
25 #include "llvm/CodeGen/MachineFunction.h"
26 #include "llvm/CodeGen/MachineRegisterInfo.h"
27 #include "llvm/CodeGen/RegAllocPBQP.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/raw_ostream.h"
31
32 using namespace llvm;
33
34 namespace {
35
36 #ifndef NDEBUG
37 bool isFPReg(unsigned reg) {
38   return AArch64::FPR32RegClass.contains(reg) ||
39          AArch64::FPR64RegClass.contains(reg) ||
40          AArch64::FPR128RegClass.contains(reg);
41 }
42 #endif
43
44 bool isOdd(unsigned reg) {
45   switch (reg) {
46   default:
47     llvm_unreachable("Register is not from the expected class !");
48   case AArch64::S1:
49   case AArch64::S3:
50   case AArch64::S5:
51   case AArch64::S7:
52   case AArch64::S9:
53   case AArch64::S11:
54   case AArch64::S13:
55   case AArch64::S15:
56   case AArch64::S17:
57   case AArch64::S19:
58   case AArch64::S21:
59   case AArch64::S23:
60   case AArch64::S25:
61   case AArch64::S27:
62   case AArch64::S29:
63   case AArch64::S31:
64   case AArch64::D1:
65   case AArch64::D3:
66   case AArch64::D5:
67   case AArch64::D7:
68   case AArch64::D9:
69   case AArch64::D11:
70   case AArch64::D13:
71   case AArch64::D15:
72   case AArch64::D17:
73   case AArch64::D19:
74   case AArch64::D21:
75   case AArch64::D23:
76   case AArch64::D25:
77   case AArch64::D27:
78   case AArch64::D29:
79   case AArch64::D31:
80   case AArch64::Q1:
81   case AArch64::Q3:
82   case AArch64::Q5:
83   case AArch64::Q7:
84   case AArch64::Q9:
85   case AArch64::Q11:
86   case AArch64::Q13:
87   case AArch64::Q15:
88   case AArch64::Q17:
89   case AArch64::Q19:
90   case AArch64::Q21:
91   case AArch64::Q23:
92   case AArch64::Q25:
93   case AArch64::Q27:
94   case AArch64::Q29:
95   case AArch64::Q31:
96     return true;
97   case AArch64::S0:
98   case AArch64::S2:
99   case AArch64::S4:
100   case AArch64::S6:
101   case AArch64::S8:
102   case AArch64::S10:
103   case AArch64::S12:
104   case AArch64::S14:
105   case AArch64::S16:
106   case AArch64::S18:
107   case AArch64::S20:
108   case AArch64::S22:
109   case AArch64::S24:
110   case AArch64::S26:
111   case AArch64::S28:
112   case AArch64::S30:
113   case AArch64::D0:
114   case AArch64::D2:
115   case AArch64::D4:
116   case AArch64::D6:
117   case AArch64::D8:
118   case AArch64::D10:
119   case AArch64::D12:
120   case AArch64::D14:
121   case AArch64::D16:
122   case AArch64::D18:
123   case AArch64::D20:
124   case AArch64::D22:
125   case AArch64::D24:
126   case AArch64::D26:
127   case AArch64::D28:
128   case AArch64::D30:
129   case AArch64::Q0:
130   case AArch64::Q2:
131   case AArch64::Q4:
132   case AArch64::Q6:
133   case AArch64::Q8:
134   case AArch64::Q10:
135   case AArch64::Q12:
136   case AArch64::Q14:
137   case AArch64::Q16:
138   case AArch64::Q18:
139   case AArch64::Q20:
140   case AArch64::Q22:
141   case AArch64::Q24:
142   case AArch64::Q26:
143   case AArch64::Q28:
144   case AArch64::Q30:
145     return false;
146
147   }
148 }
149
150 bool haveSameParity(unsigned reg1, unsigned reg2) {
151   assert(isFPReg(reg1) && "Expecting an FP register for reg1");
152   assert(isFPReg(reg2) && "Expecting an FP register for reg2");
153
154   return isOdd(reg1) == isOdd(reg2);
155 }
156
157 }
158
159 bool A57PBQPConstraints::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
160                                                  unsigned Ra) {
161   if (Rd == Ra)
162     return false;
163
164   const TargetRegisterInfo &TRI =
165     *G.getMetadata().MF.getTarget().getSubtargetImpl()->getRegisterInfo();
166   LiveIntervals &LIs = G.getMetadata().LIS;
167
168   if (TRI.isPhysicalRegister(Rd) || TRI.isPhysicalRegister(Ra)) {
169     DEBUG(dbgs() << "Rd is a physical reg:" << TRI.isPhysicalRegister(Rd)
170           << '\n');
171     DEBUG(dbgs() << "Ra is a physical reg:" << TRI.isPhysicalRegister(Ra)
172           << '\n');
173     return false;
174   }
175
176   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
177   PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
178
179   const PBQPRAGraph::NodeMetadata::OptionToRegMap *vRdAllowed =
180     &G.getNodeMetadata(node1).getOptionRegs();
181   const PBQPRAGraph::NodeMetadata::OptionToRegMap *vRaAllowed =
182     &G.getNodeMetadata(node2).getOptionRegs();
183
184   PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
185
186   // The edge does not exist. Create one with the appropriate interference
187   // costs.
188   if (edge == G.invalidEdgeId()) {
189     const LiveInterval &ld = LIs.getInterval(Rd);
190     const LiveInterval &la = LIs.getInterval(Ra);
191     bool livesOverlap = ld.overlaps(la);
192
193     PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
194                                  vRaAllowed->size() + 1, 0);
195     for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
196       unsigned pRd = (*vRdAllowed)[i];
197       for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
198         unsigned pRa = (*vRaAllowed)[j];
199         if (livesOverlap && TRI.regsOverlap(pRd, pRa))
200           costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
201         else
202           costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
203       }
204     }
205     G.addEdge(node1, node2, std::move(costs));
206     return true;
207   }
208
209   if (G.getEdgeNode1Id(edge) == node2) {
210     std::swap(node1, node2);
211     std::swap(vRdAllowed, vRaAllowed);
212   }
213
214   // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
215   PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
216   for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
217     unsigned pRd = (*vRdAllowed)[i];
218
219     // Get the maximum cost (excluding unallocatable reg) for same parity
220     // registers
221     PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
222     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
223       unsigned pRa = (*vRaAllowed)[j];
224       if (haveSameParity(pRd, pRa))
225         if (costs[i + 1][j + 1] !=
226                 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
227             costs[i + 1][j + 1] > sameParityMax)
228           sameParityMax = costs[i + 1][j + 1];
229     }
230
231     // Ensure all registers with a different parity have a higher cost
232     // than sameParityMax
233     for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
234       unsigned pRa = (*vRaAllowed)[j];
235       if (!haveSameParity(pRd, pRa))
236         if (sameParityMax > costs[i + 1][j + 1])
237           costs[i + 1][j + 1] = sameParityMax + 1.0;
238     }
239   }
240   G.setEdgeCosts(edge, std::move(costs));
241
242   return true;
243 }
244
245 void A57PBQPConstraints::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
246                                                  unsigned Ra) {
247   const TargetRegisterInfo &TRI =
248     *G.getMetadata().MF.getTarget().getSubtargetImpl()->getRegisterInfo();
249   (void)TRI;
250   LiveIntervals &LIs = G.getMetadata().LIS;
251
252   // Do some Chain management
253   if (Chains.count(Ra)) {
254     if (Rd != Ra) {
255       DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, &TRI) << " to "
256                    << PrintReg(Rd, &TRI) << '\n';);
257       Chains.remove(Ra);
258       Chains.insert(Rd);
259     }
260   } else {
261     DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, &TRI)
262                  << '\n';);
263     Chains.insert(Rd);
264   }
265
266   PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
267
268   const LiveInterval &ld = LIs.getInterval(Rd);
269   for (auto r : Chains) {
270     // Skip self
271     if (r == Rd)
272       continue;
273
274     const LiveInterval &lr = LIs.getInterval(r);
275     if (ld.overlaps(lr)) {
276       const PBQPRAGraph::NodeMetadata::OptionToRegMap *vRdAllowed =
277         &G.getNodeMetadata(node1).getOptionRegs();
278
279       PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
280       const PBQPRAGraph::NodeMetadata::OptionToRegMap *vRrAllowed =
281         &G.getNodeMetadata(node2).getOptionRegs();
282
283       PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
284       assert(edge != G.invalidEdgeId() &&
285              "PBQP error ! The edge should exist !");
286
287       DEBUG(dbgs() << "Refining constraint !\n";);
288
289       if (G.getEdgeNode1Id(edge) == node2) {
290         std::swap(node1, node2);
291         std::swap(vRdAllowed, vRrAllowed);
292       }
293
294       // Enforce that cost is higher with all other Chains of the same parity
295       PBQP::Matrix costs(G.getEdgeCosts(edge));
296       for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
297         unsigned pRd = (*vRdAllowed)[i];
298
299         // Get the maximum cost (excluding unallocatable reg) for all other
300         // parity registers
301         PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
302         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
303           unsigned pRa = (*vRrAllowed)[j];
304           if (!haveSameParity(pRd, pRa))
305             if (costs[i + 1][j + 1] !=
306                     std::numeric_limits<PBQP::PBQPNum>::infinity() &&
307                 costs[i + 1][j + 1] > sameParityMax)
308               sameParityMax = costs[i + 1][j + 1];
309         }
310
311         // Ensure all registers with same parity have a higher cost
312         // than sameParityMax
313         for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
314           unsigned pRa = (*vRrAllowed)[j];
315           if (haveSameParity(pRd, pRa))
316             if (sameParityMax > costs[i + 1][j + 1])
317               costs[i + 1][j + 1] = sameParityMax + 1.0;
318         }
319       }
320       G.setEdgeCosts(edge, std::move(costs));
321     }
322   }
323 }
324
325 void A57PBQPConstraints::apply(PBQPRAGraph &G) {
326   MachineFunction &MF = G.getMetadata().MF;
327
328   const TargetRegisterInfo &TRI =
329     *MF.getTarget().getSubtargetImpl()->getRegisterInfo();
330   (void)TRI;
331   DEBUG(MF.dump());
332
333   for (MachineFunction::const_iterator mbbItr = MF.begin(), mbbEnd = MF.end();
334        mbbItr != mbbEnd; ++mbbItr) {
335     const MachineBasicBlock *MBB = &*mbbItr;
336     Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
337
338     for (MachineBasicBlock::const_iterator miItr = MBB->begin(),
339                                            miEnd = MBB->end();
340          miItr != miEnd; ++miItr) {
341       const MachineInstr *MI = &*miItr;
342       switch (MI->getOpcode()) {
343       case AArch64::FMSUBSrrr:
344       case AArch64::FMADDSrrr:
345       case AArch64::FNMSUBSrrr:
346       case AArch64::FNMADDSrrr:
347       case AArch64::FMSUBDrrr:
348       case AArch64::FMADDDrrr:
349       case AArch64::FNMSUBDrrr:
350       case AArch64::FNMADDDrrr: {
351         unsigned Rd = MI->getOperand(0).getReg();
352         unsigned Ra = MI->getOperand(3).getReg();
353
354         if (addIntraChainConstraint(G, Rd, Ra))
355           addInterChainConstraint(G, Rd, Ra);
356         break;
357       }
358
359       case AArch64::FMLAv2f32:
360       case AArch64::FMLSv2f32: {
361         unsigned Rd = MI->getOperand(0).getReg();
362         addInterChainConstraint(G, Rd, Rd);
363         break;
364       }
365
366       default:
367         // Forget Chains which have been killed
368         for (auto r : Chains) {
369           SmallVector<unsigned, 8> toDel;
370           if (MI->killsRegister(r)) {
371             DEBUG(dbgs() << "Killing chain " << PrintReg(r, &TRI) << " at ";
372                   MI->print(dbgs()););
373             toDel.push_back(r);
374           }
375
376           while (!toDel.empty()) {
377             Chains.remove(toDel.back());
378             toDel.pop_back();
379           }
380         }
381       }
382     }
383   }
384 }