[AArch64] Handle Cyclone-specific register in common way
[oota-llvm.git] / lib / Target / AArch64 / AArch64ISelLowering.cpp
index ae77ca164029bf764e74902a7cca9853736c54ae..9a45ca711bf273810973fba168306a53318e8973 100644 (file)
@@ -281,14 +281,39 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom);
   setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
 
-  // f16 is storage-only, so we promote operations to f32 if we know this is
-  // valid, and ignore them otherwise. The operations not mentioned here will
-  // fail to select, but this is not a major problem as no source language
-  // should be emitting native f16 operations yet.
-  setOperationAction(ISD::FADD, MVT::f16, Promote);
-  setOperationAction(ISD::FDIV, MVT::f16, Promote);
-  setOperationAction(ISD::FMUL, MVT::f16, Promote);
-  setOperationAction(ISD::FSUB, MVT::f16, Promote);
+  // f16 is a storage-only type, always promote it to f32.
+  setOperationAction(ISD::SETCC,       MVT::f16,  Promote);
+  setOperationAction(ISD::BR_CC,       MVT::f16,  Promote);
+  setOperationAction(ISD::SELECT_CC,   MVT::f16,  Promote);
+  setOperationAction(ISD::SELECT,      MVT::f16,  Promote);
+  setOperationAction(ISD::FADD,        MVT::f16,  Promote);
+  setOperationAction(ISD::FSUB,        MVT::f16,  Promote);
+  setOperationAction(ISD::FMUL,        MVT::f16,  Promote);
+  setOperationAction(ISD::FDIV,        MVT::f16,  Promote);
+  setOperationAction(ISD::FREM,        MVT::f16,  Promote);
+  setOperationAction(ISD::FMA,         MVT::f16,  Promote);
+  setOperationAction(ISD::FNEG,        MVT::f16,  Promote);
+  setOperationAction(ISD::FABS,        MVT::f16,  Promote);
+  setOperationAction(ISD::FCEIL,       MVT::f16,  Promote);
+  setOperationAction(ISD::FCOPYSIGN,   MVT::f16,  Promote);
+  setOperationAction(ISD::FCOS,        MVT::f16,  Promote);
+  setOperationAction(ISD::FFLOOR,      MVT::f16,  Promote);
+  setOperationAction(ISD::FNEARBYINT,  MVT::f16,  Promote);
+  setOperationAction(ISD::FPOW,        MVT::f16,  Promote);
+  setOperationAction(ISD::FPOWI,       MVT::f16,  Promote);
+  setOperationAction(ISD::FRINT,       MVT::f16,  Promote);
+  setOperationAction(ISD::FSIN,        MVT::f16,  Promote);
+  setOperationAction(ISD::FSINCOS,     MVT::f16,  Promote);
+  setOperationAction(ISD::FSQRT,       MVT::f16,  Promote);
+  setOperationAction(ISD::FEXP,        MVT::f16,  Promote);
+  setOperationAction(ISD::FEXP2,       MVT::f16,  Promote);
+  setOperationAction(ISD::FLOG,        MVT::f16,  Promote);
+  setOperationAction(ISD::FLOG2,       MVT::f16,  Promote);
+  setOperationAction(ISD::FLOG10,      MVT::f16,  Promote);
+  setOperationAction(ISD::FROUND,      MVT::f16,  Promote);
+  setOperationAction(ISD::FTRUNC,      MVT::f16,  Promote);
+  setOperationAction(ISD::FMINNUM,     MVT::f16,  Promote);
+  setOperationAction(ISD::FMAXNUM,     MVT::f16,  Promote);
 
   // v4f16 is also a storage-only type, so promote it to v4f32 when that is
   // known to be safe.
@@ -481,6 +506,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
   // Enable TBZ/TBNZ
   MaskAndBranchFoldingIsLegal = true;
+  EnableExtLdPromotion = true;
 
   setMinFunctionAlignment(2);
 
