Merging r259375:
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index 3afb5ed6d8ac5de3758ffbcd63884547220c36c9..d9311a343eadb57fdcfd9d02a592e91985335a31 100644 (file)
@@ -730,6 +730,83 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
   return nullptr;
 }
 
+Instruction *InstCombiner::FoldAllocaCmp(ICmpInst &ICI, AllocaInst *Alloca,
+                                         Value *Other) {
+  assert(ICI.isEquality() && "Cannot fold non-equality comparison.");
+
+  // It would be tempting to fold away comparisons between allocas and any
+  // pointer not based on that alloca (e.g. an argument). However, even
+  // though such pointers cannot alias, they can still compare equal.
+  //
+  // But LLVM doesn't specify where allocas get their memory, so if the alloca
+  // doesn't escape we can argue that it's impossible to guess its value, and we
+  // can therefore act as if any such guesses are wrong.
+  //
+  // The code below checks that the alloca doesn't escape, and that it's only
+  // used in a comparison once (the current instruction). The
+  // single-comparison-use condition ensures that we're trivially folding all
+  // comparisons against the alloca consistently, and avoids the risk of
+  // erroneously folding a comparison of the pointer with itself.
+
+  unsigned MaxIter = 32; // Break cycles and bound to constant-time.
+
+  SmallVector<Use *, 32> Worklist;
+  for (Use &U : Alloca->uses()) {
+    if (Worklist.size() >= MaxIter)
+      return nullptr;
+    Worklist.push_back(&U);
+  }
+
+  unsigned NumCmps = 0;
+  while (!Worklist.empty()) {
+    assert(Worklist.size() <= MaxIter);
+    Use *U = Worklist.pop_back_val();
+    Value *V = U->getUser();
+    --MaxIter;
+
+    if (isa<BitCastInst>(V) || isa<GetElementPtrInst>(V) || isa<PHINode>(V) ||
+        isa<SelectInst>(V)) {
+      // Track the uses.
+    } else if (isa<LoadInst>(V)) {
+      // Loading from the pointer doesn't escape it.
+      continue;
+    } else if (auto *SI = dyn_cast<StoreInst>(V)) {
+      // Storing *to* the pointer is fine, but storing the pointer escapes it.
+      if (SI->getValueOperand() == U->get())
+        return nullptr;
+      continue;
+    } else if (isa<ICmpInst>(V)) {
+      if (NumCmps++)
+        return nullptr; // Found more than one cmp.
+      continue;
+    } else if (auto *Intrin = dyn_cast<IntrinsicInst>(V)) {
+      switch (Intrin->getIntrinsicID()) {
+        // These intrinsics don't escape or compare the pointer. Memset is safe
+        // because we don't allow ptrtoint. Memcpy and memmove are safe because
+        // we don't allow stores, so src cannot point to V.
+        case Intrinsic::lifetime_start: case Intrinsic::lifetime_end:
+        case Intrinsic::dbg_declare: case Intrinsic::dbg_value:
+        case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset:
+          continue;
+        default:
+          return nullptr;
+      }
+    } else {
+      return nullptr;
+    }
+    for (Use &U : V->uses()) {
+      if (Worklist.size() >= MaxIter)
+        return nullptr;
+      Worklist.push_back(&U);
+    }
+  }
+
+  Type *CmpTy = CmpInst::makeCmpResultType(Other->getType());
+  return ReplaceInstUsesWith(
+      ICI,
+      ConstantInt::get(CmpTy, !CmpInst::isTrueWhenEqual(ICI.getPredicate())));
+}
+
 /// FoldICmpAddOpCst - Fold "icmp pred (X+CI), X".
 Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI,
                                             Value *X, ConstantInt *CI,
@@ -2112,11 +2189,9 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
   // If the pattern matches, truncate the inputs to the narrower type and
   // use the sadd_with_overflow intrinsic to efficiently compute both the
   // result and the overflow bit.
