Revert "[AArch64] Add DAG combine for extract extend pattern"
[oota-llvm.git] / lib / Target / AArch64 / AArch64ISelLowering.cpp
index 77dfb8afe3fcaebd2f8617c0eb28274ee2c61d12..9f5beff121002a5e9103f31b57af3f3762c0406f 100644 (file)
@@ -237,6 +237,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
   setOperationAction(ISD::SDIVREM, MVT::i32, Expand);
   setOperationAction(ISD::SDIVREM, MVT::i64, Expand);
+  for (MVT VT : MVT::vector_valuetypes()) {
+    setOperationAction(ISD::SDIVREM, VT, Expand);
+    setOperationAction(ISD::UDIVREM, VT, Expand);
+  }
   setOperationAction(ISD::SREM, MVT::i32, Expand);
   setOperationAction(ISD::SREM, MVT::i64, Expand);
   setOperationAction(ISD::UDIVREM, MVT::i32, Expand);
@@ -489,6 +493,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(ISD::BITCAST);
   setTargetDAGCombine(ISD::CONCAT_VECTORS);
   setTargetDAGCombine(ISD::STORE);
+  if (Subtarget->supportsAddressTopByteIgnored())
+    setTargetDAGCombine(ISD::LOAD);
 
   setTargetDAGCombine(ISD::MUL);
 
