Add in the first iteration of support for llvm/clang/lldb to allow variable per addre...
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index bb1cbfade34d1a4701ed7f86ffa9a2e465db3629..e3e5ddae80b437604fb18347431e7a63202ec964 100644 (file)
@@ -16,7 +16,8 @@
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
-#include "llvm/Target/TargetData.h"
+#include "llvm/DataLayout.h"
+#include "llvm/Target/TargetLibraryInfo.h"
 #include "llvm/Support/ConstantRange.h"
 #include "llvm/Support/GetElementPtrTypeIterator.h"
 #include "llvm/Support/PatternMatch.h"
@@ -203,8 +204,12 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
   // We need TD information to know the pointer size unless this is inbounds.
   if (!GEP->isInBounds() && TD == 0) return 0;
 
-  ConstantArray *Init = dyn_cast<ConstantArray>(GV->getInitializer());
-  if (Init == 0 || Init->getNumOperands() > 1024) return 0;
+  Constant *Init = GV->getInitializer();
+  if (!isa<ConstantArray>(Init) && !isa<ConstantDataArray>(Init))
+    return 0;
+  
+  uint64_t ArrayElementCount = Init->getType()->getArrayNumElements();
+  if (ArrayElementCount > 1024) return 0;  // Don't blow up on huge arrays.
 
   // There are many forms of this optimization we can handle, for now, just do
   // the simple index into a single-dimensional array.
@@ -221,7 +226,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
   // structs.
   SmallVector<unsigned, 4> LaterIndices;
 