-  Module *M = I.getParent()->getParent()->getParent();
-
   Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth);
-  Value *F = Intrinsic::getDeclaration(M, Intrinsic::sadd_with_overflow,
-                                       NewType);
+  Value *F = Intrinsic::getDeclaration(I.getModule(),
+                                       Intrinsic::sadd_with_overflow, NewType);
 
   InstCombiner::BuilderTy *Builder = IC.Builder;
 
@@ -2394,7 +2469,6 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
 
   InstCombiner::BuilderTy *Builder = IC.Builder;
   Builder->SetInsertPoint(MulInstr);
-  Module *M = I.getParent()->getParent()->getParent();
 
   // Replace: mul(zext A, zext B) --> mul.with.overflow(A, B)
   Value *MulA = A, *MulB = B;
@@ -2402,8 +2476,8 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
     MulA = Builder->CreateZExt(A, MulType);
   if (WidthB < MulWidth)
     MulB = Builder->CreateZExt(B, MulType);
-  Value *F =
-      Intrinsic::getDeclaration(M, Intrinsic::umul_with_overflow, MulType);
+  Value *F = Intrinsic::getDeclaration(I.getModule(),
+                                       Intrinsic::umul_with_overflow, MulType);
   CallInst *Call = Builder->CreateCall(F, {MulA, MulB}, "umul");
   IC.Worklist.Add(MulInstr);
 
@@ -3211,6 +3285,17 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
                            ICmpInst::getSwappedPredicate(I.getPredicate()), I))
       return NI;
 
+  // Try to optimize equality comparisons against alloca-based pointers.
+  if (Op0->getType()->isPointerTy() && I.isEquality()) {
+    assert(Op1->getType()->isPointerTy() && "Comparing pointer with non-pointer?");
+    if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op0, DL)))
+      if (Instruction *New = FoldAllocaCmp(I, Alloca, Op1))
+        return New;
+    if (auto *Alloca = dyn_cast<AllocaInst>(GetUnderlyingObject(Op1, DL)))
+      if (Instruction *New = FoldAllocaCmp(I, Alloca, Op0))
+        return New;
+  }
+
   // Test to see if the operands of the icmp are casted versions of other
   // values.  If the ptr->ptr cast can be stripped off both arguments, we do so
   // now.
@@ -3338,6 +3423,26 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
         match(B, m_One()))
       return new ICmpInst(CmpInst::ICMP_SGE, A, Op1);
 
+    // icmp sgt X, (Y + -1) -> icmp sge X, Y
+    if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT &&
+        match(D, m_AllOnes()))
+      return new ICmpInst(CmpInst::ICMP_SGE, Op0, C);
+
+    // icmp sle X, (Y + -1) -> icmp slt X, Y
+    if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE &&
+        match(D, m_AllOnes()))
+      return new ICmpInst(CmpInst::ICMP_SLT, Op0, C);
+
+    // icmp sge X, (Y + 1) -> icmp sgt X, Y
+    if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE &&
+        match(D, m_One()))
+      return new ICmpInst(CmpInst::ICMP_SGT, Op0, C);
+
+    // icmp slt X, (Y + 1) -> icmp sle X, Y
+    if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT &&
+        match(D, m_One()))
+      return new ICmpInst(CmpInst::ICMP_SLE, Op0, C);
+
     // if C1 has greater magnitude than C2:
     //  icmp (X + C1), (Y + C2) -> icmp (X + C3), Y
     //  s.t. C3 = C1 - C2
@@ -3455,7 +3560,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
                                 BO1->getOperand(0));
           }
 
-          if (CI->isMaxValue(true)) {
+          if (BO0->getOpcode() == Instruction::Xor && CI->isMaxValue(true)) {
             ICmpInst::Predicate Pred = I.isSigned()
                                            ? I.getUnsignedPredicate()
                                            : I.getSignedPredicate();