@@ -685,12 +691,10 @@ void AArch64TargetLowering::addTypeForNEON(EVT VT, EVT PromotedBitwiseVT) {
   setOperationAction(ISD::FP_TO_SINT, VT.getSimpleVT(), Custom);
   setOperationAction(ISD::FP_TO_UINT, VT.getSimpleVT(), Custom);
 
-  // [SU][MIN|MAX] and [SU]ABSDIFF are available for all NEON types apart from
-  // i64.
+  // [SU][MIN|MAX] are available for all NEON types apart from i64.
   if (!VT.isFloatingPoint() &&
       VT.getSimpleVT() != MVT::v2i64 && VT.getSimpleVT() != MVT::v1i64)
-    for (unsigned Opcode : {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX,
-                            ISD::SABSDIFF, ISD::UABSDIFF})
+    for (unsigned Opcode : {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX})
       setOperationAction(Opcode, VT.getSimpleVT(), Legal);
 
   // F[MIN|MAX][NUM|NAN] are available for all FP NEON types (not f16 though!).
@@ -1182,8 +1186,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
   // register to WZR/XZR if it ends up being unused.
   unsigned Opcode = AArch64ISD::SUBS;
 
-  if (RHS.getOpcode() == ISD::SUB && isa<ConstantSDNode>(RHS.getOperand(0)) &&
-      cast<ConstantSDNode>(RHS.getOperand(0))->getZExtValue() == 0 &&
+  if (RHS.getOpcode() == ISD::SUB && isNullConstant(RHS.getOperand(0)) &&
       (CC == ISD::SETEQ || CC == ISD::SETNE)) {
     // We'd like to combine a (CMP op1, (sub 0, op2) into a CMN instruction on
     // the grounds that "op1 - (-op2) == op1 + op2". However, the C and V flags
@@ -1197,8 +1200,7 @@ static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
     // the absence of information about op2.
     Opcode = AArch64ISD::ADDS;
     RHS = RHS.getOperand(1);
-  } else if (LHS.getOpcode() == ISD::AND && isa<ConstantSDNode>(RHS) &&
-             cast<ConstantSDNode>(RHS)->getZExtValue() == 0 &&
+  } else if (LHS.getOpcode() == ISD::AND && isNullConstant(RHS) &&
              !isUnsignedIntSetCC(CC)) {
     // Similarly, (CMP (and X, Y), 0) can be implemented with a TST
     // (a.k.a. ANDS) except that the flags are only guaranteed to work for one
@@ -1263,8 +1265,7 @@ static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
     Opcode = AArch64ISD::FCCMP;
   else if (RHS.getOpcode() == ISD::SUB) {
     SDValue SubOp0 = RHS.getOperand(0);
-    if (const ConstantSDNode *SubOp0C = dyn_cast<ConstantSDNode>(SubOp0))
-      if (SubOp0C->isNullValue() && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
+    if (isNullConstant(SubOp0) && (CC == ISD::SETEQ || CC == ISD::SETNE)) {
         // See emitComparison() on why we can only do this for SETEQ and SETNE.
         Opcode = AArch64ISD::CCMN;
         RHS = RHS.getOperand(1);
@@ -1847,6 +1848,16 @@ static SDValue LowerVectorFP_TO_INT(SDValue Op, SelectionDAG &DAG) {
   // in the cost tables.
   EVT InVT = Op.getOperand(0).getValueType();
   EVT VT = Op.getValueType();
+  unsigned NumElts = InVT.getVectorNumElements();
+
+  // f16 vectors are promoted to f32 before a conversion.
+  if (InVT.getVectorElementType() == MVT::f16) {
+    MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
+    SDLoc dl(Op);
+    return DAG.getNode(
+        Op.getOpcode(), dl, Op.getValueType(),
+        DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
+  }
 
   if (VT.getSizeInBits() < InVT.getSizeInBits()) {
     SDLoc dl(Op);
@@ -2327,11 +2338,6 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
   }
 }
 
-/// getFunctionAlignment - Return the Log2 alignment of this function.
-unsigned AArch64TargetLowering::getFunctionAlignment(const Function *F) const {
-  return 2;
-}
-
 //===----------------------------------------------------------------------===//
 //                      Calling Convention Implementation
 //===----------------------------------------------------------------------===//
@@ -3265,6 +3271,19 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
     Flag = Chain.getValue(1);
     RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT()));
   }
+  const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+  const MCPhysReg *I =
+      TRI->getCalleeSavedRegsViaCopy(&DAG.getMachineFunction());
+  if (I) {
+    for (; *I; ++I) {
+      if (AArch64::GPR64RegClass.contains(*I))
+        RetOps.push_back(DAG.getRegister(*I, MVT::i64));
+      else if (AArch64::FPR64RegClass.contains(*I))
+        RetOps.push_back(DAG.getRegister(*I, MVT::getFloatingPointVT(64)));
+      else
+        llvm_unreachable("Unexpected register class in CSRsViaCopy!");
+    }
+  }
 
   RetOps[0] = Chain; // Update chain.
 
@@ -3580,8 +3599,7 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
   // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a branch
   // instruction.
   unsigned Opc = LHS.getOpcode();
-  if (LHS.getResNo() == 1 && isa<ConstantSDNode>(RHS) &&
-      cast<ConstantSDNode>(RHS)->isOne() &&
+  if (LHS.getResNo() == 1 && isOneConstant(RHS) &&
       (Opc == ISD::SADDO || Opc == ISD::UADDO || Opc == ISD::SSUBO ||
        Opc == ISD::USUBO || Opc == ISD::SMULO || Opc == ISD::UMULO)) {
     assert((CC == ISD::SETEQ || CC == ISD::SETNE) &&
@@ -3885,7 +3903,13 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
     }
   }
 
-  // Handle integers first.
+  // Also handle f16, for which we need to do a f32 comparison.
+  if (LHS.getValueType() == MVT::f16) {
+    LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
+    RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
+  }
+
+  // Next, handle integers.
   if (LHS.getValueType().isInteger()) {
     assert((LHS.getValueType() == RHS.getValueType()) &&
            (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64));
@@ -3908,9 +3932,7 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
     } else if (TVal.getOpcode() == ISD::XOR) {
       // If TVal is a NOT we want to swap TVal and FVal so that we can match
       // with a CSINV rather than a CSEL.
-      ConstantSDNode *CVal = dyn_cast<ConstantSDNode>(TVal.getOperand(1));
-
-      if (CVal && CVal->isAllOnesValue()) {
+      if (isAllOnesConstant(TVal.getOperand(1))) {
         std::swap(TVal, FVal);
         std::swap(CTVal, CFVal);
         CC = ISD::getSetCCInverse(CC, true);
@@ -3918,9 +3940,7 @@ SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
     } else if (TVal.getOpcode() == ISD::SUB) {
       // If TVal is a negation (SUB from 0) we want to swap TVal and FVal so
       // that we can match with a CSNEG rather than a CSEL.
-      ConstantSDNode *CVal = dyn_cast<ConstantSDNode>(TVal.getOperand(0));
-
-      if (CVal && CVal->isNullValue()) {
+      if (isNullConstant(TVal.getOperand(0))) {
         std::swap(TVal, FVal);
         std::swap(CTVal, CFVal);
         CC = ISD::getSetCCInverse(CC, true);
@@ -4380,46 +4400,57 @@ SDValue AArch64TargetLowering::LowerShiftRightParts(SDValue Op,
   SDValue ShOpLo = Op.getOperand(0);
   SDValue ShOpHi = Op.getOperand(1);
   SDValue ShAmt = Op.getOperand(2);
-  SDValue ARMcc;
   unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
 
   assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
 
   SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64,
                                  DAG.getConstant(VTBits, dl, MVT::i64), ShAmt);
-  SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
+  SDValue HiBitsForLo = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
+
+  // Unfortunately, if ShAmt == 0, we just calculated "(SHL ShOpHi, 64)" which
+  // is "undef". We wanted 0, so CSEL it directly.
+  SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64),
+                               ISD::SETEQ, dl, DAG);
+  SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32);
+  HiBitsForLo =
+      DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64),
+                  HiBitsForLo, CCVal, Cmp);
+
   SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt,
                                    DAG.getConstant(VTBits, dl, MVT::i64));
-  SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
 
-  SDValue Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64),
-                               ISD::SETGE, dl, DAG);
-  SDValue CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
+  SDValue LoBitsForLo = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
+  SDValue LoForNormalShift =
+      DAG.getNode(ISD::OR, dl, VT, LoBitsForLo, HiBitsForLo);
 
-  SDValue FalseValLo = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
-  SDValue TrueValLo = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
-  SDValue Lo =
-      DAG.getNode(AArch64ISD::CSEL, dl, VT, TrueValLo, FalseValLo, CCVal, Cmp);
+  Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE,
+                       dl, DAG);
+  CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
+  SDValue LoForBigShift = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
+  SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift,
+                           LoForNormalShift, CCVal, Cmp);
 
   // AArch64 shifts larger than the register width are wrapped rather than
   // clamped, so we can't just emit "hi >> x".
