Don't try to simplify urem and srem using arithmetic rules that don't work
authorNick Lewycky <nicholas@mxc.ca>
Thu, 6 Mar 2008 06:48:30 +0000 (06:48 +0000)
committerNick Lewycky <nicholas@mxc.ca>
Thu, 6 Mar 2008 06:48:30 +0000 (06:48 +0000)
under modulo (overflow). Fixes PR1933.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@47987 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/InstructionCombining.cpp
test/Transforms/InstCombine/rem.ll

index 1000ba60367001db25ad3ec80d1c951d5bd949a1..8e99dcc7db2cfe893a2d90df2da8ed43bbd5f089 100644 (file)
@@ -834,6 +834,49 @@ static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero,
       return;
     }
     break;
+  case Instruction::SRem:
+    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      APInt RA = Rem->getValue();
+      if (RA.isPowerOf2() || (-RA).isPowerOf2()) {
+        APInt LowBits = RA.isStrictlyPositive() ? ((RA - 1) | RA) : ~RA;
+        APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
+        ComputeMaskedBits(I->getOperand(0), Mask2,KnownZero2,KnownOne2,Depth+1);
+
+        // The sign of a remainder is equal to the sign of the first
+        // operand (zero being positive).
+        if (KnownZero2[BitWidth-1] || ((KnownZero2 & LowBits) == LowBits))
+          KnownZero2 |= ~LowBits;
+        else if (KnownOne2[BitWidth-1])
+          KnownOne2 |= ~LowBits;
+
+        KnownZero |= KnownZero2 & Mask;
+        KnownOne |= KnownOne2 & Mask;
+
+        assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); 
+      }
+    }
+    break;
+  case Instruction::URem:
+    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      APInt RA = Rem->getValue();
+      if (RA.isStrictlyPositive() && RA.isPowerOf2()) {
+        APInt LowBits = (RA - 1) | RA;
+        APInt Mask2 = LowBits & Mask;
+        KnownZero |= ~LowBits & Mask;
+        ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero, KnownOne,Depth+1);
+        assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
+      }
+    } else {
+      // Since the result is less than or equal to RHS, any leading zero bits
+      // in RHS must also exist in the result.
+      APInt AllOnes = APInt::getAllOnesValue(BitWidth);
+      ComputeMaskedBits(I->getOperand(1), AllOnes, KnownZero2, KnownOne2, Depth+1);
+
+      uint32_t Leaders = KnownZero2.countLeadingOnes();
+      KnownZero |= APInt::getHighBitsSet(BitWidth, Leaders) & Mask;
+      assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
+    }
+    break;
   }
 }
 
@@ -1418,6 +1461,52 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, APInt DemandedMask,
       }
     }
     break;
+  case Instruction::SRem:
+    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      APInt RA = Rem->getValue();
+      if (RA.isPowerOf2() || (-RA).isPowerOf2()) {
+        APInt LowBits = RA.isStrictlyPositive() ? (RA - 1) | RA : ~RA;
+        APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
+        if (SimplifyDemandedBits(I->getOperand(0), Mask2,
+                                 LHSKnownZero, LHSKnownOne, Depth+1))
+          return true;
+
+        if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits))
+          LHSKnownZero |= ~LowBits;
+        else if (LHSKnownOne[BitWidth-1])
+          LHSKnownOne |= ~LowBits;
+
+        KnownZero |= LHSKnownZero & DemandedMask;
+        KnownOne |= LHSKnownOne & DemandedMask;
+
+        assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); 
+      }
+    }
+    break;
+  case Instruction::URem:
+    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      APInt RA = Rem->getValue();
+      if (RA.isPowerOf2()) {
+        APInt LowBits = (RA - 1) | RA;
+        APInt Mask2 = LowBits & DemandedMask;
+        KnownZero |= ~LowBits & DemandedMask;
+        if (SimplifyDemandedBits(I->getOperand(0), Mask2,
+                                 KnownZero, KnownOne, Depth+1))
+          return true;
+
+        assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); 
+      }
+    } else {
+      APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0);
+      APInt AllOnes = APInt::getAllOnesValue(BitWidth);
+      if (SimplifyDemandedBits(I->getOperand(1), AllOnes,
+                               KnownZero2, KnownOne2, Depth+1))
+        return true;
+
+      uint32_t Leaders = KnownZero2.countLeadingOnes();
+      KnownZero |= APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
+    }
+    break;
   }
   
   // If the client is only demanding bits that we know, return the known
