1 //===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
9 // This file contains the AArch64 / Cortex-A57 specific register allocation
10 // constraints for use by the PBQP register allocator.
12 // It is essentially a transcription of what is contained in
13 // AArch64A57FPLoadBalancing, which tries to use a balanced
14 // mix of odd and even D-registers when performing a critical sequence of
15 // independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
16 //===----------------------------------------------------------------------===//
18 #define DEBUG_TYPE "aarch64-pbqp"
21 #include "AArch64RegisterInfo.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/CodeGen/LiveIntervalAnalysis.h"
25 #include "llvm/CodeGen/MachineBasicBlock.h"
26 #include "llvm/CodeGen/MachineFunction.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/CodeGen/RegAllocPBQP.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 #include "llvm/Support/raw_ostream.h"
33 #define PBQP_BUILDER PBQPBuilderWithCoalescing
34 //#define PBQP_BUILDER PBQPBuilder
41 bool isFPReg(unsigned reg) {
42 return AArch64::FPR32RegClass.contains(reg) ||
43 AArch64::FPR64RegClass.contains(reg) ||
44 AArch64::FPR128RegClass.contains(reg);
48 bool isOdd(unsigned reg) {
51 llvm_unreachable("Register is not from the expected class !");
154 bool haveSameParity(unsigned reg1, unsigned reg2) {
155 assert(isFPReg(reg1) && "Expecting an FP register for reg1");
156 assert(isFPReg(reg2) && "Expecting an FP register for reg2");
158 return isOdd(reg1) == isOdd(reg2);
161 class A57PBQPBuilder : public PBQP_BUILDER {
163 A57PBQPBuilder() : PBQP_BUILDER(), TRI(nullptr), LIs(nullptr), Chains() {}
165 // Build a PBQP instance to represent the register allocation problem for
166 // the given MachineFunction.
167 std::unique_ptr<PBQPRAProblem>
168 build(MachineFunction *MF, const LiveIntervals *LI,
169 const MachineBlockFrequencyInfo *blockInfo,
170 const RegSet &VRegs) override;
173 const AArch64RegisterInfo *TRI;
174 const LiveIntervals *LIs;
175 SmallSetVector<unsigned, 32> Chains;
177 // Return true if reg is a physical register
178 bool isPhysicalReg(unsigned reg) const {
179 return TRI->isPhysicalRegister(reg);
182 // Add the accumulator chaining constraint, inside the chain, i.e. so that
183 // parity(Rd) == parity(Ra).
184 // \return true if a constraint was added
185 bool addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
187 // Add constraints between existing chains
188 void addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
190 } // Anonymous namespace
192 bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd,
197 if (isPhysicalReg(Rd) || isPhysicalReg(Ra)) {
198 dbgs() << "Rd is a physical reg:" << isPhysicalReg(Rd) << '\n';
199 dbgs() << "Ra is a physical reg:" << isPhysicalReg(Ra) << '\n';
203 const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
204 const PBQPRAProblem::AllowedSet *vRaAllowed = &p->getAllowedSet(Ra);
206 PBQPRAGraph &g = p->getGraph();
207 PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
208 PBQPRAGraph::NodeId node2 = p->getNodeForVReg(Ra);
209 PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
211 // The edge does not exist. Create one with the appropriate interference
213 if (edge == g.invalidEdgeId()) {
214 const LiveInterval &ld = LIs->getInterval(Rd);
215 const LiveInterval &la = LIs->getInterval(Ra);
216 bool livesOverlap = ld.overlaps(la);
218 PBQP::Matrix costs(vRdAllowed->size() + 1, vRaAllowed->size() + 1, 0);
219 for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
220 unsigned pRd = (*vRdAllowed)[i];
221 for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
222 unsigned pRa = (*vRaAllowed)[j];
223 if (livesOverlap && TRI->regsOverlap(pRd, pRa))
224 costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
226 costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
229 g.addEdge(node1, node2, std::move(costs));
233 if (g.getEdgeNode1Id(edge) == node2) {
234 std::swap(node1, node2);
235 std::swap(vRdAllowed, vRaAllowed);
238 // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
239 PBQP::Matrix costs(g.getEdgeCosts(edge));
240 for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
241 unsigned pRd = (*vRdAllowed)[i];
243 // Get the maximum cost (excluding unallocatable reg) for same parity
245 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
246 for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
247 unsigned pRa = (*vRaAllowed)[j];
248 if (haveSameParity(pRd, pRa))
249 if (costs[i + 1][j + 1] !=
250 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
251 costs[i + 1][j + 1] > sameParityMax)
252 sameParityMax = costs[i + 1][j + 1];
255 // Ensure all registers with a different parity have a higher cost
256 // than sameParityMax
257 for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
258 unsigned pRa = (*vRaAllowed)[j];
259 if (!haveSameParity(pRd, pRa))
260 if (sameParityMax > costs[i + 1][j + 1])
261 costs[i + 1][j + 1] = sameParityMax + 1.0;
264 g.setEdgeCosts(edge, costs);
270 A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd,
272 // Do some Chain management
273 if (Chains.count(Ra)) {
275 DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to "
276 << PrintReg(Rd, TRI) << '\n';);
281 DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI)
286 const LiveInterval &ld = LIs->getInterval(Rd);
287 for (auto r : Chains) {
292 const LiveInterval &lr = LIs->getInterval(r);
293 if (ld.overlaps(lr)) {
294 const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
295 const PBQPRAProblem::AllowedSet *vRrAllowed = &p->getAllowedSet(r);
297 PBQPRAGraph &g = p->getGraph();
298 PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
299 PBQPRAGraph::NodeId node2 = p->getNodeForVReg(r);
300 PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
301 assert(edge != g.invalidEdgeId() &&
302 "PBQP error ! The edge should exist !");
304 DEBUG(dbgs() << "Refining constraint !\n";);
306 if (g.getEdgeNode1Id(edge) == node2) {
307 std::swap(node1, node2);
308 std::swap(vRdAllowed, vRrAllowed);
311 // Enforce that cost is higher with all other Chains of the same parity
312 PBQP::Matrix costs(g.getEdgeCosts(edge));
313 for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
314 unsigned pRd = (*vRdAllowed)[i];
316 // Get the maximum cost (excluding unallocatable reg) for all other
318 PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
319 for (unsigned j = 0; j != vRrAllowed->size(); ++j) {
320 unsigned pRa = (*vRrAllowed)[j];
321 if (!haveSameParity(pRd, pRa))
322 if (costs[i + 1][j + 1] !=
323 std::numeric_limits<PBQP::PBQPNum>::infinity() &&
324 costs[i + 1][j + 1] > sameParityMax)
325 sameParityMax = costs[i + 1][j + 1];
328 // Ensure all registers with same parity have a higher cost
329 // than sameParityMax
330 for (unsigned j = 0; j != vRrAllowed->size(); ++j) {
331 unsigned pRa = (*vRrAllowed)[j];
332 if (haveSameParity(pRd, pRa))
333 if (sameParityMax > costs[i + 1][j + 1])
334 costs[i + 1][j + 1] = sameParityMax + 1.0;
337 g.setEdgeCosts(edge, costs);
342 std::unique_ptr<PBQPRAProblem>
343 A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI,
344 const MachineBlockFrequencyInfo *blockInfo,
345 const RegSet &VRegs) {
346 std::unique_ptr<PBQPRAProblem> p =
347 PBQP_BUILDER::build(MF, LI, blockInfo, VRegs);
349 TRI = static_cast<const AArch64RegisterInfo *>(
350 MF->getTarget().getSubtargetImpl()->getRegisterInfo());
355 for (MachineFunction::const_iterator mbbItr = MF->begin(), mbbEnd = MF->end();
356 mbbItr != mbbEnd; ++mbbItr) {
357 const MachineBasicBlock *MBB = &*mbbItr;
358 Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
360 for (MachineBasicBlock::const_iterator miItr = MBB->begin(),
362 miItr != miEnd; ++miItr) {
363 const MachineInstr *MI = &*miItr;
364 switch (MI->getOpcode()) {
365 case AArch64::FMSUBSrrr:
366 case AArch64::FMADDSrrr:
367 case AArch64::FNMSUBSrrr:
368 case AArch64::FNMADDSrrr:
369 case AArch64::FMSUBDrrr:
370 case AArch64::FMADDDrrr:
371 case AArch64::FNMSUBDrrr:
372 case AArch64::FNMADDDrrr: {
373 unsigned Rd = MI->getOperand(0).getReg();
374 unsigned Ra = MI->getOperand(3).getReg();
376 if (addIntraChainConstraint(p.get(), Rd, Ra))
377 addInterChainConstraint(p.get(), Rd, Ra);
381 case AArch64::FMLAv2f32:
382 case AArch64::FMLSv2f32: {
383 unsigned Rd = MI->getOperand(0).getReg();
384 addInterChainConstraint(p.get(), Rd, Rd);
389 // Forget Chains which have been killed
390 for (auto r : Chains) {
391 SmallVector<unsigned, 8> toDel;
392 if (MI->killsRegister(r)) {
393 DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
398 while (!toDel.empty()) {
399 Chains.remove(toDel.back());
410 // Factory function used by AArch64TargetMachine to add the pass to the
412 FunctionPass *llvm::createAArch64A57PBQPRegAlloc() {
413 std::unique_ptr<PBQP_BUILDER> builder = llvm::make_unique<A57PBQPBuilder>();
414 return createPBQPRegisterAllocator(std::move(builder), nullptr);