@@ -1257,7 +1283,7 @@ getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) {
   case ISD::SMULO:
   case ISD::UMULO: {
     CC = AArch64CC::NE;
-    bool IsSigned = (Op.getOpcode() == ISD::SMULO) ? true : false;
+    bool IsSigned = Op.getOpcode() == ISD::SMULO;
     if (Op.getValueType() == MVT::i32) {
       unsigned ExtendOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
       // For a 32 bit multiply with overflow check we want the instruction
@@ -1557,6 +1583,14 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
   if (Op.getOperand(0).getValueType().isVector())
     return LowerVectorFP_TO_INT(Op, DAG);
 
+  // f16 conversions are promoted to f32.
+  if (Op.getOperand(0).getValueType() == MVT::f16) {
+    SDLoc dl(Op);
+    return DAG.getNode(
+        Op.getOpcode(), dl, Op.getValueType(),
+        DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, Op.getOperand(0)));
+  }
+
   if (Op.getOperand(0).getValueType() != MVT::f128) {
     // It's legal except when f128 is involved
     return Op;
@@ -1606,6 +1640,15 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
   if (Op.getValueType().isVector())
     return LowerVectorINT_TO_FP(Op, DAG);
 
+  // f16 conversions are promoted to f32.
+  if (Op.getValueType() == MVT::f16) {
+    SDLoc dl(Op);
+    return DAG.getNode(
+        ISD::FP_ROUND, dl, MVT::f16,
+        DAG.getNode(Op.getOpcode(), dl, MVT::f32, Op.getOperand(0)),
+        DAG.getIntPtrConstant(0));
+  }
+
   // i128 conversions are libcalls.
   if (Op.getOperand(0).getValueType() == MVT::i128)
     return SDValue();
@@ -2701,8 +2744,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
             DAG.getConstant(Outs[i].Flags.getByValSize(), MVT::i64);
         SDValue Cpy = DAG.getMemcpy(
             Chain, DL, DstAddr, Arg, SizeNode, Outs[i].Flags.getByValAlign(),
-            /*isVol = */ false,
-            /*AlwaysInline = */ false, DstInfo, MachinePointerInfo());
+            /*isVol = */ false, /*AlwaysInline = */ false,
+            /*isTailCall = */ false,
+            DstInfo, MachinePointerInfo());
 
         MemOpChains.push_back(Cpy);
       } else {
@@ -2794,13 +2838,13 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
   const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
   if (IsThisReturn) {
     // For 'this' returns, use the X0-preserving mask if applicable
-    Mask = TRI->getThisReturnPreservedMask(CallConv);
+    Mask = TRI->getThisReturnPreservedMask(MF, CallConv);
     if (!Mask) {
       IsThisReturn = false;
-      Mask = TRI->getCallPreservedMask(CallConv);
+      Mask = TRI->getCallPreservedMask(MF, CallConv);
     }
   } else
-    Mask = TRI->getCallPreservedMask(CallConv);
+    Mask = TRI->getCallPreservedMask(MF, CallConv);
 
   assert(Mask && "Missing call preserved mask for calling convention");
   Ops.push_back(DAG.getRegisterMask(Mask));
@@ -3514,49 +3558,10 @@ static bool selectCCOpsAreFMaxCompatible(SDValue Cmp, SDValue Result) {
   return Result->getOpcode() == ISD::FP_EXTEND && Result->getOperand(0) == Cmp;
 }
 
-SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
-                                           SelectionDAG &DAG) const {
-  SDValue CC = Op->getOperand(0);
-  SDValue TVal = Op->getOperand(1);
-  SDValue FVal = Op->getOperand(2);
-  SDLoc DL(Op);
-
-  unsigned Opc = CC.getOpcode();
-  // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a select
-  // instruction.
-  if (CC.getResNo() == 1 &&
-      (Opc == ISD::SADDO || Opc == ISD::UADDO || Opc == ISD::SSUBO ||
-       Opc == ISD::USUBO || Opc == ISD::SMULO || Opc == ISD::UMULO)) {
-    // Only lower legal XALUO ops.
-    if (!DAG.getTargetLoweringInfo().isTypeLegal(CC->getValueType(0)))
-      return SDValue();
-
-    AArch64CC::CondCode OFCC;
-    SDValue Value, Overflow;
-    std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, CC.getValue(0), DAG);
-    SDValue CCVal = DAG.getConstant(OFCC, MVT::i32);
-
-    return DAG.getNode(AArch64ISD::CSEL, DL, Op.getValueType(), TVal, FVal,
-                       CCVal, Overflow);
-  }
-
-  if (CC.getOpcode() == ISD::SETCC)
-    return DAG.getSelectCC(DL, CC.getOperand(0), CC.getOperand(1), TVal, FVal,
-                           cast<CondCodeSDNode>(CC.getOperand(2))->get());
-  else
-    return DAG.getSelectCC(DL, CC, DAG.getConstant(0, CC.getValueType()), TVal,
-                           FVal, ISD::SETNE);
-}
-
-SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
+SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
+                                              SDValue RHS, SDValue TVal,
+                                              SDValue FVal, SDLoc dl,
                                               SelectionDAG &DAG) const {
-  ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get();
-  SDValue LHS = Op.getOperand(0);
-  SDValue RHS = Op.getOperand(1);
-  SDValue TVal = Op.getOperand(2);
-  SDValue FVal = Op.getOperand(3);
-  SDLoc dl(Op);
-
   // Handle f128 first, because it will result in a comparison of some RTLIB
   // call result against zero.
   if (LHS.getValueType() == MVT::f128) {
@@ -3664,14 +3669,14 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
     SDValue CCVal;
     SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
 
-    EVT VT = Op.getValueType();
+    EVT VT = TVal.getValueType();
     return DAG.getNode(Opcode, dl, VT, TVal, FVal, CCVal, Cmp);
   }
 
   // Now we know we're dealing with FP values.
   assert(LHS.getValueType() == MVT::f32 || LHS.getValueType() == MVT::f64);
   assert(LHS.getValueType() == RHS.getValueType());
-  EVT VT = Op.getValueType();
+  EVT VT = TVal.getValueType();
 
   // Try to match this select into a max/min operation, which have dedicated
   // opcode in the instruction set.
@@ -3732,6 +3737,58 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
   return CS1;
 }
 
+SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
+                                              SelectionDAG &DAG) const {
+  ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get();
+  SDValue LHS = Op.getOperand(0);
+  SDValue RHS = Op.getOperand(1);
+  SDValue TVal = Op.getOperand(2);
+  SDValue FVal = Op.getOperand(3);
+  SDLoc DL(Op);
+  return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
+}
+
+SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
+                                           SelectionDAG &DAG) const {
+  SDValue CCVal = Op->getOperand(0);
+  SDValue TVal = Op->getOperand(1);
+  SDValue FVal = Op->getOperand(2);
+  SDLoc DL(Op);
+
+  unsigned Opc = CCVal.getOpcode();
+  // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a select
+  // instruction.
+  if (CCVal.getResNo() == 1 &&
+      (Opc == ISD::SADDO || Opc == ISD::UADDO || Opc == ISD::SSUBO ||
+       Opc == ISD::USUBO || Opc == ISD::SMULO || Opc == ISD::UMULO)) {
+    // Only lower legal XALUO ops.
+    if (!DAG.getTargetLoweringInfo().isTypeLegal(CCVal->getValueType(0)))
+      return SDValue();
+
+    AArch64CC::CondCode OFCC;
+    SDValue Value, Overflow;
+    std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, CCVal.getValue(0), DAG);
+    SDValue CCVal = DAG.getConstant(OFCC, MVT::i32);
+
+    return DAG.getNode(AArch64ISD::CSEL, DL, Op.getValueType(), TVal, FVal,
+                       CCVal, Overflow);
+  }
+
+  // Lower it the same way as we would lower a SELECT_CC node.
+  ISD::CondCode CC;
+  SDValue LHS, RHS;
+  if (CCVal.getOpcode() == ISD::SETCC) {
+    LHS = CCVal.getOperand(0);
+    RHS = CCVal.getOperand(1);
+    CC = cast<CondCodeSDNode>(CCVal->getOperand(2))->get();
+  } else {
+    LHS = CCVal;
+    RHS = DAG.getConstant(0, CCVal.getValueType());
+    CC = ISD::SETNE;
+  }
+  return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
+}
+
 SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
                                               SelectionDAG &DAG) const {
   // Jump table entries as PC relative offsets. No additional tweaking
@@ -3920,7 +3977,7 @@ SDValue AArch64TargetLowering::LowerVACOPY(SDValue Op,
 
   return DAG.getMemcpy(Op.getOperand(0), SDLoc(Op), Op.getOperand(1),
                        Op.getOperand(2), DAG.getConstant(VaListSize, MVT::i32),
-                       8, false, false, MachinePointerInfo(DestSV),
+                       8, false, false, false, MachinePointerInfo(DestSV),
                        MachinePointerInfo(SrcSV));
 }
 
@@ -5892,8 +5949,10 @@ FailedModImm:
 
     if (VT.getVectorElementType().isFloatingPoint()) {
       SmallVector<SDValue, 8> Ops;
-      MVT NewType =
-          (VT.getVectorElementType() == MVT::f32) ? MVT::i32 : MVT::i64;
+      EVT EltTy = VT.getVectorElementType();
+      assert ((EltTy == MVT::f16 || EltTy == MVT::f32 || EltTy == MVT::f64) &&
+              "Unsupported floating-point vector type");
+      MVT NewType = MVT::getIntegerVT(EltTy.getSizeInBits());
       for (unsigned i = 0; i < NumElts; ++i)
         Ops.push_back(DAG.getNode(ISD::BITCAST, dl, NewType, Op.getOperand(i)));
       EVT VecVT = EVT::getVectorVT(*DAG.getContext(), NewType, NumElts);
@@ -6552,6 +6611,59 @@ bool AArch64TargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
           VT1.getSizeInBits() <= 32);
 }
 
