Fix refactoring mistake in "Teach InstCombine to work with smaller legal types..."
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index e223a049f0b0606a285b5d5d7c4f044a569aef6b..bad46b4dab3accc67a70329f229675f931a17e48 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
-#include "llvm/IntrinsicInst.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
-#include "llvm/DataLayout.h"
-#include "llvm/Target/TargetLibraryInfo.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/Support/ConstantRange.h"
 #include "llvm/Support/GetElementPtrTypeIterator.h"
 #include "llvm/Support/PatternMatch.h"
+#include "llvm/Target/TargetLibraryInfo.h"
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -1226,6 +1226,16 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
         ICI.setOperand(0, NewAnd);
         return &ICI;
       }
+
+      // Replace ((X & AndCST) > RHSV) with ((X & AndCST) != 0), if any
+      // bit set in (X & AndCST) will produce a result greater than RHSV.
+      if (ICI.getPredicate() == ICmpInst::ICMP_UGT) {
+        unsigned NTZ = AndCST->getValue().countTrailingZeros();
+        if ((NTZ < AndCST->getBitWidth()) &&
+            APInt::getOneBitSet(AndCST->getBitWidth(), NTZ).ugt(RHSV))
+          return new ICmpInst(ICmpInst::ICMP_NE, LHSI,
+                              Constant::getNullValue(RHS->getType()));
+      }
     }
 
     // Try to optimize things like "A[i]&42 == 0" to index computations.
@@ -1321,6 +1331,25 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
       return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ,
                           And, Constant::getNullValue(And->getType()));
     }
+
+    // Transform (icmp pred iM (shl iM %v, N), CI)
+    // -> (icmp pred i(M-N) (trunc %v iM to i(N-N)), (trunc (CI>>N))
+    // Transform the shl to a trunc if (trunc (CI>>N)) has no loss.
+    // This enables to get rid of the shift in favor of a trunc which can be
+    // free on the target. It has the additional benefit of comparing to a
+    // smaller constant, which will be target friendly.
+    unsigned Amt = ShAmt->getLimitedValue(TypeBits-1);
+    if (Amt != 0 && RHSV.countTrailingZeros() >= Amt) {
+      Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt);
+      Constant *NCI = ConstantExpr::getTrunc(
+                        ConstantExpr::getAShr(RHS,
+                          ConstantInt::get(RHS->getType(), Amt)),
+                        NTy);
+      return new ICmpInst(ICI.getPredicate(),
+                          Builder->CreateTrunc(LHSI->getOperand(0), NTy),
+                          NCI);
+    }
+
     break;
   }
 
@@ -2358,15 +2387,20 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       // Determine Y and Z in the form icmp (X+Y), (X+Z).
       Value *Y, *Z;
       if (A == C) {
+        // C + B == C + D  ->  B == D
         Y = B;
         Z = D;
       } else if (A == D) {
+        // D + B == C + D  ->  B == C
         Y = B;
         Z = C;
       } else if (B == C) {
+        // A + C == C + D  ->  A == D
         Y = A;
         Z = D;
-      } else if (B == D) {
+      } else {
+        assert(B == D);
+        // A + D == C + D  ->  A == C
         Y = A;
         Z = C;
       }