Remove extra whitespace. NFC.
[oota-llvm.git] / lib / Target / AArch64 / AArch64ISelLowering.cpp
index f0fb03451b2a6f72dbdd7ff7a65f42c41f02091c..4ecfbe9e228091335dc720281446e4668be1a672 100644 (file)
@@ -237,6 +237,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
   setOperationAction(ISD::SDIVREM, MVT::i32, Expand);
   setOperationAction(ISD::SDIVREM, MVT::i64, Expand);
+  for (MVT VT : MVT::vector_valuetypes()) {
+    setOperationAction(ISD::SDIVREM, VT, Expand);
+    setOperationAction(ISD::UDIVREM, VT, Expand);
+  }
   setOperationAction(ISD::SREM, MVT::i32, Expand);
   setOperationAction(ISD::SREM, MVT::i64, Expand);
   setOperationAction(ISD::UDIVREM, MVT::i32, Expand);
@@ -687,12 +691,10 @@ void AArch64TargetLowering::addTypeForNEON(EVT VT, EVT PromotedBitwiseVT) {
   setOperationAction(ISD::FP_TO_SINT, VT.getSimpleVT(), Custom);
   setOperationAction(ISD::FP_TO_UINT, VT.getSimpleVT(), Custom);
 
-  // [SU][MIN|MAX] and [SU]ABSDIFF are available for all NEON types apart from
-  // i64.
+  // [SU][MIN|MAX] are available for all NEON types apart from i64.
   if (!VT.isFloatingPoint() &&
       VT.getSimpleVT() != MVT::v2i64 && VT.getSimpleVT() != MVT::v1i64)
-    for (unsigned Opcode : {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
-                            ISD::SABSDIFF, ISD::UABSDIFF})
+    for (unsigned Opcode : {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX})
       setOperationAction(Opcode, VT.getSimpleVT(), Legal);
 
   // F[MIN|MAX][NUM|NAN] are available for all FP NEON types (not f16 though!).
@@ -1846,6 +1848,16 @@ static SDValue LowerVectorFP_TO_INT(SDValue Op, SelectionDAG &DAG) {
   // in the cost tables.
   EVT InVT = Op.getOperand(0).getValueType();
   EVT VT = Op.getValueType();
+  unsigned NumElts = InVT.getVectorNumElements();
+
+  // f16 vectors are promoted to f32 before a conversion.
+  if (InVT.getVectorElementType() == MVT::f16) {
+    MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
+    SDLoc dl(Op);
+    return DAG.getNode(
+        Op.getOpcode(), dl, Op.getValueType(),
+        DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
+  }
 
   if (VT.getSizeInBits() < InVT.getSizeInBits()) {
     SDLoc dl(Op);
@@ -2326,11 +2338,6 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   }
 }
 
-/// getFunctionAlignment - Return the Log2 alignment of this function.
-unsigned AArch64TargetLowering::getFunctionAlignment(const Function *F) const {
-  return 2;
-}
-
 //===----------------------------------------------------------------------===//
 //                      Calling Convention Implementation
 //===----------------------------------------------------------------------===//
@@ -2419,7 +2426,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments(
 
       continue;
     }
-    
+
     if (VA.isRegLoc()) {
       // Arguments stored in registers.
       EVT RegVT = VA.getLocVT();
@@ -3264,6 +3271,19 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     Flag = Chain.getValue(1);
     RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT()));
   }
+  const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+  const MCPhysReg *I =
+      TRI->getCalleeSavedRegsViaCopy(&DAG.getMachineFunction());
+  if (I) {
+    for (; *I; ++I) {
+      if (AArch64::GPR64RegClass.contains(*I))
+        RetOps.push_back(DAG.getRegister(*I, MVT::i64));
+      else if (AArch64::FPR64RegClass.contains(*I))
+        RetOps.push_back(DAG.getRegister(*I, MVT::getFloatingPointVT(64)));
+      else
+        llvm_unreachable("Unexpected register class in CSRsViaCopy!");
+    }
+  }
 
   RetOps[0] = Chain; // Update chain.
 
@@ -5054,7 +5074,7 @@ static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
 
   // The index of an EXT is the first element if it is not UNDEF.
   // Watch out for the beginning UNDEFs. The EXT index should be the expected
