SimplifyCFG: turn recursive GatherConstantCompares into iterative
authorMehdi Amini <mehdi.amini@apple.com>
Wed, 19 Nov 2014 20:09:11 +0000 (20:09 +0000)
committerMehdi Amini <mehdi.amini@apple.com>
Wed, 19 Nov 2014 20:09:11 +0000 (20:09 +0000)
A long sequence of || or && could lead to a stack explosion.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@222384 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Utils/SimplifyCFG.cpp

index 7b3d2fb..aa3baf0 100644 (file)
@@ -357,114 +357,159 @@ static ConstantInt *GetConstantInt(Value *V, const DataLayout *DL) {
   return nullptr;
 }
 
+
+
+// Try to match Instruction I as a comparison against a constant and populates
+// Vals with the set of value that match (or does not depending on isEQ).
+// Return nullptr on failure, or return the Value the comparison matched against
+// on success
+// CurrValue, if supplied, is the value we want to match against. The function
+// is expected to fail if a match is found but the value compared to is not the
+// one expected. If CurrValue is supplied, the return value has to be either
+// nullptr or CurrValue
+static Value* GatherConstantComparesMatch(Instruction *I,
+                                          Value *CurrValue,
+                                          SmallVectorImpl<ConstantInt*> &Vals,
+                                          const DataLayout *DL,
+                                          unsigned &UsedICmps,
+                                          bool isEQ) {
+
+  // If this is an icmp against a constant, handle this as one of the cases.
+  ICmpInst *ICI;
+  ConstantInt *C;
+  if (not ((ICI = dyn_cast<ICmpInst>(I)) &&
+           (C = GetConstantInt(I->getOperand(1), DL)))) {
+    return nullptr;
+  }
+
+  Value *RHSVal;
+  ConstantInt *RHSC;
+
+  // Pattern match a special case
+  // (x & ~2^x) == y --> x == y || x == y|2^x
+  // This undoes a transformation done by instcombine to fuse 2 compares.
+  if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ:ICmpInst::ICMP_NE)) {
+    if (match(ICI->getOperand(0),
+              m_And(m_Value(RHSVal), m_ConstantInt(RHSC)))) {
+      APInt Not = ~RHSC->getValue();
+      if (Not.isPowerOf2()) {
+        // If we already have a value for the switch, it has to match!
+        if(CurrValue && CurrValue != RHSVal)
+          return nullptr;
+
+        Vals.push_back(C);
+        Vals.push_back(ConstantInt::get(C->getContext(),
+                                        C->getValue() | Not));
+        UsedICmps++;
+        return RHSVal;
+      }
+    }
+
+    // If we already have a value for the switch, it has to match!
+    if(CurrValue && CurrValue != ICI->getOperand(0))
+      return nullptr;
+
+    UsedICmps++;
+    Vals.push_back(C);
+    return ICI->getOperand(0);
+  }
+
+  // If we have "x ult 3", for example, then we can add 0,1,2 to the set.
+  ConstantRange Span = ConstantRange::makeICmpRegion(ICI->getPredicate(),
+                                                     C->getValue());
+
+  // Shift the range if the compare is fed by an add. This is the range
+  // compare idiom as emitted by instcombine.
+  Value *CandidateVal = I->getOperand(0);
+  if(match(I->getOperand(0), m_Add(m_Value(RHSVal), m_ConstantInt(RHSC)))) {
+    Span = Span.subtract(RHSC->getValue());
+    CandidateVal = RHSVal;
+  }
+
+  // If we already have a value for the switch, it has to match!
+  if(CurrValue && CurrValue != CandidateVal)
+    return nullptr;
+
+  // If this is an and/!= check, then we are looking to build the set of
+  // value that *don't* pass the and chain. I.e. to turn "x ugt 2" into
+  // x != 0 && x != 1.
+  if (!isEQ)
+    Span = Span.inverse();
+
+  // If there are a ton of values, we don't want to make a ginormous switch.
+  if (Span.getSetSize().ugt(8) || Span.isEmptySet()) {
+    return nullptr;
+  }
+
+  // Add all values from the range to the set
+  for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp)
+    Vals.push_back(ConstantInt::get(I->getContext(), Tmp));
+
+  UsedICmps++;
+  return CandidateVal;
+
+}
+
 /// GatherConstantCompares - Given a potentially 'or'd or 'and'd together
 /// collection of icmp eq/ne instructions that compare a value against a
 /// constant, return the value being compared, and stick the constant into the
 /// Values vector.
