X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FCodeGen%2FSelectionDAG%2FDAGCombiner.cpp;h=df721e2d3b5e3fa27bac77e682faae04c7b6f306;hb=b176a4f2e41751bb52bd791979749c16d22bef18;hp=7dc79a4bfeec8fa46e03f7170026629b15dd2b22;hpb=62ffaaac7c03a1b50b98d93263a02514e39c634f;p=oota-llvm.git diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 7dc79a4bfee..df721e2d3b5 100644 --- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -251,7 +251,6 @@ namespace { SDValue visitORLike(SDValue N0, SDValue N1, SDNode *LocReference); SDValue visitXOR(SDNode *N); SDValue SimplifyVBinOp(SDNode *N); - SDValue SimplifyVUnaryOp(SDNode *N); SDValue visitSHL(SDNode *N); SDValue visitSRA(SDNode *N); SDValue visitSRL(SDNode *N); @@ -308,6 +307,10 @@ namespace { SDValue visitINSERT_SUBVECTOR(SDNode *N); SDValue visitMLOAD(SDNode *N); SDValue visitMSTORE(SDNode *N); + SDValue visitFP_TO_FP16(SDNode *N); + + SDValue visitFADDForFMACombine(SDNode *N); + SDValue visitFSUBForFMACombine(SDNode *N); SDValue XformToShuffleWithZero(SDNode *N); SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS); @@ -706,13 +709,23 @@ static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) { EltVT.getSizeInBits() >= SplatBitSize); } -// \brief Returns the SDNode if it is a constant BuildVector or constant. -static SDNode *isConstantBuildVectorOrConstantInt(SDValue N) { +// \brief Returns the SDNode if it is a constant integer BuildVector +// or constant integer. +static SDNode *isConstantIntBuildVectorOrConstantInt(SDValue N) { if (isa(N)) return N.getNode(); - BuildVectorSDNode *BV = dyn_cast(N); - if (BV && BV->isConstant()) - return BV; + if (ISD::isBuildVectorOfConstantSDNodes(N.getNode())) + return N.getNode(); + return nullptr; +} + +// \brief Returns the SDNode if it is a constant float BuildVector +// or constant float. +static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { + if (isa(N)) + return N.getNode(); + if (ISD::isBuildVectorOfConstantFPSDNodes(N.getNode())) + return N.getNode(); return nullptr; } @@ -758,8 +771,8 @@ SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL, SDValue N0, SDValue N1) { EVT VT = N0.getValueType(); if (N0.getOpcode() == Opc) { - if (SDNode *L = isConstantBuildVectorOrConstantInt(N0.getOperand(1))) { - if (SDNode *R = isConstantBuildVectorOrConstantInt(N1)) { + if (SDNode *L = isConstantIntBuildVectorOrConstantInt(N0.getOperand(1))) { + if (SDNode *R = isConstantIntBuildVectorOrConstantInt(N1)) { // reassoc. (op (op x, c1), c2) -> (op x, (op c1, c2)) if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, VT, L, R)) return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode); @@ -778,8 +791,8 @@ SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL, } if (N1.getOpcode() == Opc) { - if (SDNode *R = isConstantBuildVectorOrConstantInt(N1.getOperand(1))) { - if (SDNode *L = isConstantBuildVectorOrConstantInt(N0)) { + if (SDNode *R = isConstantIntBuildVectorOrConstantInt(N1.getOperand(1))) { + if (SDNode *L = isConstantIntBuildVectorOrConstantInt(N0)) { // reassoc. (op c2, (op x, c1)) -> (op x, (op c1, c2)) if (SDValue OpNode = DAG.FoldConstantArithmetic(Opc, VT, R, L)) return DAG.getNode(Opc, DL, VT, N1.getOperand(0), OpNode); @@ -1183,11 +1196,6 @@ void DAGCombiner::Run(CombineLevel AtLevel) { LegalOperations = Level >= AfterLegalizeVectorOps; LegalTypes = Level >= AfterLegalizeTypes; - // Early exit if this basic block is in an optnone function. - if (DAG.getMachineFunction().getFunction()->hasFnAttribute( - Attribute::OptimizeNone)) - return; - // Add all the dag nodes to the worklist. for (SelectionDAG::allnodes_iterator I = DAG.allnodes_begin(), E = DAG.allnodes_end(); I != E; ++I) @@ -1376,6 +1384,7 @@ SDValue DAGCombiner::visit(SDNode *N) { case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N); case ISD::MLOAD: return visitMLOAD(N); case ISD::MSTORE: return visitMSTORE(N); + case ISD::FP_TO_FP16: return visitFP_TO_FP16(N); } return SDValue(); } @@ -1573,8 +1582,8 @@ SDValue DAGCombiner::visitADD(SDNode *N) { // fold vector ops if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (add x, 0) -> x, vector edition if (ISD::isBuildVectorAllZeros(N1.getNode())) @@ -1594,7 +1603,8 @@ SDValue DAGCombiner::visitADD(SDNode *N) { if (N0C && N1C) return DAG.FoldConstantArithmetic(ISD::ADD, VT, N0C, N1C); // canonicalize constant to RHS - if (N0C && !N1C) + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::ADD, SDLoc(N), VT, N1, N0); // fold (add x, 0) -> x if (N1C && N1C->isNullValue()) @@ -1614,8 +1624,7 @@ SDValue DAGCombiner::visitADD(SDNode *N) { N0C->getAPIntValue(), VT), N0.getOperand(1)); // reassociate add - SDValue RADD = ReassociateOps(ISD::ADD, SDLoc(N), N0, N1); - if (RADD.getNode()) + if (SDValue RADD = ReassociateOps(ISD::ADD, SDLoc(N), N0, N1)) return RADD; // fold ((0-A) + B) -> B-A if (N0.getOpcode() == ISD::SUB && isa(N0.getOperand(0)) && @@ -1818,8 +1827,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) { // fold vector ops if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (sub x, 0) -> x, vector edition if (ISD::isBuildVectorAllZeros(N1.getNode())) @@ -1974,26 +1983,27 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { APInt ConstValue0, ConstValue1; // fold vector ops if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; N0IsConst = isConstantSplatVector(N0.getNode(), ConstValue0); N1IsConst = isConstantSplatVector(N1.getNode(), ConstValue1); } else { - N0IsConst = dyn_cast(N0) != nullptr; - ConstValue0 = N0IsConst ? (dyn_cast(N0))->getAPIntValue() - : APInt(); - N1IsConst = dyn_cast(N1) != nullptr; - ConstValue1 = N1IsConst ? (dyn_cast(N1))->getAPIntValue() - : APInt(); + N0IsConst = isa(N0); + if (N0IsConst) + ConstValue0 = cast(N0)->getAPIntValue(); + N1IsConst = isa(N1); + if (N1IsConst) + ConstValue1 = cast(N1)->getAPIntValue(); } // fold (mul c1, c2) -> c1*c2 if (N0IsConst && N1IsConst) return DAG.FoldConstantArithmetic(ISD::MUL, VT, N0.getNode(), N1.getNode()); - // canonicalize constant to RHS - if (N0IsConst && !N1IsConst) + // canonicalize constant to RHS (vector doesn't have to splat) + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0); // fold (mul x, 0) -> 0 if (N1IsConst && ConstValue1 == 0) @@ -2073,8 +2083,7 @@ SDValue DAGCombiner::visitMUL(SDNode *N) { N0.getOperand(1), N1)); // reassociate mul - SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1); - if (RMUL.getNode()) + if (SDValue RMUL = ReassociateOps(ISD::MUL, SDLoc(N), N0, N1)) return RMUL; return SDValue(); @@ -2086,10 +2095,9 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { EVT VT = N->getValueType(0); // fold vector ops - if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; - } + if (VT.isVector()) + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (sdiv c1, c2) -> c1/c2 ConstantSDNode *N0C = isConstOrConstSplat(N0); @@ -2153,7 +2161,7 @@ SDValue DAGCombiner::visitSDIV(SDNode *N) { return DAG.getNode(ISD::SUB, SDLoc(N), VT, DAG.getConstant(0, VT), SRA); } - // if integer divide is expensive and we satisfy the requirements, emit an + // If integer divide is expensive and we satisfy the requirements, emit an // alternate sequence. if (N1C && !TLI.isIntDivCheap()) { SDValue Op = BuildSDIV(N); @@ -2176,10 +2184,9 @@ SDValue DAGCombiner::visitUDIV(SDNode *N) { EVT VT = N->getValueType(0); // fold vector ops - if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; - } + if (VT.isVector()) + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (udiv c1, c2) -> c1/c2 ConstantSDNode *N0C = isConstOrConstSplat(N0); @@ -2449,8 +2456,8 @@ SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); - // If the type twice as wide is legal, transform the mulhu to a wider multiply - // plus a shift. + // If the type is twice as wide is legal, transform the mulhu to a wider + // multiply plus a shift. if (VT.isSimple() && !VT.isVector()) { MVT Simple = VT.getSimpleVT(); unsigned SimpleSize = Simple.getSizeInBits(); @@ -2479,8 +2486,8 @@ SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) { EVT VT = N->getValueType(0); SDLoc DL(N); - // If the type twice as wide is legal, transform the mulhu to a wider multiply - // plus a shift. + // If the type is twice as wide is legal, transform the mulhu to a wider + // multiply plus a shift. if (VT.isSimple() && !VT.isVector()) { MVT Simple = VT.getSimpleVT(); unsigned SimpleSize = Simple.getSizeInBits(); @@ -2799,8 +2806,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // fold vector ops if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (and x, 0) -> 0, vector edition if (ISD::isBuildVectorAllZeros(N0.getNode())) @@ -2829,7 +2836,8 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (N0C && N1C) return DAG.FoldConstantArithmetic(ISD::AND, VT, N0C, N1C); // canonicalize constant to RHS - if (N0C && !N1C) + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::AND, SDLoc(N), VT, N1, N0); // fold (and x, -1) -> x if (N1C && N1C->isAllOnesValue()) @@ -2840,8 +2848,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { APInt::getAllOnesValue(BitWidth))) return DAG.getConstant(0, VT); // reassociate and - SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1); - if (RAND.getNode()) + if (SDValue RAND = ReassociateOps(ISD::AND, SDLoc(N), N0, N1)) return RAND; // fold (and (or x, C), D) -> D if (C & D) == D if (N1C && N0.getOpcode() == ISD::OR) @@ -3460,8 +3467,8 @@ SDValue DAGCombiner::visitOR(SDNode *N) { // fold vector ops if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (or x, 0) -> x, vector edition if (ISD::isBuildVectorAllZeros(N0.getNode())) @@ -3546,7 +3553,8 @@ SDValue DAGCombiner::visitOR(SDNode *N) { if (N0C && N1C) return DAG.FoldConstantArithmetic(ISD::OR, VT, N0C, N1C); // canonicalize constant to RHS - if (N0C && !N1C) + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::OR, SDLoc(N), VT, N1, N0); // fold (or x, 0) -> x if (N1C && N1C->isNullValue()) @@ -3570,8 +3578,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) { return BSwap; // reassociate or - SDValue ROR = ReassociateOps(ISD::OR, SDLoc(N), N0, N1); - if (ROR.getNode()) + if (SDValue ROR = ReassociateOps(ISD::OR, SDLoc(N), N0, N1)) return ROR; // Canonicalize (or (and X, c1), c2) -> (and (or X, c2), c1|c2) // iff (c1 & c2) == 0. @@ -3865,8 +3872,8 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { // fold vector ops if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (xor x, 0) -> x, vector edition if (ISD::isBuildVectorAllZeros(N0.getNode())) @@ -3889,14 +3896,14 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { if (N0C && N1C) return DAG.FoldConstantArithmetic(ISD::XOR, VT, N0C, N1C); // canonicalize constant to RHS - if (N0C && !N1C) + if (isConstantIntBuildVectorOrConstantInt(N0) && + !isConstantIntBuildVectorOrConstantInt(N1)) return DAG.getNode(ISD::XOR, SDLoc(N), VT, N1, N0); // fold (xor x, 0) -> x if (N1C && N1C->isNullValue()) return N0; // reassociate xor - SDValue RXOR = ReassociateOps(ISD::XOR, SDLoc(N), N0, N1); - if (RXOR.getNode()) + if (SDValue RXOR = ReassociateOps(ISD::XOR, SDLoc(N), N0, N1)) return RXOR; // fold !(x cc y) -> (x !cc y) @@ -3980,6 +3987,32 @@ SDValue DAGCombiner::visitXOR(SDNode *N) { if (N0 == N1) return tryFoldToZero(SDLoc(N), TLI, VT, DAG, LegalOperations, LegalTypes); + // fold (xor (shl 1, x), -1) -> (rotl ~1, x) + // Here is a concrete example of this equivalence: + // i16 x == 14 + // i16 shl == 1 << 14 == 16384 == 0b0100000000000000 + // i16 xor == ~(1 << 14) == 49151 == 0b1011111111111111 + // + // => + // + // i16 ~1 == 0b1111111111111110 + // i16 rol(~1, 14) == 0b1011111111111111 + // + // Some additional tips to help conceptualize this transform: + // - Try to see the operation as placing a single zero in a value of all ones. + // - There exists no value for x which would allow the result to contain zero. + // - Values of x larger than the bitwidth are undefined and do not require a + // consistent result. + // - Pushing the zero left requires shifting one bits in from the right. + // A rotate left of ~1 is a nice way of achieving the desired result. + if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT)) + if (auto *N1C = dyn_cast(N1.getNode())) + if (N0.getOpcode() == ISD::SHL) + if (auto *ShlLHS = dyn_cast(N0.getOperand(0))) + if (N1C->isAllOnesValue() && ShlLHS->isOne()) + return DAG.getNode(ISD::ROTL, SDLoc(N), VT, DAG.getConstant(~1, VT), + N0.getOperand(1)); + // Simplify: xor (op x...), (op y...) -> (op (xor x, y)) if (N0.getOpcode() == N1.getOpcode()) { SDValue Tmp = SimplifyBinOpWithSameOpcodeHands(N); @@ -4116,8 +4149,8 @@ SDValue DAGCombiner::visitSHL(SDNode *N) { // fold vector ops ConstantSDNode *N1C = dyn_cast(N1); if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; BuildVectorSDNode *N1CV = dyn_cast(N1); // If setcc produces all-one true value then: @@ -4296,8 +4329,8 @@ SDValue DAGCombiner::visitSRA(SDNode *N) { // fold vector ops ConstantSDNode *N1C = dyn_cast(N1); if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; N1C = isConstOrConstSplat(N1); } @@ -4442,8 +4475,8 @@ SDValue DAGCombiner::visitSRL(SDNode *N) { // fold vector ops ConstantSDNode *N1C = dyn_cast(N1); if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; N1C = isConstOrConstSplat(N1); } @@ -4853,7 +4886,7 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { SDValue N1_0 = N1->getOperand(0); SDValue N1_1 = N1->getOperand(1); SDValue N1_2 = N1->getOperand(2); - if (N1_2 == N2) { + if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) { // Create the actual and node if we can generate good code for it. if (!TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT)) { SDValue And = DAG.getNode(ISD::AND, SDLoc(N), N0.getValueType(), @@ -4872,7 +4905,7 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { SDValue N2_0 = N2->getOperand(0); SDValue N2_1 = N2->getOperand(1); SDValue N2_2 = N2->getOperand(2); - if (N2_1 == N1) { + if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) { // Create the actual or node if we can generate good code for it. if (!TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT)) { SDValue Or = DAG.getNode(ISD::OR, SDLoc(N), N0.getValueType(), @@ -5160,6 +5193,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) { } } + if (SimplifySelectOps(N, N1, N2)) + return SDValue(N, 0); // Don't revisit N. + // If the VSELECT result requires splitting and the mask is provided by a // SETCC, then split both nodes and its operands before legalization. This // prevents the type legalizer from unrolling SETCC into scalar comparisons @@ -6536,7 +6572,7 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) { if (N0.getValueType() == N->getValueType(0)) return N0; // fold (truncate c1) -> c1 - if (isa(N0)) + if (isConstantIntBuildVectorOrConstantInt(N0)) return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0); // fold (truncate (truncate x)) -> (truncate x) if (N0.getOpcode() == ISD::TRUNCATE) @@ -6898,6 +6934,51 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) { return CombineLD; } + // Remove double bitcasts from shuffles - this is often a legacy of + // XformToShuffleWithZero being used to combine bitmaskings (of + // float vectors bitcast to integer vectors) into shuffles. + // bitcast(shuffle(bitcast(s0),bitcast(s1))) -> shuffle(s0,s1) + if (Level < AfterLegalizeDAG && TLI.isTypeLegal(VT) && VT.isVector() && + N0->getOpcode() == ISD::VECTOR_SHUFFLE && + VT.getVectorNumElements() >= N0.getValueType().getVectorNumElements() && + !(VT.getVectorNumElements() % N0.getValueType().getVectorNumElements())) { + ShuffleVectorSDNode *SVN = cast(N0); + + // If operands are a bitcast, peek through if it casts the original VT. + // If operands are a UNDEF or constant, just bitcast back to original VT. + auto PeekThroughBitcast = [&](SDValue Op) { + if (Op.getOpcode() == ISD::BITCAST && + Op.getOperand(0)->getValueType(0) == VT) + return SDValue(Op.getOperand(0)); + if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) || + ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode())) + return DAG.getNode(ISD::BITCAST, SDLoc(N), VT, Op); + return SDValue(); + }; + + SDValue SV0 = PeekThroughBitcast(N0->getOperand(0)); + SDValue SV1 = PeekThroughBitcast(N0->getOperand(1)); + if (!(SV0 && SV1)) + return SDValue(); + + int MaskScale = + VT.getVectorNumElements() / N0.getValueType().getVectorNumElements(); + SmallVector NewMask; + for (int M : SVN->getMask()) + for (int i = 0; i != MaskScale; ++i) + NewMask.push_back(M < 0 ? -1 : M * MaskScale + i); + + bool LegalMask = TLI.isShuffleMaskLegal(NewMask, VT); + if (!LegalMask) { + std::swap(SV0, SV1); + ShuffleVectorSDNode::commuteMask(NewMask); + LegalMask = TLI.isShuffleMaskLegal(NewMask, VT); + } + + if (LegalMask) + return DAG.getVectorShuffle(VT, SDLoc(N), SV0, SV1, NewMask); + } + return SDValue(); } @@ -7001,7 +7082,6 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { // Finally, this must be the case where we are shrinking elements: each input // turns into multiple outputs. - bool isS2V = ISD::isScalarToVector(BV); unsigned NumOutputsPerInput = SrcBitSize/DstBitSize; EVT VT = EVT::getVectorVT(*DAG.getContext(), DstEltVT, NumOutputsPerInput*BV->getNumOperands()); @@ -7019,10 +7099,6 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { for (unsigned j = 0; j != NumOutputsPerInput; ++j) { APInt ThisVal = OpVal.trunc(DstBitSize); Ops.push_back(DAG.getConstant(ThisVal, DstEltVT)); - if (isS2V && i == 0 && j == 0 && ThisVal.zext(SrcBitSize) == OpVal) - // Simply turn this into a SCALAR_TO_VECTOR of the new type. - return DAG.getNode(ISD::SCALAR_TO_VECTOR, SDLoc(BV), VT, - Ops[0]); OpVal = OpVal.lshr(DstBitSize); } @@ -7034,20 +7110,44 @@ ConstantFoldBITCASTofBUILD_VECTOR(SDNode *BV, EVT DstEltVT) { return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(BV), VT, Ops); } -// Attempt different variants of (fadd (fmul a, b), c) -> fma or fmad -static SDValue performFaddFmulCombines(unsigned FusedOpcode, - bool Aggressive, - SDNode *N, - const TargetLowering &TLI, - SelectionDAG &DAG) { +/// Try to perform FMA combining on a given FADD node. +SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { + + + + SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); + SDLoc SL(N); + + const TargetOptions &Options = DAG.getTarget().Options; + bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast || + Options.UnsafeFPMath); + + // Floating-point multiply-add with intermediate rounding. + bool HasFMAD = (LegalOperations && + TLI.isOperationLegal(ISD::FMAD, VT)); + + // Floating-point multiply-add without intermediate rounding. + bool HasFMA = ((!LegalOperations || + TLI.isOperationLegalOrCustom(ISD::FMA, VT)) && + TLI.isFMAFasterThanFMulAndFAdd(VT) && + UnsafeFPMath); + + // No valid opcode, do not combine. + if (!HasFMAD && !HasFMA) + return SDValue(); + + // Always prefer FMAD to FMA for precision. + unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; + bool Aggressive = TLI.enableAggressiveFMAFusion(VT); + bool LookThroughFPExt = TLI.isFPExtFree(VT); // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (N0.getOpcode() == ISD::FMUL && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(FusedOpcode, SDLoc(N), VT, + return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), N1); } @@ -7055,53 +7155,180 @@ static SDValue performFaddFmulCombines(unsigned FusedOpcode, // Note: Commutes FADD operands. if (N1.getOpcode() == ISD::FMUL && (Aggressive || N1->hasOneUse())) { - return DAG.getNode(FusedOpcode, SDLoc(N), VT, + return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), N0); } + // Look through FP_EXTEND nodes to do more combining. + if (UnsafeFPMath && LookThroughFPExt) { + // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(1)), N1); + } + + // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) + // Note: Commutes FADD operands. + if (N1.getOpcode() == ISD::FP_EXTEND) { + SDValue N10 = N1.getOperand(0); + if (N10.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N10.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N10.getOperand(1)), N0); + } + } + // More folding opportunities when target permits. - if (Aggressive) { + if ((UnsafeFPMath || HasFMAD) && Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) - if (N0.getOpcode() == ISD::FMA && + if (N0.getOpcode() == PreferredFusedOpcode && N0.getOperand(2).getOpcode() == ISD::FMUL) { - return DAG.getNode(FusedOpcode, SDLoc(N), VT, + return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode(FusedOpcode, SDLoc(N), VT, + DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), N1)); } // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) - if (N1->getOpcode() == ISD::FMA && + if (N1->getOpcode() == PreferredFusedOpcode && N1.getOperand(2).getOpcode() == ISD::FMUL) { - return DAG.getNode(FusedOpcode, SDLoc(N), VT, + return DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), - DAG.getNode(FusedOpcode, SDLoc(N), VT, + DAG.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(2).getOperand(0), N1.getOperand(2).getOperand(1), N0)); } + + if (UnsafeFPMath && LookThroughFPExt) { + // fold (fadd (fma x, y, (fpext (fmul u, v))), z) + // -> (fma x, y, (fma (fpext u), (fpext v), z)) + auto FoldFAddFMAFPExtFMul = [&] ( + SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { + return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, U), + DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + Z)); + }; + if (N0.getOpcode() == PreferredFusedOpcode) { + SDValue N02 = N0.getOperand(2); + if (N02.getOpcode() == ISD::FP_EXTEND) { + SDValue N020 = N02.getOperand(0); + if (N020.getOpcode() == ISD::FMUL) + return FoldFAddFMAFPExtFMul(N0.getOperand(0), N0.getOperand(1), + N020.getOperand(0), N020.getOperand(1), + N1); + } + } + + // fold (fadd (fpext (fma x, y, (fmul u, v))), z) + // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) + // FIXME: This turns two single-precision and one double-precision + // operation into two double-precision operations, which might not be + // interesting for all targets, especially GPUs. + auto FoldFAddFPExtFMAFMul = [&] ( + SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z) { + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, X), + DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, U), + DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + Z)); + }; + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == PreferredFusedOpcode) { + SDValue N002 = N00.getOperand(2); + if (N002.getOpcode() == ISD::FMUL) + return FoldFAddFPExtFMAFMul(N00.getOperand(0), N00.getOperand(1), + N002.getOperand(0), N002.getOperand(1), + N1); + } + } + + // fold (fadd x, (fma y, z, (fpext (fmul u, v))) + // -> (fma y, z, (fma (fpext u), (fpext v), x)) + if (N1.getOpcode() == PreferredFusedOpcode) { + SDValue N12 = N1.getOperand(2); + if (N12.getOpcode() == ISD::FP_EXTEND) { + SDValue N120 = N12.getOperand(0); + if (N120.getOpcode() == ISD::FMUL) + return FoldFAddFMAFPExtFMul(N1.getOperand(0), N1.getOperand(1), + N120.getOperand(0), N120.getOperand(1), + N0); + } + } + + // fold (fadd x, (fpext (fma y, z, (fmul u, v))) + // -> (fma (fpext y), (fpext z), (fma (fpext u), (fpext v), x)) + // FIXME: This turns two single-precision and one double-precision + // operation into two double-precision operations, which might not be + // interesting for all targets, especially GPUs. + if (N1.getOpcode() == ISD::FP_EXTEND) { + SDValue N10 = N1.getOperand(0); + if (N10.getOpcode() == PreferredFusedOpcode) { + SDValue N102 = N10.getOperand(2); + if (N102.getOpcode() == ISD::FMUL) + return FoldFAddFPExtFMAFMul(N10.getOperand(0), N10.getOperand(1), + N102.getOperand(0), N102.getOperand(1), + N0); + } + } + } } return SDValue(); } -static SDValue performFsubFmulCombines(unsigned FusedOpcode, - bool Aggressive, - SDNode *N, - const TargetLowering &TLI, - SelectionDAG &DAG) { +/// Try to perform FMA combining on a given FSUB node. +SDValue DAGCombiner::visitFSUBForFMACombine(SDNode *N) { + + + SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); + const TargetOptions &Options = DAG.getTarget().Options; + bool UnsafeFPMath = (Options.AllowFPOpFusion == FPOpFusion::Fast || + Options.UnsafeFPMath); + + // Floating-point multiply-add with intermediate rounding. + bool HasFMAD = (LegalOperations && + TLI.isOperationLegal(ISD::FMAD, VT)); + + // Floating-point multiply-add without intermediate rounding. + bool HasFMA = ((!LegalOperations || + TLI.isOperationLegalOrCustom(ISD::FMA, VT)) && + TLI.isFMAFasterThanFMulAndFAdd(VT) && + UnsafeFPMath); + + // No valid opcode, do not combine. + if (!HasFMAD && !HasFMA) + return SDValue(); + + // Always prefer FMAD to FMA for precision. + unsigned int PreferredFusedOpcode = HasFMAD ? ISD::FMAD : ISD::FMA; + bool Aggressive = TLI.enableAggressiveFMAFusion(VT); + bool LookThroughFPExt = TLI.isFPExtFree(VT); + // fold (fsub (fmul x, y), z) -> (fma x, y, (fneg z)) if (N0.getOpcode() == ISD::FMUL && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(FusedOpcode, SL, VT, + return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), DAG.getNode(ISD::FNEG, SL, VT, N1)); } @@ -7110,7 +7337,7 @@ static SDValue performFsubFmulCombines(unsigned FusedOpcode, // Note: Commutes FSUB operands. if (N1.getOpcode() == ISD::FMUL && (Aggressive || N1->hasOneUse())) - return DAG.getNode(FusedOpcode, SL, VT, + return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), N0); @@ -7121,41 +7348,214 @@ static SDValue performFsubFmulCombines(unsigned FusedOpcode, (Aggressive || (N0->hasOneUse() && N0.getOperand(0).hasOneUse()))) { SDValue N00 = N0.getOperand(0).getOperand(0); SDValue N01 = N0.getOperand(0).getOperand(1); - return DAG.getNode(FusedOpcode, SL, VT, + return DAG.getNode(PreferredFusedOpcode, SL, VT, DAG.getNode(ISD::FNEG, SL, VT, N00), N01, DAG.getNode(ISD::FNEG, SL, VT, N1)); } + // Look through FP_EXTEND nodes to do more combining. + if (UnsafeFPMath && LookThroughFPExt) { + // fold (fsub (fpext (fmul x, y)), z) + // -> (fma (fpext x), (fpext y), (fneg z)) + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(1)), + DAG.getNode(ISD::FNEG, SL, VT, N1)); + } + + // fold (fsub x, (fpext (fmul y, z))) + // -> (fma (fneg (fpext y)), (fpext z), x) + // Note: Commutes FSUB operands. + if (N1.getOpcode() == ISD::FP_EXTEND) { + SDValue N10 = N1.getOperand(0); + if (N10.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N10.getOperand(0))), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N10.getOperand(1)), + N0); + } + + // fold (fsub (fpext (fneg (fmul, x, y))), z) + // -> (fneg (fma (fpext x), (fpext y), z)) + // Note: This could be removed with appropriate canonicalization of the + // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the + // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent + // from implementing the canonicalization in visitFSUB. + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == ISD::FNEG) { + SDValue N000 = N00.getOperand(0); + if (N000.getOpcode() == ISD::FMUL) { + return DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N000.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N000.getOperand(1)), + N1)); + } + } + } + + // fold (fsub (fneg (fpext (fmul, x, y))), z) + // -> (fneg (fma (fpext x)), (fpext y), z) + // Note: This could be removed with appropriate canonicalization of the + // input expression into (fneg (fadd (fpext (fmul, x, y)), z). However, the + // orthogonal flags -fp-contract=fast and -enable-unsafe-fp-math prevent + // from implementing the canonicalization in visitFSUB. + if (N0.getOpcode() == ISD::FNEG) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == ISD::FP_EXTEND) { + SDValue N000 = N00.getOperand(0); + if (N000.getOpcode() == ISD::FMUL) { + return DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N000.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N000.getOperand(1)), + N1)); + } + } + } + + } + // More folding opportunities when target permits. - if (Aggressive) { + if ((UnsafeFPMath || HasFMAD) && Aggressive) { // fold (fsub (fma x, y, (fmul u, v)), z) // -> (fma x, y (fma u, v, (fneg z))) - if (N0.getOpcode() == FusedOpcode && + if (N0.getOpcode() == PreferredFusedOpcode && N0.getOperand(2).getOpcode() == ISD::FMUL) { - return DAG.getNode(FusedOpcode, SDLoc(N), VT, + return DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode(FusedOpcode, SDLoc(N), VT, + DAG.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), - DAG.getNode(ISD::FNEG, SDLoc(N), VT, + DAG.getNode(ISD::FNEG, SL, VT, N1))); } // fold (fsub x, (fma y, z, (fmul u, v))) // -> (fma (fneg y), z, (fma (fneg u), v, x)) - if (N1.getOpcode() == FusedOpcode && + if (N1.getOpcode() == PreferredFusedOpcode && N1.getOperand(2).getOpcode() == ISD::FMUL) { SDValue N20 = N1.getOperand(2).getOperand(0); SDValue N21 = N1.getOperand(2).getOperand(1); - return DAG.getNode(FusedOpcode, SDLoc(N), VT, - DAG.getNode(ISD::FNEG, SDLoc(N), VT, + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), N1.getOperand(1), - DAG.getNode(FusedOpcode, SDLoc(N), VT, - DAG.getNode(ISD::FNEG, SDLoc(N), VT, - N20), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, N20), + N21, N0)); } + + if (UnsafeFPMath && LookThroughFPExt) { + // fold (fsub (fma x, y, (fpext (fmul u, v))), z) + // -> (fma x, y (fma (fpext u), (fpext v), (fneg z))) + if (N0.getOpcode() == PreferredFusedOpcode) { + SDValue N02 = N0.getOperand(2); + if (N02.getOpcode() == ISD::FP_EXTEND) { + SDValue N020 = N02.getOperand(0); + if (N020.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + N0.getOperand(0), N0.getOperand(1), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N020.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N020.getOperand(1)), + DAG.getNode(ISD::FNEG, SL, VT, + N1))); + } + } + + // fold (fsub (fpext (fma x, y, (fmul u, v))), z) + // -> (fma (fpext x), (fpext y), + // (fma (fpext u), (fpext v), (fneg z))) + // FIXME: This turns two single-precision and one double-precision + // operation into two double-precision operations, which might not be + // interesting for all targets, especially GPUs. + if (N0.getOpcode() == ISD::FP_EXTEND) { + SDValue N00 = N0.getOperand(0); + if (N00.getOpcode() == PreferredFusedOpcode) { + SDValue N002 = N00.getOperand(2); + if (N002.getOpcode() == ISD::FMUL) + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N00.getOperand(1)), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N002.getOperand(0)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N002.getOperand(1)), + DAG.getNode(ISD::FNEG, SL, VT, + N1))); + } + } + + // fold (fsub x, (fma y, z, (fpext (fmul u, v)))) + // -> (fma (fneg y), z, (fma (fneg (fpext u)), (fpext v), x)) + if (N1.getOpcode() == PreferredFusedOpcode && + N1.getOperand(2).getOpcode() == ISD::FP_EXTEND) { + SDValue N120 = N1.getOperand(2).getOperand(0); + if (N120.getOpcode() == ISD::FMUL) { + SDValue N1200 = N120.getOperand(0); + SDValue N1201 = N120.getOperand(1); + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, N1.getOperand(0)), + N1.getOperand(1), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, + VT, N1200)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N1201), + N0)); + } + } + + // fold (fsub x, (fpext (fma y, z, (fmul u, v)))) + // -> (fma (fneg (fpext y)), (fpext z), + // (fma (fneg (fpext u)), (fpext v), x)) + // FIXME: This turns two single-precision and one double-precision + // operation into two double-precision operations, which might not be + // interesting for all targets, especially GPUs. + if (N1.getOpcode() == ISD::FP_EXTEND && + N1.getOperand(0).getOpcode() == PreferredFusedOpcode) { + SDValue N100 = N1.getOperand(0).getOperand(0); + SDValue N101 = N1.getOperand(0).getOperand(1); + SDValue N102 = N1.getOperand(0).getOperand(2); + if (N102.getOpcode() == ISD::FMUL) { + SDValue N1020 = N102.getOperand(0); + SDValue N1021 = N102.getOperand(1); + return DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N100)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, N101), + DAG.getNode(PreferredFusedOpcode, SL, VT, + DAG.getNode(ISD::FNEG, SL, VT, + DAG.getNode(ISD::FP_EXTEND, SL, + VT, N1020)), + DAG.getNode(ISD::FP_EXTEND, SL, VT, + N1021), + N0)); + } + } + } } return SDValue(); @@ -7170,10 +7570,9 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; // fold vector ops - if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; - } + if (VT.isVector()) + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (fadd c1, c2) -> c1 + c2 if (N0CFP && N1CFP) @@ -7300,55 +7699,11 @@ SDValue DAGCombiner::visitFADD(SDNode *N) { } } // enable-unsafe-fp-math - if (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)) { - // Assume if there is an fmad instruction that it should be aggressively - // used. - if (SDValue Fused = performFaddFmulCombines(ISD::FMAD, true, N, TLI, DAG)) - return Fused; - } - // FADD -> FMA combines: - if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) && - TLI.isFMAFasterThanFMulAndFAdd(VT) && - (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT))) { - - if (!TLI.isOperationLegal(ISD::FMAD, VT)) { - // Don't form FMA if we are preferring FMAD. - if (SDValue Fused - = performFaddFmulCombines(ISD::FMA, - TLI.enableAggressiveFMAFusion(VT), - N, TLI, DAG)) { - return Fused; - } - } - - // When FP_EXTEND nodes are free on the target, and there is an opportunity - // to combine into FMA, arrange such nodes accordingly. - if (TLI.isFPExtFree(VT)) { - - // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (N0.getOpcode() == ISD::FP_EXTEND) { - SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FMUL) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, - N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, - N00.getOperand(1)), N1); - } - - // fold (fadd x, (fpext (fmul y, z)), z) -> (fma (fpext y), (fpext z), x) - // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { - SDValue N10 = N1.getOperand(0); - if (N10.getOpcode() == ISD::FMUL) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, - N10.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, - N10.getOperand(1)), N0); - } - } + SDValue Fused = visitFADDForFMACombine(N); + if (Fused) { + AddToWorklist(Fused.getNode()); + return Fused; } return SDValue(); @@ -7364,10 +7719,9 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; // fold vector ops - if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; - } + if (VT.isVector()) + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (fsub c1, c2) -> c1-c2 if (N0CFP && N1CFP) @@ -7410,96 +7764,11 @@ SDValue DAGCombiner::visitFSUB(SDNode *N) { } } - if (LegalOperations && TLI.isOperationLegal(ISD::FMAD, VT)) { - // Assume if there is an fmad instruction that it should be aggressively - // used. - if (SDValue Fused = performFsubFmulCombines(ISD::FMAD, true, N, TLI, DAG)) - return Fused; - } - // FSUB -> FMA combines: - if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) && - TLI.isFMAFasterThanFMulAndFAdd(VT) && - (!LegalOperations || TLI.isOperationLegalOrCustom(ISD::FMA, VT))) { - - if (!TLI.isOperationLegal(ISD::FMAD, VT)) { - // Don't form FMA if we are preferring FMAD. - - if (SDValue Fused - = performFsubFmulCombines(ISD::FMA, - TLI.enableAggressiveFMAFusion(VT), - N, TLI, DAG)) { - return Fused; - } - } - - // When FP_EXTEND nodes are free on the target, and there is an opportunity - // to combine into FMA, arrange such nodes accordingly. - if (TLI.isFPExtFree(VT)) { - // fold (fsub (fpext (fmul x, y)), z) - // -> (fma (fpext x), (fpext y), (fneg z)) - if (N0.getOpcode() == ISD::FP_EXTEND) { - SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FMUL) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, - N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, - N00.getOperand(1)), - DAG.getNode(ISD::FNEG, SDLoc(N), VT, N1)); - } - - // fold (fsub x, (fpext (fmul y, z))) - // -> (fma (fneg (fpext y)), (fpext z), x) - // Note: Commutes FSUB operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { - SDValue N10 = N1.getOperand(0); - if (N10.getOpcode() == ISD::FMUL) - return DAG.getNode(ISD::FMA, SDLoc(N), VT, - DAG.getNode(ISD::FNEG, SDLoc(N), VT, - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), - VT, N10.getOperand(0))), - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, - N10.getOperand(1)), - N0); - } - - // fold (fsub (fpext (fneg (fmul, x, y))), z) - // -> (fma (fneg (fpext x)), (fpext y), (fneg z)) - if (N0.getOpcode() == ISD::FP_EXTEND) { - SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FNEG) { - SDValue N000 = N00.getOperand(0); - if (N000.getOpcode() == ISD::FMUL) { - return DAG.getNode(ISD::FMA, dl, VT, - DAG.getNode(ISD::FNEG, dl, VT, - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), - VT, N000.getOperand(0))), - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, - N000.getOperand(1)), - DAG.getNode(ISD::FNEG, dl, VT, N1)); - } - } - } - - // fold (fsub (fneg (fpext (fmul, x, y))), z) - // -> (fma (fneg (fpext x)), (fpext y), (fneg z)) - if (N0.getOpcode() == ISD::FNEG) { - SDValue N00 = N0.getOperand(0); - if (N00.getOpcode() == ISD::FP_EXTEND) { - SDValue N000 = N00.getOperand(0); - if (N000.getOpcode() == ISD::FMUL) { - return DAG.getNode(ISD::FMA, dl, VT, - DAG.getNode(ISD::FNEG, dl, VT, - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), - VT, N000.getOperand(0))), - DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, - N000.getOperand(1)), - DAG.getNode(ISD::FNEG, dl, VT, N1)); - } - } - } - } + SDValue Fused = visitFSUBForFMACombine(N); + if (Fused) { + AddToWorklist(Fused.getNode()); + return Fused; } return SDValue(); @@ -7516,15 +7785,8 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { // fold vector ops if (VT.isVector()) { // This just handles C1 * C2 for vectors. Other vector folds are below. - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) + if (SDValue FoldedVOp = SimplifyVBinOp(N)) return FoldedVOp; - // Canonicalize vector constant to RHS. - if (N0.getOpcode() == ISD::BUILD_VECTOR && - N1.getOpcode() != ISD::BUILD_VECTOR) - if (auto *BV0 = dyn_cast(N0)) - if (BV0->isConstant()) - return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N1, N0); } // fold (fmul c1, c2) -> c1*c2 @@ -7532,7 +7794,8 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) { return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0, N1); // canonicalize constant to RHS - if (N0CFP && !N1CFP) + if (isConstantFPBuildVectorOrConstantFP(N0) && + !isConstantFPBuildVectorOrConstantFP(N1)) return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N1, N0); // fold (fmul A, 1.0) -> A @@ -7698,10 +7961,9 @@ SDValue DAGCombiner::visitFDIV(SDNode *N) { const TargetOptions &Options = DAG.getTarget().Options; // fold vector ops - if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVBinOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; - } + if (VT.isVector()) + if (SDValue FoldedVOp = SimplifyVBinOp(N)) + return FoldedVOp; // fold (fdiv c1, c2) -> c1/c2 if (N0CFP && N1CFP) @@ -7926,8 +8188,7 @@ SDValue DAGCombiner::visitSINT_TO_FP(SDNode *N) { EVT OpVT = N0.getValueType(); // fold (sint_to_fp c1) -> c1fp - ConstantSDNode *N0C = dyn_cast(N0); - if (N0C && + if (isConstantIntBuildVectorOrConstantInt(N0) && // ...but only if the target supports immediate floating-point values (!LegalOperations || TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) @@ -7979,8 +8240,7 @@ SDValue DAGCombiner::visitUINT_TO_FP(SDNode *N) { EVT OpVT = N0.getValueType(); // fold (uint_to_fp c1) -> c1fp - ConstantSDNode *N0C = dyn_cast(N0); - if (N0C && + if (isConstantIntBuildVectorOrConstantInt(N0) && // ...but only if the target supports immediate floating-point values (!LegalOperations || TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT))) @@ -8138,7 +8398,6 @@ SDValue DAGCombiner::visitFP_ROUND_INREG(SDNode *N) { SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) { SDValue N0 = N->getOperand(0); - ConstantFPSDNode *N0CFP = dyn_cast(N0); EVT VT = N->getValueType(0); // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded. @@ -8147,9 +8406,14 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) { return SDValue(); // fold (fp_extend c1fp) -> c1fp - if (N0CFP) + if (isConstantFPBuildVectorOrConstantFP(N0)) return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0); + // fold (fp_extend (fp16_to_fp op)) -> (fp16_to_fp op) + if (N0.getOpcode() == ISD::FP16_TO_FP && + TLI.getOperationAction(ISD::FP16_TO_FP, VT) == TargetLowering::Legal) + return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), VT, N0.getOperand(0)); + // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the // value of X. if (N0.getOpcode() == ISD::FP_ROUND @@ -8183,11 +8447,10 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) { SDValue DAGCombiner::visitFCEIL(SDNode *N) { SDValue N0 = N->getOperand(0); - ConstantFPSDNode *N0CFP = dyn_cast(N0); EVT VT = N->getValueType(0); // fold (fceil c1) -> fceil(c1) - if (N0CFP) + if (isConstantFPBuildVectorOrConstantFP(N0)) return DAG.getNode(ISD::FCEIL, SDLoc(N), VT, N0); return SDValue(); @@ -8195,11 +8458,10 @@ SDValue DAGCombiner::visitFCEIL(SDNode *N) { SDValue DAGCombiner::visitFTRUNC(SDNode *N) { SDValue N0 = N->getOperand(0); - ConstantFPSDNode *N0CFP = dyn_cast(N0); EVT VT = N->getValueType(0); // fold (ftrunc c1) -> ftrunc(c1) - if (N0CFP) + if (isConstantFPBuildVectorOrConstantFP(N0)) return DAG.getNode(ISD::FTRUNC, SDLoc(N), VT, N0); return SDValue(); @@ -8207,11 +8469,10 @@ SDValue DAGCombiner::visitFTRUNC(SDNode *N) { SDValue DAGCombiner::visitFFLOOR(SDNode *N) { SDValue N0 = N->getOperand(0); - ConstantFPSDNode *N0CFP = dyn_cast(N0); EVT VT = N->getValueType(0); // fold (ffloor c1) -> ffloor(c1) - if (N0CFP) + if (isConstantFPBuildVectorOrConstantFP(N0)) return DAG.getNode(ISD::FFLOOR, SDLoc(N), VT, N0); return SDValue(); @@ -8222,14 +8483,9 @@ SDValue DAGCombiner::visitFNEG(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); - if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVUnaryOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; - } - // Constant fold FNEG. - if (isa(N0)) - return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N->getOperand(0)); + if (isConstantFPBuildVectorOrConstantFP(N0)) + return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0); if (isNegatibleForFree(N0, LegalOperations, DAG.getTargetLoweringInfo(), &DAG.getTarget().Options)) @@ -8324,13 +8580,8 @@ SDValue DAGCombiner::visitFABS(SDNode *N) { SDValue N0 = N->getOperand(0); EVT VT = N->getValueType(0); - if (VT.isVector()) { - SDValue FoldedVOp = SimplifyVUnaryOp(N); - if (FoldedVOp.getNode()) return FoldedVOp; - } - // fold (fabs c1) -> fabs(c1) - if (isa(N0)) + if (isConstantFPBuildVectorOrConstantFP(N0)) return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0); // fold (fabs (fabs x)) -> (fabs x) @@ -8553,11 +8804,11 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, if (LoadSDNode *LD = dyn_cast(Use)) { if (LD->isIndexed() || LD->getBasePtr().getNode() != N) return false; - VT = Use->getValueType(0); + VT = LD->getMemoryVT(); } else if (StoreSDNode *ST = dyn_cast(Use)) { if (ST->isIndexed() || ST->getBasePtr().getNode() != N) return false; - VT = ST->getValue().getValueType(); + VT = ST->getMemoryVT(); } else return false; @@ -9046,7 +9297,8 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) { LD->getMemoryVT(), LD->isVolatile(), LD->isNonTemporal(), LD->isInvariant(), Align, LD->getAAInfo()); - return CombineTo(N, NewLoad, SDValue(NewLoad.getNode(), 1), true); + if (NewLoad.getNode() != N) + return CombineTo(N, NewLoad, SDValue(NewLoad.getNode(), 1), true); } } } @@ -9957,6 +10209,7 @@ SDValue DAGCombiner::TransformFPLoadStorePair(SDNode *N) { return SDValue(); } +namespace { /// Helper struct to parse and store a memory address as base + index + offset. /// We ignore sign extensions when it is safe to do so. /// The following two expressions are not equivalent. To differentiate we need @@ -10044,6 +10297,7 @@ struct BaseIndexOffset { return BaseIndexOffset(Base, Index, Off, IsIndexSignExt); } }; +} // namespace bool DAGCombiner::MergeStoresOfConstantsOrVecElts( SmallVectorImpl &StoreNodes, EVT MemVT, @@ -10054,19 +10308,19 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( int64_t ElementSizeBytes = MemVT.getSizeInBits() / 8; LSBaseSDNode *FirstInChain = StoreNodes[0].MemNode; - unsigned EarliestNodeUsed = 0; + unsigned LatestNodeUsed = 0; for (unsigned i=0; i < NumElem; ++i) { // Find a chain for the new wide-store operand. Notice that some // of the store nodes that we found may not be selected for inclusion // in the wide store. The chain we use needs to be the chain of the - // earliest store node which is *used* and replaced by the wide store. - if (StoreNodes[i].SequenceNum > StoreNodes[EarliestNodeUsed].SequenceNum) - EarliestNodeUsed = i; + // latest store node which is *used* and replaced by the wide store. + if (StoreNodes[i].SequenceNum < StoreNodes[LatestNodeUsed].SequenceNum) + LatestNodeUsed = i; } - // The earliest Node in the DAG. - LSBaseSDNode *EarliestOp = StoreNodes[EarliestNodeUsed].MemNode; + // The latest Node in the DAG. + LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode; SDLoc DL(StoreNodes[0].MemNode); SDValue StoredVal; @@ -10125,17 +10379,17 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts( StoredVal = DAG.getConstant(StoreInt, StoreTy); } - SDValue NewStore = DAG.getStore(EarliestOp->getChain(), DL, StoredVal, + SDValue NewStore = DAG.getStore(LatestOp->getChain(), DL, StoredVal, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), false, false, FirstInChain->getAlignment()); - // Replace the first store with the new store - CombineTo(EarliestOp, NewStore); + // Replace the last store with the new store + CombineTo(LatestOp, NewStore); // Erase all other stores. for (unsigned i = 0; i < NumElem ; ++i) { - if (StoreNodes[i].MemNode == EarliestOp) + if (StoreNodes[i].MemNode == LatestOp) continue; StoreSDNode *St = cast(StoreNodes[i].MemNode); // ReplaceAllUsesWith will replace all uses that existed when it was @@ -10512,18 +10766,19 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { if (NumElem < 2) return false; - // The earliest Node in the DAG. - unsigned EarliestNodeUsed = 0; - LSBaseSDNode *EarliestOp = StoreNodes[EarliestNodeUsed].MemNode; + // The latest Node in the DAG. + unsigned LatestNodeUsed = 0; for (unsigned i=1; i StoreNodes[EarliestNodeUsed].SequenceNum) - EarliestNodeUsed = i; + // latest store node which is *used* and replaced by the wide store. + if (StoreNodes[i].SequenceNum < StoreNodes[LatestNodeUsed].SequenceNum) + LatestNodeUsed = i; } + LSBaseSDNode *LatestOp = StoreNodes[LatestNodeUsed].MemNode; + // Find if it is better to use vectors or integers to load and store // to memory. EVT JointMemOpVT; @@ -10545,7 +10800,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { false, false, false, FirstLoad->getAlignment()); - SDValue NewStore = DAG.getStore(EarliestOp->getChain(), StoreDL, NewLoad, + SDValue NewStore = DAG.getStore(LatestOp->getChain(), StoreDL, NewLoad, FirstInChain->getBasePtr(), FirstInChain->getPointerInfo(), false, false, FirstInChain->getAlignment()); @@ -10563,12 +10818,12 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode* St) { DAG.ReplaceAllUsesOfValueWith(SDValue(Ld, 1), Ld->getChain()); } - // Replace the first store with the new store. - CombineTo(EarliestOp, NewStore); + // Replace the last store with the new store. + CombineTo(LatestOp, NewStore); // Erase all other stores. for (unsigned i = 0; i < NumElem ; ++i) { // Remove all Store nodes. - if (StoreNodes[i].MemNode == EarliestOp) + if (StoreNodes[i].MemNode == LatestOp) continue; StoreSDNode *St = cast(StoreNodes[i].MemNode); DAG.ReplaceAllUsesOfValueWith(SDValue(St, 0), St->getChain()); @@ -10677,11 +10932,15 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) { // Try to infer better alignment information than the store already has. if (OptLevel != CodeGenOpt::None && ST->isUnindexed()) { if (unsigned Align = DAG.InferPtrAlignment(Ptr)) { - if (Align > ST->getAlignment()) - return DAG.getTruncStore(Chain, SDLoc(N), Value, + if (Align > ST->getAlignment()) { + SDValue NewStore = + DAG.getTruncStore(Chain, SDLoc(N), Value, Ptr, ST->getPointerInfo(), ST->getMemoryVT(), ST->isVolatile(), ST->isNonTemporal(), Align, ST->getAAInfo()); + if (NewStore.getNode() != N) + return CombineTo(ST, NewStore, true); + } } } @@ -11493,6 +11752,68 @@ SDValue DAGCombiner::visitBUILD_VECTOR(SDNode *N) { return SDValue(); } +static SDValue combineConcatVectorOfScalars(SDNode *N, SelectionDAG &DAG) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + EVT OpVT = N->getOperand(0).getValueType(); + + // If the operands are legal vectors, leave them alone. + if (TLI.isTypeLegal(OpVT)) + return SDValue(); + + SDLoc DL(N); + EVT VT = N->getValueType(0); + SmallVector Ops; + + EVT SVT = EVT::getIntegerVT(*DAG.getContext(), OpVT.getSizeInBits()); + SDValue ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT); + + // Keep track of what we encounter. + bool AnyInteger = false; + bool AnyFP = false; + for (const SDValue &Op : N->ops()) { + if (ISD::BITCAST == Op.getOpcode() && + !Op.getOperand(0).getValueType().isVector()) + Ops.push_back(Op.getOperand(0)); + else if (ISD::UNDEF == Op.getOpcode()) + Ops.push_back(ScalarUndef); + else + return SDValue(); + + // Note whether we encounter an integer or floating point scalar. + // If it's neither, bail out, it could be something weird like x86mmx. + EVT LastOpVT = Ops.back().getValueType(); + if (LastOpVT.isFloatingPoint()) + AnyFP = true; + else if (LastOpVT.isInteger()) + AnyInteger = true; + else + return SDValue(); + } + + // If any of the operands is a floating point scalar bitcast to a vector, + // use floating point types throughout, and bitcast everything. + // Replace UNDEFs by another scalar UNDEF node, of the final desired type. + if (AnyFP) { + SVT = EVT::getFloatingPointVT(OpVT.getSizeInBits()); + ScalarUndef = DAG.getNode(ISD::UNDEF, DL, SVT); + if (AnyInteger) { + for (SDValue &Op : Ops) { + if (Op.getValueType() == SVT) + continue; + if (Op.getOpcode() == ISD::UNDEF) + Op = ScalarUndef; + else + Op = DAG.getNode(ISD::BITCAST, DL, SVT, Op); + } + } + } + + EVT VecVT = EVT::getVectorVT(*DAG.getContext(), SVT, + VT.getSizeInBits() / SVT.getSizeInBits()); + return DAG.getNode(ISD::BITCAST, DL, VT, + DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, Ops)); +} + SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { // TODO: Check to see if this is a CONCAT_VECTORS of a bunch of // EXTRACT_SUBVECTOR operations. If so, and if the EXTRACT_SUBVECTOR vector @@ -11508,9 +11829,10 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { if (ISD::allOperandsUndef(N)) return DAG.getUNDEF(VT); - // Optimize concat_vectors where one of the vectors is undef. - if (N->getNumOperands() == 2 && - N->getOperand(1)->getOpcode() == ISD::UNDEF) { + // Optimize concat_vectors where all but the first of the vectors are undef. + if (std::all_of(std::next(N->op_begin()), N->op_end(), [](const SDValue &Op) { + return Op.getOpcode() == ISD::UNDEF; + })) { SDValue In = N->getOperand(0); assert(In.getValueType().isVector() && "Must concat vectors"); @@ -11518,6 +11840,15 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { if (In->getOpcode() == ISD::BITCAST && !In->getOperand(0)->getValueType(0).isVector()) { SDValue Scalar = In->getOperand(0); + + // If the bitcast type isn't legal, it might be a trunc of a legal type; + // look through the trunc so we can still do the transform: + // concat_vectors(trunc(scalar), undef) -> scalar_to_vector(scalar) + if (Scalar->getOpcode() == ISD::TRUNCATE && + !TLI.isTypeLegal(Scalar.getValueType()) && + TLI.isTypeLegal(Scalar->getOperand(0).getValueType())) + Scalar = Scalar->getOperand(0); + EVT SclTy = Scalar->getValueType(0); if (!SclTy.isFloatingPoint() && !SclTy.isInteger()) @@ -11585,6 +11916,10 @@ SDValue DAGCombiner::visitCONCAT_VECTORS(SDNode *N) { return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Opnds); } + // Fold CONCAT_VECTORS of only bitcast scalars (or undef) to BUILD_VECTOR. + if (SDValue V = combineConcatVectorOfScalars(N, DAG)) + return V; + // Type legalization of vectors and DAG canonicalization of SHUFFLE_VECTOR // nodes often generate nop CONCAT_VECTOR nodes. // Scan the CONCAT_VECTOR operands and look for a CONCAT operations that @@ -11646,7 +11981,7 @@ SDValue DAGCombiner::visitEXTRACT_SUBVECTOR(SDNode* N) { // type. if (V->getOperand(0).getValueType() != NVT) return SDValue(); - unsigned Idx = dyn_cast(N->getOperand(1))->getZExtValue(); + unsigned Idx = N->getConstantOperandVal(1); unsigned NumElems = NVT.getVectorNumElements(); assert((Idx % NumElems) == 0 && "IDX in concat is not a multiple of the result vector length."); @@ -11949,7 +12284,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { // We may have jumped through bitcasts, so the type of the // BUILD_VECTOR may not match the type of the shuffle. if (V->getValueType(0) != VT) - NewBV = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, NewBV); + NewBV = DAG.getNode(ISD::BITCAST, SDLoc(N), VT, NewBV); return NewBV; } } @@ -11971,6 +12306,43 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { return V; } + // Attempt to combine a shuffle of 2 inputs of 'scalar sources' - + // BUILD_VECTOR or SCALAR_TO_VECTOR into a single BUILD_VECTOR. + if (Level < AfterLegalizeVectorOps && TLI.isTypeLegal(VT)) { + SmallVector Ops; + for (int M : SVN->getMask()) { + SDValue Op = DAG.getUNDEF(VT.getScalarType()); + if (M >= 0) { + int Idx = M % NumElts; + SDValue &S = (M < (int)NumElts ? N0 : N1); + if (S.getOpcode() == ISD::BUILD_VECTOR && S.hasOneUse()) { + Op = S.getOperand(Idx); + } else if (S.getOpcode() == ISD::SCALAR_TO_VECTOR && S.hasOneUse()) { + if (Idx == 0) + Op = S.getOperand(0); + } else { + // Operand can't be combined - bail out. + break; + } + } + Ops.push_back(Op); + } + if (Ops.size() == VT.getVectorNumElements()) { + // BUILD_VECTOR requires all inputs to be of the same type, find the + // maximum type and extend them all. + EVT SVT = VT.getScalarType(); + if (SVT.isInteger()) + for (SDValue &Op : Ops) + SVT = (SVT.bitsLT(Op.getValueType()) ? Op.getValueType() : SVT); + if (SVT != VT.getScalarType()) + for (SDValue &Op : Ops) + Op = TLI.isZExtFree(Op.getValueType(), SVT) + ? DAG.getZExtOrTrunc(Op, SDLoc(N), SVT) + : DAG.getSExtOrTrunc(Op, SDLoc(N), SVT); + return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), VT, Ops); + } + } + // If this shuffle only has a single input that is a bitcasted shuffle, // attempt to merge the 2 shuffles and suitably bitcast the inputs/output // back to their original types. @@ -12030,16 +12402,8 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { SDValue SV1 = BC0->getOperand(1); bool LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT); if (!LegalMask) { - for (int i = 0, e = (int)NewMask.size(); i != e; ++i) { - int idx = NewMask[i]; - if (idx < 0) - continue; - else if (idx < e) - NewMask[i] = idx + e; - else - NewMask[i] = idx - e; - } std::swap(SV0, SV1); + ShuffleVectorSDNode::commuteMask(NewMask); LegalMask = TLI.isShuffleMaskLegal(NewMask, ScaleVT); } @@ -12163,16 +12527,7 @@ SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { // Avoid introducing shuffles with illegal mask. if (!TLI.isShuffleMaskLegal(Mask, VT)) { - // Compute the commuted shuffle mask and test again. - for (unsigned i = 0; i != NumElts; ++i) { - int idx = Mask[i]; - if (idx < 0) - continue; - else if (idx < (int)NumElts) - Mask[i] = idx + NumElts; - else - Mask[i] = idx - NumElts; - } + ShuffleVectorSDNode::commuteMask(Mask); if (!TLI.isShuffleMaskLegal(Mask, VT)) return SDValue(); @@ -12247,50 +12602,67 @@ SDValue DAGCombiner::visitINSERT_SUBVECTOR(SDNode *N) { return SDValue(); } +SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) { + SDValue N0 = N->getOperand(0); + + // fold (fp_to_fp16 (fp16_to_fp op)) -> op + if (N0->getOpcode() == ISD::FP16_TO_FP) + return N0->getOperand(0); + + return SDValue(); +} + /// Returns a vector_shuffle if it able to transform an AND to a vector_shuffle /// with the destination vector and a zero vector. /// e.g. AND V, <0xffffffff, 0, 0xffffffff, 0>. ==> /// vector_shuffle V, Zero, <0, 4, 2, 4> SDValue DAGCombiner::XformToShuffleWithZero(SDNode *N) { EVT VT = N->getValueType(0); - SDLoc dl(N); SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); - if (N->getOpcode() == ISD::AND) { - if (RHS.getOpcode() == ISD::BITCAST) - RHS = RHS.getOperand(0); - if (RHS.getOpcode() == ISD::BUILD_VECTOR) { - SmallVector Indices; - unsigned NumElts = RHS.getNumOperands(); - for (unsigned i = 0; i != NumElts; ++i) { - SDValue Elt = RHS.getOperand(i); - if (!isa(Elt)) - return SDValue(); + SDLoc dl(N); - if (cast(Elt)->isAllOnesValue()) - Indices.push_back(i); - else if (cast(Elt)->isNullValue()) - Indices.push_back(NumElts+i); - else - return SDValue(); - } + // Make sure we're not running after operation legalization where it + // may have custom lowered the vector shuffles. + if (LegalOperations) + return SDValue(); + + if (N->getOpcode() != ISD::AND) + return SDValue(); - // Let's see if the target supports this vector_shuffle and make sure - // we're not running after operation legalization where it may have - // custom lowered the vector shuffles. - EVT RVT = RHS.getValueType(); - if (LegalOperations || !TLI.isVectorClearMaskLegal(Indices, RVT)) + if (RHS.getOpcode() == ISD::BITCAST) + RHS = RHS.getOperand(0); + + if (RHS.getOpcode() == ISD::BUILD_VECTOR) { + SmallVector Indices; + unsigned NumElts = RHS.getNumOperands(); + + for (unsigned i = 0; i != NumElts; ++i) { + SDValue Elt = RHS.getOperand(i); + if (!isa(Elt)) return SDValue(); - // Return the new VECTOR_SHUFFLE node. - EVT EltVT = RVT.getVectorElementType(); - SmallVector ZeroOps(RVT.getVectorNumElements(), - DAG.getConstant(0, EltVT)); - SDValue Zero = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), RVT, ZeroOps); - LHS = DAG.getNode(ISD::BITCAST, dl, RVT, LHS); - SDValue Shuf = DAG.getVectorShuffle(RVT, dl, LHS, Zero, &Indices[0]); - return DAG.getNode(ISD::BITCAST, dl, VT, Shuf); + if (cast(Elt)->isAllOnesValue()) + Indices.push_back(i); + else if (cast(Elt)->isNullValue()) + Indices.push_back(NumElts+i); + else + return SDValue(); } + + // Let's see if the target supports this vector_shuffle. + EVT RVT = RHS.getValueType(); + if (!TLI.isVectorClearMaskLegal(Indices, RVT)) + return SDValue(); + + // Return the new VECTOR_SHUFFLE node. + EVT EltVT = RVT.getVectorElementType(); + SmallVector ZeroOps(RVT.getVectorNumElements(), + DAG.getConstant(0, EltVT)); + SDValue Zero = DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), RVT, ZeroOps); + LHS = DAG.getNode(ISD::BITCAST, dl, RVT, LHS); + SDValue Shuf = DAG.getVectorShuffle(RVT, dl, LHS, Zero, &Indices[0]); + return DAG.getNode(ISD::BITCAST, dl, VT, Shuf); } return SDValue(); @@ -12383,38 +12755,6 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) { return SDValue(); } -/// Visit a binary vector operation, like FABS/FNEG. -SDValue DAGCombiner::SimplifyVUnaryOp(SDNode *N) { - assert(N->getValueType(0).isVector() && - "SimplifyVUnaryOp only works on vectors!"); - - SDValue N0 = N->getOperand(0); - - if (N0.getOpcode() != ISD::BUILD_VECTOR) - return SDValue(); - - // Operand is a BUILD_VECTOR node, see if we can constant fold it. - SmallVector Ops; - for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) { - SDValue Op = N0.getOperand(i); - if (Op.getOpcode() != ISD::UNDEF && - Op.getOpcode() != ISD::ConstantFP) - break; - EVT EltVT = Op.getValueType(); - SDValue FoldOp = DAG.getNode(N->getOpcode(), SDLoc(N0), EltVT, Op); - if (FoldOp.getOpcode() != ISD::UNDEF && - FoldOp.getOpcode() != ISD::ConstantFP) - break; - Ops.push_back(FoldOp); - AddToWorklist(FoldOp.getNode()); - } - - if (Ops.size() != N0.getNumOperands()) - return SDValue(); - - return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), N0.getValueType(), Ops); -} - SDValue DAGCombiner::SimplifySelect(SDLoc DL, SDValue N0, SDValue N1, SDValue N2){ assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!"); @@ -12451,6 +12791,38 @@ SDValue DAGCombiner::SimplifySelect(SDLoc DL, SDValue N0, bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS, SDValue RHS) { + // fold (select (setcc x, -0.0, *lt), NaN, (fsqrt x)) + // The select + setcc is redundant, because fsqrt returns NaN for X < -0. + if (const ConstantFPSDNode *NaN = isConstOrConstSplatFP(LHS)) { + if (NaN->isNaN() && RHS.getOpcode() == ISD::FSQRT) { + // We have: (select (setcc ?, ?, ?), NaN, (fsqrt ?)) + SDValue Sqrt = RHS; + ISD::CondCode CC; + SDValue CmpLHS; + const ConstantFPSDNode *NegZero = nullptr; + + if (TheSelect->getOpcode() == ISD::SELECT_CC) { + CC = dyn_cast(TheSelect->getOperand(4))->get(); + CmpLHS = TheSelect->getOperand(0); + NegZero = isConstOrConstSplatFP(TheSelect->getOperand(1)); + } else { + // SELECT or VSELECT + SDValue Cmp = TheSelect->getOperand(0); + if (Cmp.getOpcode() == ISD::SETCC) { + CC = dyn_cast(Cmp.getOperand(2))->get(); + CmpLHS = Cmp.getOperand(0); + NegZero = isConstOrConstSplatFP(Cmp.getOperand(1)); + } + } + if (NegZero && NegZero->isNegative() && NegZero->isZero() && + Sqrt.getOperand(0) == CmpLHS && (CC == ISD::SETOLT || + CC == ISD::SETULT || CC == ISD::SETLT)) { + // We have: (select (setcc x, -0.0, *lt), NaN, (fsqrt x)) + CombineTo(TheSelect, Sqrt); + return true; + } + } + } // Cannot simplify select with vector condition if (TheSelect->getOperand(0).getValueType().isVector()) return false; @@ -12472,6 +12844,9 @@ bool DAGCombiner::SimplifySelectOps(SDNode *TheSelect, SDValue LHS, if (LHS.getOperand(0) != RHS.getOperand(0) || // Do not let this transformation reduce the number of volatile loads. LLD->isVolatile() || RLD->isVolatile() || + // FIXME: If either is a pre/post inc/dec load, + // we'd need to split out the address adjustment. + LLD->isIndexed() || RLD->isIndexed() || // If this is an EXTLOAD, the VT's must match. LLD->getMemoryVT() != RLD->getMemoryVT() || // If this is an EXTLOAD, the kind of extension must match.