Address issues found by Duncan during post-commit review of r177856.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index fcd805b0396ceba7709dd54a31e3d6e1d41264da..a96e754f3dd089c1dd478856c9a7be49cccf0b78 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
-#include "llvm/IntrinsicInst.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
-#include "llvm/DataLayout.h"
-#include "llvm/Target/TargetLibraryInfo.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/Support/ConstantRange.h"
 #include "llvm/Support/GetElementPtrTypeIterator.h"
 #include "llvm/Support/PatternMatch.h"
+#include "llvm/Target/TargetLibraryInfo.h"
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -139,6 +139,31 @@ static bool isSignBitCheck(ICmpInst::Predicate pred, ConstantInt *RHS,
   }
 }
 
+/// Returns true if the exploded icmp can be expressed as a signed comparison
+/// to zero and updates the predicate accordingly.
+/// The signedness of the comparison is preserved.
+static bool isSignTest(ICmpInst::Predicate &pred, const ConstantInt *RHS) {
+  if (!ICmpInst::isSigned(pred))
+    return false;
+
+  if (RHS->isZero())
+    return ICmpInst::isRelational(pred);
+
+  if (RHS->isOne()) {
+    if (pred == ICmpInst::ICMP_SLT) {
+      pred = ICmpInst::ICMP_SLE;
+      return true;
+    }
+  } else if (RHS->isAllOnesValue()) {
+    if (pred == ICmpInst::ICMP_SGT) {
+      pred = ICmpInst::ICMP_SGE;
+      return true;
+    }
+  }
+
+  return false;
+}
+
 // isHighOnes - Return true if the constant is of the form 1+0+.
 // This is the same as lowones(~X).
 static bool isHighOnes(const ConstantInt *CI) {
@@ -365,13 +390,12 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
   // order the state machines in complexity of the generated code.
   Value *Idx = GEP->getOperand(2);
 
-  unsigned AS = GEP->getPointerAddressSpace();
   // If the index is larger than the pointer size of the target, truncate the
   // index down like the GEP would do implicitly.  We don't have to do this for
   // an inbounds GEP because the index can't be out of range.
   if (!GEP->isInBounds() &&
-      Idx->getType()->getPrimitiveSizeInBits() > TD->getPointerSizeInBits(AS))
-    Idx = Builder->CreateTrunc(Idx, TD->getIntPtrType(Idx->getContext(), AS));
+      Idx->getType()->getPrimitiveSizeInBits() > TD->getPointerSizeInBits())
+    Idx = Builder->CreateTrunc(Idx, TD->getIntPtrType(Idx->getContext()));
 
   // If the comparison is only true for one or two elements, emit direct
   // comparisons.
@@ -444,20 +468,29 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
   }
 
 
-  // If a 32-bit or 64-bit magic bitvector captures the entire comparison state
+  // If a magic bitvector captures the entire comparison state
   // of this load, replace it with computation that does:
   //   ((magic_cst >> i) & 1) != 0
-  if (ArrayElementCount <= 32 ||
-      (TD && ArrayElementCount <= 64 && TD->isLegalInteger(64))) {
-    Type *Ty;
-    if (ArrayElementCount <= 32)
+  {
+    Type *Ty = 0;
+
+    // Look for an appropriate type:
+    // - The type of Idx if the magic fits
+    // - The smallest fitting legal type if we have a DataLayout
+    // - Default to i32
+    if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth())
+      Ty = Idx->getType();
+    else if (TD)
+      Ty = TD->getSmallestLegalIntType(Init->getContext(), ArrayElementCount);
+    else if (ArrayElementCount <= 32)
       Ty = Type::getInt32Ty(Init->getContext());
-    else
-      Ty = Type::getInt64Ty(Init->getContext());
-    Value *V = Builder->CreateIntCast(Idx, Ty, false);
-    V = Builder->CreateLShr(ConstantInt::get(Ty, MagicBitvector), V);
-    V = Builder->CreateAnd(ConstantInt::get(Ty, 1), V);
-    return new ICmpInst(ICmpInst::ICMP_NE, V, ConstantInt::get(Ty, 0));
+
+    if (Ty != 0) {
+      Value *V = Builder->CreateIntCast(Idx, Ty, false);
+      V = Builder->CreateLShr(ConstantInt::get(Ty, MagicBitvector), V);
+      V = Builder->CreateAnd(ConstantInt::get(Ty, 1), V);
+      return new ICmpInst(ICmpInst::ICMP_NE, V, ConstantInt::get(Ty, 0));
+    }
   }
 
   return 0;
