IR: Split Metadata from Value
[oota-llvm.git] / lib / Analysis / ScalarEvolution.cpp
index 55472ecb6671f6b9291749fc053f902219043d74..177bd23030a7bec57f22db21cafa2d7e5c8462e2 100644 (file)
@@ -703,6 +703,34 @@ static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) {
   return APIntOps::sdiv(A, B);
 }
 
+static const APInt urem(const SCEVConstant *C1, const SCEVConstant *C2) {
+  APInt A = C1->getValue()->getValue();
+  APInt B = C2->getValue()->getValue();
+  uint32_t ABW = A.getBitWidth();
+  uint32_t BBW = B.getBitWidth();
+
+  if (ABW > BBW)
+    B = B.zext(ABW);
+  else if (ABW < BBW)
+    A = A.zext(BBW);
+
+  return APIntOps::urem(A, B);
+}
+
+static const APInt udiv(const SCEVConstant *C1, const SCEVConstant *C2) {
+  APInt A = C1->getValue()->getValue();
+  APInt B = C2->getValue()->getValue();
+  uint32_t ABW = A.getBitWidth();
+  uint32_t BBW = B.getBitWidth();
+
+  if (ABW > BBW)
+    B = B.zext(ABW);
+  else if (ABW < BBW)
+    A = A.zext(BBW);
+
+  return APIntOps::udiv(A, B);
+}
+
 namespace {
 struct FindSCEVSize {
   int Size;
@@ -729,7 +757,8 @@ static inline int sizeOfSCEV(const SCEV *S) {
 
 namespace {
 
-struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> {
+template <typename Derived>
+struct SCEVDivision : public SCEVVisitor<Derived, void> {
 public:
   // Computes the Quotient and Remainder of the division of Numerator by
   // Denominator.
@@ -738,7 +767,7 @@ public:
                      const SCEV **Remainder) {
     assert(Numerator && Denominator && "Uninitialized SCEV");
 
-    SCEVDivision D(SE, Numerator, Denominator);
+    Derived D(SE, Numerator, Denominator);
 
     // Check for the trivial case here to avoid having to check for it in the
     // rest of the code.
@@ -779,17 +808,6 @@ public:
     *Remainder = D.Remainder;
   }
 
-  SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator)
-      : SE(S), Denominator(Denominator) {
-    Zero = SE.getConstant(Denominator->getType(), 0);
-    One = SE.getConstant(Denominator->getType(), 1);
-
-    // By default, we don't know how to divide Expr by Denominator.
-    // Providing the default here simplifies the rest of the code.
-    Quotient = Zero;
-    Remainder = Numerator;
-  }
-
   // Except in the trivial case described above, we do not know how to divide
   // Expr by Denominator for the following functions with empty implementation.
   void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {}
@@ -801,14 +819,6 @@ public:
   void visitUnknown(const SCEVUnknown *Numerator) {}
   void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {}
 
-  void visitConstant(const SCEVConstant *Numerator) {
-    if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
-      Quotient = SE.getConstant(sdiv(Numerator, D));
-      Remainder = SE.getConstant(srem(Numerator, D));
-      return;
-    }
-  }
-
   void visitAddRecExpr(const SCEVAddRecExpr *Numerator) {
     const SCEV *StartQ, *StartR, *StepQ, *StepR;
     assert(Numerator->isAffine() && "Numerator should be affine");
@@ -932,12 +942,54 @@ public:
   }
 
 private:
+  SCEVDivision(ScalarEvolution &S, const SCEV *Numerator,
+               const SCEV *Denominator)
+      : SE(S), Denominator(Denominator) {
+    Zero = SE.getConstant(Denominator->getType(), 0);
+    One = SE.getConstant(Denominator->getType(), 1);
+
+    // By default, we don't know how to divide Expr by Denominator.
+    // Providing the default here simplifies the rest of the code.
+    Quotient = Zero;
+    Remainder = Numerator;
+  }
+
   ScalarEvolution &SE;
   const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;
+
+  friend struct SCEVSDivision;
+  friend struct SCEVUDivision;
 };
-}
 
+struct SCEVSDivision : public SCEVDivision<SCEVSDivision> {
+  SCEVSDivision(ScalarEvolution &S, const SCEV *Numerator,
+                const SCEV *Denominator)
+      : SCEVDivision(S, Numerator, Denominator) {}
+
+  void visitConstant(const SCEVConstant *Numerator) {
+    if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
+      Quotient = SE.getConstant(sdiv(Numerator, D));
+      Remainder = SE.getConstant(srem(Numerator, D));
+      return;
+    }
+  }
+};
 
+struct SCEVUDivision : public SCEVDivision<SCEVUDivision> {
+  SCEVUDivision(ScalarEvolution &S, const SCEV *Numerator,
+                const SCEV *Denominator)
+      : SCEVDivision(S, Numerator, Denominator) {}
+
+  void visitConstant(const SCEVConstant *Numerator) {
+    if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) {
+      Quotient = SE.getConstant(udiv(Numerator, D));
+      Remainder = SE.getConstant(urem(Numerator, D));
+      return;
+    }
+  }
+};
+
+}
 
 //===----------------------------------------------------------------------===//
 //                      Simple SCEV method implementations
@@ -2155,6 +2207,25 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
   return r;
 }
 
