X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTransforms%2FInstCombine%2FInstCombineMulDivRem.cpp;h=173f2bf6330469c638dbaa4a8e07b755fb089f82;hb=c5a4c25b8780434a00968ed93634974a0b796a06;hp=e104a0a9798d1b7077ef0cbefbcc878dbf38f13b;hpb=9753f0b9b4ab33919c5010acb6a7b2dc1e875aff;p=oota-llvm.git diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index e104a0a9798..173f2bf6330 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -13,8 +13,8 @@ //===----------------------------------------------------------------------===// #include "InstCombine.h" -#include "llvm/IntrinsicInst.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/Support/PatternMatch.h" using namespace llvm; using namespace PatternMatch; @@ -37,7 +37,7 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { if (match(V, m_LShr(m_OneUse(m_Shl(m_Value(PowerOf2), m_Value(A))), m_Value(B))) && // The "1" can be any value known to be a power of 2. - isPowerOfTwo(PowerOf2, IC.getTargetData())) { + isKnownToBeAPowerOfTwo(PowerOf2)) { A = IC.Builder->CreateSub(A, B); return IC.Builder->CreateShl(PowerOf2, A); } @@ -45,8 +45,7 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC) { // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. if (BinaryOperator *I = dyn_cast(V)) - if (I->isLogicalShift() && - isPowerOfTwo(I->getOperand(0), IC.getTargetData())) { + if (I->isLogicalShift() && isKnownToBeAPowerOfTwo(I->getOperand(0))) { // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC)) { @@ -252,24 +251,136 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { return Changed ? &I : 0; } +// +// Detect pattern: +// +// log2(Y*0.5) +// +// And check for corresponding fast math flags +// + +static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { + + if (!Op->hasOneUse()) + return; + + IntrinsicInst *II = dyn_cast(Op); + if (!II) + return; + if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra()) + return; + Log2 = II; + + Value *OpLog2Of = II->getArgOperand(0); + if (!OpLog2Of->hasOneUse()) + return; + + Instruction *I = dyn_cast(OpLog2Of); + if (!I) + return; + if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) + return; + + ConstantFP *CFP = dyn_cast(I->getOperand(0)); + if (CFP && CFP->isExactlyValue(0.5)) { + Y = I->getOperand(1); + return; + } + CFP = dyn_cast(I->getOperand(1)); + if (CFP && CFP->isExactlyValue(0.5)) + Y = I->getOperand(0); +} + +/// Helper function of InstCombiner::visitFMul(BinaryOperator(). It returns +/// true iff the given value is FMul or FDiv with one and only one operand +/// being a normal constant (i.e. not Zero/NaN/Infinity). +static bool isFMulOrFDivWithConstant(Value *V) { + Instruction *I = dyn_cast(V); + if (!I || (I->getOpcode() != Instruction::FMul && + I->getOpcode() != Instruction::FDiv)) + return false; + + ConstantFP *C0 = dyn_cast(I->getOperand(0)); + ConstantFP *C1 = dyn_cast(I->getOperand(1)); + + if (C0 && C1) + return false; + + return (C0 && C0->getValueAPF().isNormal()) || + (C1 && C1->getValueAPF().isNormal()); +} + +static bool isNormalFp(const ConstantFP *C) { + const APFloat &Flt = C->getValueAPF(); + return Flt.isNormal() && !Flt.isDenormal(); +} + +/// foldFMulConst() is a helper routine of InstCombiner::visitFMul(). +/// The input \p FMulOrDiv is a FMul/FDiv with one and only one operand +/// being a constant (i.e. isFMulOrFDivWithConstant(FMulOrDiv) == true). +/// This function is to simplify "FMulOrDiv * C" and returns the +/// resulting expression. Note that this function could return NULL in +/// case the constants cannot be folded into a normal floating-point. +/// +Value *InstCombiner::foldFMulConst(Instruction *FMulOrDiv, ConstantFP *C, + Instruction *InsertBefore) { + assert(isFMulOrFDivWithConstant(FMulOrDiv) && "V is invalid"); + + Value *Opnd0 = FMulOrDiv->getOperand(0); + Value *Opnd1 = FMulOrDiv->getOperand(1); + + ConstantFP *C0 = dyn_cast(Opnd0); + ConstantFP *C1 = dyn_cast(Opnd1); + + BinaryOperator *R = 0; + + // (X * C0) * C => X * (C0*C) + if (FMulOrDiv->getOpcode() == Instruction::FMul) { + Constant *F = ConstantExpr::getFMul(C1 ? C1 : C0, C); + if (isNormalFp(cast(F))) + R = BinaryOperator::CreateFMul(C1 ? Opnd0 : Opnd1, F); + } else { + if (C0) { + // (C0 / X) * C => (C0 * C) / X + ConstantFP *F = cast(ConstantExpr::getFMul(C0, C)); + if (isNormalFp(F)) + R = BinaryOperator::CreateFDiv(F, Opnd1); + } else { + // (X / C1) * C => X * (C/C1) if C/C1 is not a denormal + ConstantFP *F = cast(ConstantExpr::getFDiv(C, C1)); + if (isNormalFp(F)) { + R = BinaryOperator::CreateFMul(Opnd0, F); + } else { + // (X / C1) * C => X / (C1/C) + Constant *F = ConstantExpr::getFDiv(C1, C); + if (isNormalFp(cast(F))) + R = BinaryOperator::CreateFDiv(Opnd0, F); + } + } + } + + if (R) { + R->setHasUnsafeAlgebra(true); + InsertNewInstWith(R, *InsertBefore); + } + + return R; +} + Instruction *InstCombiner::visitFMul(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - // Simplify mul instructions with a constant RHS. - if (Constant *Op1C = dyn_cast(Op1)) { - if (ConstantFP *Op1F = dyn_cast(Op1C)) { - // "In IEEE floating point, x*1 is not equivalent to x for nans. However, - // ANSI says we can drop signals, so we can do this anyway." (from GCC) - if (Op1F->isExactlyValue(1.0)) - return ReplaceInstUsesWith(I, Op0); // Eliminate 'fmul double %X, 1.0' - } else if (ConstantDataVector *Op1V = dyn_cast(Op1C)) { - // As above, vector X*splat(1.0) -> X in all defined cases. - if (ConstantFP *F = dyn_cast_or_null(Op1V->getSplatValue())) - if (F->isExactlyValue(1.0)) - return ReplaceInstUsesWith(I, Op0); - } + if (isa(Op0)) + std::swap(Op0, Op1); + + if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), TD)) + return ReplaceInstUsesWith(I, V); + bool AllowReassociate = I.hasUnsafeAlgebra(); + + // Simplify mul instructions with a constant RHS. + if (isa(Op1)) { // Try to fold constant mul into select arguments. if (SelectInst *SI = dyn_cast(Op0)) if (Instruction *R = FoldOpIntoSelect(I, SI)) @@ -278,11 +389,146 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (isa(Op0)) if (Instruction *NV = FoldOpIntoPhi(I)) return NV; + + ConstantFP *C = dyn_cast(Op1); + if (C && AllowReassociate && C->getValueAPF().isNormal()) { + // Let MDC denote an expression in one of these forms: + // X * C, C/X, X/C, where C is a constant. + // + // Try to simplify "MDC * Constant" + if (isFMulOrFDivWithConstant(Op0)) { + Value *V = foldFMulConst(cast(Op0), C, &I); + if (V) + return ReplaceInstUsesWith(I, V); + } + + // (MDC +/- C1) * C => (MDC * C) +/- (C1 * C) + Instruction *FAddSub = dyn_cast(Op0); + if (FAddSub && + (FAddSub->getOpcode() == Instruction::FAdd || + FAddSub->getOpcode() == Instruction::FSub)) { + Value *Opnd0 = FAddSub->getOperand(0); + Value *Opnd1 = FAddSub->getOperand(1); + ConstantFP *C0 = dyn_cast(Opnd0); + ConstantFP *C1 = dyn_cast(Opnd1); + bool Swap = false; + if (C0) { + std::swap(C0, C1); + std::swap(Opnd0, Opnd1); + Swap = true; + } + + if (C1 && C1->getValueAPF().isNormal() && + isFMulOrFDivWithConstant(Opnd0)) { + Value *M1 = ConstantExpr::getFMul(C1, C); + Value *M0 = isNormalFp(cast(M1)) ? + foldFMulConst(cast(Opnd0), C, &I) : + 0; + if (M0 && M1) { + if (Swap && FAddSub->getOpcode() == Instruction::FSub) + std::swap(M0, M1); + + Value *R = (FAddSub->getOpcode() == Instruction::FAdd) ? + BinaryOperator::CreateFAdd(M0, M1) : + BinaryOperator::CreateFSub(M0, M1); + Instruction *RI = cast(R); + RI->copyFastMathFlags(&I); + return RI; + } + } + } + } + } + + + // Under unsafe algebra do: + // X * log2(0.5*Y) = X*log2(Y) - X + if (I.hasUnsafeAlgebra()) { + Value *OpX = NULL; + Value *OpY = NULL; + IntrinsicInst *Log2; + detectLog2OfHalf(Op0, OpY, Log2); + if (OpY) { + OpX = Op1; + } else { + detectLog2OfHalf(Op1, OpY, Log2); + if (OpY) { + OpX = Op0; + } + } + // if pattern detected emit alternate sequence + if (OpX && OpY) { + Log2->setArgOperand(0, OpY); + Value *FMulVal = Builder->CreateFMul(OpX, Log2); + Instruction *FMul = cast(FMulVal); + FMul->copyFastMathFlags(Log2); + Instruction *FSub = BinaryOperator::CreateFSub(FMulVal, OpX); + FSub->copyFastMathFlags(Log2); + return FSub; + } } - if (Value *Op0v = dyn_castFNegVal(Op0)) // -X * -Y = X*Y - if (Value *Op1v = dyn_castFNegVal(Op1)) - return BinaryOperator::CreateFMul(Op0v, Op1v); + // Handle symmetric situation in a 2-iteration loop + Value *Opnd0 = Op0; + Value *Opnd1 = Op1; + for (int i = 0; i < 2; i++) { + bool IgnoreZeroSign = I.hasNoSignedZeros(); + if (BinaryOperator::isFNeg(Opnd0, IgnoreZeroSign)) { + Value *N0 = dyn_castFNegVal(Opnd0, IgnoreZeroSign); + Value *N1 = dyn_castFNegVal(Opnd1, IgnoreZeroSign); + + // -X * -Y => X*Y + if (N1) + return BinaryOperator::CreateFMul(N0, N1); + + if (Opnd0->hasOneUse()) { + // -X * Y => -(X*Y) (Promote negation as high as possible) + Value *T = Builder->CreateFMul(N0, Opnd1); + cast(T)->setDebugLoc(I.getDebugLoc()); + Instruction *Neg = BinaryOperator::CreateFNeg(T); + if (I.getFastMathFlags().any()) { + cast(T)->copyFastMathFlags(&I); + Neg->copyFastMathFlags(&I); + } + return Neg; + } + } + + // (X*Y) * X => (X*X) * Y where Y != X + // The purpose is two-fold: + // 1) to form a power expression (of X). + // 2) potentially shorten the critical path: After transformation, the + // latency of the instruction Y is amortized by the expression of X*X, + // and therefore Y is in a "less critical" position compared to what it + // was before the transformation. + // + if (AllowReassociate) { + Value *Opnd0_0, *Opnd0_1; + if (Opnd0->hasOneUse() && + match(Opnd0, m_FMul(m_Value(Opnd0_0), m_Value(Opnd0_1)))) { + Value *Y = 0; + if (Opnd0_0 == Opnd1 && Opnd0_1 != Opnd1) + Y = Opnd0_1; + else if (Opnd0_1 == Opnd1 && Opnd0_0 != Opnd1) + Y = Opnd0_0; + + if (Y) { + Instruction *T = cast(Builder->CreateFMul(Opnd1, Opnd1)); + T->copyFastMathFlags(&I); + T->setDebugLoc(I.getDebugLoc()); + + Instruction *R = BinaryOperator::CreateFMul(T, Y); + R->copyFastMathFlags(&I); + return R; + } + } + } + + if (!isa(Op1)) + std::swap(Opnd0, Opnd1); + else + break; + } return Changed ? &I : 0; } @@ -462,13 +708,13 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { } } - // Udiv ((Lshl x, c1) , c2) -> x / (C1 * 1<(Op1)) { - Value *X = 0, *C1 = 0; - if (match(Op0, m_LShr(m_Value(X), m_Value(C1)))) { - uint64_t NC = cast(C)->getZExtValue() * - (1<< cast(C1)->getZExtValue()); - return BinaryOperator::CreateUDiv(X, ConstantInt::get(I.getType(), NC)); + // (x lshr C1) udiv C2 --> x udiv (C2 << C1) + if (ConstantInt *C2 = dyn_cast(Op1)) { + Value *X; + ConstantInt *C1; + if (match(Op0, m_LShr(m_Value(X), m_ConstantInt(C1)))) { + APInt NC = C2->getValue().shl(C1->getLimitedValue(C1->getBitWidth()-1)); + return BinaryOperator::CreateUDiv(X, Builder->getInt(NC)); } } @@ -477,7 +723,8 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { if (match(Op1, m_Shl(m_Power2(CI), m_Value(N))) || match(Op1, m_ZExt(m_Shl(m_Power2(CI), m_Value(N))))) { if (*CI != 1) - N = Builder->CreateAdd(N, ConstantInt::get(I.getType(),CI->logBase2())); + N = Builder->CreateAdd(N, + ConstantInt::get(N->getType(), CI->logBase2())); if (ZExtInst *Z = dyn_cast(Op1)) N = Builder->CreateZExt(N, Z->getDestTy()); if (I.isExact()) @@ -543,16 +790,6 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { ConstantExpr::getNeg(RHS)); } - // Sdiv ((Ashl x, c1) , c2) -> x / (C1 * 1<(Op1)) { - Value *X = 0, *C1 = 0; - if (match(Op0, m_AShr(m_Value(X), m_Value(C1)))) { - uint64_t NC = cast(C)->getZExtValue() * - (1<< cast(C1)->getZExtValue()); - return BinaryOperator::CreateSDiv(X, ConstantInt::get(I.getType(), NC)); - } - } - // If the sign bits of both operands are zero (i.e. we can prove they are // unsigned inputs), turn this into a udiv. if (I.getType()->isIntegerTy()) { @@ -576,21 +813,140 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { return 0; } +/// CvtFDivConstToReciprocal tries to convert X/C into X*1/C if C not a special +/// FP value and: +/// 1) 1/C is exact, or +/// 2) reciprocal is allowed. +/// If the convertion was successful, the simplified expression "X * 1/C" is +/// returned; otherwise, NULL is returned. +/// +static Instruction *CvtFDivConstToReciprocal(Value *Dividend, + ConstantFP *Divisor, + bool AllowReciprocal) { + const APFloat &FpVal = Divisor->getValueAPF(); + APFloat Reciprocal(FpVal.getSemantics()); + bool Cvt = FpVal.getExactInverse(&Reciprocal); + + if (!Cvt && AllowReciprocal && FpVal.isNormal()) { + Reciprocal = APFloat(FpVal.getSemantics(), 1.0f); + (void)Reciprocal.divide(FpVal, APFloat::rmNearestTiesToEven); + Cvt = !Reciprocal.isDenormal(); + } + + if (!Cvt) + return 0; + + ConstantFP *R; + R = ConstantFP::get(Dividend->getType()->getContext(), Reciprocal); + return BinaryOperator::CreateFMul(Dividend, R); +} + Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyFDivInst(Op0, Op1, TD)) return ReplaceInstUsesWith(I, V); + bool AllowReassociate = I.hasUnsafeAlgebra(); + bool AllowReciprocal = I.hasAllowReciprocal(); + if (ConstantFP *Op1C = dyn_cast(Op1)) { - const APFloat &Op1F = Op1C->getValueAPF(); - - // If the divisor has an exact multiplicative inverse we can turn the fdiv - // into a cheaper fmul. - APFloat Reciprocal(Op1F.getSemantics()); - if (Op1F.getExactInverse(&Reciprocal)) { - ConstantFP *RFP = ConstantFP::get(Builder->getContext(), Reciprocal); - return BinaryOperator::CreateFMul(Op0, RFP); + if (AllowReassociate) { + ConstantFP *C1 = 0; + ConstantFP *C2 = Op1C; + Value *X; + Instruction *Res = 0; + + if (match(Op0, m_FMul(m_Value(X), m_ConstantFP(C1)))) { + // (X*C1)/C2 => X * (C1/C2) + // + Constant *C = ConstantExpr::getFDiv(C1, C2); + const APFloat &F = cast(C)->getValueAPF(); + if (F.isNormal() && !F.isDenormal()) + Res = BinaryOperator::CreateFMul(X, C); + } else if (match(Op0, m_FDiv(m_Value(X), m_ConstantFP(C1)))) { + // (X/C1)/C2 => X /(C2*C1) [=> X * 1/(C2*C1) if reciprocal is allowed] + // + Constant *C = ConstantExpr::getFMul(C1, C2); + const APFloat &F = cast(C)->getValueAPF(); + if (F.isNormal() && !F.isDenormal()) { + Res = CvtFDivConstToReciprocal(X, cast(C), + AllowReciprocal); + if (!Res) + Res = BinaryOperator::CreateFDiv(X, C); + } + } + + if (Res) { + Res->setFastMathFlags(I.getFastMathFlags()); + return Res; + } + } + + // X / C => X * 1/C + if (Instruction *T = CvtFDivConstToReciprocal(Op0, Op1C, AllowReciprocal)) + return T; + + return 0; + } + + if (AllowReassociate && isa(Op0)) { + ConstantFP *C1 = cast(Op0), *C2; + Constant *Fold = 0; + Value *X; + bool CreateDiv = true; + + // C1 / (X*C2) => (C1/C2) / X + if (match(Op1, m_FMul(m_Value(X), m_ConstantFP(C2)))) + Fold = ConstantExpr::getFDiv(C1, C2); + else if (match(Op1, m_FDiv(m_Value(X), m_ConstantFP(C2)))) { + // C1 / (X/C2) => (C1*C2) / X + Fold = ConstantExpr::getFMul(C1, C2); + } else if (match(Op1, m_FDiv(m_ConstantFP(C2), m_Value(X)))) { + // C1 / (C2/X) => (C1/C2) * X + Fold = ConstantExpr::getFDiv(C1, C2); + CreateDiv = false; + } + + if (Fold) { + const APFloat &FoldC = cast(Fold)->getValueAPF(); + if (FoldC.isNormal() && !FoldC.isDenormal()) { + Instruction *R = CreateDiv ? + BinaryOperator::CreateFDiv(Fold, X) : + BinaryOperator::CreateFMul(X, Fold); + R->setFastMathFlags(I.getFastMathFlags()); + return R; + } + } + return 0; + } + + if (AllowReassociate) { + Value *X, *Y; + Value *NewInst = 0; + Instruction *SimpR = 0; + + if (Op0->hasOneUse() && match(Op0, m_FDiv(m_Value(X), m_Value(Y)))) { + // (X/Y) / Z => X / (Y*Z) + // + if (!isa(Y) || !isa(Op1)) { + NewInst = Builder->CreateFMul(Y, Op1); + SimpR = BinaryOperator::CreateFDiv(X, NewInst); + } + } else if (Op1->hasOneUse() && match(Op1, m_FDiv(m_Value(X), m_Value(Y)))) { + // Z / (X/Y) => Z*Y / X + // + if (!isa(Y) || !isa(Op0)) { + NewInst = Builder->CreateFMul(Op0, Y); + SimpR = BinaryOperator::CreateFDiv(NewInst, X); + } + } + + if (NewInst) { + if (Instruction *T = dyn_cast(NewInst)) + T->setDebugLoc(I.getDebugLoc()); + SimpR->setFastMathFlags(I.getFastMathFlags()); + return SimpR; } }