X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FTransforms%2FInstCombine%2FInstCombineAddSub.cpp;h=c608f84bc7bb5ec015480e9dced46671daac7dd2;hp=c37a9cf2ef9fb98f5b05826bfed2ffd625ea54c0;hb=529919ff310cbfce1ba55ea252ff738d5b56b93d;hpb=e04c0e3f8d7ae06db85079948745a855f299ba14 diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index c37a9cf2ef9..c608f84bc7b 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -11,7 +11,7 @@ // //===----------------------------------------------------------------------===// -#include "InstCombine.h" +#include "InstCombineInternal.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/DataLayout.h" @@ -32,7 +32,7 @@ namespace { /// class FAddendCoef { public: - // The constructor has to initialize a APFloat, which is uncessary for + // The constructor has to initialize a APFloat, which is unnecessary for // most addends which have coefficient either 1 or -1. So, the constructor // is expensive. In order to avoid the cost of the constructor, we should // reuse some instances whenever possible. The pre-created instances @@ -751,8 +751,7 @@ Value *FAddCombine::createNaryFAdd return LastVal; } -Value *FAddCombine::createFSub - (Value *Opnd0, Value *Opnd1) { +Value *FAddCombine::createFSub(Value *Opnd0, Value *Opnd1) { Value *V = Builder->CreateFSub(Opnd0, Opnd1); if (Instruction *I = dyn_cast(V)) createInstPostProc(I); @@ -760,15 +759,14 @@ Value *FAddCombine::createFSub } Value *FAddCombine::createFNeg(Value *V) { - Value *Zero = cast(ConstantFP::get(V->getType(), 0.0)); + Value *Zero = cast(ConstantFP::getZeroValueForNegation(V->getType())); Value *NewV = createFSub(Zero, V); if (Instruction *I = dyn_cast(NewV)) createInstPostProc(I, true); // fneg's don't receive instruction numbers. return NewV; } -Value *FAddCombine::createFAdd - (Value *Opnd0, Value *Opnd1) { +Value *FAddCombine::createFAdd(Value *Opnd0, Value *Opnd1) { Value *V = Builder->CreateFAdd(Opnd0, Opnd1); if (Instruction *I = dyn_cast(V)) createInstPostProc(I); @@ -789,8 +787,7 @@ Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) { return V; } -void FAddCombine::createInstPostProc(Instruction *NewInstr, - bool NoNumber) { +void FAddCombine::createInstPostProc(Instruction *NewInstr, bool NoNumber) { NewInstr->setDebugLoc(Instr->getDebugLoc()); // Keep track of the number of instruction created. @@ -840,8 +837,7 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { // "fmul V, C" false // // NOTE: Keep this function in sync with FAddCombine::calcInstrNumber. -Value *FAddCombine::createAddendVal - (const FAddend &Opnd, bool &NeedNeg) { +Value *FAddCombine::createAddendVal(const FAddend &Opnd, bool &NeedNeg) { const FAddendCoef &Coeff = Opnd.getCoef(); if (Opnd.isConstant()) { @@ -865,57 +861,192 @@ Value *FAddCombine::createAddendVal return createFMul(OpndVal, Coeff.getValue(Instr->getType())); } -// dyn_castFoldableMul - If this value is a multiply that can be folded into -// other computations (because it has a constant operand), return the -// non-constant operand of the multiply, and set CST to point to the multiplier. -// Otherwise, return null. -// -static inline Value *dyn_castFoldableMul(Value *V, Constant *&CST) { - if (!V->hasOneUse() || !V->getType()->isIntOrIntVectorTy()) - return nullptr; - - Instruction *I = dyn_cast(V); - if (!I) return nullptr; - - if (I->getOpcode() == Instruction::Mul) - if ((CST = dyn_cast(I->getOperand(1)))) - return I->getOperand(0); - if (I->getOpcode() == Instruction::Shl) - if ((CST = dyn_cast(I->getOperand(1)))) { - // The multiplier is really 1 << CST. - CST = ConstantExpr::getShl(ConstantInt::get(V->getType(), 1), CST); - return I->getOperand(0); - } - return nullptr; +// If one of the operands only has one non-zero bit, and if the other +// operand has a known-zero bit in a more significant place than it (not +// including the sign bit) the ripple may go up to and fill the zero, but +// won't change the sign. For example, (X & ~4) + 1. +static bool checkRippleForAdd(const APInt &Op0KnownZero, + const APInt &Op1KnownZero) { + APInt Op1MaybeOne = ~Op1KnownZero; + // Make sure that one of the operand has at most one bit set to 1. + if (Op1MaybeOne.countPopulation() != 1) + return false; + + // Find the most significant known 0 other than the sign bit. + int BitWidth = Op0KnownZero.getBitWidth(); + APInt Op0KnownZeroTemp(Op0KnownZero); + Op0KnownZeroTemp.clearBit(BitWidth - 1); + int Op0ZeroPosition = BitWidth - Op0KnownZeroTemp.countLeadingZeros() - 1; + + int Op1OnePosition = BitWidth - Op1MaybeOne.countLeadingZeros() - 1; + assert(Op1OnePosition >= 0); + + // This also covers the case of no known zero, since in that case + // Op0ZeroPosition is -1. + return Op0ZeroPosition >= Op1OnePosition; } - /// WillNotOverflowSignedAdd - Return true if we can prove that: /// (sext (add LHS, RHS)) === (add (sext LHS), (sext RHS)) /// This basically requires proving that the add in the original type would not /// overflow to change the sign bit or have a carry out. -bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) { +bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS, + Instruction &CxtI) { // There are different heuristics we can use for this. Here are some simple // ones. - // Add has the property that adding any two 2's complement numbers can only - // have one carry bit which can change a sign. As such, if LHS and RHS each - // have at least two sign bits, we know that the addition of the two values - // will sign extend fine. - if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1) + // If LHS and RHS each have at least two sign bits, the addition will look + // like + // + // XX..... + + // YY..... + // + // If the carry into the most significant position is 0, X and Y can't both + // be 1 and therefore the carry out of the addition is also 0. + // + // If the carry into the most significant position is 1, X and Y can't both + // be 0 and therefore the carry out of the addition is also 1. + // + // Since the carry into the most significant position is always equal to + // the carry out of the addition, there is no signed overflow. + if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 && + ComputeNumSignBits(RHS, 0, &CxtI) > 1) + return true; + + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &CxtI); + + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI); + + // Addition of two 2's compliment numbers having opposite signs will never + // overflow. + if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) || + (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1])) + return true; + + // Check if carry bit of addition will not cause overflow. + if (checkRippleForAdd(LHSKnownZero, RHSKnownZero)) + return true; + if (checkRippleForAdd(RHSKnownZero, LHSKnownZero)) + return true; + + return false; +} + +/// \brief Return true if we can prove that: +/// (sub LHS, RHS) === (sub nsw LHS, RHS) +/// This basically requires proving that the add in the original type would not +/// overflow to change the sign bit or have a carry out. +/// TODO: Handle this for Vectors. +bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS, + Instruction &CxtI) { + // If LHS and RHS each have at least two sign bits, the subtraction + // cannot overflow. + if (ComputeNumSignBits(LHS, 0, &CxtI) > 1 && + ComputeNumSignBits(RHS, 0, &CxtI) > 1) return true; + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &CxtI); - // If one of the operands only has one non-zero bit, and if the other operand - // has a known-zero bit in a more significant place than it (not including the - // sign bit) the ripple may go up to and fill the zero, but won't change the - // sign. For example, (X & ~4) + 1. + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI); - // TODO: Implement. + // Subtraction of two 2's compliment numbers having identical signs will + // never overflow. + if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) || + (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1])) + return true; + + // TODO: implement logic similar to checkRippleForAdd + return false; +} + +/// \brief Return true if we can prove that: +/// (sub LHS, RHS) === (sub nuw LHS, RHS) +bool InstCombiner::WillNotOverflowUnsignedSub(Value *LHS, Value *RHS, + Instruction &CxtI) { + // If the LHS is negative and the RHS is non-negative, no unsigned wrap. + bool LHSKnownNonNegative, LHSKnownNegative; + bool RHSKnownNonNegative, RHSKnownNegative; + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, /*Depth=*/0, + &CxtI); + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, /*Depth=*/0, + &CxtI); + if (LHSKnownNegative && RHSKnownNonNegative) + return true; return false; } +// Checks if any operand is negative and we can convert add to sub. +// This function checks for following negative patterns +// ADD(XOR(OR(Z, NOT(C)), C)), 1) == NEG(AND(Z, C)) +// ADD(XOR(AND(Z, C), C), 1) == NEG(OR(Z, ~C)) +// XOR(AND(Z, C), (C + 1)) == NEG(OR(Z, ~C)) if C is even +static Value *checkForNegativeOperand(BinaryOperator &I, + InstCombiner::BuilderTy *Builder) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + + // This function creates 2 instructions to replace ADD, we need at least one + // of LHS or RHS to have one use to ensure benefit in transform. + if (!LHS->hasOneUse() && !RHS->hasOneUse()) + return nullptr; + + Value *X = nullptr, *Y = nullptr, *Z = nullptr; + const APInt *C1 = nullptr, *C2 = nullptr; + + // if ONE is on other side, swap + if (match(RHS, m_Add(m_Value(X), m_One()))) + std::swap(LHS, RHS); + + if (match(LHS, m_Add(m_Value(X), m_One()))) { + // if XOR on other side, swap + if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1)))) + std::swap(X, RHS); + + if (match(X, m_Xor(m_Value(Y), m_APInt(C1)))) { + // X = XOR(Y, C1), Y = OR(Z, C2), C2 = NOT(C1) ==> X == NOT(AND(Z, C1)) + // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, AND(Z, C1)) + if (match(Y, m_Or(m_Value(Z), m_APInt(C2))) && (*C2 == ~(*C1))) { + Value *NewAnd = Builder->CreateAnd(Z, *C1); + return Builder->CreateSub(RHS, NewAnd, "sub"); + } else if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && (*C1 == *C2)) { + // X = XOR(Y, C1), Y = AND(Z, C2), C2 == C1 ==> X == NOT(OR(Z, ~C1)) + // ADD(ADD(X, 1), RHS) == ADD(X, ADD(RHS, 1)) == SUB(RHS, OR(Z, ~C1)) + Value *NewOr = Builder->CreateOr(Z, ~(*C1)); + return Builder->CreateSub(RHS, NewOr, "sub"); + } + } + } + + // Restore LHS and RHS + LHS = I.getOperand(0); + RHS = I.getOperand(1); + + // if XOR is on other side, swap + if (match(RHS, m_Xor(m_Value(Y), m_APInt(C1)))) + std::swap(LHS, RHS); + + // C2 is ODD + // LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2)) + // ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2)) + if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1)))) + if (C1->countTrailingZeros() == 0) + if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) { + Value *NewOr = Builder->CreateOr(Z, ~(*C2)); + return Builder->CreateSub(RHS, NewOr, "sub"); + } + return nullptr; +} + Instruction *InstCombiner::visitAdd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); @@ -924,10 +1055,10 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL)) + I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); - // (A*B)+(A*C) -> A*(B+C) etc + // (A*B)+(A*C) -> A*(B+C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return ReplaceInstUsesWith(I, V); @@ -963,7 +1094,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (ExtendAmt) { APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt); - if (!MaskedValueIsZero(XorLHS, Mask)) + if (!MaskedValueIsZero(XorLHS, Mask, 0, &I)) ExtendAmt = 0; } @@ -979,7 +1110,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { IntegerType *IT = cast(I.getType()); APInt LHSKnownOne(IT->getBitWidth(), 0); APInt LHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(XorLHS, LHSKnownZero, LHSKnownOne, 0, &I); if ((XorRHS->getValue() | LHSKnownZero).isAllOnesValue()) return BinaryOperator::CreateSub(ConstantExpr::getAdd(XorRHS, CI), XorLHS); @@ -1025,33 +1156,18 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = dyn_castNegVal(RHS)) return BinaryOperator::CreateSub(LHS, V); - - { - Constant *C2; - if (Value *X = dyn_castFoldableMul(LHS, C2)) { - if (X == RHS) // X*C + X --> X * (C+1) - return BinaryOperator::CreateMul(RHS, AddOne(C2)); - - // X*C1 + X*C2 --> X * (C1+C2) - Constant *C1; - if (X == dyn_castFoldableMul(RHS, C1)) - return BinaryOperator::CreateMul(X, ConstantExpr::getAdd(C1, C2)); - } - - // X + X*C --> X * (C+1) - if (dyn_castFoldableMul(RHS, C2) == LHS) - return BinaryOperator::CreateMul(LHS, AddOne(C2)); - } + if (Value *V = checkForNegativeOperand(I, Builder)) + return ReplaceInstUsesWith(I, V); // A+B --> A|B iff A and B have no bits set in common. if (IntegerType *IT = dyn_cast(I.getType())) { APInt LHSKnownOne(IT->getBitWidth(), 0); APInt LHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, 0, &I); if (LHSKnownZero != 0) { APInt RHSKnownOne(IT->getBitWidth(), 0); APInt RHSKnownZero(IT->getBitWidth(), 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &I); // No bits in common -> bitwise or. if ((LHSKnownZero|RHSKnownZero).isAllOnesValue()) @@ -1059,29 +1175,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } - // W*X + Y*Z --> W * (X+Z) iff W == Y - { - Value *W, *X, *Y, *Z; - if (match(LHS, m_Mul(m_Value(W), m_Value(X))) && - match(RHS, m_Mul(m_Value(Y), m_Value(Z)))) { - if (W != Y) { - if (W == Z) { - std::swap(Y, Z); - } else if (Y == X) { - std::swap(W, X); - } else if (X == Z) { - std::swap(Y, Z); - std::swap(W, X); - } - } - - if (W == Y) { - Value *NewAdd = Builder->CreateAdd(X, Z, LHS->getName()); - return BinaryOperator::CreateMul(W, NewAdd); - } - } - } - if (Constant *CRHS = dyn_cast(RHS)) { Value *X; if (match(LHS, m_Not(m_Value(X)))) // ~X + C --> (C-1) - X @@ -1152,7 +1245,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); if (LHSConv->hasOneUse() && ConstantExpr::getSExt(CI, I.getType()) == RHSC && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { // Insert the new, smaller add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); @@ -1165,10 +1258,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // Only do this if x/y have the same type, if at last one of them has a // single use (so we don't increase the number of sexts), and if the // integer add will not overflow. - if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& + if (LHSConv->getOperand(0)->getType() == + RHSConv->getOperand(0)->getType() && (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0))) { + RHSConv->getOperand(0), I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); @@ -1177,7 +1271,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } - // Check for (x & y) + (x ^ y) + // (add (xor A, B) (and A, B)) --> (or A, B) { Value *A = nullptr, *B = nullptr; if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && @@ -1191,6 +1285,42 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateOr(A, B); } + // (add (or A, B) (and A, B)) --> (add A, B) + { + Value *A = nullptr, *B = nullptr; + if (match(RHS, m_Or(m_Value(A), m_Value(B))) && + (match(LHS, m_And(m_Specific(A), m_Specific(B))) || + match(LHS, m_And(m_Specific(B), m_Specific(A))))) { + auto *New = BinaryOperator::CreateAdd(A, B); + New->setHasNoSignedWrap(I.hasNoSignedWrap()); + New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return New; + } + + if (match(LHS, m_Or(m_Value(A), m_Value(B))) && + (match(RHS, m_And(m_Specific(A), m_Specific(B))) || + match(RHS, m_And(m_Specific(B), m_Specific(A))))) { + auto *New = BinaryOperator::CreateAdd(A, B); + New->setHasNoSignedWrap(I.hasNoSignedWrap()); + New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return New; + } + } + + // TODO(jingyue): Consider WillNotOverflowSignedAdd and + // WillNotOverflowUnsignedAdd to reduce the number of invocations of + // computeKnownBits. + if (!I.hasNoSignedWrap() && WillNotOverflowSignedAdd(LHS, RHS, I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + if (!I.hasNoUnsignedWrap() && + computeOverflowForUnsignedAdd(LHS, RHS, &I) == + OverflowResult::NeverOverflows) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + return Changed ? &I : nullptr; } @@ -1201,7 +1331,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL)) + if (Value *V = + SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); if (isa(RHS)) { @@ -1243,7 +1374,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType()); if (LHSConv->hasOneUse() && ConstantExpr::getSIToFP(CI, I.getType()) == CFP && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) { + WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); @@ -1256,10 +1387,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { // Only do this if x/y have the same type, if at last one of them has a // single use (so we don't increase the number of int->fp conversions), // and if the integer add will not overflow. - if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&& + if (LHSConv->getOperand(0)->getType() == + RHSConv->getOperand(0)->getType() && (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0))) { + RHSConv->getOperand(0), I)) { // Insert the new integer add. Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0),"addconv"); @@ -1281,11 +1413,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { Z2 = dyn_cast(B2); B = B1; } else if (match(B1, m_AnyZero()) && match(A2, m_AnyZero())) { Z1 = dyn_cast(B1); B = B2; - Z2 = dyn_cast(A2); A = A1; + Z2 = dyn_cast(A2); A = A1; } - - if (Z1 && Z2 && - (I.hasNoSignedZeros() || + + if (Z1 && Z2 && + (I.hasNoSignedZeros() || (Z1->isNegativeZeroValue() && Z2->isNegativeZeroValue()))) { return SelectInst::Create(C, A, B); } @@ -1308,8 +1440,6 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { /// Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty) { - assert(DL && "Must have target data info for this"); - // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize // this. bool Swapped = false; @@ -1372,7 +1502,6 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, return Builder->CreateIntCast(Result, Ty, true); } - Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1380,18 +1509,27 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return ReplaceInstUsesWith(I, V); if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL)) + I.hasNoUnsignedWrap(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); // (A*B)-(A*C) -> A*(B-C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return ReplaceInstUsesWith(I, V); - // If this is a 'B = x-(-A)', change to B = x+A. This preserves NSW/NUW. + // If this is a 'B = x-(-A)', change to B = x+A. if (Value *V = dyn_castNegVal(Op1)) { BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V); - Res->setHasNoSignedWrap(I.hasNoSignedWrap()); - Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + + if (const auto *BO = dyn_cast(Op1)) { + assert(BO->getOpcode() == Instruction::Sub && + "Expected a subtraction operator!"); + if (BO->hasNoSignedWrap() && I.hasNoSignedWrap()) + Res->setHasNoSignedWrap(true); + } else { + if (cast(Op1)->isNotMinSignedValue() && I.hasNoSignedWrap()) + Res->setHasNoSignedWrap(true); + } + return Res; } @@ -1436,21 +1574,23 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // -(X >>u 31) -> (X >>s 31) // -(X >>s 31) -> (X >>u 31) if (C->isZero()) { - Value *X; ConstantInt *CI; + Value *X; + ConstantInt *CI; if (match(Op1, m_LShr(m_Value(X), m_ConstantInt(CI))) && // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) return BinaryOperator::CreateAShr(X, CI); if (match(Op1, m_AShr(m_Value(X), m_ConstantInt(CI))) && // Verify we are shifting out everything but the sign bit. - CI->getValue() == I.getType()->getPrimitiveSizeInBits()-1) + CI->getValue() == I.getType()->getPrimitiveSizeInBits() - 1) return BinaryOperator::CreateLShr(X, CI); } } - { Value *Y; + { + Value *Y; // X-(X+Y) == -Y X-(Y+X) == -Y if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) || match(Op1, m_Add(m_Value(Y), m_Specific(Op0)))) @@ -1461,6 +1601,24 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNeg(Y); } + // (sub (or A, B) (xor A, B)) --> (and A, B) + { + Value *A = nullptr, *B = nullptr; + if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && + (match(Op0, m_Or(m_Specific(A), m_Specific(B))) || + match(Op0, m_Or(m_Specific(B), m_Specific(A))))) + return BinaryOperator::CreateAnd(A, B); + } + + if (Op0->hasOneUse()) { + Value *Y = nullptr; + // ((X | Y) - X) --> (~X & Y) + if (match(Op0, m_Or(m_Value(Y), m_Specific(Op1))) || + match(Op0, m_Or(m_Specific(Op1), m_Value(Y)))) + return BinaryOperator::CreateAnd( + Y, Builder->CreateNot(Op1, Op1->getName() + ".not")); + } + if (Op1->hasOneUse()) { Value *X = nullptr, *Y = nullptr, *Z = nullptr; Constant *C = nullptr; @@ -1478,9 +1636,9 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateAnd(Op0, Builder->CreateNot(Y, Y->getName() + ".not")); - // 0 - (X sdiv C) -> (X sdiv -C) - if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && - match(Op0, m_Zero())) + // 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow. + if (match(Op1, m_SDiv(m_Value(X), m_Constant(C))) && match(Op0, m_Zero()) && + C->isNotMinSignedValue() && !C->isOneValue()) return BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(C)); // 0 - (X << Y) -> (-X << Y) when X is freely negatable. @@ -1488,19 +1646,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Value *XNeg = dyn_castNegVal(X)) return BinaryOperator::CreateShl(XNeg, Y); - // X - X*C --> X * (1-C) - if (match(Op1, m_Mul(m_Specific(Op0), m_Constant(CI)))) { - Constant *CP1 = ConstantExpr::getSub(ConstantInt::get(I.getType(),1), CI); - return BinaryOperator::CreateMul(Op0, CP1); - } - - // X - X< X * (1-(1< X + A*B // X - -A*B -> X + A*B Value *A, *B; @@ -1517,33 +1662,31 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { } } - Constant *C1; - if (Value *X = dyn_castFoldableMul(Op0, C1)) { - if (X == Op1) // X*C - X --> X * (C-1) - return BinaryOperator::CreateMul(Op1, SubOne(C1)); - - Constant *C2; // X*C1 - X*C2 -> X * (C1-C2) - if (X == dyn_castFoldableMul(Op1, C2)) - return BinaryOperator::CreateMul(X, ConstantExpr::getSub(C1, C2)); - } - // Optimize pointer differences into the same array into a size. Consider: // &A[10] - &A[0]: we should compile this to "10". - if (DL) { - Value *LHSOp, *RHSOp; - if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && - match(Op1, m_PtrToInt(m_Value(RHSOp)))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) - return ReplaceInstUsesWith(I, Res); - - // trunc(p)-trunc(q) -> trunc(p-q) - if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && - match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) - return ReplaceInstUsesWith(I, Res); + Value *LHSOp, *RHSOp; + if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && + match(Op1, m_PtrToInt(m_Value(RHSOp)))) + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) + return ReplaceInstUsesWith(I, Res); + + // trunc(p)-trunc(q) -> trunc(p-q) + if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && + match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) + return ReplaceInstUsesWith(I, Res); + + bool Changed = false; + if (!I.hasNoSignedWrap() && WillNotOverflowSignedSub(Op0, Op1, I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + if (!I.hasNoUnsignedWrap() && WillNotOverflowUnsignedSub(Op0, Op1, I)) { + Changed = true; + I.setHasNoUnsignedWrap(true); } - return nullptr; + return Changed ? &I : nullptr; } Instruction *InstCombiner::visitFSub(BinaryOperator &I) { @@ -1552,9 +1695,18 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return ReplaceInstUsesWith(I, V); - if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL)) + if (Value *V = + SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) return ReplaceInstUsesWith(I, V); + // fsub nsz 0, X ==> fsub nsz -0.0, X + if (I.getFastMathFlags().noSignedZeros() && match(Op0, m_Zero())) { + // Subtraction from -0.0 is the canonical form of fneg. + Instruction *NewI = BinaryOperator::CreateFNeg(Op1); + NewI->copyFastMathFlags(&I); + return NewI; + } + if (isa(Op0)) if (SelectInst *SI = dyn_cast(Op1)) if (Instruction *NV = FoldOpIntoSelect(I, SI))