Implement a FIXME, recusively reassociating
authorChris Lattner <sabre@nondot.org>
Tue, 14 Mar 2006 16:04:29 +0000 (16:04 +0000)
committerChris Lattner <sabre@nondot.org>
Tue, 14 Mar 2006 16:04:29 +0000 (16:04 +0000)
A*A*B + A*A*C   -->   A*(A*B+A*C)   -->   A*(A*(B+C))

This implements Reassociate/mul-factor3.ll

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@26757 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/Reassociate.cpp

index dc44ad593f2471f80776c81cff196dddbb6021ec..e495ffafbb9b2e0380fa5431495f2957d9845c99 100644 (file)
@@ -79,8 +79,8 @@ namespace {
     void BuildRankMap(Function &F);
     unsigned getRank(Value *V);
     void ReassociateExpression(BinaryOperator *I);
-    void RewriteExprTree(BinaryOperator *I, unsigned Idx,
-                         std::vector<ValueEntry> &Ops);
+    void RewriteExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops,
+                         unsigned Idx = 0);
     Value *OptimizeExpression(BinaryOperator *I, std::vector<ValueEntry> &Ops);
     void LinearizeExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops);
     void LinearizeExpr(BinaryOperator *I);
@@ -174,7 +174,7 @@ unsigned Reassociate::getRank(Value *V) {
 /// isReassociableOp - Return true if V is an instruction of the specified
 /// opcode and if it only has one use.
 static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) {
-  if (V->hasOneUse() && isa<Instruction>(V) &&
+  if ((V->hasOneUse() || V->use_empty()) && isa<Instruction>(V) &&
       cast<Instruction>(V)->getOpcode() == Opcode)
     return cast<BinaryOperator>(V);
   return 0;
@@ -234,6 +234,10 @@ void Reassociate::LinearizeExpr(BinaryOperator *I) {
 /// form of the the expression (((a+b)+c)+d), and collects information about the
 /// rank of the non-tree operands.
 ///
+/// NOTE: These intentionally destroys the expression tree operands (turning
+/// them into undef values) to reduce #uses of the values.  This means that the
+/// caller MUST use something like RewriteExprTree to put the values back in.
+///
 void Reassociate::LinearizeExprTree(BinaryOperator *I,
                                     std::vector<ValueEntry> &Ops) {
   Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
@@ -262,6 +266,10 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
       // such, just remember these operands and their rank.
       Ops.push_back(ValueEntry(getRank(LHS), LHS));
       Ops.push_back(ValueEntry(getRank(RHS), RHS));
+      
+      // Clear the leaves out.
+      I->setOperand(0, UndefValue::get(I->getType()));
+      I->setOperand(1, UndefValue::get(I->getType()));
       return;
     } else {
       // Turn X+(Y+Z) -> (Y+Z)+X
@@ -293,13 +301,17 @@ void Reassociate::LinearizeExprTree(BinaryOperator *I,
 
   // Remember the RHS operand and its rank.
   Ops.push_back(ValueEntry(getRank(RHS), RHS));
+  
+  // Clear the RHS leaf out.
+  I->setOperand(1, UndefValue::get(I->getType()));
 }
 
 // RewriteExprTree - Now that the operands for this expression tree are
 // linearized and optimized, emit them in-order.  This function is written to be
 // tail recursive.
-void Reassociate::RewriteExprTree(BinaryOperator *I, unsigned i,
-                                  std::vector<ValueEntry> &Ops) {
+void Reassociate::RewriteExprTree(BinaryOperator *I,
+                                  std::vector<ValueEntry> &Ops,
+                                  unsigned i) {
   if (i+2 == Ops.size()) {
     if (I->getOperand(0) != Ops[i].Op ||
         I->getOperand(1) != Ops[i+1].Op) {
@@ -334,7 +346,7 @@ void Reassociate::RewriteExprTree(BinaryOperator *I, unsigned i,
   // Compactify the tree instructions together with each other to guarantee
   // that the expression tree is dominated by all of Ops.
   LHS->moveBefore(I);
-  RewriteExprTree(LHS, i+1, Ops);
+  RewriteExprTree(LHS, Ops, i+1);
 }
 
 
@@ -474,14 +486,36 @@ Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) {
       Factors.erase(Factors.begin()+i);
       break;
     }
-  if (!FoundFactor) return 0;
+  if (!FoundFactor) {
+    // Make sure to restore the operands to the expression tree.
+    RewriteExprTree(BO, Factors);
+    return 0;
+  }
   
   if (Factors.size() == 1) return Factors[0].Op;
   
-  RewriteExprTree(BO, 0, Factors);
+  RewriteExprTree(BO, Factors);
   return BO;
 }
 
+/// FindSingleUseMultiplyFactors - If V is a single-use multiply, recursively
+/// add its operands as factors, otherwise add V to the list of factors.
+static void FindSingleUseMultiplyFactors(Value *V,
+                                         std::vector<Value*> &Factors) {
+  BinaryOperator *BO;
+  if ((!V->hasOneUse() && !V->use_empty()) ||
+      !(BO = dyn_cast<BinaryOperator>(V)) ||
+      BO->getOpcode() != Instruction::Mul) {
+    Factors.push_back(V);
+    return;
+  }
+  
+  // Otherwise, add the LHS and RHS to the list of factors.
+  FindSingleUseMultiplyFactors(BO->getOperand(1), Factors);
+  FindSingleUseMultiplyFactors(BO->getOperand(0), Factors);
+}
+
+
 
 Value *Reassociate::OptimizeExpression(BinaryOperator *I,
                                        std::vector<ValueEntry> &Ops) {
@@ -627,26 +661,26 @@ Value *Reassociate::OptimizeExpression(BinaryOperator *I,
     if (!I->getType()->isFloatingPoint()) {
       for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
         if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(Ops[i].Op))
-          if (BOp->getOpcode() == Instruction::Mul && BOp->hasOneUse()) {
+          if (BOp->getOpcode() == Instruction::Mul && BOp->use_empty()) {
             // Compute all of the factors of this added value.
-            std::vector<ValueEntry> Factors;
-            LinearizeExprTree(BOp, Factors);
+            std::vector<Value*> Factors;
+            FindSingleUseMultiplyFactors(BOp, Factors);
             assert(Factors.size() > 1 && "Bad linearize!");
             
             // Add one to FactorOccurrences for each unique factor in this op.
             if (Factors.size() == 2) {
-              unsigned Occ = ++FactorOccurrences[Factors[0].Op];
-              if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0].Op; }
-              if (Factors[0].Op != Factors[1].Op) {   // Don't double count A*A.
-                Occ = ++FactorOccurrences[Factors[1].Op];
-                if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1].Op; }
+              unsigned Occ = ++FactorOccurrences[Factors[0]];
+              if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0]; }
+              if (Factors[0] != Factors[1]) {   // Don't double count A*A.
+                Occ = ++FactorOccurrences[Factors[1]];
+                if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1]; }
               }
             } else {
               std::set<Value*> Duplicates;
               for (unsigned i = 0, e = Factors.size(); i != e; ++i)
-                if (Duplicates.insert(Factors[i].Op).second) {
-                  unsigned Occ = ++FactorOccurrences[Factors[i].Op];
-                  if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i].Op; }
+                if (Duplicates.insert(Factors[i]).second) {
+                  unsigned Occ = ++FactorOccurrences[Factors[i]];
+                  if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i]; }
                 }
             }
           }