-  SDValue FalseValHi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
-  SDValue TrueValHi = Opc == ISD::SRA
-                          ? DAG.getNode(Opc, dl, VT, ShOpHi,
-                                        DAG.getConstant(VTBits - 1, dl,
-                                                        MVT::i64))
-                          : DAG.getConstant(0, dl, VT);
-  SDValue Hi =
-      DAG.getNode(AArch64ISD::CSEL, dl, VT, TrueValHi, FalseValHi, CCVal, Cmp);
+  SDValue HiForNormalShift = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
+  SDValue HiForBigShift =
+      Opc == ISD::SRA
+          ? DAG.getNode(Opc, dl, VT, ShOpHi,
+                        DAG.getConstant(VTBits - 1, dl, MVT::i64))
+          : DAG.getConstant(0, dl, VT);
+  SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift,
+                           HiForNormalShift, CCVal, Cmp);
 
   SDValue Ops[2] = { Lo, Hi };
   return DAG.getMergeValues(Ops, dl);
 }
 
+
 /// LowerShiftLeftParts - Lower SHL_PARTS, which returns two
 /// i64 values and take a 2 x i64 value to shift plus a shift amount.
 SDValue AArch64TargetLowering::LowerShiftLeftParts(SDValue Op,
-                                                 SelectionDAG &DAG) const {
+                                                   SelectionDAG &DAG) const {
   assert(Op.getNumOperands() == 3 && "Not a double-shift!");
   EVT VT = Op.getValueType();
   unsigned VTBits = VT.getSizeInBits();
@@ -4427,31 +4458,41 @@ SDValue AArch64TargetLowering::LowerShiftLeftParts(SDValue Op,
   SDValue ShOpLo = Op.getOperand(0);
   SDValue ShOpHi = Op.getOperand(1);
   SDValue ShAmt = Op.getOperand(2);
-  SDValue ARMcc;
 
   assert(Op.getOpcode() == ISD::SHL_PARTS);
   SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64,
                                  DAG.getConstant(VTBits, dl, MVT::i64), ShAmt);
-  SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
+  SDValue LoBitsForHi = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
+
+  // Unfortunately, if ShAmt == 0, we just calculated "(SRL ShOpLo, 64)" which
+  // is "undef". We wanted 0, so CSEL it directly.
+  SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64),
+                               ISD::SETEQ, dl, DAG);
+  SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32);
+  LoBitsForHi =
+      DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64),
+                  LoBitsForHi, CCVal, Cmp);
+
   SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt,
                                    DAG.getConstant(VTBits, dl, MVT::i64));