@@ -2780,46 +2869,6 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
   return commonDivTransforms(I);
 }
 
-/// GetFactor - If we can prove that the specified value is at least a multiple
-/// of some factor, return that factor.
-static Constant *GetFactor(Value *V) {
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
-    return CI;
-  
-  // Unless we can be tricky, we know this is a multiple of 1.
-  Constant *Result = ConstantInt::get(V->getType(), 1);
-  
-  Instruction *I = dyn_cast<Instruction>(V);
-  if (!I) return Result;
-  
-  if (I->getOpcode() == Instruction::Mul) {
-    // Handle multiplies by a constant, etc.
-    return ConstantExpr::getMul(GetFactor(I->getOperand(0)),
-                                GetFactor(I->getOperand(1)));
-  } else if (I->getOpcode() == Instruction::Shl) {
-    // (X<<C) -> X * (1 << C)
-    if (Constant *ShRHS = dyn_cast<Constant>(I->getOperand(1))) {
-      ShRHS = ConstantExpr::getShl(Result, ShRHS);
-      return ConstantExpr::getMul(GetFactor(I->getOperand(0)), ShRHS);
-    }
-  } else if (I->getOpcode() == Instruction::And) {
-    if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
-      // X & 0xFFF0 is known to be a multiple of 16.
-      uint32_t Zeros = RHS->getValue().countTrailingZeros();
-      if (Zeros != V->getType()->getPrimitiveSizeInBits())// don't shift by "32"
-        return ConstantExpr::getShl(Result, 
-                                    ConstantInt::get(Result->getType(), Zeros));
-    }
-  } else if (CastInst *CI = dyn_cast<CastInst>(I)) {
-    // Only handle int->int casts.
-    if (!CI->isIntegerCast())
-      return Result;
-    Value *Op = CI->getOperand(0);
-    return ConstantExpr::getCast(CI->getOpcode(), GetFactor(Op), V->getType());
-  }    
-  return Result;
-}
-
 /// This function implements the transforms on rem instructions that work
 /// regardless of the kind of rem instruction it is (urem, srem, or frem). It 
 /// is used by the visitors to those instructions.
@@ -2901,9 +2950,13 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
         if (Instruction *NV = FoldOpIntoPhi(I))
           return NV;
       }
-      // (X * C1) % C2 --> 0  iff  C1 % C2 == 0
-      if (ConstantExpr::getSRem(GetFactor(Op0I), RHS)->isNullValue())
-        return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+
+      // See if we can fold away this rem instruction.
+      uint32_t BitWidth = cast<IntegerType>(I.getType())->getBitWidth();
+      APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+      if (SimplifyDemandedBits(&I, APInt::getAllOnesValue(BitWidth),
+                               KnownZero, KnownOne))
+        return &I;
     }
   }
 
index c0e0fa4dc3c3673a0c34a0a1447bc1b031581d57..8b2263d9b23f882fa45e579736862a99effe6233 100644 (file)
@@ -1,76 +1,83 @@
 ; This test makes sure that these instructions are properly eliminated.
 ;
 ; RUN: llvm-as < %s | opt -instcombine | llvm-dis | not grep rem