+/// One "Extra" case is allowed to differ from the other.
 static Value *
-GatherConstantCompares(Value *V, std::vector<ConstantInt*> &Vals, Value *&Extra,
-                       const DataLayout *DL, bool isEQ, unsigned &UsedICmps) {
+GatherConstantCompares(Value *V, SmallVectorImpl<ConstantInt*> &Vals, Value *&Extra,
+                       const DataLayout *DL, unsigned &UsedICmps) {
   Instruction *I = dyn_cast<Instruction>(V);
   if (!I) return nullptr;
 
-  // If this is an icmp against a constant, handle this as one of the cases.
-  if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) {
-    if (ConstantInt *C = GetConstantInt(I->getOperand(1), DL)) {
-      Value *RHSVal;
-      ConstantInt *RHSC;
-
-      if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ:ICmpInst::ICMP_NE)) {
-        // (x & ~2^x) == y --> x == y || x == y|2^x
-        // This undoes a transformation done by instcombine to fuse 2 compares.
-        if (match(ICI->getOperand(0),
-                  m_And(m_Value(RHSVal), m_ConstantInt(RHSC)))) {
-          APInt Not = ~RHSC->getValue();
-          if (Not.isPowerOf2()) {
-            Vals.push_back(C);
-            Vals.push_back(
-                ConstantInt::get(C->getContext(), C->getValue() | Not));
-            UsedICmps++;
-            return RHSVal;
-          }
-        }
+  bool isEQ = (I->getOpcode() == Instruction::Or);
 
-        UsedICmps++;
-        Vals.push_back(C);
-        return I->getOperand(0);
-      }
+  // Keep a stack (SmallVector for efficiency) for depth-first traversal
+  SmallVector<Value *, 8> DFT;
 
-      // If we have "x ult 3" comparison, for example, then we can add 0,1,2 to
-      // the set.
-      ConstantRange Span =
-        ConstantRange::makeICmpRegion(ICI->getPredicate(), C->getValue());
-
-      // Shift the range if the compare is fed by an add. This is the range
-      // compare idiom as emitted by instcombine.
-      bool hasAdd =
-          match(I->getOperand(0), m_Add(m_Value(RHSVal), m_ConstantInt(RHSC)));
-      if (hasAdd)
-        Span = Span.subtract(RHSC->getValue());
-
-      // If this is an and/!= check then we want to optimize "x ugt 2" into
-      // x != 0 && x != 1.
-      if (!isEQ)
-        Span = Span.inverse();
-
-      // If there are a ton of values, we don't want to make a ginormous switch.
-      if (Span.getSetSize().ugt(8) || Span.isEmptySet())
-        return nullptr;
-
-      for (APInt Tmp = Span.getLower(); Tmp != Span.getUpper(); ++Tmp)
-        Vals.push_back(ConstantInt::get(V->getContext(), Tmp));
-      UsedICmps++;
-      return hasAdd ? RHSVal : I->getOperand(0);
-    }
-    return nullptr;
-  }
+  // Initialize
+  DFT.push_back(V);
 
-  // Otherwise, we can only handle an | or &, depending on isEQ.
-  if (I->getOpcode() != (isEQ ? Instruction::Or : Instruction::And))
-    return nullptr;
+  // Will hold the value used for the switch comparison
+  Value *CurrValue = nullptr;
 
