[FastISel][AArch64] Fold sign-/zero-extends into the load instruction.
[oota-llvm.git] / lib / Target / AArch64 / AArch64FastISel.cpp
index fb4f1cb7237a5b02810e7031f6ea8016b6505048..f20dcae71686b716686f9f3fb5548d2c199cfd6a 100644 (file)
@@ -177,7 +177,7 @@ private:
   bool emitICmp(MVT RetVT, const Value *LHS, const Value *RHS, bool IsZExt);
   bool emitICmp_ri(MVT RetVT, unsigned LHSReg, bool LHSIsKill, uint64_t Imm);
   bool emitFCmp(MVT RetVT, const Value *LHS, const Value *RHS);
-  bool emitLoad(MVT VT, unsigned &ResultReg, Address Addr,
+  bool emitLoad(MVT VT, unsigned &ResultReg, Address Addr, bool WantZExt = true,
                 MachineMemOperand *MMO = nullptr);
   bool emitStore(MVT VT, unsigned SrcReg, Address Addr,
                  MachineMemOperand *MMO = nullptr);
@@ -255,6 +255,23 @@ public:
 
 #include "AArch64GenCallingConv.inc"
 
+/// \brief Check if the sign-/zero-extend will be a noop.
+static bool isIntExtFree(const Instruction *I) {
+  assert((isa<ZExtInst>(I) || isa<SExtInst>(I)) &&
+         "Unexpected integer extend instruction.");
+  bool IsZExt = isa<ZExtInst>(I);
+
+  if (const auto *LI = dyn_cast<LoadInst>(I->getOperand(0)))
+    if (LI->hasOneUse())
+      return true;
+
+  if (const auto *Arg = dyn_cast<Argument>(I->getOperand(0)))
+    if ((IsZExt && Arg->hasZExtAttr()) || (!IsZExt && Arg->hasSExtAttr()))
+      return true;
+
+  return false;
+}
+
 /// \brief Determine the implicit scale factor that is applied by a memory
 /// operation for a given value type.
 static unsigned getImplicitScaleFactor(MVT VT) {
@@ -585,72 +602,74 @@ bool AArch64FastISel::computeAddress(const Value *Obj, Address &Addr, Type *Ty)
     if (Addr.getOffsetReg())
       break;
 
-    if (const auto *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
-      unsigned Val = CI->getZExtValue();
-      if (Val < 1 || Val > 3)
-        break;
+    const auto *CI = dyn_cast<ConstantInt>(U->getOperand(1));
+    if (!CI)
+      break;
 
-      uint64_t NumBytes = 0;
-      if (Ty && Ty->isSized()) {
-        uint64_t NumBits = DL.getTypeSizeInBits(Ty);
-        NumBytes = NumBits / 8;
-        if (!isPowerOf2_64(NumBits))
-          NumBytes = 0;
-      }
+    unsigned Val = CI->getZExtValue();
+    if (Val < 1 || Val > 3)
+      break;
 
-      if (NumBytes != (1ULL << Val))
-        break;
+    uint64_t NumBytes = 0;
+    if (Ty && Ty->isSized()) {
+      uint64_t NumBits = DL.getTypeSizeInBits(Ty);
+      NumBytes = NumBits / 8;
+      if (!isPowerOf2_64(NumBits))
+        NumBytes = 0;
+    }
+
+    if (NumBytes != (1ULL << Val))
+      break;
 
-      Addr.setShift(Val);
-      Addr.setExtendType(AArch64_AM::LSL);
+    Addr.setShift(Val);
+    Addr.setExtendType(AArch64_AM::LSL);
 
-      const Value *Src = U->getOperand(0);
-      if (const auto *I = dyn_cast<Instruction>(Src))
-        if (FuncInfo.MBBMap[I->getParent()] == FuncInfo.MBB)
-          Src = I;
+    const Value *Src = U->getOperand(0);
+    if (const auto *I = dyn_cast<Instruction>(Src))
+      if (FuncInfo.MBBMap[I->getParent()] == FuncInfo.MBB)
+        Src = I;
 
-      if (const auto *ZE = dyn_cast<ZExtInst>(Src)) {
-        if (ZE->getOperand(0)->getType()->isIntegerTy(32)) {
+    // Fold the zext or sext when it won't become a noop.
+    if (const auto *ZE = dyn_cast<ZExtInst>(Src)) {
+      if (!isIntExtFree(ZE) && ZE->getOperand(0)->getType()->isIntegerTy(32)) {
           Addr.setExtendType(AArch64_AM::UXTW);
           Src = ZE->getOperand(0);
-        }
-      } else if (const auto *SE = dyn_cast<SExtInst>(Src)) {
-        if (SE->getOperand(0)->getType()->isIntegerTy(32)) {
-          Addr.setExtendType(AArch64_AM::SXTW);
-          Src = SE->getOperand(0);
-        }
       }
+    } else if (const auto *SE = dyn_cast<SExtInst>(Src)) {
+      if (!isIntExtFree(SE) && SE->getOperand(0)->getType()->isIntegerTy(32)) {
+        Addr.setExtendType(AArch64_AM::SXTW);
+        Src = SE->getOperand(0);
+      }
+    }
 
-      if (const auto *AI = dyn_cast<BinaryOperator>(Src))
-        if (AI->getOpcode() == Instruction::And) {
-          const Value *LHS = AI->getOperand(0);
-          const Value *RHS = AI->getOperand(1);
-
-          if (const auto *C = dyn_cast<ConstantInt>(LHS))
-            if (C->getValue() == 0xffffffff)
-              std::swap(LHS, RHS);
-
-          if (const auto *C = dyn_cast<ConstantInt>(RHS))
-            if (C->getValue() == 0xffffffff) {
-              Addr.setExtendType(AArch64_AM::UXTW);
-              unsigned Reg = getRegForValue(LHS);
-              if (!Reg)
-                return false;
-              bool RegIsKill = hasTrivialKill(LHS);
-              Reg = fastEmitInst_extractsubreg(MVT::i32, Reg, RegIsKill,
-                                               AArch64::sub_32);
-              Addr.setOffsetReg(Reg);
-              return true;
-            }
-        }
+    if (const auto *AI = dyn_cast<BinaryOperator>(Src))
+      if (AI->getOpcode() == Instruction::And) {
+        const Value *LHS = AI->getOperand(0);
+        const Value *RHS = AI->getOperand(1);
 
-      unsigned Reg = getRegForValue(Src);
-      if (!Reg)
-        return false;
-      Addr.setOffsetReg(Reg);
-      return true;
-    }
-    break;
+        if (const auto *C = dyn_cast<ConstantInt>(LHS))
+          if (C->getValue() == 0xffffffff)
+            std::swap(LHS, RHS);
+
+        if (const auto *C = dyn_cast<ConstantInt>(RHS))
+          if (C->getValue() == 0xffffffff) {
+            Addr.setExtendType(AArch64_AM::UXTW);
+            unsigned Reg = getRegForValue(LHS);
+            if (!Reg)
+              return false;
+            bool RegIsKill = hasTrivialKill(LHS);
+            Reg = fastEmitInst_extractsubreg(MVT::i32, Reg, RegIsKill,
+                                             AArch64::sub_32);
+            Addr.setOffsetReg(Reg);
+            return true;
+          }
+      }
+
+    unsigned Reg = getRegForValue(Src);
+    if (!Reg)
+      return false;
+    Addr.setOffsetReg(Reg);
+    return true;
   }
   case Instruction::Mul: {
     if (Addr.getOffsetReg())
@@ -692,13 +711,15 @@ bool AArch64FastISel::computeAddress(const Value *Obj, Address &Addr, Type *Ty)
       if (FuncInfo.MBBMap[I->getParent()] == FuncInfo.MBB)
         Src = I;
 
+
+    // Fold the zext or sext when it won't become a noop.
     if (const auto *ZE = dyn_cast<ZExtInst>(Src)) {
-      if (ZE->getOperand(0)->getType()->isIntegerTy(32)) {
+      if (!isIntExtFree(ZE) && ZE->getOperand(0)->getType()->isIntegerTy(32)) {
         Addr.setExtendType(AArch64_AM::UXTW);
         Src = ZE->getOperand(0);
       }
     } else if (const auto *SE = dyn_cast<SExtInst>(Src)) {
-      if (SE->getOperand(0)->getType()->isIntegerTy(32)) {
+      if (!isIntExtFree(SE) && SE->getOperand(0)->getType()->isIntegerTy(32)) {
         Addr.setExtendType(AArch64_AM::SXTW);
         Src = SE->getOperand(0);
       }
@@ -1568,7 +1589,7 @@ unsigned AArch64FastISel::emitAnd_ri(MVT RetVT, unsigned LHSReg, bool LHSIsKill,
 }
 
 bool AArch64FastISel::emitLoad(MVT VT, unsigned &ResultReg, Address Addr,
-                               MachineMemOperand *MMO) {
+                               bool WantZExt, MachineMemOperand *MMO) {
   // Simplify this down to something we can handle.
   if (!simplifyAddress(Addr, VT))
     return false;
@@ -1585,20 +1606,38 @@ bool AArch64FastISel::emitLoad(MVT VT, unsigned &ResultReg, Address Addr,
     ScaleFactor = 1;
   }
 
-  static const unsigned OpcTable[4][6] = {
-    { AArch64::LDURBBi,  AArch64::LDURHHi,  AArch64::LDURWi,  AArch64::LDURXi,
-      AArch64::LDURSi,   AArch64::LDURDi },
-    { AArch64::LDRBBui,  AArch64::LDRHHui,  AArch64::LDRWui,  AArch64::LDRXui,
-      AArch64::LDRSui,   AArch64::LDRDui },
-    { AArch64::LDRBBroX, AArch64::LDRHHroX, AArch64::LDRWroX, AArch64::LDRXroX,
-      AArch64::LDRSroX,  AArch64::LDRDroX },
-    { AArch64::LDRBBroW, AArch64::LDRHHroW, AArch64::LDRWroW, AArch64::LDRXroW,
-      AArch64::LDRSroW,  AArch64::LDRDroW }
+  static const unsigned GPOpcTable[2][4][4] = {
+    // Sign-extend.
+    { { AArch64::LDURSBWi,  AArch64::LDURSHWi,  AArch64::LDURSWi,
+        AArch64::LDURXi  },
+      { AArch64::LDRSBWui,  AArch64::LDRSHWui,  AArch64::LDRSWui,
+        AArch64::LDRXui  },
+      { AArch64::LDRSBWroX, AArch64::LDRSHWroX, AArch64::LDRSWroX,
+        AArch64::LDRXroX },
+      { AArch64::LDRSBWroW, AArch64::LDRSHWroW, AArch64::LDRSWroW,
+        AArch64::LDRXroW },
+    },
+    // Zero-extend.
+    { { AArch64::LDURBBi,   AArch64::LDURHHi,   AArch64::LDURWi,
+        AArch64::LDURXi  },
+      { AArch64::LDRBBui,   AArch64::LDRHHui,   AArch64::LDRWui,
+        AArch64::LDRXui  },
+      { AArch64::LDRBBroX,  AArch64::LDRHHroX,  AArch64::LDRWroX,
+        AArch64::LDRXroX },
+      { AArch64::LDRBBroW,  AArch64::LDRHHroW,  AArch64::LDRWroW,
+        AArch64::LDRXroW }
+    }
+  };
+
+  static const unsigned FPOpcTable[4][2] = {
+    { AArch64::LDURSi,  AArch64::LDURDi  },
+    { AArch64::LDRSui,  AArch64::LDRDui  },
+    { AArch64::LDRSroX, AArch64::LDRDroX },
+    { AArch64::LDRSroW, AArch64::LDRDroW }
   };
 
   unsigned Opc;
   const TargetRegisterClass *RC;
-  bool VTIsi1 = false;
   bool UseRegOffset = Addr.isRegBase() && !Addr.getOffset() && Addr.getReg() &&
                       Addr.getOffsetReg();
   unsigned Idx = UseRegOffset ? 2 : UseScaled ? 1 : 0;
@@ -1607,14 +1646,33 @@ bool AArch64FastISel::emitLoad(MVT VT, unsigned &ResultReg, Address Addr,
     Idx++;
 
   switch (VT.SimpleTy) {
-  default: llvm_unreachable("Unexpected value type.");
-  case MVT::i1:  VTIsi1 = true; // Intentional fall-through.
-  case MVT::i8:  Opc = OpcTable[Idx][0]; RC = &AArch64::GPR32RegClass; break;
-  case MVT::i16: Opc = OpcTable[Idx][1]; RC = &AArch64::GPR32RegClass; break;
-  case MVT::i32: Opc = OpcTable[Idx][2]; RC = &AArch64::GPR32RegClass; break;
-  case MVT::i64: Opc = OpcTable[Idx][3]; RC = &AArch64::GPR64RegClass; break;
-  case MVT::f32: Opc = OpcTable[Idx][4]; RC = &AArch64::FPR32RegClass; break;
-  case MVT::f64: Opc = OpcTable[Idx][5]; RC = &AArch64::FPR64RegClass; break;
+  default:
+    llvm_unreachable("Unexpected value type.");
+  case MVT::i1: // Intentional fall-through.
+  case MVT::i8:
+    Opc = GPOpcTable[WantZExt][Idx][0];
+    RC = &AArch64::GPR32RegClass;
+    break;
+  case MVT::i16:
+    Opc = GPOpcTable[WantZExt][Idx][1];
+    RC = &AArch64::GPR32RegClass;
+    break;
+  case MVT::i32:
+    Opc = GPOpcTable[WantZExt][Idx][2];
+    RC = WantZExt ? &AArch64::GPR32RegClass : &AArch64::GPR64RegClass;
+    break;
+  case MVT::i64:
+    Opc = GPOpcTable[WantZExt][Idx][3];
+    RC = &AArch64::GPR64RegClass;
+    break;
+  case MVT::f32:
+    Opc = FPOpcTable[Idx][0];
+    RC = &AArch64::FPR32RegClass;
+    break;
+  case MVT::f64:
+    Opc = FPOpcTable[Idx][1];
+    RC = &AArch64::FPR64RegClass;
+    break;
   }
 
   // Create the base instruction, then add the operands.
@@ -1623,8 +1681,14 @@ bool AArch64FastISel::emitLoad(MVT VT, unsigned &ResultReg, Address Addr,
                                     TII.get(Opc), ResultReg);
   addLoadStoreOperands(Addr, MIB, MachineMemOperand::MOLoad, ScaleFactor, MMO);
 
+  // For 32bit loads we do sign-extending loads to 64bit and then extract the
+  // subreg. In the end this is just a NOOP.
+  if (VT == MVT::i32 && !WantZExt)
+    ResultReg = fastEmitInst_extractsubreg(MVT::i32, ResultReg, /*IsKill=*/true,
+                                           AArch64::sub_32);
+
   // Loading an i1 requires special handling.
-  if (VTIsi1) {
+  if (VT == MVT::i1) {
     unsigned ANDReg = emitAnd_ri(MVT::i32, ResultReg, /*IsKill=*/true, 1);
     assert(ANDReg && "Unexpected AND instruction emission failure.");
     ResultReg = ANDReg;
@@ -1701,8 +1765,12 @@ bool AArch64FastISel::selectLoad(const Instruction *I) {
   if (!computeAddress(I->getOperand(0), Addr, I->getType()))
     return false;
 
+  bool WantZExt = true;
+  if (I->hasOneUse() && isa<SExtInst>(I->use_begin()->getUser()))
+    WantZExt = false;
+
   unsigned ResultReg;
-  if (!emitLoad(VT, ResultReg, Addr, createMachineMemOperandFor(I)))
+  if (!emitLoad(VT, ResultReg, Addr, WantZExt, createMachineMemOperandFor(I)))
     return false;
 
   updateValueMap(I, ResultReg);
@@ -3776,46 +3844,60 @@ unsigned AArch64FastISel::emitIntExt(MVT SrcVT, unsigned SrcReg, MVT DestVT,
 }
 
 bool AArch64FastISel::selectIntExt(const Instruction *I) {
-  // On ARM, in general, integer casts don't involve legal types; this code
-  // handles promotable integers.  The high bits for a type smaller than
-  // the register size are assumed to be undefined.
-  Type *DestTy = I->getType();
-  Value *Src = I->getOperand(0);
-  Type *SrcTy = Src->getType();
-
-  unsigned SrcReg = getRegForValue(Src);
-  if (!SrcReg)
+  assert((isa<ZExtInst>(I) || isa<SExtInst>(I)) &&
+         "Unexpected integer extend instruction.");
+  MVT RetVT;
+  MVT SrcVT;
+  if (!isTypeSupported(I->getType(), RetVT))
     return false;
 
-  EVT SrcEVT = TLI.getValueType(SrcTy, true);
-  EVT DestEVT = TLI.getValueType(DestTy, true);
-  if (!SrcEVT.isSimple())
-    return false;
-  if (!DestEVT.isSimple())
+  if (!isTypeSupported(I->getOperand(0)->getType(), SrcVT))
     return false;
 
-  MVT SrcVT = SrcEVT.getSimpleVT();
-  MVT DestVT = DestEVT.getSimpleVT();
-  unsigned ResultReg = 0;
+  if (isIntExtFree(I)) {
+    unsigned SrcReg = getRegForValue(I->getOperand(0));
+    if (!SrcReg)
+      return false;
+    bool SrcIsKill = hasTrivialKill(I->getOperand(0));
 
-  bool IsZExt = isa<ZExtInst>(I);
-  // Check if it is an argument and if it is already zero/sign-extended.
-  if (const auto *Arg = dyn_cast<Argument>(Src)) {
-    if ((IsZExt && Arg->hasZExtAttr()) || (!IsZExt && Arg->hasSExtAttr())) {
-      if (DestVT == MVT::i64) {
-        ResultReg = createResultReg(TLI.getRegClassFor(DestVT));
-        BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
-                TII.get(AArch64::SUBREG_TO_REG), ResultReg)
+    const TargetRegisterClass *RC = (RetVT == MVT::i64) ?
+        &AArch64::GPR64RegClass : &AArch64::GPR32RegClass;
+    unsigned ResultReg = createResultReg(RC);
+    if (RetVT == MVT::i64 && SrcVT != MVT::i64) {
+      BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
+              TII.get(AArch64::SUBREG_TO_REG), ResultReg)
           .addImm(0)
-          .addReg(SrcReg)
+          .addReg(SrcReg, getKillRegState(SrcIsKill))
           .addImm(AArch64::sub_32);
-      } else
-        ResultReg = SrcReg;
+    } else {
+      BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
+              TII.get(TargetOpcode::COPY), ResultReg)
+          .addReg(SrcReg, getKillRegState(SrcIsKill));
     }
+    updateValueMap(I, ResultReg);
+    return true;
+  }
+
+  unsigned SrcReg = getRegForValue(I->getOperand(0));
+  if (!SrcReg)
+    return false;
+  bool SrcRegIsKill = hasTrivialKill(I->getOperand(0));
+
+  unsigned ResultReg = 0;
+  if (isIntExtFree(I)) {
+    if (RetVT == MVT::i64) {
+      ResultReg = createResultReg(&AArch64::GPR64RegClass);
+      BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc,
+              TII.get(AArch64::SUBREG_TO_REG), ResultReg)
+          .addImm(0)
+          .addReg(SrcReg, getKillRegState(SrcRegIsKill))
+          .addImm(AArch64::sub_32);
+    } else
+      ResultReg = SrcReg;
   }
 
   if (!ResultReg)
-    ResultReg = emitIntExt(SrcVT, SrcReg, DestVT, IsZExt);
+    ResultReg = emitIntExt(SrcVT, SrcReg, RetVT, isa<ZExtInst>(I));
 
   if (!ResultReg)
     return false;
@@ -3891,18 +3973,22 @@ bool AArch64FastISel::selectMul(const Instruction *I) {
       MVT SrcVT = VT;
       bool IsZExt = true;
       if (const auto *ZExt = dyn_cast<ZExtInst>(Src0)) {
-        MVT VT;
-        if (isValueAvailable(ZExt) && isTypeSupported(ZExt->getSrcTy(), VT)) {
-          SrcVT = VT;
-          IsZExt = true;
-          Src0 = ZExt->getOperand(0);
+        if (!isIntExtFree(ZExt)) {
+          MVT VT;
+          if (isValueAvailable(ZExt) && isTypeSupported(ZExt->getSrcTy(), VT)) {
+            SrcVT = VT;
+            IsZExt = true;
+            Src0 = ZExt->getOperand(0);
+          }
         }
       } else if (const auto *SExt = dyn_cast<SExtInst>(Src0)) {
-        MVT VT;
-        if (isValueAvailable(SExt) && isTypeSupported(SExt->getSrcTy(), VT)) {
-          SrcVT = VT;
-          IsZExt = false;
-          Src0 = SExt->getOperand(0);
+        if (!isIntExtFree(SExt)) {
+          MVT VT;
+          if (isValueAvailable(SExt) && isTypeSupported(SExt->getSrcTy(), VT)) {
+            SrcVT = VT;
+            IsZExt = false;
+            Src0 = SExt->getOperand(0);
+          }
         }
       }
 
@@ -3954,18 +4040,22 @@ bool AArch64FastISel::selectShift(const Instruction *I) {
     bool IsZExt = (I->getOpcode() == Instruction::AShr) ? false : true;
     const Value *Op0 = I->getOperand(0);
     if (const auto *ZExt = dyn_cast<ZExtInst>(Op0)) {
-      MVT TmpVT;
-      if (isValueAvailable(ZExt) && isTypeSupported(ZExt->getSrcTy(), TmpVT)) {
-        SrcVT = TmpVT;
-        IsZExt = true;
-        Op0 = ZExt->getOperand(0);
+      if (!isIntExtFree(ZExt)) {
+        MVT TmpVT;
+        if (isValueAvailable(ZExt) && isTypeSupported(ZExt->getSrcTy(), TmpVT)) {
+          SrcVT = TmpVT;
+          IsZExt = true;
+          Op0 = ZExt->getOperand(0);
+        }
       }
     } else if (const auto *SExt = dyn_cast<SExtInst>(Op0)) {
-      MVT TmpVT;
-      if (isValueAvailable(SExt) && isTypeSupported(SExt->getSrcTy(), TmpVT)) {
-        SrcVT = TmpVT;
-        IsZExt = false;
-        Op0 = SExt->getOperand(0);
+      if (!isIntExtFree(SExt)) {
+        MVT TmpVT;
+        if (isValueAvailable(SExt) && isTypeSupported(SExt->getSrcTy(), TmpVT)) {
+          SrcVT = TmpVT;
+          IsZExt = false;
+          Op0 = SExt->getOperand(0);
+        }
       }
     }
 
@@ -4213,13 +4303,8 @@ bool AArch64FastISel::fastSelectInstruction(const Instruction *I) {
   case Instruction::FPToUI:
     return selectFPToInt(I, /*Signed=*/false);
   case Instruction::ZExt:
-    if (!selectCast(I, ISD::ZERO_EXTEND))
-      return selectIntExt(I);
-    return true;
   case Instruction::SExt:
-    if (!selectCast(I, ISD::SIGN_EXTEND))
-      return selectIntExt(I);
-    return true;
+    return selectIntExt(I);
   case Instruction::Trunc:
     if (!selectCast(I, ISD::TRUNCATE))
       return selectTrunc(I);