Do not lose rematerialization info when spilling already split live intervals.
[oota-llvm.git] / lib / CodeGen / LiveIntervalAnalysis.cpp
index 75d5aac404856bb6586437ae557558a283125402..7d627bb6b2a0d52d7cc0d54def5157d04aa04043 100644 (file)
@@ -362,7 +362,8 @@ void LiveIntervals::handleVirtualRegisterDef(MachineBasicBlock *mbb,
         DOUT << " Removing [" << Start << "," << End << "] from: ";
         interval.print(DOUT, mri_); DOUT << "\n";
         interval.removeRange(Start, End);
-        interval.addKill(VNI, Start+1); // odd # means phi node
+        interval.addKill(VNI, Start);
+        VNI->hasPHIKill = true;
         DOUT << " RESULT: "; interval.print(DOUT, mri_);
 
         // Replace the interval with one of a NEW value number.  Note that this
@@ -392,7 +393,8 @@ void LiveIntervals::handleVirtualRegisterDef(MachineBasicBlock *mbb,
       unsigned killIndex = getInstructionIndex(&mbb->back()) + InstrSlots::NUM;
       LiveRange LR(defIndex, killIndex, ValNo);
       interval.addRange(LR);
-      interval.addKill(ValNo, killIndex+1); // odd # means phi node
+      interval.addKill(ValNo, killIndex);
+      ValNo->hasPHIKill = true;
       DOUT << " +" << LR;
     }
   }
