PGO: preserve branch-weight metadata when simplifying Switch to a sub, an icmp
authorManman Ren <mren@apple.com>
Tue, 18 Sep 2012 00:47:33 +0000 (00:47 +0000)
committerManman Ren <mren@apple.com>
Tue, 18 Sep 2012 00:47:33 +0000 (00:47 +0000)
and a conditional branch; also when removing dead cases from a switch.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@164084 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Utils/SimplifyCFG.cpp
test/Transforms/SimplifyCFG/preserve-branchweights.ll

index 1316b859e8c2457f2d9a620fffc8a2244cd29084..3365c2feed66e4ddc24ae74b0dd14cdb5658d83d 100644 (file)
@@ -2853,9 +2853,28 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) {
   if (!Offset->isNullValue())
     Sub = Builder.CreateAdd(Sub, Offset, Sub->getName()+".off");
   Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
-  Builder.CreateCondBr(
+  BranchInst *NewBI = Builder.CreateCondBr(
       Cmp, SI->case_begin().getCaseSuccessor(), SI->getDefaultDest());
 
+  // Update weight for the newly-created conditional branch.
+  SmallVector<uint64_t, 8> Weights;
+  bool HasWeights = HasBranchWeights(SI);
+  if (HasWeights) {
+    GetBranchWeights(SI, Weights);
+    if (Weights.size() == 1 + SI->getNumCases()) {
+      // Combine all weights for the cases to be the true weight of NewBI.
+      // We assume that the sum of all weights for a Terminator can fit into 32
+      // bits.
+      uint32_t NewTrueWeight = 0;
+      for (unsigned I = 1, E = Weights.size(); I != E; ++I)
+        NewTrueWeight += (uint32_t)Weights[I];
+      NewBI->setMetadata(LLVMContext::MD_prof,
+                         MDBuilder(SI->getContext()).
+                         createBranchWeights(NewTrueWeight,
+                                             (uint32_t)Weights[0]));
+    }
+  }
+
   // Prune obsolete incoming values off the successor's PHI nodes.
   for (BasicBlock::iterator BBI = SI->case_begin().getCaseSuccessor()->begin();
        isa<PHINode>(BBI); ++BBI) {
@@ -2886,15 +2905,33 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI) {
     }
   }
 
+  SmallVector<uint64_t, 8> Weights;
+  bool HasWeight = HasBranchWeights(SI);
+  if (HasWeight) {
+    GetBranchWeights(SI, Weights);
+    HasWeight = (Weights.size() == 1 + SI->getNumCases());
+  }
+
   // Remove dead cases from the switch.
   for (unsigned I = 0, E = DeadCases.size(); I != E; ++I) {
     SwitchInst::CaseIt Case = SI->findCaseValue(DeadCases[I]);
     assert(Case != SI->case_default() &&
            "Case was not found. Probably mistake in DeadCases forming.");
+    if (HasWeight) {
+      std::swap(Weights[Case.getCaseIndex()+1], Weights.back());
+      Weights.pop_back();
+    }
+
     // Prune unused values from PHI nodes.
     Case.getCaseSuccessor()->removePredecessor(SI->getParent());
     SI->removeCase(Case);
   }
+  if (HasWeight) {
+    SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end());
+    SI->setMetadata(LLVMContext::MD_prof,
+                    MDBuilder(SI->getParent()->getContext()).
+                    createBranchWeights(MDWeights));
+  }
 
   return !DeadCases.empty();
 }
index 4b78747cb7a6b8ff49bd728a7ed5c83801e5dc08..beef5270082042cc61d754deeff57feb4f2be579 100644 (file)
@@ -250,6 +250,49 @@ end:
     ret i1 %ret
 }
 
+define void @test10(i32 %x) nounwind readnone ssp noredzone {
+entry:
+ switch i32 %x, label %lor.rhs [
+   i32 2, label %lor.end
+   i32 1, label %lor.end
+   i32 3, label %lor.end
+ ], !prof !7
+
+lor.rhs:
+ call void @helper(i32 1) nounwind
+ ret void
+
+lor.end:
+ call void @helper(i32 0) nounwind
+ ret void
+
+; CHECK: test10
+; CHECK: %x.off = add i32 %x, -1
+; CHECK: %switch = icmp ult i32 %x.off, 3
+; CHECK: br i1 %switch, label %lor.end, label %lor.rhs, !prof !8
+}
+
+; Remove dead cases from the switch.
+define void @test11(i32 %x) nounwind {
+  %i = shl i32 %x, 1
+  switch i32 %i, label %a [
+    i32 21, label %b
+    i32 24, label %c
+  ], !prof !8
+; CHECK: %cond = icmp eq i32 %i, 24
+; CHECK: br i1 %cond, label %c, label %a, !prof !9
+
+a:
+ call void @helper(i32 0) nounwind
+ ret void
+b:
+ call void @helper(i32 1) nounwind
+ ret void
+c:
+ call void @helper(i32 2) nounwind
+ ret void
+}
+
 !0 = metadata !{metadata !"branch_weights", i32 3, i32 5}
 !1 = metadata !{metadata !"branch_weights", i32 1, i32 1}
 !2 = metadata !{metadata !"branch_weights", i32 1, i32 2}
@@ -258,6 +301,7 @@ end:
 !5 = metadata !{metadata !"branch_weights", i32 7, i32 6, i32 5}
 !6 = metadata !{metadata !"branch_weights", i32 1, i32 3}
 !7 = metadata !{metadata !"branch_weights", i32 33, i32 9, i32 8, i32 7}
+!8 = metadata !{metadata !"branch_weights", i32 33, i32 9, i32 8}
 
 ; CHECK: !0 = metadata !{metadata !"branch_weights", i32 5, i32 11}
 ; CHECK: !1 = metadata !{metadata !"branch_weights", i32 1, i32 5}
@@ -267,4 +311,6 @@ end:
 ; CHECK: !5 = metadata !{metadata !"branch_weights", i32 17, i32 15} 
 ; CHECK: !6 = metadata !{metadata !"branch_weights", i32 9, i32 7}
 ; CHECK: !7 = metadata !{metadata !"branch_weights", i32 17, i32 9, i32 8, i32 7, i32 17}
-; CHECK-NOT: !8
+; CHECK: !8 = metadata !{metadata !"branch_weights", i32 24, i32 33}
+; CHECK: !9 = metadata !{metadata !"branch_weights", i32 8, i32 33}
+; CHECK-NOT: !9