Switch lowering: add heuristic for filling leaf nodes in the weight-balanced binary...
authorHans Wennborg <hans@hanshq.net>
Sat, 20 Jun 2015 17:14:07 +0000 (17:14 +0000)
committerHans Wennborg <hans@hanshq.net>
Sat, 20 Jun 2015 17:14:07 +0000 (17:14 +0000)
Sparse switches with profile info are lowered as weight-balanced BSTs. For
example, if the node weights are {1,1,1,1,1,1000}, the right-most node would
end up in a tree by itself, bringing it closer to the top.

However, a leaf in this BST can contain up to 3 cases, and having a single
case in a leaf node as in the example means the tree might become
unnecessarily high.

This patch adds a heauristic to the pivot selection algorithm that moves more
cases into leaf nodes unless that would lower their rank. It still doesn't
yield the optimal tree in every case, but I believe it's conservatibely correct.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@240224 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
test/CodeGen/X86/switch.ll

index ab988f6..8313a48 100644 (file)
@@ -7996,6 +7996,18 @@ void SelectionDAGBuilder::lowerWorkItem(SwitchWorkListItem W, Value *Cond,
   }
 }
 
+unsigned SelectionDAGBuilder::caseClusterRank(const CaseCluster &CC,
+                                              CaseClusterIt First,
+                                              CaseClusterIt Last) {
+  return std::count_if(First, Last + 1, [&](const CaseCluster &X) {
+    if (X.Weight != CC.Weight)
+      return X.Weight > CC.Weight;
+
+    // Ties are broken by comparing the case value.
+    return X.Low->getValue().slt(CC.Low->getValue());
+  });
+}
+
 void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList,
                                         const SwitchWorkListItem &W,
                                         Value *Cond,
@@ -8025,6 +8037,48 @@ void SelectionDAGBuilder::splitWorkItem(SwitchWorkList &WorkList,
       RightWeight += (--FirstRight)->Weight;
     I++;
   }
+
+  for (;;) {
+    // Our binary search tree differs from a typical BST in that ours can have up
+    // to three values in each leaf. The pivot selection above doesn't take that
+    // into account, which means the tree might require more nodes and be less
+    // efficient. We compensate for this here.
+
+    unsigned NumLeft = LastLeft - W.FirstCluster + 1;
+    unsigned NumRight = W.LastCluster - FirstRight + 1;
+
+    if (std::min(NumLeft, NumRight) < 3 && std::max(NumLeft, NumRight) > 3) {
+      // If one side has less than 3 clusters, and the other has more than 3,
+      // consider taking a cluster from the other side.
+
+      if (NumLeft < NumRight) {
+        // Consider moving the first cluster on the right to the left side.
+        CaseCluster &CC = *FirstRight;
+        unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
+        unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
+        if (LeftSideRank <= RightSideRank) {
+          // Moving the cluster to the left does not demote it.
+          ++LastLeft;
+          ++FirstRight;
+          continue;
+        }
+      } else {
+        assert(NumRight < NumLeft);
+        // Consider moving the last element on the left to the right side.
+        CaseCluster &CC = *LastLeft;
+        unsigned LeftSideRank = caseClusterRank(CC, W.FirstCluster, LastLeft);
+        unsigned RightSideRank = caseClusterRank(CC, FirstRight, W.LastCluster);
+        if (RightSideRank <= LeftSideRank) {
+          // Moving the cluster to the right does not demot it.
+          --LastLeft;
+          --FirstRight;
+          continue;
+        }
+      }
+    }
+    break;
+  }
+
   assert(LastLeft + 1 == FirstRight);
   assert(LastLeft >= W.FirstCluster);
   assert(FirstRight <= W.LastCluster);
index f0c03af..f225d54 100644 (file)
@@ -342,6 +342,11 @@ private:
   };
   typedef SmallVector<SwitchWorkListItem, 4> SwitchWorkList;
 
+  /// Determine the rank by weight of CC in [First,Last]. If CC has more weight
+  /// than each cluster in the range, its rank is 0.
+  static unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First,
+                                  CaseClusterIt Last);
+
   /// Emit comparison and split W into two subtrees.
   void splitWorkItem(SwitchWorkList &WorkList, const SwitchWorkListItem &W,
                      Value *Cond, MachineBasicBlock *SwitchMBB);
