X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=lib%2FTarget%2FAArch64%2FAArch64ISelLowering.cpp;h=6da468ed6b14d650bb6e3d88eafeaa889dfbed2a;hb=d7a4f74f15f8cff2b2b38a33e6027209ea3796f3;hp=88ec06f778ef16295bb863561506a2cc3ea15f0a;hpb=2daff76c0545a32c083577b7ef766145b3c42085;p=oota-llvm.git diff --git a/lib/Target/AArch64/AArch64ISelLowering.cpp b/lib/Target/AArch64/AArch64ISelLowering.cpp index 88ec06f778e..6da468ed6b1 100644 --- a/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -12,9 +12,10 @@ //===----------------------------------------------------------------------===// #include "AArch64ISelLowering.h" +#include "AArch64CallingConvention.h" +#include "AArch64MachineFunctionInfo.h" #include "AArch64PerfectShuffle.h" #include "AArch64Subtarget.h" -#include "AArch64MachineFunctionInfo.h" #include "AArch64TargetMachine.h" #include "AArch64TargetObjectFile.h" #include "MCTargetDesc/AArch64AddressingModes.h" @@ -38,10 +39,12 @@ using namespace llvm; STATISTIC(NumTailCalls, "Number of tail calls"); STATISTIC(NumShiftInserts, "Number of vector shift inserts"); +namespace { enum AlignMode { StrictAlign, NoStrictAlign }; +} static cl::opt Align(cl::desc("Load/store alignment support"), @@ -64,18 +67,9 @@ EnableAArch64SlrGeneration("aarch64-shift-insert-generation", cl::Hidden, cl::desc("Allow AArch64 SLI/SRI formation"), cl::init(false)); -//===----------------------------------------------------------------------===// -// AArch64 Lowering public interface. -//===----------------------------------------------------------------------===// -static TargetLoweringObjectFile *createTLOF(const Triple &TT) { - if (TT.isOSBinFormatMachO()) - return new AArch64_MachoTargetObjectFile(); - - return new AArch64_ELFTargetObjectFile(); -} -AArch64TargetLowering::AArch64TargetLowering(TargetMachine &TM) - : TargetLowering(TM, createTLOF(Triple(TM.getTargetTriple()))) { +AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM) + : TargetLowering(TM) { Subtarget = &TM.getSubtarget(); // AArch64 doesn't have comparisons which set GPRs or setcc instructions, so @@ -106,6 +100,7 @@ AArch64TargetLowering::AArch64TargetLowering(TargetMachine &TM) addDRTypeForNEON(MVT::v2i32); addDRTypeForNEON(MVT::v1i64); addDRTypeForNEON(MVT::v1f64); + addDRTypeForNEON(MVT::v4f16); addQRTypeForNEON(MVT::v4f32); addQRTypeForNEON(MVT::v2f64); @@ -113,6 +108,7 @@ AArch64TargetLowering::AArch64TargetLowering(TargetMachine &TM) addQRTypeForNEON(MVT::v8i16); addQRTypeForNEON(MVT::v4i32); addQRTypeForNEON(MVT::v2i64); + addQRTypeForNEON(MVT::v8f16); } // Compute derived properties from the register classes @@ -278,6 +274,94 @@ AArch64TargetLowering::AArch64TargetLowering(TargetMachine &TM) setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom); setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom); + // f16 is storage-only, so we promote operations to f32 if we know this is + // valid, and ignore them otherwise. The operations not mentioned here will + // fail to select, but this is not a major problem as no source language + // should be emitting native f16 operations yet. + setOperationAction(ISD::FADD, MVT::f16, Promote); + setOperationAction(ISD::FDIV, MVT::f16, Promote); + setOperationAction(ISD::FMUL, MVT::f16, Promote); + setOperationAction(ISD::FSUB, MVT::f16, Promote); + + // v4f16 is also a storage-only type, so promote it to v4f32 when that is + // known to be safe. + setOperationAction(ISD::FADD, MVT::v4f16, Promote); + setOperationAction(ISD::FSUB, MVT::v4f16, Promote); + setOperationAction(ISD::FMUL, MVT::v4f16, Promote); + setOperationAction(ISD::FDIV, MVT::v4f16, Promote); + setOperationAction(ISD::FP_EXTEND, MVT::v4f16, Promote); + setOperationAction(ISD::FP_ROUND, MVT::v4f16, Promote); + AddPromotedToType(ISD::FADD, MVT::v4f16, MVT::v4f32); + AddPromotedToType(ISD::FSUB, MVT::v4f16, MVT::v4f32); + AddPromotedToType(ISD::FMUL, MVT::v4f16, MVT::v4f32); + AddPromotedToType(ISD::FDIV, MVT::v4f16, MVT::v4f32); + AddPromotedToType(ISD::FP_EXTEND, MVT::v4f16, MVT::v4f32); + AddPromotedToType(ISD::FP_ROUND, MVT::v4f16, MVT::v4f32); + + // Expand all other v4f16 operations. + // FIXME: We could generate better code by promoting some operations to + // a pair of v4f32s + setOperationAction(ISD::FABS, MVT::v4f16, Expand); + setOperationAction(ISD::FCEIL, MVT::v4f16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::v4f16, Expand); + setOperationAction(ISD::FCOS, MVT::v4f16, Expand); + setOperationAction(ISD::FFLOOR, MVT::v4f16, Expand); + setOperationAction(ISD::FMA, MVT::v4f16, Expand); + setOperationAction(ISD::FNEARBYINT, MVT::v4f16, Expand); + setOperationAction(ISD::FNEG, MVT::v4f16, Expand); + setOperationAction(ISD::FPOW, MVT::v4f16, Expand); + setOperationAction(ISD::FPOWI, MVT::v4f16, Expand); + setOperationAction(ISD::FREM, MVT::v4f16, Expand); + setOperationAction(ISD::FROUND, MVT::v4f16, Expand); + setOperationAction(ISD::FRINT, MVT::v4f16, Expand); + setOperationAction(ISD::FSIN, MVT::v4f16, Expand); + setOperationAction(ISD::FSINCOS, MVT::v4f16, Expand); + setOperationAction(ISD::FSQRT, MVT::v4f16, Expand); + setOperationAction(ISD::FTRUNC, MVT::v4f16, Expand); + setOperationAction(ISD::SETCC, MVT::v4f16, Expand); + setOperationAction(ISD::BR_CC, MVT::v4f16, Expand); + setOperationAction(ISD::SELECT, MVT::v4f16, Expand); + setOperationAction(ISD::SELECT_CC, MVT::v4f16, Expand); + setOperationAction(ISD::FEXP, MVT::v4f16, Expand); + setOperationAction(ISD::FEXP2, MVT::v4f16, Expand); + setOperationAction(ISD::FLOG, MVT::v4f16, Expand); + setOperationAction(ISD::FLOG2, MVT::v4f16, Expand); + setOperationAction(ISD::FLOG10, MVT::v4f16, Expand); + + + // v8f16 is also a storage-only type, so expand it. + setOperationAction(ISD::FABS, MVT::v8f16, Expand); + setOperationAction(ISD::FADD, MVT::v8f16, Expand); + setOperationAction(ISD::FCEIL, MVT::v8f16, Expand); + setOperationAction(ISD::FCOPYSIGN, MVT::v8f16, Expand); + setOperationAction(ISD::FCOS, MVT::v8f16, Expand); + setOperationAction(ISD::FDIV, MVT::v8f16, Expand); + setOperationAction(ISD::FFLOOR, MVT::v8f16, Expand); + setOperationAction(ISD::FMA, MVT::v8f16, Expand); + setOperationAction(ISD::FMUL, MVT::v8f16, Expand); + setOperationAction(ISD::FNEARBYINT, MVT::v8f16, Expand); + setOperationAction(ISD::FNEG, MVT::v8f16, Expand); + setOperationAction(ISD::FPOW, MVT::v8f16, Expand); + setOperationAction(ISD::FPOWI, MVT::v8f16, Expand); + setOperationAction(ISD::FREM, MVT::v8f16, Expand); + setOperationAction(ISD::FROUND, MVT::v8f16, Expand); + setOperationAction(ISD::FRINT, MVT::v8f16, Expand); + setOperationAction(ISD::FSIN, MVT::v8f16, Expand); + setOperationAction(ISD::FSINCOS, MVT::v8f16, Expand); + setOperationAction(ISD::FSQRT, MVT::v8f16, Expand); + setOperationAction(ISD::FSUB, MVT::v8f16, Expand); + setOperationAction(ISD::FTRUNC, MVT::v8f16, Expand); + setOperationAction(ISD::SETCC, MVT::v8f16, Expand); + setOperationAction(ISD::BR_CC, MVT::v8f16, Expand); + setOperationAction(ISD::SELECT, MVT::v8f16, Expand); + setOperationAction(ISD::SELECT_CC, MVT::v8f16, Expand); + setOperationAction(ISD::FP_EXTEND, MVT::v8f16, Expand); + setOperationAction(ISD::FEXP, MVT::v8f16, Expand); + setOperationAction(ISD::FEXP2, MVT::v8f16, Expand); + setOperationAction(ISD::FLOG, MVT::v8f16, Expand); + setOperationAction(ISD::FLOG2, MVT::v8f16, Expand); + setOperationAction(ISD::FLOG10, MVT::v8f16, Expand); + // AArch64 has implementations of a lot of rounding-like FP operations. static MVT RoundingTypes[] = { MVT::f32, MVT::f64}; for (unsigned I = 0; I < array_lengthof(RoundingTypes); ++I) { @@ -439,6 +523,11 @@ AArch64TargetLowering::AArch64TargetLowering(TargetMachine &TM) // AArch64 doesn't have MUL.2d: setOperationAction(ISD::MUL, MVT::v2i64, Expand); + // Custom handling for some quad-vector types to detect MULL. + setOperationAction(ISD::MUL, MVT::v8i16, Custom); + setOperationAction(ISD::MUL, MVT::v4i32, Custom); + setOperationAction(ISD::MUL, MVT::v2i64, Custom); + setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Legal); setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand); // Likewise, narrowing and extending vector loads/stores aren't handled @@ -477,16 +566,20 @@ AArch64TargetLowering::AArch64TargetLowering(TargetMachine &TM) setOperationAction(ISD::FROUND, Ty, Legal); } } + + // Prefer likely predicted branches to selects on out-of-order cores. + if (Subtarget->isCortexA57()) + PredictableSelectIsExpensive = true; } void AArch64TargetLowering::addTypeForNEON(EVT VT, EVT PromotedBitwiseVT) { - if (VT == MVT::v2f32) { + if (VT == MVT::v2f32 || VT == MVT::v4f16) { setOperationAction(ISD::LOAD, VT.getSimpleVT(), Promote); AddPromotedToType(ISD::LOAD, VT.getSimpleVT(), MVT::v2i32); setOperationAction(ISD::STORE, VT.getSimpleVT(), Promote); AddPromotedToType(ISD::STORE, VT.getSimpleVT(), MVT::v2i32); - } else if (VT == MVT::v2f64 || VT == MVT::v4f32) { + } else if (VT == MVT::v2f64 || VT == MVT::v4f32 || VT == MVT::v8f16) { setOperationAction(ISD::LOAD, VT.getSimpleVT(), Promote); AddPromotedToType(ISD::LOAD, VT.getSimpleVT(), MVT::v2i64); @@ -727,6 +820,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { case AArch64ISD::TC_RETURN: return "AArch64ISD::TC_RETURN"; case AArch64ISD::SITOF: return "AArch64ISD::SITOF"; case AArch64ISD::UITOF: return "AArch64ISD::UITOF"; + case AArch64ISD::NVCAST: return "AArch64ISD::NVCAST"; case AArch64ISD::SQSHL_I: return "AArch64ISD::SQSHL_I"; case AArch64ISD::UQSHL_I: return "AArch64ISD::UQSHL_I"; case AArch64ISD::SRSHR_I: return "AArch64ISD::SRSHR_I"; @@ -756,6 +850,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const { case AArch64ISD::ST2LANEpost: return "AArch64ISD::ST2LANEpost"; case AArch64ISD::ST3LANEpost: return "AArch64ISD::ST3LANEpost"; case AArch64ISD::ST4LANEpost: return "AArch64ISD::ST4LANEpost"; + case AArch64ISD::SMULL: return "AArch64ISD::SMULL"; + case AArch64ISD::UMULL: return "AArch64ISD::UMULL"; } } @@ -774,7 +870,8 @@ AArch64TargetLowering::EmitF128CSEL(MachineInstr *MI, // EndBB: // Dest = PHI [IfTrue, TrueBB], [IfFalse, OrigBB] - const TargetInstrInfo *TII = getTargetMachine().getInstrInfo(); + const TargetInstrInfo *TII = + getTargetMachine().getSubtargetImpl()->getInstrInfo(); MachineFunction *MF = MBB->getParent(); const BasicBlock *LLVM_BB = MBB->getBasicBlock(); DebugLoc DL = MI->getDebugLoc(); @@ -1020,6 +1117,8 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC, static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, SDValue &AArch64cc, SelectionDAG &DAG, SDLoc dl) { + SDValue Cmp; + AArch64CC::CondCode AArch64CC; if (ConstantSDNode *RHSC = dyn_cast(RHS.getNode())) { EVT VT = RHS.getValueType(); uint64_t C = RHSC->getZExtValue(); @@ -1051,9 +1150,9 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, break; case ISD::SETLE: case ISD::SETGT: - if ((VT == MVT::i32 && C != 0x7fffffff && + if ((VT == MVT::i32 && C != INT32_MAX && isLegalArithImmed((uint32_t)(C + 1))) || - (VT == MVT::i64 && C != 0x7ffffffffffffffULL && + (VT == MVT::i64 && C != INT64_MAX && isLegalArithImmed(C + 1ULL))) { CC = (CC == ISD::SETLE) ? ISD::SETLT : ISD::SETGE; C = (VT == MVT::i32) ? (uint32_t)(C + 1) : C + 1; @@ -1062,9 +1161,9 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, break; case ISD::SETULE: case ISD::SETUGT: - if ((VT == MVT::i32 && C != 0xffffffff && + if ((VT == MVT::i32 && C != UINT32_MAX && isLegalArithImmed((uint32_t)(C + 1))) || - (VT == MVT::i64 && C != 0xfffffffffffffffULL && + (VT == MVT::i64 && C != UINT64_MAX && isLegalArithImmed(C + 1ULL))) { CC = (CC == ISD::SETULE) ? ISD::SETULT : ISD::SETUGE; C = (VT == MVT::i32) ? (uint32_t)(C + 1) : C + 1; @@ -1074,9 +1173,45 @@ static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC, } } } - - SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG); - AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC); + // The imm operand of ADDS is an unsigned immediate, in the range 0 to 4095. + // For the i8 operand, the largest immediate is 255, so this can be easily + // encoded in the compare instruction. For the i16 operand, however, the + // largest immediate cannot be encoded in the compare. + // Therefore, use a sign extending load and cmn to avoid materializing the -1 + // constant. For example, + // movz w1, #65535 + // ldrh w0, [x0, #0] + // cmp w0, w1 + // > + // ldrsh w0, [x0, #0] + // cmn w0, #1 + // Fundamental, we're relying on the property that (zext LHS) == (zext RHS) + // if and only if (sext LHS) == (sext RHS). The checks are in place to ensure + // both the LHS and RHS are truely zero extended and to make sure the + // transformation is profitable. + if ((CC == ISD::SETEQ || CC == ISD::SETNE) && isa(RHS)) { + if ((cast(RHS)->getZExtValue() >> 16 == 0) && + isa(LHS)) { + if (cast(LHS)->getExtensionType() == ISD::ZEXTLOAD && + cast(LHS)->getMemoryVT() == MVT::i16 && + LHS.getNode()->hasNUsesOfValue(1, 0)) { + int16_t ValueofRHS = cast(RHS)->getZExtValue(); + if (ValueofRHS < 0 && isLegalArithImmed(-ValueofRHS)) { + SDValue SExt = + DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, LHS.getValueType(), LHS, + DAG.getValueType(MVT::i16)); + Cmp = emitComparison(SExt, + DAG.getConstant(ValueofRHS, RHS.getValueType()), + CC, dl, DAG); + AArch64CC = changeIntCCToAArch64CC(CC); + AArch64cc = DAG.getConstant(AArch64CC, MVT::i32); + return Cmp; + } + } + } + } + Cmp = emitComparison(LHS, RHS, CC, dl, DAG); + AArch64CC = changeIntCCToAArch64CC(CC); AArch64cc = DAG.getConstant(AArch64CC, MVT::i32); return Cmp; } @@ -1333,8 +1468,7 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) { SDLoc DL(Op); unsigned IsWrite = cast(Op.getOperand(2))->getZExtValue(); unsigned Locality = cast(Op.getOperand(3))->getZExtValue(); - // The data thing is not used. - // unsigned isData = cast(Op.getOperand(4))->getZExtValue(); + unsigned IsData = cast(Op.getOperand(4))->getZExtValue(); bool IsStream = !Locality; // When the locality number is set @@ -1349,6 +1483,7 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) { // built the mask value encoding the expected behavior. unsigned PrfOp = (IsWrite << 4) | // Load/Store bit + (!IsData << 3) | // IsDataCache bit (Locality << 1) | // Cache level bits (unsigned)IsStream; // Stream bit return DAG.getNode(AArch64ISD::PREFETCH, DL, MVT::Other, Op.getOperand(0), @@ -1400,7 +1535,10 @@ static SDValue LowerVectorFP_TO_INT(SDValue Op, SelectionDAG &DAG) { if (VT.getSizeInBits() > InVT.getSizeInBits()) { SDLoc dl(Op); - SDValue Ext = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v2f64, Op.getOperand(0)); + MVT ExtVT = + MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()), + VT.getVectorNumElements()); + SDValue Ext = DAG.getNode(ISD::FP_EXTEND, dl, ExtVT, Op.getOperand(0)); return DAG.getNode(Op.getOpcode(), dl, VT, Ext); } @@ -1505,7 +1643,7 @@ SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op, (ArgVT == MVT::f64) ? "__sincos_stret" : "__sincosf_stret"; SDValue Callee = DAG.getExternalSymbol(LibcallName, getPointerTy()); - StructType *RetTy = StructType::get(ArgTy, ArgTy, NULL); + StructType *RetTy = StructType::get(ArgTy, ArgTy, nullptr); TargetLowering::CallLoweringInfo CLI(DAG); CLI.setDebugLoc(dl).setChain(DAG.getEntryNode()) .setCallee(CallingConv::Fast, RetTy, Callee, std::move(Args), 0); @@ -1529,6 +1667,197 @@ static SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) { 0); } +static EVT getExtensionTo64Bits(const EVT &OrigVT) { + if (OrigVT.getSizeInBits() >= 64) + return OrigVT; + + assert(OrigVT.isSimple() && "Expecting a simple value type"); + + MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy; + switch (OrigSimpleTy) { + default: llvm_unreachable("Unexpected Vector Type"); + case MVT::v2i8: + case MVT::v2i16: + return MVT::v2i32; + case MVT::v4i8: + return MVT::v4i16; + } +} + +static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG, + const EVT &OrigTy, + const EVT &ExtTy, + unsigned ExtOpcode) { + // The vector originally had a size of OrigTy. It was then extended to ExtTy. + // We expect the ExtTy to be 128-bits total. If the OrigTy is less than + // 64-bits we need to insert a new extension so that it will be 64-bits. + assert(ExtTy.is128BitVector() && "Unexpected extension size"); + if (OrigTy.getSizeInBits() >= 64) + return N; + + // Must extend size to at least 64 bits to be used as an operand for VMULL. + EVT NewVT = getExtensionTo64Bits(OrigTy); + + return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N); +} + +static bool isExtendedBUILD_VECTOR(SDNode *N, SelectionDAG &DAG, + bool isSigned) { + EVT VT = N->getValueType(0); + + if (N->getOpcode() != ISD::BUILD_VECTOR) + return false; + + for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) { + SDNode *Elt = N->getOperand(i).getNode(); + if (ConstantSDNode *C = dyn_cast(Elt)) { + unsigned EltSize = VT.getVectorElementType().getSizeInBits(); + unsigned HalfSize = EltSize / 2; + if (isSigned) { + if (!isIntN(HalfSize, C->getSExtValue())) + return false; + } else { + if (!isUIntN(HalfSize, C->getZExtValue())) + return false; + } + continue; + } + return false; + } + + return true; +} + +static SDValue skipExtensionForVectorMULL(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() == ISD::SIGN_EXTEND || N->getOpcode() == ISD::ZERO_EXTEND) + return addRequiredExtensionForVectorMULL(N->getOperand(0), DAG, + N->getOperand(0)->getValueType(0), + N->getValueType(0), + N->getOpcode()); + + assert(N->getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR"); + EVT VT = N->getValueType(0); + unsigned EltSize = VT.getVectorElementType().getSizeInBits() / 2; + unsigned NumElts = VT.getVectorNumElements(); + MVT TruncVT = MVT::getIntegerVT(EltSize); + SmallVector Ops; + for (unsigned i = 0; i != NumElts; ++i) { + ConstantSDNode *C = cast(N->getOperand(i)); + const APInt &CInt = C->getAPIntValue(); + // Element types smaller than 32 bits are not legal, so use i32 elements. + // The values are implicitly truncated so sext vs. zext doesn't matter. + Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), MVT::i32)); + } + return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), + MVT::getVectorVT(TruncVT, NumElts), Ops); +} + +static bool isSignExtended(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() == ISD::SIGN_EXTEND) + return true; + if (isExtendedBUILD_VECTOR(N, DAG, true)) + return true; + return false; +} + +static bool isZeroExtended(SDNode *N, SelectionDAG &DAG) { + if (N->getOpcode() == ISD::ZERO_EXTEND) + return true; + if (isExtendedBUILD_VECTOR(N, DAG, false)) + return true; + return false; +} + +static bool isAddSubSExt(SDNode *N, SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + if (Opcode == ISD::ADD || Opcode == ISD::SUB) { + SDNode *N0 = N->getOperand(0).getNode(); + SDNode *N1 = N->getOperand(1).getNode(); + return N0->hasOneUse() && N1->hasOneUse() && + isSignExtended(N0, DAG) && isSignExtended(N1, DAG); + } + return false; +} + +static bool isAddSubZExt(SDNode *N, SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + if (Opcode == ISD::ADD || Opcode == ISD::SUB) { + SDNode *N0 = N->getOperand(0).getNode(); + SDNode *N1 = N->getOperand(1).getNode(); + return N0->hasOneUse() && N1->hasOneUse() && + isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG); + } + return false; +} + +static SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) { + // Multiplications are only custom-lowered for 128-bit vectors so that + // VMULL can be detected. Otherwise v2i64 multiplications are not legal. + EVT VT = Op.getValueType(); + assert(VT.is128BitVector() && VT.isInteger() && + "unexpected type for custom-lowering ISD::MUL"); + SDNode *N0 = Op.getOperand(0).getNode(); + SDNode *N1 = Op.getOperand(1).getNode(); + unsigned NewOpc = 0; + bool isMLA = false; + bool isN0SExt = isSignExtended(N0, DAG); + bool isN1SExt = isSignExtended(N1, DAG); + if (isN0SExt && isN1SExt) + NewOpc = AArch64ISD::SMULL; + else { + bool isN0ZExt = isZeroExtended(N0, DAG); + bool isN1ZExt = isZeroExtended(N1, DAG); + if (isN0ZExt && isN1ZExt) + NewOpc = AArch64ISD::UMULL; + else if (isN1SExt || isN1ZExt) { + // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these + // into (s/zext A * s/zext C) + (s/zext B * s/zext C) + if (isN1SExt && isAddSubSExt(N0, DAG)) { + NewOpc = AArch64ISD::SMULL; + isMLA = true; + } else if (isN1ZExt && isAddSubZExt(N0, DAG)) { + NewOpc = AArch64ISD::UMULL; + isMLA = true; + } else if (isN0ZExt && isAddSubZExt(N1, DAG)) { + std::swap(N0, N1); + NewOpc = AArch64ISD::UMULL; + isMLA = true; + } + } + + if (!NewOpc) { + if (VT == MVT::v2i64) + // Fall through to expand this. It is not legal. + return SDValue(); + else + // Other vector multiplications are legal. + return Op; + } + } + + // Legalize to a S/UMULL instruction + SDLoc DL(Op); + SDValue Op0; + SDValue Op1 = skipExtensionForVectorMULL(N1, DAG); + if (!isMLA) { + Op0 = skipExtensionForVectorMULL(N0, DAG); + assert(Op0.getValueType().is64BitVector() && + Op1.getValueType().is64BitVector() && + "unexpected types for extended operands to VMULL"); + return DAG.getNode(NewOpc, DL, VT, Op0, Op1); + } + // Optimizing (zext A + zext B) * C, to (S/UMULL A, C) + (S/UMULL B, C) during + // isel lowering to take advantage of no-stall back to back s/umul + s/umla. + // This is true for CPUs with accumulate forwarding such as Cortex-A53/A57 + SDValue N00 = skipExtensionForVectorMULL(N0->getOperand(0).getNode(), DAG); + SDValue N01 = skipExtensionForVectorMULL(N0->getOperand(1).getNode(), DAG); + EVT Op1VT = Op1.getValueType(); + return DAG.getNode(N0->getOpcode(), DL, VT, + DAG.getNode(NewOpc, DL, VT, + DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1), + DAG.getNode(NewOpc, DL, VT, + DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)); +} SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { @@ -1629,6 +1958,8 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, return LowerFP_TO_INT(Op, DAG); case ISD::FSINCOS: return LowerFSINCOS(Op, DAG); + case ISD::MUL: + return LowerMUL(Op, DAG); } } @@ -1643,8 +1974,7 @@ unsigned AArch64TargetLowering::getFunctionAlignment(const Function *F) const { #include "AArch64GenCallingConv.inc" -/// Selects the correct CCAssignFn for a the given CallingConvention -/// value. +/// Selects the correct CCAssignFn for a given CallingConvention value. CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC, bool IsVarArg) const { switch (CC) { @@ -1669,8 +1999,8 @@ SDValue AArch64TargetLowering::LowerFormalArguments( // Assign locations to all of the incoming arguments. SmallVector ArgLocs; - CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), - getTargetMachine(), ArgLocs, *DAG.getContext()); + CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), ArgLocs, + *DAG.getContext()); // At this point, Ins[].VT may already be promoted to i32. To correctly // handle passing i8 as i8 instead of i32 on stack, we pass in both i32 and @@ -1774,7 +2104,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments( } else { // VA.isRegLoc() assert(VA.isMemLoc() && "CCValAssign is neither reg nor mem"); unsigned ArgOffset = VA.getLocMemOffset(); - unsigned ArgSize = VA.getLocVT().getSizeInBits() / 8; + unsigned ArgSize = VA.getValVT().getSizeInBits() / 8; uint32_t BEAlign = 0; if (ArgSize < 8 && !Subtarget->isLittleEndian()) @@ -1809,7 +2139,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments( ArgValue = DAG.getExtLoad(ExtType, DL, VA.getLocVT(), Chain, FIN, MachinePointerInfo::getFixedStack(FI), - MemVT, false, false, false, nullptr); + MemVT, false, false, false, 0); InVals.push_back(ArgValue); } @@ -1941,8 +2271,8 @@ SDValue AArch64TargetLowering::LowerCallResult( : RetCC_AArch64_AAPCS; // Assign locations to each value returned by this call. SmallVector RVLocs; - CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), - getTargetMachine(), RVLocs, *DAG.getContext()); + CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), RVLocs, + *DAG.getContext()); CCInfo.AnalyzeCallResult(Ins, RetCC); // Copy all of the result registers out of their specified physreg. @@ -2011,6 +2341,19 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( return false; } + // Externally-defined functions with weak linkage should not be + // tail-called on AArch64 when the OS does not support dynamic + // pre-emption of symbols, as the AAELF spec requires normal calls + // to undefined weak functions to be replaced with a NOP or jump to the + // next instruction. The behaviour of branch instructions in this + // situation (as used for tail calls) is implementation-defined, so we + // cannot rely on the linker replacing the tail call with a return. + if (GlobalAddressSDNode *G = dyn_cast(Callee)) { + const GlobalValue *GV = G->getGlobal(); + if (GV->hasExternalWeakLinkage()) + return false; + } + // Now we search for cases where we can use a tail call without changing the // ABI. Sibcall is used in some places (particularly gcc) to refer to this // concept. @@ -2028,8 +2371,8 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( // FIXME: for now we take the most conservative of these in both cases: // disallow all variadic memory operands. SmallVector ArgLocs; - CCState CCInfo(CalleeCC, isVarArg, DAG.getMachineFunction(), - getTargetMachine(), ArgLocs, *DAG.getContext()); + CCState CCInfo(CalleeCC, isVarArg, DAG.getMachineFunction(), ArgLocs, + *DAG.getContext()); CCInfo.AnalyzeCallOperands(Outs, CCAssignFnForCall(CalleeCC, true)); for (unsigned i = 0, e = ArgLocs.size(); i != e; ++i) @@ -2041,13 +2384,13 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( // results are returned in the same way as what the caller expects. if (!CCMatch) { SmallVector RVLocs1; - CCState CCInfo1(CalleeCC, false, DAG.getMachineFunction(), - getTargetMachine(), RVLocs1, *DAG.getContext()); + CCState CCInfo1(CalleeCC, false, DAG.getMachineFunction(), RVLocs1, + *DAG.getContext()); CCInfo1.AnalyzeCallResult(Ins, CCAssignFnForCall(CalleeCC, isVarArg)); SmallVector RVLocs2; - CCState CCInfo2(CallerCC, false, DAG.getMachineFunction(), - getTargetMachine(), RVLocs2, *DAG.getContext()); + CCState CCInfo2(CallerCC, false, DAG.getMachineFunction(), RVLocs2, + *DAG.getContext()); CCInfo2.AnalyzeCallResult(Ins, CCAssignFnForCall(CallerCC, isVarArg)); if (RVLocs1.size() != RVLocs2.size()) @@ -2072,8 +2415,8 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization( return true; SmallVector ArgLocs; - CCState CCInfo(CalleeCC, isVarArg, DAG.getMachineFunction(), - getTargetMachine(), ArgLocs, *DAG.getContext()); + CCState CCInfo(CalleeCC, isVarArg, DAG.getMachineFunction(), ArgLocs, + *DAG.getContext()); CCInfo.AnalyzeCallOperands(Outs, CCAssignFnForCall(CalleeCC, isVarArg)); @@ -2170,8 +2513,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // Analyze operands of the call, assigning locations to each operand. SmallVector ArgLocs; - CCState CCInfo(CallConv, IsVarArg, DAG.getMachineFunction(), - getTargetMachine(), ArgLocs, *DAG.getContext()); + CCState CCInfo(CallConv, IsVarArg, DAG.getMachineFunction(), ArgLocs, + *DAG.getContext()); if (IsVarArg) { // Handle fixed and variable vector arguments differently. @@ -2316,7 +2659,7 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // common case. It should also work for fundamental types too. uint32_t BEAlign = 0; unsigned OpSize = Flags.isByVal() ? Flags.getByValSize() * 8 - : VA.getLocVT().getSizeInBits(); + : VA.getValVT().getSizeInBits(); OpSize = (OpSize + 7) / 8; if (!Subtarget->isLittleEndian() && !Flags.isByVal()) { if (OpSize < 8) @@ -2350,8 +2693,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, DAG.getConstant(Outs[i].Flags.getByValSize(), MVT::i64); SDValue Cpy = DAG.getMemcpy( Chain, DL, DstAddr, Arg, SizeNode, Outs[i].Flags.getByValAlign(), - /*isVolatile = */ false, - /*alwaysInline = */ false, DstInfo, MachinePointerInfo()); + /*isVol = */ false, + /*AlwaysInline = */ false, DstInfo, MachinePointerInfo()); MemOpChains.push_back(Cpy); } else { @@ -2440,7 +2783,8 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, // Add a register mask operand representing the call-preserved registers. const uint32_t *Mask; - const TargetRegisterInfo *TRI = getTargetMachine().getRegisterInfo(); + const TargetRegisterInfo *TRI = + getTargetMachine().getSubtargetImpl()->getRegisterInfo(); const AArch64RegisterInfo *ARI = static_cast(TRI); if (IsThisReturn) { @@ -2494,7 +2838,7 @@ bool AArch64TargetLowering::CanLowerReturn( ? RetCC_AArch64_WebKit_JS : RetCC_AArch64_AAPCS; SmallVector RVLocs; - CCState CCInfo(CallConv, isVarArg, MF, getTargetMachine(), RVLocs, Context); + CCState CCInfo(CallConv, isVarArg, MF, RVLocs, Context); return CCInfo.CheckReturn(Outs, RetCC); } @@ -2508,8 +2852,8 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, ? RetCC_AArch64_WebKit_JS : RetCC_AArch64_AAPCS; SmallVector RVLocs; - CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), - getTargetMachine(), RVLocs, *DAG.getContext()); + CCState CCInfo(CallConv, isVarArg, DAG.getMachineFunction(), RVLocs, + *DAG.getContext()); CCInfo.AnalyzeReturn(Outs, RetCC); // Copy the result values into the output registers. @@ -2560,7 +2904,8 @@ SDValue AArch64TargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { EVT PtrVT = getPointerTy(); SDLoc DL(Op); - const GlobalValue *GV = cast(Op)->getGlobal(); + const GlobalAddressSDNode *GN = cast(Op); + const GlobalValue *GV = GN->getGlobal(); unsigned char OpFlags = Subtarget->ClassifyGlobalReference(GV, getTargetMachine()); @@ -2575,6 +2920,25 @@ SDValue AArch64TargetLowering::LowerGlobalAddress(SDValue Op, return DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, GotAddr); } + if ((OpFlags & AArch64II::MO_CONSTPOOL) != 0) { + assert(getTargetMachine().getCodeModel() == CodeModel::Small && + "use of MO_CONSTPOOL only supported on small model"); + SDValue Hi = DAG.getTargetConstantPool(GV, PtrVT, 0, 0, AArch64II::MO_PAGE); + SDValue ADRP = DAG.getNode(AArch64ISD::ADRP, DL, PtrVT, Hi); + unsigned char LoFlags = AArch64II::MO_PAGEOFF | AArch64II::MO_NC; + SDValue Lo = DAG.getTargetConstantPool(GV, PtrVT, 0, 0, LoFlags); + SDValue PoolAddr = DAG.getNode(AArch64ISD::ADDlow, DL, PtrVT, ADRP, Lo); + SDValue GlobalAddr = DAG.getLoad(PtrVT, DL, DAG.getEntryNode(), PoolAddr, + MachinePointerInfo::getConstantPool(), + /*isVolatile=*/ false, + /*isNonTemporal=*/ true, + /*isInvariant=*/ true, 8); + if (GN->getOffset() != 0) + return DAG.getNode(ISD::ADD, DL, PtrVT, GlobalAddr, + DAG.getConstant(GN->getOffset(), PtrVT)); + return GlobalAddr; + } + if (getTargetMachine().getCodeModel() == CodeModel::Large) { const unsigned char MO_NC = AArch64II::MO_NC; return DAG.getNode( @@ -2651,7 +3015,8 @@ AArch64TargetLowering::LowerDarwinGlobalTLSAddress(SDValue Op, // TLS calls preserve all registers except those that absolutely must be // trashed: X0 (it takes an argument), LR (it's a call) and NZCV (let's not be // silly). - const TargetRegisterInfo *TRI = getTargetMachine().getRegisterInfo(); + const TargetRegisterInfo *TRI = + getTargetMachine().getSubtargetImpl()->getRegisterInfo(); const AArch64RegisterInfo *ARI = static_cast(TRI); const uint32_t *Mask = ARI->getTLSCallPreservedMask(); @@ -2701,7 +3066,8 @@ SDValue AArch64TargetLowering::LowerELFTLSDescCall(SDValue SymAddr, // TLS calls preserve all registers except those that absolutely must be // trashed: X0 (it takes an argument), LR (it's a call) and NZCV (let's not be // silly). - const TargetRegisterInfo *TRI = getTargetMachine().getRegisterInfo(); + const TargetRegisterInfo *TRI = + getTargetMachine().getSubtargetImpl()->getRegisterInfo(); const AArch64RegisterInfo *ARI = static_cast(TRI); const uint32_t *Mask = ARI->getTLSCallPreservedMask(); @@ -2916,11 +3282,6 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { isPowerOf2_64(LHS.getConstantOperandVal(1))) { SDValue Test = LHS.getOperand(0); uint64_t Mask = LHS.getConstantOperandVal(1); - - // TBZ only operates on i64's, but the ext should be free. - if (Test.getValueType() == MVT::i32) - Test = DAG.getAnyExtOrTrunc(Test, dl, MVT::i64); - return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, Test, DAG.getConstant(Log2_64(Mask), MVT::i64), Dest); } @@ -2936,18 +3297,29 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const { isPowerOf2_64(LHS.getConstantOperandVal(1))) { SDValue Test = LHS.getOperand(0); uint64_t Mask = LHS.getConstantOperandVal(1); - - // TBNZ only operates on i64's, but the ext should be free. - if (Test.getValueType() == MVT::i32) - Test = DAG.getAnyExtOrTrunc(Test, dl, MVT::i64); - return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, Test, DAG.getConstant(Log2_64(Mask), MVT::i64), Dest); } return DAG.getNode(AArch64ISD::CBNZ, dl, MVT::Other, Chain, LHS, Dest); + } else if (CC == ISD::SETLT && LHS.getOpcode() != ISD::AND) { + // Don't combine AND since emitComparison converts the AND to an ANDS + // (a.k.a. TST) and the test in the test bit and branch instruction + // becomes redundant. This would also increase register pressure. + uint64_t Mask = LHS.getValueType().getSizeInBits() - 1; + return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, LHS, + DAG.getConstant(Mask, MVT::i64), Dest); } } + if (RHSC && RHSC->getSExtValue() == -1 && CC == ISD::SETGT && + LHS.getOpcode() != ISD::AND) { + // Don't combine AND since emitComparison converts the AND to an ANDS + // (a.k.a. TST) and the test in the test bit and branch instruction + // becomes redundant. This would also increase register pressure. + uint64_t Mask = LHS.getValueType().getSizeInBits() - 1; + return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, LHS, + DAG.getConstant(Mask, MVT::i64), Dest); + } SDValue CCVal; SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl); @@ -3062,6 +3434,9 @@ SDValue AArch64TargetLowering::LowerCTPOP(SDValue Op, SelectionDAG &DAG) const { AttributeSet::FunctionIndex, Attribute::NoImplicitFloat)) return SDValue(); + if (!Subtarget->hasNEON()) + return SDValue(); + // While there is no integer popcount instruction, it can // be more efficiently lowered to the following sequence that uses // AdvSIMD registers/instructions as long as the copies to/from @@ -4013,8 +4388,10 @@ void AArch64TargetLowering::LowerAsmOperandForConstraint( return; case 'J': { uint64_t NVal = -C->getSExtValue(); - if (isUInt<12>(NVal) || isShiftedUInt<12, 12>(NVal)) + if (isUInt<12>(NVal) || isShiftedUInt<12, 12>(NVal)) { + CVal = C->getSExtValue(); break; + } return; } // The K and L constraints apply *only* to logical immediates, including @@ -4138,10 +4515,30 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op, EVT VT = Op.getValueType(); unsigned NumElts = VT.getVectorNumElements(); - SmallVector SourceVecs; - SmallVector MinElts; - SmallVector MaxElts; + struct ShuffleSourceInfo { + SDValue Vec; + unsigned MinElt; + unsigned MaxElt; + + // We may insert some combination of BITCASTs and VEXT nodes to force Vec to + // be compatible with the shuffle we intend to construct. As a result + // ShuffleVec will be some sliding window into the original Vec. + SDValue ShuffleVec; + + // Code should guarantee that element i in Vec starts at element "WindowBase + // + i * WindowScale in ShuffleVec". + int WindowBase; + int WindowScale; + + bool operator ==(SDValue OtherVec) { return Vec == OtherVec; } + ShuffleSourceInfo(SDValue Vec) + : Vec(Vec), MinElt(UINT_MAX), MaxElt(0), ShuffleVec(Vec), WindowBase(0), + WindowScale(1) {} + }; + // First gather all vectors used as an immediate source for this BUILD_VECTOR + // node. + SmallVector Sources; for (unsigned i = 0; i < NumElts; ++i) { SDValue V = Op.getOperand(i); if (V.getOpcode() == ISD::UNDEF) @@ -4152,158 +4549,153 @@ SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op, return SDValue(); } - // Record this extraction against the appropriate vector if possible... + // Add this element source to the list if it's not already there. SDValue SourceVec = V.getOperand(0); - unsigned EltNo = cast(V.getOperand(1))->getZExtValue(); - bool FoundSource = false; - for (unsigned j = 0; j < SourceVecs.size(); ++j) { - if (SourceVecs[j] == SourceVec) { - if (MinElts[j] > EltNo) - MinElts[j] = EltNo; - if (MaxElts[j] < EltNo) - MaxElts[j] = EltNo; - FoundSource = true; - break; - } - } + auto Source = std::find(Sources.begin(), Sources.end(), SourceVec); + if (Source == Sources.end()) + Source = Sources.insert(Sources.end(), ShuffleSourceInfo(SourceVec)); - // Or record a new source if not... - if (!FoundSource) { - SourceVecs.push_back(SourceVec); - MinElts.push_back(EltNo); - MaxElts.push_back(EltNo); - } + // Update the minimum and maximum lane number seen. + unsigned EltNo = cast(V.getOperand(1))->getZExtValue(); + Source->MinElt = std::min(Source->MinElt, EltNo); + Source->MaxElt = std::max(Source->MaxElt, EltNo); } // Currently only do something sane when at most two source vectors - // involved. - if (SourceVecs.size() > 2) + // are involved. + if (Sources.size() > 2) return SDValue(); // Find out the smallest element size among result and two sources, and use // it as element size to build the shuffle_vector. EVT SmallestEltTy = VT.getVectorElementType(); - for (unsigned i = 0; i < SourceVecs.size(); ++i) { - EVT SrcEltTy = SourceVecs[i].getValueType().getVectorElementType(); + for (auto &Source : Sources) { + EVT SrcEltTy = Source.Vec.getValueType().getVectorElementType(); if (SrcEltTy.bitsLT(SmallestEltTy)) { SmallestEltTy = SrcEltTy; } } unsigned ResMultiplier = VT.getVectorElementType().getSizeInBits() / SmallestEltTy.getSizeInBits(); - int VEXTOffsets[2] = { 0, 0 }; - int OffsetMultipliers[2] = { 1, 1 }; NumElts = VT.getSizeInBits() / SmallestEltTy.getSizeInBits(); EVT ShuffleVT = EVT::getVectorVT(*DAG.getContext(), SmallestEltTy, NumElts); - SDValue ShuffleSrcs[2] = {DAG.getUNDEF(ShuffleVT), DAG.getUNDEF(ShuffleVT)}; - - // This loop extracts the usage patterns of the source vectors - // and prepares appropriate SDValues for a shuffle if possible. - for (unsigned i = 0; i < SourceVecs.size(); ++i) { - unsigned NumSrcElts = SourceVecs[i].getValueType().getVectorNumElements(); - SDValue CurSource = SourceVecs[i]; - if (SourceVecs[i].getValueType().getVectorElementType() != - ShuffleVT.getVectorElementType()) { - // As ShuffleVT holds smallest element size, it may hit here only if - // the element type of SourceVecs is bigger than that of ShuffleVT. - // Adjust the element size of SourceVecs to match ShuffleVT, and record - // the multipliers. - EVT CastVT = EVT::getVectorVT( - *DAG.getContext(), ShuffleVT.getVectorElementType(), - SourceVecs[i].getValueSizeInBits() / - ShuffleVT.getVectorElementType().getSizeInBits()); - - CurSource = DAG.getNode(ISD::BITCAST, dl, CastVT, SourceVecs[i]); - OffsetMultipliers[i] = CastVT.getVectorNumElements() / NumSrcElts; - NumSrcElts *= OffsetMultipliers[i]; - MaxElts[i] *= OffsetMultipliers[i]; - MinElts[i] *= OffsetMultipliers[i]; - } - if (CurSource.getValueType() == ShuffleVT) { - // No VEXT necessary - ShuffleSrcs[i] = CurSource; - VEXTOffsets[i] = 0; + // If the source vector is too wide or too narrow, we may nevertheless be able + // to construct a compatible shuffle either by concatenating it with UNDEF or + // extracting a suitable range of elements. + for (auto &Src : Sources) { + EVT SrcVT = Src.ShuffleVec.getValueType(); + + if (SrcVT.getSizeInBits() == VT.getSizeInBits()) continue; - } else if (NumSrcElts < NumElts) { + + // This stage of the search produces a source with the same element type as + // the original, but with a total width matching the BUILD_VECTOR output. + EVT EltVT = SrcVT.getVectorElementType(); + unsigned NumSrcElts = VT.getSizeInBits() / EltVT.getSizeInBits(); + EVT DestVT = EVT::getVectorVT(*DAG.getContext(), EltVT, NumSrcElts); + + if (SrcVT.getSizeInBits() < VT.getSizeInBits()) { + assert(2 * SrcVT.getSizeInBits() == VT.getSizeInBits()); // We can pad out the smaller vector for free, so if it's part of a // shuffle... - ShuffleSrcs[i] = - DAG.getNode(ISD::CONCAT_VECTORS, dl, ShuffleVT, CurSource, - DAG.getUNDEF(CurSource.getValueType())); + Src.ShuffleVec = + DAG.getNode(ISD::CONCAT_VECTORS, dl, DestVT, Src.ShuffleVec, + DAG.getUNDEF(Src.ShuffleVec.getValueType())); continue; } - // Since only 64-bit and 128-bit vectors are legal on ARM and - // we've eliminated the other cases... - assert(NumSrcElts == 2 * NumElts && - "unexpected vector sizes in ReconstructShuffle"); + assert(SrcVT.getSizeInBits() == 2 * VT.getSizeInBits()); - if (MaxElts[i] - MinElts[i] >= NumElts) { + if (Src.MaxElt - Src.MinElt >= NumSrcElts) { // Span too large for a VEXT to cope return SDValue(); } - if (MinElts[i] >= NumElts) { + if (Src.MinElt >= NumSrcElts) { // The extraction can just take the second half - VEXTOffsets[i] = NumElts; - ShuffleSrcs[i] = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ShuffleVT, - CurSource, DAG.getIntPtrConstant(NumElts)); - } else if (MaxElts[i] < NumElts) { + Src.ShuffleVec = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec, + DAG.getIntPtrConstant(NumSrcElts)); + Src.WindowBase = -NumSrcElts; + } else if (Src.MaxElt < NumSrcElts) { // The extraction can just take the first half - VEXTOffsets[i] = 0; - ShuffleSrcs[i] = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ShuffleVT, - CurSource, DAG.getIntPtrConstant(0)); + Src.ShuffleVec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, + Src.ShuffleVec, DAG.getIntPtrConstant(0)); } else { // An actual VEXT is needed - VEXTOffsets[i] = MinElts[i]; - SDValue VEXTSrc1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ShuffleVT, - CurSource, DAG.getIntPtrConstant(0)); - SDValue VEXTSrc2 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, ShuffleVT, - CurSource, DAG.getIntPtrConstant(NumElts)); - unsigned Imm = VEXTOffsets[i] * getExtFactor(VEXTSrc1); - ShuffleSrcs[i] = DAG.getNode(AArch64ISD::EXT, dl, ShuffleVT, VEXTSrc1, + SDValue VEXTSrc1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, + Src.ShuffleVec, DAG.getIntPtrConstant(0)); + SDValue VEXTSrc2 = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec, + DAG.getIntPtrConstant(NumSrcElts)); + unsigned Imm = Src.MinElt * getExtFactor(VEXTSrc1); + + Src.ShuffleVec = DAG.getNode(AArch64ISD::EXT, dl, DestVT, VEXTSrc1, VEXTSrc2, DAG.getConstant(Imm, MVT::i32)); + Src.WindowBase = -Src.MinElt; } } - SmallVector Mask; - unsigned VTEltSize = VT.getVectorElementType().getSizeInBits(); + // Another possible incompatibility occurs from the vector element types. We + // can fix this by bitcasting the source vectors to the same type we intend + // for the shuffle. + for (auto &Src : Sources) { + EVT SrcEltTy = Src.ShuffleVec.getValueType().getVectorElementType(); + if (SrcEltTy == SmallestEltTy) + continue; + assert(ShuffleVT.getVectorElementType() == SmallestEltTy); + Src.ShuffleVec = DAG.getNode(ISD::BITCAST, dl, ShuffleVT, Src.ShuffleVec); + Src.WindowScale = SrcEltTy.getSizeInBits() / SmallestEltTy.getSizeInBits(); + Src.WindowBase *= Src.WindowScale; + } + + // Final sanity check before we try to actually produce a shuffle. + DEBUG( + for (auto Src : Sources) + assert(Src.ShuffleVec.getValueType() == ShuffleVT); + ); + // The stars all align, our next step is to produce the mask for the shuffle. + SmallVector Mask(ShuffleVT.getVectorNumElements(), -1); + int BitsPerShuffleLane = ShuffleVT.getVectorElementType().getSizeInBits(); for (unsigned i = 0; i < VT.getVectorNumElements(); ++i) { SDValue Entry = Op.getOperand(i); - int SourceNum = 1; - unsigned LanePartNum = 0; - int ExtractElt; - if (Entry.getOpcode() != ISD::UNDEF) { - // Check how many parts of source lane should be inserted. - SDValue ExtractVec = Entry.getOperand(0); - if (ExtractVec == SourceVecs[0]) - SourceNum = 0; - ExtractElt = cast(Entry.getOperand(1))->getSExtValue(); - unsigned ExtEltSize = - ExtractVec.getValueType().getVectorElementType().getSizeInBits(); - unsigned SmallerSize = ExtEltSize < VTEltSize ? ExtEltSize : VTEltSize; - LanePartNum = SmallerSize / SmallestEltTy.getSizeInBits(); - } + if (Entry.getOpcode() == ISD::UNDEF) + continue; - for (unsigned j = 0; j != ResMultiplier; ++j) { - if (j < LanePartNum) - Mask.push_back(ExtractElt * OffsetMultipliers[SourceNum] + - NumElts * SourceNum - VEXTOffsets[SourceNum] + j); - else - Mask.push_back(-1); - } + auto Src = std::find(Sources.begin(), Sources.end(), Entry.getOperand(0)); + int EltNo = cast(Entry.getOperand(1))->getSExtValue(); + + // EXTRACT_VECTOR_ELT performs an implicit any_ext; BUILD_VECTOR an implicit + // trunc. So only std::min(SrcBits, DestBits) actually get defined in this + // segment. + EVT OrigEltTy = Entry.getOperand(0).getValueType().getVectorElementType(); + int BitsDefined = std::min(OrigEltTy.getSizeInBits(), + VT.getVectorElementType().getSizeInBits()); + int LanesDefined = BitsDefined / BitsPerShuffleLane; + + // This source is expected to fill ResMultiplier lanes of the final shuffle, + // starting at the appropriate offset. + int *LaneMask = &Mask[i * ResMultiplier]; + + int ExtractBase = EltNo * Src->WindowScale + Src->WindowBase; + ExtractBase += NumElts * (Src - Sources.begin()); + for (int j = 0; j < LanesDefined; ++j) + LaneMask[j] = ExtractBase + j; } // Final check before we try to produce nonsense... - if (isShuffleMaskLegal(Mask, ShuffleVT)) { - SDValue Shuffle = DAG.getVectorShuffle(ShuffleVT, dl, ShuffleSrcs[0], - ShuffleSrcs[1], &Mask[0]); - return DAG.getNode(ISD::BITCAST, dl, VT, Shuffle); - } + if (!isShuffleMaskLegal(Mask, ShuffleVT)) + return SDValue(); - return SDValue(); + SDValue ShuffleOps[] = { DAG.getUNDEF(ShuffleVT), DAG.getUNDEF(ShuffleVT) }; + for (unsigned i = 0; i < Sources.size(); ++i) + ShuffleOps[i] = Sources[i].ShuffleVec; + + SDValue Shuffle = DAG.getVectorShuffle(ShuffleVT, dl, ShuffleOps[0], + ShuffleOps[1], &Mask[0]); + return DAG.getNode(ISD::BITCAST, dl, VT, Shuffle); } // check if an EXT instruction can handle the shuffle mask when the @@ -4632,7 +5024,8 @@ static SDValue GeneratePerfectShuffle(unsigned PFEntry, SDValue LHS, VT.getVectorElementType() == MVT::f32) return DAG.getNode(AArch64ISD::REV64, dl, VT, OpLHS); // vrev <4 x i16> -> REV32 - if (VT.getVectorElementType() == MVT::i16) + if (VT.getVectorElementType() == MVT::i16 || + VT.getVectorElementType() == MVT::f16) return DAG.getNode(AArch64ISD::REV32, dl, VT, OpLHS); // vrev <4 x i8> -> REV16 assert(VT.getVectorElementType() == MVT::i8); @@ -4752,7 +5145,7 @@ static SDValue GenerateTBL(SDValue Op, ArrayRef ShuffleMask, static unsigned getDUPLANEOp(EVT EltType) { if (EltType == MVT::i8) return AArch64ISD::DUPLANE8; - if (EltType == MVT::i16) + if (EltType == MVT::i16 || EltType == MVT::f16) return AArch64ISD::DUPLANE16; if (EltType == MVT::i32 || EltType == MVT::f32) return AArch64ISD::DUPLANE32; @@ -4882,7 +5275,8 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op, SDValue SrcLaneV = DAG.getConstant(SrcLane, MVT::i64); EVT ScalarVT = VT.getVectorElementType(); - if (ScalarVT.getSizeInBits() < 32) + + if (ScalarVT.getSizeInBits() < 32 && ScalarVT.isInteger()) ScalarVT = MVT::i32; return DAG.getNode( @@ -4970,7 +5364,7 @@ SDValue AArch64TargetLowering::LowerVectorAND(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::BICi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(0, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType2(CnstVal)) { @@ -4979,7 +5373,7 @@ SDValue AArch64TargetLowering::LowerVectorAND(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::BICi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(8, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType3(CnstVal)) { @@ -4988,7 +5382,7 @@ SDValue AArch64TargetLowering::LowerVectorAND(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::BICi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(16, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType4(CnstVal)) { @@ -4997,7 +5391,7 @@ SDValue AArch64TargetLowering::LowerVectorAND(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::BICi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(24, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType5(CnstVal)) { @@ -5006,7 +5400,7 @@ SDValue AArch64TargetLowering::LowerVectorAND(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::BICi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(0, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType6(CnstVal)) { @@ -5015,7 +5409,7 @@ SDValue AArch64TargetLowering::LowerVectorAND(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::BICi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(8, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } } @@ -5170,7 +5564,7 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::ORRi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(0, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType2(CnstVal)) { @@ -5179,7 +5573,7 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::ORRi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(8, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType3(CnstVal)) { @@ -5188,7 +5582,7 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::ORRi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(16, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType4(CnstVal)) { @@ -5197,7 +5591,7 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::ORRi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(24, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType5(CnstVal)) { @@ -5206,7 +5600,7 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::ORRi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(0, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType6(CnstVal)) { @@ -5215,7 +5609,7 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::ORRi, dl, MovTy, LHS, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(8, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } } @@ -5288,13 +5682,13 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, if (VT.getSizeInBits() == 128) { SDValue Mov = DAG.getNode(AArch64ISD::MOVIedit, dl, MVT::v2i64, DAG.getConstant(CnstVal, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } // Support the V64 version via subregister insertion. SDValue Mov = DAG.getNode(AArch64ISD::MOVIedit, dl, MVT::f64, DAG.getConstant(CnstVal, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType1(CnstVal)) { @@ -5303,7 +5697,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MOVIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(0, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType2(CnstVal)) { @@ -5312,7 +5706,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MOVIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(8, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType3(CnstVal)) { @@ -5321,7 +5715,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MOVIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(16, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType4(CnstVal)) { @@ -5330,7 +5724,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MOVIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(24, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType5(CnstVal)) { @@ -5339,7 +5733,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MOVIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(0, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType6(CnstVal)) { @@ -5348,7 +5742,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MOVIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(8, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType7(CnstVal)) { @@ -5357,7 +5751,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MOVImsl, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(264, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType8(CnstVal)) { @@ -5366,7 +5760,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MOVImsl, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(272, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType9(CnstVal)) { @@ -5374,7 +5768,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v16i8 : MVT::v8i8; SDValue Mov = DAG.getNode(AArch64ISD::MOVI, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } // The few faces of FMOV... @@ -5383,7 +5777,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v4f32 : MVT::v2f32; SDValue Mov = DAG.getNode(AArch64ISD::FMOV, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType12(CnstVal) && @@ -5391,7 +5785,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, CnstVal = AArch64_AM::encodeAdvSIMDModImmType12(CnstVal); SDValue Mov = DAG.getNode(AArch64ISD::FMOV, dl, MVT::v2f64, DAG.getConstant(CnstVal, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } // The many faces of MVNI... @@ -5402,7 +5796,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MVNIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(0, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType2(CnstVal)) { @@ -5411,7 +5805,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MVNIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(8, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType3(CnstVal)) { @@ -5420,7 +5814,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MVNIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(16, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType4(CnstVal)) { @@ -5429,7 +5823,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MVNIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(24, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType5(CnstVal)) { @@ -5438,7 +5832,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MVNIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(0, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType6(CnstVal)) { @@ -5447,7 +5841,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MVNIshift, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(8, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType7(CnstVal)) { @@ -5456,7 +5850,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MVNImsl, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(264, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } if (AArch64_AM::isAdvSIMDModImmType8(CnstVal)) { @@ -5465,7 +5859,7 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op, SDValue Mov = DAG.getNode(AArch64ISD::MVNImsl, dl, MovTy, DAG.getConstant(CnstVal, MVT::i32), DAG.getConstant(272, MVT::i32)); - return DAG.getNode(ISD::BITCAST, dl, VT, Mov); + return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov); } } @@ -5641,11 +6035,12 @@ SDValue AArch64TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op, // Insertion/extraction are legal for V128 types. if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 || - VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64) + VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 || + VT == MVT::v8f16) return Op; if (VT != MVT::v8i8 && VT != MVT::v4i16 && VT != MVT::v2i32 && - VT != MVT::v1i64 && VT != MVT::v2f32) + VT != MVT::v1i64 && VT != MVT::v2f32 && VT != MVT::v4f16) return SDValue(); // For V64 types, we perform insertion by expanding the value @@ -5674,11 +6069,12 @@ AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op, // Insertion/extraction are legal for V128 types. if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 || - VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64) + VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 || + VT == MVT::v8f16) return Op; if (VT != MVT::v8i8 && VT != MVT::v4i16 && VT != MVT::v2i32 && - VT != MVT::v1i64 && VT != MVT::v2f32) + VT != MVT::v1i64 && VT != MVT::v2f32 && VT != MVT::v4f16) return SDValue(); // For V64 types, we perform extraction by expanding the value @@ -6212,7 +6608,7 @@ EVT AArch64TargetLowering::getOptimalMemOpType(uint64_t Size, unsigned DstAlign, !F->getAttributes().hasAttribute(AttributeSet::FunctionIndex, Attribute::NoImplicitFloat) && (memOpAlign(SrcAlign, DstAlign, 16) || - (allowsUnalignedMemoryAccesses(MVT::f128, 0, &Fast) && Fast))) + (allowsMisalignedMemoryAccesses(MVT::f128, 0, 1, &Fast) && Fast))) return MVT::f128; return Size >= 8 ? MVT::i64 : MVT::i32; @@ -6421,7 +6817,7 @@ AArch64TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor, SDValue N0 = N->getOperand(0); unsigned Lg2 = Divisor.countTrailingZeros(); SDValue Zero = DAG.getConstant(0, VT); - SDValue Pow2MinusOne = DAG.getConstant((1 << Lg2) - 1, VT); + SDValue Pow2MinusOne = DAG.getConstant((1ULL << Lg2) - 1, VT); // Add (N0 < 0) ? Pow2 - 1 : 0; SDValue CCVal; @@ -7333,11 +7729,11 @@ static SDValue performExtendCombine(SDNode *N, // If the vector type isn't a simple VT, it's beyond the scope of what // we're worried about here. Let legalization do its thing and hope for // the best. - if (!ResVT.isSimple()) + SDValue Src = N->getOperand(0); + EVT SrcVT = Src->getValueType(0); + if (!ResVT.isSimple() || !SrcVT.isSimple()) return SDValue(); - SDValue Src = N->getOperand(0); - MVT SrcVT = Src->getValueType(0).getSimpleVT(); // If the source VT is a 64-bit vector, we can play games and get the // better results we want. if (SrcVT.getSizeInBits() != 64) @@ -7571,7 +7967,7 @@ static SDValue performPostLD1Combine(SDNode *N, Ops.push_back(Inc); EVT Tys[3] = { VT, MVT::i64, MVT::Other }; - SDVTList SDTys = DAG.getVTList(ArrayRef(Tys, 3)); + SDVTList SDTys = DAG.getVTList(Tys); unsigned NewOp = IsLaneOp ? AArch64ISD::LD1LANEpost : AArch64ISD::LD1DUPpost; SDValue UpdN = DAG.getMemIntrinsicNode(NewOp, SDLoc(N), SDTys, Ops, MemVT, @@ -7701,7 +8097,7 @@ static SDValue performNEONPostLDSTCombine(SDNode *N, Tys[n] = VecTy; Tys[n++] = MVT::i64; // Type of write back register Tys[n] = MVT::Other; // Type of the chain - SDVTList SDTys = DAG.getVTList(ArrayRef(Tys, NumResultVecs + 2)); + SDVTList SDTys = DAG.getVTList(makeArrayRef(Tys, NumResultVecs + 2)); MemIntrinsicSDNode *MemInt = cast(N); SDValue UpdN = DAG.getMemIntrinsicNode(NewOpc, SDLoc(N), SDTys, Ops, @@ -7722,10 +8118,272 @@ static SDValue performNEONPostLDSTCombine(SDNode *N, return SDValue(); } +// Checks to see if the value is the prescribed width and returns information +// about its extension mode. +static +bool checkValueWidth(SDValue V, unsigned width, ISD::LoadExtType &ExtType) { + ExtType = ISD::NON_EXTLOAD; + switch(V.getNode()->getOpcode()) { + default: + return false; + case ISD::LOAD: { + LoadSDNode *LoadNode = cast(V.getNode()); + if ((LoadNode->getMemoryVT() == MVT::i8 && width == 8) + || (LoadNode->getMemoryVT() == MVT::i16 && width == 16)) { + ExtType = LoadNode->getExtensionType(); + return true; + } + return false; + } + case ISD::AssertSext: { + VTSDNode *TypeNode = cast(V.getNode()->getOperand(1)); + if ((TypeNode->getVT() == MVT::i8 && width == 8) + || (TypeNode->getVT() == MVT::i16 && width == 16)) { + ExtType = ISD::SEXTLOAD; + return true; + } + return false; + } + case ISD::AssertZext: { + VTSDNode *TypeNode = cast(V.getNode()->getOperand(1)); + if ((TypeNode->getVT() == MVT::i8 && width == 8) + || (TypeNode->getVT() == MVT::i16 && width == 16)) { + ExtType = ISD::ZEXTLOAD; + return true; + } + return false; + } + case ISD::Constant: + case ISD::TargetConstant: { + if (std::abs(cast(V.getNode())->getSExtValue()) < + 1LL << (width - 1)) + return true; + return false; + } + } + + return true; +} + +// This function does a whole lot of voodoo to determine if the tests are +// equivalent without and with a mask. Essentially what happens is that given a +// DAG resembling: +// +// +-------------+ +-------------+ +-------------+ +-------------+ +// | Input | | AddConstant | | CompConstant| | CC | +// +-------------+ +-------------+ +-------------+ +-------------+ +// | | | | +// V V | +----------+ +// +-------------+ +----+ | | +// | ADD | |0xff| | | +// +-------------+ +----+ | | +// | | | | +// V V | | +// +-------------+ | | +// | AND | | | +// +-------------+ | | +// | | | +// +-----+ | | +// | | | +// V V V +// +-------------+ +// | CMP | +// +-------------+ +// +// The AND node may be safely removed for some combinations of inputs. In +// particular we need to take into account the extension type of the Input, +// the exact values of AddConstant, CompConstant, and CC, along with the nominal +// width of the input (this can work for any width inputs, the above graph is +// specific to 8 bits. +// +// The specific equations were worked out by generating output tables for each +// AArch64CC value in terms of and AddConstant (w1), CompConstant(w2). The +// problem was simplified by working with 4 bit inputs, which means we only +// needed to reason about 24 distinct bit patterns: 8 patterns unique to zero +// extension (8,15), 8 patterns unique to sign extensions (-8,-1), and 8 +// patterns present in both extensions (0,7). For every distinct set of +// AddConstant and CompConstants bit patterns we can consider the masked and +// unmasked versions to be equivalent if the result of this function is true for +// all 16 distinct bit patterns of for the current extension type of Input (w0). +// +// sub w8, w0, w1 +// and w10, w8, #0x0f +// cmp w8, w2 +// cset w9, AArch64CC +// cmp w10, w2 +// cset w11, AArch64CC +// cmp w9, w11 +// cset w0, eq +// ret +// +// Since the above function shows when the outputs are equivalent it defines +// when it is safe to remove the AND. Unfortunately it only runs on AArch64 and +// would be expensive to run during compiles. The equations below were written +// in a test harness that confirmed they gave equivalent outputs to the above +// for all inputs function, so they can be used determine if the removal is +// legal instead. +// +// isEquivalentMaskless() is the code for testing if the AND can be removed +// factored out of the DAG recognition as the DAG can take several forms. + +static +bool isEquivalentMaskless(unsigned CC, unsigned width, + ISD::LoadExtType ExtType, signed AddConstant, + signed CompConstant) { + // By being careful about our equations and only writing the in term + // symbolic values and well known constants (0, 1, -1, MaxUInt) we can + // make them generally applicable to all bit widths. + signed MaxUInt = (1 << width); + + // For the purposes of these comparisons sign extending the type is + // equivalent to zero extending the add and displacing it by half the integer + // width. Provided we are careful and make sure our equations are valid over + // the whole range we can just adjust the input and avoid writing equations + // for sign extended inputs. + if (ExtType == ISD::SEXTLOAD) + AddConstant -= (1 << (width-1)); + + switch(CC) { + case AArch64CC::LE: + case AArch64CC::GT: { + if ((AddConstant == 0) || + (CompConstant == MaxUInt - 1 && AddConstant < 0) || + (AddConstant >= 0 && CompConstant < 0) || + (AddConstant <= 0 && CompConstant <= 0 && CompConstant < AddConstant)) + return true; + } break; + case AArch64CC::LT: + case AArch64CC::GE: { + if ((AddConstant == 0) || + (AddConstant >= 0 && CompConstant <= 0) || + (AddConstant <= 0 && CompConstant <= 0 && CompConstant <= AddConstant)) + return true; + } break; + case AArch64CC::HI: + case AArch64CC::LS: { + if ((AddConstant >= 0 && CompConstant < 0) || + (AddConstant <= 0 && CompConstant >= -1 && + CompConstant < AddConstant + MaxUInt)) + return true; + } break; + case AArch64CC::PL: + case AArch64CC::MI: { + if ((AddConstant == 0) || + (AddConstant > 0 && CompConstant <= 0) || + (AddConstant < 0 && CompConstant <= AddConstant)) + return true; + } break; + case AArch64CC::LO: + case AArch64CC::HS: { + if ((AddConstant >= 0 && CompConstant <= 0) || + (AddConstant <= 0 && CompConstant >= 0 && + CompConstant <= AddConstant + MaxUInt)) + return true; + } break; + case AArch64CC::EQ: + case AArch64CC::NE: { + if ((AddConstant > 0 && CompConstant < 0) || + (AddConstant < 0 && CompConstant >= 0 && + CompConstant < AddConstant + MaxUInt) || + (AddConstant >= 0 && CompConstant >= 0 && + CompConstant >= AddConstant) || + (AddConstant <= 0 && CompConstant < 0 && CompConstant < AddConstant)) + + return true; + } break; + case AArch64CC::VS: + case AArch64CC::VC: + case AArch64CC::AL: + case AArch64CC::NV: + return true; + case AArch64CC::Invalid: + break; + } + + return false; +} + +static +SDValue performCONDCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG, unsigned CCIndex, + unsigned CmpIndex) { + unsigned CC = cast(N->getOperand(CCIndex))->getSExtValue(); + SDNode *SubsNode = N->getOperand(CmpIndex).getNode(); + unsigned CondOpcode = SubsNode->getOpcode(); + + if (CondOpcode != AArch64ISD::SUBS) + return SDValue(); + + // There is a SUBS feeding this condition. Is it fed by a mask we can + // use? + + SDNode *AndNode = SubsNode->getOperand(0).getNode(); + unsigned MaskBits = 0; + + if (AndNode->getOpcode() != ISD::AND) + return SDValue(); + + if (ConstantSDNode *CN = dyn_cast(AndNode->getOperand(1))) { + uint32_t CNV = CN->getZExtValue(); + if (CNV == 255) + MaskBits = 8; + else if (CNV == 65535) + MaskBits = 16; + } + + if (!MaskBits) + return SDValue(); + + SDValue AddValue = AndNode->getOperand(0); + + if (AddValue.getOpcode() != ISD::ADD) + return SDValue(); + + // The basic dag structure is correct, grab the inputs and validate them. + + SDValue AddInputValue1 = AddValue.getNode()->getOperand(0); + SDValue AddInputValue2 = AddValue.getNode()->getOperand(1); + SDValue SubsInputValue = SubsNode->getOperand(1); + + // The mask is present and the provenance of all the values is a smaller type, + // lets see if the mask is superfluous. + + if (!isa(AddInputValue2.getNode()) || + !isa(SubsInputValue.getNode())) + return SDValue(); + + ISD::LoadExtType ExtType; + + if (!checkValueWidth(SubsInputValue, MaskBits, ExtType) || + !checkValueWidth(AddInputValue2, MaskBits, ExtType) || + !checkValueWidth(AddInputValue1, MaskBits, ExtType) ) + return SDValue(); + + if(!isEquivalentMaskless(CC, MaskBits, ExtType, + cast(AddInputValue2.getNode())->getSExtValue(), + cast(SubsInputValue.getNode())->getSExtValue())) + return SDValue(); + + // The AND is not necessary, remove it. + + SDVTList VTs = DAG.getVTList(SubsNode->getValueType(0), + SubsNode->getValueType(1)); + SDValue Ops[] = { AddValue, SubsNode->getOperand(1) }; + + SDValue NewValue = DAG.getNode(CondOpcode, SDLoc(SubsNode), VTs, Ops); + DAG.ReplaceAllUsesWith(SubsNode, NewValue.getNode()); + + return SDValue(N, 0); +} + // Optimize compare with zero and branch. static SDValue performBRCONDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + SDValue NV = performCONDCombine(N, DCI, DAG, 2, 3); + if (NV.getNode()) + N = NV.getNode(); SDValue Chain = N->getOperand(0); SDValue Dest = N->getOperand(1); SDValue CCVal = N->getOperand(2); @@ -7814,21 +8472,23 @@ static SDValue performSelectCombine(SDNode *N, SelectionDAG &DAG) { SDValue N0 = N->getOperand(0); EVT ResVT = N->getValueType(0); - if (!N->getOperand(1).getValueType().isVector()) - return SDValue(); - if (N0.getOpcode() != ISD::SETCC || N0.getValueType() != MVT::i1) return SDValue(); - SDLoc DL(N0); - + // If NumMaskElts == 0, the comparison is larger than select result. The + // largest real NEON comparison is 64-bits per lane, which means the result is + // at most 32-bits and an illegal vector. Just bail out for now. EVT SrcVT = N0.getOperand(0).getValueType(); - SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, - ResVT.getSizeInBits() / SrcVT.getSizeInBits()); + int NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits(); + if (!ResVT.isVector() || NumMaskElts == 0) + return SDValue(); + + SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumMaskElts); EVT CCVT = SrcVT.changeVectorElementTypeToInteger(); // First perform a vector comparison, where lane 0 is the one we're interested // in. + SDLoc DL(N0); SDValue LHS = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, SrcVT, N0.getOperand(0)); SDValue RHS = @@ -7838,8 +8498,8 @@ static SDValue performSelectCombine(SDNode *N, SelectionDAG &DAG) { // Now duplicate the comparison mask we want across all other lanes. SmallVector DUPMask(CCVT.getVectorNumElements(), 0); SDValue Mask = DAG.getVectorShuffle(CCVT, DL, SetCC, SetCC, DUPMask.data()); - Mask = DAG.getNode(ISD::BITCAST, DL, ResVT.changeVectorElementTypeToInteger(), - Mask); + Mask = DAG.getNode(ISD::BITCAST, DL, + ResVT.changeVectorElementTypeToInteger(), Mask); return DAG.getSelect(DL, ResVT, Mask, N->getOperand(1), N->getOperand(2)); } @@ -7880,6 +8540,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performSTORECombine(N, DCI, DAG, Subtarget); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); + case AArch64ISD::CSEL: + return performCONDCombine(N, DCI, DAG, 2, 3); case AArch64ISD::DUP: return performPostLD1Combine(N, DCI, false); case ISD::INSERT_VECTOR_ELT: @@ -8067,17 +8729,14 @@ void AArch64TargetLowering::ReplaceNodeResults( } } -bool AArch64TargetLowering::shouldExpandAtomicInIR(Instruction *Inst) const { - // Loads and stores less than 128-bits are already atomic; ones above that - // are doomed anyway, so defer to the default libcall and blame the OS when - // things go wrong: - if (StoreInst *SI = dyn_cast(Inst)) - return SI->getValueOperand()->getType()->getPrimitiveSizeInBits() == 128; - else if (LoadInst *LI = dyn_cast(Inst)) - return LI->getType()->getPrimitiveSizeInBits() == 128; +bool AArch64TargetLowering::useLoadStackGuardNode() const { + return true; +} - // For the real atomic operations, we have ldxr/stxr up to 128 bits. - return Inst->getType()->getPrimitiveSizeInBits() <= 128; +bool AArch64TargetLowering::combineRepeatedFPDivisors(unsigned NumUsers) const { + // Combine multiple FDIVs with the same divisor into multiple FMULs by the + // reciprocal if there are three or more FDIVs. + return NumUsers > 2; } TargetLoweringBase::LegalizeTypeAction @@ -8092,12 +8751,37 @@ AArch64TargetLowering::getPreferredVectorAction(EVT VT) const { return TargetLoweringBase::getPreferredVectorAction(VT); } +// Loads and stores less than 128-bits are already atomic; ones above that +// are doomed anyway, so defer to the default libcall and blame the OS when +// things go wrong. +bool AArch64TargetLowering::shouldExpandAtomicStoreInIR(StoreInst *SI) const { + unsigned Size = SI->getValueOperand()->getType()->getPrimitiveSizeInBits(); + return Size == 128; +} + +// Loads and stores less than 128-bits are already atomic; ones above that +// are doomed anyway, so defer to the default libcall and blame the OS when +// things go wrong. +bool AArch64TargetLowering::shouldExpandAtomicLoadInIR(LoadInst *LI) const { + unsigned Size = LI->getType()->getPrimitiveSizeInBits(); + return Size == 128; +} + +// For the real atomic operations, we have ldxr/stxr up to 128 bits, +bool AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const { + unsigned Size = AI->getType()->getPrimitiveSizeInBits(); + return Size <= 128; +} + +bool AArch64TargetLowering::hasLoadLinkedStoreConditional() const { + return true; +} + Value *AArch64TargetLowering::emitLoadLinked(IRBuilder<> &Builder, Value *Addr, AtomicOrdering Ord) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); Type *ValTy = cast(Addr->getType())->getElementType(); - bool IsAcquire = - Ord == Acquire || Ord == AcquireRelease || Ord == SequentiallyConsistent; + bool IsAcquire = isAtLeastAcquire(Ord); // Since i128 isn't legal and intrinsics don't get type-lowered, the ldrexd // intrinsic must return {i64, i64} and we have to recombine them into a @@ -8132,8 +8816,7 @@ Value *AArch64TargetLowering::emitStoreConditional(IRBuilder<> &Builder, Value *Val, Value *Addr, AtomicOrdering Ord) const { Module *M = Builder.GetInsertBlock()->getParent()->getParent(); - bool IsRelease = - Ord == Release || Ord == AcquireRelease || Ord == SequentiallyConsistent; + bool IsRelease = isAtLeastRelease(Ord); // Since the intrinsics must have legal type, the i128 intrinsics take two // parameters: "i64, i64". We must marshal Val into the appropriate form @@ -8160,3 +8843,8 @@ Value *AArch64TargetLowering::emitStoreConditional(IRBuilder<> &Builder, Val, Stxr->getFunctionType()->getParamType(0)), Addr); } + +bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters( + Type *Ty, CallingConv::ID CallConv, bool isVarArg) const { + return Ty->isArrayTy(); +}