MC: Use MCSymbol in RelAndSymbol, NFC
[oota-llvm.git] / lib / Target / AArch64 / AArch64A57FPLoadBalancing.cpp
index 195a48e54a5b33ba7e8407ef9a98332fcb71912c..bffd9e6e8c76afec57e83bd80995772c1e66c0a8 100644 (file)
@@ -38,8 +38,8 @@
 #include "llvm/CodeGen/MachineInstr.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
-#include "llvm/CodeGen/RegisterScavenging.h"
 #include "llvm/CodeGen/RegisterClassInfo.h"
+#include "llvm/CodeGen/RegisterScavenging.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -73,8 +73,6 @@ static bool isMul(MachineInstr *MI) {
   case AArch64::FNMULSrr:
   case AArch64::FMULDrr:
   case AArch64::FNMULDrr:
-
-  case AArch64::FMULv2f32:
     return true;
   default:
     return false;
@@ -92,34 +90,38 @@ static bool isMla(MachineInstr *MI) {
   case AArch64::FMADDDrrr:
   case AArch64::FNMSUBDrrr:
   case AArch64::FNMADDDrrr:
-
-  case AArch64::FMLAv2f32:
-  case AArch64::FMLSv2f32:
     return true;
   default:
     return false;
   }
 }
 
+namespace llvm {
+static void initializeAArch64A57FPLoadBalancingPass(PassRegistry &);
+}
+
 //===----------------------------------------------------------------------===//
 
 namespace {
 /// A "color", which is either even or odd. Yes, these aren't really colors
 /// but the algorithm is conceptually doing two-color graph coloring.
 enum class Color { Even, Odd };
+#ifndef NDEBUG
 static const char *ColorNames[2] = { "Even", "Odd" };
+#endif
 
 class Chain;
 
 class AArch64A57FPLoadBalancing : public MachineFunctionPass {
-  const AArch64InstrInfo *TII;
   MachineRegisterInfo *MRI;
   const TargetRegisterInfo *TRI;
   RegisterClassInfo RCI;
 
 public:
   static char ID;
-  explicit AArch64A57FPLoadBalancing() : MachineFunctionPass(ID) {}
+  explicit AArch64A57FPLoadBalancing() : MachineFunctionPass(ID) {
+    initializeAArch64A57FPLoadBalancingPass(*PassRegistry::getPassRegistry());
+  }
 
   bool runOnMachineFunction(MachineFunction &F) override;
 
@@ -139,15 +141,23 @@ private:
   bool colorChain(Chain *G, Color C, MachineBasicBlock &MBB);
   int scavengeRegister(Chain *G, Color C, MachineBasicBlock &MBB);
   void scanInstruction(MachineInstr *MI, unsigned Idx,
-                       std::map<unsigned, Chain*> &Chains,
-                       std::set<Chain*> &ChainSet);
+                       std::map<unsigned, Chain*> &Active,
+                       std::vector<std::unique_ptr<Chain>> &AllChains);
   void maybeKillChain(MachineOperand &MO, unsigned Idx,
                       std::map<unsigned, Chain*> &RegChains);
   Color getColor(unsigned Register);
   Chain *getAndEraseNext(Color PreferredColor, std::vector<Chain*> &L);
 };
+}
+
 char AArch64A57FPLoadBalancing::ID = 0;
 
