Gracefully degrade precision in branch probability numbers.
authorNick Lewycky <nicholas@mxc.ca>
Wed, 25 Jan 2012 09:43:14 +0000 (09:43 +0000)
committerNick Lewycky <nicholas@mxc.ca>
Wed, 25 Jan 2012 09:43:14 +0000 (09:43 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@148946 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Utils/SimplifyCFG.cpp

index 326bd7a2bd08fcc25262d5073eb3f05a1aaba779..d5ae609679f26a49ce83891059cbb09114a6a3e2 100644 (file)
@@ -1466,6 +1466,29 @@ static bool ExtractBranchMetadata(BranchInst *BI,
   return true;
 }
 
+/// MultiplyAndLosePrecision - Multiplies A and B, then returns the result. In
+/// the event of overflow, logically-shifts all four inputs right until the
+/// multiply fits.
+static APInt MultiplyAndLosePrecision(APInt &A, APInt &B, APInt &C, APInt &D,
+                                      unsigned &BitsLost) {
+  BitsLost = 0;
+  bool Overflow = false;
+  APInt Result = A.umul_ov(B, Overflow);
+  if (Overflow) {
+    APInt MaxB = APInt::getMaxValue(A.getBitWidth()).udiv(A);
+    do {
+      B = B.lshr(1);
+      ++BitsLost;
+    } while (B.ugt(MaxB));
+    A = A.lshr(BitsLost);
+    C = C.lshr(BitsLost);
+    D = D.lshr(BitsLost);
+    Result = A * B;
+  }
+  return Result;
+}
+
+
 /// FoldBranchToCommonDest - If this basic block is simple enough, and if a
 /// predecessor branches to us and one of our successors, fold the block into
 /// the predecessor and use logical operations to pick the right destination.
@@ -1665,32 +1688,64 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI) {
       // we get:
       // (a*c)% = A*C, (b+(a*d))% = A*D+B*C+B*D.
 
-      bool Overflow1 = false, Overflow2 = false, Overflow3 = false;
-      bool Overflow4 = false, Overflow5 = false, Overflow6 = false;
-      APInt ProbTrue = A.umul_ov(C, Overflow1);
+      // In the event of overflow, we want to drop the LSB of the input
+      // probabilities.
+      unsigned BitsLost;
 
-      APInt Tmp1 = A.umul_ov(D, Overflow2);
-      APInt Tmp2 = B.umul_ov(C, Overflow3);
-      APInt Tmp3 = B.umul_ov(D, Overflow4);
-      APInt Tmp4 = Tmp1.uadd_ov(Tmp2, Overflow5);
-      APInt ProbFalse = Tmp4.uadd_ov(Tmp3, Overflow6);
+      // Ignore overflow result on ProbTrue.
+      APInt ProbTrue = MultiplyAndLosePrecision(A, C, B, D, BitsLost);
 
-      APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
-      ProbTrue = ProbTrue.udiv(GCD);
-      ProbFalse = ProbFalse.udiv(GCD);
+      APInt Tmp1 = MultiplyAndLosePrecision(B, D, A, C, BitsLost);
+      if (BitsLost) {
+        ProbTrue = ProbTrue.lshr(BitsLost*2);
+      }
+
+      APInt Tmp2 = MultiplyAndLosePrecision(A, D, C, B, BitsLost);
+      if (BitsLost) {
+        ProbTrue = ProbTrue.lshr(BitsLost*2);
+        Tmp1 = Tmp1.lshr(BitsLost*2);
+      }
+
+      APInt Tmp3 = MultiplyAndLosePrecision(B, C, A, D, BitsLost);
+      if (BitsLost) {
+        ProbTrue = ProbTrue.lshr(BitsLost*2);
+        Tmp1 = Tmp1.lshr(BitsLost*2);
+        Tmp2 = Tmp2.lshr(BitsLost*2);
+      }
+
+      bool Overflow1 = false, Overflow2 = false;
+      APInt Tmp4 = Tmp2.uadd_ov(Tmp3, Overflow1);
+      APInt ProbFalse = Tmp4.uadd_ov(Tmp1, Overflow2);
+
+      if (Overflow1 || Overflow2) {
+        ProbTrue = ProbTrue.lshr(1);
+        Tmp1 = Tmp1.lshr(1);
+        Tmp2 = Tmp2.lshr(1);
+        Tmp3 = Tmp3.lshr(1);
+        Tmp4 = Tmp2 + Tmp3;
+        ProbFalse = Tmp4 + Tmp1;
+      }
+
+      // The sum of branch weights must fit in 32-bits.
+      if (ProbTrue.isNegative() && ProbFalse.isNegative()) {
+        ProbTrue = ProbTrue.lshr(1);
+        ProbFalse = ProbFalse.lshr(1);
+      }
+
+      if (ProbTrue != ProbFalse) {
+        // Normalize the result.
+        APInt GCD = APIntOps::GreatestCommonDivisor(ProbTrue, ProbFalse);
+        ProbTrue = ProbTrue.udiv(GCD);
+        ProbFalse = ProbFalse.udiv(GCD);
 
-      if (Overflow1 || Overflow2 || Overflow3 || Overflow4 || Overflow5 ||
-          Overflow6) {
-        DEBUG(dbgs() << "Overflow recomputing branch weight on: " << *PBI
-                     << "when merging with: " << *BI);
-        PBI->setMetadata(LLVMContext::MD_prof, NULL);
-      } else {
         LLVMContext &Context = BI->getContext();
         Value *Ops[3];
         Ops[0] = BI->getMetadata(LLVMContext::MD_prof)->getOperand(0);
         Ops[1] = ConstantInt::get(Context, ProbTrue);
         Ops[2] = ConstantInt::get(Context, ProbFalse);
         PBI->setMetadata(LLVMContext::MD_prof, MDNode::get(Context, Ops));
+      } else {
+        PBI->setMetadata(LLVMContext::MD_prof, NULL);
       }
     } else {
       PBI->setMetadata(LLVMContext::MD_prof, NULL);