Switch lowering: use profile info to build weight-balanced binary search trees
authorHans Wennborg <hans@hanshq.net>
Thu, 30 Apr 2015 00:57:37 +0000 (00:57 +0000)
committerHans Wennborg <hans@hanshq.net>
Thu, 30 Apr 2015 00:57:37 +0000 (00:57 +0000)
This will cause hot nodes to appear closer to the root.

The literature says building the tree like this makes it a near-optimal (in
terms of search time given key frequencies) binary search tree. In LLVM's case,
we can do up to 3 comparisons in each leaf node, so it might be better to opt
for lower tree height in some cases; that's something to look into in the
future.

Differential Revision: http://reviews.llvm.org/D9318

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@236192 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
test/CodeGen/Generic/MachineBranchProb.ll
test/CodeGen/X86/switch.ll

index db41c40..1c14d4d 100644 (file)
@@ -7986,14 +7986,41 @@ void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList,
   unsigned NumClusters = W.LastCluster - W.FirstCluster + 1;
   assert(NumClusters >= 2 && "Too small to split!");
 
-  // FIXME: When we have profile info, we might want to balance the tree based
-  // on weights instead of node count.
+  // Balance the tree based on branch weights to create a near-optimal (in terms
+  // of search time given key frequency) binary search tree. See e.g. Kurt
+  // Mehlhorn "Nearly Optimal Binary Search Trees" (1975).
+  CaseClusterIt LastLeft = W.FirstCluster;
+  CaseClusterIt FirstRight = W.LastCluster;
+  uint32_t LeftWeight = LastLeft->Weight;
+  uint32_t RightWeight = FirstRight->Weight;
+
+  // Move LastLeft and FirstRight towards each other from opposite directions to
+  // find a partitioning of the clusters which balances the weight on both
+  // sides.
+  while (LastLeft + 1 < FirstRight) {
+    // Zero-weight nodes would cause skewed trees since they don't affect
+    // LeftWeight or RightWeight.
+    assert(LastLeft->Weight != 0);
+    assert(FirstRight->Weight != 0);
+
+    if (LeftWeight < RightWeight)
+      LeftWeight += (++LastLeft)->Weight;
+    else
+      RightWeight += (--FirstRight)->Weight;
+  }
+  assert(LastLeft + 1 == FirstRight);
+  assert(LastLeft >= W.FirstCluster);
+  assert(FirstRight <= W.LastCluster);
+
+  // Use the first element on the right as pivot since we will make less-than
+  // comparisons against it.
+  CaseClusterIt PivotCluster = FirstRight;
+  assert(PivotCluster > W.FirstCluster);
+  assert(PivotCluster <= W.LastCluster);
 
-  CaseClusterIt PivotCluster = W.FirstCluster + NumClusters / 2;
   CaseClusterIt FirstLeft = W.FirstCluster;
-  CaseClusterIt LastLeft = PivotCluster - 1;
-  CaseClusterIt FirstRight = PivotCluster;
   CaseClusterIt LastRight = W.LastCluster;
+
   const ConstantInt *Pivot = PivotCluster->Low;
 
   // New blocks will be inserted immediately after the current one.
@@ -8032,7 +8059,8 @@ void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList,
   }
 
   // Create the CaseBlock record that will be used to lower the branch.
-  CaseBlock CB(ISD::SETLT, Cond, Pivot, nullptr, LeftMBB, RightMBB, W.MBB);
+  CaseBlock CB(ISD::SETLT, Cond, Pivot, nullptr, LeftMBB, RightMBB, W.MBB,
+               LeftWeight, RightWeight);
 
   if (W.MBB == SwitchMBB)
     visitSwitchCase(CB, SwitchMBB);