-  SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
-  SDValue Tmp3 = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
+  SDValue HiBitsForHi = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
+  SDValue HiForNormalShift =
+      DAG.getNode(ISD::OR, dl, VT, LoBitsForHi, HiBitsForHi);
 
-  SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
+  SDValue HiForBigShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
 
-  SDValue Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64),
-                               ISD::SETGE, dl, DAG);
-  SDValue CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
-  SDValue Hi =
-      DAG.getNode(AArch64ISD::CSEL, dl, VT, Tmp3, FalseVal, CCVal, Cmp);
+  Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE,
+                       dl, DAG);
+  CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
+  SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift,
+                           HiForNormalShift, CCVal, Cmp);
 
   // AArch64 shifts of larger than register sizes are wrapped rather than
   // clamped, so we can't just emit "lo << a" if a is too big.
-  SDValue TrueValLo = DAG.getConstant(0, dl, VT);
-  SDValue FalseValLo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
-  SDValue Lo =
-      DAG.getNode(AArch64ISD::CSEL, dl, VT, TrueValLo, FalseValLo, CCVal, Cmp);
+  SDValue LoForBigShift = DAG.getConstant(0, dl, VT);
+  SDValue LoForNormalShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
+  SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift,
+                           LoForNormalShift, CCVal, Cmp);
 
   SDValue Ops[2] = { Lo, Hi };
   return DAG.getMergeValues(Ops, dl);
@@ -4633,8 +4674,7 @@ void AArch64TargetLowering::LowerAsmOperandForConstraint(
   // Validate and return a target constant for them if we can.
   case 'z': {
     // 'z' maps to xzr or wzr so it needs an input of 0.
-    ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op);
-    if (!C || C->getZExtValue() != 0)
+    if (!isNullConstant(Op))
       return;
 
     if (Op.getValueType() == MVT::i64)
@@ -6395,24 +6435,11 @@ SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op,
   unsigned Val = Cst->getZExtValue();
 
   unsigned Size = Op.getValueType().getSizeInBits();
-  if (Val == 0) {
-    switch (Size) {
-    case 8:
-      return DAG.getTargetExtractSubreg(AArch64::bsub, dl, Op.getValueType(),
-                                        Op.getOperand(0));
-    case 16:
-      return DAG.getTargetExtractSubreg(AArch64::hsub, dl, Op.getValueType(),
-                                        Op.getOperand(0));
-    case 32:
-      return DAG.getTargetExtractSubreg(AArch64::ssub, dl, Op.getValueType(),
-                                        Op.getOperand(0));
-    case 64:
-      return DAG.getTargetExtractSubreg(AArch64::dsub, dl, Op.getValueType(),
-                                        Op.getOperand(0));
-    default:
-      llvm_unreachable("Unexpected vector type in extract_subvector!");
-    }
-  }
+
+  // This will get lowered to an appropriate EXTRACT_SUBREG in ISel.
+  if (Val == 0)
+    return Op;
+
   // If this is extracting the upper 64-bits of a 128-bit vector, we match
   // that directly.
   if (Size == 64 && Val * VT.getVectorElementType().getSizeInBits() == 64)
@@ -6716,7 +6743,7 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
   case Intrinsic::aarch64_neon_ld4r: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     // Conservatively set memVT to the entire set of vectors loaded.
-    uint64_t NumElts = DL.getTypeAllocSize(I.getType()) / 8;
+    uint64_t NumElts = DL.getTypeSizeInBits(I.getType()) / 64;
     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
     Info.ptrVal = I.getArgOperand(I.getNumArgOperands() - 1);
     Info.offset = 0;
@@ -6742,7 +6769,7 @@ bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
       Type *ArgTy = I.getArgOperand(ArgI)->getType();
       if (!ArgTy->isVectorTy())
         break;
-      NumElts += DL.getTypeAllocSize(ArgTy) / 8;
+      NumElts += DL.getTypeSizeInBits(ArgTy) / 64;
     }
     Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
     Info.ptrVal = I.getArgOperand(I.getNumArgOperands() - 1);
