Peephole optimization in switch table lookup: reuse the guarding table comparison...
[oota-llvm.git] / lib / Transforms / Utils / SimplifyCFG.cpp
index 318773d64c294144db8a49e91afa74b33f8e309f..c4b45ed147363aa443899440d547cc9211a34f9a 100644 (file)
@@ -73,6 +73,7 @@ STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
 STATISTIC(NumLinearMaps, "Number of switch instructions turned into linear mapping");
 STATISTIC(NumLookupTables, "Number of switch instructions turned into lookup tables");
 STATISTIC(NumLookupTablesHoles, "Number of switch instructions turned into lookup tables (holes checked)");
+STATISTIC(NumTableCmpReuses, "Number of reused switch table lookup compares");
 STATISTIC(NumSinkCommons, "Number of common instructions sunk down to the end block");
 STATISTIC(NumSpeculations, "Number of speculative executed instructions");
 
@@ -357,6 +358,8 @@ static ConstantInt *GetConstantInt(Value *V, const DataLayout *DL) {
   return nullptr;
 }
 
+namespace {
+
 /// Given a chain of or (||) or and (&&) comparison of a value against a
 /// constant, this will try to recover the information required for a switch
 /// structure.
@@ -369,19 +372,22 @@ static ConstantInt *GetConstantInt(Value *V, const DataLayout *DL) {
 /// fail.
 struct ConstantComparesGatherer {
 
-  Value *CompValue = nullptr; /// Value found for the switch comparison
-  Value *Extra = nullptr;  /// Extra clause to be checked before the switch
-  SmallVector<ConstantInt*, 8> Vals; /// Set of integers to match in switch
-  unsigned UsedICmps = 0; /// Number of comparisons matched in the and/or chain
+  Value *CompValue; /// Value found for the switch comparison
+  Value *Extra;     /// Extra clause to be checked before the switch
+  SmallVector<ConstantInt *, 8> Vals; /// Set of integers to match in switch
+  unsigned UsedICmps; /// Number of comparisons matched in the and/or chain
 
   /// Construct and compute the result for the comparison instruction Cond
-  ConstantComparesGatherer(Instruction *Cond, const DataLayout *DL) {
+  ConstantComparesGatherer(Instruction *Cond, const DataLayout *DL)
+      : CompValue(nullptr), Extra(nullptr), UsedICmps(0) {
     gather(Cond, DL);
   }
 
   /// Prevent copy
-  ConstantComparesGatherer(const ConstantComparesGatherer&) = delete;
-  ConstantComparesGatherer &operator=(const ConstantComparesGatherer&) = delete;
+  ConstantComparesGatherer(const ConstantComparesGatherer &)
+      LLVM_DELETED_FUNCTION;
+  ConstantComparesGatherer &
+  operator=(const ConstantComparesGatherer &) LLVM_DELETED_FUNCTION;
 
 private:
 
@@ -389,7 +395,8 @@ private:
   /// it wasn't set before or if the new value is the same as the old one
   bool setValueOnce(Value *NewVal) {
     if(CompValue && CompValue != NewVal) return false;
-    return CompValue = NewVal;
+    CompValue = NewVal;
+    return (CompValue != nullptr);
   }
 
   /// Try to match Instruction "I" as a comparison against a constant and
@@ -522,6 +529,8 @@ private:
   }
 };
 
+}
+
 static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) {
   Instruction *Cond = nullptr;
   if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
@@ -2825,7 +2834,7 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const DataLayout *DL,
   // 'setne's and'ed together, collect them.
 
   // Try to gather values from a chain of and/or to be turned into a switch
-  ConstantComparesGatherer ConstantCompare{Cond, DL};
+  ConstantComparesGatherer ConstantCompare(Cond, DL);
   // Unpack the result
   SmallVectorImpl<ConstantInt*> &Values = ConstantCompare.Vals;
   Value *CompVal = ConstantCompare.CompValue;
@@ -3974,6 +3983,78 @@ static bool ShouldBuildLookupTable(SwitchInst *SI,
   return SI->getNumCases() * 10 >= TableSize * 4;
 }
 
+/// Try to reuse the switch table index compare. Following pattern:
+/// \code
+///     if (idx < tablesize)
+///        r = table[idx]; // table does not contain default_value
+///     else
+///        r = default_value;
+///     if (r != default_value)
+///        ...
+/// \endcode
+/// Is optimized to:
+/// \code
+///     cond = idx < tablesize;
+///     if (cond)
+///        r = table[idx];
+///     else
+///        r = default_value;
+///     if (cond)
+///        ...
+/// \endcode
+/// Jump threading will then eliminate the second if(cond).
+static void reuseTableCompare(User *PhiUser, BasicBlock *PhiBlock,
+          BranchInst *RangeCheckBranch, Constant *DefaultValue,
+          const SmallVectorImpl<std::pair<ConstantInt*, Constant*> >& Values) {
+
+  ICmpInst *CmpInst = dyn_cast<ICmpInst>(PhiUser);
+  if (!CmpInst)
+    return;
+
+  // We require that the compare is in the same block as the phi so that jump
+  // threading can do its work afterwards.
+  if (CmpInst->getParent() != PhiBlock)
+    return;
+
+  Constant *CmpOp1 = dyn_cast<Constant>(CmpInst->getOperand(1));
+  if (!CmpOp1)
+    return;
+
+  Value *RangeCmp = RangeCheckBranch->getCondition();
+  Constant *TrueConst = ConstantInt::getTrue(RangeCmp->getType());
+  Constant *FalseConst = ConstantInt::getFalse(RangeCmp->getType());
+
+  // Check if the compare with the default value is constant true or false.
+  Constant *DefaultConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
+                                                 DefaultValue, CmpOp1, true);
+  if (DefaultConst != TrueConst && DefaultConst != FalseConst)
+    return;
+
+  // Check if the compare with the case values is distinct from the default
+  // compare result.
+  for (auto ValuePair : Values) {
+    Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(),
+                              ValuePair.second, CmpOp1, true);
+    if (!CaseConst || CaseConst == DefaultConst)
+      return;
+    assert((CaseConst == TrueConst || CaseConst == FalseConst) &&
+           "Expect true or false as compare result.");
+  }
+
+  if (DefaultConst == FalseConst) {
+    // The compare yields the same result. We can replace it.
+    CmpInst->replaceAllUsesWith(RangeCmp);
+    ++NumTableCmpReuses;
+  } else {
+    // The compare yields the same result, just inverted. We can replace it.
+    Value *InvertedTableCmp = BinaryOperator::CreateXor(RangeCmp,
+                ConstantInt::get(RangeCmp->getType(), 1), "inverted.cmp",
+                RangeCheckBranch);
+    CmpInst->replaceAllUsesWith(InvertedTableCmp);
+    ++NumTableCmpReuses;
+  }
+}
+
 /// SwitchToLookupTable - If the switch is only used to initialize one or more
 /// phi nodes in a common successor block with different constant values,
 /// replace the switch with lookup tables.
