Revert r245635, "[InstCombine] Transform A & (L - 1) u< L --> L != 0"
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index 3434a62d9e1a2a4df6fe2b78bb92146e2367b771..9ce1c3fc41b302126f33943cd0d1610a0c3b8749 100644 (file)
@@ -1447,6 +1447,23 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
           ICI.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_UGT
                                                   : ICmpInst::ICMP_ULE,
           LHSI->getOperand(0), SubOne(RHS));
+
+    // (icmp eq (and %A, C), 0) -> (icmp sgt (trunc %A), -1)
+    //   iff C is a power of 2
+    if (ICI.isEquality() && LHSI->hasOneUse() && match(RHS, m_Zero())) {
+      if (auto *CI = dyn_cast<ConstantInt>(LHSI->getOperand(1))) {
+        const APInt &AI = CI->getValue();
+        int32_t ExactLogBase2 = AI.exactLogBase2();
+        if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) {
+          Type *NTy = IntegerType::get(ICI.getContext(), ExactLogBase2 + 1);
+          Value *Trunc = Builder->CreateTrunc(LHSI->getOperand(0), NTy);
+          return new ICmpInst(ICI.getPredicate() == ICmpInst::ICMP_EQ
+                                  ? ICmpInst::ICMP_SGE
+                                  : ICmpInst::ICMP_SLT,
+                              Trunc, Constant::getNullValue(NTy));
+        }
+      }
+    }
     break;
 
   case Instruction::Or: {
@@ -2097,7 +2114,7 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
 
   Value *TruncA = Builder->CreateTrunc(A, NewType, A->getName()+".trunc");
   Value *TruncB = Builder->CreateTrunc(B, NewType, B->getName()+".trunc");
-  CallInst *Call = Builder->CreateCall2(F, TruncA, TruncB, "sadd");
+  CallInst *Call = Builder->CreateCall(F, {TruncA, TruncB}, "sadd");
   Value *Add = Builder->CreateExtractValue(Call, 0, "sadd.result");
   Value *ZExt = Builder->CreateZExt(Add, OrigAdd->getType());
 
@@ -2109,33 +2126,95 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
   return ExtractValueInst::Create(Call, 1, "sadd.overflow");
 }
 
-static Instruction *ProcessUAddIdiom(Instruction &I, Value *OrigAddV,
-                                     InstCombiner &IC) {
-  // Don't bother doing this transformation for pointers, don't do it for
-  // vectors.
-  if (!isa<IntegerType>(OrigAddV->getType())) return nullptr;
+bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS,
+                                         Value *RHS, Instruction &OrigI,
+                                         Value *&Result, Constant *&Overflow) {
+  if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS))
+    std::swap(LHS, RHS);
+
+  auto SetResult = [&](Value *OpResult, Constant *OverflowVal, bool ReuseName) {
+    Result = OpResult;
+    Overflow = OverflowVal;
+    if (ReuseName)
+      Result->takeName(&OrigI);
+    return true;
+  };
 
