Revert r248483, r242546, r242545, and r242409 - absdiff intrinsics
[oota-llvm.git] / lib / Target / AArch64 / AArch64PBQPRegAlloc.cpp
index 87e3502c9aa38566c2cbb03222cbd14a69fe73a4..5394875a6bc12f96ef219bddf8f9fafe56dcea05 100644 (file)
@@ -18,9 +18,8 @@
 #define DEBUG_TYPE "aarch64-pbqp"
 
 #include "AArch64.h"
+#include "AArch64PBQPRegAlloc.h"
 #include "AArch64RegisterInfo.h"
-
-#include "llvm/ADT/SetVector.h"
 #include "llvm/CodeGen/LiveIntervalAnalysis.h"
 #include "llvm/CodeGen/MachineBasicBlock.h"
 #include "llvm/CodeGen/MachineFunction.h"
@@ -30,9 +29,6 @@
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/raw_ostream.h"
 
-#define PBQP_BUILDER PBQPBuilderWithCoalescing
-//#define PBQP_BUILDER PBQPBuilder
-
 using namespace llvm;
 
 namespace {
@@ -42,7 +38,7 @@ bool isFPReg(unsigned reg) {
   return AArch64::FPR32RegClass.contains(reg) ||
          AArch64::FPR64RegClass.contains(reg) ||
          AArch64::FPR128RegClass.contains(reg);
-};
+}
 #endif
 
 bool isOdd(unsigned reg) {
@@ -158,67 +154,45 @@ bool haveSameParity(unsigned reg1, unsigned reg2) {
   return isOdd(reg1) == isOdd(reg2);
 }
 
-class A57PBQPBuilder : public PBQP_BUILDER {
-public:
-  A57PBQPBuilder() : PBQP_BUILDER(), TRI(nullptr), LIs(nullptr), Chains() {}
-
-  // Build a PBQP instance to represent the register allocation problem for
-  // the given MachineFunction.
-  std::unique_ptr<PBQPRAProblem>
-  build(MachineFunction *MF, const LiveIntervals *LI,
-        const MachineBlockFrequencyInfo *blockInfo,
-        const RegSet &VRegs) override;
-
-private:
-  const AArch64RegisterInfo *TRI;
-  const LiveIntervals *LIs;
-  SmallSetVector<unsigned, 32> Chains;
-
-  // Return true if reg is a physical register
-  bool isPhysicalReg(unsigned reg) const {
-    return TRI->isPhysicalRegister(reg);
-  }
-
-  // Add the accumulator chaining constraint, inside the chain, i.e. so that
-  // parity(Rd) == parity(Ra).
-  // \return true if a constraint was added
-  bool addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
-
-  // Add constraints between existing chains
-  void addInterChainConstraint(PBQPRAProblem *p, unsigned Rd, unsigned Ra);
-};
-} // Anonymous namespace
+}
 
-bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd,
-                                             unsigned Ra) {
+bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
+                                                 unsigned Ra) {
   if (Rd == Ra)
     return false;
 
-  if (isPhysicalReg(Rd) || isPhysicalReg(Ra)) {
-    dbgs() << "Rd is a physical reg:" << isPhysicalReg(Rd) << '\n';
-    dbgs() << "Ra is a physical reg:" << isPhysicalReg(Ra) << '\n';
+  LiveIntervals &LIs = G.getMetadata().LIS;
+
+  if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
+    DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
+          << '\n');
+    DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
+          << '\n');
     return false;
   }
 
-  const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
-  const PBQPRAProblem::AllowedSet *vRaAllowed = &p->getAllowedSet(Ra);
+  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
+  PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
+
+  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
+    &G.getNodeMetadata(node1).getAllowedRegs();
+  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
+    &G.getNodeMetadata(node2).getAllowedRegs();
 
-  PBQPRAGraph &g = p->getGraph();
-  PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
-  PBQPRAGraph::NodeId node2 = p->getNodeForVReg(Ra);
-  PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
+  PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
 
   // The edge does not exist. Create one with the appropriate interference
   // costs.
-  if (edge == g.invalidEdgeId()) {
-    const LiveInterval &ld = LIs->getInterval(Rd);
-    const LiveInterval &la = LIs->getInterval(Ra);
+  if (edge == G.invalidEdgeId()) {
+    const LiveInterval &ld = LIs.getInterval(Rd);
+    const LiveInterval &la = LIs.getInterval(Ra);
     bool livesOverlap = ld.overlaps(la);
 
-    PBQP::Matrix costs(vRdAllowed->size() + 1, vRaAllowed->size() + 1, 0);
-    for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
+    PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
+                                 vRaAllowed->size() + 1, 0);
+    for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
       unsigned pRd = (*vRdAllowed)[i];
-      for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
+      for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
         unsigned pRa = (*vRaAllowed)[j];
         if (livesOverlap && TRI->regsOverlap(pRd, pRa))
           costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
@@ -226,24 +200,24 @@ bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd,
           costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
       }
     }