+bool AArch64TargetLowering::isExtFreeImpl(const Instruction *Ext) const {
+  if (isa<FPExtInst>(Ext))
+    return false;
+
+  // Vector types are next free.
+  if (Ext->getType()->isVectorTy())
+    return false;
+
+  for (const Use &U : Ext->uses()) {
+    // The extension is free if we can fold it with a left shift in an
+    // addressing mode or an arithmetic operation: add, sub, and cmp.
+
+    // Is there a shift?
+    const Instruction *Instr = cast<Instruction>(U.getUser());
+
+    // Is this a constant shift?
+    switch (Instr->getOpcode()) {
+    case Instruction::Shl:
+      if (!isa<ConstantInt>(Instr->getOperand(1)))
+        return false;
+      break;
+    case Instruction::GetElementPtr: {
+      gep_type_iterator GTI = gep_type_begin(Instr);
+      std::advance(GTI, U.getOperandNo());
+      Type *IdxTy = *GTI;
+      // This extension will end up with a shift because of the scaling factor.
+      // 8-bit sized types have a scaling factor of 1, thus a shift amount of 0.
+      // Get the shift amount based on the scaling factor:
+      // log2(sizeof(IdxTy)) - log2(8).
+      uint64_t ShiftAmt =
+        countTrailingZeros(getDataLayout()->getTypeStoreSizeInBits(IdxTy)) - 3;
+      // Is the constant foldable in the shift of the addressing mode?
+      // I.e., shift amount is between 1 and 4 inclusive.
+      if (ShiftAmt == 0 || ShiftAmt > 4)
+        return false;
+      break;
+    }
+    case Instruction::Trunc:
+      // Check if this is a noop.
+      // trunc(sext ty1 to ty2) to ty1.
+      if (Instr->getType() == Ext->getOperand(0)->getType())
+        continue;
+    // FALL THROUGH.
+    default:
+      return false;
+    }
+
+    // At this point we can use the bfm family, so this extension is free
+    // for that use.
+  }
+  return true;
+}
+
 bool AArch64TargetLowering::hasPairedLoad(Type *LoadedType,
                                           unsigned &RequiredAligment) const {
   if (!LoadedType->isIntegerTy() && !LoadedType->isFloatTy())
@@ -6595,7 +6707,17 @@ EVT AArch64TargetLowering::getOptimalMemOpType(uint64_t Size, unsigned DstAlign,
        (allowsMisalignedMemoryAccesses(MVT::f128, 0, 1, &Fast) && Fast)))
     return MVT::f128;
 
-  return Size >= 8 ? MVT::i64 : MVT::i32;
+  if (Size >= 8 &&
+      (memOpAlign(SrcAlign, DstAlign, 8) ||
+       (allowsMisalignedMemoryAccesses(MVT::i64, 0, 1, &Fast) && Fast)))
+    return MVT::i64;
+
+  if (Size >= 4 &&
+      (memOpAlign(SrcAlign, DstAlign, 4) ||
+       (allowsMisalignedMemoryAccesses(MVT::i32, 0, 1, &Fast) && Fast)))
+    return MVT::i32;
+
+  return MVT::Other;
 }
 
 // 12-bit optionally shifted immediates are legal for adds.
@@ -6746,7 +6868,7 @@ bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
   unsigned LZ = countLeadingZeros((uint64_t)Val);
   unsigned Shift = (63 - LZ) / 16;
   // MOVZ is free so return true for one or fewer MOVK.
-  return (Shift < 3) ? true : false;
+  return Shift < 3;
 }
 
 // Generate SUBS and CSEL for integer abs.