-  // If the add is a constant expr, then we don't bother transforming it.
-  Instruction *OrigAdd = dyn_cast<Instruction>(OrigAddV);
-  if (!OrigAdd) return nullptr;
+  switch (OCF) {
+  case OCF_INVALID:
+    llvm_unreachable("bad overflow check kind!");
 
-  Value *LHS = OrigAdd->getOperand(0), *RHS = OrigAdd->getOperand(1);
+  case OCF_UNSIGNED_ADD: {
+    OverflowResult OR = computeOverflowForUnsignedAdd(LHS, RHS, &OrigI);
+    if (OR == OverflowResult::NeverOverflows)
+      return SetResult(Builder->CreateNUWAdd(LHS, RHS), Builder->getFalse(),
+                       true);
 
-  // Put the new code above the original add, in case there are any uses of the
-  // add between the add and the compare.
-  InstCombiner::BuilderTy *Builder = IC.Builder;
-  Builder->SetInsertPoint(OrigAdd);
+    if (OR == OverflowResult::AlwaysOverflows)
+      return SetResult(Builder->CreateAdd(LHS, RHS), Builder->getTrue(), true);
+  }
+  // FALL THROUGH uadd into sadd
+  case OCF_SIGNED_ADD: {
+    // X + 0 -> {X, false}
+    if (match(RHS, m_Zero()))
+      return SetResult(LHS, Builder->getFalse(), false);
+
+    // We can strength reduce this signed add into a regular add if we can prove
+    // that it will never overflow.
+    if (OCF == OCF_SIGNED_ADD)
+      if (WillNotOverflowSignedAdd(LHS, RHS, OrigI))
+        return SetResult(Builder->CreateNSWAdd(LHS, RHS), Builder->getFalse(),
+                         true);
+    break;
+  }
 
-  Module *M = I.getParent()->getParent()->getParent();
-  Type *Ty = LHS->getType();
-  Value *F = Intrinsic::getDeclaration(M, Intrinsic::uadd_with_overflow, Ty);
-  CallInst *Call = Builder->CreateCall2(F, LHS, RHS, "uadd");
-  Value *Add = Builder->CreateExtractValue(Call, 0);
+  case OCF_UNSIGNED_SUB:
+  case OCF_SIGNED_SUB: {
+    // X - 0 -> {X, false}
+    if (match(RHS, m_Zero()))
+      return SetResult(LHS, Builder->getFalse(), false);
+
+    if (OCF == OCF_SIGNED_SUB) {
+      if (WillNotOverflowSignedSub(LHS, RHS, OrigI))
+        return SetResult(Builder->CreateNSWSub(LHS, RHS), Builder->getFalse(),
+                         true);
+    } else {
+      if (WillNotOverflowUnsignedSub(LHS, RHS, OrigI))
+        return SetResult(Builder->CreateNUWSub(LHS, RHS), Builder->getFalse(),
+                         true);
+    }
+    break;
+  }
 
-  IC.ReplaceInstUsesWith(*OrigAdd, Add);
+  case OCF_UNSIGNED_MUL: {
+    OverflowResult OR = computeOverflowForUnsignedMul(LHS, RHS, &OrigI);
+    if (OR == OverflowResult::NeverOverflows)
+      return SetResult(Builder->CreateNUWMul(LHS, RHS), Builder->getFalse(),
+                       true);
+    if (OR == OverflowResult::AlwaysOverflows)
+      return SetResult(Builder->CreateMul(LHS, RHS), Builder->getTrue(), true);
+  } // FALL THROUGH
+  case OCF_SIGNED_MUL:
+    // X * undef -> undef
+    if (isa<UndefValue>(RHS))
+      return SetResult(RHS, UndefValue::get(Builder->getInt1Ty()), false);
+
+    // X * 0 -> {0, false}
+    if (match(RHS, m_Zero()))
+      return SetResult(RHS, Builder->getFalse(), false);
+
+    // X * 1 -> {X, false}
+    if (match(RHS, m_One()))
+      return SetResult(LHS, Builder->getFalse(), false);
+
+    if (OCF == OCF_SIGNED_MUL)
+      if (WillNotOverflowSignedMul(LHS, RHS, OrigI))
+        return SetResult(Builder->CreateNSWMul(LHS, RHS), Builder->getFalse(),
+                         true);
+    break;
+  }
 
-  // The original icmp gets replaced with the overflow value.
-  return ExtractValueInst::Create(Call, 1, "uadd.overflow");
+  return false;
 }
 
 /// \brief Recognize and process idiom involving test for multiplication
@@ -2305,7 +2384,7 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
     MulB = Builder->CreateZExt(B, MulType);
   Value *F =
       Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType);
