[FastISel][AArch64] Fold compare with zero and branch into CBZ and CBNZ.
[oota-llvm.git] / lib / Target / AArch64 / AArch64FastISel.cpp
index 826c4c089a8fafd9c1fa21c8968c421ba7ffc9e8..da69735c8f1028a05f68a1f2c92ad04608151ed5 100644 (file)
@@ -1673,6 +1673,32 @@ static AArch64CC::CondCode getCompareCC(CmpInst::Predicate Pred) {
   }
 }
 
+/// \brief Check if the comparison against zero and the following branch can be
+/// folded into a single instruction (CBZ or CBNZ).
+static bool canFoldZeroIntoBranch(const CmpInst *CI) {
+  CmpInst::Predicate Predicate = CI->getPredicate();
+  if ((Predicate != CmpInst::ICMP_EQ) && (Predicate != CmpInst::ICMP_NE))
+    return false;
+
+  Type *Ty = CI->getOperand(0)->getType();
+  if (!Ty->isIntegerTy())
+    return false;
+
+  unsigned BW = cast<IntegerType>(Ty)->getBitWidth();
+  if (BW != 1 && BW != 8 && BW != 16 && BW != 32 && BW != 64)
+    return false;
+
+  if (const auto *C = dyn_cast<ConstantInt>(CI->getOperand(0)))
+    if (C->isNullValue())
+      return true;
+
+  if (const auto *C = dyn_cast<ConstantInt>(CI->getOperand(1)))
+    if (C->isNullValue())
+      return true;
+
+  return false;
+}
+
 bool AArch64FastISel::selectBranch(const Instruction *I) {
   const BranchInst *BI = cast<BranchInst>(I);
   if (BI->isUnconditional()) {
@@ -1706,6 +1732,44 @@ bool AArch64FastISel::selectBranch(const Instruction *I) {
         Predicate = CmpInst::getInversePredicate(Predicate);
       }
 
+      // Try to optimize comparisons against zero.
+      if (canFoldZeroIntoBranch(CI)) {
+        const Value *LHS = CI->getOperand(0);
+        const Value *RHS = CI->getOperand(1);
+
+        // Canonicalize zero values to the RHS.
+        if (const auto *C = dyn_cast<ConstantInt>(LHS))
+          if (C->isNullValue())
+            std::swap(LHS, RHS);
+
+        static const unsigned OpcTable[2][2] = {
+          {AArch64::CBZW,  AArch64::CBZX }, {AArch64::CBNZW, AArch64::CBNZX}
+        };
+        bool IsCmpNE = Predicate == CmpInst::ICMP_NE;
+        bool Is64Bit = LHS->getType()->isIntegerTy(64);
+        unsigned Opc = OpcTable[IsCmpNE][Is64Bit];
+
+        unsigned SrcReg = getRegForValue(LHS);
+        if (!SrcReg)
+          return false;
+        bool SrcIsKill = hasTrivialKill(LHS);
+
+        // Emit the combined compare and branch instruction.
+        BuildMI(*FuncInfo.MBB, FuncInfo.InsertPt, DbgLoc, TII.get(Opc))
+            .addReg(SrcReg, getKillRegState(SrcIsKill))
+            .addMBB(TBB);
+
+        // Obtain the branch weight and add the TrueBB to the successor list.
+        uint32_t BranchWeight = 0;
+        if (FuncInfo.BPI)
+          BranchWeight = FuncInfo.BPI->getEdgeWeight(BI->getParent(),
+                                                     TBB->getBasicBlock());
+        FuncInfo.MBB->addSuccessor(TBB, BranchWeight);
+
+        fastEmitBranch(FBB, DbgLoc);
+        return true;
+      }
+
       // Emit the cmp.
       if (!emitCmp(CI->getOperand(0), CI->getOperand(1), CI->isUnsigned()))
         return false;