If a loop termination compare instruction is the only use of its stride,
authorEvan Cheng <evan.cheng@apple.com>
Thu, 25 Oct 2007 09:11:16 +0000 (09:11 +0000)
committerEvan Cheng <evan.cheng@apple.com>
Thu, 25 Oct 2007 09:11:16 +0000 (09:11 +0000)
and the compaison is against a constant value, try eliminate the stride
by moving the compare instruction to another stride and change its
constant operand accordingly. e.g.

loop:
...
v1 = v1 + 3
v2 = v2 + 1
if (v2 < 10) goto loop
=>
loop:
...
v1 = v1 + 3
if (v1 < 30) goto loop

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@43336 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/LoopStrengthReduce.cpp
test/CodeGen/X86/loop-strength-reduce3.ll [new file with mode: 0644]
test/CodeGen/X86/loop-strength-reduce4.ll [new file with mode: 0644]

index a58356542db5ea1f3fb3ab1661a1f82eb678fba0..a011dd732f5687e53021c700543c9bf3c1b1fb59 100644 (file)
 #include <set>
 using namespace llvm;
 
-STATISTIC(NumReduced , "Number of GEPs strength reduced");
-STATISTIC(NumInserted, "Number of PHIs inserted");
-STATISTIC(NumVariable, "Number of PHIs with variable strides");
+STATISTIC(NumReduced ,    "Number of GEPs strength reduced");
+STATISTIC(NumInserted,    "Number of PHIs inserted");
+STATISTIC(NumVariable,    "Number of PHIs with variable strides");
+STATISTIC(NumEliminated , "Number of strides eliminated");
 
 namespace {
 
@@ -170,18 +171,17 @@ private:
     bool AddUsersIfInteresting(Instruction *I, Loop *L,
                                std::set<Instruction*> &Processed);
     SCEVHandle GetExpressionSCEV(Instruction *E, Loop *L);
-
+    ICmpInst *ChangeCompareStride(Loop *L, ICmpInst *Cond,
+                                  IVStrideUse* &CondUse,
+                                  const SCEVHandle* &CondStride);
     void OptimizeIndvars(Loop *L);
     bool FindIVForUser(ICmpInst *Cond, IVStrideUse *&CondUse,
                        const SCEVHandle *&CondStride);
-
     unsigned CheckForIVReuse(bool, const SCEVHandle&,
                              IVExpr&, const Type*,
                              const std::vector<BasedUser>& UsersToProcess);
-
     bool ValidStride(bool, int64_t,
                      const std::vector<BasedUser>& UsersToProcess);
-
     void StrengthReduceStridedIVUsers(const SCEVHandle &Stride,
                                       IVUsersOfOneStride &Uses,
                                       Loop *L, bool isOnlyStride);
@@ -981,8 +981,6 @@ unsigned LoopStrengthReduce::CheckForIVReuse(bool HasBaseReg,
                                 const SCEVHandle &Stride, 
                                 IVExpr &IV, const Type *Ty,
                                 const std::vector<BasedUser>& UsersToProcess) {
-  if (!TLI) return 0;
-
   if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Stride)) {
     int64_t SInt = SC->getValue()->getSExtValue();
     if (SInt == 1) return 0;
@@ -1039,6 +1037,10 @@ void LoopStrengthReduce::StrengthReduceStridedIVUsers(const SCEVHandle &Stride,
                                                       IVUsersOfOneStride &Uses,
                                                       Loop *L,
                                                       bool isOnlyStride) {
+  // If all the users are moved to another stride, then there is nothing to do.
+  if (Uses.Users.size() == 0)
+    return;
+
   // Transform our list of users and offsets to a bit more complex table.  In
   // this new vector, each 'BasedUser' contains 'Base' the base of the
   // strided accessas well as the old information from Uses.  We progressively
@@ -1377,6 +1379,154 @@ bool LoopStrengthReduce::FindIVForUser(ICmpInst *Cond, IVStrideUse *&CondUse,
   return false;
 }    
 
+namespace {
+  // Constant strides come first which in turns are sorted by their absolute
+  // values. If absolute values are the same, then positive strides comes first.
+  // e.g.
+  // 4, -1, X, 1, 2 ==> 1, -1, 2, 4, X
+  struct StrideCompare {
+    bool operator()(const SCEVHandle &LHS, const SCEVHandle &RHS) {
+      SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS);
+      SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
+      if (LHSC && RHSC) {
+        int64_t  LV = LHSC->getValue()->getSExtValue();
+        int64_t  RV = RHSC->getValue()->getSExtValue();
+        uint64_t ALV = (LV < 0) ? -LV : LV;
+        uint64_t ARV = (RV < 0) ? -RV : RV;
+        if (ALV == ARV)
+          return LV > RV;
+        else
+          return ALV < ARV;
+      }
+      return (LHSC && !RHSC);
+    }
+  };
+}
+
+/// ChangeCompareStride - If a loop termination compare instruction is the
+/// only use of its stride, and the compaison is against a constant value,
+/// try eliminate the stride by moving the compare instruction to another
+/// stride and change its constant operand accordingly. e.g.
+///
+/// loop:
+/// ...
+/// v1 = v1 + 3
+/// v2 = v2 + 1
+/// if (v2 < 10) goto loop
+/// =>
+/// loop:
+/// ...
+/// v1 = v1 + 3
+/// if (v1 < 30) goto loop
+ICmpInst *LoopStrengthReduce::ChangeCompareStride(Loop *L, ICmpInst *Cond,
+                                                  IVStrideUse* &CondUse,
+                                                const SCEVHandle* &CondStride) {
+  if (StrideOrder.size() < 2 ||
+      IVUsesByStride[*CondStride].Users.size() != 1)
+    return Cond;
+  // FIXME: loosen this restriction?
+  if (!isa<SCEVConstant>(CondUse->Offset))
+    return Cond;
+  const SCEVConstant *SC = dyn_cast<SCEVConstant>(*CondStride);
+  if (!SC) return Cond;
+  ConstantInt *C = dyn_cast<ConstantInt>(Cond->getOperand(1));
+  if (!C) return Cond;
+
+  ICmpInst::Predicate Predicate = Cond->getPredicate();
+  bool isSigned = ICmpInst::isSignedPredicate(Predicate);
+  int64_t CmpSSInt = SC->getValue()->getSExtValue();
+  int64_t CmpVal = C->getValue().getSExtValue();
+  uint64_t SignBit = 1ULL << (C->getValue().getBitWidth()-1);
+  int64_t NewCmpVal = CmpVal;
+  SCEVHandle *NewStride = NULL;
+  Value *NewIncV = NULL;
+  int64_t Scale = 1;
+  const Type *CmpTy = C->getType();
+  const Type *NewCmpTy = NULL;
+
+  // Look for a suitable stride / iv as replacement.
+  std::stable_sort(StrideOrder.begin(), StrideOrder.end(), StrideCompare());
+  for (unsigned i = 0, e = StrideOrder.size(); i != e; ++i) {
+    std::map<SCEVHandle, IVUsersOfOneStride>::iterator SI = 
+      IVUsesByStride.find(StrideOrder[i]);
+    if (!isa<SCEVConstant>(SI->first))
+      continue;
+    int64_t SSInt = cast<SCEVConstant>(SI->first)->getValue()->getSExtValue();
+    if (abs(SSInt) < abs(CmpSSInt) && (CmpSSInt % SSInt) == 0) {
+      Scale = CmpSSInt / SSInt;
+      NewCmpVal = CmpVal / Scale;
+    } else if (abs(SSInt) > abs(CmpSSInt) && (SSInt % CmpSSInt) == 0) {
+      Scale = SSInt / CmpSSInt;
+      NewCmpVal = CmpVal * Scale;
+    } else
+      continue;
+
+    // Watch out for overflow.
+    if (isSigned && (CmpVal & SignBit) != (NewCmpVal & SignBit))
+      NewCmpVal = CmpVal;
+    if (NewCmpVal != CmpVal) {
+      // Pick the best iv to use trying to avoid a cast.
+      NewIncV = NULL;
+      for (std::vector<IVStrideUse>::iterator UI = SI->second.Users.begin(),
+             E = SI->second.Users.end(); UI != E; ++UI) {
+        //        if (!isa<SCEVConstant>(UI->Offset))
+        //          continue;
+        NewIncV = UI->OperandValToReplace;
+        if (NewIncV->getType() == CmpTy)
+          break;
+      }
+      if (!NewIncV) {
+        NewCmpVal = CmpVal;
+        continue;
+      }
+
+      // FIXME: allow reuse of iv of a smaller type?
+      NewCmpTy = NewIncV->getType();
+      if (!CmpTy->canLosslesslyBitCastTo(NewCmpTy) &&
+          !(isa<PointerType>(NewCmpTy) &&
+            CmpTy->canLosslesslyBitCastTo(UIntPtrTy))) {
+        NewCmpVal = CmpVal;
+        continue;
+      }
+
+      // If scale is negative, use inverse predicate unless it's testing
+      // for equality.
+      if (Scale < 0 && !Cond->isEquality())
+        Predicate = ICmpInst::getInversePredicate(Predicate);
+
+      NewStride = &StrideOrder[i];
+      break;
+    }
+  }
+
+  if (NewCmpVal != CmpVal) {
+    // Create a new compare instruction using new stride / iv.
+    ICmpInst *OldCond = Cond;
+    Value *RHS = ConstantInt::get(C->getType(), NewCmpVal);
+    // Both sides of a ICmpInst must be of the same type.
+    if (NewCmpTy != CmpTy) {
+      if (isa<PointerType>(NewCmpTy) && !isa<PointerType>(CmpTy))
+        RHS= SCEVExpander::InsertCastOfTo(Instruction::IntToPtr, RHS, NewCmpTy);
+      else
+        RHS = SCEVExpander::InsertCastOfTo(Instruction::BitCast, RHS, NewCmpTy);
+    }
+    Cond = new ICmpInst(Predicate, NewIncV, RHS);
+    Cond->setName(L->getHeader()->getName() + ".termcond");
+    OldCond->getParent()->getInstList().insert(OldCond, Cond);
+    OldCond->replaceAllUsesWith(Cond);
+    OldCond->eraseFromParent();
+    IVUsesByStride[*CondStride].Users.pop_back();
+    SCEVHandle NewOffset = SE->getMulExpr(CondUse->Offset,
+          SE->getConstant(ConstantInt::get(CondUse->Offset->getType(), Scale)));
+    IVUsesByStride[*NewStride].addUser(NewOffset, Cond, NewIncV);
+    CondUse = &IVUsesByStride[*NewStride].Users.back();
+    CondStride = NewStride;
+    ++NumEliminated;
+  }
+
+  return Cond;
+}
+
 // OptimizeIndvars - Now that IVUsesByStride is set up with all of the indvar
 // uses in the loop, look to see if we can eliminate some, in favor of using
 // common indvars for the different uses.
@@ -1403,7 +1553,10 @@ void LoopStrengthReduce::OptimizeIndvars(Loop *L) {
 
   if (!FindIVForUser(Cond, CondUse, CondStride))
     return; // setcc doesn't use the IV.
-  
+
+  // If possible, change stride and operands of the compare instruction to
+  // eliminate one stride.
+  Cond = ChangeCompareStride(L, Cond, CondUse, CondStride);
 
   // It's possible for the setcc instruction to be anywhere in the loop, and
   // possible for it to have multiple users.  If it is not immediately before
@@ -1431,30 +1584,6 @@ void LoopStrengthReduce::OptimizeIndvars(Loop *L) {
   CondUse->isUseOfPostIncrementedValue = true;
 }
 
-namespace {
-  // Constant strides come first which in turns are sorted by their absolute
-  // values. If absolute values are the same, then positive strides comes first.
-  // e.g.
-  // 4, -1, X, 1, 2 ==> 1, -1, 2, 4, X
-  struct StrideCompare {
-    bool operator()(const SCEVHandle &LHS, const SCEVHandle &RHS) {
-      SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS);
-      SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
-      if (LHSC && RHSC) {
-        int64_t  LV = LHSC->getValue()->getSExtValue();
-        int64_t  RV = RHSC->getValue()->getSExtValue();
-        uint64_t ALV = (LV < 0) ? -LV : LV;
-        uint64_t ARV = (RV < 0) ? -RV : RV;
-        if (ALV == ARV)
-          return LV > RV;
-        else
-          return ALV < ARV;
-      }
-      return (LHSC && !RHSC);
-    }
-  };
-}
-
 bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager &LPM) {
 
   LI = &getAnalysis<LoopInfo>();
diff --git a/test/CodeGen/X86/loop-strength-reduce3.ll b/test/CodeGen/X86/loop-strength-reduce3.ll
new file mode 100644 (file)
index 0000000..4e95bdd
--- /dev/null
@@ -0,0 +1,37 @@
+; RUN: llvm-as < %s | llc -march=x86 | grep cmp | grep 240
+; RUN: llvm-as < %s | llc -march=x86 | grep inc | count 1
+
+define i32 @foo(i32 %A, i32 %B, i32 %C, i32 %D) {
+entry:
+       %tmp2955 = icmp sgt i32 %C, 0           ; <i1> [#uses=1]
+       br i1 %tmp2955, label %bb26.outer.us, label %bb40.split
+
+bb26.outer.us:         ; preds = %bb26.bb32_crit_edge.us, %entry
+       %i.044.0.ph.us = phi i32 [ 0, %entry ], [ %indvar.next57, %bb26.bb32_crit_edge.us ]             ; <i32> [#uses=2]
+       %k.1.ph.us = phi i32 [ 0, %entry ], [ %k.0.us, %bb26.bb32_crit_edge.us ]                ; <i32> [#uses=1]
+       %tmp3.us = mul i32 %i.044.0.ph.us, 6            ; <i32> [#uses=1]
+       br label %bb1.us
+
+bb1.us:                ; preds = %bb1.us, %bb26.outer.us
+       %j.053.us = phi i32 [ 0, %bb26.outer.us ], [ %tmp25.us, %bb1.us ]               ; <i32> [#uses=2]
+       %k.154.us = phi i32 [ %k.1.ph.us, %bb26.outer.us ], [ %k.0.us, %bb1.us ]                ; <i32> [#uses=1]
+       %tmp5.us = add i32 %tmp3.us, %j.053.us          ; <i32> [#uses=1]
+       %tmp7.us = shl i32 %D, %tmp5.us         ; <i32> [#uses=2]
+       %tmp9.us = icmp eq i32 %tmp7.us, %B             ; <i1> [#uses=1]
+       %tmp910.us = zext i1 %tmp9.us to i32            ; <i32> [#uses=1]
+       %tmp12.us = and i32 %tmp7.us, %A                ; <i32> [#uses=1]
+       %tmp19.us = and i32 %tmp12.us, %tmp910.us               ; <i32> [#uses=1]
+       %k.0.us = add i32 %tmp19.us, %k.154.us          ; <i32> [#uses=3]
+       %tmp25.us = add i32 %j.053.us, 1                ; <i32> [#uses=2]
+       %tmp29.us = icmp slt i32 %tmp25.us, %C          ; <i1> [#uses=1]
+       br i1 %tmp29.us, label %bb1.us, label %bb26.bb32_crit_edge.us
+
+bb26.bb32_crit_edge.us:                ; preds = %bb1.us
+       %indvar.next57 = add i32 %i.044.0.ph.us, 1              ; <i32> [#uses=2]
+       %exitcond = icmp eq i32 %indvar.next57, 40              ; <i1> [#uses=1]
+       br i1 %exitcond, label %bb40.split, label %bb26.outer.us
+
+bb40.split:            ; preds = %bb26.bb32_crit_edge.us, %entry
+       %k.1.lcssa.lcssa.us-lcssa = phi i32 [ %k.0.us, %bb26.bb32_crit_edge.us ], [ 0, %entry ]         ; <i32> [#uses=1]
+       ret i32 %k.1.lcssa.lcssa.us-lcssa
+}
diff --git a/test/CodeGen/X86/loop-strength-reduce4.ll b/test/CodeGen/X86/loop-strength-reduce4.ll
new file mode 100644 (file)
index 0000000..711f223
--- /dev/null
@@ -0,0 +1,49 @@
+; RUN: llvm-as < %s | llc -march=x86 | grep cmp | grep 64
+; RUN: llvm-as < %s | llc -march=x86 | not grep inc
+
+@state = external global [0 x i32]             ; <[0 x i32]*> [#uses=4]
+@S = external global [0 x i32]         ; <[0 x i32]*> [#uses=4]
+
+define i32 @foo() {
+entry:
+       br label %bb
+
+bb:            ; preds = %bb, %entry
+       %indvar = phi i32 [ 0, %entry ], [ %indvar.next, %bb ]          ; <i32> [#uses=2]
+       %t.063.0 = phi i32 [ 0, %entry ], [ %tmp47, %bb ]               ; <i32> [#uses=1]
+       %j.065.0 = shl i32 %indvar, 2           ; <i32> [#uses=4]
+       %tmp3 = getelementptr [0 x i32]* @state, i32 0, i32 %j.065.0            ; <i32*> [#uses=2]
+       %tmp4 = load i32* %tmp3, align 4                ; <i32> [#uses=1]
+       %tmp6 = getelementptr [0 x i32]* @S, i32 0, i32 %t.063.0                ; <i32*> [#uses=1]
+       %tmp7 = load i32* %tmp6, align 4                ; <i32> [#uses=1]
+       %tmp8 = xor i32 %tmp7, %tmp4            ; <i32> [#uses=2]
+       store i32 %tmp8, i32* %tmp3, align 4
+       %tmp1378 = or i32 %j.065.0, 1           ; <i32> [#uses=1]
+       %tmp16 = getelementptr [0 x i32]* @state, i32 0, i32 %tmp1378           ; <i32*> [#uses=2]
+       %tmp17 = load i32* %tmp16, align 4              ; <i32> [#uses=1]
+       %tmp19 = getelementptr [0 x i32]* @S, i32 0, i32 %tmp8          ; <i32*> [#uses=1]
+       %tmp20 = load i32* %tmp19, align 4              ; <i32> [#uses=1]
+       %tmp21 = xor i32 %tmp20, %tmp17         ; <i32> [#uses=2]
+       store i32 %tmp21, i32* %tmp16, align 4
+       %tmp2680 = or i32 %j.065.0, 2           ; <i32> [#uses=1]
+       %tmp29 = getelementptr [0 x i32]* @state, i32 0, i32 %tmp2680           ; <i32*> [#uses=2]
+       %tmp30 = load i32* %tmp29, align 4              ; <i32> [#uses=1]
+       %tmp32 = getelementptr [0 x i32]* @S, i32 0, i32 %tmp21         ; <i32*> [#uses=1]
+       %tmp33 = load i32* %tmp32, align 4              ; <i32> [#uses=1]
+       %tmp34 = xor i32 %tmp33, %tmp30         ; <i32> [#uses=2]
+       store i32 %tmp34, i32* %tmp29, align 4
+       %tmp3982 = or i32 %j.065.0, 3           ; <i32> [#uses=1]
+       %tmp42 = getelementptr [0 x i32]* @state, i32 0, i32 %tmp3982           ; <i32*> [#uses=2]
+       %tmp43 = load i32* %tmp42, align 4              ; <i32> [#uses=1]
+       %tmp45 = getelementptr [0 x i32]* @S, i32 0, i32 %tmp34         ; <i32*> [#uses=1]
+       %tmp46 = load i32* %tmp45, align 4              ; <i32> [#uses=1]
+       %tmp47 = xor i32 %tmp46, %tmp43         ; <i32> [#uses=3]
+       store i32 %tmp47, i32* %tmp42, align 4
+       %indvar.next = add i32 %indvar, 1               ; <i32> [#uses=2]
+       %exitcond = icmp eq i32 %indvar.next, 4         ; <i1> [#uses=1]
+       br i1 %exitcond, label %bb57, label %bb
+
+bb57:          ; preds = %bb
+       %tmp59 = and i32 %tmp47, 255            ; <i32> [#uses=1]
+       ret i32 %tmp59
+}