From: Nick Lewycky Date: Thu, 6 Mar 2008 06:48:30 +0000 (+0000) Subject: Don't try to simplify urem and srem using arithmetic rules that don't work X-Git-Url: http://plrg.eecs.uci.edu/git/?a=commitdiff_plain;h=c1a2a612019ea1c764f3ccb5959104aea3d4df2f;p=oota-llvm.git Don't try to simplify urem and srem using arithmetic rules that don't work under modulo (overflow). Fixes PR1933. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@47987 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp index 1000ba60367..8e99dcc7db2 100644 --- a/lib/Transforms/Scalar/InstructionCombining.cpp +++ b/lib/Transforms/Scalar/InstructionCombining.cpp @@ -834,6 +834,49 @@ static void ComputeMaskedBits(Value *V, const APInt &Mask, APInt& KnownZero, return; } break; + case Instruction::SRem: + if (ConstantInt *Rem = dyn_cast(I->getOperand(1))) { + APInt RA = Rem->getValue(); + if (RA.isPowerOf2() || (-RA).isPowerOf2()) { + APInt LowBits = RA.isStrictlyPositive() ? ((RA - 1) | RA) : ~RA; + APInt Mask2 = LowBits | APInt::getSignBit(BitWidth); + ComputeMaskedBits(I->getOperand(0), Mask2,KnownZero2,KnownOne2,Depth+1); + + // The sign of a remainder is equal to the sign of the first + // operand (zero being positive). + if (KnownZero2[BitWidth-1] || ((KnownZero2 & LowBits) == LowBits)) + KnownZero2 |= ~LowBits; + else if (KnownOne2[BitWidth-1]) + KnownOne2 |= ~LowBits; + + KnownZero |= KnownZero2 & Mask; + KnownOne |= KnownOne2 & Mask; + + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + } + } + break; + case Instruction::URem: + if (ConstantInt *Rem = dyn_cast(I->getOperand(1))) { + APInt RA = Rem->getValue(); + if (RA.isStrictlyPositive() && RA.isPowerOf2()) { + APInt LowBits = (RA - 1) | RA; + APInt Mask2 = LowBits & Mask; + KnownZero |= ~LowBits & Mask; + ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero, KnownOne,Depth+1); + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + } + } else { + // Since the result is less than or equal to RHS, any leading zero bits + // in RHS must also exist in the result. + APInt AllOnes = APInt::getAllOnesValue(BitWidth); + ComputeMaskedBits(I->getOperand(1), AllOnes, KnownZero2, KnownOne2, Depth+1); + + uint32_t Leaders = KnownZero2.countLeadingOnes(); + KnownZero |= APInt::getHighBitsSet(BitWidth, Leaders) & Mask; + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + } + break; } } @@ -1418,6 +1461,52 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, APInt DemandedMask, } } break; + case Instruction::SRem: + if (ConstantInt *Rem = dyn_cast(I->getOperand(1))) { + APInt RA = Rem->getValue(); + if (RA.isPowerOf2() || (-RA).isPowerOf2()) { + APInt LowBits = RA.isStrictlyPositive() ? (RA - 1) | RA : ~RA; + APInt Mask2 = LowBits | APInt::getSignBit(BitWidth); + if (SimplifyDemandedBits(I->getOperand(0), Mask2, + LHSKnownZero, LHSKnownOne, Depth+1)) + return true; + + if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits)) + LHSKnownZero |= ~LowBits; + else if (LHSKnownOne[BitWidth-1]) + LHSKnownOne |= ~LowBits; + + KnownZero |= LHSKnownZero & DemandedMask; + KnownOne |= LHSKnownOne & DemandedMask; + + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + } + } + break; + case Instruction::URem: + if (ConstantInt *Rem = dyn_cast(I->getOperand(1))) { + APInt RA = Rem->getValue(); + if (RA.isPowerOf2()) { + APInt LowBits = (RA - 1) | RA; + APInt Mask2 = LowBits & DemandedMask; + KnownZero |= ~LowBits & DemandedMask; + if (SimplifyDemandedBits(I->getOperand(0), Mask2, + KnownZero, KnownOne, Depth+1)) + return true; + + assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); + } + } else { + APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0); + APInt AllOnes = APInt::getAllOnesValue(BitWidth); + if (SimplifyDemandedBits(I->getOperand(1), AllOnes, + KnownZero2, KnownOne2, Depth+1)) + return true; + + uint32_t Leaders = KnownZero2.countLeadingOnes(); + KnownZero |= APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask; + } + break; } // If the client is only demanding bits that we know, return the known @@ -2780,46 +2869,6 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { return commonDivTransforms(I); } -/// GetFactor - If we can prove that the specified value is at least a multiple -/// of some factor, return that factor. -static Constant *GetFactor(Value *V) { - if (ConstantInt *CI = dyn_cast(V)) - return CI; - - // Unless we can be tricky, we know this is a multiple of 1. - Constant *Result = ConstantInt::get(V->getType(), 1); - - Instruction *I = dyn_cast(V); - if (!I) return Result; - - if (I->getOpcode() == Instruction::Mul) { - // Handle multiplies by a constant, etc. - return ConstantExpr::getMul(GetFactor(I->getOperand(0)), - GetFactor(I->getOperand(1))); - } else if (I->getOpcode() == Instruction::Shl) { - // (X< X * (1 << C) - if (Constant *ShRHS = dyn_cast(I->getOperand(1))) { - ShRHS = ConstantExpr::getShl(Result, ShRHS); - return ConstantExpr::getMul(GetFactor(I->getOperand(0)), ShRHS); - } - } else if (I->getOpcode() == Instruction::And) { - if (ConstantInt *RHS = dyn_cast(I->getOperand(1))) { - // X & 0xFFF0 is known to be a multiple of 16. - uint32_t Zeros = RHS->getValue().countTrailingZeros(); - if (Zeros != V->getType()->getPrimitiveSizeInBits())// don't shift by "32" - return ConstantExpr::getShl(Result, - ConstantInt::get(Result->getType(), Zeros)); - } - } else if (CastInst *CI = dyn_cast(I)) { - // Only handle int->int casts. - if (!CI->isIntegerCast()) - return Result; - Value *Op = CI->getOperand(0); - return ConstantExpr::getCast(CI->getOpcode(), GetFactor(Op), V->getType()); - } - return Result; -} - /// This function implements the transforms on rem instructions that work /// regardless of the kind of rem instruction it is (urem, srem, or frem). It /// is used by the visitors to those instructions. @@ -2901,9 +2950,13 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { if (Instruction *NV = FoldOpIntoPhi(I)) return NV; } - // (X * C1) % C2 --> 0 iff C1 % C2 == 0 - if (ConstantExpr::getSRem(GetFactor(Op0I), RHS)->isNullValue()) - return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType())); + + // See if we can fold away this rem instruction. + uint32_t BitWidth = cast(I.getType())->getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + if (SimplifyDemandedBits(&I, APInt::getAllOnesValue(BitWidth), + KnownZero, KnownOne)) + return &I; } } diff --git a/test/Transforms/InstCombine/rem.ll b/test/Transforms/InstCombine/rem.ll index c0e0fa4dc3c..8b2263d9b23 100644 --- a/test/Transforms/InstCombine/rem.ll +++ b/test/Transforms/InstCombine/rem.ll @@ -1,76 +1,83 @@ ; This test makes sure that these instructions are properly eliminated. ; ; RUN: llvm-as < %s | opt -instcombine | llvm-dis | not grep rem +; END. define i32 @test1(i32 %A) { - %B = srem i32 %A, 1 ; [#uses=1] + %B = srem i32 %A, 1 ; ISA constant 0 ret i32 %B } -define i32 @test2(i32 %A) { - %B = srem i32 0, %A ; [#uses=1] +define i32 @test2(i32 %A) { ; 0 % X = 0, we don't need to preserve traps + %B = srem i32 0, %A ret i32 %B } define i32 @test3(i32 %A) { - %B = urem i32 %A, 8 ; [#uses=1] + %B = urem i32 %A, 8 ret i32 %B } define i1 @test3a(i32 %A) { - %B = srem i32 %A, -8 ; [#uses=1] - %C = icmp ne i32 %B, 0 ; [#uses=1] + %B = srem i32 %A, -8 + %C = icmp ne i32 %B, 0 ret i1 %C } define i32 @test4(i32 %X, i1 %C) { - %V = select i1 %C, i32 1, i32 8 ; [#uses=1] - %R = urem i32 %X, %V ; [#uses=1] + %V = select i1 %C, i32 1, i32 8 + %R = urem i32 %X, %V ret i32 %R } define i32 @test5(i32 %X, i8 %B) { - %shift.upgrd.1 = zext i8 %B to i32 ; [#uses=1] - %Amt = shl i32 32, %shift.upgrd.1 ; [#uses=1] - %V = urem i32 %X, %Amt ; [#uses=1] + %shift.upgrd.1 = zext i8 %B to i32 + %Amt = shl i32 32, %shift.upgrd.1 + %V = urem i32 %X, %Amt ret i32 %V } define i32 @test6(i32 %A) { - %B = srem i32 %A, 0 ; [#uses=1] + %B = srem i32 %A, 0 ;; undef ret i32 %B } define i32 @test7(i32 %A) { - %B = mul i32 %A, 26 ; [#uses=1] - %C = srem i32 %B, 13 ; [#uses=1] + %B = mul i32 %A, 8 + %C = srem i32 %B, 4 ret i32 %C } define i32 @test8(i32 %A) { - %B = shl i32 %A, 4 ; [#uses=1] - %C = srem i32 %B, 8 ; [#uses=1] + %B = shl i32 %A, 4 + %C = srem i32 %B, 8 ret i32 %C } define i32 @test9(i32 %A) { - %B = mul i32 %A, 124 ; [#uses=1] - %C = urem i32 %B, 62 ; [#uses=1] + %B = mul i32 %A, 64 + %C = urem i32 %B, 32 ret i32 %C } define i32 @test10(i8 %c) { - %tmp.1 = zext i8 %c to i32 ; [#uses=1] - %tmp.2 = mul i32 %tmp.1, 3 ; [#uses=1] - %tmp.3 = sext i32 %tmp.2 to i64 ; [#uses=1] - %tmp.5 = urem i64 %tmp.3, 3 ; [#uses=1] - %tmp.6 = trunc i64 %tmp.5 to i32 ; [#uses=1] + %tmp.1 = zext i8 %c to i32 + %tmp.2 = mul i32 %tmp.1, 4 + %tmp.3 = sext i32 %tmp.2 to i64 + %tmp.5 = urem i64 %tmp.3, 4 + %tmp.6 = trunc i64 %tmp.5 to i32 ret i32 %tmp.6 } define i32 @test11(i32 %i) { - %tmp.1 = and i32 %i, -2 ; [#uses=1] - %tmp.3 = mul i32 %tmp.1, 3 ; [#uses=1] - %tmp.5 = srem i32 %tmp.3, 6 ; [#uses=1] + %tmp.1 = and i32 %i, -2 + %tmp.3 = mul i32 %tmp.1, 2 + %tmp.5 = urem i32 %tmp.3, 4 + ret i32 %tmp.5 +} + +define i32 @test12(i32 %i) { + %tmp.1 = and i32 %i, -4 + %tmp.5 = srem i32 %tmp.1, 2 ret i32 %tmp.5 }