implement instcombine folding for things like (x >> c) < 42.
authorChris Lattner <sabre@nondot.org>
Sun, 13 Feb 2011 08:07:21 +0000 (08:07 +0000)
committerChris Lattner <sabre@nondot.org>
Sun, 13 Feb 2011 08:07:21 +0000 (08:07 +0000)
We were previously simplifying divisions, but not right shifts!

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125454 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineCompares.cpp
test/Transforms/InstCombine/exact.ll

index f3c35330b9cd86bed080e0e735f42142c381b9d0..dd2cb9ce0b2774bab698b343f64b6751208a882b 100644 (file)
@@ -794,9 +794,11 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
     return 0; // The ProdOV computation fails on divide by zero.
   if (DivIsSigned && DivRHS->isAllOnesValue())
     return 0; // The overflow computation also screws up here
-  if (DivRHS->isOne())
-    return 0; // Not worth bothering, and eliminates some funny cases
-              // with INT_MIN.
+  if (DivRHS->isOne()) {
+    // This eliminates some funny cases with INT_MIN.
+    ICI.setOperand(0, DivI->getOperand(0));   // X/1 == X.
+    return &ICI;
+  }
 
   // Compute Prod = CI * DivRHS. We are essentially solving an equation
   // of form X/C1=C2. We solve for X by multiplying C1 (DivRHS) and 
@@ -931,8 +933,6 @@ Instruction *InstCombiner::FoldICmpDivCst(ICmpInst &ICI, BinaryOperator *DivI,
 /// FoldICmpShrCst - Handle "icmp(([al]shr X, cst1), cst2)".
 Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr,
                                           ConstantInt *ShAmt) {
-  if (!ICI.isEquality()) return 0;
-
   const APInt &CmpRHSV = cast<ConstantInt>(ICI.getOperand(1))->getValue();
   
   // Check that the shift amount is in range.  If not, don't perform
@@ -940,9 +940,50 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr,
   // simplified.
   uint32_t TypeBits = CmpRHSV.getBitWidth();
   uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
-  if (ShAmtVal >= TypeBits)
+  if (ShAmtVal >= TypeBits || ShAmtVal == 0)
     return 0;
   
+  if (!ICI.isEquality()) {
+    // If we have an unsigned comparison and an ashr, we can't simplify this.
+    // Similarly for signed comparisons with lshr.
+    if (ICI.isSigned() != (Shr->getOpcode() == Instruction::AShr))
+      return 0;
+    
+    // Otherwise, all lshr and all exact ashr's are equivalent to a udiv/sdiv by
+    // a power of 2.  Since we already have logic to simplify these, transform
+    // to div and then simplify the resultant comparison.
+    if (Shr->getOpcode() == Instruction::AShr &&
+        !Shr->isExact())
+      return 0;
+    
+    // Revisit the shift (to delete it).
+    Worklist.Add(Shr);
+    
+    Constant *DivCst =
+      ConstantInt::get(Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal));
+    
+    Value *Tmp =
+      Shr->getOpcode() == Instruction::AShr ?
+      Builder->CreateSDiv(Shr->getOperand(0), DivCst, "", Shr->isExact()) :
+      Builder->CreateUDiv(Shr->getOperand(0), DivCst, "", Shr->isExact());
+    
+    ICI.setOperand(0, Tmp);
+    
+    // If the builder folded the binop, just return it.
+    BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp);
+    if (TheDiv == 0)
+      return &ICI;
+    
+    // Otherwise, fold this div/compare.
+    assert(TheDiv->getOpcode() == Instruction::SDiv ||
+           TheDiv->getOpcode() == Instruction::UDiv);
+    
+    Instruction *Res = FoldICmpDivCst(ICI, TheDiv, cast<ConstantInt>(DivCst));
+    assert(Res && "This div/cst should have folded!");
+    return Res;
+  }
+  
+  
   // If we are comparing against bits always shifted out, the
   // comparison cannot succeed.
   APInt Comp = CmpRHSV << ShAmtVal;
@@ -1266,8 +1307,9 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
     if (LHSI->hasOneUse() &&
         isSignBitCheck(ICI.getPredicate(), RHS, TrueIfSigned)) {
       // (X << 31) <s 0  --> (X&1) != 0
-      Constant *Mask = ConstantInt::get(ICI.getContext(), APInt(TypeBits, 1) <<
-                                           (TypeBits-ShAmt->getZExtValue()-1));
+      Constant *Mask = ConstantInt::get(LHSI->getOperand(0)->getType(),
+                                        APInt::getOneBitSet(TypeBits, 
+                                            TypeBits-ShAmt->getZExtValue()-1));
       Value *And =
         Builder->CreateAnd(LHSI->getOperand(0), Mask, LHSI->getName()+".mask");
       return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ,
index 666bcb805191b58628c767a0325b3450d2d3fb01..255bf7bced3648e6c2abb6e7e90d4c215489649c 100644 (file)
@@ -77,15 +77,24 @@ define i64 @ashr1(i64 %X) nounwind {
   ret i64 %B
 }
 
-; CHECK: @ashr_icmp
+; CHECK: @ashr_icmp1
 ; CHECK: %B = icmp eq i64 %X, 0
 ; CHECK: ret i1 %B
-define i1 @ashr_icmp(i64 %X) nounwind {
+define i1 @ashr_icmp1(i64 %X) nounwind {
   %A = ashr exact i64 %X, 2   ; X/4
   %B = icmp eq i64 %A, 0
   ret i1 %B
 }
 
+; CHECK: @ashr_icmp2
+; CHECK: %Z = icmp slt i64 %X, 16
+; CHECK: ret i1 %Z
+define i1 @ashr_icmp2(i64 %X) nounwind {
+ %Y = ashr exact i64 %X, 2  ; x / 4
+ %Z = icmp slt i64 %Y, 4    ; x < 16
+ ret i1 %Z
+}
+
 ; CHECK: @udiv_icmp1
 ; CHECK: icmp ne i64 %X, 0
 define i1 @udiv_icmp1(i64 %X) nounwind {