Teach jump threading to look at comparisons between phi nodes and non-constants.
authorNick Lewycky <nicholas@mxc.ca>
Fri, 19 Jun 2009 04:56:29 +0000 (04:56 +0000)
committerNick Lewycky <nicholas@mxc.ca>
Fri, 19 Jun 2009 04:56:29 +0000 (04:56 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@73755 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/JumpThreading.cpp
test/Transforms/JumpThreading/branch-no-const.ll [new file with mode: 0644]

index c0ca2df1ce11daa7259854e3df712fbc6cc5cd7c..ed84ec1b965403419e39c19580da7b91dd80e321 100644 (file)
@@ -76,7 +76,7 @@ namespace {
     bool ProcessBlock(BasicBlock *BB);
     bool ThreadEdge(BasicBlock *BB, BasicBlock *PredBB, BasicBlock *SuccBB,
                     unsigned JumpThreadCost);
-    BasicBlock *FactorCommonPHIPreds(PHINode *PN, Constant *CstVal);
+    BasicBlock *FactorCommonPHIPreds(PHINode *PN, Value *Val);
     bool ProcessBranchOnDuplicateCond(BasicBlock *PredBB, BasicBlock *DestBB);
     bool ProcessSwitchOnDuplicateCond(BasicBlock *PredBB, BasicBlock *DestBB);
 
@@ -163,10 +163,10 @@ void JumpThreading::FindLoopHeaders(Function &F) {
 /// This is important for things like "phi i1 [true, true, false, true, x]"
 /// where we only need to clone the block for the true blocks once.
 ///
-BasicBlock *JumpThreading::FactorCommonPHIPreds(PHINode *PN, Constant *CstVal) {
+BasicBlock *JumpThreading::FactorCommonPHIPreds(PHINode *PN, Value *Val) {
   SmallVector<BasicBlock*, 16> CommonPreds;
   for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
-    if (PN->getIncomingValue(i) == CstVal)
+    if (PN->getIncomingValue(i) == Val)
       CommonPreds.push_back(PN->getIncomingBlock(i));
   
   if (CommonPreds.size() == 1)
@@ -346,13 +346,19 @@ bool JumpThreading::ProcessBlock(BasicBlock *BB) {
                              CondInst->getOpcode() == Instruction::And))
     return true;
   
-  // If we have "br (phi != 42)" and the phi node has any constant values as 
-  // operands, we can thread through this block.
-  if (CmpInst *CondCmp = dyn_cast<CmpInst>(CondInst))
-    if (isa<PHINode>(CondCmp->getOperand(0)) &&
-        isa<Constant>(CondCmp->getOperand(1)) &&
-        ProcessBranchOnCompare(CondCmp, BB))
-      return true;
+  if (CmpInst *CondCmp = dyn_cast<CmpInst>(CondInst)) {
+    if (isa<PHINode>(CondCmp->getOperand(0))) {
+      // If we have "br (phi != 42)" and the phi node has any constant values
+      // as operands, we can thread through this block.
+      // 
+      // If we have "br (cmp phi, x)" and the phi node contains x such that the
+      // comparison uniquely identifies the branch target, we can thread
+      // through this block.
+
+      if (ProcessBranchOnCompare(CondCmp, BB))
+        return true;      
+    }
+  }
 
   // Check for some cases that are worth simplifying.  Right now we want to look
   // for loads that are used by a switch or by the condition for the branch.  If
@@ -770,12 +776,30 @@ bool JumpThreading::ProcessBranchOnLogical(Value *V, BasicBlock *BB,
   return ThreadEdge(BB, PredBB, SuccBB, JumpThreadCost);
 }
 
+/// GetResultOfComparison - Given an icmp/fcmp predicate and the left and right
+/// hand sides of the compare instruction, try to determine the result. If the
+/// result can not be determined, a null pointer is returned.
+static Constant *GetResultOfComparison(CmpInst::Predicate pred,
+                                       Value *LHS, Value *RHS) {
+  if (Constant *CLHS = dyn_cast<Constant>(LHS))
+    if (Constant *CRHS = dyn_cast<Constant>(RHS))
+      return ConstantExpr::getCompare(pred, CLHS, CRHS);
+
+  if (LHS == RHS)
+    if (isa<IntegerType>(LHS->getType()) || isa<PointerType>(LHS->getType()))
+      return ICmpInst::isTrueWhenEqual(pred) ? 
+                 ConstantInt::getTrue() : ConstantInt::getFalse();
+
+  return 0;
+}
+
 /// ProcessBranchOnCompare - We found a branch on a comparison between a phi
-/// node and a constant.  If the PHI node contains any constants as inputs, we
-/// can fold the compare for that edge and thread through it.
+/// node and a value.  If we can identify when the comparison is true between
+/// the phi inputs and the value, we can fold the compare for that edge and
+/// thread through it.
 bool JumpThreading::ProcessBranchOnCompare(CmpInst *Cmp, BasicBlock *BB) {
   PHINode *PN = cast<PHINode>(Cmp->getOperand(0));
-  Constant *RHS = cast<Constant>(Cmp->getOperand(1));
+  Value *RHS = Cmp->getOperand(1);
   
   // If the phi isn't in the current block, an incoming edge to this block
   // doesn't control the destination.
@@ -784,18 +808,17 @@ bool JumpThreading::ProcessBranchOnCompare(CmpInst *Cmp, BasicBlock *BB) {
   
   // We can do this simplification if any comparisons fold to true or false.
   // See if any do.
-  Constant *PredCst = 0;
+  Value *PredVal = 0;
   bool TrueDirection = false;
   for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
-    PredCst = dyn_cast<Constant>(PN->getIncomingValue(i));
-    if (PredCst == 0) continue;
+    PredVal = PN->getIncomingValue(i);
+    
+    Constant *Res = GetResultOfComparison(Cmp->getPredicate(), PredVal, RHS);
+    if (!Res) {
+      PredVal = 0;
+      continue;
+    }
     
-    Constant *Res;
-    if (ICmpInst *ICI = dyn_cast<ICmpInst>(Cmp))
-      Res = ConstantExpr::getICmp(ICI->getPredicate(), PredCst, RHS);
-    else
-      Res = ConstantExpr::getFCmp(cast<FCmpInst>(Cmp)->getPredicate(),
-                                  PredCst, RHS);
     // If this folded to a constant expr, we can't do anything.
     if (ConstantInt *ResC = dyn_cast<ConstantInt>(Res)) {
       TrueDirection = ResC->getZExtValue();
@@ -808,11 +831,11 @@ bool JumpThreading::ProcessBranchOnCompare(CmpInst *Cmp, BasicBlock *BB) {
     }
     
     // Otherwise, we can't fold this input.
-    PredCst = 0;
+    PredVal = 0;
   }
   
   // If no match, bail out.
-  if (PredCst == 0)
+  if (PredVal == 0)
     return false;
   
   // See if the cost of duplicating this block is low enough.
@@ -825,7 +848,7 @@ bool JumpThreading::ProcessBranchOnCompare(CmpInst *Cmp, BasicBlock *BB) {
   
   // If so, we can actually do this threading.  Merge any common predecessors
   // that will act the same.
-  BasicBlock *PredBB = FactorCommonPHIPreds(PN, PredCst);
+  BasicBlock *PredBB = FactorCommonPHIPreds(PN, PredVal);
   
   // Next, get our successor.
   BasicBlock *SuccBB = BB->getTerminator()->getSuccessor(!TrueDirection);
diff --git a/test/Transforms/JumpThreading/branch-no-const.ll b/test/Transforms/JumpThreading/branch-no-const.ll
new file mode 100644 (file)
index 0000000..0ea2431
--- /dev/null
@@ -0,0 +1,21 @@
+; RUN: llvm-as < %s | opt -jump-threading | llvm-dis | not grep phi
+
+declare i8 @mcguffin()
+
+define i32 @test(i1 %foo, i8 %b) {
+entry:
+  %a = call i8 @mcguffin()
+  br i1 %foo, label %bb1, label %bb2
+bb1:
+  br label %jt
+bb2:
+  br label %jt
+jt:
+  %x = phi i8 [%a, %bb1], [%b, %bb2]
+  %A = icmp eq i8 %x, %a
+  br i1 %A, label %rt, label %rf
+rt:
+  ret i32 7
+rf:
+  ret i32 8
+}