-  // value of the first element.  E.g. 
+  // value of the first element.  E.g.
   // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>.
   // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>.
   // ExpectedElt is the last mask index plus 1.
@@ -6723,7 +6743,7 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
   case Intrinsic::aarch64_neon_ld4r: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     // Conservatively set memVT to the entire set of vectors loaded.
-    uint64_t NumElts = DL.getTypeAllocSize(I.getType()) / 8;
+    uint64_t NumElts = DL.getTypeSizeInBits(I.getType()) / 64;
     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
     Info.ptrVal = I.getArgOperand(I.getNumArgOperands() - 1);
     Info.offset = 0;
@@ -6749,7 +6769,7 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
       Type *ArgTy = I.getArgOperand(ArgI)->getType();
       if (!ArgTy->isVectorTy())
         break;
-      NumElts += DL.getTypeAllocSize(ArgTy) / 8;
+      NumElts += DL.getTypeSizeInBits(ArgTy) / 64;
     }
     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
     Info.ptrVal = I.getArgOperand(I.getNumArgOperands() - 1);
@@ -6992,7 +7012,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
   const DataLayout &DL = LI->getModule()->getDataLayout();
 
   VectorType *VecTy = Shuffles[0]->getType();
-  unsigned VecSize = DL.getTypeAllocSizeInBits(VecTy);
+  unsigned VecSize = DL.getTypeSizeInBits(VecTy);
 
   // Skip if we do not have NEON and skip illegal vector types.
   if (!Subtarget->hasNEON() || (VecSize != 64 && VecSize != 128))
@@ -7078,7 +7098,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
   VectorType *SubVecTy = VectorType::get(EltTy, NumSubElts);
 
   const DataLayout &DL = SI->getModule()->getDataLayout();
-  unsigned SubVecSize = DL.getTypeAllocSizeInBits(SubVecTy);
+  unsigned SubVecSize = DL.getTypeSizeInBits(SubVecTy);
 
   // Skip if we do not have NEON and skip illegal vector types.
   if (!Subtarget->hasNEON() || (SubVecSize != 64 && SubVecSize != 128))
@@ -8237,15 +8257,14 @@ static SDValue performAddSubLongCombine(SDNode *N,
 //   (aarch64_neon_umull (extract_high (v2i64 vec)))
 //                     (extract_high (v2i64 (dup128 scalar)))))
 //
-static SDValue tryCombineLongOpWithDup(SDNode *N,
+static SDValue tryCombineLongOpWithDup(unsigned IID, SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        SelectionDAG &DAG) {
   if (DCI.isBeforeLegalizeOps())
     return SDValue();
 
-  bool IsIntrinsic = N->getOpcode() == ISD::INTRINSIC_WO_CHAIN;
-  SDValue LHS = N->getOperand(IsIntrinsic ? 1 : 0);
-  SDValue RHS = N->getOperand(IsIntrinsic ? 2 : 1);
+  SDValue LHS = N->getOperand(1);
+  SDValue RHS = N->getOperand(2);
   assert(LHS.getValueType().is64BitVector() &&
          RHS.getValueType().is64BitVector() &&
          "unexpected shape for long operation");
@@ -8263,13 +8282,8 @@ static SDValue tryCombineLongOpWithDup(SDNode *N,
       return SDValue();
   }
 
-  // N could either be an intrinsic or a sabsdiff/uabsdiff node.
-  if (IsIntrinsic)
-    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0),
-                       N->getOperand(0), LHS, RHS);
-  else
-    return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0),
-                       LHS, RHS);
+  return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0),
+                     N->getOperand(0), LHS, RHS);
 }
 
 static SDValue tryCombineShiftImm(unsigned IID, SDNode *N, SelectionDAG &DAG) {
@@ -8387,12 +8401,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_neon_fmin:
     return DAG.getNode(ISD::FMINNAN, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
-  case Intrinsic::aarch64_neon_sabd:
-    return DAG.getNode(ISD::SABSDIFF, SDLoc(N), N->getValueType(0),
-                       N->getOperand(1), N->getOperand(2));
-  case Intrinsic::aarch64_neon_uabd:
-    return DAG.getNode(ISD::UABSDIFF, SDLoc(N), N->getValueType(0),
-                       N->getOperand(1), N->getOperand(2));
   case Intrinsic::aarch64_neon_fmaxnm:
     return DAG.getNode(ISD::FMAXNUM, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
@@ -8403,7 +8411,7 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_neon_umull:
   case Intrinsic::aarch64_neon_pmull:
   case Intrinsic::aarch64_neon_sqdmull:
-    return tryCombineLongOpWithDup(N, DCI, DAG);
+    return tryCombineLongOpWithDup(IID, N, DCI, DAG);
   case Intrinsic::aarch64_neon_sqshl:
   case Intrinsic::aarch64_neon_uqshl:
   case Intrinsic::aarch64_neon_sqshlu:
@@ -8428,15 +8436,18 @@ static SDValue performExtendCombine(SDNode *N,
   // helps the backend to decide that an sabdl2 would be useful, saving a real
   // extract_high operation.
   if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
-      (N->getOperand(0).getOpcode() == ISD::SABSDIFF ||
-       N->getOperand(0).getOpcode() == ISD::UABSDIFF)) {
+      N->getOperand(0).getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
     SDNode *ABDNode = N->getOperand(0).getNode();
-    SDValue NewABD = tryCombineLongOpWithDup(ABDNode, DCI, DAG);
-    if (!NewABD.getNode())
-      return SDValue();
+    unsigned IID = getIntrinsicID(ABDNode);
+    if (IID == Intrinsic::aarch64_neon_sabd ||
+        IID == Intrinsic::aarch64_neon_uabd) {
+      SDValue NewABD = tryCombineLongOpWithDup(IID, ABDNode, DCI, DAG);
+      if (!NewABD.getNode())
+        return SDValue();
 
-    return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0),
-                       NewABD);
+      return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0),
+                         NewABD);
+    }
   }
 
   // This is effectively a custom type legalization for AArch64.