+INITIALIZE_PASS_BEGIN(AArch64A57FPLoadBalancing, DEBUG_TYPE,
+                      "AArch64 A57 FP Load-Balancing", false, false)
+INITIALIZE_PASS_END(AArch64A57FPLoadBalancing, DEBUG_TYPE,
+                    "AArch64 A57 FP Load-Balancing", false, false)
+
+namespace {
 /// A Chain is a sequence of instructions that are linked together by 
 /// an accumulation operand. For example:
 ///
@@ -191,10 +201,10 @@ public:
   /// instruction can be more tricky.
   Color LastColor;
 
-  Chain(MachineInstr *MI, unsigned Idx, Color C) :
-  StartInst(MI), LastInst(MI), KillInst(NULL),
-  StartInstIdx(Idx), LastInstIdx(Idx), KillInstIdx(0),
-  LastColor(C) {
+  Chain(MachineInstr *MI, unsigned Idx, Color C)
+      : StartInst(MI), LastInst(MI), KillInst(nullptr),
+        StartInstIdx(Idx), LastInstIdx(Idx), KillInstIdx(0),
+        LastColor(C) {
     Insts.insert(MI);
   }
 
@@ -204,6 +214,9 @@ public:
     LastInst = MI;
     LastInstIdx = Idx;
     LastColor = C;
+    assert((KillInstIdx == 0 || LastInstIdx < KillInstIdx) &&
+           "Chain: broken invariant. A Chain can only be killed after its last "
+           "def");
 
     Insts.insert(MI);
   }
@@ -222,6 +235,9 @@ public:
     KillInst = MI;
     KillInstIdx = Idx;
     KillIsImmutable = Immutable;
+    assert((KillInstIdx == 0 || LastInstIdx < KillInstIdx) &&
+           "Chain: broken invariant. A Chain can only be killed after its last "
+           "def");
   }
 
   /// Return the first instruction in the chain.
@@ -232,7 +248,7 @@ public:
   MachineInstr *getKill() const { return KillInst; }
   /// Return an instruction that can be used as an iterator for the end
   /// of the chain. This is the maximum of KillInst (if set) and LastInst.
-  MachineInstr *getEnd() const {
+  MachineBasicBlock::iterator getEnd() const {
     return ++MachineBasicBlock::iterator(KillInst ? KillInst : LastInst);
   }
 
@@ -247,16 +263,16 @@ public:
   }
 
   /// Return true if this chain (StartInst..KillInst) overlaps with Other.
