X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FCodeGen%2FSelectionDAG%2FDAGCombiner.cpp;h=f119023d217b03ea185eb29990623f5f498b1469;hp=3c31d24fffcc0483752026eedc73cdf79464fdfd;hb=20a42bb20d43b80e322c95dd99b64a5a4566fe08;hpb=298a718c94bad00f580ddd4cb4fec5d218053f86 diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 3c31d24fffc..f119023d217 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -156,13 +156,16 @@ namespace { void deleteAndRecombine(SDNode *N); bool recursivelyDeleteUnusedNodes(SDNode *N); + /// Replaces all uses of the results of one DAG node with new values. SDValue CombineTo(SDNode *N, const SDValue *To, unsigned NumTo, bool AddTo = true); + /// Replaces all uses of the results of one DAG node with new values. SDValue CombineTo(SDNode *N, SDValue Res, bool AddTo = true) { return CombineTo(N, &Res, 1, AddTo); } + /// Replaces all uses of the results of one DAG node with new values. SDValue CombineTo(SDNode *N, SDValue Res0, SDValue Res1, bool AddTo = true) { SDValue To[] = { Res0, Res1 }; @@ -233,18 +236,17 @@ namespace { SDValue visitADDE(SDNode *N); SDValue visitSUBE(SDNode *N); SDValue visitMUL(SDNode *N); + SDValue useDivRem(SDNode *N); SDValue visitSDIV(SDNode *N); SDValue visitUDIV(SDNode *N); - SDValue visitSREM(SDNode *N); - SDValue visitUREM(SDNode *N); + SDValue visitREM(SDNode *N); SDValue visitMULHU(SDNode *N); SDValue visitMULHS(SDNode *N); SDValue visitSMUL_LOHI(SDNode *N); SDValue visitUMUL_LOHI(SDNode *N); SDValue visitSMULO(SDNode *N); SDValue visitUMULO(SDNode *N); - SDValue visitSDIVREM(SDNode *N); - SDValue visitUDIVREM(SDNode *N); + SDValue visitIMINMAX(SDNode *N); SDValue visitAND(SDNode *N); SDValue visitANDLike(SDValue N0, SDValue N1, SDNode *LocReference); SDValue visitOR(SDNode *N); @@ -265,6 +267,7 @@ namespace { SDValue visitVSELECT(SDNode *N); SDValue visitSELECT_CC(SDNode *N); SDValue visitSETCC(SDNode *N); + SDValue visitSETCCE(SDNode *N); SDValue visitSIGN_EXTEND(SDNode *N); SDValue visitZERO_EXTEND(SDNode *N); SDValue visitANY_EXTEND(SDNode *N); @@ -298,6 +301,10 @@ namespace { SDValue visitBRCOND(SDNode *N); SDValue visitBR_CC(SDNode *N); SDValue visitLOAD(SDNode *N); + + SDValue replaceStoreChain(StoreSDNode *ST, SDValue BetterChain); + SDValue replaceStoreOfFPConstant(StoreSDNode *ST); + SDValue visitSTORE(SDNode *N); SDValue visitINSERT_VECTOR_ELT(SDNode *N); SDValue visitEXTRACT_VECTOR_ELT(SDNode *N); @@ -312,9 +319,11 @@ namespace { SDValue visitMGATHER(SDNode *N); SDValue visitMSCATTER(SDNode *N); SDValue visitFP_TO_FP16(SDNode *N); + SDValue visitFP16_TO_FP(SDNode *N); SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); + SDValue visitFMULForFMACombine(SDNode *N); SDValue XformToShuffleWithZero(SDNode *N); SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS); @@ -338,14 +347,17 @@ namespace { unsigned HiOp); SDValue CombineConsecutiveLoads(SDNode *N, EVT VT); SDValue CombineExtLoad(SDNode *N); + SDValue combineRepeatedFPDivisors(SDNode *N); SDValue ConstantFoldBITCASTofBUILD_VECTOR(SDNode *, EVT); SDValue BuildSDIV(SDNode *N); SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); - SDValue BuildReciprocalEstimate(SDValue Op); - SDValue BuildRsqrtEstimate(SDValue Op); - SDValue BuildRsqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations); - SDValue BuildRsqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations); + SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags); + SDValue BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags); + SDValue BuildRsqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations, + SDNodeFlags *Flags); + SDValue BuildRsqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations, + SDNodeFlags *Flags); SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, bool DemandHighBits = true); SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1); @@ -374,6 +386,10 @@ namespace { /// chain (aliasing node.) SDValue FindBetterChain(SDNode *N, SDValue Chain); + /// Do FindBetterChain for a store and any possibly adjacent stores on + /// consecutive chains. + bool findBetterNeighborChains(StoreSDNode *St); + /// Holds a pointer to an LSBaseSDNode as well as information on where it /// is located in a sequence of memory operations connected by a chain. struct MemOpLink { @@ -388,19 +404,37 @@ namespace { unsigned SequenceNum; }; + /// This is a helper function for visitMUL to check the profitability + /// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2). + /// MulNode is the original multiply, AddNode is (add x, c1), + /// and ConstNode is c2. + bool isMulAddWithConstProfitable(SDNode *MulNode, + SDValue &AddNode, + SDValue &ConstNode); + /// This is a helper function for MergeStoresOfConstantsOrVecElts. Returns a /// constant build_vector of the stored constant values in Stores. SDValue getMergedConstantVectorStore(SelectionDAG &DAG, SDLoc SL, ArrayRef Stores, + SmallVectorImpl &Chains, EVT Ty) const; + /// This is a helper function for visitAND and visitZERO_EXTEND. Returns + /// true if the (and (load x) c) pattern matches an extload. ExtVT returns + /// the type of the loaded value to be extended. LoadedVT returns the type + /// of the original loaded value. NarrowLoad returns whether the load would + /// need to be narrowed in order to match. + bool isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, + EVT LoadResultTy, EVT &ExtVT, EVT &LoadedVT, + bool &NarrowLoad); + /// This is a helper function for MergeConsecutiveStores. When the source /// elements of the consecutive stores are all constants or all extracted /// vector elements, try to merge them into one larger store. /// \return True if a merged store was created. bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl &StoreNodes, - EVT MemVT, unsigned NumElem, + EVT MemVT, unsigned NumStores, bool IsConstantSrc, bool UseVector); /// This is a helper function for MergeConsecutiveStores. @@ -409,7 +443,7 @@ namespace { void getStoreMergeAndAliasCandidates( StoreSDNode* St, SmallVectorImpl &StoreNodes, SmallVectorImpl &AliasLoadNodes); - + /// Merge consecutive store operations into a wide store. /// This optimization uses wide integers or vectors when possible. /// \return True if some memory operations were changed. @@ -427,9 +461,7 @@ namespace { DAGCombiner(SelectionDAG &D, AliasAnalysis &A, CodeGenOpt::Level OL) : DAG(D), TLI(D.getTargetLoweringInfo()), Level(BeforeLegalizeTypes), OptLevel(OL), LegalOperations(false), LegalTypes(false), AA(A) { - auto *F = DAG.getMachineFunction().getFunction(); - ForCodeSize = F->hasFnAttribute(Attribute::OptimizeForSize) || - F->hasFnAttribute(Attribute::MinSize); + ForCodeSize = DAG.getMachineFunction().getFunction()->optForSize(); } /// Runs the dag combiner on all nodes in the work list @@ -443,8 +475,9 @@ namespace { assert(LHSTy.isInteger() && "Shift amount is not an integer type!"); if (LHSTy.isVector()) return LHSTy; - return LegalTypes ? TLI.getScalarShiftAmountTy(LHSTy) - : TLI.getPointerTy(); + auto &DL = DAG.getDataLayout(); + return LegalTypes ? TLI.getScalarShiftAmountTy(DL, LHSTy) + : TLI.getPointerTy(DL); } /// This method returns true if we are running before type legalization or @@ -456,7 +489,7 @@ namespace { /// Convenience wrapper around TargetLowering::getSetCCResultType EVT getSetCCResultType(EVT VT) const { - return TLI.getSetCCResultType(*DAG.getContext(), VT); + return TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); } }; } @@ -605,6 +638,9 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, assert(Op.hasOneUse() && "Unknown reuse!"); assert(Depth <= 6 && "GetNegatedExpression doesn't match isNegatibleForFree"); + + const SDNodeFlags *Flags = Op.getNode()->getFlags(); + switch (Op.getOpcode()) { default: llvm_unreachable("Unknown code"); case ISD::ConstantFP: { @@ -622,12 +658,12 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), GetNegatedExpression(Op.getOperand(0), DAG, LegalOperations, Depth+1), - Op.getOperand(1)); + Op.getOperand(1), Flags); // fold (fneg (fadd A, B)) -> (fsub (fneg B), A) return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), GetNegatedExpression(Op.getOperand(1), DAG, LegalOperations, Depth+1), - Op.getOperand(0)); + Op.getOperand(0), Flags); case ISD::FSUB: // We can't turn -(A-B) into B-A when we honor signed zeros. assert(Options.UnsafeFPMath); @@ -639,7 +675,7 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, // fold (fneg (fsub A, B)) -> (fsub B, A) return DAG.getNode(ISD::FSUB, SDLoc(Op), Op.getValueType(), - Op.getOperand(1), Op.getOperand(0)); + Op.getOperand(1), Op.getOperand(0), Flags); case ISD::FMUL: case ISD::FDIV: @@ -651,13 +687,13 @@ static SDValue GetNegatedExpression(SDValue Op, SelectionDAG &DAG, return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), GetNegatedExpression(Op.getOperand(0), DAG, LegalOperations, Depth+1), - Op.getOperand(1)); + Op.getOperand(1), Flags); // fold (fneg (fmul X, Y)) -> (fmul X, (fneg Y)) return DAG.getNode(Op.getOpcode(), SDLoc(Op), Op.getValueType(), Op.getOperand(0), GetNegatedExpression(Op.getOperand(1), DAG, - LegalOperations, Depth+1)); + LegalOperations, Depth+1), Flags); case ISD::FP_EXTEND: case ISD::FSIN: @@ -888,6 +924,62 @@ CommitTargetLoweringOpt(const TargetLowering::TargetLoweringOpt &TLO) { bool DAGCombiner::SimplifyDemandedBits(SDValue Op, const APInt &Demanded) { TargetLowering::TargetLoweringOpt TLO(DAG, LegalTypes, LegalOperations); APInt KnownZero, KnownOne; + + // XXX-disabled: + auto Opcode = Op.getOpcode(); + if (Opcode == ISD::AND || Opcode == ISD::OR) { + auto* Op1 = Op.getOperand(0).getNode(); + auto* Op2 = Op.getOperand(1).getNode(); + auto* Op1C = dyn_cast(Op1); + auto* Op2C = dyn_cast(Op2); + + // and X, 0 + if (Opcode == ISD::AND && !Op1C && Op2C && Op2C->isNullValue()) { + return false; + } + + // or (and X, 0), Y + if (Opcode == ISD::OR) { + if (Op1->getOpcode() == ISD::AND) { + auto* Op11 = Op1->getOperand(0).getNode(); + auto* Op12 = Op1->getOperand(1).getNode(); + auto* Op11C = dyn_cast(Op11); + auto* Op12C = dyn_cast(Op12); + if (!Op11C && Op12C && Op12C->isNullValue()) { + return false; + } + } + if (Op1->getOpcode() == ISD::TRUNCATE) { + // or (trunc (and %0, 0)), Y + auto* Op11 = Op1->getOperand(0).getNode(); + if (Op11->getOpcode() == ISD::AND) { + auto* Op111 = Op11->getOperand(0).getNode(); + auto* Op112 = Op11->getOperand(1).getNode(); + auto* Op111C = dyn_cast(Op111); + auto* Op112C = dyn_cast(Op112); + if (!Op111C && Op112C && Op112C->isNullValue()) { + // or (and X, 0), Y + return false; + } + } + } + } + } + + // trunc (and X, 0) + if (Opcode == ISD::TRUNCATE) { + auto* Op1 = Op.getOperand(0).getNode(); + if (Op1->getOpcode() == ISD::AND) { + auto* Op11 = Op1->getOperand(0).getNode(); + auto* Op12 = Op1->getOperand(1).getNode(); + auto* Op11C = dyn_cast(Op11); + auto* Op12C = dyn_cast(Op12); + if (!Op11C && Op12C && Op12C->isNullValue()) { + return false; + } + } + } + if (!TLI.SimplifyDemandedBits(Op, Demanded, KnownZero, KnownOne, TLO)) return false; @@ -1215,9 +1307,8 @@ void DAGCombiner::Run(CombineLevel AtLevel) { LegalTypes = Level >= AfterLegalizeTypes; // Add all the dag nodes to the worklist. - for (SelectionDAG::allnodes_iterator I = DAG.allnodes_begin(), - E = DAG.allnodes_end(); I != E; ++I) - AddToWorklist(I); + for (SDNode &Node : DAG.allnodes()) + AddToWorklist(&Node); // Create a dummy node (which is not added to allnodes), that adds a reference // to the root node, preventing it from being deleted, and tracking any @@ -1332,16 +1423,18 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::MUL: return visitMUL(N); case ISD::SDIV: return visitSDIV(N); case ISD::UDIV: return visitUDIV(N); - case ISD::SREM: return visitSREM(N); - case ISD::UREM: return visitUREM(N); + case ISD::SREM: + case ISD::UREM: return visitREM(N); case ISD::MULHU: return visitMULHU(N); case ISD::MULHS: return visitMULHS(N); case ISD::SMUL_LOHI: return visitSMUL_LOHI(N); case ISD::UMUL_LOHI: return visitUMUL_LOHI(N); case ISD::SMULO: return visitSMULO(N); case ISD::UMULO: return visitUMULO(N); - case ISD::SDIVREM: return visitSDIVREM(N); - case ISD::UDIVREM: return visitUDIVREM(N); + case ISD::SMIN: + case ISD::SMAX: + case ISD::UMIN: + case ISD::UMAX: return visitIMINMAX(N); case ISD::AND: return visitAND(N); case ISD::OR: return visitOR(N); case ISD::XOR: return visitXOR(N); @@ -1360,6 +1453,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::VSELECT: return visitVSELECT(N); case ISD::SELECT_CC: return visitSELECT_CC(N); case ISD::SETCC: return visitSETCC(N); + case ISD::SETCCE: return visitSETCCE(N); case ISD::SIGN_EXTEND: return visitSIGN_EXTEND(N); case ISD::ZERO_EXTEND: return visitZERO_EXTEND(N); case ISD::ANY_EXTEND: return visitANY_EXTEND(N); @@ -1407,6 +1501,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::MSCATTER: return visitMSCATTER(N); case ISD::MSTORE: return visitMSTORE(N); case ISD::FP_TO_FP16: return visitFP_TO_FP16(N); + case ISD::FP16_TO_FP: return visitFP16_TO_FP(N); } return SDValue(); } @@ -1469,13 +1564,8 @@ SDValue DAGCombiner::combine(SDNode *N) { // Constant operands are canonicalized to RHS. if (isa(N0) || !isa(N1)) { SDValue Ops[] = {N1, N0}; - SDNode *CSENode; - if (const auto *BinNode = dyn_cast(N)) { - CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops, - &BinNode->Flags); - } else { - CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops); - } + SDNode *CSENode = DAG.getNodeIfExists(N->getOpcode(), N->getVTList(), Ops, + N->getFlags()); if (CSENode) return SDValue(CSENode, 0); } @@ -1594,26 +1684,6 @@ SDValue DAGCombiner::visitMERGE_VALUES(SDNode *N) { return SDValue(N, 0); // Return N so it doesn't get rechecked! } -static bool isNullConstant(SDValue V) { - ConstantSDNode *Const = dyn_cast(V); - return Const != nullptr && Const->isNullValue(); -} - -static bool isNullFPConstant(SDValue V) { - ConstantFPSDNode *Const = dyn_cast(V); - return Const != nullptr && Const->isZero() && !Const->isNegative(); -} - -static bool isAllOnesConstant(SDValue V) { - ConstantSDNode *Const = dyn_cast(V); - return Const != nullptr && Const->isAllOnesValue(); -} - -static bool isOneConstant(SDValue V) { - ConstantSDNode *Const = dyn_cast(V); - return Const != nullptr && Const->isOne(); -} - /// If \p N is a ContantSDNode with isOpaque() == false return it casted to a /// ContantSDNode pointer else nullptr. static ConstantSDNode *getAsNonOpaqueConstant(SDValue N) { @@ -1720,22 +1790,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) { return SDValue(N, 0); // fold (a+b) -> (a|b) iff a and b share no bits. - if (VT.isInteger() && !VT.isVector()) { - APInt LHSZero, LHSOne; - APInt RHSZero, RHSOne; - DAG.computeKnownBits(N0, LHSZero, LHSOne); - - if (LHSZero.getBoolValue()) { - DAG.computeKnownBits(N1, RHSZero, RHSOne); - - // If all possibly-set bits on the LHS are clear on the RHS, return an OR. - // If all possibly-set bits on the RHS are clear on the LHS, return an OR. - if ((RHSZero & ~LHSZero) == ~LHSZero || (LHSZero & ~RHSZero) == ~RHSZero){ - if (!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) - return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N1); - } - } - } + if ((!LegalOperations || TLI.isOperationLegal(ISD::OR, VT)) && + VT.isInteger() && !VT.isVector() && DAG.haveNoCommonBitsSet(N0, N1)) + return DAG.getNode(ISD::OR, SDLoc(N), VT, N0, N1); // fold (add x, shl(0 - y, n)) -> sub(x, shl(y, n)) if (N1.getOpcode() == ISD::SHL && N1.getOperand(0).getOpcode() == ISD::SUB && @@ -1970,31 +2027,26 @@ SDValue DAGCombiner::visitSUBC(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N0.getValueType(); + SDLoc DL(N); // If the flag result is dead, turn this into an SUB. if (!N->hasAnyUseOfValue(1)) - return CombineTo(N, DAG.getNode(ISD::SUB, SDLoc(N), VT, N0, N1), - DAG.getNode(ISD::CARRY_FALSE, SDLoc(N), - MVT::Glue)); + return CombineTo(N, DAG.getNode(ISD::SUB, DL, VT, N0, N1), + DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); // fold (subc x, x) -> 0 + no borrow - if (N0 == N1) { - SDLoc DL(N); + if (N0 == N1) return CombineTo(N, DAG.getConstant(0, DL, VT), - DAG.getNode(ISD::CARRY_FALSE, DL, - MVT::Glue)); - } + DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); // fold (subc x, 0) -> x + no borrow if (isNullConstant(N1)) - return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, SDLoc(N), - MVT::Glue)); + return CombineTo(N, N0, DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); // Canonicalize (sub -1, x) -> ~x, i.e. (xor x, -1) + no borrow if (isAllOnesConstant(N0)) - return CombineTo(N, DAG.getNode(ISD::XOR, SDLoc(N), VT, N1, N0), - DAG.getNode(ISD::CARRY_FALSE, SDLoc(N), - MVT::Glue)); + return CombineTo(N, DAG.getNode(ISD::XOR, DL, VT, N1, N0), + DAG.getNode(ISD::CARRY_FALSE, DL, MVT::Glue)); return SDValue(); } @@ -2129,14 +2181,15 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { } // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2) - if (N1IsConst && N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse() && - (isConstantSplatVector(N0.getOperand(1).getNode(), Val) || - isa(N0.getOperand(1)))) - return DAG.getNode(ISD::ADD, SDLoc(N), VT, - DAG.getNode(ISD::MUL, SDLoc(N0), VT, - N0.getOperand(0), N1), - DAG.getNode(ISD::MUL, SDLoc(N1), VT, - N0.getOperand(1), N1)); + if (isConstantIntBuildVectorOrConstantInt(N1) && + N0.getOpcode() == ISD::ADD && + isConstantIntBuildVectorOrConstantInt(N0.getOperand(1)) && + isMulAddWithConstProfitable(N, N0, N1)) + return DAG.getNode(ISD::ADD, SDLoc(N), VT, + DAG.getNode(ISD::MUL, SDLoc(N0), VT, + N0.getOperand(0), N1), + DAG.getNode(ISD::MUL, SDLoc(N1), VT, + N0.getOperand(1), N1)); // reassociate mul if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1)) @@ -2145,6 +2198,88 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { return SDValue(); } +/// Return true if divmod libcall is available. +static bool isDivRemLibcallAvailable(SDNode *Node, bool isSigned, + const TargetLowering &TLI) { + RTLIB::Libcall LC; + switch (Node->getSimpleValueType(0).SimpleTy) { + default: return false; // No libcall for vector types. + case MVT::i8: LC= isSigned ? RTLIB::SDIVREM_I8 : RTLIB::UDIVREM_I8; break; + case MVT::i16: LC= isSigned ? RTLIB::SDIVREM_I16 : RTLIB::UDIVREM_I16; break; + case MVT::i32: LC= isSigned ? RTLIB::SDIVREM_I32 : RTLIB::UDIVREM_I32; break; + case MVT::i64: LC= isSigned ? RTLIB::SDIVREM_I64 : RTLIB::UDIVREM_I64; break; + case MVT::i128: LC= isSigned ? RTLIB::SDIVREM_I128:RTLIB::UDIVREM_I128; break; + } + + return TLI.getLibcallName(LC) != nullptr; +} + +/// Issue divrem if both quotient and remainder are needed. +SDValue DAGCombiner::useDivRem(SDNode *Node) { + if (Node->use_empty()) + return SDValue(); // This is a dead node, leave it alone. + + EVT VT = Node->getValueType(0); + if (!TLI.isTypeLegal(VT)) + return SDValue(); + + unsigned Opcode = Node->getOpcode(); + bool isSigned = (Opcode == ISD::SDIV) || (Opcode == ISD::SREM); + + unsigned DivRemOpc = isSigned ? ISD::SDIVREM : ISD::UDIVREM; + // If DIVREM is going to get expanded into a libcall, + // but there is no libcall available, then don't combine. + if (!TLI.isOperationLegalOrCustom(DivRemOpc, VT) && + !isDivRemLibcallAvailable(Node, isSigned, TLI)) + return SDValue(); + + // If div is legal, it's better to do the normal expansion + unsigned OtherOpcode = 0; + if ((Opcode == ISD::SDIV) || (Opcode == ISD::UDIV)) { + OtherOpcode = isSigned ? ISD::SREM : ISD::UREM; + if (TLI.isOperationLegalOrCustom(Opcode, VT)) + return SDValue(); + } else { + OtherOpcode = isSigned ? ISD::SDIV : ISD::UDIV; + if (TLI.isOperationLegalOrCustom(OtherOpcode, VT)) + return SDValue(); + } + + SDValue Op0 = Node->getOperand(0); + SDValue Op1 = Node->getOperand(1); + SDValue combined; + for (SDNode::use_iterator UI = Op0.getNode()->use_begin(), + UE = Op0.getNode()->use_end(); UI != UE; ++UI) { + SDNode *User = *UI; + if (User == Node || User->use_empty()) + continue; + // Convert the other matching node(s), too; + // otherwise, the DIVREM may get target-legalized into something + // target-specific that we won't be able to recognize. + unsigned UserOpc = User->getOpcode(); + if ((UserOpc == Opcode || UserOpc == OtherOpcode || UserOpc == DivRemOpc) && + User->getOperand(0) == Op0 && + User->getOperand(1) == Op1) { + if (!combined) { + if (UserOpc == OtherOpcode) { + SDVTList VTs = DAG.getVTList(VT, VT); + combined = DAG.getNode(DivRemOpc, SDLoc(Node), VTs, Op0, Op1); + } else if (UserOpc == DivRemOpc) { + combined = SDValue(User, 0); + } else { + assert(UserOpc == Opcode); + continue; + } + } + if (UserOpc == ISD::SDIV || UserOpc == ISD::UDIV) + CombineTo(User, combined); + else if (UserOpc == ISD::SREM || UserOpc == ISD::UREM) + CombineTo(User, combined.getValue(1)); + } + } + return combined; +} + SDValue DAGCombiner::visitSDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -2155,26 +2290,26 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; + SDLoc DL(N); + // fold (sdiv c1, c2) -> c1/c2 ConstantSDNode *N0C = isConstOrConstSplat(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); if (N0C && N1C && !N0C->isOpaque() && !N1C->isOpaque()) - return DAG.FoldConstantArithmetic(ISD::SDIV, SDLoc(N), VT, N0C, N1C); + return DAG.FoldConstantArithmetic(ISD::SDIV, DL, VT, N0C, N1C); // fold (sdiv X, 1) -> X if (N1C && N1C->isOne()) return N0; // fold (sdiv X, -1) -> 0-X - if (N1C && N1C->isAllOnesValue()) { - SDLoc DL(N); + if (N1C && N1C->isAllOnesValue()) return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), N0); - } + // If we know the sign bits of both operands are zero, strength reduce to a // udiv instead. Handles (X&15) /s 4 -> X&15 >> 2 if (!VT.isVector()) { if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) - return DAG.getNode(ISD::UDIV, SDLoc(N), N1.getValueType(), - N0, N1); + return DAG.getNode(ISD::UDIV, DL, N1.getValueType(), N0, N1); } // fold (sdiv X, pow2) -> simple ops after legalize @@ -2185,18 +2320,11 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { !cast(N)->Flags.hasExact() && (N1C->getAPIntValue().isPowerOf2() || (-N1C->getAPIntValue()).isPowerOf2())) { - // If dividing by powers of two is cheap, then don't perform the following - // fold. - if (TLI.isPow2SDivCheap()) - return SDValue(); - // Target-specific implementation of sdiv x, pow2. - SDValue Res = BuildSDIVPow2(N); - if (Res.getNode()) + if (SDValue Res = BuildSDIVPow2(N)) return Res; unsigned lg2 = N1C->getAPIntValue().countTrailingZeros(); - SDLoc DL(N); // Splat the sign bit into the register SDValue SGN = @@ -2227,15 +2355,23 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { } // If integer divide is expensive and we satisfy the requirements, emit an - // alternate sequence. - if (N1C && !TLI.isIntDivCheap()) { - SDValue Op = BuildSDIV(N); - if (Op.getNode()) return Op; - } + // alternate sequence. Targets may check function attributes for size/speed + // trade-offs. + AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes(); + if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue Op = BuildSDIV(N)) + return Op; + + // sdiv, srem -> sdivrem + // If the divisor is constant, then return DIVREM only if isIntDivCheap() is true. + // Otherwise, we break the simplification logic in visitREM(). + if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue DivRem = useDivRem(N)) + return DivRem; // undef / X -> 0 if (N0.getOpcode() == ISD::UNDEF) - return DAG.getConstant(0, SDLoc(N), VT); + return DAG.getConstant(0, DL, VT); // X / undef -> undef if (N1.getOpcode() == ISD::UNDEF) return N1; @@ -2253,26 +2389,26 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; + SDLoc DL(N); + // fold (udiv c1, c2) -> c1/c2 ConstantSDNode *N0C = isConstOrConstSplat(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); if (N0C && N1C) - if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, SDLoc(N), VT, + if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UDIV, DL, VT, N0C, N1C)) return Folded; // fold (udiv x, (1 << c)) -> x >>u c - if (N1C && !N1C->isOpaque() && N1C->getAPIntValue().isPowerOf2()) { - SDLoc DL(N); + if (N1C && !N1C->isOpaque() && N1C->getAPIntValue().isPowerOf2()) return DAG.getNode(ISD::SRL, DL, VT, N0, DAG.getConstant(N1C->getAPIntValue().logBase2(), DL, getShiftAmountTy(N0.getValueType()))); - } + // fold (udiv x, (shl c, y)) -> x >>u (log2(c)+y) iff c is power of 2 if (N1.getOpcode() == ISD::SHL) { if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) { if (SHC->getAPIntValue().isPowerOf2()) { EVT ADDVT = N1.getOperand(1).getValueType(); - SDLoc DL(N); SDValue Add = DAG.getNode(ISD::ADD, DL, ADDVT, N1.getOperand(1), DAG.getConstant(SHC->getAPIntValue() @@ -2283,15 +2419,23 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { } } } + // fold (udiv x, c) -> alternate - if (N1C && !TLI.isIntDivCheap()) { - SDValue Op = BuildUDIV(N); - if (Op.getNode()) return Op; - } + AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes(); + if (N1C && !TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue Op = BuildUDIV(N)) + return Op; + + // sdiv, srem -> sdivrem + // If the divisor is constant, then return DIVREM only if isIntDivCheap() is true. + // Otherwise, we break the simplification logic in visitREM(). + if (!N1C || TLI.isIntDivCheap(N->getValueType(0), Attr)) + if (SDValue DivRem = useDivRem(N)) + return DivRem; // undef / X -> 0 if (N0.getOpcode() == ISD::UNDEF) - return DAG.getConstant(0, SDLoc(N), VT); + return DAG.getConstant(0, DL, VT); // X / undef -> undef if (N1.getOpcode() == ISD::UNDEF) return N1; @@ -2299,102 +2443,83 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitSREM(SDNode *N) { +// handles ISD::SREM and ISD::UREM +SDValue DAGCombiner::visitREM(SDNode *N) { + unsigned Opcode = N->getOpcode(); SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); + bool isSigned = (Opcode == ISD::SREM); + SDLoc DL(N); - // fold (srem c1, c2) -> c1%c2 + // fold (rem c1, c2) -> c1%c2 ConstantSDNode *N0C = isConstOrConstSplat(N0); ConstantSDNode *N1C = isConstOrConstSplat(N1); if (N0C && N1C) - if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::SREM, SDLoc(N), VT, - N0C, N1C)) + if (SDValue Folded = DAG.FoldConstantArithmetic(Opcode, DL, VT, N0C, N1C)) return Folded; - // If we know the sign bits of both operands are zero, strength reduce to a - // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15 - if (!VT.isVector()) { - if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) - return DAG.getNode(ISD::UREM, SDLoc(N), VT, N0, N1); - } - // If X/C can be simplified by the division-by-constant logic, lower - // X%C to the equivalent of X-X/C*C. - if (N1C && !N1C->isNullValue()) { - SDValue Div = DAG.getNode(ISD::SDIV, SDLoc(N), VT, N0, N1); - AddToWorklist(Div.getNode()); - SDValue OptimizedDiv = combine(Div.getNode()); - if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != Div.getNode()) { - SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, - OptimizedDiv, N1); - SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, N0, Mul); - AddToWorklist(Mul.getNode()); - return Sub; + if (isSigned) { + // If we know the sign bits of both operands are zero, strength reduce to a + // urem instead. Handles (X & 0x0FFFFFFF) %s 16 -> X&15 + if (!VT.isVector()) { + if (DAG.SignBitIsZero(N1) && DAG.SignBitIsZero(N0)) + return DAG.getNode(ISD::UREM, DL, VT, N0, N1); } - } - - // undef % X -> 0 - if (N0.getOpcode() == ISD::UNDEF) - return DAG.getConstant(0, SDLoc(N), VT); - // X % undef -> undef - if (N1.getOpcode() == ISD::UNDEF) - return N1; - - return SDValue(); -} - -SDValue DAGCombiner::visitUREM(SDNode *N) { - SDValue N0 = N->getOperand(0); - SDValue N1 = N->getOperand(1); - EVT VT = N->getValueType(0); - - // fold (urem c1, c2) -> c1%c2 - ConstantSDNode *N0C = isConstOrConstSplat(N0); - ConstantSDNode *N1C = isConstOrConstSplat(N1); - if (N0C && N1C) - if (SDValue Folded = DAG.FoldConstantArithmetic(ISD::UREM, SDLoc(N), VT, - N0C, N1C)) - return Folded; - // fold (urem x, pow2) -> (and x, pow2-1) - if (N1C && !N1C->isNullValue() && !N1C->isOpaque() && - N1C->getAPIntValue().isPowerOf2()) { - SDLoc DL(N); - return DAG.getNode(ISD::AND, DL, VT, N0, - DAG.getConstant(N1C->getAPIntValue() - 1, DL, VT)); - } - // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1)) - if (N1.getOpcode() == ISD::SHL) { - if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) { - if (SHC->getAPIntValue().isPowerOf2()) { - SDLoc DL(N); - SDValue Add = - DAG.getNode(ISD::ADD, DL, VT, N1, + } else { + // fold (urem x, pow2) -> (and x, pow2-1) + if (N1C && !N1C->isNullValue() && !N1C->isOpaque() && + N1C->getAPIntValue().isPowerOf2()) { + return DAG.getNode(ISD::AND, DL, VT, N0, + DAG.getConstant(N1C->getAPIntValue() - 1, DL, VT)); + } + // fold (urem x, (shl pow2, y)) -> (and x, (add (shl pow2, y), -1)) + if (N1.getOpcode() == ISD::SHL) { + if (ConstantSDNode *SHC = getAsNonOpaqueConstant(N1.getOperand(0))) { + if (SHC->getAPIntValue().isPowerOf2()) { + SDValue Add = + DAG.getNode(ISD::ADD, DL, VT, N1, DAG.getConstant(APInt::getAllOnesValue(VT.getSizeInBits()), DL, VT)); - AddToWorklist(Add.getNode()); - return DAG.getNode(ISD::AND, DL, VT, N0, Add); + AddToWorklist(Add.getNode()); + return DAG.getNode(ISD::AND, DL, VT, N0, Add); + } } } } + AttributeSet Attr = DAG.getMachineFunction().getFunction()->getAttributes(); + // If X/C can be simplified by the division-by-constant logic, lower // X%C to the equivalent of X-X/C*C. - if (N1C && !N1C->isNullValue()) { - SDValue Div = DAG.getNode(ISD::UDIV, SDLoc(N), VT, N0, N1); + // To avoid mangling nodes, this simplification requires that the combine() + // call for the speculative DIV must not cause a DIVREM conversion. We guard + // against this by skipping the simplification if isIntDivCheap(). When + // div is not cheap, combine will not return a DIVREM. Regardless, + // checking cheapness here makes sense since the simplification results in + // fatter code. + if (N1C && !N1C->isNullValue() && !TLI.isIntDivCheap(VT, Attr)) { + unsigned DivOpcode = isSigned ? ISD::SDIV : ISD::UDIV; + SDValue Div = DAG.getNode(DivOpcode, DL, VT, N0, N1); AddToWorklist(Div.getNode()); SDValue OptimizedDiv = combine(Div.getNode()); if (OptimizedDiv.getNode() && OptimizedDiv.getNode() != Div.getNode()) { - SDValue Mul = DAG.getNode(ISD::MUL, SDLoc(N), VT, - OptimizedDiv, N1); - SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, N0, Mul); + assert((OptimizedDiv.getOpcode() != ISD::UDIVREM) && + (OptimizedDiv.getOpcode() != ISD::SDIVREM)); + SDValue Mul = DAG.getNode(ISD::MUL, DL, VT, OptimizedDiv, N1); + SDValue Sub = DAG.getNode(ISD::SUB, DL, VT, N0, Mul); AddToWorklist(Mul.getNode()); return Sub; } } + // sdiv, srem -> sdivrem + if (SDValue DivRem = useDivRem(N)) + return DivRem.getValue(1); + // undef % X -> 0 if (N0.getOpcode() == ISD::UNDEF) - return DAG.getConstant(0, SDLoc(N), VT); + return DAG.getConstant(0, DL, VT); // X % undef -> undef if (N1.getOpcode() == ISD::UNDEF) return N1; @@ -2531,8 +2656,8 @@ SDValue DAGCombiner::SimplifyNodeWithTwoResults(SDNode *N, unsigned LoOp, } SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) { - SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS); - if (Res.getNode()) return Res; + if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHS)) + return Res; EVT VT = N->getValueType(0); SDLoc DL(N); @@ -2562,8 +2687,8 @@ SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) { } SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) { - SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU); - if (Res.getNode()) return Res; + if (SDValue Res = SimplifyNodeWithTwoResults(N, ISD::MUL, ISD::MULHU)) + return Res; EVT VT = N->getValueType(0); SDLoc DL(N); @@ -2612,16 +2737,26 @@ SDValue DAGCombiner::visitUMULO(SDNode *N) { return SDValue(); } -SDValue DAGCombiner::visitSDIVREM(SDNode *N) { - SDValue Res = SimplifyNodeWithTwoResults(N, ISD::SDIV, ISD::SREM); - if (Res.getNode()) return Res; +SDValue DAGCombiner::visitIMINMAX(SDNode *N) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N0.getValueType(); - return SDValue(); -} + // fold vector ops + if (VT.isVector()) + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; + + // fold (add c1, c2) -> c1+c2 + ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); + ConstantSDNode *N1C = getAsNonOpaqueConstant(N1); + if (N0C && N1C) + return DAG.FoldConstantArithmetic(N->getOpcode(), SDLoc(N), VT, N0C, N1C); -SDValue DAGCombiner::visitUDIVREM(SDNode *N) { - SDValue Res = SimplifyNodeWithTwoResults(N, ISD::UDIV, ISD::UREM); - if (Res.getNode()) return Res; + // canonicalize constant to RHS + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) + return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0); return SDValue(); } @@ -2847,10 +2982,13 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, if (Result != ISD::SETCC_INVALID && (!LegalOperations || (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) && - TLI.isOperationLegal(ISD::SETCC, - getSetCCResultType(N0.getSimpleValueType()))))) - return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), - LL, LR, Result); + TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) { + EVT CCVT = getSetCCResultType(LL.getValueType()); + if (N0.getValueType() == CCVT || + (!LegalOperations && N0.getValueType() == MVT::i1)) + return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), + LL, LR, Result); + } } } @@ -2886,6 +3024,46 @@ SDValue DAGCombiner::visitANDLike(SDValue N0, SDValue N1, return SDValue(); } +bool DAGCombiner::isAndLoadExtLoad(ConstantSDNode *AndC, LoadSDNode *LoadN, + EVT LoadResultTy, EVT &ExtVT, EVT &LoadedVT, + bool &NarrowLoad) { + uint32_t ActiveBits = AndC->getAPIntValue().getActiveBits(); + + if (ActiveBits == 0 || !APIntOps::isMask(ActiveBits, AndC->getAPIntValue())) + return false; + + ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); + LoadedVT = LoadN->getMemoryVT(); + + if (ExtVT == LoadedVT && + (!LegalOperations || + TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT))) { + // ZEXTLOAD will match without needing to change the size of the value being + // loaded. + NarrowLoad = false; + return true; + } + + // Do not change the width of a volatile load. + if (LoadN->isVolatile()) + return false; + + // Do not generate loads of non-round integer types since these can + // be expensive (and would be wrong if the type is not byte sized). + if (!LoadedVT.bitsGT(ExtVT) || !ExtVT.isRound()) + return false; + + if (LegalOperations && + !TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, ExtVT)) + return false; + + if (!TLI.shouldReduceLoadWidth(LoadN, ISD::ZEXTLOAD, ExtVT)) + return false; + + NarrowLoad = true; + return true; +} + SDValue DAGCombiner::visitAND(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -2920,6 +3098,22 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // fold (and c1, c2) -> c1&c2 ConstantSDNode *N0C = getAsNonOpaqueConstant(N0); ConstantSDNode *N1C = dyn_cast(N1); + + // XXX-disabled: (and x, 0) should not be folded. + // (and (and x, 0), y) shouldn't either. + if (!N0C && N1C && N1C->isNullValue()) { + return SDValue(); + } + if (!N0C) { + if (N0.getOpcode() == ISD::AND) { + auto* N01 = N0.getOperand(1).getNode(); + auto* N01C = dyn_cast(N01); + if (N01C && N01C->isNullValue()) { + return SDValue(); + } + } + } + if (N0C && N1C && !N1C->isOpaque()) return DAG.FoldConstantArithmetic(ISD::AND, SDLoc(N), VT, N0C, N1C); // canonicalize constant to RHS @@ -3078,16 +3272,12 @@ SDValue DAGCombiner::visitAND(SDNode *N) { : cast(N0); if (LN0->getExtensionType() != ISD::SEXTLOAD && LN0->isUnindexed() && N0.hasOneUse() && SDValue(LN0, 0).hasOneUse()) { - uint32_t ActiveBits = N1C->getAPIntValue().getActiveBits(); - if (ActiveBits > 0 && APIntOps::isMask(ActiveBits, N1C->getAPIntValue())){ - EVT ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); - EVT LoadedVT = LN0->getMemoryVT(); - EVT LoadResultTy = HasAnyExt ? LN0->getValueType(0) : VT; - - if (ExtVT == LoadedVT && - (!LegalOperations || TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, - ExtVT))) { - + auto NarrowLoad = false; + EVT LoadResultTy = HasAnyExt ? LN0->getValueType(0) : VT; + EVT ExtVT, LoadedVT; + if (isAndLoadExtLoad(N1C, LN0, LoadResultTy, ExtVT, LoadedVT, + NarrowLoad)) { + if (!NarrowLoad) { SDValue NewLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN0), LoadResultTy, LN0->getChain(), LN0->getBasePtr(), ExtVT, @@ -3095,14 +3285,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { AddToWorklist(N); CombineTo(LN0, NewLoad, NewLoad.getValue(1)); return SDValue(N, 0); // Return N so it doesn't get rechecked! - } - - // Do not change the width of a volatile load. - // Do not generate loads of non-round integer types since these can - // be expensive (and would be wrong if the type is not byte sized). - if (!LN0->isVolatile() && LoadedVT.bitsGT(ExtVT) && ExtVT.isRound() && - (!LegalOperations || TLI.isLoadExtLegal(ISD::ZEXTLOAD, LoadResultTy, - ExtVT))) { + } else { EVT PtrType = LN0->getOperand(1).getValueType(); unsigned Alignment = LN0->getAlignment(); @@ -3141,10 +3324,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) { return Combined; // Simplify: (and (op x...), (op y...)) -> (op (and x, y)) - if (N0.getOpcode() == N1.getOpcode()) { - SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N); - if (Tmp.getNode()) return Tmp; - } + if (N0.getOpcode() == N1.getOpcode()) + if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) + return Tmp; // fold (and (sign_extend_inreg x, i16 to i32), 1) -> (and x, 1) // fold (and (sra)) -> (and (srl)) when possible. @@ -3506,10 +3688,13 @@ SDValue DAGCombiner::visitORLike(SDValue N0, SDValue N1, SDNode *LocReference) { if (Result != ISD::SETCC_INVALID && (!LegalOperations || (TLI.isCondCodeLegal(Result, LL.getSimpleValueType()) && - TLI.isOperationLegal(ISD::SETCC, - getSetCCResultType(N0.getValueType()))))) - return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), - LL, LR, Result); + TLI.isOperationLegal(ISD::SETCC, LL.getValueType())))) { + EVT CCVT = getSetCCResultType(LL.getValueType()); + if (N0.getValueType() == CCVT || + (!LegalOperations && N0.getValueType() == MVT::i1)) + return DAG.getSetCC(SDLoc(LocReference), N0.getValueType(), + LL, LR, Result); + } } } @@ -3664,11 +3849,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) { return Combined; // Recognize halfword bswaps as (bswap + rotl 16) or (bswap + shl 16) - SDValue BSwap = MatchBSwapHWord(N, N0, N1); - if (BSwap.getNode()) + if (SDValue BSwap = MatchBSwapHWord(N, N0, N1)) return BSwap; - BSwap = MatchBSwapHWordLow(N, N0, N1); - if (BSwap.getNode()) + if (SDValue BSwap = MatchBSwapHWordLow(N, N0, N1)) return BSwap; // reassociate or @@ -3689,10 +3872,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) { } } // Simplify: (or (op x...), (op y...)) -> (op (or x, y)) - if (N0.getOpcode() == N1.getOpcode()) { - SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N); - if (Tmp.getNode()) return Tmp; - } + if (N0.getOpcode() == N1.getOpcode()) + if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) + return Tmp; // See if this is some rotate idiom. if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N))) @@ -3709,7 +3891,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) { /// Match "(X shl/srl V1) & V2" where V2 may not be present. static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { if (Op.getOpcode() == ISD::AND) { - if (isa(Op.getOperand(1))) { + if (isConstantIntBuildVectorOrConstantInt(Op.getOperand(1))) { Mask = Op.getOperand(1); Op = Op.getOperand(0); } else { @@ -3726,105 +3908,106 @@ static bool MatchRotateHalf(SDValue Op, SDValue &Shift, SDValue &Mask) { } // Return true if we can prove that, whenever Neg and Pos are both in the -// range [0, OpSize), Neg == (Pos == 0 ? 0 : OpSize - Pos). This means that +// range [0, EltSize), Neg == (Pos == 0 ? 0 : EltSize - Pos). This means that // for two opposing shifts shift1 and shift2 and a value X with OpBits bits: // // (or (shift1 X, Neg), (shift2 X, Pos)) // // reduces to a rotate in direction shift2 by Pos or (equivalently) a rotate -// in direction shift1 by Neg. The range [0, OpSize) means that we only need +// in direction shift1 by Neg. The range [0, EltSize) means that we only need // to consider shift amounts with defined behavior. -static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned OpSize) { - // If OpSize is a power of 2 then: +static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize) { + // If EltSize is a power of 2 then: // - // (a) (Pos == 0 ? 0 : OpSize - Pos) == (OpSize - Pos) & (OpSize - 1) - // (b) Neg == Neg & (OpSize - 1) whenever Neg is in [0, OpSize). + // (a) (Pos == 0 ? 0 : EltSize - Pos) == (EltSize - Pos) & (EltSize - 1) + // (b) Neg == Neg & (EltSize - 1) whenever Neg is in [0, EltSize). // - // So if OpSize is a power of 2 and Neg is (and Neg', OpSize-1), we check + // So if EltSize is a power of 2 and Neg is (and Neg', EltSize-1), we check // for the stronger condition: // - // Neg & (OpSize - 1) == (OpSize - Pos) & (OpSize - 1) [A] + // Neg & (EltSize - 1) == (EltSize - Pos) & (EltSize - 1) [A] // - // for all Neg and Pos. Since Neg & (OpSize - 1) == Neg' & (OpSize - 1) + // for all Neg and Pos. Since Neg & (EltSize - 1) == Neg' & (EltSize - 1) // we can just replace Neg with Neg' for the rest of the function. // // In other cases we check for the even stronger condition: // - // Neg == OpSize - Pos [B] + // Neg == EltSize - Pos [B] // // for all Neg and Pos. Note that the (or ...) then invokes undefined - // behavior if Pos == 0 (and consequently Neg == OpSize). + // behavior if Pos == 0 (and consequently Neg == EltSize). // - // We could actually use [A] whenever OpSize is a power of 2, but the + // We could actually use [A] whenever EltSize is a power of 2, but the // only extra cases that it would match are those uninteresting ones // where Neg and Pos are never in range at the same time. E.g. for - // OpSize == 32, using [A] would allow a Neg of the form (sub 64, Pos) + // EltSize == 32, using [A] would allow a Neg of the form (sub 64, Pos) // as well as (sub 32, Pos), but: // // (or (shift1 X, (sub 64, Pos)), (shift2 X, Pos)) // // always invokes undefined behavior for 32-bit X. // - // Below, Mask == OpSize - 1 when using [A] and is all-ones otherwise. + // Below, Mask == EltSize - 1 when using [A] and is all-ones otherwise. unsigned MaskLoBits = 0; - if (Neg.getOpcode() == ISD::AND && - isPowerOf2_64(OpSize) && - Neg.getOperand(1).getOpcode() == ISD::Constant && - cast(Neg.getOperand(1))->getAPIntValue() == OpSize - 1) { - Neg = Neg.getOperand(0); - MaskLoBits = Log2_64(OpSize); + if (Neg.getOpcode() == ISD::AND && isPowerOf2_64(EltSize)) { + if (ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(1))) { + if (NegC->getAPIntValue() == EltSize - 1) { + Neg = Neg.getOperand(0); + MaskLoBits = Log2_64(EltSize); + } + } } // Check whether Neg has the form (sub NegC, NegOp1) for some NegC and NegOp1. if (Neg.getOpcode() != ISD::SUB) - return 0; - ConstantSDNode *NegC = dyn_cast(Neg.getOperand(0)); + return false; + ConstantSDNode *NegC = isConstOrConstSplat(Neg.getOperand(0)); if (!NegC) - return 0; + return false; SDValue NegOp1 = Neg.getOperand(1); - // On the RHS of [A], if Pos is Pos' & (OpSize - 1), just replace Pos with + // On the RHS of [A], if Pos is Pos' & (EltSize - 1), just replace Pos with // Pos'. The truncation is redundant for the purpose of the equality. - if (MaskLoBits && - Pos.getOpcode() == ISD::AND && - Pos.getOperand(1).getOpcode() == ISD::Constant && - cast(Pos.getOperand(1))->getAPIntValue() == OpSize - 1) - Pos = Pos.getOperand(0); + if (MaskLoBits && Pos.getOpcode() == ISD::AND) + if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) + if (PosC->getAPIntValue() == EltSize - 1) + Pos = Pos.getOperand(0); // The condition we need is now: // - // (NegC - NegOp1) & Mask == (OpSize - Pos) & Mask + // (NegC - NegOp1) & Mask == (EltSize - Pos) & Mask // // If NegOp1 == Pos then we need: // - // OpSize & Mask == NegC & Mask + // EltSize & Mask == NegC & Mask // // (because "x & Mask" is a truncation and distributes through subtraction). APInt Width; if (Pos == NegOp1) Width = NegC->getAPIntValue(); + // Check for cases where Pos has the form (add NegOp1, PosC) for some PosC. // Then the condition we want to prove becomes: // - // (NegC - NegOp1) & Mask == (OpSize - (NegOp1 + PosC)) & Mask + // (NegC - NegOp1) & Mask == (EltSize - (NegOp1 + PosC)) & Mask // // which, again because "x & Mask" is a truncation, becomes: // - // NegC & Mask == (OpSize - PosC) & Mask - // OpSize & Mask == (NegC + PosC) & Mask - else if (Pos.getOpcode() == ISD::ADD && - Pos.getOperand(0) == NegOp1 && - Pos.getOperand(1).getOpcode() == ISD::Constant) - Width = (cast(Pos.getOperand(1))->getAPIntValue() + - NegC->getAPIntValue()); - else + // NegC & Mask == (EltSize - PosC) & Mask + // EltSize & Mask == (NegC + PosC) & Mask + else if (Pos.getOpcode() == ISD::ADD && Pos.getOperand(0) == NegOp1) { + if (ConstantSDNode *PosC = isConstOrConstSplat(Pos.getOperand(1))) + Width = PosC->getAPIntValue() + NegC->getAPIntValue(); + else + return false; + } else return false; - // Now we just need to check that OpSize & Mask == Width & Mask. + // Now we just need to check that EltSize & Mask == Width & Mask. if (MaskLoBits) - // Opsize & Mask is 0 since Mask is Opsize - 1. + // EltSize & Mask is 0 since Mask is EltSize - 1. return Width.getLoBits(MaskLoBits) == 0; - return Width == OpSize; + return Width == EltSize; } // A subroutine of MatchRotate used once we have found an OR of two opposite @@ -3844,7 +4027,7 @@ SDNode *DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos, // (srl x, (*ext y))) -> // (rotr x, y) or (rotl x, (sub 32, y)) EVT VT = Shifted.getValueType(); - if (matchRotateSub(InnerPos, InnerNeg, VT.getSizeInBits())) { + if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits())) { bool HasPos = TLI.isOperationLegalOrCustom(PosOpcode, VT); return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted, HasPos ? Pos : Neg).getNode(); @@ -3887,10 +4070,10 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { if (RHSShift.getOpcode() == ISD::SHL) { std::swap(LHS, RHS); std::swap(LHSShift, RHSShift); - std::swap(LHSMask , RHSMask ); + std::swap(LHSMask, RHSMask); } - unsigned OpSizeInBits = VT.getSizeInBits(); + unsigned EltSizeInBits = VT.getScalarSizeInBits(); SDValue LHSShiftArg = LHSShift.getOperand(0); SDValue LHSShiftAmt = LHSShift.getOperand(1); SDValue RHSShiftArg = RHSShift.getOperand(0); @@ -3898,11 +4081,10 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1) // fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2) - if (LHSShiftAmt.getOpcode() == ISD::Constant && - RHSShiftAmt.getOpcode() == ISD::Constant) { - uint64_t LShVal = cast(LHSShiftAmt)->getZExtValue(); - uint64_t RShVal = cast(RHSShiftAmt)->getZExtValue(); - if ((LShVal + RShVal) != OpSizeInBits) + if (isConstOrConstSplat(LHSShiftAmt) && isConstOrConstSplat(RHSShiftAmt)) { + uint64_t LShVal = isConstOrConstSplat(LHSShiftAmt)->getZExtValue(); + uint64_t RShVal = isConstOrConstSplat(RHSShiftAmt)->getZExtValue(); + if ((LShVal + RShVal) != EltSizeInBits) return nullptr; SDValue Rot = DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, @@ -3910,18 +4092,23 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) { // If there is an AND of either shifted operand, apply it to the result. if (LHSMask.getNode() || RHSMask.getNode()) { - APInt Mask = APInt::getAllOnesValue(OpSizeInBits); + APInt AllBits = APInt::getAllOnesValue(EltSizeInBits); + SDValue Mask = DAG.getConstant(AllBits, DL, VT); if (LHSMask.getNode()) { - APInt RHSBits = APInt::getLowBitsSet(OpSizeInBits, LShVal); - Mask &= cast(LHSMask)->getAPIntValue() | RHSBits; + APInt RHSBits = APInt::getLowBitsSet(EltSizeInBits, LShVal); + Mask = DAG.getNode(ISD::AND, DL, VT, Mask, + DAG.getNode(ISD::OR, DL, VT, LHSMask, + DAG.getConstant(RHSBits, DL, VT))); } if (RHSMask.getNode()) { - APInt LHSBits = APInt::getHighBitsSet(OpSizeInBits, RShVal); - Mask &= cast(RHSMask)->getAPIntValue() | LHSBits; + APInt LHSBits = APInt::getHighBitsSet(EltSizeInBits, RShVal); + Mask = DAG.getNode(ISD::AND, DL, VT, Mask, + DAG.getNode(ISD::OR, DL, VT, RHSMask, + DAG.getConstant(LHSBits, DL, VT))); } - Rot = DAG.getNode(ISD::AND, DL, VT, Rot, DAG.getConstant(Mask, DL, VT)); + Rot = DAG.getNode(ISD::AND, DL, VT, Rot, Mask); } return Rot.getNode(); @@ -4111,10 +4298,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { } // Simplify: xor (op x...), (op y...) -> (op (xor x, y)) - if (N0.getOpcode() == N1.getOpcode()) { - SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N); - if (Tmp.getNode()) return Tmp; - } + if (N0.getOpcode() == N1.getOpcode()) + if (SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N)) + return Tmp; // Simplify the expression using non-local knowledge. if (!VT.isVector() && @@ -4433,12 +4619,19 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { return DAG.getNode(ISD::ADD, SDLoc(N), VT, Shl0, Shl1); } - if (N1C && !N1C->isOpaque()) { - SDValue NewSHL = visitShiftByConstant(N, N1C); - if (NewSHL.getNode()) - return NewSHL; + // fold (shl (mul x, c1), c2) -> (mul x, c1 << c2) + if (N1C && N0.getOpcode() == ISD::MUL && N0.getNode()->hasOneUse()) { + if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { + if (SDValue Folded = + DAG.FoldConstantArithmetic(ISD::SHL, SDLoc(N1), VT, N0C1, N1C)) + return DAG.getNode(ISD::MUL, SDLoc(N), VT, N0.getOperand(0), Folded); + } } + if (N1C && !N1C->isOpaque()) + if (SDValue NewSHL = visitShiftByConstant(N, N1C)) + return NewSHL; + return SDValue(); } @@ -4582,11 +4775,9 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { if (DAG.SignBitIsZero(N0)) return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, N1); - if (N1C && !N1C->isOpaque()) { - SDValue NewSRA = visitShiftByConstant(N, N1C); - if (NewSRA.getNode()) + if (N1C && !N1C->isOpaque()) + if (SDValue NewSRA = visitShiftByConstant(N, N1C)) return NewSRA; - } return SDValue(); } @@ -4743,8 +4934,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), (trunc c))). if (N1.getOpcode() == ISD::TRUNCATE && N1.getOperand(0).getOpcode() == ISD::AND) { - SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode()); - if (NewOp1.getNode()) + if (SDValue NewOp1 = distributeTruncateThroughAnd(N1.getNode())) return DAG.getNode(ISD::SRL, SDLoc(N), VT, N0, NewOp1); } @@ -4753,15 +4943,12 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { if (N1C && SimplifyDemandedBits(SDValue(N, 0))) return SDValue(N, 0); - if (N1C && !N1C->isOpaque()) { - SDValue NewSRL = visitShiftByConstant(N, N1C); - if (NewSRL.getNode()) + if (N1C && !N1C->isOpaque()) + if (SDValue NewSRL = visitShiftByConstant(N, N1C)) return NewSRL; - } // Attempt to convert a srl of a load into a narrower zero-extending load. - SDValue NarrowLoad = ReduceLoadWidth(N); - if (NarrowLoad.getNode()) + if (SDValue NarrowLoad = ReduceLoadWidth(N)) return NarrowLoad; // Here is a common situation. We want to optimize: @@ -4972,70 +5159,47 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { if (SimplifySelectOps(N, N1, N2)) return SDValue(N, 0); // Don't revisit N. - // fold selects based on a setcc into other things, such as min/max/abs - if (N0.getOpcode() == ISD::SETCC) { - // select x, y (fcmp lt x, y) -> fminnum x, y - // select x, y (fcmp gt x, y) -> fmaxnum x, y - // - // This is OK if we don't care about what happens if either operand is a - // NaN. - // - - // FIXME: Instead of testing for UnsafeFPMath, this should be checking for - // no signed zeros as well as no nans. - const TargetOptions &Options = DAG.getTarget().Options; - if (Options.UnsafeFPMath && - VT.isFloatingPoint() && N0.hasOneUse() && - DAG.isKnownNeverNaN(N1) && DAG.isKnownNeverNaN(N2)) { - ISD::CondCode CC = cast(N0.getOperand(2))->get(); - - SDValue FMinMax = - combineMinNumMaxNum(SDLoc(N), VT, N0.getOperand(0), N0.getOperand(1), - N1, N2, CC, TLI, DAG); - if (FMinMax) - return FMinMax; - } - - if ((!LegalOperations && - TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT)) || - TLI.isOperationLegal(ISD::SELECT_CC, VT)) - return DAG.getNode(ISD::SELECT_CC, SDLoc(N), VT, - N0.getOperand(0), N0.getOperand(1), - N1, N2, N0.getOperand(2)); - return SimplifySelect(SDLoc(N), N0, N1, N2); - } - - if (VT0 == MVT::i1) { - if (TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT)) { - // select (and Cond0, Cond1), X, Y - // -> select Cond0, (select Cond1, X, Y), Y - if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) { - SDValue Cond0 = N0->getOperand(0); - SDValue Cond1 = N0->getOperand(1); - SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N), - N1.getValueType(), Cond1, N1, N2); - return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0, - InnerSelect, N2); - } - // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y) - if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) { - SDValue Cond0 = N0->getOperand(0); - SDValue Cond1 = N0->getOperand(1); - SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N), - N1.getValueType(), Cond1, N1, N2); - return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0, N1, - InnerSelect); - } + if (VT0 == MVT::i1) { + // The code in this block deals with the following 2 equivalences: + // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y)) + // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y) + // The target can specify its prefered form with the + // shouldNormalizeToSelectSequence() callback. However we always transform + // to the right anyway if we find the inner select exists in the DAG anyway + // and we always transform to the left side if we know that we can further + // optimize the combination of the conditions. + bool normalizeToSequence + = TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT); + // select (and Cond0, Cond1), X, Y + // -> select Cond0, (select Cond1, X, Y), Y + if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) { + SDValue Cond0 = N0->getOperand(0); + SDValue Cond1 = N0->getOperand(1); + SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N), + N1.getValueType(), Cond1, N1, N2); + if (normalizeToSequence || !InnerSelect.use_empty()) + return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0, + InnerSelect, N2); + } + // select (or Cond0, Cond1), X, Y -> select Cond0, X, (select Cond1, X, Y) + if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) { + SDValue Cond0 = N0->getOperand(0); + SDValue Cond1 = N0->getOperand(1); + SDValue InnerSelect = DAG.getNode(ISD::SELECT, SDLoc(N), + N1.getValueType(), Cond1, N1, N2); + if (normalizeToSequence || !InnerSelect.use_empty()) + return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Cond0, N1, + InnerSelect); } // select Cond0, (select Cond1, X, Y), Y -> select (and Cond0, Cond1), X, Y - if (N1->getOpcode() == ISD::SELECT) { + if (N1->getOpcode() == ISD::SELECT && N1->hasOneUse()) { SDValue N1_0 = N1->getOperand(0); SDValue N1_1 = N1->getOperand(1); SDValue N1_2 = N1->getOperand(2); if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) { // Create the actual and node if we can generate good code for it. - if (!TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT)) { + if (!normalizeToSequence) { SDValue And = DAG.getNode(ISD::AND, SDLoc(N), N0.getValueType(), N0, N1_0); return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), And, @@ -5048,13 +5212,13 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { } } // select Cond0, X, (select Cond1, X, Y) -> select (or Cond0, Cond1), X, Y - if (N2->getOpcode() == ISD::SELECT) { + if (N2->getOpcode() == ISD::SELECT && N2->hasOneUse()) { SDValue N2_0 = N2->getOperand(0); SDValue N2_1 = N2->getOperand(1); SDValue N2_2 = N2->getOperand(2); if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) { // Create the actual or node if we can generate good code for it. - if (!TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT)) { + if (!normalizeToSequence) { SDValue Or = DAG.getNode(ISD::OR, SDLoc(N), N0.getValueType(), N0, N2_0); return DAG.getNode(ISD::SELECT, SDLoc(N), N1.getValueType(), Or, @@ -5068,6 +5232,38 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { } } + // fold selects based on a setcc into other things, such as min/max/abs + if (N0.getOpcode() == ISD::SETCC) { + // select x, y (fcmp lt x, y) -> fminnum x, y + // select x, y (fcmp gt x, y) -> fmaxnum x, y + // + // This is OK if we don't care about what happens if either operand is a + // NaN. + // + + // FIXME: Instead of testing for UnsafeFPMath, this should be checking for + // no signed zeros as well as no nans. + const TargetOptions &Options = DAG.getTarget().Options; + if (Options.UnsafeFPMath && + VT.isFloatingPoint() && N0.hasOneUse() && + DAG.isKnownNeverNaN(N1) && DAG.isKnownNeverNaN(N2)) { + ISD::CondCode CC = cast(N0.getOperand(2))->get(); + + if (SDValue FMinMax = combineMinNumMaxNum(SDLoc(N), VT, N0.getOperand(0), + N0.getOperand(1), N1, N2, CC, + TLI, DAG)) + return FMinMax; + } + + if ((!LegalOperations && + TLI.isOperationLegalOrCustom(ISD::SELECT_CC, VT)) || + TLI.isOperationLegal(ISD::SELECT_CC, VT)) + return DAG.getNode(ISD::SELECT_CC, SDLoc(N), VT, + N0.getOperand(0), N0.getOperand(1), + N1, N2, N0.getOperand(2)); + return SimplifySelect(SDLoc(N), N0, N1, N2); + } + return SDValue(); } @@ -5522,8 +5718,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { if (N1.getOpcode() == ISD::CONCAT_VECTORS && N2.getOpcode() == ISD::CONCAT_VECTORS && ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) { - SDValue CV = ConvertSelectToConcatVector(N, DAG); - if (CV.getNode()) + if (SDValue CV = ConvertSelectToConcatVector(N, DAG)) return CV; } @@ -5579,7 +5774,20 @@ SDValue DAGCombiner::visitSETCC(SDNode *N) { SDLoc(N)); } -/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or +SDValue DAGCombiner::visitSETCCE(SDNode *N) { + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + SDValue Carry = N->getOperand(2); + SDValue Cond = N->getOperand(3); + + // If Carry is false, fold to a regular SETCC. + if (Carry.getOpcode() == ISD::CARRY_FALSE) + return DAG.getNode(ISD::SETCC, SDLoc(N), N->getVTList(), LHS, RHS, Cond); + + return SDValue(); +} + +/// Try to fold a sext/zext/aext dag node into a ConstantSDNode or /// a build_vector of constants. /// This function is called by the DAGCombiner when visiting sext/zext/aext /// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND). @@ -5836,8 +6044,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { if (N0.getOpcode() == ISD::TRUNCATE) { // fold (sext (truncate (load x))) -> (sext (smaller load x)) // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n))) - SDValue NarrowLoad = ReduceLoadWidth(N0.getNode()); - if (NarrowLoad.getNode()) { + if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { SDNode* oye = N0.getNode()->getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); @@ -6023,7 +6230,8 @@ SDValue DAGCombiner::visitSIGN_EXTEND(SDNode *N) { if (!VT.isVector()) { EVT SetCCVT = getSetCCResultType(N0.getOperand(0).getValueType()); - if (!LegalOperations || TLI.isOperationLegal(ISD::SETCC, SetCCVT)) { + if (!LegalOperations || + TLI.isOperationLegal(ISD::SETCC, N0.getOperand(0).getValueType())) { SDLoc DL(N); ISD::CondCode CC = cast(N0.getOperand(2))->get(); SDValue SetCC = DAG.getSetCC(DL, SetCCVT, @@ -6119,8 +6327,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // fold (zext (truncate (load x))) -> (zext (smaller load x)) // fold (zext (truncate (srl (load x), c))) -> (zext (small load (x+c/n))) if (N0.getOpcode() == ISD::TRUNCATE) { - SDValue NarrowLoad = ReduceLoadWidth(N0.getNode()); - if (NarrowLoad.getNode()) { + if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { SDNode* oye = N0.getNode()->getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); @@ -6132,32 +6339,45 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { } // fold (zext (truncate x)) -> (and x, mask) - if (N0.getOpcode() == ISD::TRUNCATE && - (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT))) { - + if (N0.getOpcode() == ISD::TRUNCATE) { // fold (zext (truncate (load x))) -> (zext (smaller load x)) // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n))) - SDValue NarrowLoad = ReduceLoadWidth(N0.getNode()); - if (NarrowLoad.getNode()) { - SDNode* oye = N0.getNode()->getOperand(0).getNode(); + if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { + SDNode *oye = N0.getNode()->getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); // CombineTo deleted the truncate, if needed, but not what's under it. AddToWorklist(oye); } - return SDValue(N, 0); // Return N so it doesn't get rechecked! + return SDValue(N, 0); // Return N so it doesn't get rechecked! } - SDValue Op = N0.getOperand(0); - if (Op.getValueType().bitsLT(VT)) { - Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, Op); - AddToWorklist(Op.getNode()); - } else if (Op.getValueType().bitsGT(VT)) { - Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op); - AddToWorklist(Op.getNode()); + EVT SrcVT = N0.getOperand(0).getValueType(); + EVT MinVT = N0.getValueType(); + + // Try to mask before the extension to avoid having to generate a larger mask, + // possibly over several sub-vectors. + if (SrcVT.bitsLT(VT)) { + if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) && + TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) { + SDValue Op = N0.getOperand(0); + Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); + AddToWorklist(Op.getNode()); + return DAG.getZExtOrTrunc(Op, SDLoc(N), VT); + } + } + + if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) { + SDValue Op = N0.getOperand(0); + if (SrcVT.bitsLT(VT)) { + Op = DAG.getNode(ISD::ANY_EXTEND, SDLoc(N), VT, Op); + AddToWorklist(Op.getNode()); + } else if (SrcVT.bitsGT(VT)) { + Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, Op); + AddToWorklist(Op.getNode()); + } + return DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType()); } - return DAG.getZeroExtendInReg(Op, SDLoc(N), - N0.getValueType().getScalarType()); } // Fold (zext (and (trunc x), cst)) -> (and x, cst), @@ -6218,6 +6438,8 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { // fold (zext (and/or/xor (load x), cst)) -> // (and/or/xor (zextload x), (zext cst)) + // Unless (and (load x) cst) will match as a zextload already and has + // additional users. if ((N0.getOpcode() == ISD::AND || N0.getOpcode() == ISD::OR || N0.getOpcode() == ISD::XOR) && isa(N0.getOperand(0)) && @@ -6228,9 +6450,20 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) { if (LN0->getExtensionType() != ISD::SEXTLOAD && LN0->isUnindexed()) { bool DoXform = true; SmallVector SetCCs; - if (!N0.hasOneUse()) - DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0), ISD::ZERO_EXTEND, - SetCCs, TLI); + if (!N0.hasOneUse()) { + if (N0.getOpcode() == ISD::AND) { + auto *AndC = cast(N0.getOperand(1)); + auto NarrowLoad = false; + EVT LoadResultTy = AndC->getValueType(0); + EVT ExtVT, LoadedVT; + if (isAndLoadExtLoad(AndC, LN0, LoadResultTy, ExtVT, LoadedVT, + NarrowLoad)) + DoXform = false; + } + if (DoXform) + DoXform = ExtendUsesToFormExtLoad(N, N0.getOperand(0), + ISD::ZERO_EXTEND, SetCCs, TLI); + } if (DoXform) { SDValue ExtLoad = DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(LN0), VT, LN0->getChain(), LN0->getBasePtr(), @@ -6377,8 +6610,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) { // fold (aext (truncate (load x))) -> (aext (smaller load x)) // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n))) if (N0.getOpcode() == ISD::TRUNCATE) { - SDValue NarrowLoad = ReduceLoadWidth(N0.getNode()); - if (NarrowLoad.getNode()) { + if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { SDNode* oye = N0.getNode()->getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); @@ -6545,8 +6777,7 @@ SDValue DAGCombiner::GetDemandedBits(SDValue V, const APInt &Mask) { // Watch out for shift count overflow though. if (Amt >= Mask.getBitWidth()) break; APInt NewMask = Mask << Amt; - SDValue SimplifyLHS = GetDemandedBits(V.getOperand(0), NewMask); - if (SimplifyLHS.getNode()) + if (SDValue SimplifyLHS = GetDemandedBits(V.getOperand(0), NewMask)) return DAG.getNode(ISD::SRL, SDLoc(V), V.getValueType(), SimplifyLHS, V.getOperand(1)); } @@ -6684,9 +6915,13 @@ SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { uint64_t PtrOff = ShAmt / 8; unsigned NewAlign = MinAlign(LN0->getAlignment(), PtrOff); SDLoc DL(LN0); + // The original load itself didn't wrap, so an offset within it doesn't. + SDNodeFlags Flags; + Flags.setNoUnsignedWrap(true); SDValue NewPtr = DAG.getNode(ISD::ADD, DL, PtrType, LN0->getBasePtr(), - DAG.getConstant(PtrOff, DL, PtrType)); + DAG.getConstant(PtrOff, DL, PtrType), + &Flags); AddToWorklist(NewPtr.getNode()); SDValue Load; @@ -6735,8 +6970,11 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { unsigned VTBits = VT.getScalarType().getSizeInBits(); unsigned EVTBits = EVT.getScalarType().getSizeInBits(); + if (N0.isUndef()) + return DAG.getUNDEF(VT); + // fold (sext_in_reg c1) -> c1 - if (isa(N0) || N0.getOpcode() == ISD::UNDEF) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::SIGN_EXTEND_INREG, SDLoc(N), VT, N0, N1); // If the input is already sign extended, just drop the extension. @@ -6770,8 +7008,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // fold (sext_in_reg (load x)) -> (smaller sextload x) // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits)) - SDValue NarrowLoad = ReduceLoadWidth(N); - if (NarrowLoad.getNode()) + if (SDValue NarrowLoad = ReduceLoadWidth(N)) return NarrowLoad; // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24) @@ -6830,29 +7067,6 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { BSwap, N1); } - // Fold a sext_inreg of a build_vector of ConstantSDNodes or undefs - // into a build_vector. - if (ISD::isBuildVectorOfConstantSDNodes(N0.getNode())) { - SmallVector Elts; - unsigned NumElts = N0->getNumOperands(); - unsigned ShAmt = VTBits - EVTBits; - - for (unsigned i = 0; i != NumElts; ++i) { - SDValue Op = N0->getOperand(i); - if (Op->getOpcode() == ISD::UNDEF) { - Elts.push_back(Op); - continue; - } - - ConstantSDNode *CurrentND = cast(Op); - const APInt &C = APInt(VTBits, CurrentND->getAPIntValue().getZExtValue()); - Elts.push_back(DAG.getConstant(C.shl(ShAmt).ashr(ShAmt).getZExtValue(), - SDLoc(Op), Op.getValueType())); - } - - return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Elts); - } - return SDValue(); } @@ -6926,7 +7140,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { SDValue EltNo = N0->getOperand(1); if (isa(EltNo) && isTypeLegal(NVT)) { int Elt = cast(EltNo)->getZExtValue(); - EVT IndexTy = TLI.getVectorIdxTy(); + EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout()); int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1)); SDValue V = DAG.getNode(ISD::BITCAST, SDLoc(N), @@ -6998,9 +7212,9 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { // fold (truncate (load x)) -> (smaller load x) // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits)) if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) { - SDValue Reduced = ReduceLoadWidth(N); - if (Reduced.getNode()) + if (SDValue Reduced = ReduceLoadWidth(N)) return Reduced; + // Handle the case where the load remains an extending load even // after truncation. if (N0.hasOneUse() && ISD::isUNINDEXEDLoad(N0.getNode())) { @@ -7106,6 +7320,12 @@ SDValue DAGCombiner::CombineConsecutiveLoads(SDNode *N, EVT VT) { return SDValue(); } +static unsigned getPPCf128HiElementSelector(const SelectionDAG &DAG) { + // On little-endian machines, bitcasting from ppcf128 to i128 does swap the Hi + // and Lo parts; on big-endian machines it doesn't. + return DAG.getDataLayout().isBigEndian() ? 1 : 0; +} + SDValue DAGCombiner::visitBITCAST(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); @@ -7150,8 +7370,8 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { // Do not change the width of a volatile load. !cast(N0)->isVolatile() && // Do not remove the cast if the types differ in endian layout. - TLI.hasBigEndianPartOrdering(N0.getValueType()) == - TLI.hasBigEndianPartOrdering(VT) && + TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) == + TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) && (!LegalOperations || TLI.isOperationLegal(ISD::LOAD, VT)) && TLI.isLoadBitCastBeneficial(N0.getValueType(), VT)) { LoadSDNode *LN0 = cast(N0); @@ -7172,6 +7392,15 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { // fold (bitconvert (fneg x)) -> (xor (bitconvert x), signbit) // fold (bitconvert (fabs x)) -> (and (bitconvert x), (not signbit)) + // + // For ppc_fp128: + // fold (bitcast (fneg x)) -> + // flipbit = signbit + // (xor (bitcast x) (build_pair flipbit, flipbit)) + // + // fold (bitcast (fabs x)) -> + // flipbit = (and (extract_element (bitcast x), 0), signbit) + // (xor (bitcast x) (build_pair flipbit, flipbit)) // This often reduces constant pool loads. if (((N0.getOpcode() == ISD::FNEG && !TLI.isFNegFree(N0.getValueType())) || (N0.getOpcode() == ISD::FABS && !TLI.isFAbsFree(N0.getValueType()))) && @@ -7182,6 +7411,29 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { AddToWorklist(NewConv.getNode()); SDLoc DL(N); + if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) { + assert(VT.getSizeInBits() == 128); + SDValue SignBit = DAG.getConstant( + APInt::getSignBit(VT.getSizeInBits() / 2), SDLoc(N0), MVT::i64); + SDValue FlipBit; + if (N0.getOpcode() == ISD::FNEG) { + FlipBit = SignBit; + AddToWorklist(FlipBit.getNode()); + } else { + assert(N0.getOpcode() == ISD::FABS); + SDValue Hi = + DAG.getNode(ISD::EXTRACT_ELEMENT, SDLoc(NewConv), MVT::i64, NewConv, + DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG), + SDLoc(NewConv))); + AddToWorklist(Hi.getNode()); + FlipBit = DAG.getNode(ISD::AND, SDLoc(N0), MVT::i64, Hi, SignBit); + AddToWorklist(FlipBit.getNode()); + } + SDValue FlipBits = + DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit); + AddToWorklist(FlipBits.getNode()); + return DAG.getNode(ISD::XOR, DL, VT, NewConv, FlipBits); + } APInt SignBit = APInt::getSignBit(VT.getSizeInBits()); if (N0.getOpcode() == ISD::FNEG) return DAG.getNode(ISD::XOR, DL, VT, @@ -7195,6 +7447,13 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { // (or (and (bitconvert x), sign), (and cst, (not sign))) // Note that we don't handle (copysign x, cst) because this can always be // folded to an fneg or fabs. + // + // For ppc_fp128: + // fold (bitcast (fcopysign cst, x)) -> + // flipbit = (and (extract_element + // (xor (bitcast cst), (bitcast x)), 0), + // signbit) + // (xor (bitcast cst) (build_pair flipbit, flipbit)) if (N0.getOpcode() == ISD::FCOPYSIGN && N0.getNode()->hasOneUse() && isa(N0.getOperand(0)) && VT.isInteger() && !VT.isVector()) { @@ -7223,6 +7482,30 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { AddToWorklist(X.getNode()); } + if (N0.getValueType() == MVT::ppcf128 && !LegalTypes) { + APInt SignBit = APInt::getSignBit(VT.getSizeInBits() / 2); + SDValue Cst = DAG.getNode(ISD::BITCAST, SDLoc(N0.getOperand(0)), VT, + N0.getOperand(0)); + AddToWorklist(Cst.getNode()); + SDValue X = DAG.getNode(ISD::BITCAST, SDLoc(N0.getOperand(1)), VT, + N0.getOperand(1)); + AddToWorklist(X.getNode()); + SDValue XorResult = DAG.getNode(ISD::XOR, SDLoc(N0), VT, Cst, X); + AddToWorklist(XorResult.getNode()); + SDValue XorResult64 = DAG.getNode( + ISD::EXTRACT_ELEMENT, SDLoc(XorResult), MVT::i64, XorResult, + DAG.getIntPtrConstant(getPPCf128HiElementSelector(DAG), + SDLoc(XorResult))); + AddToWorklist(XorResult64.getNode()); + SDValue FlipBit = + DAG.getNode(ISD::AND, SDLoc(XorResult64), MVT::i64, XorResult64, + DAG.getConstant(SignBit, SDLoc(XorResult64), MVT::i64)); + AddToWorklist(FlipBit.getNode()); + SDValue FlipBits = + DAG.getNode(ISD::BUILD_PAIR, SDLoc(N0), VT, FlipBit, FlipBit); + AddToWorklist(FlipBits.getNode()); + return DAG.getNode(ISD::XOR, SDLoc(N), VT, Cst, FlipBits); + } APInt SignBit = APInt::getSignBit(VT.getSizeInBits()); X = DAG.getNode(ISD::AND, SDLoc(X), VT, X, DAG.getConstant(SignBit, SDLoc(X), VT)); @@ -7239,11 +7522,9 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { } // bitconvert(build_pair(ld, ld)) -> ld iff load locations are consecutive. - if (N0.getOpcode() == ISD::BUILD_PAIR) { - SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT); - if (CombineLD.getNode()) + if (N0.getOpcode() == ISD::BUILD_PAIR) + if (SDValue CombineLD = CombineConsecutiveLoads(N0.getNode(), VT)) return CombineLD; - } // Remove double bitcasts from shuffles - this is often a legacy of // XformToShuffleWithZero being used to combine bitmaskings (of @@ -7256,10 +7537,10 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { ShuffleVectorSDNode *SVN = cast(N0); // If operands are a bitcast, peek through if it casts the original VT. - // If operands are a UNDEF or constant, just bitcast back to original VT. + // If operands are a constant, just bitcast back to original VT. auto PeekThroughBitcast = [&](SDValue Op) { if (Op.getOpcode() == ISD::BITCAST && - Op.getOperand(0)->getValueType(0) == VT) + Op.getOperand(0).getValueType() == VT) return SDValue(Op.getOperand(0)); if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) || ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) @@ -7430,28 +7711,34 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDLoc SL(N); const TargetOptions &Options = DAG.getTarget().Options; - bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast || - Options.UnsafeFPMath); + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); // Floating-point multiply-add with intermediate rounding. - bool HasFMAD = (LegalOperations && - TLI.isOperationLegal(ISD::FMAD, VT)); + bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); // Floating-point multiply-add without intermediate rounding. - bool HasFMA = ((!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::FMA, VT)) && - TLI.isFMAFasterThanFMulAndFAdd(VT) && - UnsafeFPMath); + bool HasFMA = + AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) return SDValue(); // Always prefer FMAD to FMA for precision. - unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; + unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); bool LookThroughFPExt = TLI.isFPExtFree(VT); + // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), + // prefer to fold the multiply with fewer uses. + if (Aggressive && N0.getOpcode() == ISD::FMUL && + N1.getOpcode() == ISD::FMUL) { + if (N0.getNode()->use_size() > N1.getNode()->use_size()) + std::swap(N0, N1); + } + // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (N0.getOpcode() == ISD::FMUL && (Aggressive || N0->hasOneUse())) { @@ -7468,7 +7755,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { } // Look through FP_EXTEND nodes to do more combining. - if (UnsafeFPMath && LookThroughFPExt) { + if (AllowFusion && LookThroughFPExt) { // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) if (N0.getOpcode() == ISD::FP_EXTEND) { SDValue N00 = N0.getOperand(0); @@ -7494,7 +7781,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { } // More folding opportunities when target permits. - if ((UnsafeFPMath || HasFMAD) && Aggressive) { + if ((AllowFusion || HasFMAD) && Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) if (N0.getOpcode() == PreferredFusedOpcode && N0.getOperand(2).getOpcode() == ISD::FMUL) { @@ -7517,7 +7804,7 @@ SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { N0)); } - if (UnsafeFPMath && LookThroughFPExt) { + if (AllowFusion && LookThroughFPExt) { // fold (fadd (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y, (fma (fpext u), (fpext v), z)) auto FoldFAddFMAFPExtFMul = [&] ( @@ -7607,25 +7894,23 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { SDLoc SL(N); const TargetOptions &Options = DAG.getTarget().Options; - bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast || - Options.UnsafeFPMath); + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); // Floating-point multiply-add with intermediate rounding. - bool HasFMAD = (LegalOperations && - TLI.isOperationLegal(ISD::FMAD, VT)); + bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); // Floating-point multiply-add without intermediate rounding. - bool HasFMA = ((!LegalOperations || - TLI.isOperationLegalOrCustom(ISD::FMA, VT)) && - TLI.isFMAFasterThanFMulAndFAdd(VT) && - UnsafeFPMath); + bool HasFMA = + AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); // No valid opcode, do not combine. if (!HasFMAD && !HasFMA) return SDValue(); // Always prefer FMAD to FMA for precision. - unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; + unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; bool Aggressive = TLI.enableAggressiveFMAFusion(VT); bool LookThroughFPExt = TLI.isFPExtFree(VT); @@ -7658,7 +7943,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { } // Look through FP_EXTEND nodes to do more combining. - if (UnsafeFPMath && LookThroughFPExt) { + if (AllowFusion && LookThroughFPExt) { // fold (fsub (fpext (fmul x, y)), z) // -> (fma (fpext x), (fpext y), (fneg z)) if (N0.getOpcode() == ISD::FP_EXTEND) { @@ -7734,7 +8019,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { } // More folding opportunities when target permits. - if ((UnsafeFPMath || HasFMAD) && Aggressive) { + if ((AllowFusion || HasFMAD) && Aggressive) { // fold (fsub (fma x, y, (fmul u, v)), z) // -> (fma x, y (fma u, v, (fneg z))) if (N0.getOpcode() == PreferredFusedOpcode && @@ -7764,7 +8049,7 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { N21, N0)); } - if (UnsafeFPMath && LookThroughFPExt) { + if (AllowFusion && LookThroughFPExt) { // fold (fsub (fma x, y, (fpext (fmul u, v))), z) // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) if (N0.getOpcode() == PreferredFusedOpcode) { @@ -7865,14 +8150,97 @@ SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { return SDValue(); } +/// Try to perform FMA combining on a given FMUL node. +SDValue DAGCombiner::visitFMULForFMACombine(SDNode *N) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + EVT VT = N->getValueType(0); + SDLoc SL(N); + + assert(N->getOpcode() == ISD::FMUL && "Expected FMUL Operation"); + + const TargetOptions &Options = DAG.getTarget().Options; + bool AllowFusion = + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); + + // Floating-point multiply-add with intermediate rounding. + bool HasFMAD = (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)); + + // Floating-point multiply-add without intermediate rounding. + bool HasFMA = + AllowFusion && TLI.isFMAFasterThanFMulAndFAdd(VT) && + (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT)); + + // No valid opcode, do not combine. + if (!HasFMAD && !HasFMA) + return SDValue(); + + // Always prefer FMAD to FMA for precision. + unsigned PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; + bool Aggressive = TLI.enableAggressiveFMAFusion(VT); + + // fold (fmul (fadd x, +1.0), y) -> (fma x, y, y) + // fold (fmul (fadd x, -1.0), y) -> (fma x, y, (fneg y)) + auto FuseFADD = [&](SDValue X, SDValue Y) { + if (X.getOpcode() == ISD::FADD && (Aggressive || X->hasOneUse())) { + auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); + if (XC1 && XC1->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y); + if (XC1 && XC1->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y)); + } + return SDValue(); + }; + + if (SDValue FMA = FuseFADD(N0, N1)) + return FMA; + if (SDValue FMA = FuseFADD(N1, N0)) + return FMA; + + // fold (fmul (fsub +1.0, x), y) -> (fma (fneg x), y, y) + // fold (fmul (fsub -1.0, x), y) -> (fma (fneg x), y, (fneg y)) + // fold (fmul (fsub x, +1.0), y) -> (fma x, y, (fneg y)) + // fold (fmul (fsub x, -1.0), y) -> (fma x, y, y) + auto FuseFSUB = [&](SDValue X, SDValue Y) { + if (X.getOpcode() == ISD::FSUB && (Aggressive || X->hasOneUse())) { + auto XC0 = isConstOrConstSplatFP(X.getOperand(0)); + if (XC0 && XC0->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, + Y); + if (XC0 && XC0->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, X.getOperand(1)), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y)); + + auto XC1 = isConstOrConstSplatFP(X.getOperand(1)); + if (XC1 && XC1->isExactlyValue(+1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, + DAG.getNode(ISD::FNEG, SL, VT, Y)); + if (XC1 && XC1->isExactlyValue(-1.0)) + return DAG.getNode(PreferredFusedOpcode, SL, VT, X.getOperand(0), Y, Y); + } + return SDValue(); + }; + + if (SDValue FMA = FuseFSUB(N0, N1)) + return FMA; + if (SDValue FMA = FuseFSUB(N1, N0)) + return FMA; + + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - ConstantFPSDNode *N0CFP = dyn_cast(N0); - ConstantFPSDNode *N1CFP = dyn_cast(N1); + bool N0CFP = isConstantFPBuildVectorOrConstantFP(N0); + bool N1CFP = isConstantFPBuildVectorOrConstantFP(N1); EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; + const SDNodeFlags *Flags = &cast(N)->Flags; // fold vector ops if (VT.isVector()) @@ -7881,23 +8249,23 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { // fold (fadd c1, c2) -> c1 + c2 if (N0CFP && N1CFP) - return DAG.getNode(ISD::FADD, DL, VT, N0, N1); + return DAG.getNode(ISD::FADD, DL, VT, N0, N1, Flags); // canonicalize constant to RHS if (N0CFP && !N1CFP) - return DAG.getNode(ISD::FADD, DL, VT, N1, N0); + return DAG.getNode(ISD::FADD, DL, VT, N1, N0, Flags); // fold (fadd A, (fneg B)) -> (fsub A, B) if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) && isNegatibleForFree(N1, LegalOperations, TLI, &Options) == 2) return DAG.getNode(ISD::FSUB, DL, VT, N0, - GetNegatedExpression(N1, DAG, LegalOperations)); + GetNegatedExpression(N1, DAG, LegalOperations), Flags); // fold (fadd (fneg A), B) -> (fsub B, A) if ((!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FSUB, VT)) && isNegatibleForFree(N0, LegalOperations, TLI, &Options) == 2) return DAG.getNode(ISD::FSUB, DL, VT, N1, - GetNegatedExpression(N0, DAG, LegalOperations)); + GetNegatedExpression(N0, DAG, LegalOperations), Flags); // If 'unsafe math' is enabled, fold lots of things. if (Options.UnsafeFPMath) { @@ -7906,14 +8274,17 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { bool AllowNewConst = (Level < AfterLegalizeDAG); // fold (fadd A, 0) -> A - if (N1CFP && N1CFP->isZero()) - return N0; + if (ConstantFPSDNode *N1C = isConstOrConstSplatFP(N1)) + if (N1C->isZero()) + return N0; // fold (fadd (fadd x, c1), c2) -> (fadd x, (fadd c1, c2)) if (N1CFP && N0.getOpcode() == ISD::FADD && N0.getNode()->hasOneUse() && - isa(N0.getOperand(1))) + isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) return DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(0), - DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1)); + DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), N1, + Flags), + Flags); // If allowed, fold (fadd (fneg x), x) -> 0.0 if (AllowNewConst && N0.getOpcode() == ISD::FNEG && N0.getOperand(0) == N1) @@ -7928,64 +8299,64 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { // of rounding steps. if (TLI.isOperationLegalOrCustom(ISD::FMUL, VT) && !N0CFP && !N1CFP) { if (N0.getOpcode() == ISD::FMUL) { - ConstantFPSDNode *CFP00 = dyn_cast(N0.getOperand(0)); - ConstantFPSDNode *CFP01 = dyn_cast(N0.getOperand(1)); + bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0)); + bool CFP01 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(1)); // (fadd (fmul x, c), x) -> (fmul x, c+1) if (CFP01 && !CFP00 && N0.getOperand(0) == N1) { - SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, SDValue(CFP01, 0), - DAG.getConstantFP(1.0, DL, VT)); - return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP); + SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), + DAG.getConstantFP(1.0, DL, VT), Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N1, NewCFP, Flags); } // (fadd (fmul x, c), (fadd x, x)) -> (fmul x, c+2) if (CFP01 && !CFP00 && N1.getOpcode() == ISD::FADD && N1.getOperand(0) == N1.getOperand(1) && N0.getOperand(0) == N1.getOperand(0)) { - SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, SDValue(CFP01, 0), - DAG.getConstantFP(2.0, DL, VT)); - return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP); + SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N0.getOperand(1), + DAG.getConstantFP(2.0, DL, VT), Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), NewCFP, Flags); } } if (N1.getOpcode() == ISD::FMUL) { - ConstantFPSDNode *CFP10 = dyn_cast(N1.getOperand(0)); - ConstantFPSDNode *CFP11 = dyn_cast(N1.getOperand(1)); + bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0)); + bool CFP11 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(1)); // (fadd x, (fmul x, c)) -> (fmul x, c+1) if (CFP11 && !CFP10 && N1.getOperand(0) == N0) { - SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, SDValue(CFP11, 0), - DAG.getConstantFP(1.0, DL, VT)); - return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP); + SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1), + DAG.getConstantFP(1.0, DL, VT), Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N0, NewCFP, Flags); } // (fadd (fadd x, x), (fmul x, c)) -> (fmul x, c+2) if (CFP11 && !CFP10 && N0.getOpcode() == ISD::FADD && N0.getOperand(0) == N0.getOperand(1) && N1.getOperand(0) == N0.getOperand(0)) { - SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, SDValue(CFP11, 0), - DAG.getConstantFP(2.0, DL, VT)); - return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP); + SDValue NewCFP = DAG.getNode(ISD::FADD, DL, VT, N1.getOperand(1), + DAG.getConstantFP(2.0, DL, VT), Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N1.getOperand(0), NewCFP, Flags); } } if (N0.getOpcode() == ISD::FADD && AllowNewConst) { - ConstantFPSDNode *CFP = dyn_cast(N0.getOperand(0)); + bool CFP00 = isConstantFPBuildVectorOrConstantFP(N0.getOperand(0)); // (fadd (fadd x, x), x) -> (fmul x, 3.0) - if (!CFP && N0.getOperand(0) == N0.getOperand(1) && + if (!CFP00 && N0.getOperand(0) == N0.getOperand(1) && (N0.getOperand(0) == N1)) { return DAG.getNode(ISD::FMUL, DL, VT, - N1, DAG.getConstantFP(3.0, DL, VT)); + N1, DAG.getConstantFP(3.0, DL, VT), Flags); } } if (N1.getOpcode() == ISD::FADD && AllowNewConst) { - ConstantFPSDNode *CFP10 = dyn_cast(N1.getOperand(0)); + bool CFP10 = isConstantFPBuildVectorOrConstantFP(N1.getOperand(0)); // (fadd x, (fadd x, x)) -> (fmul x, 3.0) if (!CFP10 && N1.getOperand(0) == N1.getOperand(1) && N1.getOperand(0) == N0) { return DAG.getNode(ISD::FMUL, DL, VT, - N0, DAG.getConstantFP(3.0, DL, VT)); + N0, DAG.getConstantFP(3.0, DL, VT), Flags); } } @@ -7995,15 +8366,14 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { N0.getOperand(0) == N0.getOperand(1) && N1.getOperand(0) == N1.getOperand(1) && N0.getOperand(0) == N1.getOperand(0)) { - return DAG.getNode(ISD::FMUL, DL, VT, - N0.getOperand(0), DAG.getConstantFP(4.0, DL, VT)); + return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), + DAG.getConstantFP(4.0, DL, VT), Flags); } } } // enable-unsafe-fp-math // FADD -> FMA combines: - SDValue Fused = visitFADDForFMACombine(N); - if (Fused) { + if (SDValue Fused = visitFADDForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -8019,6 +8389,7 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { EVT VT = N->getValueType(0); SDLoc dl(N); const TargetOptions &Options = DAG.getTarget().Options; + const SDNodeFlags *Flags = &cast(N)->Flags; // fold vector ops if (VT.isVector()) @@ -8027,12 +8398,12 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { // fold (fsub c1, c2) -> c1-c2 if (N0CFP && N1CFP) - return DAG.getNode(ISD::FSUB, dl, VT, N0, N1); + return DAG.getNode(ISD::FSUB, dl, VT, N0, N1, Flags); // fold (fsub A, (fneg B)) -> (fadd A, B) if (isNegatibleForFree(N1, LegalOperations, TLI, &Options)) return DAG.getNode(ISD::FADD, dl, VT, N0, - GetNegatedExpression(N1, DAG, LegalOperations)); + GetNegatedExpression(N1, DAG, LegalOperations), Flags); // If 'unsafe math' is enabled, fold lots of things. if (Options.UnsafeFPMath) { @@ -8067,8 +8438,7 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { } // FSUB -> FMA combines: - SDValue Fused = visitFSUBForFMACombine(N); - if (Fused) { + if (SDValue Fused = visitFSUBForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -8084,6 +8454,7 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; + const SDNodeFlags *Flags = &cast(N)->Flags; // fold vector ops if (VT.isVector()) { @@ -8094,12 +8465,12 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { // fold (fmul c1, c2) -> c1*c2 if (N0CFP && N1CFP) - return DAG.getNode(ISD::FMUL, DL, VT, N0, N1); + return DAG.getNode(ISD::FMUL, DL, VT, N0, N1, Flags); // canonicalize constant to RHS if (isConstantFPBuildVectorOrConstantFP(N0) && !isConstantFPBuildVectorOrConstantFP(N1)) - return DAG.getNode(ISD::FMUL, DL, VT, N1, N0); + return DAG.getNode(ISD::FMUL, DL, VT, N1, N0, Flags); // fold (fmul A, 1.0) -> A if (N1CFP && N1CFP->isExactlyValue(1.0)) @@ -8128,8 +8499,8 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { // the second operand of the outer multiply are constants. if ((N1CFP && isConstOrConstSplatFP(N01)) || (BV1 && BV01 && BV1->isConstant() && BV01->isConstant())) { - SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1); - return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts); + SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, N01, N1, Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N00, MulConsts, Flags); } } } @@ -8138,16 +8509,18 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { // Undo the fmul 2.0, x -> fadd x, x transformation, since if it occurs // during an early run of DAGCombiner can prevent folding with fmuls // inserted during lowering. - if (N0.getOpcode() == ISD::FADD && N0.getOperand(0) == N0.getOperand(1)) { + if (N0.getOpcode() == ISD::FADD && + (N0.getOperand(0) == N0.getOperand(1)) && + N0.hasOneUse()) { const SDValue Two = DAG.getConstantFP(2.0, DL, VT); - SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1); - return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts); + SDValue MulConsts = DAG.getNode(ISD::FMUL, DL, VT, Two, N1, Flags); + return DAG.getNode(ISD::FMUL, DL, VT, N0.getOperand(0), MulConsts, Flags); } } // fold (fmul X, 2.0) -> (fadd X, X) if (N1CFP && N1CFP->isExactlyValue(+2.0)) - return DAG.getNode(ISD::FADD, DL, VT, N0, N0); + return DAG.getNode(ISD::FADD, DL, VT, N0, N0, Flags); // fold (fmul X, -1.0) -> (fneg X) if (N1CFP && N1CFP->isExactlyValue(-1.0)) @@ -8162,10 +8535,17 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { if (LHSNeg == 2 || RHSNeg == 2) return DAG.getNode(ISD::FMUL, DL, VT, GetNegatedExpression(N0, DAG, LegalOperations), - GetNegatedExpression(N1, DAG, LegalOperations)); + GetNegatedExpression(N1, DAG, LegalOperations), + Flags); } } + // FMUL -> FMA combines: + if (SDValue Fused = visitFMULForFMACombine(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); } @@ -8192,66 +8572,145 @@ SDValue DAGCombiner::visitFMA(SDNode *N) { if (N1CFP && N1CFP->isZero()) return N2; } + // TODO: The FMA node should have flags that propagate to these nodes. if (N0CFP && N0CFP->isExactlyValue(1.0)) return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2); if (N1CFP && N1CFP->isExactlyValue(1.0)) return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2); // Canonicalize (fma c, x, y) -> (fma x, c, y) - if (N0CFP && !N1CFP) + if (isConstantFPBuildVectorOrConstantFP(N0) && + !isConstantFPBuildVectorOrConstantFP(N1)) return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2); - // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) - if (Options.UnsafeFPMath && N1CFP && - N2.getOpcode() == ISD::FMUL && - N0 == N2.getOperand(0) && - N2.getOperand(1).getOpcode() == ISD::ConstantFP) { - return DAG.getNode(ISD::FMUL, dl, VT, N0, - DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1))); - } + // TODO: FMA nodes should have flags that propagate to the created nodes. + // For now, create a Flags object for use with all unsafe math transforms. + SDNodeFlags Flags; + Flags.setUnsafeAlgebra(true); + if (Options.UnsafeFPMath) { + // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2) + if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) && + isConstantFPBuildVectorOrConstantFP(N1) && + isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) { + return DAG.getNode(ISD::FMUL, dl, VT, N0, + DAG.getNode(ISD::FADD, dl, VT, N1, N2.getOperand(1), + &Flags), &Flags); + } - // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) - if (Options.UnsafeFPMath && - N0.getOpcode() == ISD::FMUL && N1CFP && - N0.getOperand(1).getOpcode() == ISD::ConstantFP) { - return DAG.getNode(ISD::FMA, dl, VT, - N0.getOperand(0), - DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1)), - N2); + // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y) + if (N0.getOpcode() == ISD::FMUL && + isConstantFPBuildVectorOrConstantFP(N1) && + isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) { + return DAG.getNode(ISD::FMA, dl, VT, + N0.getOperand(0), + DAG.getNode(ISD::FMUL, dl, VT, N1, N0.getOperand(1), + &Flags), + N2); + } } // (fma x, 1, y) -> (fadd x, y) // (fma x, -1, y) -> (fadd (fneg x), y) if (N1CFP) { if (N1CFP->isExactlyValue(1.0)) + // TODO: The FMA node should have flags that propagate to this node. return DAG.getNode(ISD::FADD, dl, VT, N0, N2); if (N1CFP->isExactlyValue(-1.0) && (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) { SDValue RHSNeg = DAG.getNode(ISD::FNEG, dl, VT, N0); AddToWorklist(RHSNeg.getNode()); + // TODO: The FMA node should have flags that propagate to this node. return DAG.getNode(ISD::FADD, dl, VT, N2, RHSNeg); } } - // (fma x, c, x) -> (fmul x, (c+1)) - if (Options.UnsafeFPMath && N1CFP && N0 == N2) - return DAG.getNode(ISD::FMUL, dl, VT, N0, - DAG.getNode(ISD::FADD, dl, VT, - N1, DAG.getConstantFP(1.0, dl, VT))); - - // (fma x, c, (fneg x)) -> (fmul x, (c-1)) - if (Options.UnsafeFPMath && N1CFP && - N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) + if (Options.UnsafeFPMath) { + // (fma x, c, x) -> (fmul x, (c+1)) + if (N1CFP && N0 == N2) { return DAG.getNode(ISD::FMUL, dl, VT, N0, - DAG.getNode(ISD::FADD, dl, VT, - N1, DAG.getConstantFP(-1.0, dl, VT))); + DAG.getNode(ISD::FADD, dl, VT, + N1, DAG.getConstantFP(1.0, dl, VT), + &Flags), &Flags); + } + // (fma x, c, (fneg x)) -> (fmul x, (c-1)) + if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) { + return DAG.getNode(ISD::FMUL, dl, VT, N0, + DAG.getNode(ISD::FADD, dl, VT, + N1, DAG.getConstantFP(-1.0, dl, VT), + &Flags), &Flags); + } + } return SDValue(); } +// Combine multiple FDIVs with the same divisor into multiple FMULs by the +// reciprocal. +// E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip) +// Notice that this is not always beneficial. One reason is different target +// may have different costs for FDIV and FMUL, so sometimes the cost of two +// FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason +// is the critical path is increased from "one FDIV" to "one FDIV + one FMUL". +SDValue DAGCombiner::combineRepeatedFPDivisors(SDNode *N) { + bool UnsafeMath = DAG.getTarget().Options.UnsafeFPMath; + const SDNodeFlags *Flags = N->getFlags(); + if (!UnsafeMath && !Flags->hasAllowReciprocal()) + return SDValue(); + + // Skip if current node is a reciprocal. + SDValue N0 = N->getOperand(0); + ConstantFPSDNode *N0CFP = dyn_cast(N0); + if (N0CFP && N0CFP->isExactlyValue(1.0)) + return SDValue(); + + // Exit early if the target does not want this transform or if there can't + // possibly be enough uses of the divisor to make the transform worthwhile. + SDValue N1 = N->getOperand(1); + unsigned MinUses = TLI.combineRepeatedFPDivisors(); + if (!MinUses || N1->use_size() < MinUses) + return SDValue(); + + // Find all FDIV users of the same divisor. + // Use a set because duplicates may be present in the user list. + SetVector Users; + for (auto *U : N1->uses()) { + if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) { + // This division is eligible for optimization only if global unsafe math + // is enabled or if this division allows reciprocal formation. + if (UnsafeMath || U->getFlags()->hasAllowReciprocal()) + Users.insert(U); + } + } + + // Now that we have the actual number of divisor uses, make sure it meets + // the minimum threshold specified by the target. + if (Users.size() < MinUses) + return SDValue(); + + EVT VT = N->getValueType(0); + SDLoc DL(N); + SDValue FPOne = DAG.getConstantFP(1.0, DL, VT); + SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1, Flags); + + // Dividend / Divisor -> Dividend * Reciprocal + for (auto *U : Users) { + SDValue Dividend = U->getOperand(0); + if (Dividend != FPOne) { + SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend, + Reciprocal, Flags); + CombineTo(U, NewNode); + } else if (U != Reciprocal.getNode()) { + // In the absence of fast-math-flags, this user node is always the + // same node as Reciprocal, but with FMF they may be different nodes. + CombineTo(U, Reciprocal); + } + } + return SDValue(N, 0); // N was replaced. +} + SDValue DAGCombiner::visitFDIV(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -8260,6 +8719,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); const TargetOptions &Options = DAG.getTarget().Options; + SDNodeFlags *Flags = &cast(N)->Flags; // fold vector ops if (VT.isVector()) @@ -8268,7 +8728,7 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { // fold (fdiv c1, c2) -> c1/c2 if (N0CFP && N1CFP) - return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1); + return DAG.getNode(ISD::FDIV, SDLoc(N), VT, N0, N1, Flags); if (Options.UnsafeFPMath) { // fold (fdiv X, c2) -> fmul X, 1/c2 if losing precision is acceptable. @@ -8287,28 +8747,30 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { TLI.isOperationLegal(llvm::ISD::ConstantFP, VT) || TLI.isFPImmLegal(Recip, VT))) return DAG.getNode(ISD::FMUL, DL, VT, N0, - DAG.getConstantFP(Recip, DL, VT)); + DAG.getConstantFP(Recip, DL, VT), Flags); } // If this FDIV is part of a reciprocal square root, it may be folded // into a target-specific square root estimate instruction. if (N1.getOpcode() == ISD::FSQRT) { - if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0))) { - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0), Flags)) { + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } else if (N1.getOpcode() == ISD::FP_EXTEND && N1.getOperand(0).getOpcode() == ISD::FSQRT) { - if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0))) { + if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0), + Flags)) { RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV); AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } else if (N1.getOpcode() == ISD::FP_ROUND && N1.getOperand(0).getOpcode() == ISD::FSQRT) { - if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0))) { + if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0), + Flags)) { RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1)); AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } else if (N1.getOpcode() == ISD::FMUL) { // Look through an FMUL. Even though this won't remove the FDIV directly, @@ -8325,18 +8787,18 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { if (SqrtOp.getNode()) { // We found a FSQRT, so try to make this fold: // x / (y * sqrt(z)) -> x * (rsqrt(z) / y) - if (SDValue RV = BuildRsqrtEstimate(SqrtOp.getOperand(0))) { - RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp); + if (SDValue RV = BuildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) { + RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags); AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } } // Fold into a reciprocal estimate and multiply instead of a real divide. - if (SDValue RV = BuildReciprocalEstimate(N1)) { + if (SDValue RV = BuildReciprocalEstimate(N1, Flags)) { AddToWorklist(RV.getNode()); - return DAG.getNode(ISD::FMUL, DL, VT, N0, RV); + return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } @@ -8348,45 +8810,13 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { if (LHSNeg == 2 || RHSNeg == 2) return DAG.getNode(ISD::FDIV, SDLoc(N), VT, GetNegatedExpression(N0, DAG, LegalOperations), - GetNegatedExpression(N1, DAG, LegalOperations)); + GetNegatedExpression(N1, DAG, LegalOperations), + Flags); } } - // Combine multiple FDIVs with the same divisor into multiple FMULs by the - // reciprocal. - // E.g., (a / D; b / D;) -> (recip = 1.0 / D; a * recip; b * recip) - // Notice that this is not always beneficial. One reason is different target - // may have different costs for FDIV and FMUL, so sometimes the cost of two - // FDIVs may be lower than the cost of one FDIV and two FMULs. Another reason - // is the critical path is increased from "one FDIV" to "one FDIV + one FMUL". - if (Options.UnsafeFPMath) { - // Skip if current node is a reciprocal. - if (N0CFP && N0CFP->isExactlyValue(1.0)) - return SDValue(); - - SmallVector Users; - // Find all FDIV users of the same divisor. - for (auto *U : N1->uses()) { - if (U->getOpcode() == ISD::FDIV && U->getOperand(1) == N1) - Users.push_back(U); - } - - if (TLI.combineRepeatedFPDivisors(Users.size())) { - SDValue FPOne = DAG.getConstantFP(1.0, DL, VT); - SDValue Reciprocal = DAG.getNode(ISD::FDIV, DL, VT, FPOne, N1); - - // Dividend / Divisor -> Dividend * Reciprocal - for (auto *U : Users) { - SDValue Dividend = U->getOperand(0); - if (Dividend != FPOne) { - SDValue NewNode = DAG.getNode(ISD::FMUL, SDLoc(U), VT, Dividend, - Reciprocal); - DAG.ReplaceAllUsesWith(U, NewNode.getNode()); - } - } - return SDValue(); - } - } + if (SDValue CombineRepeatedDivisors = combineRepeatedFPDivisors(N)) + return CombineRepeatedDivisors; return SDValue(); } @@ -8400,36 +8830,58 @@ SDValue DAGCombiner::visitFREM(SDNode *N) { // fold (frem c1, c2) -> fmod(c1,c2) if (N0CFP && N1CFP) - return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1); + return DAG.getNode(ISD::FREM, SDLoc(N), VT, N0, N1, + &cast(N)->Flags); return SDValue(); } SDValue DAGCombiner::visitFSQRT(SDNode *N) { - if (DAG.getTarget().Options.UnsafeFPMath && - !TLI.isFsqrtCheap()) { - // Compute this as X * (1/sqrt(X)) = X * (X ** -0.5) - if (SDValue RV = BuildRsqrtEstimate(N->getOperand(0))) { - EVT VT = RV.getValueType(); - SDLoc DL(N); - RV = DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV); - AddToWorklist(RV.getNode()); + if (!DAG.getTarget().Options.UnsafeFPMath || TLI.isFsqrtCheap()) + return SDValue(); - // Unfortunately, RV is now NaN if the input was exactly 0. - // Select out this case and force the answer to 0. - SDValue Zero = DAG.getConstantFP(0.0, DL, VT); - SDValue ZeroCmp = - DAG.getSetCC(DL, TLI.getSetCCResultType(*DAG.getContext(), VT), - N->getOperand(0), Zero, ISD::SETEQ); - AddToWorklist(ZeroCmp.getNode()); - AddToWorklist(RV.getNode()); + // TODO: FSQRT nodes should have flags that propagate to the created nodes. + // For now, create a Flags object for use with all unsafe math transforms. + SDNodeFlags Flags; + Flags.setUnsafeAlgebra(true); - RV = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, - DL, VT, ZeroCmp, Zero, RV); - return RV; - } + // Compute this as X * (1/sqrt(X)) = X * (X ** -0.5) + SDValue RV = BuildRsqrtEstimate(N->getOperand(0), &Flags); + if (!RV) + return SDValue(); + + EVT VT = RV.getValueType(); + SDLoc DL(N); + RV = DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV, &Flags); + AddToWorklist(RV.getNode()); + + // Unfortunately, RV is now NaN if the input was exactly 0. + // Select out this case and force the answer to 0. + SDValue Zero = DAG.getConstantFP(0.0, DL, VT); + EVT CCVT = getSetCCResultType(VT); + SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, N->getOperand(0), Zero, ISD::SETEQ); + AddToWorklist(ZeroCmp.getNode()); + AddToWorklist(RV.getNode()); + + return DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, + ZeroCmp, Zero, RV); +} + +/// copysign(x, fp_extend(y)) -> copysign(x, y) +/// copysign(x, fp_round(y)) -> copysign(x, y) +static inline bool CanCombineFCOPYSIGN_EXTEND_ROUND(SDNode *N) { + SDValue N1 = N->getOperand(1); + if ((N1.getOpcode() == ISD::FP_EXTEND || + N1.getOpcode() == ISD::FP_ROUND)) { + // Do not optimize out type conversion of f128 type yet. + // For some targets like x86_64, configuration is changed to keep one f128 + // value in one SSE register, but instruction selection cannot handle + // FCOPYSIGN on SSE registers yet. + EVT N1VT = N1->getValueType(0); + EVT N1Op0VT = N1->getOperand(0)->getValueType(0); + return (N1VT == N1Op0VT || N1Op0VT != MVT::f128); } - return SDValue(); + return false; } SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) { @@ -8475,7 +8927,7 @@ SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) { // copysign(x, fp_extend(y)) -> copysign(x, y) // copysign(x, fp_round(y)) -> copysign(x, y) - if (N1.getOpcode() == ISD::FP_EXTEND || N1.getOpcode() == ISD::FP_ROUND) + if (CanCombineFCOPYSIGN_EXTEND_ROUND(N)) return DAG.getNode(ISD::FCOPYSIGN, SDLoc(N), VT, N0, N1.getOperand(0)); @@ -8830,11 +9282,12 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { APFloat CVal = CFP1->getValueAPF(); CVal.changeSign(); if (Level >= AfterLegalizeDAG && - (TLI.isFPImmLegal(CVal, N->getValueType(0)) || - TLI.isOperationLegal(ISD::ConstantFP, N->getValueType(0)))) - return DAG.getNode( - ISD::FMUL, SDLoc(N), VT, N0.getOperand(0), - DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0.getOperand(1))); + (TLI.isFPImmLegal(CVal, VT) || + TLI.isOperationLegal(ISD::ConstantFP, VT))) + return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0.getOperand(0), + DAG.getNode(ISD::FNEG, SDLoc(N), VT, + N0.getOperand(1)), + &cast(N0)->Flags); } } @@ -8844,20 +9297,20 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { SDValue DAGCombiner::visitFMINNUM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - const ConstantFPSDNode *N0CFP = dyn_cast(N0); - const ConstantFPSDNode *N1CFP = dyn_cast(N1); + EVT VT = N->getValueType(0); + const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); + const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); if (N0CFP && N1CFP) { const APFloat &C0 = N0CFP->getValueAPF(); const APFloat &C1 = N1CFP->getValueAPF(); - return DAG.getConstantFP(minnum(C0, C1), SDLoc(N), N->getValueType(0)); + return DAG.getConstantFP(minnum(C0, C1), SDLoc(N), VT); } - if (N0CFP) { - EVT VT = N->getValueType(0); - // Canonicalize to constant on RHS. + // Canonicalize to constant on RHS. + if (isConstantFPBuildVectorOrConstantFP(N0) && + !isConstantFPBuildVectorOrConstantFP(N1)) return DAG.getNode(ISD::FMINNUM, SDLoc(N), VT, N1, N0); - } return SDValue(); } @@ -8865,20 +9318,20 @@ SDValue DAGCombiner::visitFMINNUM(SDNode *N) { SDValue DAGCombiner::visitFMAXNUM(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - const ConstantFPSDNode *N0CFP = dyn_cast(N0); - const ConstantFPSDNode *N1CFP = dyn_cast(N1); + EVT VT = N->getValueType(0); + const ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0); + const ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1); if (N0CFP && N1CFP) { const APFloat &C0 = N0CFP->getValueAPF(); const APFloat &C1 = N1CFP->getValueAPF(); - return DAG.getConstantFP(maxnum(C0, C1), SDLoc(N), N->getValueType(0)); + return DAG.getConstantFP(maxnum(C0, C1), SDLoc(N), VT); } - if (N0CFP) { - EVT VT = N->getValueType(0); - // Canonicalize to constant on RHS. + // Canonicalize to constant on RHS. + if (isConstantFPBuildVectorOrConstantFP(N0) && + !isConstantFPBuildVectorOrConstantFP(N1)) return DAG.getNode(ISD::FMAXNUM, SDLoc(N), VT, N1, N0); - } return SDValue(); } @@ -9027,8 +9480,7 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) { SDValue Op1 = TheXor->getOperand(1); if (Op0.getOpcode() == Op1.getOpcode()) { // Avoid missing important xor optimizations. - SDValue Tmp = visitXOR(TheXor); - if (Tmp.getNode()) { + if (SDValue Tmp = visitXOR(TheXor)) { if (Tmp.getNode() != TheXor) { DEBUG(dbgs() << "\nReplacing.8 "; TheXor->dump(&DAG); @@ -9144,7 +9596,8 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, } else return false; - return TLI.isLegalAddressingMode(AM, VT.getTypeForEVT(*DAG.getContext()), AS); + return TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, + VT.getTypeForEVT(*DAG.getContext()), AS); } /// Try turning a load/store into a pre-indexed load/store when the base @@ -9714,8 +10167,8 @@ struct LoadedSlice { void addSliceGain(const LoadedSlice &LS) { // Each slice saves a truncate. const TargetLowering &TLI = LS.DAG->getTargetLoweringInfo(); - if (!TLI.isTruncateFree(LS.Inst->getValueType(0), - LS.Inst->getOperand(0).getValueType())) + if (!TLI.isTruncateFree(LS.Inst->getOperand(0).getValueType(), + LS.Inst->getValueType(0))) ++Truncates; // If there is a shift amount, this slice gets rid of it. if (LS.Shift) @@ -10617,30 +11070,109 @@ struct BaseIndexOffset { }; } // namespace +// This is a helper function for visitMUL to check the profitability +// of folding (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2). +// MulNode is the original multiply, AddNode is (add x, c1), +// and ConstNode is c2. +// +// If the (add x, c1) has multiple uses, we could increase +// the number of adds if we make this transformation. +// It would only be worth doing this if we can remove a +// multiply in the process. Check for that here. +// To illustrate: +// (A + c1) * c3 +// (A + c2) * c3 +// We're checking for cases where we have common "c3 * A" expressions. +bool DAGCombiner::isMulAddWithConstProfitable(SDNode *MulNode, + SDValue &AddNode, + SDValue &ConstNode) { + APInt Val; + + // If the add only has one use, this would be OK to do. + if (AddNode.getNode()->hasOneUse()) + return true; + + // Walk all the users of the constant with which we're multiplying. + for (SDNode *Use : ConstNode->uses()) { + + if (Use == MulNode) // This use is the one we're on right now. Skip it. + continue; + + if (Use->getOpcode() == ISD::MUL) { // We have another multiply use. + SDNode *OtherOp; + SDNode *MulVar = AddNode.getOperand(0).getNode(); + + // OtherOp is what we're multiplying against the constant. + if (Use->getOperand(0) == ConstNode) + OtherOp = Use->getOperand(1).getNode(); + else + OtherOp = Use->getOperand(0).getNode(); + + // Check to see if multiply is with the same operand of our "add". + // + // ConstNode = CONST + // Use = ConstNode * A <-- visiting Use. OtherOp is A. + // ... + // AddNode = (A + c1) <-- MulVar is A. + // = AddNode * ConstNode <-- current visiting instruction. + // + // If we make this transformation, we will have a common + // multiply (ConstNode * A) that we can save. + if (OtherOp == MulVar) + return true; + + // Now check to see if a future expansion will give us a common + // multiply. + // + // ConstNode = CONST + // AddNode = (A + c1) + // ... = AddNode * ConstNode <-- current visiting instruction. + // ... + // OtherOp = (A + c2) + // Use = OtherOp * ConstNode <-- visiting Use. + // + // If we make this transformation, we will have a common + // multiply (CONST * A) after we also do the same transformation + // to the "t2" instruction. + if (OtherOp->getOpcode() == ISD::ADD && + isConstantIntBuildVectorOrConstantInt(OtherOp->getOperand(1)) && + OtherOp->getOperand(0).getNode() == MulVar) + return true; + } + } + + // Didn't find a case where this would be profitable. + return false; +} + SDValue DAGCombiner::getMergedConstantVectorStore(SelectionDAG &DAG, SDLoc SL, ArrayRef Stores, + SmallVectorImpl &Chains, EVT Ty) const { SmallVector BuildVector; - for (unsigned I = 0, E = Ty.getVectorNumElements(); I != E; ++I) - BuildVector.push_back(cast(Stores[I].MemNode)->getValue()); + for (unsigned I = 0, E = Ty.getVectorNumElements(); I != E; ++I) { + StoreSDNode *St = cast(Stores[I].MemNode); + Chains.push_back(St->getChain()); + BuildVector.push_back(St->getValue()); + } return DAG.getNode(ISD::BUILD_VECTOR, SL, Ty, BuildVector); } bool DAGCombiner::MergeStoresOfConstantsOrVecElts( SmallVectorImpl &StoreNodes, EVT MemVT, - unsigned NumElem, bool IsConstantSrc, bool UseVector) { + unsigned NumStores, bool IsConstantSrc, bool UseVector) { // Make sure we have something to merge. - if (NumElem < 2) + if (NumStores < 2) return false; int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8; LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; unsigned LatestNodeUsed = 0; - for (unsigned i=0; i < NumElem; ++i) { + for (unsigned i=0; i < NumStores; ++i) { // Find a chain for the new wide-store operand. Notice that some // of the store nodes that we found may not be selected for inclusion // in the wide store. The chain we use needs to be the chain of the @@ -10649,45 +11181,57 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( LatestNodeUsed = i; } + SmallVector Chains; + // The latest Node in the DAG. LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode; SDLoc DL(StoreNodes[0].MemNode); SDValue StoredVal; if (UseVector) { - // Find a legal type for the vector store. - EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT, NumElem); + bool IsVec = MemVT.isVector(); + unsigned Elts = NumStores; + if (IsVec) { + // When merging vector stores, get the total number of elements. + Elts *= MemVT.getVectorNumElements(); + } + // Get the type for the merged vector store. + EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); assert(TLI.isTypeLegal(Ty) && "Illegal vector store"); + if (IsConstantSrc) { - StoredVal = getMergedConstantVectorStore(DAG, DL, StoreNodes, Ty); + StoredVal = getMergedConstantVectorStore(DAG, DL, StoreNodes, Chains, Ty); } else { SmallVector Ops; - for (unsigned i = 0; i < NumElem ; ++i) { + for (unsigned i = 0; i < NumStores; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); SDValue Val = St->getValue(); - // All of the operands of a BUILD_VECTOR must have the same type. + // All operands of BUILD_VECTOR / CONCAT_VECTOR must have the same type. if (Val.getValueType() != MemVT) return false; Ops.push_back(Val); + Chains.push_back(St->getChain()); } // Build the extracted vector elements back into a vector. - StoredVal = DAG.getNode(ISD::BUILD_VECTOR, DL, Ty, Ops); - } + StoredVal = DAG.getNode(IsVec ? ISD::CONCAT_VECTORS : ISD::BUILD_VECTOR, + DL, Ty, Ops); } } else { // We should always use a vector store when merging extracted vector // elements, so this path implies a store of constants. assert(IsConstantSrc && "Merged vector elements should use vector store"); - unsigned SizeInBits = NumElem * ElementSizeBytes * 8; + unsigned SizeInBits = NumStores * ElementSizeBytes * 8; APInt StoreInt(SizeInBits, 0); // Construct a single integer constant which is made of the smaller // constant inputs. bool IsLE = DAG.getDataLayout().isLittleEndian(); - for (unsigned i = 0; i < NumElem ; ++i) { - unsigned Idx = IsLE ? (NumElem - 1 - i) : i; + for (unsigned i = 0; i < NumStores; ++i) { + unsigned Idx = IsLE ? (NumStores - 1 - i) : i; StoreSDNode *St = cast(StoreNodes[Idx].MemNode); + Chains.push_back(St->getChain()); + SDValue Val = St->getValue(); StoreInt <<= ElementSizeBytes * 8; if (ConstantSDNode *C = dyn_cast(Val)) { @@ -10704,7 +11248,10 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( StoredVal = DAG.getConstant(StoreInt, DL, StoreTy); } - SDValue NewStore = DAG.getStore(LatestOp->getChain(), DL, StoredVal, + assert(!Chains.empty()); + + SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Chains); + SDValue NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), false, false, @@ -10713,7 +11260,7 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( // Replace the last store with the new store CombineTo(LatestOp, NewStore); // Erase all other stores. - for (unsigned i = 0; i < NumElem ; ++i) { + for (unsigned i = 0; i < NumStores; ++i) { if (StoreNodes[i].MemNode == LatestOp) continue; StoreSDNode *St = cast(StoreNodes[i].MemNode); @@ -10735,17 +11282,6 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( return true; } -static bool allowableAlignment(const SelectionDAG &DAG, - const TargetLowering &TLI, EVT EVTTy, - unsigned AS, unsigned Align) { - if (TLI.allowsMisalignedMemoryAccesses(EVTTy, AS, Align)) - return true; - - Type *Ty = EVTTy.getTypeForEVT(*DAG.getContext()); - unsigned ABIAlignment = DAG.getDataLayout().getPrefTypeAlignment(Ty); - return (Align >= ABIAlignment); -} - void DAGCombiner::getStoreMergeAndAliasCandidates( StoreSDNode* St, SmallVectorImpl &StoreNodes, SmallVectorImpl &AliasLoadNodes) { @@ -10767,12 +11303,44 @@ void DAGCombiner::getStoreMergeAndAliasCandidates( EVT MemVT = St->getMemoryVT(); unsigned Seq = 0; StoreSDNode *Index = St; - while (Index) { - // If the chain has more than one use, then we can't reorder the mem ops. - if (Index != St && !SDValue(Index, 0)->hasOneUse()) - break; - // Find the base pointer and offset for this memory node. + + bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA + : DAG.getSubtarget().useAA(); + + if (UseAA) { + // Look at other users of the same chain. Stores on the same chain do not + // alias. If combiner-aa is enabled, non-aliasing stores are canonicalized + // to be on the same chain, so don't bother looking at adjacent chains. + + SDValue Chain = St->getChain(); + for (auto I = Chain->use_begin(), E = Chain->use_end(); I != E; ++I) { + if (StoreSDNode *OtherST = dyn_cast(*I)) { + if (I.getOperandNo() != 0) + continue; + + if (OtherST->isVolatile() || OtherST->isIndexed()) + continue; + + if (OtherST->getMemoryVT() != MemVT) + continue; + + BaseIndexOffset Ptr = BaseIndexOffset::match(OtherST->getBasePtr()); + + if (Ptr.equalBaseIndex(BasePtr)) + StoreNodes.push_back(MemOpLink(OtherST, Ptr.Offset, Seq++)); + } + } + + return; + } + + while (Index) { + // If the chain has more than one use, then we can't reorder the mem ops. + if (Index != St && !SDValue(Index, 0)->hasOneUse()) + break; + + // Find the base pointer and offset for this memory node. BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr()); // Check that the base pointer is the same as the original one. @@ -10792,6 +11360,13 @@ void DAGCombiner::getStoreMergeAndAliasCandidates( if (Index->getMemoryVT() != MemVT) break; + // We do not allow under-aligned stores in order to prevent + // overriding stores. NOTE: this is a bad hack. Alignment SHOULD + // be irrelevant here; what MATTERS is that we not move memory + // operations that potentially overlap past each-other. + if (Index->getAlignment() < MemVT.getStoreSize()) + break; + // We found a potential memory operand to merge. StoreNodes.push_back(MemOpLink(Index, Ptr.Offset, Seq++)); @@ -10836,8 +11411,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { if (ElementSizeBytes * 8 != MemVT.getSizeInBits()) return false; - // Don't merge vectors into wider inputs. - if (MemVT.isVector() || !MemVT.isSimple()) + if (!MemVT.isSimple()) return false; // Perform an early exit check. Do not bother looking at stored values that @@ -10846,9 +11420,16 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { bool IsLoadSrc = isa(StoredVal); bool IsConstantSrc = isa(StoredVal) || isa(StoredVal); - bool IsExtractVecEltSrc = (StoredVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT); + bool IsExtractVecSrc = (StoredVal.getOpcode() == ISD::EXTRACT_VECTOR_ELT || + StoredVal.getOpcode() == ISD::EXTRACT_SUBVECTOR); - if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecEltSrc) + if (!IsConstantSrc && !IsLoadSrc && !IsExtractVecSrc) + return false; + + // Don't merge vectors into wider vectors if the source data comes from loads. + // TODO: This restriction can be lifted by using logic similar to the + // ExtractVecSrc case. + if (MemVT.isVector() && IsLoadSrc) return false; // Only look at ends of store sequences. @@ -10860,22 +11441,28 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { // We need to make sure that these nodes do not interfere with // any of the store nodes. SmallVector AliasLoadNodes; - + // Save the StoreSDNodes that we find in the chain. SmallVector StoreNodes; getStoreMergeAndAliasCandidates(St, StoreNodes, AliasLoadNodes); - + // Check if there is anything to merge. if (StoreNodes.size() < 2) return false; - // Sort the memory operands according to their distance from the base pointer. + // Sort the memory operands according to their distance from the + // base pointer. As a secondary criteria: make sure stores coming + // later in the code come first in the list. This is important for + // the non-UseAA case, because we're merging stores into the FINAL + // store along a chain which potentially contains aliasing stores. + // Thus, if there are multiple stores to the same address, the last + // one can be considered for merging but not the others. std::sort(StoreNodes.begin(), StoreNodes.end(), [](MemOpLink LHS, MemOpLink RHS) { return LHS.OffsetFromBase < RHS.OffsetFromBase || (LHS.OffsetFromBase == RHS.OffsetFromBase && - LHS.SequenceNum > RHS.SequenceNum); + LHS.SequenceNum < RHS.SequenceNum); }); // Scan the memory operations on the chain and find the first non-consecutive @@ -10892,15 +11479,12 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { break; } - bool Alias = false; // Check if this store interferes with any of the loads that we found. - for (unsigned ld = 0, lde = AliasLoadNodes.size(); ld < lde; ++ld) - if (isAlias(AliasLoadNodes[ld], StoreNodes[i].MemNode)) { - Alias = true; - break; - } - // We found a load that alias with this store. Stop the sequence. - if (Alias) + // If we find a load that alias with this store. Stop the sequence. + if (std::any_of(AliasLoadNodes.begin(), AliasLoadNodes.end(), + [&](LSBaseSDNode* Ldn) { + return isAlias(Ldn, StoreNodes[i].MemNode); + })) break; // Mark this node as useful. @@ -10911,6 +11495,8 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; unsigned FirstStoreAS = FirstInChain->getAddressSpace(); unsigned FirstStoreAlign = FirstInChain->getAlignment(); + LLVMContext &Context = *DAG.getContext(); + const DataLayout &DL = DAG.getDataLayout(); // Store the constants into memory as one consecutive store. if (IsConstantSrc) { @@ -10932,43 +11518,40 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { // Find a legal type for the constant store. unsigned SizeInBits = (i+1) * ElementSizeBytes * 8; - EVT StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits); + EVT StoreTy = EVT::getIntegerVT(Context, SizeInBits); + bool IsFast; if (TLI.isTypeLegal(StoreTy) && - allowableAlignment(DAG, TLI, StoreTy, FirstStoreAS, - FirstStoreAlign)) { + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, + FirstStoreAlign, &IsFast) && IsFast) { LastLegalType = i+1; // Or check whether a truncstore is legal. - } else if (TLI.getTypeAction(*DAG.getContext(), StoreTy) == + } else if (TLI.getTypeAction(Context, StoreTy) == TargetLowering::TypePromoteInteger) { EVT LegalizedStoredValueTy = - TLI.getTypeToTransformTo(*DAG.getContext(), StoredVal.getValueType()); + TLI.getTypeToTransformTo(Context, StoredVal.getValueType()); if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) && - allowableAlignment(DAG, TLI, LegalizedStoredValueTy, FirstStoreAS, - FirstStoreAlign)) { + TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy, + FirstStoreAS, FirstStoreAlign, &IsFast) && + IsFast) { LastLegalType = i + 1; } } - // Find a legal type for the vector store. - EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT, i+1); - if (TLI.isTypeLegal(Ty) && - allowableAlignment(DAG, TLI, Ty, FirstStoreAS, FirstStoreAlign)) { - LastLegalVectorType = i + 1; + // We only use vectors if the constant is known to be zero or the target + // allows it and the function is not marked with the noimplicitfloat + // attribute. + if ((!NonZero || TLI.storeOfVectorConstantIsCheap(MemVT, i+1, + FirstStoreAS)) && + !NoVectors) { + // Find a legal type for the vector store. + EVT Ty = EVT::getVectorVT(Context, MemVT, i+1); + if (TLI.isTypeLegal(Ty) && + TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, + FirstStoreAlign, &IsFast) && IsFast) + LastLegalVectorType = i + 1; } } - - // We only use vectors if the constant is known to be zero or the target - // allows it and the function is not marked with the noimplicitfloat - // attribute. - if (NoVectors) { - LastLegalVectorType = 0; - } else if (NonZero && !TLI.storeOfVectorConstantIsCheap(MemVT, - LastLegalVectorType, - FirstStoreAS)) { - LastLegalVectorType = 0; - } - // Check if we found a legal integer type to store. if (LastLegalType == 0 && LastLegalVectorType == 0) return false; @@ -10982,27 +11565,36 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { // When extracting multiple vector elements, try to store them // in one vector store rather than a sequence of scalar stores. - if (IsExtractVecEltSrc) { - unsigned NumElem = 0; + if (IsExtractVecSrc) { + unsigned NumStoresToMerge = 0; + bool IsVec = MemVT.isVector(); for (unsigned i = 0; i < LastConsecutiveStore + 1; ++i) { StoreSDNode *St = cast(StoreNodes[i].MemNode); - SDValue StoredVal = St->getValue(); + unsigned StoreValOpcode = St->getValue().getOpcode(); // This restriction could be loosened. // Bail out if any stored values are not elements extracted from a vector. // It should be possible to handle mixed sources, but load sources need // more careful handling (see the block of code below that handles // consecutive loads). - if (StoredVal.getOpcode() != ISD::EXTRACT_VECTOR_ELT) + if (StoreValOpcode != ISD::EXTRACT_VECTOR_ELT && + StoreValOpcode != ISD::EXTRACT_SUBVECTOR) return false; // Find a legal type for the vector store. - EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT, i+1); + unsigned Elts = i + 1; + if (IsVec) { + // When merging vector stores, get the total number of elements. + Elts *= MemVT.getVectorNumElements(); + } + EVT Ty = EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(), Elts); + bool IsFast; if (TLI.isTypeLegal(Ty) && - allowableAlignment(DAG, TLI, Ty, FirstStoreAS, FirstStoreAlign)) - NumElem = i + 1; + TLI.allowsMemoryAccess(Context, DL, Ty, FirstStoreAS, + FirstStoreAlign, &IsFast) && IsFast) + NumStoresToMerge = i + 1; } - return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumElem, + return MergeStoresOfConstantsOrVecElts(StoreNodes, MemVT, NumStoresToMerge, false, true); } @@ -11076,7 +11668,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { StartAddress = LoadNodes[0].OffsetFromBase; SDValue FirstChain = FirstLoad->getChain(); for (unsigned i = 1; i < LoadNodes.size(); ++i) { - // All loads much share the same chain. + // All loads must share the same chain. if (LoadNodes[i].MemNode->getChain() != FirstChain) break; @@ -11084,35 +11676,41 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { if (CurrAddress - StartAddress != (ElementSizeBytes * i)) break; LastConsecutiveLoad = i; - // Find a legal type for the vector store. - EVT StoreTy = EVT::getVectorVT(*DAG.getContext(), MemVT, i+1); + EVT StoreTy = EVT::getVectorVT(Context, MemVT, i+1); + bool IsFastSt, IsFastLd; if (TLI.isTypeLegal(StoreTy) && - allowableAlignment(DAG, TLI, StoreTy, FirstStoreAS, FirstStoreAlign) && - allowableAlignment(DAG, TLI, StoreTy, FirstLoadAS, FirstLoadAlign)) { + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, + FirstStoreAlign, &IsFastSt) && IsFastSt && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, + FirstLoadAlign, &IsFastLd) && IsFastLd) { LastLegalVectorType = i + 1; } // Find a legal type for the integer store. unsigned SizeInBits = (i+1) * ElementSizeBytes * 8; - StoreTy = EVT::getIntegerVT(*DAG.getContext(), SizeInBits); + StoreTy = EVT::getIntegerVT(Context, SizeInBits); if (TLI.isTypeLegal(StoreTy) && - allowableAlignment(DAG, TLI, StoreTy, FirstStoreAS, FirstStoreAlign) && - allowableAlignment(DAG, TLI, StoreTy, FirstLoadAS, FirstLoadAlign)) + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstStoreAS, + FirstStoreAlign, &IsFastSt) && IsFastSt && + TLI.allowsMemoryAccess(Context, DL, StoreTy, FirstLoadAS, + FirstLoadAlign, &IsFastLd) && IsFastLd) LastLegalIntegerType = i + 1; // Or check whether a truncstore and extload is legal. - else if (TLI.getTypeAction(*DAG.getContext(), StoreTy) == + else if (TLI.getTypeAction(Context, StoreTy) == TargetLowering::TypePromoteInteger) { EVT LegalizedStoredValueTy = - TLI.getTypeToTransformTo(*DAG.getContext(), StoreTy); + TLI.getTypeToTransformTo(Context, StoreTy); if (TLI.isTruncStoreLegal(LegalizedStoredValueTy, StoreTy) && TLI.isLoadExtLegal(ISD::ZEXTLOAD, LegalizedStoredValueTy, StoreTy) && TLI.isLoadExtLegal(ISD::SEXTLOAD, LegalizedStoredValueTy, StoreTy) && TLI.isLoadExtLegal(ISD::EXTLOAD, LegalizedStoredValueTy, StoreTy) && - allowableAlignment(DAG, TLI, LegalizedStoredValueTy, FirstStoreAS, - FirstStoreAlign) && - allowableAlignment(DAG, TLI, LegalizedStoredValueTy, FirstLoadAS, - FirstLoadAlign)) + TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy, + FirstStoreAS, FirstStoreAlign, &IsFastSt) && + IsFastSt && + TLI.allowsMemoryAccess(Context, DL, LegalizedStoredValueTy, + FirstLoadAS, FirstLoadAlign, &IsFastLd) && + IsFastLd) LastLegalIntegerType = i+1; } } @@ -11130,6 +11728,10 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { if (NumElem < 2) return false; + // Collect the chains from all merged stores. + SmallVector MergeStoreChains; + MergeStoreChains.push_back(StoreNodes[0].MemNode->getChain()); + // The latest Node in the DAG. unsigned LatestNodeUsed = 0; for (unsigned i=1; igetChain()); } LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode; @@ -11147,34 +11751,33 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { // to memory. EVT JointMemOpVT; if (UseVectorTy) { - JointMemOpVT = EVT::getVectorVT(*DAG.getContext(), MemVT, NumElem); + JointMemOpVT = EVT::getVectorVT(Context, MemVT, NumElem); } else { unsigned SizeInBits = NumElem * ElementSizeBytes * 8; - JointMemOpVT = EVT::getIntegerVT(*DAG.getContext(), SizeInBits); + JointMemOpVT = EVT::getIntegerVT(Context, SizeInBits); } SDLoc LoadDL(LoadNodes[0].MemNode); SDLoc StoreDL(StoreNodes[0].MemNode); + // The merged loads are required to have the same incoming chain, so + // using the first's chain is acceptable. SDValue NewLoad = DAG.getLoad( JointMemOpVT, LoadDL, FirstLoad->getChain(), FirstLoad->getBasePtr(), FirstLoad->getPointerInfo(), false, false, false, FirstLoadAlign); + SDValue NewStoreChain = + DAG.getNode(ISD::TokenFactor, StoreDL, MVT::Other, MergeStoreChains); + SDValue NewStore = DAG.getStore( - LatestOp->getChain(), StoreDL, NewLoad, FirstInChain->getBasePtr(), + NewStoreChain, StoreDL, NewLoad, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), false, false, FirstStoreAlign); - // Replace one of the loads with the new load. - LoadSDNode *Ld = cast(LoadNodes[0].MemNode); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), - SDValue(NewLoad.getNode(), 1)); - - // Remove the rest of the load chains. - for (unsigned i = 1; i < NumElem ; ++i) { - // Replace all chain users of the old load nodes with the chain of the new - // load node. + // Transfer chain users from old loads to the new load. + for (unsigned i = 0; i < NumElem; ++i) { LoadSDNode *Ld = cast(LoadNodes[i].MemNode); - DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Ld->getChain()); + DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), + SDValue(NewLoad.getNode(), 1)); } // Replace the last store with the new store. @@ -11192,6 +11795,114 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { return true; } +SDValue DAGCombiner::replaceStoreChain(StoreSDNode *ST, SDValue BetterChain) { + SDLoc SL(ST); + SDValue ReplStore; + + // Replace the chain to avoid dependency. + if (ST->isTruncatingStore()) { + ReplStore = DAG.getTruncStore(BetterChain, SL, ST->getValue(), + ST->getBasePtr(), ST->getMemoryVT(), + ST->getMemOperand()); + } else { + ReplStore = DAG.getStore(BetterChain, SL, ST->getValue(), ST->getBasePtr(), + ST->getMemOperand()); + } + + // Create token to keep both nodes around. + SDValue Token = DAG.getNode(ISD::TokenFactor, SL, + MVT::Other, ST->getChain(), ReplStore); + + // Make sure the new and old chains are cleaned up. + AddToWorklist(Token.getNode()); + + // Don't add users to work list. + return CombineTo(ST, Token, false); +} + +SDValue DAGCombiner::replaceStoreOfFPConstant(StoreSDNode *ST) { + SDValue Value = ST->getValue(); + if (Value.getOpcode() == ISD::TargetConstantFP) + return SDValue(); + + SDLoc DL(ST); + + SDValue Chain = ST->getChain(); + SDValue Ptr = ST->getBasePtr(); + + const ConstantFPSDNode *CFP = cast(Value); + + // NOTE: If the original store is volatile, this transform must not increase + // the number of stores. For example, on x86-32 an f64 can be stored in one + // processor operation but an i64 (which is not legal) requires two. So the + // transform should not be done in this case. + + SDValue Tmp; + switch (CFP->getSimpleValueType(0).SimpleTy) { + default: + llvm_unreachable("Unknown FP type"); + case MVT::f16: // We don't do this for these yet. + case MVT::f80: + case MVT::f128: + case MVT::ppcf128: + return SDValue(); + case MVT::f32: + if ((isTypeLegal(MVT::i32) && !LegalOperations && !ST->isVolatile()) || + TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) { + ; + Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF(). + bitcastToAPInt().getZExtValue(), SDLoc(CFP), + MVT::i32); + return DAG.getStore(Chain, DL, Tmp, Ptr, ST->getMemOperand()); + } + + return SDValue(); + case MVT::f64: + if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations && + !ST->isVolatile()) || + TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) { + ; + Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt(). + getZExtValue(), SDLoc(CFP), MVT::i64); + return DAG.getStore(Chain, DL, Tmp, + Ptr, ST->getMemOperand()); + } + + if (!ST->isVolatile() && + TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) { + // Many FP stores are not made apparent until after legalize, e.g. for + // argument passing. Since this is so common, custom legalize the + // 64-bit integer store into two 32-bit stores. + uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue(); + SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32); + SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32); + if (DAG.getDataLayout().isBigEndian()) + std::swap(Lo, Hi); + + unsigned Alignment = ST->getAlignment(); + bool isVolatile = ST->isVolatile(); + bool isNonTemporal = ST->isNonTemporal(); + AAMDNodes AAInfo = ST->getAAInfo(); + + SDValue St0 = DAG.getStore(Chain, DL, Lo, + Ptr, ST->getPointerInfo(), + isVolatile, isNonTemporal, + ST->getAlignment(), AAInfo); + Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, + DAG.getConstant(4, DL, Ptr.getValueType())); + Alignment = MinAlign(Alignment, 4U); + SDValue St1 = DAG.getStore(Chain, DL, Hi, + Ptr, ST->getPointerInfo().getWithOffset(4), + isVolatile, isNonTemporal, + Alignment, AAInfo); + return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, + St0, St1); + } + + return SDValue(); + } +} + SDValue DAGCombiner::visitSTORE(SDNode *N) { StoreSDNode *ST = cast(N); SDValue Chain = ST->getChain(); @@ -11219,81 +11930,6 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { if (Value.getOpcode() == ISD::UNDEF && ST->isUnindexed()) return Chain; - // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr' - if (ConstantFPSDNode *CFP = dyn_cast(Value)) { - // NOTE: If the original store is volatile, this transform must not increase - // the number of stores. For example, on x86-32 an f64 can be stored in one - // processor operation but an i64 (which is not legal) requires two. So the - // transform should not be done in this case. - if (Value.getOpcode() != ISD::TargetConstantFP) { - SDValue Tmp; - switch (CFP->getSimpleValueType(0).SimpleTy) { - default: llvm_unreachable("Unknown FP type"); - case MVT::f16: // We don't do this for these yet. - case MVT::f80: - case MVT::f128: - case MVT::ppcf128: - break; - case MVT::f32: - if ((isTypeLegal(MVT::i32) && !LegalOperations && !ST->isVolatile()) || - TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) { - ; - Tmp = DAG.getConstant((uint32_t)CFP->getValueAPF(). - bitcastToAPInt().getZExtValue(), SDLoc(CFP), - MVT::i32); - return DAG.getStore(Chain, SDLoc(N), Tmp, - Ptr, ST->getMemOperand()); - } - break; - case MVT::f64: - if ((TLI.isTypeLegal(MVT::i64) && !LegalOperations && - !ST->isVolatile()) || - TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i64)) { - ; - Tmp = DAG.getConstant(CFP->getValueAPF().bitcastToAPInt(). - getZExtValue(), SDLoc(CFP), MVT::i64); - return DAG.getStore(Chain, SDLoc(N), Tmp, - Ptr, ST->getMemOperand()); - } - - if (!ST->isVolatile() && - TLI.isOperationLegalOrCustom(ISD::STORE, MVT::i32)) { - // Many FP stores are not made apparent until after legalize, e.g. for - // argument passing. Since this is so common, custom legalize the - // 64-bit integer store into two 32-bit stores. - uint64_t Val = CFP->getValueAPF().bitcastToAPInt().getZExtValue(); - SDValue Lo = DAG.getConstant(Val & 0xFFFFFFFF, SDLoc(CFP), MVT::i32); - SDValue Hi = DAG.getConstant(Val >> 32, SDLoc(CFP), MVT::i32); - if (DAG.getDataLayout().isBigEndian()) - std::swap(Lo, Hi); - - unsigned Alignment = ST->getAlignment(); - bool isVolatile = ST->isVolatile(); - bool isNonTemporal = ST->isNonTemporal(); - AAMDNodes AAInfo = ST->getAAInfo(); - - SDLoc DL(N); - - SDValue St0 = DAG.getStore(Chain, SDLoc(ST), Lo, - Ptr, ST->getPointerInfo(), - isVolatile, isNonTemporal, - ST->getAlignment(), AAInfo); - Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr, - DAG.getConstant(4, DL, Ptr.getValueType())); - Alignment = MinAlign(Alignment, 4U); - SDValue St1 = DAG.getStore(Chain, SDLoc(ST), Hi, - Ptr, ST->getPointerInfo().getWithOffset(4), - isVolatile, isNonTemporal, - Alignment, AAInfo); - return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, - St0, St1); - } - - break; - } - } - } - // Try to infer better alignment information than the store already has. if (OptLevel != CodeGenOpt::None && ST->isUnindexed()) { if (unsigned Align = DAG.InferPtrAlignment(Ptr)) { @@ -11311,8 +11947,7 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // Try transforming a pair floating point load / store ops to integer // load / store ops. - SDValue NewST = TransformFPLoadStorePair(N); - if (NewST.getNode()) + if (SDValue NewST = TransformFPLoadStorePair(N)) return NewST; bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA @@ -11323,31 +11958,17 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { UseAA = false; #endif if (UseAA && ST->isUnindexed()) { - // Walk up chain skipping non-aliasing memory nodes. - SDValue BetterChain = FindBetterChain(N, Chain); - - // If there is a better chain. - if (Chain != BetterChain) { - SDValue ReplStore; - - // Replace the chain to avoid dependency. - if (ST->isTruncatingStore()) { - ReplStore = DAG.getTruncStore(BetterChain, SDLoc(N), Value, Ptr, - ST->getMemoryVT(), ST->getMemOperand()); - } else { - ReplStore = DAG.getStore(BetterChain, SDLoc(N), Value, Ptr, - ST->getMemOperand()); - } - - // Create token to keep both nodes around. - SDValue Token = DAG.getNode(ISD::TokenFactor, SDLoc(N), - MVT::Other, Chain, ReplStore); - - // Make sure the new and old chains are cleaned up. - AddToWorklist(Token.getNode()); + // FIXME: We should do this even without AA enabled. AA will just allow + // FindBetterChain to work in more situations. The problem with this is that + // any combine that expects memory operations to be on consecutive chains + // first needs to be updated to look for users of the same chain. - // Don't add users to work list. - return CombineTo(N, Token, false); + // Walk up chain skipping non-aliasing memory nodes, on this store and any + // adjacent stores. + if (findBetterNeighborChains(ST)) { + // replaceStoreChain uses CombineTo, which handled all of the worklist + // manipulation. Return the original node to not do anything else. + return SDValue(ST, 0); } } @@ -11432,6 +12053,16 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { return SDValue(N, 0); } + // Turn 'store float 1.0, Ptr' -> 'store int 0x12345678, Ptr' + // + // Make sure to do this only after attempting to merge stores in order to + // avoid changing the types of some subset of stores due to visit order, + // preventing their merging. + if (isa(Value)) { + if (SDValue NewSt = replaceStoreOfFPConstant(ST)) + return NewSt; + } + return ReduceLoadOpStoreWidth(N); } @@ -11605,7 +12236,24 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { } SDValue EltNo = N->getOperand(1); - bool ConstEltNo = isa(EltNo); + ConstantSDNode *ConstEltNo = dyn_cast(EltNo); + + // extract_vector_elt (build_vector x, y), 1 -> y + if (ConstEltNo && + InVec.getOpcode() == ISD::BUILD_VECTOR && + TLI.isTypeLegal(VT) && + (InVec.hasOneUse() || + TLI.aggressivelyPreferBuildVectorSources(VT))) { + SDValue Elt = InVec.getOperand(ConstEltNo->getZExtValue()); + EVT InEltVT = Elt.getValueType(); + + // Sometimes build_vector's scalar input types do not match result type. + if (NVT == InEltVT) + return Elt; + + // TODO: It may be useful to truncate if free if the build_vector implicitly + // converts. + } // Transform: (EXTRACT_VECTOR_ELT( VECTOR_SHUFFLE )) -> EXTRACT_VECTOR_ELT. // We only perform this optimization before the op legalization phase because @@ -11613,13 +12261,11 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // patterns. For example on AVX, extracting elements from a wide vector // without using extract_subvector. However, if we can find an underlying // scalar value, then we can always use that. - if (InVec.getOpcode() == ISD::VECTOR_SHUFFLE - && ConstEltNo) { - int Elt = cast(EltNo)->getZExtValue(); + if (ConstEltNo && InVec.getOpcode() == ISD::VECTOR_SHUFFLE) { int NumElem = VT.getVectorNumElements(); ShuffleVectorSDNode *SVOp = cast(InVec); // Find the new index to extract from. - int OrigElt = SVOp->getMaskElt(Elt); + int OrigElt = SVOp->getMaskElt(ConstEltNo->getZExtValue()); // Extracting an undef index is undef. if (OrigElt == -1) @@ -11648,7 +12294,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) { // scalar_to_vector here as well. if (!LegalOperations) { - EVT IndexTy = TLI.getVectorIdxTy(); + EVT IndexTy = TLI.getVectorIdxTy(DAG.getDataLayout()); return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(N), NVT, SVInVec, DAG.getConstant(OrigElt, SDLoc(SVOp), IndexTy)); } @@ -12079,10 +12725,13 @@ SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { // Try to replace VecIn1 with two extract_subvectors // No need to update the masks, they should still be correct. - VecIn2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, VecIn1, - DAG.getConstant(VT.getVectorNumElements(), dl, TLI.getVectorIdxTy())); - VecIn1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, VecIn1, - DAG.getConstant(0, dl, TLI.getVectorIdxTy())); + VecIn2 = DAG.getNode( + ISD::EXTRACT_SUBVECTOR, dl, VT, VecIn1, + DAG.getConstant(VT.getVectorNumElements(), dl, + TLI.getVectorIdxTy(DAG.getDataLayout()))); + VecIn1 = DAG.getNode( + ISD::EXTRACT_SUBVECTOR, dl, VT, VecIn1, + DAG.getConstant(0, dl, TLI.getVectorIdxTy(DAG.getDataLayout()))); } else return SDValue(); } @@ -12172,12 +12821,90 @@ static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) { DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, Ops)); } -SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { - // TODO: Check to see if this is a CONCAT_VECTORS of a bunch of - // EXTRACT_SUBVECTOR operations. If so, and if the EXTRACT_SUBVECTOR vector - // inputs come from at most two distinct vectors, turn this into a shuffle - // node. +// Check to see if this is a CONCAT_VECTORS of a bunch of EXTRACT_SUBVECTOR +// operations. If so, and if the EXTRACT_SUBVECTOR vector inputs come from at +// most two distinct vectors the same size as the result, attempt to turn this +// into a legal shuffle. +static SDValue combineConcatVectorOfExtracts(SDNode *N, SelectionDAG &DAG) { + EVT VT = N->getValueType(0); + EVT OpVT = N->getOperand(0).getValueType(); + int NumElts = VT.getVectorNumElements(); + int NumOpElts = OpVT.getVectorNumElements(); + + SDValue SV0 = DAG.getUNDEF(VT), SV1 = DAG.getUNDEF(VT); + SmallVector Mask; + + for (SDValue Op : N->ops()) { + // Peek through any bitcast. + while (Op.getOpcode() == ISD::BITCAST) + Op = Op.getOperand(0); + + // UNDEF nodes convert to UNDEF shuffle mask values. + if (Op.getOpcode() == ISD::UNDEF) { + Mask.append((unsigned)NumOpElts, -1); + continue; + } + + if (Op.getOpcode() != ISD::EXTRACT_SUBVECTOR) + return SDValue(); + + // What vector are we extracting the subvector from and at what index? + SDValue ExtVec = Op.getOperand(0); + + // We want the EVT of the original extraction to correctly scale the + // extraction index. + EVT ExtVT = ExtVec.getValueType(); + + // Peek through any bitcast. + while (ExtVec.getOpcode() == ISD::BITCAST) + ExtVec = ExtVec.getOperand(0); + + // UNDEF nodes convert to UNDEF shuffle mask values. + if (ExtVec.getOpcode() == ISD::UNDEF) { + Mask.append((unsigned)NumOpElts, -1); + continue; + } + + if (!isa(Op.getOperand(1))) + return SDValue(); + int ExtIdx = cast(Op.getOperand(1))->getZExtValue(); + + // Ensure that we are extracting a subvector from a vector the same + // size as the result. + if (ExtVT.getSizeInBits() != VT.getSizeInBits()) + return SDValue(); + // Scale the subvector index to account for any bitcast. + int NumExtElts = ExtVT.getVectorNumElements(); + if (0 == (NumExtElts % NumElts)) + ExtIdx /= (NumExtElts / NumElts); + else if (0 == (NumElts % NumExtElts)) + ExtIdx *= (NumElts / NumExtElts); + else + return SDValue(); + + // At most we can reference 2 inputs in the final shuffle. + if (SV0.getOpcode() == ISD::UNDEF || SV0 == ExtVec) { + SV0 = ExtVec; + for (int i = 0; i != NumOpElts; ++i) + Mask.push_back(i + ExtIdx); + } else if (SV1.getOpcode() == ISD::UNDEF || SV1 == ExtVec) { + SV1 = ExtVec; + for (int i = 0; i != NumOpElts; ++i) + Mask.push_back(i + ExtIdx + NumElts); + } else { + return SDValue(); + } + } + + if (!DAG.getTargetLoweringInfo().isShuffleMaskLegal(Mask, VT)) + return SDValue(); + + return DAG.getVectorShuffle(VT, SDLoc(N), DAG.getBitcast(VT, SV0), + DAG.getBitcast(VT, SV1), Mask); +} + +SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // If we only have one input vector, we don't need to do any concatenation. if (N->getNumOperands() == 1) return N->getOperand(0); @@ -12278,6 +13005,11 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { if (SDValue V = combineConcatVectorOfScalars(N, DAG)) return V; + // Fold CONCAT_VECTORS of EXTRACT_SUBVECTOR (or undef) to VECTOR_SHUFFLE. + if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) + if (SDValue V = combineConcatVectorOfExtracts(N, DAG)) + return V; + // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR // nodes often generate nop CONCAT_VECTOR nodes. // Scan the CONCAT_VECTOR operands and look for a CONCAT operations that @@ -12492,7 +13224,7 @@ static SDValue partitionShuffleOfConcats(SDNode *N, SelectionDAG &DAG) { std::all_of(SVN->getMask().begin() + NumElemsPerConcat, SVN->getMask().end(), [](int i) { return i == -1; })) { N0 = DAG.getVectorShuffle(ConcatVT, SDLoc(N), N0.getOperand(0), N0.getOperand(1), - ArrayRef(SVN->getMask().begin(), NumElemsPerConcat)); + makeArrayRef(SVN->getMask().begin(), NumElemsPerConcat)); N1 = DAG.getUNDEF(ConcatVT); return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), VT, N0, N1); } @@ -12970,6 +13702,21 @@ SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) { + SDValue N0 = N->getOperand(0); + + // fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) + if (N0->getOpcode() == ISD::AND) { + ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1)); + if (AndConst && AndConst->getAPIntValue() == 0xffff) { + return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0), + N0.getOperand(0)); + } + } + + return SDValue(); +} + /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle /// with the destination vector and a zero vector. /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==> @@ -12991,34 +13738,76 @@ SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) { if (RHS.getOpcode() == ISD::BITCAST) RHS = RHS.getOperand(0); - if (RHS.getOpcode() == ISD::BUILD_VECTOR) { + if (RHS.getOpcode() != ISD::BUILD_VECTOR) + return SDValue(); + + EVT RVT = RHS.getValueType(); + unsigned NumElts = RHS.getNumOperands(); + + // Attempt to create a valid clear mask, splitting the mask into + // sub elements and checking to see if each is + // all zeros or all ones - suitable for shuffle masking. + auto BuildClearMask = [&](int Split) { + int NumSubElts = NumElts * Split; + int NumSubBits = RVT.getScalarSizeInBits() / Split; + SmallVector Indices; - unsigned NumElts = RHS.getNumOperands(); + for (int i = 0; i != NumSubElts; ++i) { + int EltIdx = i / Split; + int SubIdx = i % Split; + SDValue Elt = RHS.getOperand(EltIdx); + if (Elt.getOpcode() == ISD::UNDEF) { + Indices.push_back(-1); + continue; + } - for (unsigned i = 0; i != NumElts; ++i) { - SDValue Elt = RHS.getOperand(i); - if (isAllOnesConstant(Elt)) + APInt Bits; + if (isa(Elt)) + Bits = cast(Elt)->getAPIntValue(); + else if (isa(Elt)) + Bits = cast(Elt)->getValueAPF().bitcastToAPInt(); + else + return SDValue(); + + // Extract the sub element from the constant bit mask. + if (DAG.getDataLayout().isBigEndian()) { + Bits = Bits.lshr((Split - SubIdx - 1) * NumSubBits); + } else { + Bits = Bits.lshr(SubIdx * NumSubBits); + } + + if (Split > 1) + Bits = Bits.trunc(NumSubBits); + + if (Bits.isAllOnesValue()) Indices.push_back(i); - else if (isNullConstant(Elt)) - Indices.push_back(NumElts+i); + else if (Bits == 0) + Indices.push_back(i + NumSubElts); else return SDValue(); } // Let's see if the target supports this vector_shuffle. - EVT RVT = RHS.getValueType(); - if (!TLI.isVectorClearMaskLegal(Indices, RVT)) + EVT ClearSVT = EVT::getIntegerVT(*DAG.getContext(), NumSubBits); + EVT ClearVT = EVT::getVectorVT(*DAG.getContext(), ClearSVT, NumSubElts); + if (!TLI.isVectorClearMaskLegal(Indices, ClearVT)) return SDValue(); - // Return the new VECTOR_SHUFFLE node. - EVT EltVT = RVT.getVectorElementType(); - SmallVector ZeroOps(RVT.getVectorNumElements(), - DAG.getConstant(0, dl, EltVT)); - SDValue Zero = DAG.getNode(ISD::BUILD_VECTOR, dl, RVT, ZeroOps); - LHS = DAG.getNode(ISD::BITCAST, dl, RVT, LHS); - SDValue Shuf = DAG.getVectorShuffle(RVT, dl, LHS, Zero, &Indices[0]); - return DAG.getNode(ISD::BITCAST, dl, VT, Shuf); - } + SDValue Zero = DAG.getConstant(0, dl, ClearVT); + return DAG.getBitcast(VT, DAG.getVectorShuffle(ClearVT, dl, + DAG.getBitcast(ClearVT, LHS), + Zero, &Indices[0])); + }; + + // Determine maximum split level (byte level masking). + int MaxSplit = 1; + if (RVT.getScalarSizeInBits() % 8 == 0) + MaxSplit = RVT.getScalarSizeInBits() / 8; + + for (int Split = 1; Split <= MaxSplit; ++Split) + if (RVT.getScalarSizeInBits() % Split == 0) + if (SDValue S = BuildClearMask(Split)) + return S; return SDValue(); } @@ -13030,60 +13819,17 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) { SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); + SDValue Ops[] = {LHS, RHS}; + // See if we can constant fold the vector operation. + if (SDValue Fold = DAG.FoldConstantVectorArithmetic( + N->getOpcode(), SDLoc(LHS), LHS.getValueType(), Ops, N->getFlags())) + return Fold; + + // Try to convert a constant mask AND into a shuffle clear mask. if (SDValue Shuffle = XformToShuffleWithZero(N)) return Shuffle; - // If the LHS and RHS are BUILD_VECTOR nodes, see if we can constant fold - // this operation. - if (LHS.getOpcode() == ISD::BUILD_VECTOR && - RHS.getOpcode() == ISD::BUILD_VECTOR) { - // Check if both vectors are constants. If not bail out. - if (!(cast(LHS)->isConstant() && - cast(RHS)->isConstant())) - return SDValue(); - - SmallVector Ops; - for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) { - SDValue LHSOp = LHS.getOperand(i); - SDValue RHSOp = RHS.getOperand(i); - - // Can't fold divide by zero. - if (N->getOpcode() == ISD::SDIV || N->getOpcode() == ISD::UDIV || - N->getOpcode() == ISD::FDIV) { - if (isNullConstant(RHSOp) || (RHSOp.getOpcode() == ISD::ConstantFP && - cast(RHSOp.getNode())->isZero())) - break; - } - - EVT VT = LHSOp.getValueType(); - EVT RVT = RHSOp.getValueType(); - if (RVT != VT) { - // Integer BUILD_VECTOR operands may have types larger than the element - // size (e.g., when the element type is not legal). Prior to type - // legalization, the types may not match between the two BUILD_VECTORS. - // Truncate one of the operands to make them match. - if (RVT.getSizeInBits() > VT.getSizeInBits()) { - RHSOp = DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, RHSOp); - } else { - LHSOp = DAG.getNode(ISD::TRUNCATE, SDLoc(N), RVT, LHSOp); - VT = RVT; - } - } - SDValue FoldOp = DAG.getNode(N->getOpcode(), SDLoc(LHS), VT, - LHSOp, RHSOp); - if (FoldOp.getOpcode() != ISD::UNDEF && - FoldOp.getOpcode() != ISD::Constant && - FoldOp.getOpcode() != ISD::ConstantFP) - break; - Ops.push_back(FoldOp); - AddToWorklist(FoldOp.getNode()); - } - - if (Ops.size() == LHS.getNumOperands()) - return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), LHS.getValueType(), Ops); - } - // Type legalization might introduce new shuffles in the DAG. // Fold (VBinOp (shuffle (A, Undef, Mask)), (shuffle (B, Undef, Mask))) // -> (shuffle (VBinOp (A, B)), Undef, Mask). @@ -13098,7 +13844,8 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) { EVT VT = N->getValueType(0); SDValue UndefVector = LHS.getOperand(1); SDValue NewBinOp = DAG.getNode(N->getOpcode(), SDLoc(N), VT, - LHS.getOperand(0), RHS.getOperand(0)); + LHS.getOperand(0), RHS.getOperand(0), + N->getFlags()); AddUsersToWorklist(N); return DAG.getVectorShuffle(VT, SDLoc(N), NewBinOp, UndefVector, &SVN0->getMask()[0]); @@ -13358,8 +14105,9 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1, // Create a ConstantArray of the two constants. Constant *CA = ConstantArray::get(ArrayType::get(FPTy, 2), Elts); - SDValue CPIdx = DAG.getConstantPool(CA, TLI.getPointerTy(), - TD.getPrefTypeAlignment(FPTy)); + SDValue CPIdx = + DAG.getConstantPool(CA, TLI.getPointerTy(DAG.getDataLayout()), + TD.getPrefTypeAlignment(FPTy)); unsigned Alignment = cast(CPIdx)->getAlignment(); // Get the offsets to the 0 and 1 element of the array so that we can @@ -13378,9 +14126,10 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1, CPIdx = DAG.getNode(ISD::ADD, DL, CPIdx.getValueType(), CPIdx, CstOffset); AddToWorklist(CPIdx.getNode()); - return DAG.getLoad(TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx, - MachinePointerInfo::getConstantPool(), false, - false, false, Alignment); + return DAG.getLoad( + TV->getValueType(0), DL, DAG.getEntryNode(), CPIdx, + MachinePointerInfo::getConstantPool(DAG.getMachineFunction()), + false, false, false, Alignment); } } @@ -13469,8 +14218,7 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1, // Get a SetCC of the condition // NOTE: Don't create a SETCC if it's not legal on this target. if (!LegalOperations || - TLI.isOperationLegal(ISD::SETCC, - LegalTypes ? getSetCCResultType(N0.getValueType()) : MVT::i1)) { + TLI.isOperationLegal(ISD::SETCC, N0.getValueType())) { SDValue Temp, SCC; // cast from setcc result type to select result type if (LegalTypes) { @@ -13502,51 +14250,6 @@ SDValue DAGCombiner::SimplifySelectCC(SDLoc DL, SDValue N0, SDValue N1, } } - // Check to see if this is the equivalent of setcc - // FIXME: Turn all of these into setcc if setcc if setcc is legal - // otherwise, go ahead with the folds. - if (0 && isNullConstant(N3) && isOneConstant(N2)) { - EVT XType = N0.getValueType(); - if (!LegalOperations || - TLI.isOperationLegal(ISD::SETCC, getSetCCResultType(XType))) { - SDValue Res = DAG.getSetCC(DL, getSetCCResultType(XType), N0, N1, CC); - if (Res.getValueType() != VT) - Res = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Res); - return Res; - } - - // fold (seteq X, 0) -> (srl (ctlz X, log2(size(X)))) - if (isNullConstant(N1) && CC == ISD::SETEQ && - (!LegalOperations || - TLI.isOperationLegal(ISD::CTLZ, XType))) { - SDValue Ctlz = DAG.getNode(ISD::CTLZ, SDLoc(N0), XType, N0); - return DAG.getNode(ISD::SRL, DL, XType, Ctlz, - DAG.getConstant(Log2_32(XType.getSizeInBits()), - SDLoc(Ctlz), - getShiftAmountTy(Ctlz.getValueType()))); - } - // fold (setgt X, 0) -> (srl (and (-X, ~X), size(X)-1)) - if (isNullConstant(N1) && CC == ISD::SETGT) { - SDLoc DL(N0); - SDValue NegN0 = DAG.getNode(ISD::SUB, DL, - XType, DAG.getConstant(0, DL, XType), N0); - SDValue NotN0 = DAG.getNOT(DL, N0, XType); - return DAG.getNode(ISD::SRL, DL, XType, - DAG.getNode(ISD::AND, DL, XType, NegN0, NotN0), - DAG.getConstant(XType.getSizeInBits() - 1, DL, - getShiftAmountTy(XType))); - } - // fold (setgt X, -1) -> (xor (srl (X, size(X)-1), 1)) - if (isAllOnesConstant(N1) && CC == ISD::SETGT) { - SDLoc DL(N0); - SDValue Sign = DAG.getNode(ISD::SRL, DL, XType, N0, - DAG.getConstant(XType.getSizeInBits() - 1, DL, - getShiftAmountTy(N0.getValueType()))); - return DAG.getNode(ISD::XOR, DL, XType, Sign, DAG.getConstant(1, DL, - XType)); - } - } - // Check to see if this is an integer abs. // select_cc setg[te] X, 0, X, -X -> // select_cc setgt X, -1, X, -X -> @@ -13654,7 +14357,7 @@ SDValue DAGCombiner::BuildUDIV(SDNode *N) { return S; } -SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op) { +SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) { if (Level >= AfterLegalizeDAG) return SDValue(); @@ -13678,16 +14381,16 @@ SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op) { // Newton iterations: Est = Est + Est (1 - Arg * Est) for (unsigned i = 0; i < Iterations; ++i) { - SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, Est); + SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, Est, Flags); AddToWorklist(NewEst.getNode()); - NewEst = DAG.getNode(ISD::FSUB, DL, VT, FPOne, NewEst); + NewEst = DAG.getNode(ISD::FSUB, DL, VT, FPOne, NewEst, Flags); AddToWorklist(NewEst.getNode()); - NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst); + NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags); AddToWorklist(NewEst.getNode()); - Est = DAG.getNode(ISD::FADD, DL, VT, Est, NewEst); + Est = DAG.getNode(ISD::FADD, DL, VT, Est, NewEst, Flags); AddToWorklist(Est.getNode()); } } @@ -13704,31 +14407,32 @@ SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op) { /// X_{i+1} = X_i (1.5 - A X_i^2 / 2) /// As a result, we precompute A/2 prior to the iteration loop. SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est, - unsigned Iterations) { + unsigned Iterations, + SDNodeFlags *Flags) { EVT VT = Arg.getValueType(); SDLoc DL(Arg); SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT); // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that // this entire sequence requires only one FP constant. - SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg); + SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, ThreeHalves, Arg, Flags); AddToWorklist(HalfArg.getNode()); - HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg); + HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Arg, Flags); AddToWorklist(HalfArg.getNode()); // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est) for (unsigned i = 0; i < Iterations; ++i) { - SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est); + SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags); AddToWorklist(NewEst.getNode()); - NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst); + NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst, Flags); AddToWorklist(NewEst.getNode()); - NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst); + NewEst = DAG.getNode(ISD::FSUB, DL, VT, ThreeHalves, NewEst, Flags); AddToWorklist(NewEst.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst); + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags); AddToWorklist(Est.getNode()); } return Est; @@ -13740,7 +14444,8 @@ SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est, /// => /// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0)) SDValue DAGCombiner::BuildRsqrtNRTwoConst(SDValue Arg, SDValue Est, - unsigned Iterations) { + unsigned Iterations, + SDNodeFlags *Flags) { EVT VT = Arg.getValueType(); SDLoc DL(Arg); SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT); @@ -13748,25 +14453,25 @@ SDValue DAGCombiner::BuildRsqrtNRTwoConst(SDValue Arg, SDValue Est, // Newton iterations: Est = -0.5 * Est * (-3.0 + Arg * Est * Est) for (unsigned i = 0; i < Iterations; ++i) { - SDValue HalfEst = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf); + SDValue HalfEst = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags); AddToWorklist(HalfEst.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Est); + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags); AddToWorklist(Est.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg); + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags); AddToWorklist(Est.getNode()); - Est = DAG.getNode(ISD::FADD, DL, VT, Est, MinusThree); + Est = DAG.getNode(ISD::FADD, DL, VT, Est, MinusThree, Flags); AddToWorklist(Est.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, HalfEst); + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, HalfEst, Flags); AddToWorklist(Est.getNode()); } return Est; } -SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op) { +SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) { if (Level >= AfterLegalizeDAG) return SDValue(); @@ -13778,8 +14483,8 @@ SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op) { AddToWorklist(Est.getNode()); if (Iterations) { Est = UseOneConstNR ? - BuildRsqrtNROneConst(Op, Est, Iterations) : - BuildRsqrtNRTwoConst(Op, Est, Iterations); + BuildRsqrtNROneConst(Op, Est, Iterations, Flags) : + BuildRsqrtNRTwoConst(Op, Est, Iterations, Flags); } return Est; } @@ -13832,6 +14537,15 @@ bool DAGCombiner::isAlias(LSBaseSDNode *Op0, LSBaseSDNode *Op1) const { // If they are both volatile then they cannot be reordered. if (Op0->isVolatile() && Op1->isVolatile()) return true; + // If one operation reads from invariant memory, and the other may store, they + // cannot alias. These should really be checking the equivalent of mayWrite, + // but it only matters for memory nodes other than load /store. + if (Op0->isInvariant() && Op1->writeMem()) + return false; + + if (Op1->isInvariant() && Op0->writeMem()) + return false; + // Gather base node and offset information. SDValue Base1, Base2; int64_t Offset1, Offset2; @@ -13934,14 +14648,12 @@ void DAGCombiner::GatherAllAliases(SDNode *N, SDValue OriginalChain, SDValue Chain = Chains.pop_back_val(); // For TokenFactor nodes, look at each operand and only continue up the - // chain until we find two aliases. If we've seen two aliases, assume we'll - // find more and revert to original chain since the xform is unlikely to be - // profitable. + // chain until we reach the depth limit. // // FIXME: The depth check could be made to return the last non-aliasing // chain we found before we hit a tokenfactor rather than the original // chain. - if (Depth > 6 || Aliases.size() == 2) { + if (Depth > TLI.getGatherAllAliasesMaxDepth()) { Aliases.clear(); Aliases.push_back(OriginalChain); return; @@ -14073,6 +14785,83 @@ SDValue DAGCombiner::FindBetterChain(SDNode *N, SDValue OldChain) { return DAG.getNode(ISD::TokenFactor, SDLoc(N), MVT::Other, Aliases); } +bool DAGCombiner::findBetterNeighborChains(StoreSDNode* St) { + // This holds the base pointer, index, and the offset in bytes from the base + // pointer. + BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr()); + + // We must have a base and an offset. + if (!BasePtr.Base.getNode()) + return false; + + // Do not handle stores to undef base pointers. + if (BasePtr.Base.getOpcode() == ISD::UNDEF) + return false; + + SmallVector ChainedStores; + ChainedStores.push_back(St); + + // Walk up the chain and look for nodes with offsets from the same + // base pointer. Stop when reaching an instruction with a different kind + // or instruction which has a different base pointer. + StoreSDNode *Index = St; + while (Index) { + // If the chain has more than one use, then we can't reorder the mem ops. + if (Index != St && !SDValue(Index, 0)->hasOneUse()) + break; + + if (Index->isVolatile() || Index->isIndexed()) + break; + + // Find the base pointer and offset for this memory node. + BaseIndexOffset Ptr = BaseIndexOffset::match(Index->getBasePtr()); + + // Check that the base pointer is the same as the original one. + if (!Ptr.equalBaseIndex(BasePtr)) + break; + + // Find the next memory operand in the chain. If the next operand in the + // chain is a store then move up and continue the scan with the next + // memory operand. If the next operand is a load save it and use alias + // information to check if it interferes with anything. + SDNode *NextInChain = Index->getChain().getNode(); + while (true) { + if (StoreSDNode *STn = dyn_cast(NextInChain)) { + // We found a store node. Use it for the next iteration. + ChainedStores.push_back(STn); + Index = STn; + break; + } else if (LoadSDNode *Ldn = dyn_cast(NextInChain)) { + NextInChain = Ldn->getChain().getNode(); + continue; + } else { + Index = nullptr; + break; + } + } + } + + bool MadeChange = false; + SmallVector, 8> BetterChains; + + for (StoreSDNode *ChainedStore : ChainedStores) { + SDValue Chain = ChainedStore->getChain(); + SDValue BetterChain = FindBetterChain(ChainedStore, Chain); + + if (Chain != BetterChain) { + MadeChange = true; + BetterChains.push_back(std::make_pair(ChainedStore, BetterChain)); + } + } + + // Do all replacements after finding the replacements to make to avoid making + // the chains more complicated by introducing new TokenFactors. + for (auto Replacement : BetterChains) + replaceStoreChain(Replacement.first, Replacement.second); + + return MadeChange; +} + /// This is the entry point for the file. void SelectionDAG::Combine(CombineLevel Level, AliasAnalysis &AA, CodeGenOpt::Level OptLevel) {