@@ -9480,6 +9491,103 @@ static SDValue performBRCONDCombine(SDNode *N,
   return SDValue();
 }
 
+// Optimize some simple tbz/tbnz cases.  Returns the new operand and bit to test
+// as well as whether the test should be inverted.  This code is required to
+// catch these cases (as opposed to standard dag combines) because
+// AArch64ISD::TBZ is matched during legalization.
+static SDValue getTestBitOperand(SDValue Op, unsigned &Bit, bool &Invert,
+                                 SelectionDAG &DAG) {
+
+  if (!Op->hasOneUse())
+    return Op;
+
+  // We don't handle undef/constant-fold cases below, as they should have
+  // already been taken care of (e.g. and of 0, test of undefined shifted bits,
+  // etc.)
+
+  // (tbz (trunc x), b) -> (tbz x, b)
+  // This case is just here to enable more of the below cases to be caught.
+  if (Op->getOpcode() == ISD::TRUNCATE &&
+      Bit < Op->getValueType(0).getSizeInBits()) {
+    return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+  }
+
+  if (Op->getNumOperands() != 2)
+    return Op;
+
+  auto *C = dyn_cast<ConstantSDNode>(Op->getOperand(1));
+  if (!C)
+    return Op;
+
+  switch (Op->getOpcode()) {
+  default:
+    return Op;
+
+  // (tbz (and x, m), b) -> (tbz x, b)
+  case ISD::AND:
+    if ((C->getZExtValue() >> Bit) & 1)
+      return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+    return Op;
+
+  // (tbz (shl x, c), b) -> (tbz x, b-c)
+  case ISD::SHL:
+    if (C->getZExtValue() <= Bit &&
+        (Bit - C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
+      Bit = Bit - C->getZExtValue();
+      return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+    }
+    return Op;
+
+  // (tbz (sra x, c), b) -> (tbz x, b+c) or (tbz x, msb) if b+c is > # bits in x
+  case ISD::SRA:
+    Bit = Bit + C->getZExtValue();
+    if (Bit >= Op->getValueType(0).getSizeInBits())
+      Bit = Op->getValueType(0).getSizeInBits() - 1;
+    return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+
+  // (tbz (srl x, c), b) -> (tbz x, b+c)
+  case ISD::SRL:
+    if ((Bit + C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
+      Bit = Bit + C->getZExtValue();
+      return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+    }
+    return Op;
+
+  // (tbz (xor x, -1), b) -> (tbnz x, b)
+  case ISD::XOR:
+    if ((C->getZExtValue() >> Bit) & 1)
+      Invert = !Invert;
+    return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
+  }
+}
+
+// Optimize test single bit zero/non-zero and branch.
+static SDValue performTBZCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI,
+                                 SelectionDAG &DAG) {
+  unsigned Bit = cast<ConstantSDNode>(N->getOperand(2))->getZExtValue();
+  bool Invert = false;
+  SDValue TestSrc = N->getOperand(1);
+  SDValue NewTestSrc = getTestBitOperand(TestSrc, Bit, Invert, DAG);
+
+  if (TestSrc == NewTestSrc)
+    return SDValue();
+
+  unsigned NewOpc = N->getOpcode();
+  if (Invert) {
+    if (NewOpc == AArch64ISD::TBZ)
+      NewOpc = AArch64ISD::TBNZ;
+    else {
+      assert(NewOpc == AArch64ISD::TBNZ);
+      NewOpc = AArch64ISD::TBZ;
+    }
+  }
+
+  SDLoc DL(N);
+  return DAG.getNode(NewOpc, DL, MVT::Other, N->getOperand(0), NewTestSrc,
+                     DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3));
+}
+
 // vselect (v1i1 setcc) ->
 //     vselect (v1iXX setcc)  (XX is the size of the compared operand type)
 // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as