-  bool rangeOverlapsWith(Chain *Other) {
+  bool rangeOverlapsWith(const Chain &Other) const {
     unsigned End = KillInst ? KillInstIdx : LastInstIdx;
-    unsigned OtherEnd = Other->KillInst ?
-      Other->KillInstIdx : Other->LastInstIdx;
+    unsigned OtherEnd = Other.KillInst ?
+      Other.KillInstIdx : Other.LastInstIdx;
 
-    return StartInstIdx <= OtherEnd && Other->StartInstIdx <= End;
+    return StartInstIdx <= OtherEnd && Other.StartInstIdx <= End;
   }
 
   /// Return true if this chain starts before Other.
-  bool startsBefore(Chain *Other) {
+  bool startsBefore(const Chain *Other) const {
     return StartInstIdx < Other->StartInstIdx;
   }
 
@@ -271,12 +287,12 @@ public:
     raw_string_ostream OS(S);
     
     OS << "{";
-    StartInst->print(OS, NULL, true);
+    StartInst->print(OS, /* SkipOpers= */true);
     OS << " -> ";
-    LastInst->print(OS, NULL, true);
+    LastInst->print(OS, /* SkipOpers= */true);
     if (KillInst) {
       OS << " (kill @ ";
-      KillInst->print(OS, NULL, true);
+      KillInst->print(OS, /* SkipOpers= */true);
       OS << ")";
     }
     OS << "}";
@@ -291,13 +307,16 @@ public:
 //===----------------------------------------------------------------------===//
 
 bool AArch64A57FPLoadBalancing::runOnMachineFunction(MachineFunction &F) {
+  // Don't do anything if this isn't an A53 or A57.
+  if (!(F.getSubtarget<AArch64Subtarget>().isCortexA53() ||
+        F.getSubtarget<AArch64Subtarget>().isCortexA57()))
+    return false;
+
   bool Changed = false;
   DEBUG(dbgs() << "***** AArch64A57FPLoadBalancing *****\n");
 
-  const TargetMachine &TM = F.getTarget();
   MRI = &F.getRegInfo();
   TRI = F.getRegInfo().getTargetRegisterInfo();
-  TII = TM.getSubtarget<AArch64Subtarget>().getInstrInfo();
   RCI.runOnMachineFunction(F);
 
   for (auto &MBB : F) {
@@ -317,7 +336,7 @@ bool AArch64A57FPLoadBalancing::runOnBasicBlock(MachineBasicBlock &MBB) {
   // been killed yet. This is keyed by register - all chains can only have one
   // "link" register between each inst in the chain.
   std::map<unsigned, Chain*> ActiveChains;
-  std::set<Chain*> AllChains;
+  std::vector<std::unique_ptr<Chain>> AllChains;
   unsigned Idx = 0;
   for (auto &MI : MBB)
     scanInstruction(&MI, Idx++, ActiveChains, AllChains);
@@ -332,15 +351,13 @@ bool AArch64A57FPLoadBalancing::runOnBasicBlock(MachineBasicBlock &MBB) {
   //       range of chains is quite small and they are clustered between loads
   //       and stores.
   EquivalenceClasses<Chain*> EC;
-  for (auto *I : AllChains)
-    EC.insert(I);
+  for (auto &I : AllChains)
+    EC.insert(I.get());
 
-  for (auto *I : AllChains) {
-    for (auto *J : AllChains) {
-      if (I != J && I->rangeOverlapsWith(J))
-        EC.unionSets(I, J);
-    }
-  }
+  for (auto &I : AllChains)
+    for (auto &J : AllChains)
+      if (I != J && I->rangeOverlapsWith(*J))
+        EC.unionSets(I.get(), J.get());
   DEBUG(dbgs() << "Created " << EC.getNumClasses() << " disjoint sets.\n");
 
   // Now we assume that every member of an equivalence class interferes
@@ -351,7 +368,7 @@ bool AArch64A57FPLoadBalancing::runOnBasicBlock(MachineBasicBlock &MBB) {
   for (auto I = EC.begin(), E = EC.end(); I != E; ++I) {
     std::vector<Chain*> Cs(EC.member_begin(I), EC.member_end());
     if (Cs.empty()) continue;
-    V.push_back(Cs);
+    V.push_back(std::move(Cs));
   }
 
   // Now we have a set of sets, order them by start address so
@@ -376,10 +393,7 @@ bool AArch64A57FPLoadBalancing::runOnBasicBlock(MachineBasicBlock &MBB) {
   int Parity = 0;
 
   for (auto &I : V)
-    Changed |= colorChainSet(I, MBB, Parity);
-
-  for (auto *C : AllChains)
-    delete C;
+    Changed |= colorChainSet(std::move(I), MBB, Parity);
 
   return Changed;
 }
@@ -433,10 +447,17 @@ bool AArch64A57FPLoadBalancing::colorChainSet(std::vector<Chain*> GV,
   // chains that we cannot change before we look at those we can,
   // so the parity counter is updated and we know what color we should
   // change them to!
+  // Final tie-break with instruction order so pass output is stable (i.e. not
+  // dependent on malloc'd pointer values).
   std::sort(GV.begin(), GV.end(), [](const Chain *G1, const Chain *G2) {
       if (G1->size() != G2->size())
         return G1->size() > G2->size();
-      return G1->requiresFixup() > G2->requiresFixup();
+      if (G1->requiresFixup() != G2->requiresFixup())
+        return G1->requiresFixup() > G2->requiresFixup();
+      // Make sure startsBefore() produces a stable final order.
+      assert((G1 == G2 || (G1->startsBefore(G2) ^ G2->startsBefore(G1))) &&
+             "Starts before not total order!");
+      return G1->startsBefore(G2);
     });
 
   Color PreferredColor = Parity < 0 ? Color::Even : Color::Odd;
@@ -483,10 +504,16 @@ int AArch64A57FPLoadBalancing::scavengeRegister(Chain *G, Color C,
     RS.forward(I);
     AvailableRegs &= RS.getRegsAvailable(TRI->getRegClass(RegClassID));
 
-    // Remove any registers clobbered by a regmask.
+    // Remove any registers clobbered by a regmask or any def register that is
+    // immediately dead.
     for (auto J : I->operands()) {
       if (J.isRegMask())
         AvailableRegs.clearBitsNotInMask(J.getRegMask());
+
+      if (J.isReg() && J.isDef() && AvailableRegs[J.getReg()]) {
+        assert(J.isDead() && "Non-dead def should have been removed by now!");
+        AvailableRegs.reset(J.getReg());
+      }
     }
   }
 
@@ -576,15 +603,16 @@ bool AArch64A57FPLoadBalancing::colorChain(Chain *G, Color C,
   return Changed;
 }
 
-void AArch64A57FPLoadBalancing::
-scanInstruction(MachineInstr *MI, unsigned Idx, 
-                std::map<unsigned, Chain*> &ActiveChains,
-                std::set<Chain*> &AllChains) {
+void AArch64A57FPLoadBalancing::scanInstruction(
+    MachineInstr *MI, unsigned Idx, std::map<unsigned, Chain *> &ActiveChains,
+    std::vector<std::unique_ptr<Chain>> &AllChains) {
   // Inspect "MI", updating ActiveChains and AllChains.
 
   if (isMul(MI)) {
 
-    for (auto &I : MI->operands())
+    for (auto &I : MI->uses())
+      maybeKillChain(I, Idx, ActiveChains);
+    for (auto &I : MI->defs())
       maybeKillChain(I, Idx, ActiveChains);
 
     // Create a new chain. Multiplies don't require forwarding so can go on any
@@ -594,9 +622,9 @@ scanInstruction(MachineInstr *MI, unsigned Idx,
     DEBUG(dbgs() << "New chain started for register "
           << TRI->getName(DestReg) << " at " << *MI);
 
-    Chain *G = new Chain(MI, Idx, getColor(DestReg));
-    ActiveChains[DestReg] = G;
-    AllChains.insert(G);
+    auto G = llvm::make_unique<Chain>(MI, Idx, getColor(DestReg));
+    ActiveChains[DestReg] = G.get();
+    AllChains.push_back(std::move(G));
 
   } else if (isMla(MI)) {
 
@@ -624,7 +652,10 @@ scanInstruction(MachineInstr *MI, unsigned Idx,
         DEBUG(dbgs() << "Instruction was successfully added to chain.\n");
         ActiveChains[AccumReg]->add(MI, Idx, getColor(DestReg));
         // Handle cases where the destination is not the same as the accumulator.
-        ActiveChains[DestReg] = ActiveChains[AccumReg];
+        if (DestReg != AccumReg) {
+          ActiveChains[DestReg] = ActiveChains[AccumReg];
+          ActiveChains.erase(AccumReg);
+        }
         return;
       }
 
@@ -635,15 +666,17 @@ scanInstruction(MachineInstr *MI, unsigned Idx,
 
     DEBUG(dbgs() << "Creating new chain for dest register "
           << TRI->getName(DestReg) << "\n");
-    Chain *G = new Chain(MI, Idx, getColor(DestReg));
-    ActiveChains[DestReg] = G;
-    AllChains.insert(G);
+    auto G = llvm::make_unique<Chain>(MI, Idx, getColor(DestReg));
+    ActiveChains[DestReg] = G.get();
+    AllChains.push_back(std::move(G));
 
   } else {
 
     // Non-MUL or MLA instruction. Invalidate any chain in the uses or defs
     // lists.
-    for (auto &I : MI->operands())
+    for (auto &I : MI->uses())
+      maybeKillChain(I, Idx, ActiveChains);
+    for (auto &I : MI->defs())
       maybeKillChain(I, Idx, ActiveChains);
 
   }
@@ -669,13 +702,14 @@ maybeKillChain(MachineOperand &MO, unsigned Idx,
   } else if (MO.isRegMask()) {
 
     for (auto I = ActiveChains.begin(), E = ActiveChains.end();
-         I != E; ++I) {
+         I != E;) {
       if (MO.clobbersPhysReg(I->first)) {
         DEBUG(dbgs() << "Kill (regmask) seen for chain "
               << TRI->getName(I->first) << "\n");
         I->second->setKill(MI, Idx, /*Immutable=*/true);
-        ActiveChains.erase(I);
-      }
+        ActiveChains.erase(I++);
+      } else
+        ++I;
     }
 
   }