Do simple constant propagation in lookup table formation for switches
[oota-llvm.git] / lib / Transforms / Utils / SimplifyCFG.cpp
index f05a8bae6006a82fcf3255ddf345113306010e2c..9a9ce8568f16c89f32d08c9172728ca1379683ef 100644 (file)
@@ -14,6 +14,7 @@
 #define DEBUG_TYPE "simplifycfg"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Constants.h"
+#include "llvm/DataLayout.h"
 #include "llvm/DerivedTypes.h"
 #include "llvm/GlobalVariable.h"
 #include "llvm/IRBuilder.h"
@@ -39,7 +40,7 @@
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/NoFolder.h"
 #include "llvm/Support/raw_ostream.h"
-#include "llvm/DataLayout.h"
+#include "llvm/TargetTransformInfo.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include <algorithm>
 #include <set>
@@ -82,6 +83,7 @@ namespace {
 
 class SimplifyCFGOpt {
   const DataLayout *const TD;
+  const TargetTransformInfo *const TTI;
 
   Value *isValueEqualityComparison(TerminatorInst *TI);
   BasicBlock *GetValueEqualityComparisonCases(TerminatorInst *TI,
@@ -101,7 +103,8 @@ class SimplifyCFGOpt {
   bool SimplifyCondBranch(BranchInst *BI, IRBuilder <>&Builder);
 
 public:
-  explicit SimplifyCFGOpt(const DataLayout *td) : TD(td) {}
+  SimplifyCFGOpt(const DataLayout *td, const TargetTransformInfo *tti)
+      : TD(td), TTI(tti) {}
   bool run(BasicBlock *BB);
 };
 }
@@ -3197,26 +3200,99 @@ static bool ValidLookupTableConstant(Constant *C) {
       isa<UndefValue>(C);
 }
 
-/// GetCaseResulsts - Try to determine the resulting constant values in phi
-/// nodes at the common destination basic block for one of the case
-/// destinations of a switch instruction.
+/// ConstantFold - Try to fold instruction I into a constant. This works for
+/// simple instructions such as binary operations where both operands are
+/// constant or can be replaced by constants from the ConstantPool. Returns the
+/// resulting constant on success, NULL otherwise.
+static Constant* ConstantFold(Instruction *I,
+                         const SmallDenseMap<Value*, Constant*>& ConstantPool) {
+  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) {
+    Constant *A = dyn_cast<Constant>(BO->getOperand(0));
+    if (!A) A = ConstantPool.lookup(BO->getOperand(0));
+    if (!A) return NULL;
+
+    Constant *B = dyn_cast<Constant>(BO->getOperand(1));
+    if (!B) B = ConstantPool.lookup(BO->getOperand(1));
+    if (!B) return NULL;
+
+    Constant *C = ConstantExpr::get(BO->getOpcode(), A, B);
+    return C;
+  }
+
+  if (CmpInst *Cmp = dyn_cast<CmpInst>(I)) {
+    Constant *A = dyn_cast<Constant>(I->getOperand(0));
+    if (!A) A = ConstantPool.lookup(I->getOperand(0));
+    if (!A) return NULL;
+
+    Constant *B = dyn_cast<Constant>(I->getOperand(1));
+    if (!B) B = ConstantPool.lookup(I->getOperand(1));
+    if (!B) return NULL;
+
+    Constant *C = ConstantExpr::getCompare(Cmp->getPredicate(), A, B);
+    return C;
+  }
+
+  if (SelectInst *Select = dyn_cast<SelectInst>(I)) {
+    Constant *A = dyn_cast<Constant>(Select->getCondition());
+    if (!A) A = ConstantPool.lookup(Select->getCondition());
+    if (!A) return NULL;
+
+    Value *Res;
+    if (A->isAllOnesValue()) Res = Select->getTrueValue();
+    else if (A->isNullValue()) Res = Select->getFalseValue();
+    else return NULL;
+
+    Constant *C = dyn_cast<Constant>(Res);
+    if (!C) C = ConstantPool.lookup(Res);
+    if (!C) return NULL;
+    return C;
+  }
+
+  if (CastInst *Cast = dyn_cast<CastInst>(I)) {
+    Constant *A = dyn_cast<Constant>(I->getOperand(0));
+    if (!A) A = ConstantPool.lookup(I->getOperand(0));
+    if (!A) return false;
+
+    Constant *C = ConstantExpr::getCast(Cast->getOpcode(), A, Cast->getDestTy());
+    return C;
+  }
+
+  return NULL;
+}
+
+/// GetCaseResults - Try to determine the resulting constant values in phi nodes
+/// at the common destination basic block, *CommonDest, for one of the case
+/// destionations CaseDest corresponding to value CaseVal (NULL for the default
+/// case), of a switch instruction SI.
 static bool GetCaseResults(SwitchInst *SI,
+                           ConstantInt *CaseVal,
                            BasicBlock *CaseDest,
                            BasicBlock **CommonDest,
                            SmallVector<std::pair<PHINode*,Constant*>, 4> &Res) {
   // The block from which we enter the common destination.
   BasicBlock *Pred = SI->getParent();
 
-  // If CaseDest is empty, continue to its successor.
-  if (CaseDest->getFirstNonPHIOrDbg() == CaseDest->getTerminator() &&
-      !isa<PHINode>(CaseDest->begin())) {
-
-    TerminatorInst *Terminator = CaseDest->getTerminator();
-    if (Terminator->getNumSuccessors() != 1)
-      return false;
-
-    Pred = CaseDest;
-    CaseDest = Terminator->getSuccessor(0);
+  // If CaseDest is empty except for some side-effect free instructions through
+  // which we can constant-propagate the CaseVal, continue to its successor.
+  SmallDenseMap<Value*, Constant*> ConstantPool;
+  ConstantPool.insert(std::make_pair(SI->getCondition(), CaseVal));
+  for (BasicBlock::iterator I = CaseDest->begin(), E = CaseDest->end(); I != E;
+       ++I) {
+    if (TerminatorInst *T = dyn_cast<TerminatorInst>(I)) {
+      // If the terminator is a simple branch, continue to the next block.
+      if (T->getNumSuccessors() != 1)
+        return false;
+      Pred = CaseDest;
+      CaseDest = T->getSuccessor(0);
+    } else if (isa<DbgInfoIntrinsic>(I)) {
+      // Skip debug intrinsic.
+      continue;
+    } else if (Constant *C = ConstantFold(I, ConstantPool)) {
+      // Instruction is side-effect free and constant.
+      ConstantPool.insert(std::make_pair(I, C));
+    } else {
+      break;
+    }
   }
 
   // If we did not have a CommonDest before, use the current one.
@@ -3234,9 +3310,17 @@ static bool GetCaseResults(SwitchInst *SI,
       continue;
 
     Constant *ConstVal = dyn_cast<Constant>(PHI->getIncomingValue(Idx));
+    if (!ConstVal)
+      ConstVal = ConstantPool.lookup(PHI->getIncomingValue(Idx));
     if (!ConstVal)
       return false;
 
+    // Note: If the constant comes from constant-propagating the case value
+    // through the CaseDest basic block, it will be safe to remove the
+    // instructions in that block. They cannot be used (except in the phi nodes
+    // we visit) outside CaseDest, because that block does not dominate its
+    // successor. If it did, we would not be in this phi node.
+
     // Be conservative about which kinds of constants we support.
     if (!ValidLookupTableConstant(ConstVal))
       return false;
@@ -3459,8 +3543,13 @@ static bool ShouldBuildLookupTable(SwitchInst *SI,
 /// replace the switch with lookup tables.
 static bool SwitchToLookupTable(SwitchInst *SI,
                                 IRBuilder<> &Builder,
-                                const DataLayout* TD) {
+                                const DataLayout* TD,
+                                const TargetTransformInfo *TTI) {
   assert(SI->getNumCases() > 1 && "Degenerate switch?");
+
+  if (TTI && !TTI->getScalarTargetTransformInfo()->shouldBuildLookupTables())
+    return false;
+
   // FIXME: Handle unreachable cases.
 
   // FIXME: If the switch is too sparse for a lookup table, perhaps we could
@@ -3501,7 +3590,8 @@ static bool SwitchToLookupTable(SwitchInst *SI,
     // Resulting value at phi nodes for this case value.
     typedef SmallVector<std::pair<PHINode*, Constant*>, 4> ResultsTy;
     ResultsTy Results;
-    if (!GetCaseResults(SI, CI.getCaseSuccessor(), &CommonDest, Results))
+    if (!GetCaseResults(SI, CaseVal, CI.getCaseSuccessor(), &CommonDest,
+                        Results))
       return false;
 
     // Append the result from this case to the list for each phi.
@@ -3514,7 +3604,8 @@ static bool SwitchToLookupTable(SwitchInst *SI,
 
   // Get the resulting values for the default case.
   SmallVector<std::pair<PHINode*, Constant*>, 4> DefaultResultsList;
-  if (!GetCaseResults(SI, SI->getDefaultDest(), &CommonDest, DefaultResultsList))
+  if (!GetCaseResults(SI, NULL, SI->getDefaultDest(), &CommonDest,
+                      DefaultResultsList))
     return false;
   for (size_t I = 0, E = DefaultResultsList.size(); I != E; ++I) {
     PHINode *PHI = DefaultResultsList[I].first;
@@ -3619,7 +3710,7 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
   if (ForwardSwitchConditionToPHI(SI))
     return SimplifyCFG(BB) | true;
 
-  if (SwitchToLookupTable(SI, Builder, TD))
+  if (SwitchToLookupTable(SI, Builder, TD, TTI))
     return SimplifyCFG(BB) | true;
 
   return false;
@@ -3916,6 +4007,7 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) {
 /// eliminates unreachable basic blocks, and does other "peephole" optimization
 /// of the CFG.  It returns true if a modification was made.
 ///
-bool llvm::SimplifyCFG(BasicBlock *BB, const DataLayout *TD) {
-  return SimplifyCFGOpt(TD).run(BB);
+bool llvm::SimplifyCFG(BasicBlock *BB, const DataLayout *TD,
+                       const TargetTransformInfo *TTI) {
+  return SimplifyCFGOpt(TD, TTI).run(BB);
 }