[ValueTracking] Add a framework for encoding implication rules
authorSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 6 Nov 2015 19:00:57 +0000 (19:00 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Fri, 6 Nov 2015 19:00:57 +0000 (19:00 +0000)
Summary:
This change adds a framework for adding more smarts to
`isImpliedCondition` around inequalities.  Informally,
`isImpliedCondition` will now try to prove "A < B ==> C < D" by proving
"C <= A && B <= D", since then it follows "C <= A < B <= D".

While this change is in principle NFC, I could not think of a way to not
handle cases like "i +_nsw 1 < L ==> i < L +_nsw 1" (that ValueTracking
did not handle before) while keeping the change understandable.  I've
added tests for these cases.

Reviewers: reames, majnemer, hfinkel

Subscribers: llvm-commits

Differential Revision: http://reviews.llvm.org/D14368

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@252331 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Analysis/ValueTracking.cpp
test/Transforms/InstSimplify/implies.ll

index 1187de7b59bd40e98aefc4536451199eca5213e2..3dc9f3a10370d19d02b84272af162f826147d965 100644 (file)
@@ -4082,6 +4082,65 @@ ConstantRange llvm::getConstantRangeFromMetadata(MDNode &Ranges) {
   return CR;
 }
 
   return CR;
 }
 
+/// Return true if "icmp Pred LHS RHS" is always true.
+static bool isTruePredicate(CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
+  if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS)
+    return true;
+
+  switch (Pred) {
+  default:
+    return false;
+
+  case CmpInst::ICMP_SLT:
+  case CmpInst::ICMP_SLE: {
+    ConstantInt *CI;
+
+    // LHS s<  LHS +_{nsw} C   if C > 0
+    // LHS s<= LHS +_{nsw} C   if C >= 0
+    if (match(RHS, m_NSWAdd(m_Specific(LHS), m_ConstantInt(CI)))) {
+      if (Pred == CmpInst::ICMP_SLT)
+        return CI->getValue().isStrictlyPositive();
+      return !CI->isNegative();
+    }
+    return false;
+  }
+
+  case CmpInst::ICMP_ULT:
+  case CmpInst::ICMP_ULE: {
+    ConstantInt *CI;
+
+    // LHS u<  LHS +_{nuw} C   if C > 0
+    // LHS u<= LHS +_{nuw} C   if C >= 0
+    if (match(RHS, m_NUWAdd(m_Specific(LHS), m_ConstantInt(CI)))) {
+      if (Pred == CmpInst::ICMP_ULT)
+        return CI->getValue().isStrictlyPositive();
+      return !CI->isNegative();
+    }
+    return false;
+  }
+  }
+}
+
+/// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
+/// ALHS ARHS" is true.
+static bool isImpliedCondOperands(CmpInst::Predicate Pred, Value *ALHS,
+                                  Value *ARHS, Value *BLHS, Value *BRHS) {
+  switch (Pred) {
+  default:
+    return false;
+
+  case CmpInst::ICMP_SLT:
+  case CmpInst::ICMP_SLE:
+    return isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS) &&
+           isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS);
+
+  case CmpInst::ICMP_ULT:
+  case CmpInst::ICMP_ULE:
+    return isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS) &&
+           isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS);
+  }
+}
+
 bool llvm::isImpliedCondition(Value *LHS, Value *RHS) {
   assert(LHS->getType() == RHS->getType() && "mismatched type");
   Type *OpTy = LHS->getType();
 bool llvm::isImpliedCondition(Value *LHS, Value *RHS) {
   assert(LHS->getType() == RHS->getType() && "mismatched type");
   Type *OpTy = LHS->getType();
@@ -4096,28 +4155,15 @@ bool llvm::isImpliedCondition(Value *LHS, Value *RHS) {
   assert(OpTy->isIntegerTy(1) && "implied by above");
 
   ICmpInst::Predicate APred, BPred;
   assert(OpTy->isIntegerTy(1) && "implied by above");
 
   ICmpInst::Predicate APred, BPred;
-  Value *I;
-  Value *L;
-  ConstantInt *CI;
-  // i +_{nsw} C_{>0} <s L ==> i <s L
-  if (match(LHS, m_ICmp(APred,
-                        m_NSWAdd(m_Value(I), m_ConstantInt(CI)),
-                        m_Value(L))) &&
-      APred == ICmpInst::ICMP_SLT &&
-      !CI->isNegative() &&
-      match(RHS, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
-      BPred == ICmpInst::ICMP_SLT)
-    return true;
+  Value *ALHS, *ARHS;
+  Value *BLHS, *BRHS;
 
 
-  // i +_{nuw} C_{>0} <u L ==> i <u L
-  if (match(LHS, m_ICmp(APred,
-                        m_NUWAdd(m_Value(I), m_ConstantInt(CI)),
-                        m_Value(L))) &&
-      APred == ICmpInst::ICMP_ULT &&
-      !CI->isNegative() &&
-      match(RHS, m_ICmp(BPred, m_Specific(I), m_Specific(L))) &&
-      BPred == ICmpInst::ICMP_ULT)
-    return true;
+  if (!match(LHS, m_ICmp(APred, m_Value(ALHS), m_Value(ARHS))) ||
+      !match(RHS, m_ICmp(BPred, m_Value(BLHS), m_Value(BRHS))))
+    return false;
+
+  if (APred == BPred)
+    return isImpliedCondOperands(APred, ALHS, ARHS, BLHS, BRHS);
 
   return false;
 }
 
   return false;
 }
index ac46b8d28280dcc4346804e58a17403584a99464..8e5bbf2c89709603b67dda3ae13f29cdce1dc161 100644 (file)
@@ -92,6 +92,30 @@ define <4 x i1> @test6(<4 x i1> %a, <4 x i1> %b) {
   ret <4 x i1> %res
 }
 
   ret <4 x i1> %res
 }
 
+; i +_{nsw} 1 <s L  ==> i < L +_{nsw} 1
+define i1 @test7(i32 %length.i, i32 %i) {
+; CHECK-LABEL: @test7(
+; CHECK: ret i1 true
+  %iplus1 = add nsw i32 %i, 1
+  %len.plus.one = add nsw i32 %length.i, 1
+  %var29 = icmp slt i32 %i, %len.plus.one
+  %var30 = icmp slt i32 %iplus1, %length.i
+  %res = icmp ule i1 %var30, %var29
+  ret i1 %res
+}
+
+; i +_{nuw} 1 <s L  ==> i < L +_{nuw} 1
+define i1 @test8(i32 %length.i, i32 %i) {
+; CHECK-LABEL: @test8(
+; CHECK: ret i1 true
+  %iplus1 = add nuw i32 %i, 1
+  %len.plus.one = add nuw i32 %length.i, 1
+  %var29 = icmp ult i32 %i, %len.plus.one
+  %var30 = icmp ult i32 %iplus1, %length.i
+  %res = icmp ule i1 %var30, %var29
+  ret i1 %res
+}
+
 ; X >=(s) Y == X ==> Y (i1 1 becomes -1 for reasoning)
 define i1 @test_sge(i32 %length.i, i32 %i) {
 ; CHECK-LABEL: @test_sge
 ; X >=(s) Y == X ==> Y (i1 1 becomes -1 for reasoning)
 define i1 @test_sge(i32 %length.i, i32 %i) {
 ; CHECK-LABEL: @test_sge