For PR950:
[oota-llvm.git] / lib / Transforms / Utils / LowerSwitch.cpp
index ebad6c36dbc7bf0558fbe5b15c5f1ced1c18f1fd..b2974a98c8028ea40dc2601bdd3294f79ad8dbac 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
 #include "llvm/Constants.h"
 #include "llvm/Function.h"
 #include "llvm/Instructions.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/Debug.h"
-#include "llvm/ADT/Statistic.h"
+#include "llvm/Support/Compiler.h"
 #include <algorithm>
-#include <iostream>
 using namespace llvm;
 
 namespace {
-  Statistic<> NumLowered("lowerswitch", "Number of SwitchInst's replaced");
-
   /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch
   /// instructions.  Note that this cannot be a BasicBlock pass because it
   /// modifies the CFG!
-  class LowerSwitch : public FunctionPass {
+  class VISIBILITY_HIDDEN LowerSwitch : public FunctionPass {
   public:
-    bool runOnFunction(Function &F);
+    virtual bool runOnFunction(Function &F);
+    
+    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
+      // This is a cluster of orthogonal Transforms    
+      AU.addPreserved<UnifyFunctionExitNodes>();
+      AU.addPreservedID(PromoteMemoryToRegisterID);
+      AU.addPreservedID(LowerSelectID);
+      AU.addPreservedID(LowerInvokePassID);
+      AU.addPreservedID(LowerAllocationsID);
+    }
+        
     typedef std::pair<Constant*, BasicBlock*> Case;
     typedef std::vector<Case>::iterator       CaseItr;
   private:
@@ -48,15 +56,16 @@ namespace {
   struct CaseCmp {
     bool operator () (const LowerSwitch::Case& C1,
                       const LowerSwitch::Case& C2) {
-      if (const ConstantUInt* U1 = dyn_cast<const ConstantUInt>(C1.first))
-        return U1->getValue() < cast<const ConstantUInt>(C2.first)->getValue();
 
-      const ConstantSInt* S1 = dyn_cast<const ConstantSInt>(C1.first);
-      return S1->getValue() < cast<const ConstantSInt>(C2.first)->getValue();
+      const ConstantInt* CI1 = cast<const ConstantInt>(C1.first);
+      const ConstantInt* CI2 = cast<const ConstantInt>(C2.first);
+      if (CI1->getType()->isUnsigned()) 
+        return CI1->getZExtValue() < CI2->getZExtValue();
+      return CI1->getSExtValue() < CI2->getSExtValue();
     }
   };
 
-  RegisterOpt<LowerSwitch>
+  RegisterPass<LowerSwitch>
   X("lowerswitch", "Lower SwitchInst's to branches");
 }
 
@@ -96,6 +105,10 @@ std::ostream& operator<<(std::ostream &O,
 
   return O << "]";
 }
+OStream& operator<<(OStream &O, const std::vector<LowerSwitch::Case> &C) {
+  if (O.stream()) *O.stream() << C;
+  return O;
+}
 
 // switchConvert - Convert the switch statement into a binary lookup of
 // the case values. The function recursively builds this tree.
@@ -111,14 +124,13 @@ BasicBlock* LowerSwitch::switchConvert(CaseItr Begin, CaseItr End,
 
   unsigned Mid = Size / 2;
   std::vector<Case> LHS(Begin, Begin + Mid);
-  DEBUG(std::cerr << "LHS: " << LHS << "\n");
+  DOUT << "LHS: " << LHS << "\n";
   std::vector<Case> RHS(Begin + Mid, End);
-  DEBUG(std::cerr << "RHS: " << RHS << "\n");
+  DOUT << "RHS: " << RHS << "\n";
 
   Case& Pivot = *(Begin + Mid);
-  DEBUG(std::cerr << "Pivot ==> "
-                  << (int64_t)cast<ConstantInt>(Pivot.first)->getRawValue()
-                  << "\n");
+  DOUT << "Pivot ==> "
+       << cast<ConstantInt>(Pivot.first)->getSExtValue() << "\n";
 
   BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val,
                                       OrigBlock, Default);
@@ -131,8 +143,7 @@ BasicBlock* LowerSwitch::switchConvert(CaseItr Begin, CaseItr End,
   BasicBlock* NewNode = new BasicBlock("NodeBlock");
   F->getBasicBlockList().insert(OrigBlock->getNext(), NewNode);
 
-  SetCondInst* Comp = new SetCondInst(Instruction::SetLT, Val, Pivot.first,
-                                      "Pivot");
+  ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_ULT, Val, Pivot.first, "Pivot");
   NewNode->getInstList().push_back(Comp);
   new BranchInst(LBranch, RBranch, Comp, NewNode);
   return NewNode;
@@ -153,8 +164,8 @@ BasicBlock* LowerSwitch::newLeafBlock(Case& Leaf, Value* Val,
   F->getBasicBlockList().insert(OrigBlock->getNext(), NewLeaf);
 
   // Make the seteq instruction...
-  SetCondInst* Comp = new SetCondInst(Instruction::SetEQ, Val,
-                                      Leaf.first, "SwitchLeaf");
+  ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_EQ, Val,
+                                Leaf.first, "SwitchLeaf");
   NewLeaf->getInstList().push_back(Comp);
 
   // Make the conditional branch...
@@ -213,7 +224,7 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI) {
     Cases.push_back(Case(SI->getSuccessorValue(i), SI->getSuccessor(i)));
 
   std::sort(Cases.begin(), Cases.end(), CaseCmp());
-  DEBUG(std::cerr << "Cases: " << Cases << "\n");
+  DOUT << "Cases: " << Cases << "\n";
   BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val,
                                           OrigBlock, NewDefault);