[AAarch64] Optimize CSINC-branch sequence
[oota-llvm.git] / lib / Target / AArch64 / AArch64InstrInfo.cpp
index f984eb1ad15e719de9a7d6d877822f52bf414788..c32a1e95faaf5d1e17a683c156b6af4585ab33ef 100644 (file)
@@ -768,6 +768,39 @@ static unsigned convertFlagSettingOpcode(MachineInstr *MI) {
     return NewOpc;
 }
 
+/// True when condition code could be modified on the instruction
+/// trace starting at from and ending at to.
+static bool modifiesConditionCode(MachineInstr *From, MachineInstr *To,
+                                  const bool CheckOnlyCCWrites,
+                                  const TargetRegisterInfo *TRI) {
+  // We iterate backward starting \p To until we hit \p From
+  MachineBasicBlock::iterator I = To, E = From, B = To->getParent()->begin();
+
+  // Early exit if To is at the beginning of the BB.
+  if (I == B)
+    return true;
+
+  // Check whether the definition of SrcReg is in the same basic block as
+  // Compare. If not, assume the condition code gets modified on some path.
+  if (To->getParent() != From->getParent())
+    return true;
+
+  // Check that NZCV isn't set on the trace.
+  for (--I; I != E; --I) {
+    const MachineInstr &Instr = *I;
+
+    if (Instr.modifiesRegister(AArch64::NZCV, TRI) ||
+        (!CheckOnlyCCWrites && Instr.readsRegister(AArch64::NZCV, TRI)))
+      // This instruction modifies or uses NZCV after the one we want to
+      // change.
+      return true;
+    if (I == B)
+      // We currently don't allow the instruction trace to cross basic
+      // block boundaries
+      return true;
+  }
+  return false;
+}
 /// optimizeCompareInstr - Convert the instruction supplying the argument to the
 /// comparison into one that sets the zero bit in the flags register.
 bool AArch64InstrInfo::optimizeCompareInstr(
@@ -806,36 +839,10 @@ bool AArch64InstrInfo::optimizeCompareInstr(
   if (!MI)
     return false;
 
-  // We iterate backward, starting from the instruction before CmpInstr and
-  // stop when reaching the definition of the source register or done with the
-  // basic block, to check whether NZCV is used or modified in between.
-  MachineBasicBlock::iterator I = CmpInstr, E = MI,
-                              B = CmpInstr->getParent()->begin();
-
-  // Early exit if CmpInstr is at the beginning of the BB.
-  if (I == B)
-    return false;
-
-  // Check whether the definition of SrcReg is in the same basic block as
-  // Compare. If not, we can't optimize away the Compare.
-  if (MI->getParent() != CmpInstr->getParent())
-    return false;
-
-  // Check that NZCV isn't set between the comparison instruction and the one we
-  // want to change.
+  bool CheckOnlyCCWrites = false;
   const TargetRegisterInfo *TRI = &getRegisterInfo();
-  for (--I; I != E; --I) {
-    const MachineInstr &Instr = *I;
-
-    if (Instr.modifiesRegister(AArch64::NZCV, TRI) ||
-        Instr.readsRegister(AArch64::NZCV, TRI))
-      // This instruction modifies or uses NZCV after the one we want to
-      // change. We can't do this transformation.
-      return false;
-    if (I == B)
-      // The 'and' is below the comparison instruction.
-      return false;
-  }
+  if (modifiesConditionCode(MI, CmpInstr, CheckOnlyCCWrites, TRI))
+    return false;
 
   unsigned NewOpc = MI->getOpcode();
   switch (MI->getOpcode()) {
@@ -2830,3 +2837,103 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
 
   return;
 }
+
+/// \brief Replace csincr-branch sequence by simple conditional branch
+///
+/// Examples:
+/// 1.
+///   csinc  w9, wzr, wzr, <condition code>
+///   tbnz   w9, #0, 0x44
+/// to
+///   b.<inverted condition code>
+///
+/// 2.
+///   csinc w9, wzr, wzr, <condition code>
+///   tbz   w9, #0, 0x44
+/// to
+///   b.<condition code>
+///
+/// \param  MI Conditional Branch
+/// \return True when the simple conditional branch is generated
+///
+bool AArch64InstrInfo::optimizeCondBranch(MachineInstr *MI) const {
+  bool IsNegativeBranch = false;
+  bool IsTestAndBranch = false;
+  unsigned TargetBBInMI = 0;
+  unsigned CCInMI = 0;
+  switch (MI->getOpcode()) {
+  default:
+    llvm_unreachable("Unknown branch instruction?");
+  case AArch64::Bcc:
+    return false;
+  case AArch64::CBZW:
+  case AArch64::CBZX:
+    TargetBBInMI = 1;
+    CCInMI = 2;
+    break;
+  case AArch64::CBNZW:
+  case AArch64::CBNZX:
+    TargetBBInMI = 1;
+    CCInMI = 2;
+    IsNegativeBranch = true;
+    break;
+  case AArch64::TBZW:
+  case AArch64::TBZX:
+    TargetBBInMI = 2;
+    CCInMI = 3;
+    IsTestAndBranch = true;
+    break;
+  case AArch64::TBNZW:
+  case AArch64::TBNZX:
+    TargetBBInMI = 2;
+    CCInMI = 3;
+    IsNegativeBranch = true;
+    IsTestAndBranch = true;
+    break;
+  }
+  // So we increment a zero register and test for bits other
+  // than bit 0? Conservatively bail out in case the verifier
+  // missed this case.
+  if (IsTestAndBranch && MI->getOperand(1).getImm())
+    return false;
+
+  // Find Definition.
+  assert(MI->getParent() && "Incomplete machine instruciton\n");
+  MachineBasicBlock *MBB = MI->getParent();
+  MachineFunction *MF = MBB->getParent();
+  MachineRegisterInfo *MRI = &MF->getRegInfo();
+  unsigned VReg = MI->getOperand(0).getReg();
+  if (!TargetRegisterInfo::isVirtualRegister(VReg))
+    return false;
+
+  MachineInstr *DefMI = MRI->getVRegDef(VReg);
+
+  // Look for CSINC
+  if (!(DefMI->getOpcode() == AArch64::CSINCWr &&
+        DefMI->getOperand(1).getReg() == AArch64::WZR &&
+        DefMI->getOperand(2).getReg() == AArch64::WZR) &&
+      !(DefMI->getOpcode() == AArch64::CSINCXr &&
+        DefMI->getOperand(1).getReg() == AArch64::XZR &&
+        DefMI->getOperand(2).getReg() == AArch64::XZR))
+    return false;
+
+  if (DefMI->findRegisterDefOperandIdx(AArch64::NZCV, true) != -1)
+    return false;
+
+  AArch64CC::CondCode CC =
+      (AArch64CC::CondCode)DefMI->getOperand(CCInMI).getImm();
+  bool CheckOnlyCCWrites = true;
+  // Convert only when the condition code is not modified between
+  // the CSINC and the branch. The CC may be used by other
+  // instructions in between.
+  if (modifiesConditionCode(DefMI, MI, CheckOnlyCCWrites, &getRegisterInfo()))
+    return false;
+  MachineBasicBlock &RefToMBB = *MBB;
+  MachineBasicBlock *TBB = MI->getOperand(TargetBBInMI).getMBB();
+  DebugLoc DL = MI->getDebugLoc();
+  if (IsNegativeBranch)
+    CC = AArch64CC::getInvertedCondCode(CC);
+  BuildMI(RefToMBB, MI, DL, get(AArch64::Bcc)).addImm(CC).addMBB(TBB);
+  MI->eraseFromParent();
+  return true;
+}