InstCombine: Fix a combine assuming that icmp operands were integers
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCasts.cpp
index ff083d7926cc3b84f9dc5e745e2ce3580ed60811..b41cdc65202f6dc6edbc00765a58b0c7dea9785b 100644 (file)
@@ -335,7 +335,8 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) {
 ///
 /// This function works on both vectors and scalars.
 ///
-static bool CanEvaluateTruncated(Value *V, Type *Ty) {
+static bool CanEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC,
+                                 Instruction *CxtI) {
   // We can always evaluate constants in another type.
   if (isa<Constant>(V))
     return true;
@@ -364,8 +365,8 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) {
   case Instruction::Or:
   case Instruction::Xor:
     // These operators can all arbitrarily be extended or truncated.
-    return CanEvaluateTruncated(I->getOperand(0), Ty) &&
-           CanEvaluateTruncated(I->getOperand(1), Ty);
+    return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+           CanEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
 
   case Instruction::UDiv:
   case Instruction::URem: {
@@ -374,10 +375,10 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) {
     uint32_t BitWidth = Ty->getScalarSizeInBits();
     if (BitWidth < OrigBitWidth) {
       APInt Mask = APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth);
-      if (MaskedValueIsZero(I->getOperand(0), Mask) &&
-          MaskedValueIsZero(I->getOperand(1), Mask)) {
-        return CanEvaluateTruncated(I->getOperand(0), Ty) &&
-               CanEvaluateTruncated(I->getOperand(1), Ty);
+      if (IC.MaskedValueIsZero(I->getOperand(0), Mask, 0, CxtI) &&
+          IC.MaskedValueIsZero(I->getOperand(1), Mask, 0, CxtI)) {
+        return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) &&
+               CanEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI);
       }
     }
     break;
@@ -388,7 +389,7 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
       uint32_t BitWidth = Ty->getScalarSizeInBits();
       if (CI->getLimitedValue(BitWidth) < BitWidth)
-        return CanEvaluateTruncated(I->getOperand(0), Ty);
+        return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
     }
     break;
   case Instruction::LShr:
@@ -398,10 +399,10 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
       uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits();
       uint32_t BitWidth = Ty->getScalarSizeInBits();
-      if (MaskedValueIsZero(I->getOperand(0),
-            APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) &&
+      if (IC.MaskedValueIsZero(I->getOperand(0),
+            APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) &&
           CI->getLimitedValue(BitWidth) < BitWidth) {
-        return CanEvaluateTruncated(I->getOperand(0), Ty);
+        return CanEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI);
       }
     }
     break;
@@ -415,8 +416,8 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) {
     return true;
   case Instruction::Select: {
     SelectInst *SI = cast<SelectInst>(I);
-    return CanEvaluateTruncated(SI->getTrueValue(), Ty) &&
-           CanEvaluateTruncated(SI->getFalseValue(), Ty);
+    return CanEvaluateTruncated(SI->getTrueValue(), Ty, IC, CxtI) &&
+           CanEvaluateTruncated(SI->getFalseValue(), Ty, IC, CxtI);
   }
   case Instruction::PHI: {
     // We can change a phi if we can change all operands.  Note that we never
@@ -424,7 +425,7 @@ static bool CanEvaluateTruncated(Value *V, Type *Ty) {
     // instructions with a single use.
     PHINode *PN = cast<PHINode>(I);
     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
-      if (!CanEvaluateTruncated(PN->getIncomingValue(i), Ty))
+      if (!CanEvaluateTruncated(PN->getIncomingValue(i), Ty, IC, CxtI))
         return false;
     return true;
   }
@@ -453,7 +454,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
   // expression tree to something weird like i93 unless the source is also
   // strange.
   if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
-      CanEvaluateTruncated(Src, DestTy)) {
+      CanEvaluateTruncated(Src, DestTy, *this, &CI)) {
 
     // If this cast is a truncate, evaluting in a different type always
     // eliminates the cast, so it is always a win.
@@ -553,7 +554,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI,
       // If Op1C some other power of two, convert:
       uint32_t BitWidth = Op1C->getType()->getBitWidth();
       APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
-      computeKnownBits(ICI->getOperand(0), KnownZero, KnownOne);
+      computeKnownBits(ICI->getOperand(0), KnownZero, KnownOne, 0, &CI);
 
       APInt KnownZeroMask(~KnownZero);
       if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1?
@@ -601,8 +602,8 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI,
 
       APInt KnownZeroLHS(BitWidth, 0), KnownOneLHS(BitWidth, 0);
       APInt KnownZeroRHS(BitWidth, 0), KnownOneRHS(BitWidth, 0);
-      computeKnownBits(LHS, KnownZeroLHS, KnownOneLHS);
-      computeKnownBits(RHS, KnownZeroRHS, KnownOneRHS);
+      computeKnownBits(LHS, KnownZeroLHS, KnownOneLHS, 0, &CI);
+      computeKnownBits(RHS, KnownZeroRHS, KnownOneRHS, 0, &CI);
 
       if (KnownZeroLHS == KnownZeroRHS && KnownOneLHS == KnownOneRHS) {
         APInt KnownBits = KnownZeroLHS | KnownOneLHS;
@@ -651,7 +652,8 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI,
 /// clear the top bits anyway, doing this has no extra cost.
 ///
 /// This function works on both vectors and scalars.
-static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) {
+static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
+                             InstCombiner &IC, Instruction *CxtI) {
   BitsToClear = 0;
   if (isa<Constant>(V))
     return true;
@@ -680,8 +682,8 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) {
   case Instruction::Add:
   case Instruction::Sub:
   case Instruction::Mul:
-    if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear) ||
-        !CanEvaluateZExtd(I->getOperand(1), Ty, Tmp))
+    if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI) ||
+        !CanEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI))
       return false;
     // These can all be promoted if neither operand has 'bits to clear'.
     if (BitsToClear == 0 && Tmp == 0)
