Handle PHI nodes prefacing EH pads too
[oota-llvm.git] / lib / Transforms / Utils / LowerSwitch.cpp
index a057f5d0c0fadebbb915e1c1540b910f9103f872..4acd988691d22f99caaba4c65e98fef1565e026f 100644 (file)
@@ -101,7 +101,7 @@ namespace {
       return CI1->getValue().slt(CI2->getValue());
     }
   };
-} // namespace
+}
 
 char LowerSwitch::ID = 0;
 INITIALIZE_PASS(LowerSwitch, "lowerswitch",
@@ -364,9 +364,9 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) {
   std::sort(Cases.begin(), Cases.end(), CaseCmp());
 
   // Merge case into clusters
-  if (Cases.size()>=2)
-    for (CaseItr I = Cases.begin(), J = std::next(Cases.begin());
-         J != Cases.end();) {
+  if (Cases.size() >= 2) {
+    CaseItr I = Cases.begin();
+    for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) {
       int64_t nextValue = J->Low->getSExtValue();
       int64_t currentValue = I->High->getSExtValue();
       BasicBlock* nextBB = J->BB;
@@ -374,13 +374,16 @@ unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) {
 
       // If the two neighboring cases go to the same destination, merge them
       // into a single case.
-      if ((nextValue-currentValue==1) && (currentBB == nextBB)) {
+      assert(nextValue > currentValue && "Cases should be strictly ascending");
+      if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
         I->High = J->High;
-        J = Cases.erase(J);
-      } else {
-        I = J++;
+        // FIXME: Combine branch weights.
+      } else if (++I != J) {
+        *I = *J;
       }
     }
+    Cases.erase(std::next(I), Cases.end());
+  }
 
   for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) {
     if (I->Low != I->High)
@@ -476,12 +479,10 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI) {
     // cases.
     assert(MaxPop > 0 && PopSucc);
     Default = PopSucc;
-    for (CaseItr I = Cases.begin(); I != Cases.end();) {
-      if (I->BB == PopSucc)
-        I = Cases.erase(I);
-      else
-        ++I;
-    }
+    Cases.erase(std::remove_if(
+                    Cases.begin(), Cases.end(),
+                    [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }),
+                Cases.end());
 
     // If there are no cases left, just branch.
     if (Cases.empty()) {