X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FCodeGen%2FSelectionDAG%2FDAGCombiner.cpp;h=a1edb29e4860ba3605791d7bacf5429f857a2ca8;hb=b8f7aeb21802f773effd28222f24f00942b50c50;hp=53d07b385e317295baca39a819bb2fb6db378c3e;hpb=19a3f63b5477afc8de8a5de22ecd975050e5cef2;p=oota-llvm.git diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 53d07b385e3..a1edb29e486 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -156,13 +156,16 @@ namespace { void deleteAndRecombine(SDNode *N); bool recursivelyDeleteUnusedNodes(SDNode *N); + /// Replaces all uses of the results of one DAG node with new values. SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo, bool AddTo = true); + /// Replaces all uses of the results of one DAG node with new values. SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) { return CombineTo(N, &Res, 1, AddTo); } + /// Replaces all uses of the results of one DAG node with new values. SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo = true) { SDValue To[] = { Res0, Res1 }; @@ -233,18 +236,16 @@ namespace { SDValue visitADDE(SDNode *N); SDValue visitSUBE(SDNode *N); SDValue visitMUL(SDNode *N); + SDValue useDivRem(SDNode *N); SDValue visitSDIV(SDNode *N); SDValue visitUDIV(SDNode *N); - SDValue visitSREM(SDNode *N); - SDValue visitUREM(SDNode *N); + SDValue visitREM(SDNode *N); SDValue visitMULHU(SDNode *N); SDValue visitMULHS(SDNode *N); SDValue visitSMUL_LOHI(SDNode *N); SDValue visitUMUL_LOHI(SDNode *N); SDValue visitSMULO(SDNode *N); SDValue visitUMULO(SDNode *N); - SDValue visitSDIVREM(SDNode *N); - SDValue visitUDIVREM(SDNode *N); SDValue visitIMINMAX(SDNode *N); SDValue visitAND(SDNode *N); SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *LocReference); @@ -299,6 +300,10 @@ namespace { SDValue visitBRCOND(SDNode *N); SDValue visitBR_CC(SDNode *N); SDValue visitLOAD(SDNode *N); + + SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain); + SDValue replaceStoreOfFPConstant(StoreSDNode *ST); + SDValue visitSTORE(SDNode *N); SDValue visitINSERT_VECTOR_ELT(SDNode *N); SDValue visitEXTRACT_VECTOR_ELT(SDNode *N); @@ -313,9 +318,11 @@ namespace { SDValue visitMGATHER(SDNode *N); SDValue visitMSCATTER(SDNode *N); SDValue visitFP_TO_FP16(SDNode *N); + SDValue visitFP16_TO_FP(SDNode *N); SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); + SDValue visitFMULForFMACombine(SDNode *N); SDValue XformToShuffleWithZero(SDNode *N); SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS); @@ -344,10 +351,12 @@ namespace { SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); - SDValue BuildReciprocalEstimate(SDValue Op); - SDValue BuildRsqrtEstimate(SDValue Op); - SDValue BuildRsqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations); - SDValue BuildRsqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations); + SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags); + SDValue BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags); + SDValue BuildRsqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations, + SDNodeFlags *Flags); + SDValue BuildRsqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations, + SDNodeFlags *Flags); SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, bool DemandHighBits = true); SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1); @@ -376,6 +385,10 @@ namespace { /// chain (aliasing node.) SDValue FindBetterChain(SDNode *N, SDValue Chain); + /// Do FindBetterChain for a store and any possibly adjacent stores on + /// consecutive chains. + bool findBetterNeighborChains(StoreSDNode *St); + /// Holds a pointer to an LSBaseSDNode as well as information on where it /// is located in a sequence of memory operations connected by a chain. struct MemOpLink { @@ -390,11 +403,20 @@ namespace { unsigned SequenceNum; }; + /// This is a helper function for visitMUL to check the profitability + /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2). + /// MulNode is the original multiply, AddNode is (add x, c1), + /// and ConstNode is c2. + bool isMulAddWithConstProfitable(SDNode *MulNode, + SDValue &AddNode, + SDValue &ConstNode); + /// This is a helper function for MergeStoresOfConstantsOrVecElts. Returns a /// constant build_vector of the stored constant values in Stores. SDValue getMergedConstantVectorStore(SelectionDAG &DAG, SDLoc SL, ArrayRef Stores, + SmallVectorImpl &Chains, EVT Ty) const; /// This is a helper function for MergeConsecutiveStores. When the source @@ -402,7 +424,7 @@ namespace { /// vector elements, try to merge them into one larger store. /// \return True if a merged store was created. bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl &StoreNodes, - EVT MemVT, unsigned NumElem, + EVT MemVT, unsigned NumStores, bool IsConstantSrc, bool UseVector); /// This is a helper function for MergeConsecutiveStores. @@ -606,6 +628,9 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, assert(Op.hasOneUse() && "Unknown reuse!"); assert(Depth <= 6 && "GetNegatedExpression doesn't match isNegatibleForFree"); + + const SDNodeFlags *Flags = Op.getNode()->getFlags(); + switch (Op.getOpcode()) { default: llvm_unreachable("Unknown code"); case ISD::ConstantFP: { @@ -623,12 +648,12 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), GetNegatedExpression(Op.getOperand(0), DAG, LegalOperations, Depth+1), - Op.getOperand(1)); + Op.getOperand(1), Flags); // fold (fneg (fadd A, B)) -> (fsub (fneg B), A) return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), GetNegatedExpression(Op.getOperand(1), DAG, LegalOperations, Depth+1), - Op.getOperand(0)); + Op.getOperand(0), Flags); case ISD::FSUB: // We can't turn -(A-B) into B-A when we honor signed zeros. assert(Options.UnsafeFPMath); @@ -640,7 +665,7 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, // fold (fneg (fsub A, B)) -> (fsub B, A) return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), - Op.getOperand(1), Op.getOperand(0)); + Op.getOperand(1), Op.getOperand(0), Flags); case ISD::FMUL: case ISD::FDIV: @@ -652,13 +677,13 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), GetNegatedExpression(Op.getOperand(0), DAG, LegalOperations, Depth+1), - Op.getOperand(1)); + Op.getOperand(1), Flags); // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y)) return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), Op.getOperand(0), GetNegatedExpression(Op.getOperand(1), DAG, - LegalOperations, Depth+1)); + LegalOperations, Depth+1), Flags); case ISD::FP_EXTEND: case ISD::FSIN: @@ -1332,16 +1357,14 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::MUL: return visitMUL(N); case ISD::SDIV: return visitSDIV(N); case ISD::UDIV: return visitUDIV(N); - case ISD::SREM: return visitSREM(N); - case ISD::UREM: return visitUREM(N); + case ISD::SREM: + case ISD::UREM: return visitREM(N); case ISD::MULHU: return visitMULHU(N); case ISD::MULHS: return visitMULHS(N); case ISD::SMUL_LOHI: return visitSMUL_LOHI(N); case ISD::UMUL_LOHI: return visitUMUL_LOHI(N); case ISD::SMULO: return visitSMULO(N); case ISD::UMULO: return visitUMULO(N); - case ISD::SDIVREM: return visitSDIVREM(N); - case ISD::UDIVREM: return visitUDIVREM(N); case ISD::SMIN: case ISD::SMAX: case ISD::UMIN: @@ -1411,6 +1434,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::MSCATTER: return visitMSCATTER(N); case ISD::MSTORE: return visitMSTORE(N); case ISD::FP_TO_FP16: return visitFP_TO_FP16(N); + case ISD::FP16_TO_FP: return visitFP16_TO_FP(N); } return SDValue(); } @@ -1473,13 +1497,8 @@ SDValue DAGCombiner::combine(SDNode *N) { // Constant operands are canonicalized to RHS. if (isa(N0) || !isa(N1)) { SDValue Ops[] = {N1, N0}; - SDNode *CSENode; - if (const auto *BinNode = dyn_cast(N)) { - CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops, - &BinNode->Flags); - } else { - CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops); - } + SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops, + N->getFlags()); if (CSENode) return SDValue(CSENode, 0); } @@ -1724,22 +1743,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) { return SDValue(N, 0); // fold (a+b) -> (a|b) iff a and b share no bits. - if (VT.isInteger() && !VT.isVector()) { - APInt LHSZero, LHSOne; - APInt RHSZero, RHSOne; - DAG.computeKnownBits(N0, LHSZero, LHSOne); - - if (LHSZero.getBoolValue()) { - DAG.computeKnownBits(N1, RHSZero, RHSOne); - - // If all possibly-set bits on the LHS are clear on the RHS, return an OR. - // If all possibly-set bits on the RHS are clear on the LHS, return an OR. - if ((RHSZero & ~LHSZero) == ~LHSZero || (LHSZero & ~RHSZero) == ~RHSZero){ - if (!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) - return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N1); - } - } - } + if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) && + VT.isInteger() && !VT.isVector() && DAG.haveNoCommonBitsSet(N0, N1)) + return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N1); // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n)) if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB && @@ -1974,31 +1980,26 @@ SDValue DAGCombiner::visitSUBC(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); + SDLoc DL(N); // If the flag result is dead, turn this into an SUB. if (!N->hasAnyUseOfValue(1)) - return CombineTo(N, DAG.getNode(ISD::SUB, SDLoc(N), VT, N0, N1), - DAG.getNode(ISD::CARRY_FALSE, SDLoc(N), - MVT::Glue)); + return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1), + DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); // fold (subc x, x) -> 0 + no borrow - if (N0 == N1) { - SDLoc DL(N); + if (N0 == N1) return CombineTo(N, DAG.getConstant(0, DL, VT), - DAG.getNode(ISD::CARRY_FALSE, DL, - MVT::Glue)); - } + DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); // fold (subc x, 0) -> x + no borrow if (isNullConstant(N1)) - return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, SDLoc(N), - MVT::Glue)); + return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow if (isAllOnesConstant(N0)) - return CombineTo(N, DAG.getNode(ISD::XOR, SDLoc(N), VT, N1, N0), - DAG.getNode(ISD::CARRY_FALSE, SDLoc(N), - MVT::Glue)); + return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0), + DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); return SDValue(); } @@ -2133,14 +2134,15 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { } // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2) - if (N1IsConst && N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse() && - (isConstantSplatVector(N0.getOperand(1).getNode(), Val) || - isa(N0.getOperand(1)))) - return DAG.getNode(ISD::ADD, SDLoc(N), VT, - DAG.getNode(ISD::MUL, SDLoc(N0), VT, - N0.getOperand(0), N1), - DAG.getNode(ISD::MUL, SDLoc(N1), VT, - N0.getOperand(1), N1)); + if (isConstantIntBuildVectorOrConstantInt(N1) && + N0.getOpcode() == ISD::ADD && + isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) && + isMulAddWithConstProfitable(N, N0, N1)) + return DAG.getNode(ISD::ADD, SDLoc(N), VT, + DAG.getNode(ISD::MUL, SDLoc(N0), VT, + N0.getOperand(0), N1), + DAG.getNode(ISD::MUL, SDLoc(N1), VT, + N0.getOperand(1), N1)); // reassociate mul if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1)) @@ -2149,6 +2151,88 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { return SDValue(); } +/// Return true if divmod libcall is available. +static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned, + const TargetLowering &TLI) { + RTLIB::Libcall LC; + switch (Node->getSimpleValueType(0).SimpleTy) { + default: return false; // No libcall for vector types. + case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break; + case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break; + case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break; + case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break; + case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break; + } + + return TLI.getLibcallName(LC) != nullptr; +} + +/// Issue divrem if both quotient and remainder are needed. +SDValue DAGCombiner::useDivRem(SDNode *Node) { + if (Node->use_empty()) + return SDValue(); // This is a dead node, leave it alone. + + EVT VT = Node->getValueType(0); + if (!TLI.isTypeLegal(VT)) + return SDValue(); + + unsigned Opcode = Node->getOpcode(); + bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM); + + unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM; + // If DIVREM is going to get expanded into a libcall, + // but there is no libcall available, then don't combine. + if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) && + !isDivRemLibcallAvailable(Node, isSigned, TLI)) + return SDValue(); + + // If div is legal, it's better to do the normal expansion + unsigned OtherOpcode = 0; + if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) { + OtherOpcode = isSigned ? ISD::SREM : ISD::UREM; + if (TLI.isOperationLegalOrCustom(Opcode, VT)) + return SDValue(); + } else { + OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV; + if (TLI.isOperationLegalOrCustom(OtherOpcode, VT)) + return SDValue(); + } + + SDValue Op0 = Node->getOperand(0); + SDValue Op1 = Node->getOperand(1); + SDValue combined; + for (SDNode::use_iterator UI = Op0.getNode()->use_begin(), + UE = Op0.getNode()->use_end(); UI != UE; ++UI) { + SDNode *User = *UI; + if (User == Node || User->use_empty()) + continue; + // Convert the other matching node(s), too; + // otherwise, the DIVREM may get target-legalized into something + // target-specific that we won't be able to recognize. + unsigned UserOpc = User->getOpcode(); + if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) && + User->getOperand(0) == Op0 && + User->getOperand(1) == Op1) { + if (!combined) { + if (UserOpc == OtherOpcode) { + SDVTList VTs = DAG.getVTList(VT, VT); + combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1); + } else if (UserOpc == DivRemOpc) { + combined = SDValue(User, 0); + } else { + assert(UserOpc == Opcode); + continue; + } + } + if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV) + CombineTo(User, combined); + else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM) + CombineTo(User, combined.getValue(1)); + } + } + return combined; +} + SDValue DAGCombiner::visitSDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -2159,29 +2243,28 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; + SDLoc DL(N); + // fold (sdiv c1, c2) -> c1/c2 ConstantSDNode *N0C = isConstOrConstSplat(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); if (N0C && N1C && !N0C->isOpaque() && !N1C->isOpaque()) - return DAG.FoldConstantArithmetic(ISD::SDIV, SDLoc(N), VT, N0C, N1C); + return DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, N0C, N1C); // fold (sdiv X, 1) -> X if (N1C && N1C->isOne()) return N0; // fold (sdiv X, -1) -> 0-X - if (N1C && N1C->isAllOnesValue()) { - SDLoc DL(N); + if (N1C && N1C->isAllOnesValue()) return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0); - } + // If we know the sign bits of both operands are zero, strength reduce to a // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2 if (!VT.isVector()) { if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) - return DAG.getNode(ISD::UDIV, SDLoc(N), N1.getValueType(), - N0, N1); + return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1); } - bool MinSize = DAG.getMachineFunction().getFunction()->optForMinSize(); // fold (sdiv X, pow2) -> simple ops after legalize // FIXME: We check for the exact bit here because the generic lowering gives // better results in that case. The target-specific lowering should learn how @@ -2190,16 +2273,11 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { !cast(N)->Flags.hasExact() && (N1C->getAPIntValue().isPowerOf2() || (-N1C->getAPIntValue()).isPowerOf2())) { - // If integer division is cheap, then don't perform the following fold. - if (TLI.isIntDivCheap(N->getValueType(0), MinSize)) - return SDValue(); - // Target-specific implementation of sdiv x, pow2. if (SDValue Res = BuildSDIVPow2(N)) return Res; unsigned lg2 = N1C->getAPIntValue().countTrailingZeros(); - SDLoc DL(N); // Splat the sign bit into the register SDValue SGN = @@ -2230,14 +2308,23 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { } // If integer divide is expensive and we satisfy the requirements, emit an - // alternate sequence. - if (N1C && !TLI.isIntDivCheap(N->getValueType(0), MinSize)) + // alternate sequence. Targets may check function attributes for size/speed + // trade-offs. + AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes(); + if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr)) if (SDValue Op = BuildSDIV(N)) return Op; + // sdiv, srem -> sdivrem + // If the divisor is constant, then return DIVREM only if isIntDivCheap() is true. + // Otherwise, we break the simplification logic in visitREM(). + if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue DivRem = useDivRem(N)) + return DivRem; + // undef / X -> 0 if (N0.getOpcode() == ISD::UNDEF) - return DAG.getConstant(0, SDLoc(N), VT); + return DAG.getConstant(0, DL, VT); // X / undef -> undef if (N1.getOpcode() == ISD::UNDEF) return N1; @@ -2255,26 +2342,26 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; + SDLoc DL(N); + // fold (udiv c1, c2) -> c1/c2 ConstantSDNode *N0C = isConstOrConstSplat(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); if (N0C && N1C) - if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, SDLoc(N), VT, + if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, N0C, N1C)) return Folded; // fold (udiv x, (1 << c)) -> x >>u c - if (N1C && !N1C->isOpaque() && N1C->getAPIntValue().isPowerOf2()) { - SDLoc DL(N); + if (N1C && !N1C->isOpaque() && N1C->getAPIntValue().isPowerOf2()) return DAG.getNode(ISD::SRL, DL, VT, N0, DAG.getConstant(N1C->getAPIntValue().logBase2(), DL, getShiftAmountTy(N0.getValueType()))); - } + // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2 if (N1.getOpcode() == ISD::SHL) { if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) { if (SHC->getAPIntValue().isPowerOf2()) { EVT ADDVT = N1.getOperand(1).getValueType(); - SDLoc DL(N); SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), DAG.getConstant(SHC->getAPIntValue() @@ -2285,16 +2372,23 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { } } } - + // fold (udiv x, c) -> alternate - bool MinSize = DAG.getMachineFunction().getFunction()->optForMinSize(); - if (N1C && !TLI.isIntDivCheap(N->getValueType(0), MinSize)) + AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes(); + if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr)) if (SDValue Op = BuildUDIV(N)) return Op; + // sdiv, srem -> sdivrem + // If the divisor is constant, then return DIVREM only if isIntDivCheap() is true. + // Otherwise, we break the simplification logic in visitREM(). + if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue DivRem = useDivRem(N)) + return DivRem; + // undef / X -> 0 if (N0.getOpcode() == ISD::UNDEF) - return DAG.getConstant(0, SDLoc(N), VT); + return DAG.getConstant(0, DL, VT); // X / undef -> undef if (N1.getOpcode() == ISD::UNDEF) return N1; @@ -2302,102 +2396,83 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitSREM(SDNode *N) { +// handles ISD::SREM and ISD::UREM +SDValue DAGCombiner::visitREM(SDNode *N) { + unsigned Opcode = N->getOpcode(); SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); + bool isSigned = (Opcode == ISD::SREM); + SDLoc DL(N); - // fold (srem c1, c2) -> c1%c2 + // fold (rem c1, c2) -> c1%c2 ConstantSDNode *N0C = isConstOrConstSplat(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); if (N0C && N1C) - if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::SREM, SDLoc(N), VT, - N0C, N1C)) + if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C)) return Folded; - // If we know the sign bits of both operands are zero, strength reduce to a - // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15 - if (!VT.isVector()) { - if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) - return DAG.getNode(ISD::UREM, SDLoc(N), VT, N0, N1); - } - // If X/C can be simplified by the division-by-constant logic, lower - // X%C to the equivalent of X-X/C*C. - if (N1C && !N1C->isNullValue()) { - SDValue Div = DAG.getNode(ISD::SDIV, SDLoc(N), VT, N0, N1); - AddToWorklist(Div.getNode()); - SDValue OptimizedDiv = combine(Div.getNode()); - if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != Div.getNode()) { - SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, - OptimizedDiv, N1); - SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, N0, Mul); - AddToWorklist(Mul.getNode()); - return Sub; + if (isSigned) { + // If we know the sign bits of both operands are zero, strength reduce to a + // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15 + if (!VT.isVector()) { + if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) + return DAG.getNode(ISD::UREM, DL, VT, N0, N1); } - } - - // undef % X -> 0 - if (N0.getOpcode() == ISD::UNDEF) - return DAG.getConstant(0, SDLoc(N), VT); - // X % undef -> undef - if (N1.getOpcode() == ISD::UNDEF) - return N1; - - return SDValue(); -} - -SDValue DAGCombiner::visitUREM(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N->getValueType(0); - - // fold (urem c1, c2) -> c1%c2 - ConstantSDNode *N0C = isConstOrConstSplat(N0); - ConstantSDNode *N1C = isConstOrConstSplat(N1); - if (N0C && N1C) - if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UREM, SDLoc(N), VT, - N0C, N1C)) - return Folded; - // fold (urem x, pow2) -> (and x, pow2-1) - if (N1C && !N1C->isNullValue() && !N1C->isOpaque() && - N1C->getAPIntValue().isPowerOf2()) { - SDLoc DL(N); - return DAG.getNode(ISD::AND, DL, VT, N0, - DAG.getConstant(N1C->getAPIntValue() - 1, DL, VT)); - } - // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1)) - if (N1.getOpcode() == ISD::SHL) { - if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) { - if (SHC->getAPIntValue().isPowerOf2()) { - SDLoc DL(N); - SDValue Add = - DAG.getNode(ISD::ADD, DL, VT, N1, + } else { + // fold (urem x, pow2) -> (and x, pow2-1) + if (N1C && !N1C->isNullValue() && !N1C->isOpaque() && + N1C->getAPIntValue().isPowerOf2()) { + return DAG.getNode(ISD::AND, DL, VT, N0, + DAG.getConstant(N1C->getAPIntValue() - 1, DL, VT)); + } + // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1)) + if (N1.getOpcode() == ISD::SHL) { + if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) { + if (SHC->getAPIntValue().isPowerOf2()) { + SDValue Add = + DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getConstant(APInt::getAllOnesValue(VT.getSizeInBits()), DL, VT)); - AddToWorklist(Add.getNode()); - return DAG.getNode(ISD::AND, DL, VT, N0, Add); + AddToWorklist(Add.getNode()); + return DAG.getNode(ISD::AND, DL, VT, N0, Add); + } } } } + AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes(); + // If X/C can be simplified by the division-by-constant logic, lower // X%C to the equivalent of X-X/C*C. - if (N1C && !N1C->isNullValue()) { - SDValue Div = DAG.getNode(ISD::UDIV, SDLoc(N), VT, N0, N1); + // To avoid mangling nodes, this simplification requires that the combine() + // call for the speculative DIV must not cause a DIVREM conversion. We guard + // against this by skipping the simplification if isIntDivCheap(). When + // div is not cheap, combine will not return a DIVREM. Regardless, + // checking cheapness here makes sense since the simplification results in + // fatter code. + if (N1C && !N1C->isNullValue() && !TLI.isIntDivCheap(VT, Attr)) { + unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV; + SDValue Div = DAG.getNode(DivOpcode, DL, VT, N0, N1); AddToWorklist(Div.getNode()); SDValue OptimizedDiv = combine(Div.getNode()); if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != Div.getNode()) { - SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, - OptimizedDiv, N1); - SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, N0, Mul); + assert((OptimizedDiv.getOpcode() != ISD::UDIVREM) && + (OptimizedDiv.getOpcode() != ISD::SDIVREM)); + SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1); + SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); AddToWorklist(Mul.getNode()); return Sub; } } + // sdiv, srem -> sdivrem + if (SDValue DivRem = useDivRem(N)) + return DivRem.getValue(1); + // undef % X -> 0 if (N0.getOpcode() == ISD::UNDEF) - return DAG.getConstant(0, SDLoc(N), VT); + return DAG.getConstant(0, DL, VT); // X % undef -> undef if (N1.getOpcode() == ISD::UNDEF) return N1; @@ -2615,20 +2690,6 @@ SDValue DAGCombiner::visitUMULO(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitSDIVREM(SDNode *N) { - if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::SDIV, ISD::SREM)) - return Res; - - return SDValue(); -} - -SDValue DAGCombiner::visitUDIVREM(SDNode *N) { - if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::UDIV, ISD::UREM)) - return Res; - - return SDValue(); -} - SDValue DAGCombiner::visitIMINMAX(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -2874,10 +2935,13 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, if (Result != ISD::SETCC_INVALID && (!LegalOperations || (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) && - TLI.isOperationLegal(ISD::SETCC, - getSetCCResultType(N0.getSimpleValueType()))))) - return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), - LL, LR, Result); + TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) { + EVT CCVT = getSetCCResultType(LL.getValueType()); + if (N0.getValueType() == CCVT || + (!LegalOperations && N0.getValueType() == MVT::i1)) + return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), + LL, LR, Result); + } } } @@ -3129,7 +3193,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // be expensive (and would be wrong if the type is not byte sized). if (!LN0->isVolatile() && LoadedVT.bitsGT(ExtVT) && ExtVT.isRound() && (!LegalOperations || TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, - ExtVT))) { + ExtVT)) && + TLI.shouldReduceLoadWidth(LN0, ISD::ZEXTLOAD, ExtVT)) { EVT PtrType = LN0->getOperand(1).getValueType(); unsigned Alignment = LN0->getAlignment(); @@ -3532,10 +3597,13 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *LocReference) { if (Result != ISD::SETCC_INVALID && (!LegalOperations || (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) && - TLI.isOperationLegal(ISD::SETCC, - getSetCCResultType(N0.getValueType()))))) - return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), - LL, LR, Result); + TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) { + EVT CCVT = getSetCCResultType(LL.getValueType()); + if (N0.getValueType() == CCVT || + (!LegalOperations && N0.getValueType() == MVT::i1)) + return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), + LL, LR, Result); + } } } @@ -3732,7 +3800,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) { /// Match "(X shl/srl V1) & V2" where V2 may not be present. static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { if (Op.getOpcode() == ISD::AND) { - if (isa(Op.getOperand(1))) { + if (isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { Mask = Op.getOperand(1); Op = Op.getOperand(0); } else { @@ -3749,105 +3817,106 @@ static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { } // Return true if we can prove that, whenever Neg and Pos are both in the -// range [0, OpSize), Neg == (Pos == 0 ? 0 : OpSize - Pos). This means that +// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that // for two opposing shifts shift1 and shift2 and a value X with OpBits bits: // // (or (shift1 X, Neg), (shift2 X, Pos)) // // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate -// in direction shift1 by Neg. The range [0, OpSize) means that we only need +// in direction shift1 by Neg. The range [0, EltSize) means that we only need // to consider shift amounts with defined behavior. -static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned OpSize) { - // If OpSize is a power of 2 then: +static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) { + // If EltSize is a power of 2 then: // - // (a) (Pos == 0 ? 0 : OpSize - Pos) == (OpSize - Pos) & (OpSize - 1) - // (b) Neg == Neg & (OpSize - 1) whenever Neg is in [0, OpSize). + // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1) + // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize). // - // So if OpSize is a power of 2 and Neg is (and Neg', OpSize-1), we check + // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check // for the stronger condition: // - // Neg & (OpSize - 1) == (OpSize - Pos) & (OpSize - 1) [A] + // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A] // - // for all Neg and Pos. Since Neg & (OpSize - 1) == Neg' & (OpSize - 1) + // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1) // we can just replace Neg with Neg' for the rest of the function. // // In other cases we check for the even stronger condition: // - // Neg == OpSize - Pos [B] + // Neg == EltSize - Pos [B] // // for all Neg and Pos. Note that the (or ...) then invokes undefined - // behavior if Pos == 0 (and consequently Neg == OpSize). + // behavior if Pos == 0 (and consequently Neg == EltSize). // - // We could actually use [A] whenever OpSize is a power of 2, but the + // We could actually use [A] whenever EltSize is a power of 2, but the // only extra cases that it would match are those uninteresting ones // where Neg and Pos are never in range at the same time. E.g. for - // OpSize == 32, using [A] would allow a Neg of the form (sub 64, Pos) + // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos) // as well as (sub 32, Pos), but: // // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos)) // // always invokes undefined behavior for 32-bit X. // - // Below, Mask == OpSize - 1 when using [A] and is all-ones otherwise. + // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise. unsigned MaskLoBits = 0; - if (Neg.getOpcode() == ISD::AND && - isPowerOf2_64(OpSize) && - Neg.getOperand(1).getOpcode() == ISD::Constant && - cast(Neg.getOperand(1))->getAPIntValue() == OpSize - 1) { - Neg = Neg.getOperand(0); - MaskLoBits = Log2_64(OpSize); + if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) { + if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) { + if (NegC->getAPIntValue() == EltSize - 1) { + Neg = Neg.getOperand(0); + MaskLoBits = Log2_64(EltSize); + } + } } // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1. if (Neg.getOpcode() != ISD::SUB) return 0; - ConstantSDNode *NegC = dyn_cast(Neg.getOperand(0)); + ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0)); if (!NegC) return 0; SDValue NegOp1 = Neg.getOperand(1); - // On the RHS of [A], if Pos is Pos' & (OpSize - 1), just replace Pos with + // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with // Pos'. The truncation is redundant for the purpose of the equality. - if (MaskLoBits && - Pos.getOpcode() == ISD::AND && - Pos.getOperand(1).getOpcode() == ISD::Constant && - cast(Pos.getOperand(1))->getAPIntValue() == OpSize - 1) - Pos = Pos.getOperand(0); + if (MaskLoBits && Pos.getOpcode() == ISD::AND) + if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) + if (PosC->getAPIntValue() == EltSize - 1) + Pos = Pos.getOperand(0); // The condition we need is now: // - // (NegC - NegOp1) & Mask == (OpSize - Pos) & Mask + // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask // // If NegOp1 == Pos then we need: // - // OpSize & Mask == NegC & Mask + // EltSize & Mask == NegC & Mask // // (because "x & Mask" is a truncation and distributes through subtraction). APInt Width; if (Pos == NegOp1) Width = NegC->getAPIntValue(); + // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC. // Then the condition we want to prove becomes: // - // (NegC - NegOp1) & Mask == (OpSize - (NegOp1 + PosC)) & Mask + // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask // // which, again because "x & Mask" is a truncation, becomes: // - // NegC & Mask == (OpSize - PosC) & Mask - // OpSize & Mask == (NegC + PosC) & Mask - else if (Pos.getOpcode() == ISD::ADD && - Pos.getOperand(0) == NegOp1 && - Pos.getOperand(1).getOpcode() == ISD::Constant) - Width = (cast(Pos.getOperand(1))->getAPIntValue() + - NegC->getAPIntValue()); - else + // NegC & Mask == (EltSize - PosC) & Mask + // EltSize & Mask == (NegC + PosC) & Mask + else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) { + if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) + Width = PosC->getAPIntValue() + NegC->getAPIntValue(); + else + return false; + } else return false; - // Now we just need to check that OpSize & Mask == Width & Mask. + // Now we just need to check that EltSize & Mask == Width & Mask. if (MaskLoBits) - // Opsize & Mask is 0 since Mask is Opsize - 1. + // EltSize & Mask is 0 since Mask is EltSize - 1. return Width.getLoBits(MaskLoBits) == 0; - return Width == OpSize; + return Width == EltSize; } // A subroutine of MatchRotate used once we have found an OR of two opposite @@ -3867,7 +3936,7 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, // (srl x, (*ext y))) -> // (rotr x, y) or (rotl x, (sub 32, y)) EVT VT = Shifted.getValueType(); - if (matchRotateSub(InnerPos, InnerNeg, VT.getSizeInBits())) { + if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits())) { bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted, HasPos ? Pos : Neg).getNode(); @@ -3910,10 +3979,10 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { if (RHSShift.getOpcode() == ISD::SHL) { std::swap(LHS, RHS); std::swap(LHSShift, RHSShift); - std::swap(LHSMask , RHSMask ); + std::swap(LHSMask, RHSMask); } - unsigned OpSizeInBits = VT.getSizeInBits(); + unsigned EltSizeInBits = VT.getScalarSizeInBits(); SDValue LHSShiftArg = LHSShift.getOperand(0); SDValue LHSShiftAmt = LHSShift.getOperand(1); SDValue RHSShiftArg = RHSShift.getOperand(0); @@ -3921,11 +3990,10 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) - if (LHSShiftAmt.getOpcode() == ISD::Constant && - RHSShiftAmt.getOpcode() == ISD::Constant) { - uint64_t LShVal = cast(LHSShiftAmt)->getZExtValue(); - uint64_t RShVal = cast(RHSShiftAmt)->getZExtValue(); - if ((LShVal + RShVal) != OpSizeInBits) + if (isConstOrConstSplat(LHSShiftAmt) && isConstOrConstSplat(RHSShiftAmt)) { + uint64_t LShVal = isConstOrConstSplat(LHSShiftAmt)->getZExtValue(); + uint64_t RShVal = isConstOrConstSplat(RHSShiftAmt)->getZExtValue(); + if ((LShVal + RShVal) != EltSizeInBits) return nullptr; SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, @@ -3933,18 +4001,23 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { // If there is an AND of either shifted operand, apply it to the result. if (LHSMask.getNode() || RHSMask.getNode()) { - APInt Mask = APInt::getAllOnesValue(OpSizeInBits); + APInt AllBits = APInt::getAllOnesValue(EltSizeInBits); + SDValue Mask = DAG.getConstant(AllBits, DL, VT); if (LHSMask.getNode()) { - APInt RHSBits = APInt::getLowBitsSet(OpSizeInBits, LShVal); - Mask &= cast(LHSMask)->getAPIntValue() | RHSBits; + APInt RHSBits = APInt::getLowBitsSet(EltSizeInBits, LShVal); + Mask = DAG.getNode(ISD::AND, DL, VT, Mask, + DAG.getNode(ISD::OR, DL, VT, LHSMask, + DAG.getConstant(RHSBits, DL, VT))); } if (RHSMask.getNode()) { - APInt LHSBits = APInt::getHighBitsSet(OpSizeInBits, RShVal); - Mask &= cast(RHSMask)->getAPIntValue() | LHSBits; + APInt LHSBits = APInt::getHighBitsSet(EltSizeInBits, RShVal); + Mask = DAG.getNode(ISD::AND, DL, VT, Mask, + DAG.getNode(ISD::OR, DL, VT, RHSMask, + DAG.getConstant(LHSBits, DL, VT))); } - Rot = DAG.getNode(ISD::AND, DL, VT, Rot, DAG.getConstant(Mask, DL, VT)); + Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask); } return Rot.getNode(); @@ -4458,8 +4531,9 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2) if (N1C && N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse()) { if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { - SDValue Folded = DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, N0C1, N1C); - return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Folded); + if (SDValue Folded = + DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, N0C1, N1C)) + return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Folded); } } @@ -6052,7 +6126,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { if (!VT.isVector()) { EVT SetCCVT = getSetCCResultType(N0.getOperand(0).getValueType()); - if (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, SetCCVT)) { + if (!LegalOperations || + TLI.isOperationLegal(ISD::SETCC, N0.getOperand(0).getValueType())) { SDLoc DL(N); ISD::CondCode CC = cast(N0.getOperand(2))->get(); SDValue SetCC = DAG.getSetCC(DL, SetCCVT, @@ -6774,8 +6849,11 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { unsigned VTBits = VT.getScalarType().getSizeInBits(); unsigned EVTBits = EVT.getScalarType().getSizeInBits(); + if (N0.isUndef()) + return DAG.getUNDEF(VT); + // fold (sext_in_reg c1) -> c1 - if (isa(N0) || N0.getOpcode() == ISD::UNDEF) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1); // If the input is already sign extended, just drop the extension. @@ -6868,29 +6946,6 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { BSwap, N1); } - // Fold a sext_inreg of a build_vector of ConstantSDNodes or undefs - // into a build_vector. - if (ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) { - SmallVector Elts; - unsigned NumElts = N0->getNumOperands(); - unsigned ShAmt = VTBits - EVTBits; - - for (unsigned i = 0; i != NumElts; ++i) { - SDValue Op = N0->getOperand(i); - if (Op->getOpcode() == ISD::UNDEF) { - Elts.push_back(Op); - continue; - } - - ConstantSDNode *CurrentND = cast(Op); - const APInt &C = APInt(VTBits, CurrentND->getAPIntValue().getZExtValue()); - Elts.push_back(DAG.getConstant(C.shl(ShAmt).ashr(ShAmt).getZExtValue(), - SDLoc(Op), Op.getValueType())); - } - - return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Elts); - } - return SDValue(); } @@ -7466,25 +7521,23 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDLoc SL(N); const TargetOptions &Options = DAG.getTarget().Options; - bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast || - Options.UnsafeFPMath); + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); // Floating-point multiply-add with intermediate rounding. - bool HasFMAD = (LegalOperations && - TLI.isOperationLegal(ISD::FMAD, VT)); + bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); // Floating-point multiply-add without intermediate rounding. - bool HasFMA = ((!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::FMA, VT)) && - TLI.isFMAFasterThanFMulAndFAdd(VT) && - UnsafeFPMath); + bool HasFMA = + AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) return SDValue(); // Always prefer FMAD to FMA for precision. - unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; + unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); bool LookThroughFPExt = TLI.isFPExtFree(VT); @@ -7512,7 +7565,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { } // Look through FP_EXTEND nodes to do more combining. - if (UnsafeFPMath && LookThroughFPExt) { + if (AllowFusion && LookThroughFPExt) { // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); @@ -7538,7 +7591,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { } // More folding opportunities when target permits. - if ((UnsafeFPMath || HasFMAD) && Aggressive) { + if ((AllowFusion || HasFMAD) && Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) if (N0.getOpcode() == PreferredFusedOpcode && N0.getOperand(2).getOpcode() == ISD::FMUL) { @@ -7561,7 +7614,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { N0)); } - if (UnsafeFPMath && LookThroughFPExt) { + if (AllowFusion && LookThroughFPExt) { // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&] ( @@ -7651,25 +7704,23 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { SDLoc SL(N); const TargetOptions &Options = DAG.getTarget().Options; - bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast || - Options.UnsafeFPMath); + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); // Floating-point multiply-add with intermediate rounding. - bool HasFMAD = (LegalOperations && - TLI.isOperationLegal(ISD::FMAD, VT)); + bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); // Floating-point multiply-add without intermediate rounding. - bool HasFMA = ((!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::FMA, VT)) && - TLI.isFMAFasterThanFMulAndFAdd(VT) && - UnsafeFPMath); + bool HasFMA = + AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) return SDValue(); // Always prefer FMAD to FMA for precision. - unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; + unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); bool LookThroughFPExt = TLI.isFPExtFree(VT); @@ -7702,7 +7753,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { } // Look through FP_EXTEND nodes to do more combining. - if (UnsafeFPMath && LookThroughFPExt) { + if (AllowFusion && LookThroughFPExt) { // fold (fsub (fpext (fmul x, y)), z) // -> (fma (fpext x), (fpext y), (fneg z)) if (N0.getOpcode() == ISD::FP_EXTEND) { @@ -7778,7 +7829,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { } // More folding opportunities when target permits. - if ((UnsafeFPMath || HasFMAD) && Aggressive) { + if ((AllowFusion || HasFMAD) && Aggressive) { // fold (fsub (fma x, y, (fmul u, v)), z) // -> (fma x, y (fma u, v, (fneg z))) if (N0.getOpcode() == PreferredFusedOpcode && @@ -7808,7 +7859,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N21, N0)); } - if (UnsafeFPMath && LookThroughFPExt) { + if (AllowFusion && LookThroughFPExt) { // fold (fsub (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) if (N0.getOpcode() == PreferredFusedOpcode) { @@ -7909,14 +7960,97 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { return SDValue(); } +/// Try to perform FMA combining on a given FMUL node. +SDValue DAGCombiner::visitFMULForFMACombine(SDNode *N) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getValueType(0); + SDLoc SL(N); + + assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation"); + + const TargetOptions &Options = DAG.getTarget().Options; + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); + + // Floating-point multiply-add with intermediate rounding. + bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); + + // Floating-point multiply-add without intermediate rounding. + bool HasFMA = + AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); + + // No valid opcode, do not combine. + if (!HasFMAD && !HasFMA) + return SDValue(); + + // Always prefer FMAD to FMA for precision. + unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; + bool Aggressive = TLI.enableAggressiveFMAFusion(VT); + + // fold (fmul (fadd x, +1.0), y) -> (fma x, y, y) + // fold (fmul (fadd x, -1.0), y) -> (fma x, y, (fneg y)) + auto FuseFADD = [&](SDValue X, SDValue Y) { + if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) { + auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); + if (XC1 && XC1->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y); + if (XC1 && XC1->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y)); + } + return SDValue(); + }; + + if (SDValue FMA = FuseFADD(N0, N1)) + return FMA; + if (SDValue FMA = FuseFADD(N1, N0)) + return FMA; + + // fold (fmul (fsub +1.0, x), y) -> (fma (fneg x), y, y) + // fold (fmul (fsub -1.0, x), y) -> (fma (fneg x), y, (fneg y)) + // fold (fmul (fsub x, +1.0), y) -> (fma x, y, (fneg y)) + // fold (fmul (fsub x, -1.0), y) -> (fma x, y, y) + auto FuseFSUB = [&](SDValue X, SDValue Y) { + if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) { + auto XC0 = isConstOrConstSplatFP(X.getOperand(0)); + if (XC0 && XC0->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, + Y); + if (XC0 && XC0->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y)); + + auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); + if (XC1 && XC1->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y)); + if (XC1 && XC1->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y); + } + return SDValue(); + }; + + if (SDValue FMA = FuseFSUB(N0, N1)) + return FMA; + if (SDValue FMA = FuseFSUB(N1, N0)) + return FMA; + + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantFPSDNode *N0CFP = dyn_cast(N0); - ConstantFPSDNode *N1CFP = dyn_cast(N1); + bool N0CFP = isConstantFPBuildVectorOrConstantFP(N0); + bool N1CFP = isConstantFPBuildVectorOrConstantFP(N1); EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; + const SDNodeFlags *Flags = &cast(N)->Flags; // fold vector ops if (VT.isVector()) @@ -7925,23 +8059,23 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { // fold (fadd c1, c2) -> c1 + c2 if (N0CFP && N1CFP) - return DAG.getNode(ISD::FADD, DL, VT, N0, N1); + return DAG.getNode(ISD::FADD, DL, VT, N0, N1, Flags); // canonicalize constant to RHS if (N0CFP && !N1CFP) - return DAG.getNode(ISD::FADD, DL, VT, N1, N0); + return DAG.getNode(ISD::FADD, DL, VT, N1, N0, Flags); // fold (fadd A, (fneg B)) -> (fsub A, B) if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) && isNegatibleForFree(N1, LegalOperations, TLI, &Options) == 2) return DAG.getNode(ISD::FSUB, DL, VT, N0, - GetNegatedExpression(N1, DAG, LegalOperations)); + GetNegatedExpression(N1, DAG, LegalOperations), Flags); // fold (fadd (fneg A), B) -> (fsub B, A) if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) && isNegatibleForFree(N0, LegalOperations, TLI, &Options) == 2) return DAG.getNode(ISD::FSUB, DL, VT, N1, - GetNegatedExpression(N0, DAG, LegalOperations)); + GetNegatedExpression(N0, DAG, LegalOperations), Flags); // If 'unsafe math' is enabled, fold lots of things. if (Options.UnsafeFPMath) { @@ -7950,14 +8084,17 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { bool AllowNewConst = (Level < AfterLegalizeDAG); // fold (fadd A, 0) -> A - if (N1CFP && N1CFP->isZero()) - return N0; + if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1)) + if (N1C->isZero()) + return N0; // fold (fadd (fadd x, c1), c2) -> (fadd x, (fadd c1, c2)) if (N1CFP && N0.getOpcode() == ISD::FADD && N0.getNode()->hasOneUse() && - isa(N0.getOperand(1))) + isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), - DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1)); + DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1, + Flags), + Flags); // If allowed, fold (fadd (fneg x), x) -> 0.0 if (AllowNewConst && N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1) @@ -7972,64 +8109,64 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { // of rounding steps. if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) { if (N0.getOpcode() == ISD::FMUL) { - ConstantFPSDNode *CFP00 = dyn_cast(N0.getOperand(0)); - ConstantFPSDNode *CFP01 = dyn_cast(N0.getOperand(1)); + bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0)); + bool CFP01 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(1)); // (fadd (fmul x, c), x) -> (fmul x, c+1) if (CFP01 && !CFP00 && N0.getOperand(0) == N1) { - SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, SDValue(CFP01, 0), - DAG.getConstantFP(1.0, DL, VT)); - return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP); + SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), + DAG.getConstantFP(1.0, DL, VT), Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP, Flags); } // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2) if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD && N1.getOperand(0) == N1.getOperand(1) && N0.getOperand(0) == N1.getOperand(0)) { - SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, SDValue(CFP01, 0), - DAG.getConstantFP(2.0, DL, VT)); - return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP); + SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), + DAG.getConstantFP(2.0, DL, VT), Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP, Flags); } } if (N1.getOpcode() == ISD::FMUL) { - ConstantFPSDNode *CFP10 = dyn_cast(N1.getOperand(0)); - ConstantFPSDNode *CFP11 = dyn_cast(N1.getOperand(1)); + bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0)); + bool CFP11 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(1)); // (fadd x, (fmul x, c)) -> (fmul x, c+1) if (CFP11 && !CFP10 && N1.getOperand(0) == N0) { - SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, SDValue(CFP11, 0), - DAG.getConstantFP(1.0, DL, VT)); - return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP); + SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1), + DAG.getConstantFP(1.0, DL, VT), Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP, Flags); } // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2) if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD && N0.getOperand(0) == N0.getOperand(1) && N1.getOperand(0) == N0.getOperand(0)) { - SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, SDValue(CFP11, 0), - DAG.getConstantFP(2.0, DL, VT)); - return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP); + SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1), + DAG.getConstantFP(2.0, DL, VT), Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP, Flags); } } if (N0.getOpcode() == ISD::FADD && AllowNewConst) { - ConstantFPSDNode *CFP = dyn_cast(N0.getOperand(0)); + bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0)); // (fadd (fadd x, x), x) -> (fmul x, 3.0) - if (!CFP && N0.getOperand(0) == N0.getOperand(1) && + if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) && (N0.getOperand(0) == N1)) { return DAG.getNode(ISD::FMUL, DL, VT, - N1, DAG.getConstantFP(3.0, DL, VT)); + N1, DAG.getConstantFP(3.0, DL, VT), Flags); } } if (N1.getOpcode() == ISD::FADD && AllowNewConst) { - ConstantFPSDNode *CFP10 = dyn_cast(N1.getOperand(0)); + bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0)); // (fadd x, (fadd x, x)) -> (fmul x, 3.0) if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) && N1.getOperand(0) == N0) { return DAG.getNode(ISD::FMUL, DL, VT, - N0, DAG.getConstantFP(3.0, DL, VT)); + N0, DAG.getConstantFP(3.0, DL, VT), Flags); } } @@ -8039,8 +8176,8 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { N0.getOperand(0) == N0.getOperand(1) && N1.getOperand(0) == N1.getOperand(1) && N0.getOperand(0) == N1.getOperand(0)) { - return DAG.getNode(ISD::FMUL, DL, VT, - N0.getOperand(0), DAG.getConstantFP(4.0, DL, VT)); + return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), + DAG.getConstantFP(4.0, DL, VT), Flags); } } } // enable-unsafe-fp-math @@ -8062,6 +8199,7 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { EVT VT = N->getValueType(0); SDLoc dl(N); const TargetOptions &Options = DAG.getTarget().Options; + const SDNodeFlags *Flags = &cast(N)->Flags; // fold vector ops if (VT.isVector()) @@ -8070,12 +8208,12 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { // fold (fsub c1, c2) -> c1-c2 if (N0CFP && N1CFP) - return DAG.getNode(ISD::FSUB, dl, VT, N0, N1); + return DAG.getNode(ISD::FSUB, dl, VT, N0, N1, Flags); // fold (fsub A, (fneg B)) -> (fadd A, B) if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) return DAG.getNode(ISD::FADD, dl, VT, N0, - GetNegatedExpression(N1, DAG, LegalOperations)); + GetNegatedExpression(N1, DAG, LegalOperations), Flags); // If 'unsafe math' is enabled, fold lots of things. if (Options.UnsafeFPMath) { @@ -8126,6 +8264,7 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; + const SDNodeFlags *Flags = &cast(N)->Flags; // fold vector ops if (VT.isVector()) { @@ -8136,12 +8275,12 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { // fold (fmul c1, c2) -> c1*c2 if (N0CFP && N1CFP) - return DAG.getNode(ISD::FMUL, DL, VT, N0, N1); + return DAG.getNode(ISD::FMUL, DL, VT, N0, N1, Flags); // canonicalize constant to RHS if (isConstantFPBuildVectorOrConstantFP(N0) && !isConstantFPBuildVectorOrConstantFP(N1)) - return DAG.getNode(ISD::FMUL, DL, VT, N1, N0); + return DAG.getNode(ISD::FMUL, DL, VT, N1, N0, Flags); // fold (fmul A, 1.0) -> A if (N1CFP && N1CFP->isExactlyValue(1.0)) @@ -8170,8 +8309,8 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { // the second operand of the outer multiply are constants. if ((N1CFP && isConstOrConstSplatFP(N01)) || (BV1 && BV01 && BV1->isConstant() && BV01->isConstant())) { - SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1); - return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts); + SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags); } } } @@ -8184,14 +8323,14 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { (N0.getOperand(0) == N0.getOperand(1)) && N0.hasOneUse()) { const SDValue Two = DAG.getConstantFP(2.0, DL, VT); - SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1); - return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts); + SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1, Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts, Flags); } } // fold (fmul X, 2.0) -> (fadd X, X) if (N1CFP && N1CFP->isExactlyValue(+2.0)) - return DAG.getNode(ISD::FADD, DL, VT, N0, N0); + return DAG.getNode(ISD::FADD, DL, VT, N0, N0, Flags); // fold (fmul X, -1.0) -> (fneg X) if (N1CFP && N1CFP->isExactlyValue(-1.0)) @@ -8206,10 +8345,17 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { if (LHSNeg == 2 || RHSNeg == 2) return DAG.getNode(ISD::FMUL, DL, VT, GetNegatedExpression(N0, DAG, LegalOperations), - GetNegatedExpression(N1, DAG, LegalOperations)); + GetNegatedExpression(N1, DAG, LegalOperations), + Flags); } } + // FMUL -> FMA combines: + if (SDValue Fused = visitFMULForFMACombine(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); } @@ -8236,62 +8382,77 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { if (N1CFP && N1CFP->isZero()) return N2; } + // TODO: The FMA node should have flags that propagate to these nodes. if (N0CFP && N0CFP->isExactlyValue(1.0)) return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2); if (N1CFP && N1CFP->isExactlyValue(1.0)) return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2); // Canonicalize (fma c, x, y) -> (fma x, c, y) - if (N0CFP && !N1CFP) + if (isConstantFPBuildVectorOrConstantFP(N0) && + !isConstantFPBuildVectorOrConstantFP(N1)) return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2); - // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) - if (Options.UnsafeFPMath && N1CFP && - N2.getOpcode() == ISD::FMUL && - N0 == N2.getOperand(0) && - N2.getOperand(1).getOpcode() == ISD::ConstantFP) { - return DAG.getNode(ISD::FMUL, dl, VT, N0, - DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1))); - } + // TODO: FMA nodes should have flags that propagate to the created nodes. + // For now, create a Flags object for use with all unsafe math transforms. + SDNodeFlags Flags; + Flags.setUnsafeAlgebra(true); + if (Options.UnsafeFPMath) { + // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) + if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) && + isConstantFPBuildVectorOrConstantFP(N1) && + isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) { + return DAG.getNode(ISD::FMUL, dl, VT, N0, + DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1), + &Flags), &Flags); + } - // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) - if (Options.UnsafeFPMath && - N0.getOpcode() == ISD::FMUL && N1CFP && - N0.getOperand(1).getOpcode() == ISD::ConstantFP) { - return DAG.getNode(ISD::FMA, dl, VT, - N0.getOperand(0), - DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1)), - N2); + // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) + if (N0.getOpcode() == ISD::FMUL && + isConstantFPBuildVectorOrConstantFP(N1) && + isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) { + return DAG.getNode(ISD::FMA, dl, VT, + N0.getOperand(0), + DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1), + &Flags), + N2); + } } // (fma x, 1, y) -> (fadd x, y) // (fma x, -1, y) -> (fadd (fneg x), y) if (N1CFP) { if (N1CFP->isExactlyValue(1.0)) + // TODO: The FMA node should have flags that propagate to this node. return DAG.getNode(ISD::FADD, dl, VT, N0, N2); if (N1CFP->isExactlyValue(-1.0) && (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) { SDValue RHSNeg = DAG.getNode(ISD::FNEG, dl, VT, N0); AddToWorklist(RHSNeg.getNode()); + // TODO: The FMA node should have flags that propagate to this node. return DAG.getNode(ISD::FADD, dl, VT, N2, RHSNeg); } } - // (fma x, c, x) -> (fmul x, (c+1)) - if (Options.UnsafeFPMath && N1CFP && N0 == N2) - return DAG.getNode(ISD::FMUL, dl, VT, N0, - DAG.getNode(ISD::FADD, dl, VT, - N1, DAG.getConstantFP(1.0, dl, VT))); - - // (fma x, c, (fneg x)) -> (fmul x, (c-1)) - if (Options.UnsafeFPMath && N1CFP && - N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) + if (Options.UnsafeFPMath) { + // (fma x, c, x) -> (fmul x, (c+1)) + if (N1CFP && N0 == N2) { return DAG.getNode(ISD::FMUL, dl, VT, N0, - DAG.getNode(ISD::FADD, dl, VT, - N1, DAG.getConstantFP(-1.0, dl, VT))); + DAG.getNode(ISD::FADD, dl, VT, + N1, DAG.getConstantFP(1.0, dl, VT), + &Flags), &Flags); + } + // (fma x, c, (fneg x)) -> (fmul x, (c-1)) + if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) { + return DAG.getNode(ISD::FMUL, dl, VT, N0, + DAG.getNode(ISD::FADD, dl, VT, + N1, DAG.getConstantFP(-1.0, dl, VT), + &Flags), &Flags); + } + } return SDValue(); } @@ -8304,7 +8465,9 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL". SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) { - if (!DAG.getTarget().Options.UnsafeFPMath) + bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath; + const SDNodeFlags *Flags = N->getFlags(); + if (!UnsafeMath && !Flags->hasAllowReciprocal()) return SDValue(); // Skip if current node is a reciprocal. @@ -8323,9 +8486,14 @@ SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) { // Find all FDIV users of the same divisor. // Use a set because duplicates may be present in the user list. SetVector Users; - for (auto *U : N1->uses()) - if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) - Users.insert(U); + for (auto *U : N1->uses()) { + if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) { + // This division is eligible for optimization only if global unsafe math + // is enabled or if this division allows reciprocal formation. + if (UnsafeMath || U->getFlags()->hasAllowReciprocal()) + Users.insert(U); + } + } // Now that we have the actual number of divisor uses, make sure it meets // the minimum threshold specified by the target. @@ -8335,17 +8503,14 @@ SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); SDValue FPOne = DAG.getConstantFP(1.0, DL, VT); - // FIXME: This optimization requires some level of fast-math, so the - // created reciprocal node should at least have the 'allowReciprocal' - // fast-math-flag set. - SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1); + SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags); // Dividend / Divisor -> Dividend * Reciprocal for (auto *U : Users) { SDValue Dividend = U->getOperand(0); if (Dividend != FPOne) { SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend, - Reciprocal); + Reciprocal, Flags); CombineTo(U, NewNode); } else if (U != Reciprocal.getNode()) { // In the absence of fast-math-flags, this user node is always the @@ -8364,6 +8529,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; + SDNodeFlags *Flags = &cast(N)->Flags; // fold vector ops if (VT.isVector()) @@ -8372,7 +8538,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { // fold (fdiv c1, c2) -> c1/c2 if (N0CFP && N1CFP) - return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1); + return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1, Flags); if (Options.UnsafeFPMath) { // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable. @@ -8391,28 +8557,30 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { TLI.isOperationLegal(llvm::ISD::ConstantFP, VT) || TLI.isFPImmLegal(Recip, VT))) return DAG.getNode(ISD::FMUL, DL, VT, N0, - DAG.getConstantFP(Recip, DL, VT)); + DAG.getConstantFP(Recip, DL, VT), Flags); } // If this FDIV is part of a reciprocal square root, it may be folded // into a target-specific square root estimate instruction. if (N1.getOpcode() == ISD::FSQRT) { - if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0))) { - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0), Flags)) { + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } else if (N1.getOpcode() == ISD::FP_EXTEND && N1.getOperand(0).getOpcode() == ISD::FSQRT) { - if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0))) { + if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0), + Flags)) { RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV); AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } else if (N1.getOpcode() == ISD::FP_ROUND && N1.getOperand(0).getOpcode() == ISD::FSQRT) { - if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0))) { + if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0), + Flags)) { RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1)); AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } else if (N1.getOpcode() == ISD::FMUL) { // Look through an FMUL. Even though this won't remove the FDIV directly, @@ -8429,18 +8597,18 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { if (SqrtOp.getNode()) { // We found a FSQRT, so try to make this fold: // x / (y * sqrt(z)) -> x * (rsqrt(z) / y) - if (SDValue RV = BuildRsqrtEstimate(SqrtOp.getOperand(0))) { - RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp); + if (SDValue RV = BuildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) { + RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags); AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } } // Fold into a reciprocal estimate and multiply instead of a real divide. - if (SDValue RV = BuildReciprocalEstimate(N1)) { + if (SDValue RV = BuildReciprocalEstimate(N1, Flags)) { AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } @@ -8452,7 +8620,8 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { if (LHSNeg == 2 || RHSNeg == 2) return DAG.getNode(ISD::FDIV, SDLoc(N), VT, GetNegatedExpression(N0, DAG, LegalOperations), - GetNegatedExpression(N1, DAG, LegalOperations)); + GetNegatedExpression(N1, DAG, LegalOperations), + Flags); } } @@ -8471,7 +8640,8 @@ SDValue DAGCombiner::visitFREM(SDNode *N) { // fold (frem c1, c2) -> fmod(c1,c2) if (N0CFP && N1CFP) - return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1); + return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1, + &cast(N)->Flags); return SDValue(); } @@ -8480,20 +8650,25 @@ SDValue DAGCombiner::visitFSQRT(SDNode *N) { if (!DAG.getTarget().Options.UnsafeFPMath || TLI.isFsqrtCheap()) return SDValue(); + // TODO: FSQRT nodes should have flags that propagate to the created nodes. + // For now, create a Flags object for use with all unsafe math transforms. + SDNodeFlags Flags; + Flags.setUnsafeAlgebra(true); + // Compute this as X * (1/sqrt(X)) = X * (X ** -0.5) - SDValue RV = BuildRsqrtEstimate(N->getOperand(0)); + SDValue RV = BuildRsqrtEstimate(N->getOperand(0), &Flags); if (!RV) return SDValue(); EVT VT = RV.getValueType(); SDLoc DL(N); - RV = DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV); + RV = DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV, &Flags); AddToWorklist(RV.getNode()); // Unfortunately, RV is now NaN if the input was exactly 0. // Select out this case and force the answer to 0. SDValue Zero = DAG.getConstantFP(0.0, DL, VT); - EVT CCVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + EVT CCVT = getSetCCResultType(VT); SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, N->getOperand(0), Zero, ISD::SETEQ); AddToWorklist(ZeroCmp.getNode()); AddToWorklist(RV.getNode()); @@ -8902,9 +9077,10 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { if (Level >= AfterLegalizeDAG && (TLI.isFPImmLegal(CVal, N->getValueType(0)) || TLI.isOperationLegal(ISD::ConstantFP, N->getValueType(0)))) - return DAG.getNode( - ISD::FMUL, SDLoc(N), VT, N0.getOperand(0), - DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0.getOperand(1))); + return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0.getOperand(0), + DAG.getNode(ISD::FNEG, SDLoc(N), VT, + N0.getOperand(1)), + &cast(N0)->Flags); } } @@ -8914,20 +9090,20 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { SDValue DAGCombiner::visitFMINNUM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - const ConstantFPSDNode *N0CFP = dyn_cast(N0); - const ConstantFPSDNode *N1CFP = dyn_cast(N1); + EVT VT = N->getValueType(0); + const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); + const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); if (N0CFP && N1CFP) { const APFloat &C0 = N0CFP->getValueAPF(); const APFloat &C1 = N1CFP->getValueAPF(); - return DAG.getConstantFP(minnum(C0, C1), SDLoc(N), N->getValueType(0)); + return DAG.getConstantFP(minnum(C0, C1), SDLoc(N), VT); } - if (N0CFP) { - EVT VT = N->getValueType(0); - // Canonicalize to constant on RHS. + // Canonicalize to constant on RHS. + if (isConstantFPBuildVectorOrConstantFP(N0) && + !isConstantFPBuildVectorOrConstantFP(N1)) return DAG.getNode(ISD::FMINNUM, SDLoc(N), VT, N1, N0); - } return SDValue(); } @@ -8935,20 +9111,20 @@ SDValue DAGCombiner::visitFMINNUM(SDNode *N) { SDValue DAGCombiner::visitFMAXNUM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - const ConstantFPSDNode *N0CFP = dyn_cast(N0); - const ConstantFPSDNode *N1CFP = dyn_cast(N1); + EVT VT = N->getValueType(0); + const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); + const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); if (N0CFP && N1CFP) { const APFloat &C0 = N0CFP->getValueAPF(); const APFloat &C1 = N1CFP->getValueAPF(); - return DAG.getConstantFP(maxnum(C0, C1), SDLoc(N), N->getValueType(0)); + return DAG.getConstantFP(maxnum(C0, C1), SDLoc(N), VT); } - if (N0CFP) { - EVT VT = N->getValueType(0); - // Canonicalize to constant on RHS. + // Canonicalize to constant on RHS. + if (isConstantFPBuildVectorOrConstantFP(N0) && + !isConstantFPBuildVectorOrConstantFP(N1)) return DAG.getNode(ISD::FMAXNUM, SDLoc(N), VT, N1, N0); - } return SDValue(); } @@ -10687,30 +10863,109 @@ struct BaseIndexOffset { }; } // namespace +// This is a helper function for visitMUL to check the profitability +// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2). +// MulNode is the original multiply, AddNode is (add x, c1), +// and ConstNode is c2. +// +// If the (add x, c1) has multiple uses, we could increase +// the number of adds if we make this transformation. +// It would only be worth doing this if we can remove a +// multiply in the process. Check for that here. +// To illustrate: +// (A + c1) * c3 +// (A + c2) * c3 +// We're checking for cases where we have common "c3 * A" expressions. +bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, + SDValue &AddNode, + SDValue &ConstNode) { + APInt Val; + + // If the add only has one use, this would be OK to do. + if (AddNode.getNode()->hasOneUse()) + return true; + + // Walk all the users of the constant with which we're multiplying. + for (SDNode *Use : ConstNode->uses()) { + + if (Use == MulNode) // This use is the one we're on right now. Skip it. + continue; + + if (Use->getOpcode() == ISD::MUL) { // We have another multiply use. + SDNode *OtherOp; + SDNode *MulVar = AddNode.getOperand(0).getNode(); + + // OtherOp is what we're multiplying against the constant. + if (Use->getOperand(0) == ConstNode) + OtherOp = Use->getOperand(1).getNode(); + else + OtherOp = Use->getOperand(0).getNode(); + + // Check to see if multiply is with the same operand of our "add". + // + // ConstNode = CONST + // Use = ConstNode * A <-- visiting Use. OtherOp is A. + // ... + // AddNode = (A + c1) <-- MulVar is A. + // = AddNode * ConstNode <-- current visiting instruction. + // + // If we make this transformation, we will have a common + // multiply (ConstNode * A) that we can save. + if (OtherOp == MulVar) + return true; + + // Now check to see if a future expansion will give us a common + // multiply. + // + // ConstNode = CONST + // AddNode = (A + c1) + // ... = AddNode * ConstNode <-- current visiting instruction. + // ... + // OtherOp = (A + c2) + // Use = OtherOp * ConstNode <-- visiting Use. + // + // If we make this transformation, we will have a common + // multiply (CONST * A) after we also do the same transformation + // to the "t2" instruction. + if (OtherOp->getOpcode() == ISD::ADD && + isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) && + OtherOp->getOperand(0).getNode() == MulVar) + return true; + } + } + + // Didn't find a case where this would be profitable. + return false; +} + SDValue DAGCombiner::getMergedConstantVectorStore(SelectionDAG &DAG, SDLoc SL, ArrayRef Stores, + SmallVectorImpl &Chains, EVT Ty) const { SmallVector BuildVector; - for (unsigned I = 0, E = Ty.getVectorNumElements(); I != E; ++I) - BuildVector.push_back(cast(Stores[I].MemNode)->getValue()); + for (unsigned I = 0, E = Ty.getVectorNumElements(); I != E; ++I) { + StoreSDNode *St = cast(Stores[I].MemNode); + Chains.push_back(St->getChain()); + BuildVector.push_back(St->getValue()); + } return DAG.getNode(ISD::BUILD_VECTOR, SL, Ty, BuildVector); } bool DAGCombiner::MergeStoresOfConstantsOrVecElts( SmallVectorImpl &StoreNodes, EVT MemVT, - unsigned NumElem, bool IsConstantSrc, bool UseVector) { + unsigned NumStores, bool IsConstantSrc, bool UseVector) { // Make sure we have something to merge. - if (NumElem < 2) + if (NumStores < 2) return false; int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8; LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; unsigned LatestNodeUsed = 0; - for (unsigned i=0; i < NumElem; ++i) { + for (unsigned i=0; i < NumStores; ++i) { // Find a chain for the new wide-store operand. Notice that some // of the store nodes that we found may not be selected for inclusion // in the wide store. The chain we use needs to be the chain of the @@ -10719,45 +10974,57 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( LatestNodeUsed = i; } + SmallVector Chains; + // The latest Node in the DAG. LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode; SDLoc DL(StoreNodes[0].MemNode); SDValue StoredVal; if (UseVector) { - // Find a legal type for the vector store. - EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT, NumElem); + bool IsVec = MemVT.isVector(); + unsigned Elts = NumStores; + if (IsVec) { + // When merging vector stores, get the total number of elements. + Elts *= MemVT.getVectorNumElements(); + } + // Get the type for the merged vector store. + EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); assert(TLI.isTypeLegal(Ty) && "Illegal vector store"); + if (IsConstantSrc) { - StoredVal = getMergedConstantVectorStore(DAG, DL, StoreNodes, Ty); + StoredVal = getMergedConstantVectorStore(DAG, DL, StoreNodes, Chains, Ty); } else { SmallVector Ops; - for (unsigned i = 0; i < NumElem ; ++i) { + for (unsigned i = 0; i < NumStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); SDValue Val = St->getValue(); - // All of the operands of a BUILD_VECTOR must have the same type. + // All operands of BUILD_VECTOR / CONCAT_VECTOR must have the same type. if (Val.getValueType() != MemVT) return false; Ops.push_back(Val); + Chains.push_back(St->getChain()); } // Build the extracted vector elements back into a vector. - StoredVal = DAG.getNode(ISD::BUILD_VECTOR, DL, Ty, Ops); - } + StoredVal = DAG.getNode(IsVec ? ISD::CONCAT_VECTORS : ISD::BUILD_VECTOR, + DL, Ty, Ops); } } else { // We should always use a vector store when merging extracted vector // elements, so this path implies a store of constants. assert(IsConstantSrc && "Merged vector elements should use vector store"); - unsigned SizeInBits = NumElem * ElementSizeBytes * 8; + unsigned SizeInBits = NumStores * ElementSizeBytes * 8; APInt StoreInt(SizeInBits, 0); // Construct a single integer constant which is made of the smaller // constant inputs. bool IsLE = DAG.getDataLayout().isLittleEndian(); - for (unsigned i = 0; i < NumElem ; ++i) { - unsigned Idx = IsLE ? (NumElem - 1 - i) : i; + for (unsigned i = 0; i < NumStores; ++i) { + unsigned Idx = IsLE ? (NumStores - 1 - i) : i; StoreSDNode *St = cast(StoreNodes[Idx].MemNode); + Chains.push_back(St->getChain()); + SDValue Val = St->getValue(); StoreInt <<= ElementSizeBytes * 8; if (ConstantSDNode *C = dyn_cast(Val)) { @@ -10774,7 +11041,10 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( StoredVal = DAG.getConstant(StoreInt, DL, StoreTy); } - SDValue NewStore = DAG.getStore(LatestOp->getChain(), DL, StoredVal, + assert(!Chains.empty()); + + SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains); + SDValue NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), false, false, @@ -10783,7 +11053,7 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( // Replace the last store with the new store CombineTo(LatestOp, NewStore); // Erase all other stores. - for (unsigned i = 0; i < NumElem ; ++i) { + for (unsigned i = 0; i < NumStores; ++i) { if (StoreNodes[i].MemNode == LatestOp) continue; StoreSDNode *St = cast(StoreNodes[i].MemNode); @@ -10826,6 +11096,38 @@ void DAGCombiner::getStoreMergeAndAliasCandidates( EVT MemVT = St->getMemoryVT(); unsigned Seq = 0; StoreSDNode *Index = St; + + + bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA + : DAG.getSubtarget().useAA(); + + if (UseAA) { + // Look at other users of the same chain. Stores on the same chain do not + // alias. If combiner-aa is enabled, non-aliasing stores are canonicalized + // to be on the same chain, so don't bother looking at adjacent chains. + + SDValue Chain = St->getChain(); + for (auto I = Chain->use_begin(), E = Chain->use_end(); I != E; ++I) { + if (StoreSDNode *OtherST = dyn_cast(*I)) { + if (I.getOperandNo() != 0) + continue; + + if (OtherST->isVolatile() || OtherST->isIndexed()) + continue; + + if (OtherST->getMemoryVT() != MemVT) + continue; + + BaseIndexOffset Ptr = BaseIndexOffset::match(OtherST->getBasePtr()); + + if (Ptr.equalBaseIndex(BasePtr)) + StoreNodes.push_back(MemOpLink(OtherST, Ptr.Offset, Seq++)); + } + } + + return; + } + while (Index) { // If the chain has more than one use, then we can't reorder the mem ops. if (Index != St && !SDValue(Index, 0)->hasOneUse()) @@ -10851,6 +11153,13 @@ void DAGCombiner::getStoreMergeAndAliasCandidates( if (Index->getMemoryVT() != MemVT) break; + // We do not allow under-aligned stores in order to prevent + // overriding stores. NOTE: this is a bad hack. Alignment SHOULD + // be irrelevant here; what MATTERS is that we not move memory + // operations that potentially overlap past each-other. + if (Index->getAlignment() < MemVT.getStoreSize()) + break; + // We found a potential memory operand to merge. StoreNodes.push_back(MemOpLink(Index, Ptr.Offset, Seq++)); @@ -10895,8 +11204,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { if (ElementSizeBytes * 8 != MemVT.getSizeInBits()) return false; - // Don't merge vectors into wider inputs. - if (MemVT.isVector() || !MemVT.isSimple()) + if (!MemVT.isSimple()) return false; // Perform an early exit check. Do not bother looking at stored values that @@ -10905,9 +11213,16 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { bool IsLoadSrc = isa(StoredVal); bool IsConstantSrc = isa(StoredVal) || isa(StoredVal); - bool IsExtractVecEltSrc = (StoredVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT); + bool IsExtractVecSrc = (StoredVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT || + StoredVal.getOpcode() == ISD::EXTRACT_SUBVECTOR); - if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecEltSrc) + if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecSrc) + return false; + + // Don't merge vectors into wider vectors if the source data comes from loads. + // TODO: This restriction can be lifted by using logic similar to the + // ExtractVecSrc case. + if (MemVT.isVector() && IsLoadSrc) return false; // Only look at ends of store sequences. @@ -10929,12 +11244,18 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { if (StoreNodes.size() < 2) return false; - // Sort the memory operands according to their distance from the base pointer. + // Sort the memory operands according to their distance from the + // base pointer. As a secondary criteria: make sure stores coming + // later in the code come first in the list. This is important for + // the non-UseAA case, because we're merging stores into the FINAL + // store along a chain which potentially contains aliasing stores. + // Thus, if there are multiple stores to the same address, the last + // one can be considered for merging but not the others. std::sort(StoreNodes.begin(), StoreNodes.end(), [](MemOpLink LHS, MemOpLink RHS) { return LHS.OffsetFromBase < RHS.OffsetFromBase || (LHS.OffsetFromBase == RHS.OffsetFromBase && - LHS.SequenceNum > RHS.SequenceNum); + LHS.SequenceNum < RHS.SequenceNum); }); // Scan the memory operations on the chain and find the first non-consecutive @@ -10994,9 +11315,10 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { // Find a legal type for the constant store. unsigned SizeInBits = (i+1) * ElementSizeBytes * 8; EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits); + bool IsFast; if (TLI.isTypeLegal(StoreTy) && TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign)) { + FirstStoreAlign, &IsFast) && IsFast) { LastLegalType = i+1; // Or check whether a truncstore is legal. } else if (TLI.getTypeAction(Context, StoreTy) == @@ -11005,32 +11327,27 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { TLI.getTypeToTransformTo(Context, StoredVal.getValueType()); if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) && TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy, - FirstStoreAS, FirstStoreAlign)) { + FirstStoreAS, FirstStoreAlign, &IsFast) && + IsFast) { LastLegalType = i + 1; } } - // Find a legal type for the vector store. - EVT Ty = EVT::getVectorVT(Context, MemVT, i+1); - if (TLI.isTypeLegal(Ty) && - TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, - FirstStoreAlign)) { - LastLegalVectorType = i + 1; + // We only use vectors if the constant is known to be zero or the target + // allows it and the function is not marked with the noimplicitfloat + // attribute. + if ((!NonZero || TLI.storeOfVectorConstantIsCheap(MemVT, i+1, + FirstStoreAS)) && + !NoVectors) { + // Find a legal type for the vector store. + EVT Ty = EVT::getVectorVT(Context, MemVT, i+1); + if (TLI.isTypeLegal(Ty) && + TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, + FirstStoreAlign, &IsFast) && IsFast) + LastLegalVectorType = i + 1; } } - - // We only use vectors if the constant is known to be zero or the target - // allows it and the function is not marked with the noimplicitfloat - // attribute. - if (NoVectors) { - LastLegalVectorType = 0; - } else if (NonZero && !TLI.storeOfVectorConstantIsCheap(MemVT, - LastLegalVectorType, - FirstStoreAS)) { - LastLegalVectorType = 0; - } - // Check if we found a legal integer type to store. if (LastLegalType == 0 && LastLegalVectorType == 0) return false; @@ -11044,28 +11361,36 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { // When extracting multiple vector elements, try to store them // in one vector store rather than a sequence of scalar stores. - if (IsExtractVecEltSrc) { - unsigned NumElem = 0; + if (IsExtractVecSrc) { + unsigned NumStoresToMerge = 0; + bool IsVec = MemVT.isVector(); for (unsigned i = 0; i < LastConsecutiveStore + 1; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); - SDValue StoredVal = St->getValue(); + unsigned StoreValOpcode = St->getValue().getOpcode(); // This restriction could be loosened. // Bail out if any stored values are not elements extracted from a vector. // It should be possible to handle mixed sources, but load sources need // more careful handling (see the block of code below that handles // consecutive loads). - if (StoredVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + if (StoreValOpcode != ISD::EXTRACT_VECTOR_ELT && + StoreValOpcode != ISD::EXTRACT_SUBVECTOR) return false; // Find a legal type for the vector store. - EVT Ty = EVT::getVectorVT(Context, MemVT, i+1); + unsigned Elts = i + 1; + if (IsVec) { + // When merging vector stores, get the total number of elements. + Elts *= MemVT.getVectorNumElements(); + } + EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); + bool IsFast; if (TLI.isTypeLegal(Ty) && TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, - FirstStoreAlign)) - NumElem = i + 1; + FirstStoreAlign, &IsFast) && IsFast) + NumStoresToMerge = i + 1; } - return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem, + return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStoresToMerge, false, true); } @@ -11147,14 +11472,14 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { if (CurrAddress - StartAddress != (ElementSizeBytes * i)) break; LastConsecutiveLoad = i; - // Find a legal type for the vector store. EVT StoreTy = EVT::getVectorVT(Context, MemVT, i+1); + bool IsFastSt, IsFastLd; if (TLI.isTypeLegal(StoreTy) && TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign) && + FirstStoreAlign, &IsFastSt) && IsFastSt && TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, - FirstLoadAlign)) { + FirstLoadAlign, &IsFastLd) && IsFastLd) { LastLegalVectorType = i + 1; } @@ -11163,9 +11488,9 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { StoreTy = EVT::getIntegerVT(Context, SizeInBits); if (TLI.isTypeLegal(StoreTy) && TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, - FirstStoreAlign) && + FirstStoreAlign, &IsFastSt) && IsFastSt && TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, - FirstLoadAlign)) + FirstLoadAlign, &IsFastLd) && IsFastLd) LastLegalIntegerType = i + 1; // Or check whether a truncstore and extload is legal. else if (TLI.getTypeAction(Context, StoreTy) == @@ -11177,9 +11502,11 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValueTy, StoreTy) && TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValueTy, StoreTy) && TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy, - FirstStoreAS, FirstStoreAlign) && + FirstStoreAS, FirstStoreAlign, &IsFastSt) && + IsFastSt && TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy, - FirstLoadAS, FirstLoadAlign)) + FirstLoadAS, FirstLoadAlign, &IsFastLd) && + IsFastLd) LastLegalIntegerType = i+1; } } @@ -11197,6 +11524,10 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { if (NumElem < 2) return false; + // Collect the chains from all merged stores. + SmallVector MergeStoreChains; + MergeStoreChains.push_back(StoreNodes[0].MemNode->getChain()); + // The latest Node in the DAG. unsigned LatestNodeUsed = 0; for (unsigned i=1; igetChain()); } LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode; @@ -11223,12 +11556,17 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { SDLoc LoadDL(LoadNodes[0].MemNode); SDLoc StoreDL(StoreNodes[0].MemNode); + // The merged loads are required to have the same chain, so using the first's + // chain is acceptable. SDValue NewLoad = DAG.getLoad( JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(), false, false, false, FirstLoadAlign); + SDValue NewStoreChain = + DAG.getNode(ISD::TokenFactor, StoreDL, MVT::Other, MergeStoreChains); + SDValue NewStore = DAG.getStore( - LatestOp->getChain(), StoreDL, NewLoad, FirstInChain->getBasePtr(), + NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), false, false, FirstStoreAlign); // Replace one of the loads with the new load. @@ -11259,6 +11597,114 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { return true; } +SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) { + SDLoc SL(ST); + SDValue ReplStore; + + // Replace the chain to avoid dependency. + if (ST->isTruncatingStore()) { + ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(), + ST->getBasePtr(), ST->getMemoryVT(), + ST->getMemOperand()); + } else { + ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(), + ST->getMemOperand()); + } + + // Create token to keep both nodes around. + SDValue Token = DAG.getNode(ISD::TokenFactor, SL, + MVT::Other, ST->getChain(), ReplStore); + + // Make sure the new and old chains are cleaned up. + AddToWorklist(Token.getNode()); + + // Don't add users to work list. + return CombineTo(ST, Token, false); +} + +SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) { + SDValue Value = ST->getValue(); + if (Value.getOpcode() == ISD::TargetConstantFP) + return SDValue(); + + SDLoc DL(ST); + + SDValue Chain = ST->getChain(); + SDValue Ptr = ST->getBasePtr(); + + const ConstantFPSDNode *CFP = cast(Value); + + // NOTE: If the original store is volatile, this transform must not increase + // the number of stores. For example, on x86-32 an f64 can be stored in one + // processor operation but an i64 (which is not legal) requires two. So the + // transform should not be done in this case. + + SDValue Tmp; + switch (CFP->getSimpleValueType(0).SimpleTy) { + default: + llvm_unreachable("Unknown FP type"); + case MVT::f16: // We don't do this for these yet. + case MVT::f80: + case MVT::f128: + case MVT::ppcf128: + return SDValue(); + case MVT::f32: + if ((isTypeLegal(MVT::i32) && !LegalOperations && !ST->isVolatile()) || + TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) { + ; + Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF(). + bitcastToAPInt().getZExtValue(), SDLoc(CFP), + MVT::i32); + return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand()); + } + + return SDValue(); + case MVT::f64: + if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations && + !ST->isVolatile()) || + TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) { + ; + Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt(). + getZExtValue(), SDLoc(CFP), MVT::i64); + return DAG.getStore(Chain, DL, Tmp, + Ptr, ST->getMemOperand()); + } + + if (!ST->isVolatile() && + TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) { + // Many FP stores are not made apparent until after legalize, e.g. for + // argument passing. Since this is so common, custom legalize the + // 64-bit integer store into two 32-bit stores. + uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue(); + SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32); + SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32); + if (DAG.getDataLayout().isBigEndian()) + std::swap(Lo, Hi); + + unsigned Alignment = ST->getAlignment(); + bool isVolatile = ST->isVolatile(); + bool isNonTemporal = ST->isNonTemporal(); + AAMDNodes AAInfo = ST->getAAInfo(); + + SDValue St0 = DAG.getStore(Chain, DL, Lo, + Ptr, ST->getPointerInfo(), + isVolatile, isNonTemporal, + ST->getAlignment(), AAInfo); + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, + DAG.getConstant(4, DL, Ptr.getValueType())); + Alignment = MinAlign(Alignment, 4U); + SDValue St1 = DAG.getStore(Chain, DL, Hi, + Ptr, ST->getPointerInfo().getWithOffset(4), + isVolatile, isNonTemporal, + Alignment, AAInfo); + return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, + St0, St1); + } + + return SDValue(); + } +} + SDValue DAGCombiner::visitSTORE(SDNode *N) { StoreSDNode *ST = cast(N); SDValue Chain = ST->getChain(); @@ -11286,81 +11732,6 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { if (Value.getOpcode() == ISD::UNDEF && ST->isUnindexed()) return Chain; - // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr' - if (ConstantFPSDNode *CFP = dyn_cast(Value)) { - // NOTE: If the original store is volatile, this transform must not increase - // the number of stores. For example, on x86-32 an f64 can be stored in one - // processor operation but an i64 (which is not legal) requires two. So the - // transform should not be done in this case. - if (Value.getOpcode() != ISD::TargetConstantFP) { - SDValue Tmp; - switch (CFP->getSimpleValueType(0).SimpleTy) { - default: llvm_unreachable("Unknown FP type"); - case MVT::f16: // We don't do this for these yet. - case MVT::f80: - case MVT::f128: - case MVT::ppcf128: - break; - case MVT::f32: - if ((isTypeLegal(MVT::i32) && !LegalOperations && !ST->isVolatile()) || - TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) { - ; - Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF(). - bitcastToAPInt().getZExtValue(), SDLoc(CFP), - MVT::i32); - return DAG.getStore(Chain, SDLoc(N), Tmp, - Ptr, ST->getMemOperand()); - } - break; - case MVT::f64: - if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations && - !ST->isVolatile()) || - TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) { - ; - Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt(). - getZExtValue(), SDLoc(CFP), MVT::i64); - return DAG.getStore(Chain, SDLoc(N), Tmp, - Ptr, ST->getMemOperand()); - } - - if (!ST->isVolatile() && - TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) { - // Many FP stores are not made apparent until after legalize, e.g. for - // argument passing. Since this is so common, custom legalize the - // 64-bit integer store into two 32-bit stores. - uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue(); - SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32); - SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32); - if (DAG.getDataLayout().isBigEndian()) - std::swap(Lo, Hi); - - unsigned Alignment = ST->getAlignment(); - bool isVolatile = ST->isVolatile(); - bool isNonTemporal = ST->isNonTemporal(); - AAMDNodes AAInfo = ST->getAAInfo(); - - SDLoc DL(N); - - SDValue St0 = DAG.getStore(Chain, SDLoc(ST), Lo, - Ptr, ST->getPointerInfo(), - isVolatile, isNonTemporal, - ST->getAlignment(), AAInfo); - Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, - DAG.getConstant(4, DL, Ptr.getValueType())); - Alignment = MinAlign(Alignment, 4U); - SDValue St1 = DAG.getStore(Chain, SDLoc(ST), Hi, - Ptr, ST->getPointerInfo().getWithOffset(4), - isVolatile, isNonTemporal, - Alignment, AAInfo); - return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, - St0, St1); - } - - break; - } - } - } - // Try to infer better alignment information than the store already has. if (OptLevel != CodeGenOpt::None && ST->isUnindexed()) { if (unsigned Align = DAG.InferPtrAlignment(Ptr)) { @@ -11389,31 +11760,17 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { UseAA = false; #endif if (UseAA && ST->isUnindexed()) { - // Walk up chain skipping non-aliasing memory nodes. - SDValue BetterChain = FindBetterChain(N, Chain); + // FIXME: We should do this even without AA enabled. AA will just allow + // FindBetterChain to work in more situations. The problem with this is that + // any combine that expects memory operations to be on consecutive chains + // first needs to be updated to look for users of the same chain. - // If there is a better chain. - if (Chain != BetterChain) { - SDValue ReplStore; - - // Replace the chain to avoid dependency. - if (ST->isTruncatingStore()) { - ReplStore = DAG.getTruncStore(BetterChain, SDLoc(N), Value, Ptr, - ST->getMemoryVT(), ST->getMemOperand()); - } else { - ReplStore = DAG.getStore(BetterChain, SDLoc(N), Value, Ptr, - ST->getMemOperand()); - } - - // Create token to keep both nodes around. - SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N), - MVT::Other, Chain, ReplStore); - - // Make sure the new and old chains are cleaned up. - AddToWorklist(Token.getNode()); - - // Don't add users to work list. - return CombineTo(N, Token, false); + // Walk up chain skipping non-aliasing memory nodes, on this store and any + // adjacent stores. + if (findBetterNeighborChains(ST)) { + // replaceStoreChain uses CombineTo, which handled all of the worklist + // manipulation. Return the original node to not do anything else. + return SDValue(ST, 0); } } @@ -11498,6 +11855,16 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { return SDValue(N, 0); } + // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr' + // + // Make sure to do this only after attempting to merge stores in order to + // avoid changing the types of some subset of stores due to visit order, + // preventing their merging. + if (isa(Value)) { + if (SDValue NewSt = replaceStoreOfFPConstant(ST)) + return NewSt; + } + return ReduceLoadOpStoreWidth(N); } @@ -11671,7 +12038,24 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { } SDValue EltNo = N->getOperand(1); - bool ConstEltNo = isa(EltNo); + ConstantSDNode *ConstEltNo = dyn_cast(EltNo); + + // extract_vector_elt (build_vector x, y), 1 -> y + if (ConstEltNo && + InVec.getOpcode() == ISD::BUILD_VECTOR && + TLI.isTypeLegal(VT) && + (InVec.hasOneUse() || + TLI.aggressivelyPreferBuildVectorSources(VT))) { + SDValue Elt = InVec.getOperand(ConstEltNo->getZExtValue()); + EVT InEltVT = Elt.getValueType(); + + // Sometimes build_vector's scalar input types do not match result type. + if (NVT == InEltVT) + return Elt; + + // TODO: It may be useful to truncate if free if the build_vector implicitly + // converts. + } // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT. // We only perform this optimization before the op legalization phase because @@ -11679,13 +12063,11 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // patterns. For example on AVX, extracting elements from a wide vector // without using extract_subvector. However, if we can find an underlying // scalar value, then we can always use that. - if (InVec.getOpcode() == ISD::VECTOR_SHUFFLE - && ConstEltNo) { - int Elt = cast(EltNo)->getZExtValue(); + if (ConstEltNo && InVec.getOpcode() == ISD::VECTOR_SHUFFLE) { int NumElem = VT.getVectorNumElements(); ShuffleVectorSDNode *SVOp = cast(InVec); // Find the new index to extract from. - int OrigElt = SVOp->getMaskElt(Elt); + int OrigElt = SVOp->getMaskElt(ConstEltNo->getZExtValue()); // Extracting an undef index is undef. if (OrigElt == -1) @@ -12270,15 +12652,24 @@ static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { // What vector are we extracting the subvector from and at what index? SDValue ExtVec = Op.getOperand(0); + + // We want the EVT of the original extraction to correctly scale the + // extraction index. + EVT ExtVT = ExtVec.getValueType(); + + // Peek through any bitcast. + while (ExtVec.getOpcode() == ISD::BITCAST) + ExtVec = ExtVec.getOperand(0); + + // UNDEF nodes convert to UNDEF shuffle mask values. if (ExtVec.getOpcode() == ISD::UNDEF) { Mask.append((unsigned)NumOpElts, -1); continue; } - EVT ExtVT = ExtVec.getValueType(); if (!isa(Op.getOperand(1))) return SDValue(); - int ExtIdx = dyn_cast(Op.getOperand(1))->getZExtValue(); + int ExtIdx = cast(Op.getOperand(1))->getZExtValue(); // Ensure that we are extracting a subvector from a vector the same // size as the result. @@ -12635,7 +13026,7 @@ static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) { std::all_of(SVN->getMask().begin() + NumElemsPerConcat, SVN->getMask().end(), [](int i) { return i == -1; })) { N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0), N0.getOperand(1), - ArrayRef(SVN->getMask().begin(), NumElemsPerConcat)); + makeArrayRef(SVN->getMask().begin(), NumElemsPerConcat)); N1 = DAG.getUNDEF(ConcatVT); return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1); } @@ -13113,6 +13504,21 @@ SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) { + SDValue N0 = N->getOperand(0); + + // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) + if (N0->getOpcode() == ISD::AND) { + ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1)); + if (AndConst && AndConst->getAPIntValue() == 0xffff) { + return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0), + N0.getOperand(0)); + } + } + + return SDValue(); +} + /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle /// with the destination vector and a zero vector. /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==> @@ -13215,56 +13621,12 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) { SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); + SDValue Ops[] = {LHS, RHS}; - // If the LHS and RHS are BUILD_VECTOR nodes, see if we can constant fold - // this operation. - if (LHS.getOpcode() == ISD::BUILD_VECTOR && - RHS.getOpcode() == ISD::BUILD_VECTOR) { - // Check if both vectors are constants. If not bail out. - if (!(cast(LHS)->isConstant() && - cast(RHS)->isConstant())) - return SDValue(); - - SmallVector Ops; - for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) { - SDValue LHSOp = LHS.getOperand(i); - SDValue RHSOp = RHS.getOperand(i); - - // Can't fold divide by zero. - if (N->getOpcode() == ISD::SDIV || N->getOpcode() == ISD::UDIV || - N->getOpcode() == ISD::FDIV) { - if (isNullConstant(RHSOp) || (RHSOp.getOpcode() == ISD::ConstantFP && - cast(RHSOp.getNode())->isZero())) - break; - } - - EVT VT = LHSOp.getValueType(); - EVT RVT = RHSOp.getValueType(); - if (RVT != VT) { - // Integer BUILD_VECTOR operands may have types larger than the element - // size (e.g., when the element type is not legal). Prior to type - // legalization, the types may not match between the two BUILD_VECTORS. - // Truncate one of the operands to make them match. - if (RVT.getSizeInBits() > VT.getSizeInBits()) { - RHSOp = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, RHSOp); - } else { - LHSOp = DAG.getNode(ISD::TRUNCATE, SDLoc(N), RVT, LHSOp); - VT = RVT; - } - } - SDValue FoldOp = DAG.getNode(N->getOpcode(), SDLoc(LHS), VT, - LHSOp, RHSOp); - if (FoldOp.getOpcode() != ISD::UNDEF && - FoldOp.getOpcode() != ISD::Constant && - FoldOp.getOpcode() != ISD::ConstantFP) - break; - Ops.push_back(FoldOp); - AddToWorklist(FoldOp.getNode()); - } - - if (Ops.size() == LHS.getNumOperands()) - return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), LHS.getValueType(), Ops); - } + // See if we can constant fold the vector operation. + if (SDValue Fold = DAG.FoldConstantVectorArithmetic( + N->getOpcode(), SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags())) + return Fold; // Try to convert a constant mask AND into a shuffle clear mask. if (SDValue Shuffle = XformToShuffleWithZero(N)) @@ -13284,7 +13646,8 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) { EVT VT = N->getValueType(0); SDValue UndefVector = LHS.getOperand(1); SDValue NewBinOp = DAG.getNode(N->getOpcode(), SDLoc(N), VT, - LHS.getOperand(0), RHS.getOperand(0)); + LHS.getOperand(0), RHS.getOperand(0), + N->getFlags()); AddUsersToWorklist(N); return DAG.getVectorShuffle(VT, SDLoc(N), NewBinOp, UndefVector, &SVN0->getMask()[0]); @@ -13657,8 +14020,7 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1, // Get a SetCC of the condition // NOTE: Don't create a SETCC if it's not legal on this target. if (!LegalOperations || - TLI.isOperationLegal(ISD::SETCC, - LegalTypes ? getSetCCResultType(N0.getValueType()) : MVT::i1)) { + TLI.isOperationLegal(ISD::SETCC, N0.getValueType())) { SDValue Temp, SCC; // cast from setcc result type to select result type if (LegalTypes) { @@ -13690,51 +14052,6 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1, } } - // Check to see if this is the equivalent of setcc - // FIXME: Turn all of these into setcc if setcc if setcc is legal - // otherwise, go ahead with the folds. - if (0 && isNullConstant(N3) && isOneConstant(N2)) { - EVT XType = N0.getValueType(); - if (!LegalOperations || - TLI.isOperationLegal(ISD::SETCC, getSetCCResultType(XType))) { - SDValue Res = DAG.getSetCC(DL, getSetCCResultType(XType), N0, N1, CC); - if (Res.getValueType() != VT) - Res = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Res); - return Res; - } - - // fold (seteq X, 0) -> (srl (ctlz X, log2(size(X)))) - if (isNullConstant(N1) && CC == ISD::SETEQ && - (!LegalOperations || - TLI.isOperationLegal(ISD::CTLZ, XType))) { - SDValue Ctlz = DAG.getNode(ISD::CTLZ, SDLoc(N0), XType, N0); - return DAG.getNode(ISD::SRL, DL, XType, Ctlz, - DAG.getConstant(Log2_32(XType.getSizeInBits()), - SDLoc(Ctlz), - getShiftAmountTy(Ctlz.getValueType()))); - } - // fold (setgt X, 0) -> (srl (and (-X, ~X), size(X)-1)) - if (isNullConstant(N1) && CC == ISD::SETGT) { - SDLoc DL(N0); - SDValue NegN0 = DAG.getNode(ISD::SUB, DL, - XType, DAG.getConstant(0, DL, XType), N0); - SDValue NotN0 = DAG.getNOT(DL, N0, XType); - return DAG.getNode(ISD::SRL, DL, XType, - DAG.getNode(ISD::AND, DL, XType, NegN0, NotN0), - DAG.getConstant(XType.getSizeInBits() - 1, DL, - getShiftAmountTy(XType))); - } - // fold (setgt X, -1) -> (xor (srl (X, size(X)-1), 1)) - if (isAllOnesConstant(N1) && CC == ISD::SETGT) { - SDLoc DL(N0); - SDValue Sign = DAG.getNode(ISD::SRL, DL, XType, N0, - DAG.getConstant(XType.getSizeInBits() - 1, DL, - getShiftAmountTy(N0.getValueType()))); - return DAG.getNode(ISD::XOR, DL, XType, Sign, DAG.getConstant(1, DL, - XType)); - } - } - // Check to see if this is an integer abs. // select_cc setg[te] X, 0, X, -X -> // select_cc setgt X, -1, X, -X -> @@ -13842,7 +14159,7 @@ SDValue DAGCombiner::BuildUDIV(SDNode *N) { return S; } -SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op) { +SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) { if (Level >= AfterLegalizeDAG) return SDValue(); @@ -13866,16 +14183,16 @@ SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op) { // Newton iterations: Est = Est + Est (1 - Arg * Est) for (unsigned i = 0; i < Iterations; ++i) { - SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, Est); + SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, Est, Flags); AddToWorklist(NewEst.getNode()); - NewEst = DAG.getNode(ISD::FSUB, DL, VT, FPOne, NewEst); + NewEst = DAG.getNode(ISD::FSUB, DL, VT, FPOne, NewEst, Flags); AddToWorklist(NewEst.getNode()); - NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst); + NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags); AddToWorklist(NewEst.getNode()); - Est = DAG.getNode(ISD::FADD, DL, VT, Est, NewEst); + Est = DAG.getNode(ISD::FADD, DL, VT, Est, NewEst, Flags); AddToWorklist(Est.getNode()); } } @@ -13892,31 +14209,32 @@ SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op) { /// X_{i+1} = X_i (1.5 - A X_i^2 / 2) /// As a result, we precompute A/2 prior to the iteration loop. SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est, - unsigned Iterations) { + unsigned Iterations, + SDNodeFlags *Flags) { EVT VT = Arg.getValueType(); SDLoc DL(Arg); SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT); // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that // this entire sequence requires only one FP constant. - SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg); + SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags); AddToWorklist(HalfArg.getNode()); - HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg); + HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags); AddToWorklist(HalfArg.getNode()); // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est) for (unsigned i = 0; i < Iterations; ++i) { - SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est); + SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags); AddToWorklist(NewEst.getNode()); - NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst); + NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags); AddToWorklist(NewEst.getNode()); - NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst); + NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags); AddToWorklist(NewEst.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst); + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags); AddToWorklist(Est.getNode()); } return Est; @@ -13928,7 +14246,8 @@ SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est, /// => /// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0)) SDValue DAGCombiner::BuildRsqrtNRTwoConst(SDValue Arg, SDValue Est, - unsigned Iterations) { + unsigned Iterations, + SDNodeFlags *Flags) { EVT VT = Arg.getValueType(); SDLoc DL(Arg); SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT); @@ -13936,25 +14255,25 @@ SDValue DAGCombiner::BuildRsqrtNRTwoConst(SDValue Arg, SDValue Est, // Newton iterations: Est = -0.5 * Est * (-3.0 + Arg * Est * Est) for (unsigned i = 0; i < Iterations; ++i) { - SDValue HalfEst = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf); + SDValue HalfEst = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags); AddToWorklist(HalfEst.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Est); + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags); AddToWorklist(Est.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg); + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags); AddToWorklist(Est.getNode()); - Est = DAG.getNode(ISD::FADD, DL, VT, Est, MinusThree); + Est = DAG.getNode(ISD::FADD, DL, VT, Est, MinusThree, Flags); AddToWorklist(Est.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, HalfEst); + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, HalfEst, Flags); AddToWorklist(Est.getNode()); } return Est; } -SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op) { +SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) { if (Level >= AfterLegalizeDAG) return SDValue(); @@ -13966,8 +14285,8 @@ SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op) { AddToWorklist(Est.getNode()); if (Iterations) { Est = UseOneConstNR ? - BuildRsqrtNROneConst(Op, Est, Iterations) : - BuildRsqrtNRTwoConst(Op, Est, Iterations); + BuildRsqrtNROneConst(Op, Est, Iterations, Flags) : + BuildRsqrtNRTwoConst(Op, Est, Iterations, Flags); } return Est; } @@ -14131,14 +14450,12 @@ void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain, SDValue Chain = Chains.pop_back_val(); // For TokenFactor nodes, look at each operand and only continue up the - // chain until we find two aliases. If we've seen two aliases, assume we'll - // find more and revert to original chain since the xform is unlikely to be - // profitable. + // chain until we reach the depth limit. // // FIXME: The depth check could be made to return the last non-aliasing // chain we found before we hit a tokenfactor rather than the original // chain. - if (Depth > 6 || Aliases.size() == 2) { + if (Depth > TLI.getGatherAllAliasesMaxDepth()) { Aliases.clear(); Aliases.push_back(OriginalChain); return; @@ -14270,6 +14587,83 @@ SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) { return DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Aliases); } +bool DAGCombiner::findBetterNeighborChains(StoreSDNode* St) { + // This holds the base pointer, index, and the offset in bytes from the base + // pointer. + BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr()); + + // We must have a base and an offset. + if (!BasePtr.Base.getNode()) + return false; + + // Do not handle stores to undef base pointers. + if (BasePtr.Base.getOpcode() == ISD::UNDEF) + return false; + + SmallVector ChainedStores; + ChainedStores.push_back(St); + + // Walk up the chain and look for nodes with offsets from the same + // base pointer. Stop when reaching an instruction with a different kind + // or instruction which has a different base pointer. + StoreSDNode *Index = St; + while (Index) { + // If the chain has more than one use, then we can't reorder the mem ops. + if (Index != St && !SDValue(Index, 0)->hasOneUse()) + break; + + if (Index->isVolatile() || Index->isIndexed()) + break; + + // Find the base pointer and offset for this memory node. + BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr()); + + // Check that the base pointer is the same as the original one. + if (!Ptr.equalBaseIndex(BasePtr)) + break; + + // Find the next memory operand in the chain. If the next operand in the + // chain is a store then move up and continue the scan with the next + // memory operand. If the next operand is a load save it and use alias + // information to check if it interferes with anything. + SDNode *NextInChain = Index->getChain().getNode(); + while (true) { + if (StoreSDNode *STn = dyn_cast(NextInChain)) { + // We found a store node. Use it for the next iteration. + ChainedStores.push_back(STn); + Index = STn; + break; + } else if (LoadSDNode *Ldn = dyn_cast(NextInChain)) { + NextInChain = Ldn->getChain().getNode(); + continue; + } else { + Index = nullptr; + break; + } + } + } + + bool MadeChange = false; + SmallVector, 8> BetterChains; + + for (StoreSDNode *ChainedStore : ChainedStores) { + SDValue Chain = ChainedStore->getChain(); + SDValue BetterChain = FindBetterChain(ChainedStore, Chain); + + if (Chain != BetterChain) { + MadeChange = true; + BetterChains.push_back(std::make_pair(ChainedStore, BetterChain)); + } + } + + // Do all replacements after finding the replacements to make to avoid making + // the chains more complicated by introducing new TokenFactors. + for (auto Replacement : BetterChains) + replaceStoreChain(Replacement.first, Replacement.second); + + return MadeChange; +} + /// This is the entry point for the file. void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis &AA, CodeGenOpt::Level OptLevel) {