PGO: preserve branch-weight metadata when removing a case which jumps
[oota-llvm.git] / lib / Transforms / Utils / Local.cpp
index 060143356534aedb3341956d4cf51d600d29ba7f..dc6b506a3bac75fcdeeb95480c4e90f0ebf56ccf 100644 (file)
@@ -23,6 +23,7 @@
 #include "llvm/Instructions.h"
 #include "llvm/IntrinsicInst.h"
 #include "llvm/Intrinsics.h"
+#include "llvm/MDBuilder.h"
 #include "llvm/Metadata.h"
 #include "llvm/Operator.h"
 #include "llvm/ADT/DenseMap.h"
@@ -122,6 +123,27 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
       // Check to see if this branch is going to the same place as the default
       // dest.  If so, eliminate it as an explicit compare.
       if (i.getCaseSuccessor() == DefaultDest) {
+        MDNode* MD = SI->getMetadata(LLVMContext::MD_prof);
+        // MD should have 2 + NumCases operands.
+        if (MD && MD->getNumOperands() == 2 + SI->getNumCases()) {
+          // Collect branch weights into a vector.
+          SmallVector<uint32_t, 8> Weights;
+          for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e;
+               ++MD_i) {
+            ConstantInt* CI = dyn_cast<ConstantInt>(MD->getOperand(MD_i));
+            assert(CI);
+            Weights.push_back(CI->getValue().getZExtValue());
+          }
+          // Merge weight of this case to the default weight.
+          unsigned idx = i.getCaseIndex();
+          Weights[0] += Weights[idx+1];
+          // Remove weight for this case.
+          std::swap(Weights[idx+1], Weights.back());
+          Weights.pop_back();
+          SI->setMetadata(LLVMContext::MD_prof,
+                          MDBuilder(BB->getContext()).
+                          createBranchWeights(Weights));
+        }
         // Remove this entry.
         DefaultDest->removePredecessor(SI->getParent());
         SI->removeCase(i);