@@ -695,8 +697,9 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) {
       // We use MaskedValueIsZero here for generality, but the case we care
       // about the most is constant RHS.
       unsigned VSize = V->getType()->getScalarSizeInBits();
-      if (MaskedValueIsZero(I->getOperand(1),
-                            APInt::getHighBitsSet(VSize, BitsToClear)))
+      if (IC.MaskedValueIsZero(I->getOperand(1),
+                               APInt::getHighBitsSet(VSize, BitsToClear),
+                               0, CxtI))
         return true;
     }
 
@@ -707,7 +710,7 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) {
     // We can promote shl(x, cst) if we can promote x.  Since shl overwrites the
     // upper bits we can reduce BitsToClear by the shift amount.
     if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) {
-      if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear))
+      if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
         return false;
       uint64_t ShiftAmt = Amt->getZExtValue();
       BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0;
@@ -718,7 +721,7 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) {
     // We can promote lshr(x, cst) if we can promote x.  This requires the
     // ultimate 'and' to clear out the high zero bits we're clearing out though.
     if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) {
-      if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear))
+      if (!CanEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI))
         return false;
       BitsToClear += Amt->getZExtValue();
       if (BitsToClear > V->getType()->getScalarSizeInBits())
@@ -728,8 +731,8 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) {
     // Cannot promote variable LSHR.
     return false;
   case Instruction::Select:
-    if (!CanEvaluateZExtd(I->getOperand(1), Ty, Tmp) ||
-        !CanEvaluateZExtd(I->getOperand(2), Ty, BitsToClear) ||
+    if (!CanEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) ||
+        !CanEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) ||
         // TODO: If important, we could handle the case when the BitsToClear are
         // known zero in the disagreeing side.
         Tmp != BitsToClear)
@@ -741,10 +744,10 @@ static bool CanEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear) {
     // get into trouble with cyclic PHIs here because we only consider
     // instructions with a single use.
     PHINode *PN = cast<PHINode>(I);
-    if (!CanEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear))
+    if (!CanEvaluateZExtd(PN->getIncomingValue(0), Ty, BitsToClear, IC, CxtI))
       return false;
     for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i)
-      if (!CanEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp) ||
+      if (!CanEvaluateZExtd(PN->getIncomingValue(i), Ty, Tmp, IC, CxtI) ||
           // TODO: If important, we could handle the case when the BitsToClear
           // are known zero in the disagreeing input.
           Tmp != BitsToClear)
@@ -781,7 +784,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
   // strange.
   unsigned BitsToClear;
   if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
-      CanEvaluateZExtd(Src, DestTy, BitsToClear)) {
+      CanEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) {
     assert(BitsToClear < SrcTy->getScalarSizeInBits() &&
            "Unreasonable BitsToClear");
 
@@ -796,8 +799,10 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
 
     // If the high bits are already filled with zeros, just replace this
     // cast with the result.
-    if (MaskedValueIsZero(Res, APInt::getHighBitsSet(DestBitSize,
-                                                     DestBitSize-SrcBitsKept)))
+    if (MaskedValueIsZero(Res,
+                          APInt::getHighBitsSet(DestBitSize,
+                                                DestBitSize-SrcBitsKept),
+                             0, &CI))
       return ReplaceInstUsesWith(CI, Res);
 
     // We need to emit an AND to clear the high bits.
@@ -895,6 +900,10 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) {
   Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1);
   ICmpInst::Predicate Pred = ICI->getPredicate();
 
