X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTransforms%2FInstCombine%2FInstCombineSelect.cpp;h=2df6193d512d70cac0f8ee3e4cae2b72727db4e0;hb=81e467d35217e7c331048c474f13bc91c942a911;hp=69380fc41d3a1b83a95f69c12a7d996dea6a7019;hpb=944d86558eae35e40e9f8a6bbfd626bea939abf5;p=oota-llvm.git diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 69380fc41d3..2df6193d512 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -14,85 +14,58 @@ #include "InstCombineInternal.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/PatternMatch.h" using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -/// MatchSelectPattern - Pattern match integer [SU]MIN, [SU]MAX, and ABS idioms, -/// returning the kind and providing the out parameter results if we -/// successfully match. static SelectPatternFlavor -MatchSelectPattern(Value *V, Value *&LHS, Value *&RHS) { - SelectInst *SI = dyn_cast(V); - if (!SI) return SPF_UNKNOWN; - - ICmpInst *ICI = dyn_cast(SI->getCondition()); - if (!ICI) return SPF_UNKNOWN; - - ICmpInst::Predicate Pred = ICI->getPredicate(); - Value *CmpLHS = ICI->getOperand(0); - Value *CmpRHS = ICI->getOperand(1); - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - - LHS = CmpLHS; - RHS = CmpRHS; - - // (icmp X, Y) ? X : Y - if (TrueVal == CmpLHS && FalseVal == CmpRHS) { - switch (Pred) { - default: return SPF_UNKNOWN; // Equality. - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: return SPF_UMAX; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: return SPF_SMAX; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: return SPF_UMIN; - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: return SPF_SMIN; - } - } - - // (icmp X, Y) ? Y : X - if (TrueVal == CmpRHS && FalseVal == CmpLHS) { - switch (Pred) { - default: return SPF_UNKNOWN; // Equality. - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: return SPF_UMIN; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: return SPF_SMIN; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: return SPF_UMAX; - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: return SPF_SMAX; - } +getInverseMinMaxSelectPattern(SelectPatternFlavor SPF) { + switch (SPF) { + default: + llvm_unreachable("unhandled!"); + + case SPF_SMIN: + return SPF_SMAX; + case SPF_UMIN: + return SPF_UMAX; + case SPF_SMAX: + return SPF_SMIN; + case SPF_UMAX: + return SPF_UMIN; } +} - if (ConstantInt *C1 = dyn_cast(CmpRHS)) { - if ((CmpLHS == TrueVal && match(FalseVal, m_Neg(m_Specific(CmpLHS)))) || - (CmpLHS == FalseVal && match(TrueVal, m_Neg(m_Specific(CmpLHS))))) { - - // ABS(X) ==> (X >s 0) ? X : -X and (X >s -1) ? X : -X - // NABS(X) ==> (X >s 0) ? -X : X and (X >s -1) ? -X : X - if (Pred == ICmpInst::ICMP_SGT && (C1->isZero() || C1->isMinusOne())) { - return (CmpLHS == TrueVal) ? SPF_ABS : SPF_NABS; - } - - // ABS(X) ==> (X (X isZero() || C1->isOne())) { - return (CmpLHS == FalseVal) ? SPF_ABS : SPF_NABS; - } - } +static CmpInst::Predicate getCmpPredicateForMinMax(SelectPatternFlavor SPF, + bool Ordered=false) { + switch (SPF) { + default: + llvm_unreachable("unhandled!"); + + case SPF_SMIN: + return ICmpInst::ICMP_SLT; + case SPF_UMIN: + return ICmpInst::ICMP_ULT; + case SPF_SMAX: + return ICmpInst::ICMP_SGT; + case SPF_UMAX: + return ICmpInst::ICMP_UGT; + case SPF_FMINNUM: + return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; + case SPF_FMAXNUM: + return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; } - - // TODO: (X > 4) ? X : 5 --> (X >= 5) ? X : 5 --> MAX(X, 5) - - return SPF_UNKNOWN; } +static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder, + SelectPatternFlavor SPF, Value *A, + Value *B) { + CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF); + assert(CmpInst::isIntPredicate(Pred)); + return Builder->CreateSelect(Builder->CreateICmp(Pred, A, B), A, B); +} /// GetSelectFoldableOperands - We want to turn code that looks like this: /// %C = or %A, %B @@ -309,72 +282,6 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal, return nullptr; } -/// SimplifyWithOpReplaced - See if V simplifies when its operand Op is -/// replaced with RepOp. -static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, - const DataLayout *TD, - const TargetLibraryInfo *TLI, - DominatorTree *DT, AssumptionCache *AC) { - // Trivial replacement. - if (V == Op) - return RepOp; - - Instruction *I = dyn_cast(V); - if (!I) - return nullptr; - - // If this is a binary operator, try to simplify it with the replaced op. - if (BinaryOperator *B = dyn_cast(I)) { - if (B->getOperand(0) == Op) - return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD, TLI); - if (B->getOperand(1) == Op) - return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD, TLI); - } - - // Same for CmpInsts. - if (CmpInst *C = dyn_cast(I)) { - if (C->getOperand(0) == Op) - return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD, - TLI, DT, AC); - if (C->getOperand(1) == Op) - return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD, - TLI, DT, AC); - } - - // TODO: We could hand off more cases to instsimplify here. - - // If all operands are constant after substituting Op for RepOp then we can - // constant fold the instruction. - if (Constant *CRepOp = dyn_cast(RepOp)) { - // Build a list of all constant operands. - SmallVector ConstOps; - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { - if (I->getOperand(i) == Op) - ConstOps.push_back(CRepOp); - else if (Constant *COp = dyn_cast(I->getOperand(i))) - ConstOps.push_back(COp); - else - break; - } - - // All operands were constants, fold it. - if (ConstOps.size() == I->getNumOperands()) { - if (CmpInst *C = dyn_cast(I)) - return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0], - ConstOps[1], TD, TLI); - - if (LoadInst *LI = dyn_cast(I)) - if (!LI->isVolatile()) - return ConstantFoldLoadFromConstPtr(ConstOps[0], TD); - - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), - ConstOps, TD, TLI); - } - } - - return nullptr; -} - /// foldSelectICmpAndOr - We want to turn: /// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) /// into: @@ -482,16 +389,12 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, match(Count, m_Intrinsic(m_Specific(CmpLHS)))) { IntrinsicInst *II = cast(Count); IRBuilder<> Builder(II); - if (cast(II->getArgOperand(1))->isOne()) { - // Explicitly clear the 'undef_on_zero' flag. - IntrinsicInst *NewI = cast(II->clone()); - Type *Ty = NewI->getArgOperand(1)->getType(); - NewI->setArgOperand(1, Constant::getNullValue(Ty)); - Builder.Insert(NewI); - Count = NewI; - } - - return Builder.CreateZExtOrTrunc(Count, ValueOnZero->getType()); + // Explicitly clear the 'undef_on_zero' flag. + IntrinsicInst *NewI = cast(II->clone()); + Type *Ty = NewI->getArgOperand(1)->getType(); + NewI->setArgOperand(1, Constant::getNullValue(Ty)); + Builder.Insert(NewI); + return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); } return nullptr; @@ -514,14 +417,6 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, // here, so make sure the select is the only user. if (ICI->hasOneUse()) if (ConstantInt *CI = dyn_cast(CmpRHS)) { - // X < MIN ? T : F --> F - if ((Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) - && CI->isMinValue(Pred == ICmpInst::ICMP_SLT)) - return ReplaceInstUsesWith(SI, FalseVal); - // X > MAX ? T : F --> F - else if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT) - && CI->isMaxValue(Pred == ICmpInst::ICMP_SGT)) - return ReplaceInstUsesWith(SI, FalseVal); switch (Pred) { default: break; case ICmpInst::ICMP_ULT: @@ -635,33 +530,6 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, } } - // If we have an equality comparison then we know the value in one of the - // arms of the select. See if substituting this value into the arm and - // simplifying the result yields the same value as the other arm. - if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == - TrueVal) - return ReplaceInstUsesWith(SI, FalseVal); - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == - FalseVal) - return ReplaceInstUsesWith(SI, FalseVal); - } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == - FalseVal) - return ReplaceInstUsesWith(SI, TrueVal); - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) == - TrueVal) - return ReplaceInstUsesWith(SI, TrueVal); - } - // NOTE: if we wanted to, this is where to detect integer MIN/MAX if (CmpRHS != CmpLHS && isa(CmpRHS)) { @@ -676,7 +544,8 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI, } } - if (unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits()) { + { + unsigned BitWidth = DL.getTypeSizeInBits(TrueVal->getType()); APInt MinSignedValue = APInt::getSignBit(BitWidth); Value *X; const APInt *Y, *C; @@ -833,6 +702,52 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner, SI->getCondition(), SI->getFalseValue(), SI->getTrueValue()); return ReplaceInstUsesWith(Outer, NewSI); } + + auto IsFreeOrProfitableToInvert = + [&](Value *V, Value *&NotV, bool &ElidesXor) { + if (match(V, m_Not(m_Value(NotV)))) { + // If V has at most 2 uses then we can get rid of the xor operation + // entirely. + ElidesXor |= !V->hasNUsesOrMore(3); + return true; + } + + if (IsFreeToInvert(V, !V->hasNUsesOrMore(3))) { + NotV = nullptr; + return true; + } + + return false; + }; + + Value *NotA, *NotB, *NotC; + bool ElidesXor = false; + + // MIN(MIN(~A, ~B), ~C) == ~MAX(MAX(A, B), C) + // MIN(MAX(~A, ~B), ~C) == ~MAX(MIN(A, B), C) + // MAX(MIN(~A, ~B), ~C) == ~MIN(MAX(A, B), C) + // MAX(MAX(~A, ~B), ~C) == ~MIN(MIN(A, B), C) + // + // This transform is performance neutral if we can elide at least one xor from + // the set of three operands, since we'll be tacking on an xor at the very + // end. + if (IsFreeOrProfitableToInvert(A, NotA, ElidesXor) && + IsFreeOrProfitableToInvert(B, NotB, ElidesXor) && + IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) { + if (!NotA) + NotA = Builder->CreateNot(A); + if (!NotB) + NotB = Builder->CreateNot(B); + if (!NotC) + NotC = Builder->CreateNot(C); + + Value *NewInner = generateMinMaxSelectPattern( + Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB); + Value *NewOuter = Builder->CreateNot(generateMinMaxSelectPattern( + Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC)); + return ReplaceInstUsesWith(Outer, NewOuter); + } + return nullptr; } @@ -931,7 +846,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return BinaryOperator::CreateAnd(NotCond, FalseVal); } if (ConstantInt *C = dyn_cast(FalseVal)) { - if (C->getZExtValue() == false) { + if (!C->getZExtValue()) { // Change: A = select B, C, false --> A = and B, C return BinaryOperator::CreateAnd(CondVal, TrueVal); } @@ -1017,6 +932,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { FCmpInst::Predicate InvPred = FCI->getInversePredicate(); + IRBuilder<>::FastMathFlagGuard FMFG(*Builder); + Builder->SetFastMathFlags(FCI->getFastMathFlags()); Value *NewCond = Builder->CreateFCmp(InvPred, TrueVal, FalseVal, FCI->getName() + ".inv"); @@ -1058,6 +975,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // (X ugt Y) ? X : Y -> (X ole Y) ? X : Y if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { FCmpInst::Predicate InvPred = FCI->getInversePredicate(); + IRBuilder<>::FastMathFlagGuard FMFG(*Builder); + Builder->SetFastMathFlags(FCI->getFastMathFlags()); Value *NewCond = Builder->CreateFCmp(InvPred, FalseVal, TrueVal, FCI->getName() + ".inv"); @@ -1145,26 +1064,78 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } // See if we can fold the select into one of our operands. - if (SI.getType()->isIntegerTy()) { + if (SI.getType()->isIntOrIntVectorTy() || SI.getType()->isFPOrFPVectorTy()) { if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal)) return FoldI; - // MAX(MAX(a, b), a) -> MAX(a, b) - // MIN(MIN(a, b), a) -> MIN(a, b) - // MAX(MIN(a, b), a) -> a - // MIN(MAX(a, b), a) -> a Value *LHS, *RHS, *LHS2, *RHS2; - if (SelectPatternFlavor SPF = MatchSelectPattern(&SI, LHS, RHS)) { - if (SelectPatternFlavor SPF2 = MatchSelectPattern(LHS, LHS2, RHS2)) + Instruction::CastOps CastOp; + SelectPatternResult SPR = matchSelectPattern(&SI, LHS, RHS, &CastOp); + auto SPF = SPR.Flavor; + + if (SPF) { + // Canonicalize so that type casts are outside select patterns. + if (LHS->getType()->getPrimitiveSizeInBits() != + SI.getType()->getPrimitiveSizeInBits()) { + CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); + + Value *Cmp; + if (CmpInst::isIntPredicate(Pred)) { + Cmp = Builder->CreateICmp(Pred, LHS, RHS); + } else { + IRBuilder<>::FastMathFlagGuard FMFG(*Builder); + auto FMF = cast(SI.getCondition())->getFastMathFlags(); + Builder->SetFastMathFlags(FMF); + Cmp = Builder->CreateFCmp(Pred, LHS, RHS); + } + + Value *NewSI = Builder->CreateCast(CastOp, + Builder->CreateSelect(Cmp, LHS, RHS), + SI.getType()); + return ReplaceInstUsesWith(SI, NewSI); + } + + // MAX(MAX(a, b), a) -> MAX(a, b) + // MIN(MIN(a, b), a) -> MIN(a, b) + // MAX(MIN(a, b), a) -> a + // MIN(MAX(a, b), a) -> a + if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor) if (Instruction *R = FoldSPFofSPF(cast(LHS),SPF2,LHS2,RHS2, SI, SPF, RHS)) return R; - if (SelectPatternFlavor SPF2 = MatchSelectPattern(RHS, LHS2, RHS2)) + if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2).Flavor) if (Instruction *R = FoldSPFofSPF(cast(RHS),SPF2,LHS2,RHS2, SI, SPF, LHS)) return R; } + // MAX(~a, ~b) -> ~MIN(a, b) + if (SPF == SPF_SMAX || SPF == SPF_UMAX) { + if (IsFreeToInvert(LHS, LHS->hasNUses(2)) && + IsFreeToInvert(RHS, RHS->hasNUses(2))) { + + // This transform adds a xor operation and that extra cost needs to be + // justified. We look for simplifications that will result from + // applying this rule: + + bool Profitable = + (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) || + (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) || + (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value()))); + + if (Profitable) { + Value *NewLHS = Builder->CreateNot(LHS); + Value *NewRHS = Builder->CreateNot(RHS); + Value *NewCmp = SPF == SPF_SMAX + ? Builder->CreateICmpSLT(NewLHS, NewRHS) + : Builder->CreateICmpULT(NewLHS, NewRHS); + Value *NewSI = + Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS)); + return ReplaceInstUsesWith(SI, NewSI); + } + } + } + // TODO. // ABS(-X) -> ABS(X) } @@ -1178,19 +1149,41 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return NV; if (SelectInst *TrueSI = dyn_cast(TrueVal)) { - if (TrueSI->getCondition() == CondVal) { - if (SI.getTrueValue() == TrueSI->getTrueValue()) - return nullptr; - SI.setOperand(1, TrueSI->getTrueValue()); - return &SI; + if (TrueSI->getCondition()->getType() == CondVal->getType()) { + // select(C, select(C, a, b), c) -> select(C, a, c) + if (TrueSI->getCondition() == CondVal) { + if (SI.getTrueValue() == TrueSI->getTrueValue()) + return nullptr; + SI.setOperand(1, TrueSI->getTrueValue()); + return &SI; + } + // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) + // We choose this as normal form to enable folding on the And and shortening + // paths for the values (this helps GetUnderlyingObjects() for example). + if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) { + Value *And = Builder->CreateAnd(CondVal, TrueSI->getCondition()); + SI.setOperand(0, And); + SI.setOperand(1, TrueSI->getTrueValue()); + return &SI; + } } } if (SelectInst *FalseSI = dyn_cast(FalseVal)) { - if (FalseSI->getCondition() == CondVal) { - if (SI.getFalseValue() == FalseSI->getFalseValue()) - return nullptr; - SI.setOperand(2, FalseSI->getFalseValue()); - return &SI; + if (FalseSI->getCondition()->getType() == CondVal->getType()) { + // select(C, a, select(C, b, c)) -> select(C, a, c) + if (FalseSI->getCondition() == CondVal) { + if (SI.getFalseValue() == FalseSI->getFalseValue()) + return nullptr; + SI.setOperand(2, FalseSI->getFalseValue()); + return &SI; + } + // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b) + if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) { + Value *Or = Builder->CreateOr(CondVal, FalseSI->getCondition()); + SI.setOperand(0, Or); + SI.setOperand(2, FalseSI->getFalseValue()); + return &SI; + } } }