@@ -8048,7 +8076,7 @@ void SelectionDAGBuilder::visitSwitch(const SwitchInst &SI) {
   for (auto I : SI.cases()) {
     MachineBasicBlock *Succ = FuncInfo.MBBMap[I.getCaseSuccessor()];
     const ConstantInt *CaseVal = I.getCaseValue();
-    uint32_t Weight = 0; // FIXME: Use 1 instead?
+    uint32_t Weight = 1;
     if (BPI) {
       Weight = BPI->getEdgeWeight(SI.getParent(), I.getSuccessorIndex());
       assert(Weight <= UINT32_MAX / SI.getNumSuccessors());
index 83277c9..f030775 100644 (file)
@@ -5,7 +5,7 @@
 
 ; Make sure we have the correct weight attached to each successor.
 define i32 @test2(i32 %x) nounwind uwtable readnone ssp {
-; CHECK: Machine code for function test2:
+; CHECK-LABEL: Machine code for function test2:
 entry:
   %conv = sext i32 %x to i64
   switch i64 %conv, label %return [
@@ -33,3 +33,41 @@ return:
 }
 
 !0 = !{!"branch_weights", i32 7, i32 6, i32 4, i32 4, i32 64}
+
+
+declare void @g(i32)
+define void @left_leaning_weight_balanced_tree(i32 %x) {
+entry:
+  switch i32 %x, label %return [
+    i32 0,  label %bb0
+    i32 10, label %bb1
+    i32 20, label %bb2
+    i32 30, label %bb3
+    i32 40, label %bb4
+    i32 50, label %bb5
+  ], !prof !1
+bb0: tail call void @g(i32 0) br label %return
+bb1: tail call void @g(i32 1) br label %return
+bb2: tail call void @g(i32 2) br label %return
+bb3: tail call void @g(i32 3) br label %return
+bb4: tail call void @g(i32 4) br label %return
+bb5: tail call void @g(i32 5) br label %return
+return: ret void
+
+; Check that we set branch weights on the pivot cmp instruction correctly.
+; Cases {0,10,20,30} go on the left with weight 13; cases {40,50} go on the
+; right with weight 20.
+;
+; CHECK-LABEL: Machine code for function left_leaning_weight_balanced_tree:
+; CHECK: BB#0: derived from LLVM BB %entry
+; CHECK-NOT: Successors
+; CHECK: Successors according to CFG: BB#8(13) BB#9(20)
+}
+
+!1 = !{!"branch_weights",
+  ; Default:
+  i32 1,
+  ; Case 0, 10, 20:
+  i32 10, i32 1, i32 1,
+  ; Case 30, 40, 50:
+  i32 1, i32 10, i32 10}
index 2e5c0a6..d50eaba 100644 (file)
@@ -442,3 +442,86 @@ return: ret void
        i32 1000,
        ; Case 300:
        i32 10}
+
+
+define void @zero_weight_tree(i32 %x) {
+entry:
+  switch i32 %x, label %return [
+    i32 0,  label %bb0
+    i32 10, label %bb1
+    i32 20, label %bb2
+    i32 30, label %bb3
+    i32 40, label %bb4
+    i32 50, label %bb5
+  ], !prof !3
+bb0: tail call void @g(i32 0) br label %return
+bb1: tail call void @g(i32 1) br label %return
+bb2: tail call void @g(i32 2) br label %return
+bb3: tail call void @g(i32 3) br label %return
+bb4: tail call void @g(i32 4) br label %return
+bb5: tail call void @g(i32 5) br label %return
+return: ret void
+
+; Make sure to pick a pivot in the middle also with zero-weight cases.
+; CHECK-LABEL: zero_weight_tree
+; CHECK-NOT: cmpl
+; CHECK: cmpl $29
+}
+
+!3 = !{!"branch_weights", i32 1, i32 10, i32 0, i32 0, i32 0, i32 0, i32 10}
+
+
+define void @left_leaning_weight_balanced_tree(i32 %x) {
+entry:
+  switch i32 %x, label %return [
+    i32 0,  label %bb0
+    i32 10, label %bb1
+    i32 20, label %bb2
+    i32 30, label %bb3
+    i32 40, label %bb4
+    i32 50, label %bb5
+  ], !prof !4
+bb0: tail call void @g(i32 0) br label %return
+bb1: tail call void @g(i32 1) br label %return
+bb2: tail call void @g(i32 2) br label %return
+bb3: tail call void @g(i32 3) br label %return
+bb4: tail call void @g(i32 4) br label %return
+bb5: tail call void @g(i32 5) br label %return
+return: ret void
+
+; To balance the tree by weight, the pivot is shifted to the right, moving hot
+; cases closer to the root.
+; CHECK-LABEL: left_leaning_weight_balanced_tree
+; CHECK-NOT: cmpl
+; CHECK: cmpl $39
+}
+
+!4 = !{!"branch_weights", i32 1, i32 10, i32 1, i32 1, i32 1, i32 10, i32 10}
+
+
+define void @jump_table_affects_balance(i32 %x) {
+entry:
+  switch i32 %x, label %return [
+    ; Jump table:
+    i32 0,  label %bb0
+    i32 1,  label %bb1
+    i32 2,  label %bb2
+    i32 3,  label %bb3
+
+    i32 100, label %bb0
+    i32 200, label %bb1
+    i32 300, label %bb2
+  ]
+bb0: tail call void @g(i32 0) br label %return
+bb1: tail call void @g(i32 1) br label %return
+bb2: tail call void @g(i32 2) br label %return
+bb3: tail call void @g(i32 3) br label %return
+return: ret void
+
+; CHECK-LABEL: jump_table_affects_balance
+; If the tree were balanced based on number of clusters, {0-3,100} would go on
+; the left and {200,300} on the right. However, the jump table weights as much
+; as its components, so 100 is selected as the pivot.
+; CHECK-NOT: cmpl
+; CHECK: cmpl $99
+}