-  Type *EltTy = cast<ArrayType>(Init->getType())->getElementType();
+  Type *EltTy = Init->getType()->getArrayElementType();
   for (unsigned i = 3, e = GEP->getNumOperands(); i != e; ++i) {
     ConstantInt *Idx = dyn_cast<ConstantInt>(GEP->getOperand(i));
     if (Idx == 0) return 0;  // Variable index.
@@ -272,8 +277,9 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
 
   // Scan the array and see if one of our patterns matches.
   Constant *CompareRHS = cast<Constant>(ICI.getOperand(1));
-  for (unsigned i = 0, e = Init->getNumOperands(); i != e; ++i) {
-    Constant *Elt = Init->getOperand(i);
+  for (unsigned i = 0, e = ArrayElementCount; i != e; ++i) {
+    Constant *Elt = Init->getAggregateElement(i);
+    if (Elt == 0) return 0;
 
     // If this is indexing an array of structures, get the structure element.
     if (!LaterIndices.empty())
@@ -284,7 +290,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
 
     // Find out if the comparison would be true or false for the i'th element.
     Constant *C = ConstantFoldCompareInstOperands(ICI.getPredicate(), Elt,
-                                                  CompareRHS, TD);
+                                                  CompareRHS, TD, TLI);
     // If the result is undef for this element, ignore it.
     if (isa<UndefValue>(C)) {
       // Extend range state machines to cover this element in case there is an
@@ -359,11 +365,12 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
   // order the state machines in complexity of the generated code.
   Value *Idx = GEP->getOperand(2);
 
+  unsigned AS = GEP->getPointerAddressSpace();
   // If the index is larger than the pointer size of the target, truncate the
   // index down like the GEP would do implicitly.  We don't have to do this for
   // an inbounds GEP because the index can't be out of range.
   if (!GEP->isInBounds() &&
-      Idx->getType()->getPrimitiveSizeInBits() > TD->getPointerSizeInBits())
+      Idx->getType()->getPrimitiveSizeInBits() > TD->getPointerSizeInBits(AS))
     Idx = Builder->CreateTrunc(Idx, TD->getIntPtrType(Idx->getContext()));
 
   // If the comparison is only true for one or two elements, emit direct
@@ -440,10 +447,10 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
   // If a 32-bit or 64-bit magic bitvector captures the entire comparison state
   // of this load, replace it with computation that does:
   //   ((magic_cst >> i) & 1) != 0
-  if (Init->getNumOperands() <= 32 ||
-      (TD && Init->getNumOperands() <= 64 && TD->isLegalInteger(64))) {
+  if (ArrayElementCount <= 32 ||
+      (TD && ArrayElementCount <= 64 && TD->isLegalInteger(64))) {
     Type *Ty;
-    if (Init->getNumOperands() <= 32)
+    if (ArrayElementCount <= 32)
       Ty = Type::getInt32Ty(Init->getContext());
     else
       Ty = Type::getInt64Ty(Init->getContext());
@@ -468,7 +475,7 @@ FoldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, GlobalVariable *GV,
 /// If we can't emit an optimized form for this expression, this returns null.
 ///
 static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) {
-  TargetData &TD = *IC.getTargetData();
+  DataLayout &TD = *IC.getDataLayout();
   gep_type_iterator GTI = gep_type_begin(GEP);
 
   // Check to see if this gep only has a single variable index.  If so, and if
@@ -522,10 +529,11 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) {
     }
   }
 
+  unsigned AS = cast<GetElementPtrInst>(GEP)->getPointerAddressSpace();
   // Okay, we know we have a single variable index, which must be a
   // pointer/array/vector index.  If there is no offset, life is simple, return
   // the index.
-  unsigned IntPtrWidth = TD.getPointerSizeInBits();
+  unsigned IntPtrWidth = TD.getPointerSizeInBits(AS);
   if (Offset == 0) {
     // Cast to intptrty in case a truncation occurs.  If an extension is needed,
     // we don't need to bother extending: the extension won't affect where the
@@ -566,6 +574,14 @@ static Value *EvaluateGEPOffsetExpression(User *GEP, InstCombiner &IC) {
 Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
                                        ICmpInst::Predicate Cond,
                                        Instruction &I) {
+  // Don't transform signed compares of GEPs into index compares. Even if the
+  // GEP is inbounds, the final add of the base pointer can have signed overflow
+  // and would change the result of the icmp.
+  // e.g. "&foo[0] <s &foo[1]" can't be folded to "true" because "foo" could be
+  // the maximum signed value for the pointer type.
+  if (ICmpInst::isSigned(Cond))
+    return 0;
+
   // Look through bitcasts.
   if (BitCastInst *BCI = dyn_cast<BitCastInst>(RHS))
     RHS = BCI->getOperand(0);
@@ -602,6 +618,20 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
         return new ICmpInst(ICmpInst::getSignedPredicate(Cond),
                             GEPLHS->getOperand(0), GEPRHS->getOperand(0));
 
+      // If we're comparing GEPs with two base pointers that only differ in type
+      // and both GEPs have only constant indices or just one use, then fold
+      // the compare with the adjusted indices.
+      if (TD && GEPLHS->isInBounds() && GEPRHS->isInBounds() &&
+          (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) &&
+          (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) &&
+          PtrBase->stripPointerCasts() ==
+            GEPRHS->getOperand(0)->stripPointerCasts()) {
+        Value *Cmp = Builder->CreateICmp(ICmpInst::getSignedPredicate(Cond),
+                                         EmitGEPOffset(GEPLHS),
+                                         EmitGEPOffset(GEPRHS));
+        return ReplaceInstUsesWith(I, Cmp);
+      }
+
       // Otherwise, the base pointers are different and the indices are
       // different, bail out.
       return 0;
@@ -1001,15 +1031,14 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
       // of the high bits truncated out of x are known.
       unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(),
              SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits();
-      APInt Mask(APInt::getHighBitsSet(SrcBits, SrcBits-DstBits));
       APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0);
-      ComputeMaskedBits(LHSI->getOperand(0), Mask, KnownZero, KnownOne);
+      ComputeMaskedBits(LHSI->getOperand(0), KnownZero, KnownOne);
 
       // If all the high bits are known, we can do this xform.
       if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) {
         // Pull in the high bits from known-ones set.
         APInt NewRHS = RHS->getValue().zext(SrcBits);
-        NewRHS |= KnownOne;
+        NewRHS |= KnownOne & APInt::getHighBitsSet(SrcBits, SrcBits-DstBits);
         return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
                             ConstantInt::get(ICI.getContext(), NewRHS));
       }
@@ -1525,7 +1554,8 @@ Instruction *InstCombiner::visitICmpInstWithCastAndCast(ICmpInst &ICI) {
   // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the
   // integer type is the same size as the pointer type.
   if (TD && LHSCI->getOpcode() == Instruction::PtrToInt &&
-      TD->getPointerSizeInBits() ==
+      TD->getPointerSizeInBits(
+        cast<PtrToIntInst>(LHSCI)->getPointerAddressSpace()) ==
          cast<IntegerType>(DestTy)->getBitWidth()) {
     Value *RHSOp = 0;
     if (Constant *RHSC = dyn_cast<Constant>(ICI.getOperand(1))) {
@@ -1657,6 +1687,14 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
       CI1->getValue() != APInt::getLowBitsSet(CI1->getBitWidth(), NewWidth))
     return 0;
 
+  // This is only really a signed overflow check if the inputs have been
+  // sign-extended; check for that condition. For example, if CI2 is 2^31 and
+  // the operands of the add are 64 bits wide, we need at least 33 sign bits.
+  unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1;
+  if (IC.ComputeNumSignBits(A) < NeededSignBits ||
+      IC.ComputeNumSignBits(B) < NeededSignBits)
+    return 0;
+
   // In order to replace the original add with a narrower
   // llvm.sadd.with.overflow, the only uses allowed are the add-with-constant
   // and truncates that discard the high bits of the add.  Verify that this is
@@ -1787,6 +1825,24 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
   if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, TD))
     return ReplaceInstUsesWith(I, V);
 
+  // comparing -val or val with non-zero is the same as just comparing val
+  // ie, abs(val) != 0 -> val != 0
+  if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero()))
+  {
+    Value *Cond, *SelectTrue, *SelectFalse;
+    if (match(Op0, m_Select(m_Value(Cond), m_Value(SelectTrue),
+                            m_Value(SelectFalse)))) {
+      if (Value *V = dyn_castNegVal(SelectTrue)) {
+        if (V == SelectFalse)
+          return CmpInst::Create(Instruction::ICmp, I.getPredicate(), V, Op1);
+      }
+      else if (Value *V = dyn_castNegVal(SelectFalse)) {
+        if (V == SelectTrue)
+          return CmpInst::Create(Instruction::ICmp, I.getPredicate(), V, Op1);
+      }
+    }
+  }
+
   Type *Ty = Op0->getType();
 
   // icmp's with boolean values can always be turned into bitwise operations