@@ -6985,7 +7012,7 @@ bool AArch64TargetLowering::lowerInterleavedLoad(
   const DataLayout &DL = LI->getModule()->getDataLayout();
 
   VectorType *VecTy = Shuffles[0]->getType();
-  unsigned VecSize = DL.getTypeAllocSizeInBits(VecTy);
+  unsigned VecSize = DL.getTypeSizeInBits(VecTy);
 
   // Skip if we do not have NEON and skip illegal vector types.
   if (!Subtarget->hasNEON() || (VecSize != 64 && VecSize != 128))
@@ -7071,7 +7098,7 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
   VectorType *SubVecTy = VectorType::get(EltTy, NumSubElts);
 
   const DataLayout &DL = SI->getModule()->getDataLayout();
-  unsigned SubVecSize = DL.getTypeAllocSizeInBits(SubVecTy);
+  unsigned SubVecSize = DL.getTypeSizeInBits(SubVecTy);
 
   // Skip if we do not have NEON and skip illegal vector types.
   if (!Subtarget->hasNEON() || (SubVecSize != 64 && SubVecSize != 128))
@@ -8230,15 +8257,14 @@ static SDValue performAddSubLongCombine(SDNode *N,
 //   (aarch64_neon_umull (extract_high (v2i64 vec)))
 //                     (extract_high (v2i64 (dup128 scalar)))))
 //
-static SDValue tryCombineLongOpWithDup(SDNode *N,
+static SDValue tryCombineLongOpWithDup(unsigned IID, SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        SelectionDAG &DAG) {
   if (DCI.isBeforeLegalizeOps())
     return SDValue();
 
-  bool IsIntrinsic = N->getOpcode() == ISD::INTRINSIC_WO_CHAIN;
-  SDValue LHS = N->getOperand(IsIntrinsic ? 1 : 0);
-  SDValue RHS = N->getOperand(IsIntrinsic ? 2 : 1);
+  SDValue LHS = N->getOperand(1);
+  SDValue RHS = N->getOperand(2);
   assert(LHS.getValueType().is64BitVector() &&
          RHS.getValueType().is64BitVector() &&
          "unexpected shape for long operation");
@@ -8256,13 +8282,8 @@ static SDValue tryCombineLongOpWithDup(SDNode *N,
       return SDValue();
   }
 
-  // N could either be an intrinsic or a sabsdiff/uabsdiff node.
-  if (IsIntrinsic)
-    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0),
-                       N->getOperand(0), LHS, RHS);
-  else
-    return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0),
-                       LHS, RHS);
+  return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0),
+                     N->getOperand(0), LHS, RHS);
 }
 
 static SDValue tryCombineShiftImm(unsigned IID, SDNode *N, SelectionDAG &DAG) {
@@ -8380,12 +8401,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_neon_fmin:
     return DAG.getNode(ISD::FMINNAN, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
-  case Intrinsic::aarch64_neon_sabd:
-    return DAG.getNode(ISD::SABSDIFF, SDLoc(N), N->getValueType(0),
-                       N->getOperand(1), N->getOperand(2));
-  case Intrinsic::aarch64_neon_uabd:
-    return DAG.getNode(ISD::UABSDIFF, SDLoc(N), N->getValueType(0),
-                       N->getOperand(1), N->getOperand(2));
   case Intrinsic::aarch64_neon_fmaxnm:
     return DAG.getNode(ISD::FMAXNUM, SDLoc(N), N->getValueType(0),
                        N->getOperand(1), N->getOperand(2));
@@ -8396,7 +8411,7 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_neon_umull:
   case Intrinsic::aarch64_neon_pmull:
   case Intrinsic::aarch64_neon_sqdmull:
-    return tryCombineLongOpWithDup(N, DCI, DAG);
+    return tryCombineLongOpWithDup(IID, N, DCI, DAG);
   case Intrinsic::aarch64_neon_sqshl:
   case Intrinsic::aarch64_neon_uqshl:
   case Intrinsic::aarch64_neon_sqshlu:
@@ -8421,15 +8436,18 @@ static SDValue performExtendCombine(SDNode *N,
   // helps the backend to decide that an sabdl2 would be useful, saving a real
   // extract_high operation.
   if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
-      (N->getOperand(0).getOpcode() == ISD::SABSDIFF ||
-       N->getOperand(0).getOpcode() == ISD::UABSDIFF)) {
+      N->getOperand(0).getOpcode() == ISD::INTRINSIC_WO_CHAIN) {
     SDNode *ABDNode = N->getOperand(0).getNode();
-    SDValue NewABD = tryCombineLongOpWithDup(ABDNode, DCI, DAG);
-    if (!NewABD.getNode())
-      return SDValue();
+    unsigned IID = getIntrinsicID(ABDNode);
+    if (IID == Intrinsic::aarch64_neon_sabd ||
+        IID == Intrinsic::aarch64_neon_uabd) {
+      SDValue NewABD = tryCombineLongOpWithDup(IID, ABDNode, DCI, DAG);
+      if (!NewABD.getNode())
+        return SDValue();
 
-    return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0),
-                       NewABD);
+      return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0),
+                         NewABD);
+    }
   }
 
   // This is effectively a custom type legalization for AArch64.