index fc217d5..748fd6f 100644 (file)
@@ -499,6 +499,8 @@ entry:
     i32 30, label %bb3
     i32 40, label %bb4
     i32 50, label %bb5
+    i32 60, label %bb6
+    i32 70, label %bb6
   ], !prof !4
 bb0: tail call void @g(i32 0) br label %return
 bb1: tail call void @g(i32 1) br label %return
@@ -506,16 +508,87 @@ bb2: tail call void @g(i32 2) br label %return
 bb3: tail call void @g(i32 3) br label %return
 bb4: tail call void @g(i32 4) br label %return
 bb5: tail call void @g(i32 5) br label %return
+bb6: tail call void @g(i32 6) br label %return
+bb7: tail call void @g(i32 7) br label %return
 return: ret void
 
-; To balance the tree by weight, the pivot is shifted to the right, moving hot
-; cases closer to the root.
+; Without branch probabilities, the pivot would be 40, since that would yield
+; equal-sized sub-trees. When taking weights into account, case 70 becomes the
+; pivot. Since there is room for 3 cases in a leaf, cases 50 and 60 are also
+; included in the right-hand side because that doesn't reduce their rank.
+
 ; CHECK-LABEL: left_leaning_weight_balanced_tree
 ; CHECK-NOT: cmpl
-; CHECK: cmpl $39
+; CHECK: cmpl $49
+}
+
+!4 = !{!"branch_weights", i32 1, i32 10, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1000}
+
+
+define void @left_leaning_weight_balanced_tree2(i32 %x) {
+entry:
+  switch i32 %x, label %return [
+    i32 0,  label %bb0
+    i32 10, label %bb1
+    i32 20, label %bb2
+    i32 30, label %bb3
+    i32 40, label %bb4
+    i32 50, label %bb5
+    i32 60, label %bb6
+    i32 70, label %bb6
+  ], !prof !5
+bb0: tail call void @g(i32 0) br label %return
+bb1: tail call void @g(i32 1) br label %return
+bb2: tail call void @g(i32 2) br label %return
+bb3: tail call void @g(i32 3) br label %return
+bb4: tail call void @g(i32 4) br label %return
+bb5: tail call void @g(i32 5) br label %return
+bb6: tail call void @g(i32 6) br label %return
+bb7: tail call void @g(i32 7) br label %return
+return: ret void
+
+; Same as the previous test, except case 50 has higher rank to the left than it
+; would have on the right. Case 60 would have the same rank on both sides, so is
+; moved into the leaf.
+
+; CHECK-LABEL: left_leaning_weight_balanced_tree2
+; CHECK-NOT: cmpl
+; CHECK: cmpl $59
+}
+
+!5 = !{!"branch_weights", i32 1, i32 10, i32 1, i32 1, i32 1, i32 1, i32 90, i32 70, i32 1000}
+
+
+define void @right_leaning_weight_balanced_tree(i32 %x) {
+entry:
+  switch i32 %x, label %return [
+    i32 0,  label %bb0
+    i32 10, label %bb1
+    i32 20, label %bb2
+    i32 30, label %bb3
+    i32 40, label %bb4
+    i32 50, label %bb5
+    i32 60, label %bb6
+    i32 70, label %bb6
+  ], !prof !6
+bb0: tail call void @g(i32 0) br label %return
+bb1: tail call void @g(i32 1) br label %return
+bb2: tail call void @g(i32 2) br label %return
+bb3: tail call void @g(i32 3) br label %return
+bb4: tail call void @g(i32 4) br label %return
+bb5: tail call void @g(i32 5) br label %return
+bb6: tail call void @g(i32 6) br label %return
+bb7: tail call void @g(i32 7) br label %return
+return: ret void
+
+; Analogous to left_leaning_weight_balanced_tree.
+
+; CHECK-LABEL: right_leaning_weight_balanced_tree
+; CHECK-NOT: cmpl
+; CHECK: cmpl $19
 }
 
-!4 = !{!"branch_weights", i32 1, i32 10, i32 1, i32 1, i32 1, i32 10, i32 10}
+!6 = !{!"branch_weights", i32 1, i32 1000, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 10}
 
 
 define void @jump_table_affects_balance(i32 %x) {