@@ -2528,10 +2584,25 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       }
     }
 
+    // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B)
+    // and       (B & (1<<X)-1) == (zext A) --> A == (trunc B)
+    ConstantInt *Cst1;
+    if ((Op0->hasOneUse() &&
+         match(Op0, m_ZExt(m_Value(A))) &&
+         match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) ||
+        (Op1->hasOneUse() &&
+         match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) &&
+         match(Op1, m_ZExt(m_Value(A))))) {
+      APInt Pow2 = Cst1->getValue() + 1;
+      if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) &&
+          Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth())
+        return new ICmpInst(I.getPredicate(), A,
+                            Builder->CreateTrunc(B, A->getType()));
+    }
+
     // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to
     // "icmp (and X, mask), cst"
     uint64_t ShAmt = 0;
-    ConstantInt *Cst1;
     if (Op0->hasOneUse() &&
         match(Op0, m_Trunc(m_OneUse(m_LShr(m_Value(A),
                                            m_ConstantInt(ShAmt))))) &&
@@ -2683,6 +2754,17 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I,
         return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
       return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
     }
+  } else {
+    // See if the RHS value is < UnsignedMin.
+    APFloat SMin(RHS.getSemantics(), APFloat::fcZero, false);
+    SMin.convertFromAPInt(APInt::getMinValue(IntWidth), true,
+                          APFloat::rmNearestTiesToEven);
+    if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // umin > 12312.0
+      if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT ||
+          Pred == ICmpInst::ICMP_UGE)
+        return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
+      return ReplaceInstUsesWith(I, ConstantInt::getFalse(I.getContext()));
+    }
   }
 
   // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
@@ -2746,7 +2828,7 @@ Instruction *InstCombiner::FoldFCmp_IntToFP_Cst(FCmpInst &I,
       case ICmpInst::ICMP_UGE:
         // (float)int >= -4.4   --> true
         // (float)int >= 4.4    --> int > 4
-        if (!RHS.isNegative())
+        if (RHS.isNegative())
           return ReplaceInstUsesWith(I, ConstantInt::getTrue(I.getContext()));
         Pred = ICmpInst::ICMP_UGT;
         break;
@@ -2822,7 +2904,9 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
 
         const fltSemantics *Sem;
         // FIXME: This shouldn't be here.
-        if (LHSExt->getSrcTy()->isFloatTy())
+        if (LHSExt->getSrcTy()->isHalfTy())
+          Sem = &APFloat::IEEEhalf;
+        else if (LHSExt->getSrcTy()->isFloatTy())
           Sem = &APFloat::IEEEsingle;
         else if (LHSExt->getSrcTy()->isDoubleTy())
           Sem = &APFloat::IEEEdouble;
@@ -2905,6 +2989,44 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
                 return Res;
         }
         break;
+      case Instruction::Call: {
+        CallInst *CI = cast<CallInst>(LHSI);
+        LibFunc::Func Func;
+        // Various optimization for fabs compared with zero.
+        if (RHSC->isNullValue() && CI->getCalledFunction() &&
+            TLI->getLibFunc(CI->getCalledFunction()->getName(), Func) &&
+            TLI->has(Func)) {
+          if (Func == LibFunc::fabs || Func == LibFunc::fabsf ||
+              Func == LibFunc::fabsl) {
+            switch (I.getPredicate()) {
+            default: break;
+            // fabs(x) < 0 --> false
+            case FCmpInst::FCMP_OLT:
+              return ReplaceInstUsesWith(I, Builder->getFalse());
+            // fabs(x) > 0 --> x != 0
+            case FCmpInst::FCMP_OGT:
+              return new FCmpInst(FCmpInst::FCMP_ONE, CI->getArgOperand(0),
+                                  RHSC);
+            // fabs(x) <= 0 --> x == 0
+            case FCmpInst::FCMP_OLE:
+              return new FCmpInst(FCmpInst::FCMP_OEQ, CI->getArgOperand(0),
+                                  RHSC);
+            // fabs(x) >= 0 --> !isnan(x)
+            case FCmpInst::FCMP_OGE:
+              return new FCmpInst(FCmpInst::FCMP_ORD, CI->getArgOperand(0),
+                                  RHSC);
+            // fabs(x) == 0 --> x == 0
+            // fabs(x) != 0 --> x != 0
+            case FCmpInst::FCMP_OEQ:
+            case FCmpInst::FCMP_UEQ:
+            case FCmpInst::FCMP_ONE:
+            case FCmpInst::FCMP_UNE:
+              return new FCmpInst(I.getPredicate(), CI->getArgOperand(0),
+                                  RHSC);
+            }
+          }
+        }
+      }
       }
   }