When factoring multiply expressions across adds, factor both
authorChris Lattner <sabre@nondot.org>
Fri, 1 Jan 2010 01:13:15 +0000 (01:13 +0000)
committerChris Lattner <sabre@nondot.org>
Fri, 1 Jan 2010 01:13:15 +0000 (01:13 +0000)
positive and negative forms of constants together.  This
allows us to compile:

int foo(int x, int y) {
    return (x-y) + (x-y) + (x-y);
}

into:

_foo:                                                       ## @foo
subl %esi, %edi
leal (%rdi,%rdi,2), %eax
ret

instead of (where the 3 and -3 were not factored):

_foo:
        imull   $-3, 8(%esp), %ecx
        imull   $3, 4(%esp), %eax
        addl    %ecx, %eax
        ret

this started out as:
    movl    12(%ebp), %ecx
    imull   $3, 8(%ebp), %eax
    subl    %ecx, %eax
    subl    %ecx, %eax
    subl    %ecx, %eax
    ret

This comes from PR5359.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@92381 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/Reassociate.cpp
test/Transforms/Reassociate/basictest.ll

index a4c84863d92c9f0e4335bde244da7b5baeb43472..827b47d3feeb8718aa6afc0d322b537f34c035ad 100644 (file)
@@ -510,7 +510,8 @@ static Instruction *ConvertShiftToMul(Instruction *Shl,
 }
 
 // Scan backwards and forwards among values with the same rank as element i to
-// see if X exists.  If X does not exist, return i.
+// see if X exists.  If X does not exist, return i.  This is useful when
+// scanning for 'x' when we see '-x' because they both get the same rank.
 static unsigned FindInOperandList(SmallVectorImpl<ValueEntry> &Ops, unsigned i,
                                   Value *X) {
   unsigned XRank = Ops[i].Rank;
@@ -518,7 +519,7 @@ static unsigned FindInOperandList(SmallVectorImpl<ValueEntry> &Ops, unsigned i,
   for (unsigned j = i+1; j != e && Ops[j].Rank == XRank; ++j)
     if (Ops[j].Op == X)
       return j;
-  // Scan backwards
+  // Scan backwards.
   for (unsigned j = i-1; j != ~0U && Ops[j].Rank == XRank; --j)
     if (Ops[j].Op == X)
       return j;
@@ -547,28 +548,47 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) {
   LinearizeExprTree(BO, Factors);
 
   bool FoundFactor = false;
-  for (unsigned i = 0, e = Factors.size(); i != e; ++i)
+  bool NeedsNegate = false;
+  for (unsigned i = 0, e = Factors.size(); i != e; ++i) {
     if (Factors[i].Op == Factor) {
       FoundFactor = true;
       Factors.erase(Factors.begin()+i);
       break;
     }
+    
+    // If this is a negative version of this factor, remove it.
+    if (ConstantInt *FC1 = dyn_cast<ConstantInt>(Factor))
+      if (ConstantInt *FC2 = dyn_cast<ConstantInt>(Factors[i].Op))
+        if (FC1->getValue() == -FC2->getValue()) {
+          FoundFactor = NeedsNegate = true;
+          Factors.erase(Factors.begin()+i);
+          break;
+        }
+  }
+  
   if (!FoundFactor) {
     // Make sure to restore the operands to the expression tree.
     RewriteExprTree(BO, Factors);
     return 0;
   }
   
+  BasicBlock::iterator InsertPt = BO; ++InsertPt;
+  
   // If this was just a single multiply, remove the multiply and return the only
   // remaining operand.
   if (Factors.size() == 1) {
     ValueRankMap.erase(BO);
     BO->eraseFromParent();
-    return Factors[0].Op;
+    V = Factors[0].Op;
+  } else {
+    RewriteExprTree(BO, Factors);
+    V = BO;
   }
   
-  RewriteExprTree(BO, Factors);
-  return BO;
+  if (NeedsNegate)
+    V = BinaryOperator::CreateNeg(V, "neg", InsertPt);
+  
+  return V;
 }
 
 /// FindSingleUseMultiplyFactors - If V is a single-use multiply, recursively
@@ -645,6 +665,9 @@ Value *Reassociate::OptimizeAdd(Instruction *I,
   // Scan the operand lists looking for X and -X pairs.  If we find any, we
   // can simplify the expression. X+-X == 0.  While we're at it, scan for any
   // duplicates.  We want to canonicalize Y+Y+Y+Z -> 3*Y+Z.
+  //
+  // TODO: We could handle "X + ~X" -> "-1" if we wanted, since "-X = ~X+1".
+  //
   for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
     Value *TheOp = Ops[i].Op;
     // Check to see if we've seen this operand before.  If so, we factor all
@@ -730,21 +753,26 @@ Value *Reassociate::OptimizeAdd(Instruction *I,
     assert(Factors.size() > 1 && "Bad linearize!");
     
     // Add one to FactorOccurrences for each unique factor in this op.
-    if (Factors.size() == 2) {
-      unsigned Occ = ++FactorOccurrences[Factors[0]];
-      if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0]; }
-      if (Factors[0] != Factors[1]) {   // Don't double count A*A.
-        Occ = ++FactorOccurrences[Factors[1]];
-        if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1]; }
-      }
-    } else {
-      SmallPtrSet<Value*, 4> Duplicates;
-      for (unsigned i = 0, e = Factors.size(); i != e; ++i) {
-        if (!Duplicates.insert(Factors[i])) continue;
-        
-        unsigned Occ = ++FactorOccurrences[Factors[i]];
-        if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i]; }
-      }
+    SmallPtrSet<Value*, 8> Duplicates;
+    for (unsigned i = 0, e = Factors.size(); i != e; ++i) {
+      Value *Factor = Factors[i];
+      if (!Duplicates.insert(Factor)) continue;
+      
+      unsigned Occ = ++FactorOccurrences[Factor];
+      if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factor; }
+      
+      // If Factor is a negative constant, add the negated value as a factor
+      // because we can percolate the negate out.  Watch for minint, which
+      // cannot be positivified.
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(Factor))
+        if (CI->getValue().isNegative() && !CI->getValue().isMinSignedValue()) {
+          Factor = ConstantInt::get(CI->getContext(), -CI->getValue());
+          assert(!Duplicates.count(Factor) &&
+                 "Shouldn't have two constant factors, missed a canonicalize");
+          
+          unsigned Occ = ++FactorOccurrences[Factor];
+          if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factor; }
+        }
     }
   }
   
index ba0c9f210fcfeb14fb25a277c947e7cb28b68097..e77d83d160e40f91121b7b7d9e1dcb56ffd24fb4 100644 (file)
@@ -191,3 +191,16 @@ define i32 @test13(i32 %X1, i32 %X2, i32 %X3) {
 ; CHECK-NEXT: ret i32
 }
 
+; PR5359
+define i32 @test14(i32 %X1, i32 %X2) {
+  %B = mul i32 %X1, 47   ; X1*47
+  %C = mul i32 %X2, -47  ; X2*-47
+  %D = add i32 %B, %C    ; X1*47 + X2*-47 -> 47*(X1-X2)
+  ret i32 %D
+; CHECK: @test14
+; CHECK-NEXT: sub i32 %X1, %X2
+; CHECK-NEXT: mul i32 {{.*}}, 47
+; CHECK-NEXT: ret i32
+}
+
+