-  unsigned NumValsBeforeLHS = Vals.size();
-  unsigned UsedICmpsBeforeLHS = UsedICmps;
-  if (Value *LHS = GatherConstantCompares(I->getOperand(0), Vals, Extra, DL,
-                                          isEQ, UsedICmps)) {
-    unsigned NumVals = Vals.size();
-    unsigned UsedICmpsBeforeRHS = UsedICmps;
-    if (Value *RHS = GatherConstantCompares(I->getOperand(1), Vals, Extra, DL,
-                                            isEQ, UsedICmps)) {
-      if (LHS == RHS)
-        return LHS;
-      Vals.resize(NumVals);
-      UsedICmps = UsedICmpsBeforeRHS;
-    }
+  while(not DFT.empty()) {
+    V = DFT.pop_back_val();
 
-    // The RHS of the or/and can't be folded in and we haven't used "Extra" yet,
-    // set it and return success.
-    if (Extra == nullptr || Extra == I->getOperand(1)) {
-      Extra = I->getOperand(1);
-      return LHS;
+    if (Instruction *I = dyn_cast<Instruction>(V)) {
+
+      // If it is a || (or && depending on isEQ), process the operands.
+      if (I->getOpcode() == (isEQ ? Instruction::Or : Instruction::And)) {
+        DFT.push_back(I->getOperand(1));
+        DFT.push_back(I->getOperand(0));
+        continue;
+      }
+
+      // Try to match the current instruction
+      if (Value *Matched = GatherConstantComparesMatch(I,
+                                                       CurrValue,
+                                                       Vals,
+                                                       DL,
+                                                       UsedICmps,
+                                                       isEQ)) {
+        // Match succeed, continue the loop
+        CurrValue = Matched;
+        continue;
+      }
     }
 
-    Vals.resize(NumValsBeforeLHS);
-    UsedICmps = UsedICmpsBeforeLHS;
+    // One element of the sequence of || (or &&) could not be match as a
+    // comparison against the same value as the others.
+    // We allow only one "Extra" case to be checked before the switch
+    if (Extra == nullptr) {
+      Extra = V;
+      continue;
+    }
     return nullptr;
-  }
 
-  // If the LHS can't be folded in, but Extra is available and RHS can, try to
-  // use LHS as Extra.
-  if (Extra == nullptr || Extra == I->getOperand(0)) {
-    Value *OldExtra = Extra;
-    Extra = I->getOperand(0);
-    if (Value *RHS = GatherConstantCompares(I->getOperand(1), Vals, Extra, DL,
-                                            isEQ, UsedICmps))
-      return RHS;
-    assert(Vals.size() == NumValsBeforeLHS);
-    Extra = OldExtra;
   }
 
-  return nullptr;
+  // Return the value to be used for the switch comparison (if any)
+  return CurrValue;
 }
 
 static void EraseTerminatorInstAndDCECond(TerminatorInst *TI) {
@@ -2770,19 +2815,13 @@ static bool SimplifyBranchOnICmpChain(BranchInst *BI, const DataLayout *DL,
   // If this is a bunch of seteq's or'd together, or if it's a bunch of
   // 'setne's and'ed together, collect them.
   Value *CompVal = nullptr;
-  std::vector<ConstantInt*> Values;
-  bool TrueWhenEqual = true;
+  SmallVector<ConstantInt*, 8> Values;
+  bool TrueWhenEqual = (Cond->getOpcode() == Instruction::Or);
   Value *ExtraCase = nullptr;
   unsigned UsedICmps = 0;
 
-  if (Cond->getOpcode() == Instruction::Or) {
-    CompVal = GatherConstantCompares(Cond, Values, ExtraCase, DL, true,
-                                     UsedICmps);
-  } else if (Cond->getOpcode() == Instruction::And) {
-    CompVal = GatherConstantCompares(Cond, Values, ExtraCase, DL, false,
-                                     UsedICmps);
-    TrueWhenEqual = false;
-  }
+  // Try to gather values from a chain of and/or to be turned into a switch
+  CompVal = GatherConstantCompares(Cond, Values, ExtraCase, DL, UsedICmps);
 
   // If we didn't have a multiply compared value, fail.
   if (!CompVal) return false;