Record whether the weights on out-edges from a MBB are normalized.
[oota-llvm.git] / lib / CodeGen / MachineBranchProbabilityInfo.cpp
index 6fbc2be70486ad3ff2059ebae14fa9aa1edd56f2..fe03d4d0b5fcd98cf43c289a98b64b4c635401e7 100644 (file)
@@ -28,36 +28,35 @@ char MachineBranchProbabilityInfo::ID = 0;
 
 void MachineBranchProbabilityInfo::anchor() { }
 
-uint32_t MachineBranchProbabilityInfo::
-getSumForBlock(const MachineBasicBlock *MBB, uint32_t &Scale) const {
-  // First we compute the sum with 64-bits of precision, ensuring that cannot
-  // overflow by bounding the number of weights considered. Hopefully no one
-  // actually needs 2^32 successors.
-  assert(MBB->succ_size() < UINT32_MAX);
-  uint64_t Sum = 0;
-  Scale = 1;
+uint32_t
+MachineBranchProbabilityInfo::getSumForBlock(MachineBasicBlock *MBB) const {
+  // Normalize the weights of MBB's all successors so that the sum is guaranteed
+  // to be no greater than UINT32_MAX.
+  MBB->normalizeSuccWeights();
+
+  SmallVector<uint32_t, 8> Weights;
   for (MachineBasicBlock::const_succ_iterator I = MBB->succ_begin(),
-       E = MBB->succ_end(); I != E; ++I) {
-    uint32_t Weight = getEdgeWeight(MBB, I);
-    Sum += Weight;
-  }
+                                              E = MBB->succ_end();
+       I != E; ++I)
+    Weights.push_back(getEdgeWeight(MBB, I));
 
-  // If the computed sum fits in 32-bits, we're done.
-  if (Sum <= UINT32_MAX)
-    return Sum;
+  return std::accumulate(Weights.begin(), Weights.end(), 0u);
+}
 
-  // Otherwise, compute the scale necessary to cause the weights to fit, and
-  // re-sum with that scale applied.
-  assert((Sum / UINT32_MAX) < UINT32_MAX);
-  Scale = (Sum / UINT32_MAX) + 1;
-  Sum = 0;
+uint32_t
+MachineBranchProbabilityInfo::getSumForBlock(const MachineBasicBlock *MBB,
+                                             uint32_t &Scale) const {
+  SmallVector<uint32_t, 8> Weights;
   for (MachineBasicBlock::const_succ_iterator I = MBB->succ_begin(),
-       E = MBB->succ_end(); I != E; ++I) {
-    uint32_t Weight = getEdgeWeight(MBB, I);
-    Sum += Weight / Scale;
-  }
-  assert(Sum <= UINT32_MAX);
-  return Sum;
+                                              E = MBB->succ_end();
+       I != E; ++I)
+    Weights.push_back(getEdgeWeight(MBB, I));
+
+  if (MBB->areSuccWeightsNormalized())
+    Scale = 1;
+  else
+    Scale = MachineBranchProbabilityInfo::normalizeEdgeWeights(Weights);
+  return std::accumulate(Weights.begin(), Weights.end(), 0u);
 }
 
 uint32_t MachineBranchProbabilityInfo::