@@ -7176,21 +7298,54 @@ static SDValue performBitcastCombine(SDNode *N,
 static SDValue performConcatVectorsCombine(SDNode *N,
                                            TargetLowering::DAGCombinerInfo &DCI,
                                            SelectionDAG &DAG) {
+  SDLoc dl(N);
+  EVT VT = N->getValueType(0);
+  SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+
+  // Optimize concat_vectors of truncated vectors, where the intermediate
+  // type is illegal, to avoid said illegality,  e.g.,
+  //   (v4i16 (concat_vectors (v2i16 (truncate (v2i64))),
+  //                          (v2i16 (truncate (v2i64)))))
+  // ->
+  //   (v4i16 (truncate (vector_shuffle (v4i32 (bitcast (v2i64))),
+  //                                    (v4i32 (bitcast (v2i64))),
+  //                                    <0, 2, 4, 6>)))
+  // This isn't really target-specific, but ISD::TRUNCATE legality isn't keyed
+  // on both input and result type, so we might generate worse code.
+  // On AArch64 we know it's fine for v2i64->v4i16 and v4i32->v8i8.
+  if (N->getNumOperands() == 2 &&
+      N0->getOpcode() == ISD::TRUNCATE &&
+      N1->getOpcode() == ISD::TRUNCATE) {
+    SDValue N00 = N0->getOperand(0);
+    SDValue N10 = N1->getOperand(0);
+    EVT N00VT = N00.getValueType();
+
+    if (N00VT == N10.getValueType() &&
+        (N00VT == MVT::v2i64 || N00VT == MVT::v4i32) &&
+        N00VT.getScalarSizeInBits() == 4 * VT.getScalarSizeInBits()) {
+      MVT MidVT = (N00VT == MVT::v2i64 ? MVT::v4i32 : MVT::v8i16);
+      SmallVector<int, 8> Mask(MidVT.getVectorNumElements());
+      for (size_t i = 0; i < Mask.size(); ++i)
+        Mask[i] = i * 2;
+      return DAG.getNode(ISD::TRUNCATE, dl, VT,
+                         DAG.getVectorShuffle(
+                             MidVT, dl,
+                             DAG.getNode(ISD::BITCAST, dl, MidVT, N00),
+                             DAG.getNode(ISD::BITCAST, dl, MidVT, N10), Mask));
+    }
+  }
+
   // Wait 'til after everything is legalized to try this. That way we have
   // legal vector types and such.
   if (DCI.isBeforeLegalizeOps())
     return SDValue();
 
-  SDLoc dl(N);
-  EVT VT = N->getValueType(0);
-
   // If we see a (concat_vectors (v1x64 A), (v1x64 A)) it's really a vector
   // splat. The indexed instructions are going to be expecting a DUPLANE64, so
   // canonicalise to that.
-  if (N->getOperand(0) == N->getOperand(1) && VT.getVectorNumElements() == 2) {
+  if (N0 == N1 && VT.getVectorNumElements() == 2) {
     assert(VT.getVectorElementType().getSizeInBits() == 64);
-    return DAG.getNode(AArch64ISD::DUPLANE64, dl, VT,
-                       WidenVector(N->getOperand(0), DAG),
+    return DAG.getNode(AArch64ISD::DUPLANE64, dl, VT, WidenVector(N0, DAG),
                        DAG.getConstant(0, MVT::i64));
   }
 
@@ -7203,10 +7358,9 @@ static SDValue performConcatVectorsCombine(SDNode *N,
   // becomes
   //    (bitconvert (concat_vectors (v4i16 (bitconvert LHS)), RHS))
 
-  SDValue Op1 = N->getOperand(1);
-  if (Op1->getOpcode() != ISD::BITCAST)
+  if (N1->getOpcode() != ISD::BITCAST)
     return SDValue();
-  SDValue RHS = Op1->getOperand(0);
+  SDValue RHS = N1->getOperand(0);
   MVT RHSTy = RHS.getValueType().getSimpleVT();
   // If the RHS is not a vector, this is not the pattern we're looking for.
   if (!RHSTy.isVector())
@@ -7216,10 +7370,10 @@ static SDValue performConcatVectorsCombine(SDNode *N,
 
   MVT ConcatTy = MVT::getVectorVT(RHSTy.getVectorElementType(),
                                   RHSTy.getVectorNumElements() * 2);
-  return DAG.getNode(
-      ISD::BITCAST, dl, VT,
-      DAG.getNode(ISD::CONCAT_VECTORS, dl, ConcatTy,
-                  DAG.getNode(ISD::BITCAST, dl, RHSTy, N->getOperand(0)), RHS));
+  return DAG.getNode(ISD::BITCAST, dl, VT,
+                     DAG.getNode(ISD::CONCAT_VECTORS, dl, ConcatTy,
+                                 DAG.getNode(ISD::BITCAST, dl, RHSTy, N0),
+                                 RHS));
 }
 
 static SDValue tryCombineFixedPointConvert(SDNode *N,