@@ -701,7 +703,7 @@ rewriteInstructionForSpills(const LiveInterval &li, bool TrySplit,
                  SmallVector<int, 4> &ReMatIds,
                  unsigned &NewVReg, bool &HasDef, bool &HasUse,
                  const LoopInfo *loopInfo,
-                 std::map<unsigned,unsigned> &NewVRegs,
+                 std::map<unsigned,unsigned> &MBBVRegsMap,
                  std::vector<LiveInterval*> &NewLIs) {
  RestartInstruction:
   for (unsigned i = 0; i != MI->getNumOperands(); ++i) {
@@ -718,7 +720,7 @@ rewriteInstructionForSpills(const LiveInterval &li, bool TrySplit,
       continue;
 
     bool TryFold = !DefIsReMat;
-    bool FoldSS = true;
+    bool FoldSS = true; // Default behavior unless it's a remat.
     int FoldSlot = Slot;
     if (DefIsReMat) {
       // If this is the rematerializable definition MI itself and
@@ -732,8 +734,7 @@ rewriteInstructionForSpills(const LiveInterval &li, bool TrySplit,
 
       // If def for this use can't be rematerialized, then try folding.
       // If def is rematerializable and it's a load, also try folding.
-      TryFold = !ReMatOrigDefMI ||
-        (ReMatOrigDefMI && (MI == ReMatOrigDefMI || isLoad));
+      TryFold = !ReMatDefMI || (ReMatDefMI && (MI == ReMatOrigDefMI || isLoad));
       if (isLoad) {
         // Try fold loads (from stack slot, constant pool, etc.) into uses.
         FoldSS = isLoadSS;
@@ -808,13 +809,19 @@ rewriteInstructionForSpills(const LiveInterval &li, bool TrySplit,
       } else {
         vrm.assignVirt2StackSlot(NewVReg, Slot);
       }
+    } else if (HasUse && HasDef &&
+               vrm.getStackSlot(NewVReg) == VirtRegMap::NO_STACK_SLOT) {
+      // If this interval hasn't been assigned a stack slot (because earlier
+      // def is a deleted remat def), do it now.
+      assert(Slot != VirtRegMap::NO_STACK_SLOT);
+      vrm.assignVirt2StackSlot(NewVReg, Slot);
     }
 
     // create a new register interval for this spill / remat.
     LiveInterval &nI = getOrCreateInterval(NewVReg);
     if (CreatedNewVReg) {
       NewLIs.push_back(&nI);
-      NewVRegs.insert(std::make_pair(MI->getParent()->getNumber(), NewVReg));
+      MBBVRegsMap.insert(std::make_pair(MI->getParent()->getNumber(), NewVReg));
       if (TrySplit)
         vrm.setIsSplitFromReg(NewVReg, li.reg);
     }
@@ -859,6 +866,17 @@ bool LiveIntervals::anyKillInMBBAfterIdx(const LiveInterval &li,
   return false;
 }
 
+static const VNInfo *findDefinedVNInfo(const LiveInterval &li, unsigned DefIdx) {
+  const VNInfo *VNI = NULL;
+  for (LiveInterval::const_vni_iterator i = li.vni_begin(),
+         e = li.vni_end(); i != e; ++i)
+    if ((*i)->def == DefIdx) {
+      VNI = *i;
+      break;
+    }
+  return VNI;
+}
+
 void LiveIntervals::
 rewriteInstructionsForSpills(const LiveInterval &li, bool TrySplit,
                     LiveInterval::Ranges::const_iterator &I,
@@ -870,10 +888,10 @@ rewriteInstructionsForSpills(const LiveInterval &li, bool TrySplit,
                     SmallVector<int, 4> &ReMatIds,
                     const LoopInfo *loopInfo,
                     BitVector &SpillMBBs,
-                    std::map<unsigned, std::pair<int, bool> > &SpillIdxes,
+                    std::map<unsigned, std::vector<SRInfo> > &SpillIdxes,
                     BitVector &RestoreMBBs,
-                    std::map<unsigned, std::pair<int, bool> > &RestoreIdxes,
-                    std::map<unsigned,unsigned> &NewVRegs,
+                    std::map<unsigned, std::vector<SRInfo> > &RestoreIdxes,
+                    std::map<unsigned,unsigned> &MBBVRegsMap,
                     std::vector<LiveInterval*> &NewLIs) {
   unsigned NewVReg = 0;
   unsigned index = getBaseIndex(I->start);
@@ -890,9 +908,34 @@ rewriteInstructionsForSpills(const LiveInterval &li, bool TrySplit,
     NewVReg = 0;
     if (TrySplitMI) {
       std::map<unsigned,unsigned>::const_iterator NVI =
-        NewVRegs.find(MBB->getNumber());
-      if (NVI != NewVRegs.end())
+        MBBVRegsMap.find(MBB->getNumber());
+      if (NVI != MBBVRegsMap.end()) {
         NewVReg = NVI->second;
+        // One common case:
+        // x = use
+        // ...
+        // ...
+        // def = ...
+        //     = use
+        // It's better to start a new interval to avoid artifically
+        // extend the new interval.
+        // FIXME: Too slow? Can we fix it after rewriteInstructionsForSpills?
+        bool MIHasUse = false;
+        bool MIHasDef = false;
+        for (unsigned i = 0; i != MI->getNumOperands(); ++i) {
+          MachineOperand& mop = MI->getOperand(i);
+          if (!mop.isRegister() || mop.getReg() != li.reg)
+            continue;
+          if (mop.isUse())
+            MIHasUse = true;
+          else
+            MIHasDef = true;
+        }
+        if (MIHasDef && !MIHasUse) {
+          MBBVRegsMap.erase(MBB->getNumber());
+          NewVReg = 0;
+        }
+      }
     }
     bool IsNew = NewVReg == 0;
     bool HasDef = false;
@@ -901,7 +944,7 @@ rewriteInstructionsForSpills(const LiveInterval &li, bool TrySplit,
                                 MI, ReMatOrigDefMI, ReMatDefMI, Slot, LdSlot,
                                 isLoad, isLoadSS, DefIsReMat, CanDelete, vrm,
                                 RegMap, rc, ReMatIds, NewVReg, HasDef, HasUse,
-                                loopInfo, NewVRegs, NewLIs);
+                                loopInfo, MBBVRegsMap, NewLIs);
     if (!HasDef && !HasUse)
       continue;
 
@@ -917,62 +960,60 @@ rewriteInstructionsForSpills(const LiveInterval &li, bool TrySplit,
     unsigned MBBId = MBB->getNumber();
     if (HasDef) {
       if (MI != ReMatOrigDefMI || !CanDelete) {
-        // If this is a two-address code, then this index probably starts a
-        // VNInfo so we should examine all the VNInfo's.
         bool HasKill = false;
         if (!HasUse)
           HasKill = anyKillInMBBAfterIdx(li, I->valno, MBB, getDefIndex(index));
         else {
-          const VNInfo *VNI = NULL;
-          for (LiveInterval::const_vni_iterator i = li.vni_begin(),
-                 e = li.vni_end(); i != e; ++i)
-            if ((*i)->def == getDefIndex(index)) {
-              VNI = *i;
-              break;
-            }
+          // If this is a two-address code, then this index starts a new VNInfo.
+          const VNInfo *VNI = findDefinedVNInfo(li, getDefIndex(index));
           if (VNI)
             HasKill = anyKillInMBBAfterIdx(li, VNI, MBB, getDefIndex(index));
         }
         if (!HasKill) {
-          std::map<unsigned, std::pair<int, bool> >::iterator SII =
+          std::map<unsigned, std::vector<SRInfo> >::iterator SII =
             SpillIdxes.find(MBBId);
-          if (SII == SpillIdxes.end())
-            SpillIdxes[MBBId] = std::make_pair(index, true);
-          else if ((int)index > SII->second.first) {
+          if (SII == SpillIdxes.end()) {
+            std::vector<SRInfo> S;
+            S.push_back(SRInfo(index, NewVReg, true));
+            SpillIdxes.insert(std::make_pair(MBBId, S));
+          } else if (SII->second.back().vreg != NewVReg) {
+            SII->second.push_back(SRInfo(index, NewVReg, true));
+          } else if ((int)index > SII->second.back().index) {
             // If there is an earlier def and this is a two-address
             // instruction, then it's not possible to fold the store (which
             // would also fold the load).
-            SpillIdxes[MBBId] = std::make_pair(index, !HasUse);
+            SRInfo &Info = SII->second.back();
+            Info.index = index;
+            Info.canFold = !HasUse;
           }
           SpillMBBs.set(MBBId);
         }
       }
-      if (!IsNew) {
-        // It this interval hasn't been assigned a stack slot
-        // (because earlier def is remat), do it now.
-        int SS = vrm.getStackSlot(NewVReg);
-        if (SS != (int)Slot) {
-          assert(SS == VirtRegMap::NO_STACK_SLOT);
-          vrm.assignVirt2StackSlot(NewVReg, Slot);
-        }
-      }
     }
 
     if (HasUse) {
-      std::map<unsigned, std::pair<int, bool> >::iterator SII =
+      std::map<unsigned, std::vector<SRInfo> >::iterator SII =
         SpillIdxes.find(MBBId);
-      if (SII != SpillIdxes.end() && (int)index > SII->second.first)
+      if (SII != SpillIdxes.end() &&
+          SII->second.back().vreg == NewVReg &&
+          (int)index > SII->second.back().index)
         // Use(s) following the last def, it's not safe to fold the spill.
-        SII->second.second = false;
-      std::map<unsigned, std::pair<int, bool> >::iterator RII =
+        SII->second.back().canFold = false;
+      std::map<unsigned, std::vector<SRInfo> >::iterator RII =
         RestoreIdxes.find(MBBId);
-      if (RII != RestoreIdxes.end())
+      if (RII != RestoreIdxes.end() && RII->second.back().vreg == NewVReg)
         // If we are splitting live intervals, only fold if it's the first
         // use and there isn't another use later in the MBB.
-        RII->second.second = false;
+        RII->second.back().canFold = false;
       else if (IsNew) {
         // Only need a reload if there isn't an earlier def / use.
-        RestoreIdxes[MBBId] = std::make_pair(index, true);
+        if (RII == RestoreIdxes.end()) {
+          std::vector<SRInfo> Infos;
+          Infos.push_back(SRInfo(index, NewVReg, true));
+          RestoreIdxes.insert(std::make_pair(MBBId, Infos));
+        } else {
+          RII->second.push_back(SRInfo(index, NewVReg, true));
+        }
         RestoreMBBs.set(MBBId);
       }
     }
@@ -983,6 +1024,30 @@ rewriteInstructionsForSpills(const LiveInterval &li, bool TrySplit,
   }
 }
 
+bool LiveIntervals::alsoFoldARestore(int Id, int index, unsigned vr,
+                        BitVector &RestoreMBBs,
+                        std::map<unsigned,std::vector<SRInfo> > &RestoreIdxes) {
+  if (!RestoreMBBs[Id])
+    return false;
+  std::vector<SRInfo> &Restores = RestoreIdxes[Id];
+  for (unsigned i = 0, e = Restores.size(); i != e; ++i)
+    if (Restores[i].index == index &&
+        Restores[i].vreg == vr &&
+        Restores[i].canFold)
+      return true;
+  return false;
+}
+
+void LiveIntervals::eraseRestoreInfo(int Id, int index, unsigned vr,
+                        BitVector &RestoreMBBs,
+                        std::map<unsigned,std::vector<SRInfo> > &RestoreIdxes) {
+  if (!RestoreMBBs[Id])
+    return;
+  std::vector<SRInfo> &Restores = RestoreIdxes[Id];
+  for (unsigned i = 0, e = Restores.size(); i != e; ++i)
+    if (Restores[i].index == index && Restores[i].vreg)
+      Restores[i].index = -1;
+}
 
 
 std::vector<LiveInterval*> LiveIntervals::
@@ -1001,10 +1066,10 @@ addIntervalsForSpills(const LiveInterval &li,
 
   // Each bit specify whether it a spill is required in the MBB.
   BitVector SpillMBBs(mf_->getNumBlockIDs());
-  std::map<unsigned, std::pair<int, bool> > SpillIdxes;
+  std::map<unsigned, std::vector<SRInfo> > SpillIdxes;
   BitVector RestoreMBBs(mf_->getNumBlockIDs());
-  std::map<unsigned, std::pair<int, bool> > RestoreIdxes;
-  std::map<unsigned,unsigned> NewVRegs;
+  std::map<unsigned, std::vector<SRInfo> > RestoreIdxes;
+  std::map<unsigned,unsigned> MBBVRegsMap;
   std::vector<LiveInterval*> NewLIs;
   SSARegMap *RegMap = mf_->getSSARegMap();
   const TargetRegisterClass* rc = RegMap->getRegClass(li.reg);
@@ -1039,17 +1104,18 @@ addIntervalsForSpills(const LiveInterval &li,
       // are two-address instructions that re-defined the value. Only the
       // first def can be rematerialized!
       if (IsFirstRange) {
+        // Note ReMatOrigDefMI has already been deleted.
         rewriteInstructionsForSpills(li, false, I, NULL, ReMatDefMI,
                              Slot, LdSlot, isLoad, isLoadSS, DefIsReMat,
                              false, vrm, RegMap, rc, ReMatIds, loopInfo,
                              SpillMBBs, SpillIdxes, RestoreMBBs, RestoreIdxes,
-                             NewVRegs, NewLIs);
+                             MBBVRegsMap, NewLIs);
       } else {
         rewriteInstructionsForSpills(li, false, I, NULL, 0,
                              Slot, 0, false, false, false,
                              false, vrm, RegMap, rc, ReMatIds, loopInfo,
                              SpillMBBs, SpillIdxes, RestoreMBBs, RestoreIdxes,
-                             NewVRegs, NewLIs);
+                             MBBVRegsMap, NewLIs);
       }
       IsFirstRange = false;
     }
@@ -1081,21 +1147,14 @@ addIntervalsForSpills(const LiveInterval &li,
       vrm.setVirtIsReMaterialized(li.reg, ReMatDefMI);
 
       bool CanDelete = true;
-      for (unsigned j = 0, ee = VNI->kills.size(); j != ee; ++j) {
-        unsigned KillIdx = VNI->kills[j];
-        MachineInstr *KillMI = (KillIdx & 1)
-          ? NULL : getInstructionFromIndex(KillIdx);
-        // Kill is a phi node, not all of its uses can be rematerialized.
+      if (VNI->hasPHIKill) {
+        // A kill is a phi node, not all of its uses can be rematerialized.
         // It must not be deleted.
-        if (!KillMI) {
-          CanDelete = false;
-          // Need a stack slot if there is any live range where uses cannot be
-          // rematerialized.
-          NeedStackSlot = true;
-          break;
-        }
+        CanDelete = false;
+        // Need a stack slot if there is any live range where uses cannot be
+        // rematerialized.
+        NeedStackSlot = true;
       }
-
       if (CanDelete)
         ReMatDelete.set(VN);
     } else {
@@ -1124,17 +1183,21 @@ addIntervalsForSpills(const LiveInterval &li,
                                Slot, LdSlot, isLoad, isLoadSS, DefIsReMat,
                                CanDelete, vrm, RegMap, rc, ReMatIds, loopInfo,
                                SpillMBBs, SpillIdxes, RestoreMBBs, RestoreIdxes,
-                               NewVRegs, NewLIs);
+                               MBBVRegsMap, NewLIs);
   }
 
   // Insert spills / restores if we are splitting.
-  if (TrySplit) {
-    if (NeedStackSlot) {
-      int Id = SpillMBBs.find_first();
-      while (Id != -1) {
-        unsigned VReg = NewVRegs[Id];
-        int index = SpillIdxes[Id].first;
-        bool DoFold = SpillIdxes[Id].second;
+  if (!TrySplit)
+    return NewLIs;
+
+  if (NeedStackSlot) {
+    int Id = SpillMBBs.find_first();
+    while (Id != -1) {
+      std::vector<SRInfo> &spills = SpillIdxes[Id];
+      for (unsigned i = 0, e = spills.size(); i != e; ++i) {
+        int index = spills[i].index;
+        unsigned VReg = spills[i].vreg;
+        bool DoFold = spills[i].canFold;
         bool isReMat = vrm.isReMaterialized(VReg);
         MachineInstr *MI = getInstructionFromIndex(index);
         int OpIdx = -1;
@@ -1149,8 +1212,7 @@ addIntervalsForSpills(const LiveInterval &li,
               // first and only use.
               // If there are more than one uses, a load is still needed.
               if (!isReMat && !FoldedLoad &&
-                  RestoreMBBs[Id] && RestoreIdxes[Id].first == index &&
-                  RestoreIdxes[Id].second) {
+                  alsoFoldARestore(Id, index,VReg,RestoreMBBs,RestoreIdxes)) {
                 FoldedLoad = true;
                 continue;
               } else {
@@ -1165,13 +1227,10 @@ addIntervalsForSpills(const LiveInterval &li,
         if (OpIdx == -1)
           DoFold = false;
         if (DoFold) {
-          if (tryFoldMemoryOperand(MI, vrm, NULL, index, OpIdx, true, Slot,
-                                   VReg)) {
-            if (FoldedLoad) {
+          if (tryFoldMemoryOperand(MI, vrm, NULL, index,OpIdx,true,Slot,VReg)) {
+            if (FoldedLoad)
               // Folded a two-address instruction, do not issue a load.
-              RestoreMBBs.reset(Id);
-              RestoreIdxes.erase(Id);
-            }
+              eraseRestoreInfo(Id, index, VReg, RestoreMBBs, RestoreIdxes);
           } else
             DoFold = false;
         }
@@ -1179,15 +1238,20 @@ addIntervalsForSpills(const LiveInterval &li,
         // Else tell the spiller to issue a store for us.
         if (!DoFold)
           vrm.addSpillPoint(VReg, MI);
-        Id = SpillMBBs.find_next(Id);
       }
+      Id = SpillMBBs.find_next(Id);
     }
+  }
 
-    int Id = RestoreMBBs.find_first();
-    while (Id != -1) {
-      unsigned VReg = NewVRegs[Id];
-      int index = RestoreIdxes[Id].first;
-      bool DoFold = RestoreIdxes[Id].second;
+  int Id = RestoreMBBs.find_first();
+  while (Id != -1) {
+    std::vector<SRInfo> &restores = RestoreIdxes[Id];
+    for (unsigned i = 0, e = restores.size(); i != e; ++i) {
+      int index = restores[i].index;
+      if (index == -1)
+        continue;
+      unsigned VReg = restores[i].vreg;
+      bool DoFold = restores[i].canFold;
       MachineInstr *MI = getInstructionFromIndex(index);
       int OpIdx = -1;
       if (DoFold) {
@@ -1232,14 +1296,13 @@ addIntervalsForSpills(const LiveInterval &li,
       // load / rematerialization for us.
       if (!DoFold)
         vrm.addRestorePoint(VReg, MI);
-      Id = RestoreMBBs.find_next(Id);
     }
+    Id = RestoreMBBs.find_next(Id);
   }
 
   // Finalize spill weights.
-  if (TrySplit)
-    for (unsigned i = 0, e = NewLIs.size(); i != e; ++i)
-      NewLIs[i]->weight /= NewLIs[i]->getSize();
+  for (unsigned i = 0, e = NewLIs.size(); i != e; ++i)
+    NewLIs[i]->weight /= NewLIs[i]->getSize();
 
   return NewLIs;
 }