+; END.
 
 define i32 @test1(i32 %A) {
-       %B = srem i32 %A, 1             ; <i32> [#uses=1]
+       %B = srem i32 %A, 1     ; ISA constant 0
        ret i32 %B
 }
 
-define i32 @test2(i32 %A) {
-       %B = srem i32 0, %A             ; <i32> [#uses=1]
+define i32 @test2(i32 %A) {    ; 0 % X = 0, we don't need to preserve traps
+       %B = srem i32 0, %A
        ret i32 %B
 }
 
 define i32 @test3(i32 %A) {
-       %B = urem i32 %A, 8             ; <i32> [#uses=1]
+       %B = urem i32 %A, 8
        ret i32 %B
 }
 
 define i1 @test3a(i32 %A) {
-       %B = srem i32 %A, -8            ; <i32> [#uses=1]
-       %C = icmp ne i32 %B, 0          ; <i1> [#uses=1]
+       %B = srem i32 %A, -8
+       %C = icmp ne i32 %B, 0
        ret i1 %C
 }
 
 define i32 @test4(i32 %X, i1 %C) {
-       %V = select i1 %C, i32 1, i32 8         ; <i32> [#uses=1]
-       %R = urem i32 %X, %V            ; <i32> [#uses=1]
+       %V = select i1 %C, i32 1, i32 8
+       %R = urem i32 %X, %V
        ret i32 %R
 }
 
 define i32 @test5(i32 %X, i8 %B) {
-       %shift.upgrd.1 = zext i8 %B to i32              ; <i32> [#uses=1]
-       %Amt = shl i32 32, %shift.upgrd.1               ; <i32> [#uses=1]
-       %V = urem i32 %X, %Amt          ; <i32> [#uses=1]
+       %shift.upgrd.1 = zext i8 %B to i32
+       %Amt = shl i32 32, %shift.upgrd.1
+       %V = urem i32 %X, %Amt
        ret i32 %V
 }
 
 define i32 @test6(i32 %A) {
-       %B = srem i32 %A, 0             ; <i32> [#uses=1]
+       %B = srem i32 %A, 0     ;; undef
        ret i32 %B
 }
 
 define i32 @test7(i32 %A) {
-       %B = mul i32 %A, 26             ; <i32> [#uses=1]
-       %C = srem i32 %B, 13            ; <i32> [#uses=1]
+       %B = mul i32 %A, 8
+       %C = srem i32 %B, 4
        ret i32 %C
 }
 
 define i32 @test8(i32 %A) {
-       %B = shl i32 %A, 4              ; <i32> [#uses=1]
-       %C = srem i32 %B, 8             ; <i32> [#uses=1]
+       %B = shl i32 %A, 4
+       %C = srem i32 %B, 8
        ret i32 %C
 }
 
 define i32 @test9(i32 %A) {
-       %B = mul i32 %A, 124            ; <i32> [#uses=1]
-       %C = urem i32 %B, 62            ; <i32> [#uses=1]
+       %B = mul i32 %A, 64
+       %C = urem i32 %B, 32
        ret i32 %C
 }
 
 define i32 @test10(i8 %c) {
-       %tmp.1 = zext i8 %c to i32              ; <i32> [#uses=1]
-       %tmp.2 = mul i32 %tmp.1, 3              ; <i32> [#uses=1]
-       %tmp.3 = sext i32 %tmp.2 to i64         ; <i64> [#uses=1]
-       %tmp.5 = urem i64 %tmp.3, 3             ; <i64> [#uses=1]
-       %tmp.6 = trunc i64 %tmp.5 to i32                ; <i32> [#uses=1]
+       %tmp.1 = zext i8 %c to i32
+       %tmp.2 = mul i32 %tmp.1, 4
+       %tmp.3 = sext i32 %tmp.2 to i64
+       %tmp.5 = urem i64 %tmp.3, 4
+       %tmp.6 = trunc i64 %tmp.5 to i32
        ret i32 %tmp.6
 }
 
 define i32 @test11(i32 %i) {
-       %tmp.1 = and i32 %i, -2         ; <i32> [#uses=1]
-       %tmp.3 = mul i32 %tmp.1, 3              ; <i32> [#uses=1]
-       %tmp.5 = srem i32 %tmp.3, 6             ; <i32> [#uses=1]
+       %tmp.1 = and i32 %i, -2
+       %tmp.3 = mul i32 %tmp.1, 2
+       %tmp.5 = urem i32 %tmp.3, 4
+       ret i32 %tmp.5
+}
+
+define i32 @test12(i32 %i) {
+       %tmp.1 = and i32 %i, -4
+       %tmp.5 = srem i32 %tmp.1, 2
        ret i32 %tmp.5
 }