@@ -8568,10 +8586,9 @@ static SDValue replaceSplatVectorStore(SelectionDAG &DAG, StoreSDNode *St) {
   return NewST1;
 }
 
-static SDValue performSTORECombine(SDNode *N,
-                                   TargetLowering::DAGCombinerInfo &DCI,
-                                   SelectionDAG &DAG,
-                                   const AArch64Subtarget *Subtarget) {
+static SDValue split16BStores(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                              SelectionDAG &DAG,
+                              const AArch64Subtarget *Subtarget) {
   if (!DCI.isBeforeLegalize())
     return SDValue();
 
@@ -8733,7 +8750,39 @@ static SDValue performPostLD1Combine(SDNode *N,
   return SDValue();
 }
 
-/// This function handles the log2-shuffle pattern produced by the
+/// Simplify \Addr given that the top byte of it is ignored by HW during
+/// address translation.
+static bool performTBISimplification(SDValue Addr,
+                                     TargetLowering::DAGCombinerInfo &DCI,
+                                     SelectionDAG &DAG) {
+  APInt DemandedMask = APInt::getLowBitsSet(64, 56);
+  APInt KnownZero, KnownOne;
+  TargetLowering::TargetLoweringOpt TLO(DAG, DCI.isBeforeLegalize(),
+                                        DCI.isBeforeLegalizeOps());
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (TLI.SimplifyDemandedBits(Addr, DemandedMask, KnownZero, KnownOne, TLO)) {
+    DCI.CommitTargetLoweringOpt(TLO);
+    return true;
+  }
+  return false;
+}
+
+static SDValue performSTORECombine(SDNode *N,
+                                   TargetLowering::DAGCombinerInfo &DCI,
+                                   SelectionDAG &DAG,
+                                   const AArch64Subtarget *Subtarget) {
+  SDValue Split = split16BStores(N, DCI, DAG, Subtarget);
+  if (Split.getNode())
+    return Split;
+
+  if (Subtarget->supportsAddressTopByteIgnored() &&
+      performTBISimplification(N->getOperand(2), DCI, DAG))
+    return SDValue(N, 0);
+
+  return SDValue();
+}
+
+  /// This function handles the log2-shuffle pattern produced by the
 /// LoopVectorizer for the across vector reduction. It consists of
 /// log2(NumVectorElements) steps and, in each step, 2^(s) elements
 /// are reduced, where s is an induction variable from 0 to
@@ -8936,18 +8985,15 @@ performAcrossLaneMinMaxReductionCombine(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   // Expect to check only lane 0 from the vector SETCC.
-  if (!isa<ConstantSDNode>(N0.getOperand(1)) ||
-      cast<ConstantSDNode>(N0.getOperand(1))->getZExtValue() != 0)
+  if (!isNullConstant(N0.getOperand(1)))
     return SDValue();
 
   // Expect to extract the true value from lane 0.
-  if (!isa<ConstantSDNode>(IfTrue.getOperand(1)) ||
-      cast<ConstantSDNode>(IfTrue.getOperand(1))->getZExtValue() != 0)
+  if (!isNullConstant(IfTrue.getOperand(1)))
     return SDValue();
 
   // Expect to extract the false value from lane 1.
-  if (!isa<ConstantSDNode>(IfFalse.getOperand(1)) ||
-      cast<ConstantSDNode>(IfFalse.getOperand(1))->getZExtValue() != 1)
+  if (!isOneConstant(IfFalse.getOperand(1)))
     return SDValue();
 
   return tryMatchAcrossLaneShuffleForReduction(N, SetCC, Op, DAG);
@@ -8980,7 +9026,7 @@ performAcrossLaneAddReductionCombine(SDNode *N, SelectionDAG &DAG,
 
   // The vector extract idx must constant zero because we only expect the final
   // result of the reduction is placed in lane 0.
-  if (!isa<ConstantSDNode>(N1) || cast<ConstantSDNode>(N1)->getZExtValue() != 0)
+  if (!isNullConstant(N1))
     return SDValue();
 
   EVT VTy = N0.getValueType();
@@ -9422,10 +9468,10 @@ static SDValue performBRCONDCombine(SDNode *N,
   if (LHS.getValueType() != MVT::i32 && LHS.getValueType() != MVT::i64)
     return SDValue();
 
-  if (isa<ConstantSDNode>(LHS) && cast<ConstantSDNode>(LHS)->isNullValue())
+  if (isNullConstant(LHS))
     std::swap(LHS, RHS);
 
-  if (!isa<ConstantSDNode>(RHS) || !cast<ConstantSDNode>(RHS)->isNullValue())
+  if (!isNullConstant(RHS))
     return SDValue();
 
   if (LHS.getOpcode() == ISD::SHL || LHS.getOpcode() == ISD::SRA ||
@@ -9588,6 +9634,10 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   }
   case ISD::VSELECT:
     return performVSelectCombine(N, DCI.DAG);
+  case ISD::LOAD:
+    if (performTBISimplification(N->getOperand(1), DCI, DAG))
+      return SDValue(N, 0);
+    break;
   case ISD::STORE:
     return performSTORECombine(N, DCI, DAG, Subtarget);
   case AArch64ISD::BRCOND:
@@ -9966,3 +10016,49 @@ Value *AArch64TargetLowering::getSafeStackPointerLocation(IRBuilder<> &IRB) cons
       IRB.CreateConstGEP1_32(IRB.CreateCall(ThreadPointerFunc), TlsOffset),
       Type::getInt8PtrTy(IRB.getContext())->getPointerTo(0));
 }
+
+void AArch64TargetLowering::initializeSplitCSR(MachineBasicBlock *Entry) const {
+  // Update IsSplitCSR in AArch64unctionInfo.
+  AArch64FunctionInfo *AFI = Entry->getParent()->getInfo<AArch64FunctionInfo>();
+  AFI->setIsSplitCSR(true);
+}
+
+void AArch64TargetLowering::insertCopiesSplitCSR(
+    MachineBasicBlock *Entry,
+    const SmallVectorImpl<MachineBasicBlock *> &Exits) const {
+  const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
+  const MCPhysReg *IStart = TRI->getCalleeSavedRegsViaCopy(Entry->getParent());
+  if (!IStart)
+    return;
+
+  const TargetInstrInfo *TII = Subtarget->getInstrInfo();
+  MachineRegisterInfo *MRI = &Entry->getParent()->getRegInfo();
+  for (const MCPhysReg *I = IStart; *I; ++I) {
+    const TargetRegisterClass *RC = nullptr;
+    if (AArch64::GPR64RegClass.contains(*I))
+      RC = &AArch64::GPR64RegClass;
+    else if (AArch64::FPR64RegClass.contains(*I))
+      RC = &AArch64::FPR64RegClass;
+    else
+      llvm_unreachable("Unexpected register class in CSRsViaCopy!");
+
+    unsigned NewVR = MRI->createVirtualRegister(RC);
+    // Create copy from CSR to a virtual register.
+    // FIXME: this currently does not emit CFI pseudo-instructions, it works
+    // fine for CXX_FAST_TLS since the C++-style TLS access functions should be
+    // nounwind. If we want to generalize this later, we may need to emit
+    // CFI pseudo-instructions.
+    assert(Entry->getParent()->getFunction()->hasFnAttribute(
+               Attribute::NoUnwind) &&
+           "Function should be nounwind in insertCopiesSplitCSR!");
+    Entry->addLiveIn(*I);
+    BuildMI(*Entry, Entry->begin(), DebugLoc(), TII->get(TargetOpcode::COPY),
+            NewVR)
+        .addReg(*I);
+
+    for (auto *Exit : Exits)
+      BuildMI(*Exit, Exit->begin(), DebugLoc(), TII->get(TargetOpcode::COPY),
+              *I)
+          .addReg(NewVR);
+  }
+}