@@ -529,17 +562,16 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) {
     }
   }
 
-  unsigned AS = cast<GetElementPtrInst>(GEP)->getPointerAddressSpace();
   // Okay, we know we have a single variable index, which must be a
   // pointer/array/vector index.  If there is no offset, life is simple, return
   // the index.
-  unsigned IntPtrWidth = TD.getPointerSizeInBits(AS);
+  unsigned IntPtrWidth = TD.getPointerSizeInBits();
   if (Offset == 0) {
     // Cast to intptrty in case a truncation occurs.  If an extension is needed,
     // we don't need to bother extending: the extension won't affect where the
     // computation crosses zero.
     if (VariableIdx->getType()->getPrimitiveSizeInBits() > IntPtrWidth) {
-      Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext(), AS);
+      Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext());
       VariableIdx = IC.Builder->CreateTrunc(VariableIdx, IntPtrTy);
     }
     return VariableIdx;
@@ -561,7 +593,7 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) {
     return 0;
 
   // Okay, we can do this evaluation.  Start by converting the index to intptr.
-  Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext(), AS);
+  Type *IntPtrTy = TD.getIntPtrType(VariableIdx->getContext());
   if (VariableIdx->getType() != IntPtrTy)
     VariableIdx = IC.Builder->CreateIntCast(VariableIdx, IntPtrTy,
                                             true /*Signed*/);
@@ -1228,6 +1260,16 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
         ICI.setOperand(0, NewAnd);
         return &ICI;
       }
+
+      // Replace ((X & AndCST) > RHSV) with ((X & AndCST) != 0), if any
+      // bit set in (X & AndCST) will produce a result greater than RHSV.
+      if (ICI.getPredicate() == ICmpInst::ICMP_UGT) {
+        unsigned NTZ = AndCST->getValue().countTrailingZeros();
+        if ((NTZ < AndCST->getBitWidth()) &&
+            APInt::getOneBitSet(AndCST->getBitWidth(), NTZ).ugt(RHSV))
+          return new ICmpInst(ICmpInst::ICMP_NE, LHSI,
+                              Constant::getNullValue(RHS->getType()));
+      }
     }
 
     // Try to optimize things like "A[i]&42 == 0" to index computations.
@@ -1265,6 +1307,23 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
     break;
   }
 
+  case Instruction::Mul: {       // (icmp pred (mul X, Val), CI)
+    ConstantInt *Val = dyn_cast<ConstantInt>(LHSI->getOperand(1));
+    if (!Val) break;
+
+    // If this is a signed comparison to 0 and the mul is sign preserving,
+    // use the mul LHS operand instead.
+    ICmpInst::Predicate pred = ICI.getPredicate();
+    if (isSignTest(pred, RHS) && !Val->isZero() &&
+        cast<BinaryOperator>(LHSI)->hasNoSignedWrap())
+      return new ICmpInst(Val->isNegative() ?
+                          ICmpInst::getSwappedPredicate(pred) : pred,
+                          LHSI->getOperand(0),
+                          Constant::getNullValue(RHS->getType()));
+
+    break;
+  }
+
   case Instruction::Shl: {       // (icmp pred (shl X, ShAmt), CI)
     ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1));
     if (!ShAmt) break;
@@ -1296,6 +1355,12 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
         return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
                             ConstantExpr::getLShr(RHS, ShAmt));
 
+      // If the shift is NSW and we compare to 0, then it is just shifting out
+      // sign bits, no need for an AND either.
+      if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && RHSV == 0)
+        return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
+                            ConstantExpr::getLShr(RHS, ShAmt));
+
       if (LHSI->hasOneUse()) {
         // Otherwise strength reduce the shift into an and.
         uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
@@ -1310,6 +1375,15 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
       }
     }
 