@@ -675,21 +709,26 @@ Value *Reassociate::OptimizeExpression(BinaryOperator *I,
       // No need for extra uses anymore.
       delete DummyInst;
 
+      unsigned NumAddedValues = NewMulOps.size();
       Value *V = EmitAddTreeOfValues(I, NewMulOps);
-      // FIXME: Must optimize V now, to handle this case:
-      // A*A*B + A*A*C -> A*(A*B+A*C)   -> A*(A*(B+C))
-      V = BinaryOperator::createMul(V, MaxOccVal, "tmp", I);
+      Value *V2 = BinaryOperator::createMul(V, MaxOccVal, "tmp", I);
 
+      // Now that we have inserted V and its sole use, optimize it. This allows
+      // us to handle cases that require multiple factoring steps, such as this:
+      // A*A*B + A*A*C   -->   A*(A*B+A*C)   -->   A*(A*(B+C))
+      if (NumAddedValues > 1)
+        ReassociateExpression(cast<BinaryOperator>(V));
+      
       ++NumFactor;
       
       if (Ops.size() == 0)
-        return V;
+        return V2;
 
       // Add the new value to the list of things being added.
-      Ops.insert(Ops.begin(), ValueEntry(getRank(V), V));
+      Ops.insert(Ops.begin(), ValueEntry(getRank(V2), V2));
       
       // Rewrite the tree so that there is now a use of V.
-      RewriteExprTree(I, 0, Ops);
+      RewriteExprTree(I, Ops);
       return OptimizeExpression(I, Ops);
     }
     break;
@@ -808,7 +847,7 @@ void Reassociate::ReassociateExpression(BinaryOperator *I) {
   } else {
     // Now that we ordered and optimized the expressions, splat them back into
     // the expression tree, removing any unneeded nodes.
-    RewriteExprTree(I, 0, Ops);
+    RewriteExprTree(I, Ops);
   }
 }