Switch-to-lookup tables: Don't require a result for the default
[oota-llvm.git] / lib / Transforms / Utils / SimplifyCFG.cpp
index 0a45897961535a8994e95cb0963c542e4eb27b76..90f8847bc1c29ce0d27a6329218827d8a0cba488 100644 (file)
@@ -3428,7 +3428,7 @@ GetCaseResults(SwitchInst *SI,
     Res.push_back(std::make_pair(PHI, ConstVal));
   }
 
-  return true;
+  return Res.size() > 0;
 }
 
 namespace {
@@ -3499,12 +3499,14 @@ SwitchLookupTable::SwitchLookupTable(Module &M,
   // If all values in the table are equal, this is that value.
   SingleValue = Values.begin()->second;
 
+  Type *ValueType = Values.begin()->second->getType();
+
   // Build up the table contents.
   SmallVector<Constant*, 64> TableContents(TableSize);
   for (size_t I = 0, E = Values.size(); I != E; ++I) {
     ConstantInt *CaseVal = Values[I].first;
     Constant *CaseRes = Values[I].second;
-    assert(CaseRes->getType() == DefaultValue->getType());
+    assert(CaseRes->getType() == ValueType);
 
     uint64_t Idx = (CaseVal->getValue() - Offset->getValue())
                    .getLimitedValue();
@@ -3516,6 +3518,8 @@ SwitchLookupTable::SwitchLookupTable(Module &M,
 
   // Fill in any holes in the table with the default result.
   if (Values.size() < TableSize) {
+    assert(DefaultValue && "Need a default value to fill the lookup table holes.");
+    assert(DefaultValue->getType() == ValueType);
     for (uint64_t I = 0; I < TableSize; ++I) {
       if (!TableContents[I])
         TableContents[I] = DefaultValue;
@@ -3533,8 +3537,8 @@ SwitchLookupTable::SwitchLookupTable(Module &M,
   }
 
   // If the type is integer and the table fits in a register, build a bitmap.
-  if (WouldFitInRegister(TD, TableSize, DefaultValue->getType())) {
-    IntegerType *IT = cast<IntegerType>(DefaultValue->getType());
+  if (WouldFitInRegister(TD, TableSize, ValueType)) {
+    IntegerType *IT = cast<IntegerType>(ValueType);
     APInt TableInt(TableSize * IT->getBitWidth(), 0);
     for (uint64_t I = TableSize; I > 0; --I) {
       TableInt <<= IT->getBitWidth();
@@ -3552,7 +3556,7 @@ SwitchLookupTable::SwitchLookupTable(Module &M,
   }
 
   // Store the table in an array.
-  ArrayType *ArrayTy = ArrayType::get(DefaultValue->getType(), TableSize);
+  ArrayType *ArrayTy = ArrayType::get(ValueType, TableSize);
   Constant *Initializer = ConstantArray::get(ArrayTy, TableContents);
 
   Array = new GlobalVariable(M, ArrayTy, /*constant=*/ true,
@@ -3723,20 +3727,29 @@ static bool SwitchToLookupTable(SwitchInst *SI,
     }
   }
 
-  // Get the resulting values for the default case.
+  // Keep track of the result types.
+  for (size_t I = 0, E = PHIs.size(); I != E; ++I) {
+    PHINode *PHI = PHIs[I];
+    ResultTypes[PHI] = ResultLists[PHI][0].second->getType();
+  }
+
+  uint64_t NumResults = ResultLists[PHIs[0]].size();
+  APInt RangeSpread = MaxCaseVal->getValue() - MinCaseVal->getValue();
+  uint64_t TableSize = RangeSpread.getLimitedValue() + 1;
+  bool TableHasHoles = (NumResults < TableSize);
+
+  // If the table has holes, we need a constant result for the default case.
   SmallVector<std::pair<PHINode*, Constant*>, 4> DefaultResultsList;
-  if (!GetCaseResults(SI, 0, SI->getDefaultDest(), &CommonDest,
-                      DefaultResultsList, TD))
+  if (TableHasHoles && !GetCaseResults(SI, 0, SI->getDefaultDest(), &CommonDest,
+                                       DefaultResultsList, TD))
     return false;
+
   for (size_t I = 0, E = DefaultResultsList.size(); I != E; ++I) {
     PHINode *PHI = DefaultResultsList[I].first;
     Constant *Result = DefaultResultsList[I].second;
     DefaultResults[PHI] = Result;
-    ResultTypes[PHI] = Result->getType();
   }
 
-  APInt RangeSpread = MaxCaseVal->getValue() - MinCaseVal->getValue();
-  uint64_t TableSize = RangeSpread.getLimitedValue() + 1;
   if (!ShouldBuildLookupTable(SI, TableSize, TTI, TD, ResultTypes))
     return false;
 
@@ -3755,7 +3768,7 @@ static bool SwitchToLookupTable(SwitchInst *SI,
   // Compute the maximum table size representable by the integer type we are
   // switching upon.
   unsigned CaseSize = MinCaseVal->getType()->getPrimitiveSizeInBits();
-  uint64_t MaxTableSize = CaseSize > 63? UINT64_MAX : 1ULL << CaseSize;
+  uint64_t MaxTableSize = CaseSize > 63 ? UINT64_MAX : 1ULL << CaseSize;
   assert(MaxTableSize >= TableSize &&
          "It is impossible for a switch to have more entries than the max "
          "representable value of its input integer type's size.");