[SCEV][LV] Add SCEV Predicates and use them to re-implement stride versioning
[oota-llvm.git] / lib / Analysis / ScalarEvolutionExpander.cpp
index 81316849847a0ca09a96366707d2ffd087db12de..2c2e5828003aea6dd76c7ce27108b0df39d16a1d 100644 (file)
@@ -1944,6 +1944,43 @@ bool SCEVExpander::isHighCostExpansionHelper(
   return false;
 }
 
+Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred,
+                                            Instruction *IP) {
+  assert(IP);
+  switch (Pred->getKind()) {
+  case SCEVPredicate::P_Union:
+    return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP);
+  case SCEVPredicate::P_Equal:
+    return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP);
+  }
+  llvm_unreachable("Unknown SCEV predicate type");
+}
+
+Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred,
+                                          Instruction *IP) {
+  Value *Expr0 = expandCodeFor(Pred->getLHS(), Pred->getLHS()->getType(), IP);
+  Value *Expr1 = expandCodeFor(Pred->getRHS(), Pred->getRHS()->getType(), IP);
+
+  Builder.SetInsertPoint(IP);
+  auto *I = Builder.CreateICmpNE(Expr0, Expr1, "ident.check");
+  return I;
+}
+
+Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union,
+                                          Instruction *IP) {
+  auto *BoolType = IntegerType::get(IP->getContext(), 1);
+  Value *Check = ConstantInt::getNullValue(BoolType);
+
+  // Loop over all checks in this set.
+  for (auto Pred : Union->getPredicates()) {
+    auto *NextCheck = expandCodeForPredicate(Pred, IP);
+    Builder.SetInsertPoint(IP);
+    Check = Builder.CreateOr(Check, NextCheck);
+  }
+
+  return Check;
+}
+
 namespace {
 // Search for a SCEV subexpression that is not safe to expand.  Any expression
 // that may expand to a !isSafeToSpeculativelyExecute value is unsafe, namely