Add the long awaited memory operand folding support for linear scan
authorAlkis Evlogimenos <alkis@evlogimenos.com>
Mon, 1 Mar 2004 20:05:10 +0000 (20:05 +0000)
committerAlkis Evlogimenos <alkis@evlogimenos.com>
Mon, 1 Mar 2004 20:05:10 +0000 (20:05 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@12058 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/LiveIntervalAnalysis.h
lib/CodeGen/LiveIntervalAnalysis.cpp
lib/CodeGen/LiveIntervalAnalysis.h
lib/CodeGen/RegAllocLinearScan.cpp
lib/CodeGen/VirtRegMap.cpp
lib/CodeGen/VirtRegMap.h

index 24bc8956b3f94b1caedc5a73f7b1f33c59e915f5..5b78342e281d4747e8407546da76edfcdcb4b6f7 100644 (file)
@@ -28,6 +28,7 @@ namespace llvm {
 
     class LiveVariables;
     class MRegisterInfo;
+    class VirtRegMap;
 
     class LiveIntervals : public MachineFunctionPass
     {
@@ -164,7 +165,7 @@ namespace llvm {
 
         Intervals& getIntervals() { return intervals_; }
 
-        void updateSpilledInterval(Interval& i, int slot);
+        void updateSpilledInterval(Interval& i, VirtRegMap& vrm, int slot);
 
     private:
         /// computeIntervals - compute live intervals
index fc46de2be9285c72da251c2dfd7ecaf70c410fd6..d6cc357fd5425b00f49f5413c7b1a25b93645aa5 100644 (file)
@@ -31,6 +31,7 @@
 #include "Support/Debug.h"
 #include "Support/Statistic.h"
 #include "Support/STLExtras.h"
+#include "VirtRegMap.h"
 #include <cmath>
 #include <iostream>
 #include <limits>
@@ -184,7 +185,9 @@ bool LiveIntervals::runOnMachineFunction(MachineFunction &fn) {
     return true;
 }
 
-void LiveIntervals::updateSpilledInterval(Interval& li, int slot)
+void LiveIntervals::updateSpilledInterval(Interval& li,
+                                          VirtRegMap& vrm,
+                                          int slot)
 {
     assert(li.weight != std::numeric_limits<float>::infinity() &&
            "attempt to spill already spilled interval!");
@@ -202,27 +205,40 @@ void LiveIntervals::updateSpilledInterval(Interval& li, int slot)
             while (!getInstructionFromIndex(index)) index += InstrSlots::NUM;
             MachineBasicBlock::iterator mi = getInstructionFromIndex(index);
 
+        for_operand:
             for (unsigned i = 0; i < mi->getNumOperands(); ++i) {
                 MachineOperand& mop = mi->getOperand(i);
                 if (mop.isRegister() && mop.getReg() == li.reg) {
-                    // This is tricky. We need to add information in
-                    // the interval about the spill code so we have to
-                    // use our extra load/store slots.
-                    //
-                    // If we have a use we are going to have a load so
-                    // we start the interval from the load slot
-                    // onwards. Otherwise we start from the def slot.
-                    unsigned start = (mop.isUse() ?
-                                      getLoadIndex(index) :
-                                      getDefIndex(index));
-                    // If we have a def we are going to have a store
-                    // right after it so we end the interval after the
-                    // use of the next instruction. Otherwise we end
-                    // after the use of this instruction.
-                    unsigned end = 1 + (mop.isDef() ?
-                                        getUseIndex(index+InstrSlots::NUM) :
-                                        getUseIndex(index));
-                    li.addRange(start, end);
+                    MachineInstr* old = mi;
+                    if (mri_->foldMemoryOperand(mi, i, slot)) {
+                        lv_->instructionChanged(old, mi);
+                        vrm.virtFolded(li.reg, old, mi);
+                        mi2iMap_.erase(old);
+                        i2miMap_[index/InstrSlots::NUM] = mi;
+                        mi2iMap_[mi] = index;
+                        ++numFolded;
+                        goto for_operand;
+                    }
+                    else {
+                        // This is tricky. We need to add information in
+                        // the interval about the spill code so we have to
+                        // use our extra load/store slots.
+                        //
+                        // If we have a use we are going to have a load so
+                        // we start the interval from the load slot
+                        // onwards. Otherwise we start from the def slot.
+                        unsigned start = (mop.isUse() ?
+                                          getLoadIndex(index) :
+                                          getDefIndex(index));
+                        // If we have a def we are going to have a store
+                        // right after it so we end the interval after the
+                        // use of the next instruction. Otherwise we end
+                        // after the use of this instruction.
+                        unsigned end = 1 + (mop.isDef() ?
+                                            getUseIndex(index+InstrSlots::NUM) :
+                                            getUseIndex(index));
+                        li.addRange(start, end);
+                    }
                 }
             }
         }
index 24bc8956b3f94b1caedc5a73f7b1f33c59e915f5..5b78342e281d4747e8407546da76edfcdcb4b6f7 100644 (file)
@@ -28,6 +28,7 @@ namespace llvm {
 
     class LiveVariables;
     class MRegisterInfo;
+    class VirtRegMap;
 
     class LiveIntervals : public MachineFunctionPass
     {
@@ -164,7 +165,7 @@ namespace llvm {
 
         Intervals& getIntervals() { return intervals_; }
 
-        void updateSpilledInterval(Interval& i, int slot);
+        void updateSpilledInterval(Interval& i, VirtRegMap& vrm, int slot);
 
     private:
         /// computeIntervals - compute live intervals
index 9e3961868276941b546e16ac860bbe029e2ecae9..d6c53cd0a72e4827a21967620684fe7972370e22 100644 (file)
@@ -385,7 +385,7 @@ void RA::assignRegOrStackSlotAtInterval(IntervalPtrs::value_type cur)
     if (cur->weight <= minWeight) {
         DEBUG(std::cerr << "\t\t\tspilling(c): " << *cur << '\n';);
         int slot = vrm_->assignVirt2StackSlot(cur->reg);
-        li_->updateSpilledInterval(*cur, slot);
+        li_->updateSpilledInterval(*cur, *vrm_, slot);
 
         // if we didn't eliminate the interval find where to add it
         // back to unhandled. We need to scan since unhandled are
@@ -424,7 +424,7 @@ void RA::assignRegOrStackSlotAtInterval(IntervalPtrs::value_type cur)
             DEBUG(std::cerr << "\t\t\tspilling(a): " << **i << '\n');
             earliestStart = std::min(earliestStart, (*i)->start());
             int slot = vrm_->assignVirt2StackSlot((*i)->reg);
-            li_->updateSpilledInterval(**i, slot);
+            li_->updateSpilledInterval(**i, *vrm_, slot);
         }
     }
     for (IntervalPtrs::iterator i = inactive_.begin();
@@ -436,7 +436,7 @@ void RA::assignRegOrStackSlotAtInterval(IntervalPtrs::value_type cur)
             DEBUG(std::cerr << "\t\t\tspilling(i): " << **i << '\n');
             earliestStart = std::min(earliestStart, (*i)->start());
             int slot = vrm_->assignVirt2StackSlot((*i)->reg);
-            li_->updateSpilledInterval(**i, slot);
+            li_->updateSpilledInterval(**i, *vrm_, slot);
         }
     }
 
index e517cb371e65b5c3ee874c1befaa2d742061f84d..1238a2cefbb43938beaf8dbfb3e8cb9c58df0b42 100644 (file)
@@ -19,6 +19,7 @@
 #include "VirtRegMap.h"
 #include "llvm/Function.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineInstr.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetInstrInfo.h"
 #include "Support/Statistic.h"
@@ -49,6 +50,24 @@ int VirtRegMap::assignVirt2StackSlot(unsigned virtReg)
     return frameIndex;
 }
 
+void VirtRegMap::virtFolded(unsigned virtReg,
+                            MachineInstr* oldMI,
+                            MachineInstr* newMI)
+{
+    // move previous memory references folded to new instruction
+    MI2VirtMap::iterator i, e;
+    std::vector<MI2VirtMap::mapped_type> regs;
+    for (tie(i, e) = mi2vMap_.equal_range(oldMI); i != e; ) {
+        regs.push_back(i->second);
+        mi2vMap_.erase(i++);
+    }
+    for (unsigned i = 0, e = regs.size(); i != e; ++i)
+        mi2vMap_.insert(std::make_pair(newMI, i));
+
+    // add new memory reference
+    mi2vMap_.insert(std::make_pair(newMI, virtReg));
+}
+
 std::ostream& llvm::operator<<(std::ostream& os, const VirtRegMap& vrm)
 {
     const MRegisterInfo* mri = vrm.mf_->getTarget().getRegisterInfo();
@@ -129,9 +148,9 @@ namespace {
                                          vrm_.getStackSlot(virtReg),
                                          mri_.getRegClass(physReg));
                 ++numStores;
-                DEBUG(std::cerr << "\t\tadded: ";
+                DEBUG(std::cerr << "added: ";
                       prior(nextLastRef)->print(std::cerr, tm_);
-                      std::cerr << "\t\tafter: ";
+                      std::cerr << "after: ";
                       lastDef->print(std::cerr, tm_));
                 lastDef_[virtReg] = 0;
             }
@@ -161,10 +180,8 @@ namespace {
                                               vrm_.getStackSlot(virtReg),
                                               mri_.getRegClass(physReg));
                     ++numLoads;
-                    DEBUG(std::cerr << "\t\tadded: ";
-                          prior(mii)->print(std::cerr,tm_);
-                          std::cerr << "\t\tbefore: ";
-                          mii->print(std::cerr, tm_));
+                    DEBUG(std::cerr << "added: ";
+                          prior(mii)->print(std::cerr,tm_));
                     lastDef_[virtReg] = mii;
                 }
             }
@@ -186,6 +203,16 @@ namespace {
         void eliminateVirtRegsInMbb(MachineBasicBlock& mbb) {
             for (MachineBasicBlock::iterator mii = mbb.begin(),
                      mie = mbb.end(); mii != mie; ++mii) {
+
+                // if we have references to memory operands make sure
+                // we clear all physical registers that may contain
+                // the value of the spilled virtual register
+                VirtRegMap::MI2VirtMap::const_iterator i, e;
+                for (tie(i, e) = vrm_.getFoldedVirts(mii); i != e; ++i) {
+                    unsigned physReg = vrm_.getPhys(i->second);
+                    if (physReg) vacateJustPhysReg(mbb, mii, physReg);
+                }
+
                 // rewrite all used operands
                 for (unsigned i = 0, e = mii->getNumOperands(); i != e; ++i) {
                     MachineOperand& op = mii->getOperand(i);
index b0af8c5c739ded595bca0cad7aad6244a9308a14..90cc44d31c4261e0182e413ac77e1ed2453b8dcf 100644 (file)
 #include "llvm/CodeGen/SSARegMap.h"
 #include "Support/DenseMap.h"
 #include <climits>
+#include <map>
 
 namespace llvm {
 
+    class MachineInstr;
+
     class VirtRegMap {
     public:
         typedef DenseMap<unsigned, VirtReg2IndexFunctor> Virt2PhysMap;
         typedef DenseMap<int, VirtReg2IndexFunctor> Virt2StackSlotMap;
+        typedef std::multimap<MachineInstr*, unsigned> MI2VirtMap;
 
     private:
         MachineFunction* mf_;
         Virt2PhysMap v2pMap_;
         Virt2StackSlotMap v2ssMap_;
+        MI2VirtMap mi2vMap_;
 
         // do not implement
         VirtRegMap(const VirtRegMap& rhs);
@@ -89,6 +94,15 @@ namespace llvm {
 
         int assignVirt2StackSlot(unsigned virtReg);
 
+        void virtFolded(unsigned virtReg,
+                        MachineInstr* oldMI,
+                        MachineInstr* newMI);
+
+        std::pair<MI2VirtMap::const_iterator, MI2VirtMap::const_iterator>
+        getFoldedVirts(MachineInstr* MI) const {
+            return mi2vMap_.equal_range(MI);
+        }
+
         friend std::ostream& operator<<(std::ostream& os, const VirtRegMap& li);
     };