-  CallInst *Call = Builder->CreateCall2(F, MulA, MulB, "umul");
+  CallInst *Call = Builder->CreateCall(F, {MulA, MulB}, "umul");
   IC.Worklist.Add(MulInstr);
 
   // If there are uses of mul result other than the comparison, we know that
@@ -2583,7 +2662,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     Changed = true;
   }
 
-  if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC))
+  if (Value *V =
+          SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC, &I))
     return ReplaceInstUsesWith(I, V);
 
   // comparing -val or val with non-zero is the same as just comparing val
@@ -3432,21 +3512,18 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
         return new ICmpInst(I.getPredicate(), ConstantExpr::getNot(RHSC), A);
     }
 
-    // (a+b) <u a  --> llvm.uadd.with.overflow.
-    // (a+b) <u b  --> llvm.uadd.with.overflow.
-    if (I.getPredicate() == ICmpInst::ICMP_ULT &&
-        match(Op0, m_Add(m_Value(A), m_Value(B))) &&
-        (Op1 == A || Op1 == B))
-      if (Instruction *R = ProcessUAddIdiom(I, Op0, *this))
-        return R;
-
-    // a >u (a+b)  --> llvm.uadd.with.overflow.
-    // b >u (a+b)  --> llvm.uadd.with.overflow.
-    if (I.getPredicate() == ICmpInst::ICMP_UGT &&
-        match(Op1, m_Add(m_Value(A), m_Value(B))) &&
-        (Op0 == A || Op0 == B))
-      if (Instruction *R = ProcessUAddIdiom(I, Op1, *this))
-        return R;
+    Instruction *AddI = nullptr;
+    if (match(&I, m_UAddWithOverflow(m_Value(A), m_Value(B),
+                                     m_Instruction(AddI))) &&
+        isa<IntegerType>(A->getType())) {
+      Value *Result;
+      Constant *Overflow;
+      if (OptimizeOverflowCheck(OCF_UNSIGNED_ADD, A, B, *AddI, Result,
+                                Overflow)) {
+        ReplaceInstUsesWith(*AddI, Result);
+        return ReplaceInstUsesWith(I, Overflow);
+      }
+    }
 
     // (zext a) * (zext b)  --> llvm.umul.with.overflow.
     if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
@@ -3553,6 +3630,21 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
       }
     }
 
+    // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0
+    if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) &&
+        match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) {
+      unsigned TypeBits = Cst1->getBitWidth();
+      unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits);
+      if (ShAmt < TypeBits && ShAmt != 0) {
+        Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted");
+        APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt);
+        Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal),
+                                        I.getName() + ".mask");
+        return new ICmpInst(I.getPredicate(), And,
+                            Constant::getNullValue(Cst1->getType()));
+      }
+    }
+
     // Transform "icmp eq (trunc (lshr(X, cst1)), cst" to
     // "icmp (and X, mask), cst"
     uint64_t ShAmt = 0;
@@ -3852,7 +3944,8 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AC))
+  if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1,
+                                  I.getFastMathFlags(), DL, TLI, DT, AC, &I))
     return ReplaceInstUsesWith(I, V);
 
   // Simplify 'fcmp pred X, X'
@@ -3879,6 +3972,19 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
     }
   }
 
+  // Test if the FCmpInst instruction is used exclusively by a select as
+  // part of a minimum or maximum operation. If so, refrain from doing
+  // any other folding. This helps out other analyses which understand
+  // non-obfuscated minimum and maximum idioms, such as ScalarEvolution
+  // and CodeGen. And in this case, at least one of the comparison
+  // operands has at least one user besides the compare (the select),
+  // which would often largely negate the benefit of folding anyway.
+  if (I.hasOneUse())
+    if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin()))
+      if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) ||
+          (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1))
+        return nullptr;
+
   // Handle fcmp with constant RHS
   if (Constant *RHSC = dyn_cast<Constant>(Op1)) {
     if (Instruction *LHSI = dyn_cast<Instruction>(Op0))