X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTarget%2FAArch64%2FAArch64PBQPRegAlloc.cpp;h=5394875a6bc12f96ef219bddf8f9fafe56dcea05;hb=15c5be1ee58a67965bee79832441f1136a7698dc;hp=86e21732a132c3c6b761617c5caf1208825ae8e7;hpb=4a93e8dd02f6fea5cbd1c5aefb2b7cd035386420;p=oota-llvm.git diff --git a/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp b/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp index 86e21732a13..5394875a6bc 100644 --- a/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp +++ b/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp @@ -18,9 +18,8 @@ #define DEBUG_TYPE "aarch64-pbqp" #include "AArch64.h" +#include "AArch64PBQPRegAlloc.h" #include "AArch64RegisterInfo.h" - -#include "llvm/ADT/SetVector.h" #include "llvm/CodeGen/LiveIntervalAnalysis.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineFunction.h" @@ -30,8 +29,6 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#define PBQP_BUILDER PBQPBuilderWithCoalescing - using namespace llvm; namespace { @@ -157,64 +154,42 @@ bool haveSameParity(unsigned reg1, unsigned reg2) { return isOdd(reg1) == isOdd(reg2); } -class A57PBQPBuilder : public PBQP_BUILDER { -public: - A57PBQPBuilder() : PBQP_BUILDER(), TRI(nullptr), LIs(nullptr), Chains() {} - - // Build a PBQP instance to represent the register allocation problem for - // the given MachineFunction. - std::unique_ptr - build(MachineFunction *MF, const LiveIntervals *LI, - const MachineBlockFrequencyInfo *blockInfo, - const RegSet &VRegs) override; - -private: - const AArch64RegisterInfo *TRI; - const LiveIntervals *LIs; - SmallSetVector Chains; - - // Return true if reg is a physical register - bool isPhysicalReg(unsigned reg) const { - return TRI->isPhysicalRegister(reg); - } - - // Add the accumulator chaining constraint, inside the chain, i.e. so that - // parity(Rd) == parity(Ra). - // \return true if a constraint was added - bool addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra); - - // Add constraints between existing chains - void addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra); -}; -} // Anonymous namespace +} -bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, - unsigned Ra) { +bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd, + unsigned Ra) { if (Rd == Ra) return false; - if (isPhysicalReg(Rd) || isPhysicalReg(Ra)) { - DEBUG(dbgs() << "Rd is a physical reg:" << isPhysicalReg(Rd) << '\n'); - DEBUG(dbgs() << "Ra is a physical reg:" << isPhysicalReg(Ra) << '\n'); + LiveIntervals &LIs = G.getMetadata().LIS; + + if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) { + DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd) + << '\n'); + DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra) + << '\n'); return false; } - const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd); - const PBQPRAProblem::AllowedSet *vRaAllowed = &p->getAllowedSet(Ra); + PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); + PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra); + + const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = + &G.getNodeMetadata(node1).getAllowedRegs(); + const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed = + &G.getNodeMetadata(node2).getAllowedRegs(); - PBQPRAGraph &g = p->getGraph(); - PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd); - PBQPRAGraph::NodeId node2 = p->getNodeForVReg(Ra); - PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2); + PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); // The edge does not exist. Create one with the appropriate interference // costs. - if (edge == g.invalidEdgeId()) { - const LiveInterval &ld = LIs->getInterval(Rd); - const LiveInterval &la = LIs->getInterval(Ra); + if (edge == G.invalidEdgeId()) { + const LiveInterval &ld = LIs.getInterval(Rd); + const LiveInterval &la = LIs.getInterval(Ra); bool livesOverlap = ld.overlaps(la); - PBQP::Matrix costs(vRdAllowed->size() + 1, vRaAllowed->size() + 1, 0); + PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1, + vRaAllowed->size() + 1, 0); for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { unsigned pRd = (*vRdAllowed)[i]; for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) { @@ -225,17 +200,17 @@ bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0; } } - g.addEdge(node1, node2, std::move(costs)); + G.addEdge(node1, node2, std::move(costs)); return true; } - if (g.getEdgeNode1Id(edge) == node2) { + if (G.getEdgeNode1Id(edge) == node2) { std::swap(node1, node2); std::swap(vRdAllowed, vRaAllowed); } // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass)) - PBQP::Matrix costs(g.getEdgeCosts(edge)); + PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge)); for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { unsigned pRd = (*vRdAllowed)[i]; @@ -260,14 +235,15 @@ bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, costs[i + 1][j + 1] = sameParityMax + 1.0; } } - g.setEdgeCosts(edge, costs); + G.updateEdgeCosts(edge, std::move(costs)); return true; } -void -A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, - unsigned Ra) { +void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd, + unsigned Ra) { + LiveIntervals &LIs = G.getMetadata().LIS; + // Do some Chain management if (Chains.count(Ra)) { if (Rd != Ra) { @@ -282,33 +258,36 @@ A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, Chains.insert(Rd); } - const LiveInterval &ld = LIs->getInterval(Rd); + PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd); + + const LiveInterval &ld = LIs.getInterval(Rd); for (auto r : Chains) { // Skip self if (r == Rd) continue; - const LiveInterval &lr = LIs->getInterval(r); + const LiveInterval &lr = LIs.getInterval(r); if (ld.overlaps(lr)) { - const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd); - const PBQPRAProblem::AllowedSet *vRrAllowed = &p->getAllowedSet(r); - - PBQPRAGraph &g = p->getGraph(); - PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd); - PBQPRAGraph::NodeId node2 = p->getNodeForVReg(r); - PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2); - assert(edge != g.invalidEdgeId() && + const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed = + &G.getNodeMetadata(node1).getAllowedRegs(); + + PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r); + const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed = + &G.getNodeMetadata(node2).getAllowedRegs(); + + PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2); + assert(edge != G.invalidEdgeId() && "PBQP error ! The edge should exist !"); DEBUG(dbgs() << "Refining constraint !\n";); - if (g.getEdgeNode1Id(edge) == node2) { + if (G.getEdgeNode1Id(edge) == node2) { std::swap(node1, node2); std::swap(vRdAllowed, vRrAllowed); } // Enforce that cost is higher with all other Chains of the same parity - PBQP::Matrix costs(g.getEdgeCosts(edge)); + PBQP::Matrix costs(G.getEdgeCosts(edge)); for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) { unsigned pRd = (*vRdAllowed)[i]; @@ -333,34 +312,46 @@ A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, costs[i + 1][j + 1] = sameParityMax + 1.0; } } - g.setEdgeCosts(edge, costs); + G.updateEdgeCosts(edge, std::move(costs)); } } } -std::unique_ptr -A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI, - const MachineBlockFrequencyInfo *blockInfo, - const RegSet &VRegs) { - std::unique_ptr p = - PBQP_BUILDER::build(MF, LI, blockInfo, VRegs); +static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg, + const MachineInstr &MI) { + const LiveInterval &LI = LIs.getInterval(reg); + SlotIndex SI = LIs.getInstructionIndex(&MI); + return LI.expiredAt(SI); +} - TRI = static_cast( - MF->getTarget().getSubtargetImpl()->getRegisterInfo()); - LIs = LI; +void A57ChainingConstraint::apply(PBQPRAGraph &G) { + const MachineFunction &MF = G.getMetadata().MF; + LiveIntervals &LIs = G.getMetadata().LIS; - DEBUG(MF->dump();); + TRI = MF.getSubtarget().getRegisterInfo(); + DEBUG(MF.dump()); - for (MachineFunction::const_iterator mbbItr = MF->begin(), mbbEnd = MF->end(); - mbbItr != mbbEnd; ++mbbItr) { - const MachineBasicBlock *MBB = &*mbbItr; + for (const auto &MBB: MF) { Chains.clear(); // FIXME: really needed ? Could not work at MF level ? - for (MachineBasicBlock::const_iterator miItr = MBB->begin(), - miEnd = MBB->end(); - miItr != miEnd; ++miItr) { - const MachineInstr *MI = &*miItr; - switch (MI->getOpcode()) { + for (const auto &MI: MBB) { + + // Forget Chains which have expired + for (auto r : Chains) { + SmallVector toDel; + if(regJustKilledBefore(LIs, r, MI)) { + DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at "; + MI.print(dbgs());); + toDel.push_back(r); + } + + while (!toDel.empty()) { + Chains.remove(toDel.back()); + toDel.pop_back(); + } + } + + switch (MI.getOpcode()) { case AArch64::FMSUBSrrr: case AArch64::FMADDSrrr: case AArch64::FNMSUBSrrr: @@ -369,46 +360,24 @@ A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI, case AArch64::FMADDDrrr: case AArch64::FNMSUBDrrr: case AArch64::FNMADDDrrr: { - unsigned Rd = MI->getOperand(0).getReg(); - unsigned Ra = MI->getOperand(3).getReg(); + unsigned Rd = MI.getOperand(0).getReg(); + unsigned Ra = MI.getOperand(3).getReg(); - if (addIntraChainConstraint(p.get(), Rd, Ra)) - addInterChainConstraint(p.get(), Rd, Ra); + if (addIntraChainConstraint(G, Rd, Ra)) + addInterChainConstraint(G, Rd, Ra); break; } case AArch64::FMLAv2f32: case AArch64::FMLSv2f32: { - unsigned Rd = MI->getOperand(0).getReg(); - addInterChainConstraint(p.get(), Rd, Rd); + unsigned Rd = MI.getOperand(0).getReg(); + addInterChainConstraint(G, Rd, Rd); break; } default: - // Forget Chains which have been killed - for (auto r : Chains) { - SmallVector toDel; - if (MI->killsRegister(r)) { - DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at "; - MI->print(dbgs());); - toDel.push_back(r); - } - - while (!toDel.empty()) { - Chains.remove(toDel.back()); - toDel.pop_back(); - } - } + break; } } } - - return p; -} - -// Factory function used by AArch64TargetMachine to add the pass to the -// passmanager. -FunctionPass *llvm::createAArch64A57PBQPRegAlloc() { - std::unique_ptr builder = llvm::make_unique(); - return createPBQPRegisterAllocator(std::move(builder), nullptr); }