-    g.addEdge(node1, node2, std::move(costs));
+    G.addEdge(node1, node2, std::move(costs));
     return true;
   }
 
-  if (g.getEdgeNode1Id(edge) == node2) {
+  if (G.getEdgeNode1Id(edge) == node2) {
     std::swap(node1, node2);
     std::swap(vRdAllowed, vRaAllowed);
   }
 
   // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
-  PBQP::Matrix costs(g.getEdgeCosts(edge));
-  for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
+  PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
+  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
     unsigned pRd = (*vRdAllowed)[i];
 
     // Get the maximum cost (excluding unallocatable reg) for same parity
     // registers
     PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
-    for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
+    for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
       unsigned pRa = (*vRaAllowed)[j];
       if (haveSameParity(pRd, pRa))
         if (costs[i + 1][j + 1] !=
@@ -254,21 +228,22 @@ bool A57PBQPBuilder::addIntraChainConstraint(PBQPRAProblem *p, unsigned Rd,
 
     // Ensure all registers with a different parity have a higher cost
     // than sameParityMax
-    for (unsigned j = 0; j != vRaAllowed->size(); ++j) {
+    for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
       unsigned pRa = (*vRaAllowed)[j];
       if (!haveSameParity(pRd, pRa))
         if (sameParityMax > costs[i + 1][j + 1])
           costs[i + 1][j + 1] = sameParityMax + 1.0;
     }
   }
-  g.setEdgeCosts(edge, costs);
+  G.updateEdgeCosts(edge, std::move(costs));
 
   return true;
 }
 
-void
-A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd,
-                                        unsigned Ra) {
+void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
+                                                 unsigned Ra) {
+  LiveIntervals &LIs = G.getMetadata().LIS;
+
   // Do some Chain management
   if (Chains.count(Ra)) {
     if (Rd != Ra) {
@@ -283,40 +258,43 @@ A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd,
     Chains.insert(Rd);
   }
 
-  const LiveInterval &ld = LIs->getInterval(Rd);
+  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
+
+  const LiveInterval &ld = LIs.getInterval(Rd);
   for (auto r : Chains) {
     // Skip self
     if (r == Rd)
       continue;
 
-    const LiveInterval &lr = LIs->getInterval(r);
+    const LiveInterval &lr = LIs.getInterval(r);
     if (ld.overlaps(lr)) {
-      const PBQPRAProblem::AllowedSet *vRdAllowed = &p->getAllowedSet(Rd);
-      const PBQPRAProblem::AllowedSet *vRrAllowed = &p->getAllowedSet(r);
-
-      PBQPRAGraph &g = p->getGraph();
-      PBQPRAGraph::NodeId node1 = p->getNodeForVReg(Rd);
-      PBQPRAGraph::NodeId node2 = p->getNodeForVReg(r);
-      PBQPRAGraph::EdgeId edge = g.findEdge(node1, node2);
-      assert(edge != g.invalidEdgeId() &&
+      const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
+        &G.getNodeMetadata(node1).getAllowedRegs();
+
+      PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
+      const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
+        &G.getNodeMetadata(node2).getAllowedRegs();
+
+      PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
+      assert(edge != G.invalidEdgeId() &&
              "PBQP error ! The edge should exist !");
 
       DEBUG(dbgs() << "Refining constraint !\n";);
 
-      if (g.getEdgeNode1Id(edge) == node2) {
+      if (G.getEdgeNode1Id(edge) == node2) {
         std::swap(node1, node2);
         std::swap(vRdAllowed, vRrAllowed);
       }
 
       // Enforce that cost is higher with all other Chains of the same parity
-      PBQP::Matrix costs(g.getEdgeCosts(edge));
-      for (unsigned i = 0; i != vRdAllowed->size(); ++i) {
+      PBQP::Matrix costs(G.getEdgeCosts(edge));
+      for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
         unsigned pRd = (*vRdAllowed)[i];
 
         // Get the maximum cost (excluding unallocatable reg) for all other
         // parity registers
         PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
-        for (unsigned j = 0; j != vRrAllowed->size(); ++j) {
+        for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
           unsigned pRa = (*vRrAllowed)[j];
           if (!haveSameParity(pRd, pRa))
             if (costs[i + 1][j + 1] !=
@@ -327,41 +305,53 @@ A57PBQPBuilder::addInterChainConstraint(PBQPRAProblem *p, unsigned Rd,
 
         // Ensure all registers with same parity have a higher cost
         // than sameParityMax
-        for (unsigned j = 0; j != vRrAllowed->size(); ++j) {
+        for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
           unsigned pRa = (*vRrAllowed)[j];
           if (haveSameParity(pRd, pRa))
             if (sameParityMax > costs[i + 1][j + 1])
               costs[i + 1][j + 1] = sameParityMax + 1.0;
         }
       }
-      g.setEdgeCosts(edge, costs);
+      G.updateEdgeCosts(edge, std::move(costs));
     }
   }
 }
 
-std::unique_ptr<PBQPRAProblem>
-A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI,
-                      const MachineBlockFrequencyInfo *blockInfo,
-                      const RegSet &VRegs) {
-  std::unique_ptr<PBQPRAProblem> p =
-      PBQP_BUILDER::build(MF, LI, blockInfo, VRegs);
+static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
+                                const MachineInstr &MI) {
+  const LiveInterval &LI = LIs.getInterval(reg);
+  SlotIndex SI = LIs.getInstructionIndex(&MI);
+  return LI.expiredAt(SI);
+}
 
-  TRI = static_cast<const AArch64RegisterInfo *>(
-      MF->getTarget().getSubtargetImpl()->getRegisterInfo());
-  LIs = LI;
+void A57ChainingConstraint::apply(PBQPRAGraph &G) {
+  const MachineFunction &MF = G.getMetadata().MF;
+  LiveIntervals &LIs = G.getMetadata().LIS;
 
-  DEBUG(MF->dump(););
+  TRI = MF.getSubtarget().getRegisterInfo();
+  DEBUG(MF.dump());
 
-  for (MachineFunction::const_iterator mbbItr = MF->begin(), mbbEnd = MF->end();
-       mbbItr != mbbEnd; ++mbbItr) {
-    const MachineBasicBlock *MBB = &*mbbItr;
+  for (const auto &MBB: MF) {
     Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
 
-    for (MachineBasicBlock::const_iterator miItr = MBB->begin(),
-                                           miEnd = MBB->end();
-         miItr != miEnd; ++miItr) {
-      const MachineInstr *MI = &*miItr;
-      switch (MI->getOpcode()) {
+    for (const auto &MI: MBB) {
+
+      // Forget Chains which have expired
+      for (auto r : Chains) {
+        SmallVector<unsigned, 8> toDel;
+        if(regJustKilledBefore(LIs, r, MI)) {
+          DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
+                MI.print(dbgs()););
+          toDel.push_back(r);
+        }
+
+        while (!toDel.empty()) {
+          Chains.remove(toDel.back());
+          toDel.pop_back();
+        }
+      }
+
+      switch (MI.getOpcode()) {
       case AArch64::FMSUBSrrr:
       case AArch64::FMADDSrrr:
       case AArch64::FNMSUBSrrr:
@@ -370,46 +360,24 @@ A57PBQPBuilder::build(MachineFunction *MF, const LiveIntervals *LI,
       case AArch64::FMADDDrrr:
       case AArch64::FNMSUBDrrr:
       case AArch64::FNMADDDrrr: {
-        unsigned Rd = MI->getOperand(0).getReg();
-        unsigned Ra = MI->getOperand(3).getReg();
+        unsigned Rd = MI.getOperand(0).getReg();
+        unsigned Ra = MI.getOperand(3).getReg();
 
-        if (addIntraChainConstraint(p.get(), Rd, Ra))
-          addInterChainConstraint(p.get(), Rd, Ra);
+        if (addIntraChainConstraint(G, Rd, Ra))
+          addInterChainConstraint(G, Rd, Ra);
         break;
       }
 
       case AArch64::FMLAv2f32:
       case AArch64::FMLSv2f32: {
-        unsigned Rd = MI->getOperand(0).getReg();
-        addInterChainConstraint(p.get(), Rd, Rd);
+        unsigned Rd = MI.getOperand(0).getReg();
+        addInterChainConstraint(G, Rd, Rd);
         break;
       }
 
       default:
-        // Forget Chains which have been killed
-        for (auto r : Chains) {
-          SmallVector<unsigned, 8> toDel;
-          if (MI->killsRegister(r)) {
-            DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
-                  MI->print(dbgs()););
-            toDel.push_back(r);
-          }
-
-          while (!toDel.empty()) {
-            Chains.remove(toDel.back());
-            toDel.pop_back();
-          }
-        }
+        break;
       }
     }
   }
-
-  return p;
-}
-
-// Factory function used by AArch64TargetMachine to add the pass to the
-// passmanager.
-FunctionPass *llvm::createAArch64A57PBQPRegAlloc() {
-  std::unique_ptr<PBQP_BUILDER> builder = llvm::make_unique<A57PBQPBuilder>();
-  return createPBQPRegisterAllocator(std::move(builder), nullptr);
 }