InstCombine: Make switch folding with equality compares more aggressive by trying...
authorBenjamin Kramer <benny.kra@googlemail.com>
Fri, 27 May 2011 13:00:16 +0000 (13:00 +0000)
committerBenjamin Kramer <benny.kra@googlemail.com>
Fri, 27 May 2011 13:00:16 +0000 (13:00 +0000)
Stuff like "x == y ? y : x&y" now folds into "x&y".

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@132185 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineSelect.cpp
test/Transforms/InstCombine/select.ll

index 75a0b43bb69a48b39f962c79ee5df0fba26ee82e..2a451ce5f12ace995269c29ffcf04a43b443fae8 100644 (file)
@@ -278,6 +278,49 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
   return 0;
 }
 
+/// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
+/// replaced with RepOp.
+static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                                     const TargetData *TD) {
+  // Trivial replacement.
+  if (V == Op)
+    return RepOp;
+
+  Instruction *I = dyn_cast<Instruction>(V);
+  if (!I)
+    return 0;
+
+  // If this is a binary operator, try to simplify it with the replaced op.
+  if (BinaryOperator *B = dyn_cast<BinaryOperator>(I)) {
+    if (B->getOperand(0) == Op)
+      return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD);
+    if (B->getOperand(1) == Op)
+      return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD);
+  }
+
+  // If all operands are constant after substituting Op for RepOp then we can
+  // constant fold the instruction.
+  if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) {
+    // Build a list of all constant operands.
+    SmallVector<Constant*, 8> ConstOps;
+    for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
+      if (I->getOperand(i) == Op)
+        ConstOps.push_back(CRepOp);
+      else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i)))
+        ConstOps.push_back(COp);
+      else
+        break;
+    }
+
+    // All operands were constants, fold it.
+    if (ConstOps.size() == I->getNumOperands())
+      return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
+                                      ConstOps.data(), ConstOps.size(), TD);
+  }
+
+  return 0;
+}
+
 /// visitSelectInstWithICmp - Visit a SelectInst that has an
 /// ICmpInst as its first operand.
 ///
@@ -416,25 +459,21 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
     }
   }
 
-  if (CmpLHS == TrueVal && CmpRHS == FalseVal) {
-    // Transform (X == Y) ? X : Y  -> Y
-    if (Pred == ICmpInst::ICMP_EQ)
+  // If we have an equality comparison then we know the value in one of the
+  // arms of the select. See if substituting this value into the arm and
+  // simplifying the result yields the same value as the other arm.
+  if (Pred == ICmpInst::ICMP_EQ) {
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD) == TrueVal ||
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD) == TrueVal)
       return ReplaceInstUsesWith(SI, FalseVal);
-    // Transform (X != Y) ? X : Y  -> X
-    if (Pred == ICmpInst::ICMP_NE)
+  } else if (Pred == ICmpInst::ICMP_NE) {
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD) == FalseVal ||
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD) == FalseVal)
       return ReplaceInstUsesWith(SI, TrueVal);
-    /// NOTE: if we wanted to, this is where to detect integer MIN/MAX
-
-  } else if (CmpLHS == FalseVal && CmpRHS == TrueVal) {
-    // Transform (X == Y) ? Y : X  -> X
-    if (Pred == ICmpInst::ICMP_EQ)
-      return ReplaceInstUsesWith(SI, FalseVal);
-    // Transform (X != Y) ? Y : X  -> Y
-    if (Pred == ICmpInst::ICMP_NE)
-      return ReplaceInstUsesWith(SI, TrueVal);
-    /// NOTE: if we wanted to, this is where to detect integer MIN/MAX
   }
 
+  // NOTE: if we wanted to, this is where to detect integer MIN/MAX
+
   if (isa<Constant>(CmpRHS)) {
     if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) {
       // Transform (X == C) ? X : Y -> (X == C) ? C : Y
index 39259078b877d236c95d134f1783ed12bc630832..379228512cd1b9f41266563782a5284dd412d5d6 100644 (file)
@@ -749,3 +749,43 @@ define i1 @test55(i1 %X, i32 %Y, i32 %Z) {
 ; CHECK: icmp eq
 ; CHECK: ret i1
 }
+
+define i32 @test56(i16 %x) nounwind {
+  %tobool = icmp eq i16 %x, 0
+  %conv = zext i16 %x to i32
+  %cond = select i1 %tobool, i32 0, i32 %conv
+  ret i32 %cond
+; CHECK: @test56
+; CHECK-NEXT: zext
+; CHECK-NEXT: ret
+}
+
+define i32 @test57(i32 %x, i32 %y) nounwind {
+  %and = and i32 %x, %y
+  %tobool = icmp eq i32 %x, 0
+  %.and = select i1 %tobool, i32 0, i32 %and
+  ret i32 %.and
+; CHECK: @test57
+; CHECK-NEXT: and i32 %x, %y
+; CHECK-NEXT: ret
+}
+
+define i32 @test58(i16 %x) nounwind {
+  %tobool = icmp ne i16 %x, 1
+  %conv = zext i16 %x to i32
+  %cond = select i1 %tobool, i32 %conv, i32 1
+  ret i32 %cond
+; CHECK: @test58
+; CHECK-NEXT: zext
+; CHECK-NEXT: ret
+}
+
+define i32 @test59(i32 %x, i32 %y) nounwind {
+  %and = and i32 %x, %y
+  %tobool = icmp ne i32 %x, %y
+  %.and = select i1 %tobool, i32 %and, i32 %y
+  ret i32 %.and
+; CHECK: @test59
+; CHECK-NEXT: and i32 %x, %y
+; CHECK-NEXT: ret
+}