+    // If this is a signed comparison to 0 and the shift is sign preserving,
+    // use the shift LHS operand instead.
+    ICmpInst::Predicate pred = ICI.getPredicate();
+    if (isSignTest(pred, RHS) &&
+        cast<BinaryOperator>(LHSI)->hasNoSignedWrap())
+      return new ICmpInst(pred,
+                          LHSI->getOperand(0),
+                          Constant::getNullValue(RHS->getType()));
+
     // Otherwise, if this is a comparison of the sign bit, simplify to and/test.
     bool TrueIfSigned = false;
     if (LHSI->hasOneUse() &&
@@ -1323,6 +1397,26 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
       return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ,
                           And, Constant::getNullValue(And->getType()));
     }
+
+    // Transform (icmp pred iM (shl iM %v, N), CI)
+    // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (CI>>N))
+    // Transform the shl to a trunc if (trunc (CI>>N)) has no loss and M-N.
+    // This enables to get rid of the shift in favor of a trunc which can be
+    // free on the target. It has the additional benefit of comparing to a
+    // smaller constant, which will be target friendly.
+    unsigned Amt = ShAmt->getLimitedValue(TypeBits-1);
+    if (LHSI->hasOneUse() &&
+        Amt != 0 && RHSV.countTrailingZeros() >= Amt) {
+      Type *NTy = IntegerType::get(ICI.getContext(), TypeBits - Amt);
+      Constant *NCI = ConstantExpr::getTrunc(
+                        ConstantExpr::getAShr(RHS,
+                          ConstantInt::get(RHS->getType(), Amt)),
+                        NTy);
+      return new ICmpInst(ICI.getPredicate(),
+                          Builder->CreateTrunc(LHSI->getOperand(0), NTy),
+                          NCI);
+    }
+
     break;
   }
 
@@ -1504,6 +1598,19 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
             return new ICmpInst(pred, X, NegX);
           }
         }
+        break;
+      case Instruction::Mul:
+        if (RHSV == 0 && BO->hasNoSignedWrap()) {
+          if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) {
+            // The trivial case (mul X, 0) is handled by InstSimplify
+            // General case : (mul X, C) != 0 iff X != 0
+            //                (mul X, C) == 0 iff X == 0
+            if (!BOC->isZero())
+              return new ICmpInst(ICI.getPredicate(), BO->getOperand(0),
+                                  Constant::getNullValue(RHS->getType()));
+          }
+        }
+        break;
       default: break;
       }
     } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(LHSI)) {
@@ -1554,7 +1661,7 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) {
   // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the
   // integer type is the same size as the pointer type.
   if (TD && LHSCI->getOpcode() == Instruction::PtrToInt &&
-      TD->getTypeSizeInBits(DestTy) ==
+      TD->getPointerSizeInBits() ==
          cast<IntegerType>(DestTy)->getBitWidth()) {
     Value *RHSOp = 0;
     if (Constant *RHSC = dyn_cast<Constant>(ICI.getOperand(1))) {
@@ -2250,7 +2357,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       case Instruction::IntToPtr:
         // icmp pred inttoptr(X), null -> icmp pred X, 0
         if (RHSC->isNullValue() && TD &&
-            TD->getIntPtrType(LHSI->getType()) ==
+            TD->getIntPtrType(RHSC->getContext()) ==
                LHSI->getOperand(0)->getType())
           return new ICmpInst(I.getPredicate(), LHSI->getOperand(0),
                         Constant::getNullValue(LHSI->getOperand(0)->getType()));
@@ -2358,8 +2465,25 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
         // Try not to increase register pressure.
         BO0->hasOneUse() && BO1->hasOneUse()) {
       // Determine Y and Z in the form icmp (X+Y), (X+Z).
-      Value *Y = (A == C || A == D) ? B : A;
-      Value *Z = (C == A || C == B) ? D : C;
+      Value *Y, *Z;
+      if (A == C) {
+        // C + B == C + D  ->  B == D
+        Y = B;
+        Z = D;
+      } else if (A == D) {
+        // D + B == C + D  ->  B == C
+        Y = B;
+        Z = C;
+      } else if (B == C) {
+        // A + C == C + D  ->  A == D
+        Y = A;
+        Z = D;
+      } else {
+        assert(B == D);
+        // A + D == C + D  ->  A == C
+        Y = A;
+        Z = C;
+      }
       return new ICmpInst(Pred, Y, Z);
     }