+/// Determine if any of the operands in this SCEV are a constant or if
+/// any of the add or multiply expressions in this SCEV contain a constant.
+static bool containsConstantSomewhere(const SCEV *StartExpr) {
+  SmallVector<const SCEV *, 4> Ops;
+  Ops.push_back(StartExpr);
+  while (!Ops.empty()) {
+    const SCEV *CurrentExpr = Ops.pop_back_val();
+    if (isa<SCEVConstant>(*CurrentExpr))
+      return true;
+
+    if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) {
+      const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr);
+      for (const SCEV *Operand : CurrentNAry->operands())
+        Ops.push_back(Operand);
+    }
+  }
+  return false;
+}
+
 /// getMulExpr - Get a canonical multiply expression, or something simpler if
 /// possible.
 const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
@@ -2194,11 +2265,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
 
     // C1*(C2+V) -> C1*C2 + C1*V
     if (Ops.size() == 2)
-      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
-        if (Add->getNumOperands() == 2 &&
-            isa<SCEVConstant>(Add->getOperand(0)))
-          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
-                            getMulExpr(LHSC, Add->getOperand(1)));
+        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
+          // If any of Add's ops are Adds or Muls with a constant,
+          // apply this transformation as well.
+          if (Add->getNumOperands() == 2)
+            if (containsConstantSomewhere(Add))
+              return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
+                                getMulExpr(LHSC, Add->getOperand(1)));
 
     ++Idx;
     while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
@@ -3343,7 +3416,8 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
   Visited.insert(PN);
   while (!Worklist.empty()) {
     Instruction *I = Worklist.pop_back_val();
-    if (!Visited.insert(I)) continue;
+    if (!Visited.insert(I).second)
+      continue;
 
     ValueExprMapType::iterator It =
       ValueExprMap.find_as(static_cast<Value *>(I));
@@ -3676,8 +3750,10 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
       assert(NumRanges >= 1);
 
       for (unsigned i = 0; i < NumRanges; ++i) {
-        ConstantInt *Lower = cast<ConstantInt>(MD->getOperand(2*i + 0));
-        ConstantInt *Upper = cast<ConstantInt>(MD->getOperand(2*i + 1));
+        ConstantInt *Lower =
+            mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 0));
+        ConstantInt *Upper =
+            mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 1));
         ConstantRange Range(Lower->getValue(), Upper->getValue());
         TotalRange = TotalRange.unionWith(Range);
       }
@@ -4541,7 +4617,8 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
     SmallPtrSet<Instruction *, 8> Visited;
     while (!Worklist.empty()) {
       Instruction *I = Worklist.pop_back_val();
-      if (!Visited.insert(I)) continue;
+      if (!Visited.insert(I).second)
+        continue;
 
       ValueExprMapType::iterator It =
         ValueExprMap.find_as(static_cast<Value *>(I));
@@ -4593,7 +4670,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
   SmallPtrSet<Instruction *, 8> Visited;
   while (!Worklist.empty()) {
     Instruction *I = Worklist.pop_back_val();
-    if (!Visited.insert(I)) continue;
+    if (!Visited.insert(I).second)
+      continue;
 
     ValueExprMapType::iterator It =
       ValueExprMap.find_as(static_cast<Value *>(I));
@@ -4627,7 +4705,8 @@ void ScalarEvolution::forgetValue(Value *V) {
   SmallPtrSet<Instruction *, 8> Visited;
   while (!Worklist.empty()) {
     I = Worklist.pop_back_val();
-    if (!Visited.insert(I)) continue;
+    if (!Visited.insert(I).second)
+      continue;
 
     ValueExprMapType::iterator It =
       ValueExprMap.find_as(static_cast<Value *>(I));
@@ -6086,7 +6165,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) {
   // backedge count.
   const SCEV *Q, *R;
   ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
-  SCEVDivision::divide(SE, Distance, Step, &Q, &R);
+  SCEVUDivision::divide(SE, Distance, Step, &Q, &R);
   if (R->isZero()) {
     const SCEV *Exact =
         getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
@@ -7401,7 +7480,7 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,
   for (const SCEV *&Term : Terms) {
     // Normalize the terms before the next call to findArrayDimensionsRec.
     const SCEV *Q, *R;
-    SCEVDivision::divide(SE, Term, Step, &Q, &R);
+    SCEVSDivision::divide(SE, Term, Step, &Q, &R);
 
     // Bail out when GCD does not evenly divide one of the terms.
     if (!R->isZero())
@@ -7538,7 +7617,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
   // Divide all terms by the element size.
   for (const SCEV *&Term : Terms) {
     const SCEV *Q, *R;
-    SCEVDivision::divide(SE, Term, ElementSize, &Q, &R);
+    SCEVSDivision::divide(SE, Term, ElementSize, &Q, &R);
     Term = Q;
   }
 
@@ -7585,7 +7664,7 @@ void SCEVAddRecExpr::computeAccessFunctions(
   int Last = Sizes.size() - 1;
   for (int i = Last; i >= 0; i--) {
     const SCEV *Q, *R;
-    SCEVDivision::divide(SE, Res, Sizes[i], &Q, &R);
+    SCEVSDivision::divide(SE, Res, Sizes[i], &Q, &R);
 
     DEBUG({
         dbgs() << "Res: " << *Res << "\n";
@@ -7740,7 +7819,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
     // that until everything else is done.
     if (U == Old)
       continue;
-    if (!Visited.insert(U))
+    if (!Visited.insert(U).second)
       continue;
     if (PHINode *PN = dyn_cast<PHINode>(U))
       SE->ConstantEvolutionLoopExitValue.erase(PN);