@@ -4050,11 +4131,8 @@ static bool SwitchToLookupTable(SwitchInst *SI,
   // If the table has holes, we need a constant result for the default case
   // or a bitmask that fits in a register.
   SmallVector<std::pair<PHINode*, Constant*>, 4> DefaultResultsList;
-  bool HasDefaultResults = false;
-  if (TableHasHoles) {
-    HasDefaultResults = GetCaseResults(SI, nullptr, SI->getDefaultDest(),
+  bool HasDefaultResults = GetCaseResults(SI, nullptr, SI->getDefaultDest(),
                                        &CommonDest, DefaultResultsList, DL);
-  }
 
   bool NeedMask = (TableHasHoles && !HasDefaultResults);
   if (NeedMask) {
@@ -4098,6 +4176,8 @@ static bool SwitchToLookupTable(SwitchInst *SI,
   // lookup table BB. Otherwise, check if the condition value is within the case
   // range. If it is so, branch to the new BB. Otherwise branch to SI's default
   // destination.
+  BranchInst *RangeCheckBranch = nullptr;
+
   const bool GeneratingCoveredLookupTable = MaxTableSize == TableSize;
   if (GeneratingCoveredLookupTable) {
     Builder.CreateBr(LookupBB);
@@ -4108,7 +4188,7 @@ static bool SwitchToLookupTable(SwitchInst *SI,
   } else {
     Value *Cmp = Builder.CreateICmpULT(TableIndex, ConstantInt::get(
                                        MinCaseVal->getType(), TableSize));
-    Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest());
+    RangeCheckBranch = Builder.CreateCondBr(Cmp, LookupBB, SI->getDefaultDest());
   }
 
   // Populate the BB that does the lookups.
@@ -4159,11 +4239,11 @@ static bool SwitchToLookupTable(SwitchInst *SI,
   bool ReturnedEarly = false;
   for (size_t I = 0, E = PHIs.size(); I != E; ++I) {
     PHINode *PHI = PHIs[I];
+    const ResultListTy &ResultList = ResultLists[PHI];
 
     // If using a bitmask, use any value to fill the lookup table holes.
     Constant *DV = NeedMask ? ResultLists[PHI][0].second : DefaultResults[PHI];
-    SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultLists[PHI],
-                            DV, DL);
+    SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultList, DV, DL);
 
     Value *Result = Table.BuildLookup(TableIndex, Builder);
 
@@ -4176,6 +4256,16 @@ static bool SwitchToLookupTable(SwitchInst *SI,
       break;
     }
 
+    // Do a small peephole optimization: re-use the switch table compare if
+    // possible.
+    if (!TableHasHoles && HasDefaultResults && RangeCheckBranch) {
+      BasicBlock *PhiBlock = PHI->getParent();
+      // Search for compare instructions which use the phi.
+      for (auto *User : PHI->users()) {
+        reuseTableCompare(User, PhiBlock, RangeCheckBranch, DV, ResultList);
+      }
+    }
+
     PHI->addIncoming(Result, LookupBB);
   }