@@ -9631,6 +9739,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
     return performSTORECombine(N, DCI, DAG, Subtarget);
   case AArch64ISD::BRCOND:
     return performBRCONDCombine(N, DCI, DAG);
+  case AArch64ISD::TBNZ:
+  case AArch64ISD::TBZ:
+    return performTBZCombine(N, DCI, DAG);
   case AArch64ISD::CSEL:
     return performCONDCombine(N, DCI, DAG, 2, 3);
   case AArch64ISD::DUP:
@@ -10005,3 +10116,49 @@ Value *AArch64TargetLowering::getSafeStackPointerLocation(IRBuilder<> &IRB) cons
       IRB.CreateConstGEP1_32(IRB.CreateCall(ThreadPointerFunc), TlsOffset),
       Type::getInt8PtrTy(IRB.getContext())->getPointerTo(0));
 }
+
+void AArch64TargetLowering::initializeSplitCSR(MachineBasicBlock *Entry) const {
+  // Update IsSplitCSR in AArch64unctionInfo.
+  AArch64FunctionInfo *AFI = Entry->getParent()->getInfo<AArch64FunctionInfo>();
+  AFI->setIsSplitCSR(true);
+}
+
+void AArch64TargetLowering::insertCopiesSplitCSR(
+    MachineBasicBlock *Entry,
+    const SmallVectorImpl<MachineBasicBlock *> &Exits) const {
+  const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+  const MCPhysReg *IStart = TRI->getCalleeSavedRegsViaCopy(Entry->getParent());
+  if (!IStart)
+    return;
+
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  MachineRegisterInfo *MRI = &Entry->getParent()->getRegInfo();
+  for (const MCPhysReg *I = IStart; *I; ++I) {
+    const TargetRegisterClass *RC = nullptr;
+    if (AArch64::GPR64RegClass.contains(*I))
+      RC = &AArch64::GPR64RegClass;
+    else if (AArch64::FPR64RegClass.contains(*I))
+      RC = &AArch64::FPR64RegClass;
+    else
+      llvm_unreachable("Unexpected register class in CSRsViaCopy!");
+
+    unsigned NewVR = MRI->createVirtualRegister(RC);
+    // Create copy from CSR to a virtual register.
+    // FIXME: this currently does not emit CFI pseudo-instructions, it works
+    // fine for CXX_FAST_TLS since the C++-style TLS access functions should be
+    // nounwind. If we want to generalize this later, we may need to emit
+    // CFI pseudo-instructions.
+    assert(Entry->getParent()->getFunction()->hasFnAttribute(
+               Attribute::NoUnwind) &&
+           "Function should be nounwind in insertCopiesSplitCSR!");
+    Entry->addLiveIn(*I);
+    BuildMI(*Entry, Entry->begin(), DebugLoc(), TII->get(TargetOpcode::COPY),
+            NewVR)
+        .addReg(*I);
+
+    for (auto *Exit : Exits)
+      BuildMI(*Exit, Exit->begin(), DebugLoc(), TII->get(TargetOpcode::COPY),
+              *I)
+          .addReg(NewVR);
+  }
+}