+  // Don't bother if Op1 isn't of vector or integer type.
+  if (!Op1->getType()->isIntOrIntVectorTy())
+    return nullptr;
+
   if (Constant *Op1C = dyn_cast<Constant>(Op1)) {
     // (x <s  0) ? -1 : 0 -> ashr x, 31        -> all ones if negative
     // (x >s -1) ? -1 : 0 -> not (ashr x, 31)  -> all ones if positive
@@ -921,7 +930,7 @@ Instruction *InstCombiner::transformSExtICmp(ICmpInst *ICI, Instruction &CI) {
         ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){
       unsigned BitWidth = Op1C->getType()->getBitWidth();
       APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
-      computeKnownBits(Op0, KnownZero, KnownOne);
+      computeKnownBits(Op0, KnownZero, KnownOne, 0, &CI);
 
       APInt KnownZeroMask(~KnownZero);
       if (KnownZeroMask.isPowerOf2()) {
@@ -1072,7 +1081,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
 
     // If the high bits are already filled with sign bit, just replace this
     // cast with the result.
-    if (ComputeNumSignBits(Res) > DestBitSize - SrcBitSize)
+    if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize)
       return ReplaceInstUsesWith(CI, Res);
 
     // We need to emit a shl + ashr to do the sign extend.
@@ -1312,42 +1321,6 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
     }
   }
 
-  // Fold (fptrunc (sqrt (fpext x))) -> (sqrtf x)
-  // Note that we restrict this transformation based on
-  // TLI->has(LibFunc::sqrtf), even for the sqrt intrinsic, because
-  // TLI->has(LibFunc::sqrtf) is sufficient to guarantee that the
-  // single-precision intrinsic can be expanded in the backend.
-  CallInst *Call = dyn_cast<CallInst>(CI.getOperand(0));
-  if (Call && Call->getCalledFunction() && TLI->has(LibFunc::sqrtf) &&
-      (Call->getCalledFunction()->getName() == TLI->getName(LibFunc::sqrt) ||
-       Call->getCalledFunction()->getIntrinsicID() == Intrinsic::sqrt) &&
-      Call->getNumArgOperands() == 1 &&
-      Call->hasOneUse()) {
-    CastInst *Arg = dyn_cast<CastInst>(Call->getArgOperand(0));
-    if (Arg && Arg->getOpcode() == Instruction::FPExt &&
-        CI.getType()->isFloatTy() &&
-        Call->getType()->isDoubleTy() &&
-        Arg->getType()->isDoubleTy() &&
-        Arg->getOperand(0)->getType()->isFloatTy()) {
-      Function *Callee = Call->getCalledFunction();
-      Module *M = CI.getParent()->getParent()->getParent();
-      Constant *SqrtfFunc = (Callee->getIntrinsicID() == Intrinsic::sqrt) ?
-        Intrinsic::getDeclaration(M, Intrinsic::sqrt, Builder->getFloatTy()) :
-        M->getOrInsertFunction("sqrtf", Callee->getAttributes(),
-                               Builder->getFloatTy(), Builder->getFloatTy(),
-                               NULL);
-      CallInst *ret = CallInst::Create(SqrtfFunc, Arg->getOperand(0),
-                                       "sqrtfcall");
-      ret->setAttributes(Callee->getAttributes());
-
-
-      // Remove the old Call.  With -fmath-errno, it won't get marked readnone.
-      ReplaceInstUsesWith(*Call, UndefValue::get(Call->getType()));
-      EraseInstFromFunction(*Call);
-      return ret;
-    }
-  }
-
   return nullptr;
 }
 
@@ -1909,9 +1882,9 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
 }
 
 Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) {
-  // If the destination pointer element type is not the the same as the source's
-  // do the addrspacecast to the same type, and then the bitcast in the new
-  // address space. This allows the cast to be exposed to other transforms.
+  // If the destination pointer element type is not the same as the source's
+  // first do a bitcast to the destination type, and then the addrspacecast.
+  // This allows the cast to be exposed to other transforms.
   Value *Src = CI.getOperand(0);
   PointerType *SrcTy = cast<PointerType>(Src->getType()->getScalarType());
   PointerType *DestTy = cast<PointerType>(CI.getType()->getScalarType());