InstCombine: Don't create an unused instruction
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCompares.cpp
index f204eeb548c7bdc6aa740b416e5466c9c21e6c0f..f59e3758101dd9bf9cb8f9c0e3bdf637edccb245 100644 (file)
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
+#include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/IR/GetElementPtrTypeIterator.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Target/TargetLibraryInfo.h"
+
 using namespace llvm;
 using namespace PatternMatch;
 
 #define DEBUG_TYPE "instcombine"
 
+// How many times is a select replaced by one of its operands?
+STATISTIC(NumSel, "Number of select opts");
+
+// Initialization Routines
+
 static ConstantInt *getOne(Constant *C) {
   return ConstantInt::get(cast<IntegerType>(C->getType()), 1);
 }
@@ -683,26 +692,12 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
     }
 
     // If one of the GEPs has all zero indices, recurse.
-    bool AllZeros = true;
-    for (unsigned i = 1, e = GEPLHS->getNumOperands(); i != e; ++i)
-      if (!isa<Constant>(GEPLHS->getOperand(i)) ||
-          !cast<Constant>(GEPLHS->getOperand(i))->isNullValue()) {
-        AllZeros = false;
-        break;
-      }
-    if (AllZeros)
+    if (GEPLHS->hasAllZeroIndices())
       return FoldGEPICmp(GEPRHS, GEPLHS->getOperand(0),
                          ICmpInst::getSwappedPredicate(Cond), I);
 
     // If the other GEP has all zero indices, recurse.
-    AllZeros = true;
-    for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i)
-      if (!isa<Constant>(GEPRHS->getOperand(i)) ||
-          !cast<Constant>(GEPRHS->getOperand(i))->isNullValue()) {
-        AllZeros = false;
-        break;
-      }
-    if (AllZeros)
+    if (GEPRHS->hasAllZeroIndices())
       return FoldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I);
 
     bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds();
@@ -754,21 +749,6 @@ Instruction *InstCombiner::FoldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
 Instruction *InstCombiner::FoldICmpAddOpCst(Instruction &ICI,
                                             Value *X, ConstantInt *CI,
                                             ICmpInst::Predicate Pred) {
-  // If we have X+0, exit early (simplifying logic below) and let it get folded
-  // elsewhere.   icmp X+0, X  -> icmp X, X
-  if (CI->isZero()) {
-    bool isTrue = ICmpInst::isTrueWhenEqual(Pred);
-    return ReplaceInstUsesWith(ICI, ConstantInt::get(ICI.getType(), isTrue));
-  }
-
-  // (X+4) == X -> false.
-  if (Pred == ICmpInst::ICMP_EQ)
-    return ReplaceInstUsesWith(ICI, Builder->getFalse());
-
-  // (X+4) != X -> true.
-  if (Pred == ICmpInst::ICMP_NE)
-    return ReplaceInstUsesWith(ICI, Builder->getTrue());
-
   // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
   // so the values can never be equal.  Similarly for all other "or equals"
   // operators.
@@ -1058,6 +1038,111 @@ Instruction *InstCombiner::FoldICmpShrCst(ICmpInst &ICI, BinaryOperator *Shr,
   return nullptr;
 }
 
+/// FoldICmpCstShrCst - Handle "(icmp eq/ne (ashr/lshr const2, A), const1)" ->
+/// (icmp eq/ne A, Log2(const2/const1)) ->
+/// (icmp eq/ne A, Log2(const2) - Log2(const1)).
+Instruction *InstCombiner::FoldICmpCstShrCst(ICmpInst &I, Value *Op, Value *A,
+                                             ConstantInt *CI1,
+                                             ConstantInt *CI2) {
+  assert(I.isEquality() && "Cannot fold icmp gt/lt");
+
+  auto getConstant = [&I, this](bool IsTrue) {
+    if (I.getPredicate() == I.ICMP_NE)
+      IsTrue = !IsTrue;
+    return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue));
+  };
+
+  auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
+    if (I.getPredicate() == I.ICMP_NE)
+      Pred = CmpInst::getInversePredicate(Pred);
+    return new ICmpInst(Pred, LHS, RHS);
+  };
+
+  APInt AP1 = CI1->getValue();
+  APInt AP2 = CI2->getValue();
+
+  // Don't bother doing any work for cases which InstSimplify handles.
+  if (AP2 == 0)
+    return nullptr;
+  bool IsAShr = isa<AShrOperator>(Op);
+  if (IsAShr) {
+    if (AP2.isAllOnesValue())
+      return nullptr;
+    if (AP2.isNegative() != AP1.isNegative())
+      return nullptr;
+    if (AP2.sgt(AP1))
+      return nullptr;
+  }
+
+  if (!AP1)
+    // 'A' must be large enough to shift out the highest set bit.
+    return getICmp(I.ICMP_UGT, A,
+                   ConstantInt::get(A->getType(), AP2.logBase2()));
+
+  if (AP1 == AP2)
+    return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
+
+  // Get the distance between the highest bit that's set.
+  int Shift;
+  // Both the constants are negative, take their positive to calculate log.
+  if (IsAShr && AP1.isNegative())
+    // Get the ones' complement of AP2 and AP1 when computing the distance.
+    Shift = (~AP2).logBase2() - (~AP1).logBase2();
+  else
+    Shift = AP2.logBase2() - AP1.logBase2();
+
+  if (Shift > 0) {
+    if (IsAShr ? AP1 == AP2.ashr(Shift) : AP1 == AP2.lshr(Shift))
+      return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
+  }
+  // Shifting const2 will never be equal to const1.
+  return getConstant(false);
+}
+
+/// FoldICmpCstShlCst - Handle "(icmp eq/ne (shl const2, A), const1)" ->
+/// (icmp eq/ne A, TrailingZeros(const1) - TrailingZeros(const2)).
+Instruction *InstCombiner::FoldICmpCstShlCst(ICmpInst &I, Value *Op, Value *A,
+                                             ConstantInt *CI1,
+                                             ConstantInt *CI2) {
+  assert(I.isEquality() && "Cannot fold icmp gt/lt");
+
+  auto getConstant = [&I, this](bool IsTrue) {
+    if (I.getPredicate() == I.ICMP_NE)
+      IsTrue = !IsTrue;
+    return ReplaceInstUsesWith(I, ConstantInt::get(I.getType(), IsTrue));
+  };
+
+  auto getICmp = [&I](CmpInst::Predicate Pred, Value *LHS, Value *RHS) {
+    if (I.getPredicate() == I.ICMP_NE)
+      Pred = CmpInst::getInversePredicate(Pred);
+    return new ICmpInst(Pred, LHS, RHS);
+  };
+
+  APInt AP1 = CI1->getValue();
+  APInt AP2 = CI2->getValue();
+
+  // Don't bother doing any work for cases which InstSimplify handles.
+  if (AP2 == 0)
+    return nullptr;
+
+  unsigned AP2TrailingZeros = AP2.countTrailingZeros();
+
+  if (!AP1 && AP2TrailingZeros != 0)
+    return getICmp(I.ICMP_UGE, A,
+                   ConstantInt::get(A->getType(), AP2.getBitWidth() - AP2TrailingZeros));
+
+  if (AP1 == AP2)
+    return getICmp(I.ICMP_EQ, A, ConstantInt::getNullValue(A->getType()));
+
+  // Get the distance between the lowest bits that are set.
+  int Shift = AP1.countTrailingZeros() - AP2TrailingZeros;
+
+  if (Shift > 0 && AP2.shl(Shift) == AP1)
+    return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift));
+
+  // Shifting const2 will never be equal to const1.
+  return getConstant(false);
+}
 
 /// visitICmpInstWithInstAndIntCst - Handle "icmp (instr, intcst)".
 ///
@@ -1074,7 +1159,7 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
       unsigned DstBits = LHSI->getType()->getPrimitiveSizeInBits(),
              SrcBits = LHSI->getOperand(0)->getType()->getPrimitiveSizeInBits();
       APInt KnownZero(SrcBits, 0), KnownOne(SrcBits, 0);
-      computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne);
+      computeKnownBits(LHSI->getOperand(0), KnownZero, KnownOne, 0, &ICI);
 
       // If all the high bits are known, we can do this xform.
       if ((KnownZero|KnownOne).countLeadingOnes() >= SrcBits-DstBits) {
@@ -1296,6 +1381,48 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
         return &ICI;
       }
 
+      // (icmp pred (and (or (lshr X, Y), X), 1), 0) -->
+      //    (icmp pred (and X, (or (shl 1, Y), 1), 0))
+      //
+      // iff pred isn't signed
+      {
+        Value *X, *Y, *LShr;
+        if (!ICI.isSigned() && RHSV == 0) {
+          if (match(LHSI->getOperand(1), m_One())) {
+            Constant *One = cast<Constant>(LHSI->getOperand(1));
+            Value *Or = LHSI->getOperand(0);
+            if (match(Or, m_Or(m_Value(LShr), m_Value(X))) &&
+                match(LShr, m_LShr(m_Specific(X), m_Value(Y)))) {
+              unsigned UsesRemoved = 0;
+              if (LHSI->hasOneUse())
+                ++UsesRemoved;
+              if (Or->hasOneUse())
+                ++UsesRemoved;
+              if (LShr->hasOneUse())
+                ++UsesRemoved;
+              Value *NewOr = nullptr;
+              // Compute X & ((1 << Y) | 1)
+              if (auto *C = dyn_cast<Constant>(Y)) {
+                if (UsesRemoved >= 1)
+                  NewOr =
+                      ConstantExpr::getOr(ConstantExpr::getNUWShl(One, C), One);
+              } else {
+                if (UsesRemoved >= 3)
+                  NewOr = Builder->CreateOr(Builder->CreateShl(One, Y,
+                                                               LShr->getName(),
+                                                               /*HasNUW=*/true),
+                                            One, Or->getName());
+              }
+              if (NewOr) {
+                Value *NewAnd = Builder->CreateAnd(X, NewOr, LHSI->getName());
+                ICI.setOperand(0, NewAnd);
+                return &ICI;
+              }
+            }
+          }
+        }
+      }
+
       // Replace ((X & AndCst) > RHSV) with ((X & AndCst) != 0), if any
       // bit set in (X & AndCst) will produce a result greater than RHSV.
       if (ICI.getPredicate() == ICmpInst::ICMP_UGT) {
@@ -1391,16 +1518,10 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
           unsigned RHSLog2 = RHSV.logBase2();
 
           // (1 << X) >= 2147483648 -> X >= 31 -> X == 31
-          // (1 << X) >  2147483648 -> X >  31 -> false
-          // (1 << X) <= 2147483648 -> X <= 31 -> true
           // (1 << X) <  2147483648 -> X <  31 -> X != 31
           if (RHSLog2 == TypeBits-1) {
             if (Pred == ICmpInst::ICMP_UGE)
               Pred = ICmpInst::ICMP_EQ;
-            else if (Pred == ICmpInst::ICMP_UGT)
-              return ReplaceInstUsesWith(ICI, Builder->getFalse());
-            else if (Pred == ICmpInst::ICMP_ULE)
-              return ReplaceInstUsesWith(ICI, Builder->getTrue());
             else if (Pred == ICmpInst::ICMP_ULT)
               Pred = ICmpInst::ICMP_NE;
           }
@@ -1435,10 +1556,6 @@ Instruction *InstCombiner::visitICmpInstWithInstAndIntCst(ICmpInst &ICI,
           if (RHSVIsPowerOf2)
             return new ICmpInst(
                 Pred, X, ConstantInt::get(RHS->getType(), RHSV.logBase2()));
-
-          return ReplaceInstUsesWith(
-              ICI, Pred == ICmpInst::ICMP_EQ ? Builder->getFalse()
-                                             : Builder->getTrue());
         }
       }
       break;
@@ -1946,8 +2063,8 @@ static Instruction *ProcessUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
   // sign-extended; check for that condition. For example, if CI2 is 2^31 and
   // the operands of the add are 64 bits wide, we need at least 33 sign bits.
   unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1;
-  if (IC.ComputeNumSignBits(A) < NeededSignBits ||
-      IC.ComputeNumSignBits(B) < NeededSignBits)
+  if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits ||
+      IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits)
     return nullptr;
 
   // In order to replace the original add with a narrower
@@ -2052,8 +2169,8 @@ static Instruction *ProcessUMulZExtIdiom(ICmpInst &I, Value *MulVal,
   Instruction *MulInstr = cast<Instruction>(MulVal);
   assert(MulInstr->getOpcode() == Instruction::Mul);
 
-  Instruction *LHS = cast<Instruction>(MulInstr->getOperand(0)),
-              *RHS = cast<Instruction>(MulInstr->getOperand(1));
+  auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)),
+       *RHS = cast<ZExtOperator>(MulInstr->getOperand(1));
   assert(LHS->getOpcode() == Instruction::ZExt);
   assert(RHS->getOpcode() == Instruction::ZExt);
   Value *A = LHS->getOperand(0), *B = RHS->getOperand(0);
@@ -2338,6 +2455,122 @@ static bool swapMayExposeCSEOpportunities(const Value * Op0,
   return GlobalSwapBenefits > 0;
 }
 
+/// \brief Check that one use is in the same block as the definition and all
+/// other uses are in blocks dominated by a given block
+///
+/// \param DI Definition
+/// \param UI Use
+/// \param DB Block that must dominate all uses of \p DI outside
+///           the parent block
+/// \return true when \p UI is the only use of \p DI in the parent block
+/// and all other uses of \p DI are in blocks dominated by \p DB.
+///
+bool InstCombiner::dominatesAllUses(const Instruction *DI,
+                                    const Instruction *UI,
+                                    const BasicBlock *DB) const {
+  assert(DI && UI && "Instruction not defined\n");
+  // ignore incomplete definitions
+  if (!DI->getParent())
+    return false;
+  // DI and UI must be in the same block
+  if (DI->getParent() != UI->getParent())
+    return false;
+  // Protect from self-referencing blocks
+  if (DI->getParent() == DB)
+    return false;
+  // DominatorTree available?
+  if (!DT)
+    return false;
+  for (const User *U : DI->users()) {
+    auto *Usr = cast<Instruction>(U);
+    if (Usr != UI && !DT->dominates(DB, Usr->getParent()))
+      return false;
+  }
+  return true;
+}
+
+///
+/// true when the instruction sequence within a block is select-cmp-br.
+///
+static bool isChainSelectCmpBranch(const SelectInst *SI) {
+  const BasicBlock *BB = SI->getParent();
+  if (!BB)
+    return false;
+  auto *BI = dyn_cast_or_null<BranchInst>(BB->getTerminator());
+  if (!BI || BI->getNumSuccessors() != 2)
+    return false;
+  auto *IC = dyn_cast<ICmpInst>(BI->getCondition());
+  if (!IC || (IC->getOperand(0) != SI && IC->getOperand(1) != SI))
+    return false;
+  return true;
+}
+
+///
+/// \brief True when a select result is replaced by one of its operands
+/// in select-icmp sequence. This will eventually result in the elimination
+/// of the select.
+///
+/// \param SI    Select instruction
+/// \param Icmp  Compare instruction
+/// \param SIOpd Operand that replaces the select
+///
+/// Notes:
+/// - The replacement is global and requires dominator information
+/// - The caller is responsible for the actual replacement
+///
+/// Example:
+///
+/// entry:
+///  %4 = select i1 %3, %C* %0, %C* null
+///  %5 = icmp eq %C* %4, null
+///  br i1 %5, label %9, label %7
+///  ...
+///  ; <label>:7                                       ; preds = %entry
+///  %8 = getelementptr inbounds %C* %4, i64 0, i32 0
+///  ...
+///
+/// can be transformed to
+///
+///  %5 = icmp eq %C* %0, null
+///  %6 = select i1 %3, i1 %5, i1 true
+///  br i1 %6, label %9, label %7
+///  ...
+///  ; <label>:7                                       ; preds = %entry
+///  %8 = getelementptr inbounds %C* %0, i64 0, i32 0  // replace by %0!
+///
+/// Similar when the first operand of the select is a constant or/and
+/// the compare is for not equal rather than equal.
+///
+/// NOTE: The function is only called when the select and compare constants
+/// are equal, the optimization can work only for EQ predicates. This is not a
+/// major restriction since a NE compare should be 'normalized' to an equal
+/// compare, which usually happens in the combiner and test case
+/// select-cmp-br.ll
+/// checks for it.
+bool InstCombiner::replacedSelectWithOperand(SelectInst *SI,
+                                             const ICmpInst *Icmp,
+                                             const unsigned SIOpd) {
+  assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!");
+  if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) {
+    BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1);
+    // The check for the unique predecessor is not the best that can be
+    // done. But it protects efficiently against cases like  when SI's
+    // home block has two successors, Succ and Succ1, and Succ1 predecessor
+    // of Succ. Then SI can't be replaced by SIOpd because the use that gets
+    // replaced can be reached on either path. So the uniqueness check
+    // guarantees that the path all uses of SI (outside SI's parent) are on
+    // is disjoint from all other paths out of SI. But that information
+    // is more expensive to compute, and the trade-off here is in favor
+    // of compile-time.
+    if (Succ->getUniquePredecessor() && dominatesAllUses(SI, Icmp, Succ)) {
+      NumSel++;
+      SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent());
+      return true;
+    }
+  }
+  return false;
+}
+
 Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
   bool Changed = false;
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
@@ -2355,7 +2588,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     Changed = true;
   }
 
-  if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL))
+  if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   // comparing -val or val with non-zero is the same as just comparing val
@@ -2483,6 +2716,21 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
                           Builder->getInt(CI->getValue()-1));
     }
 
+    if (I.isEquality()) {
+      ConstantInt *CI2;
+      if (match(Op0, m_AShr(m_ConstantInt(CI2), m_Value(A))) ||
+          match(Op0, m_LShr(m_ConstantInt(CI2), m_Value(A)))) {
+        // (icmp eq/ne (ashr/lshr const2, A), const1)
+        if (Instruction *Inst = FoldICmpCstShrCst(I, Op0, A, CI, CI2))
+          return Inst;
+      }
+      if (match(Op0, m_Shl(m_ConstantInt(CI2), m_Value(A)))) {
+        // (icmp eq/ne (shl const2, A), const1)
+        if (Instruction *Inst = FoldICmpCstShlCst(I, Op0, A, CI, CI2))
+          return Inst;
+      }
+    }
+
     // If this comparison is a normal comparison, it demands all
     // bits, if it is a sign bit comparison, it only demands the sign bit.
     bool UnusedBit;
@@ -2775,18 +3023,39 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
         // comparison into the select arms, which will cause one to be
         // constant folded and the select turned into a bitwise or.
         Value *Op1 = nullptr, *Op2 = nullptr;
-        if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1)))
+        ConstantInt *CI = 0;
+        if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) {
           Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
-        if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2)))
+          CI = dyn_cast<ConstantInt>(Op1);
+        }
+        if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) {
           Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
+          CI = dyn_cast<ConstantInt>(Op2);
+        }
 
         // We only want to perform this transformation if it will not lead to
         // additional code. This is true if either both sides of the select
         // fold to a constant (in which case the icmp is replaced with a select
         // which will usually simplify) or this is the only user of the
         // select (in which case we are trading a select+icmp for a simpler
-        // select+icmp).
-        if ((Op1 && Op2) || (LHSI->hasOneUse() && (Op1 || Op2))) {
+        // select+icmp) or all uses of the select can be replaced based on
+        // dominance information ("Global cases").
+        bool Transform = false;
+        if (Op1 && Op2)
+          Transform = true;
+        else if (Op1 || Op2) {
+          // Local case
+          if (LHSI->hasOneUse())
+            Transform = true;
+          // Global cases
+          else if (CI && !CI->isZero())
+            // When Op1 is constant try replacing select with second operand.
+            // Otherwise Op2 is constant and try replacing select with first
+            // operand.
+            Transform = replacedSelectWithOperand(cast<SelectInst>(LHSI), &I,
+                                                  Op1 ? 2 : 1);
+        }
+        if (Transform) {
           if (!Op1)
             Op1 = Builder->CreateICmp(I.getPredicate(), LHSI->getOperand(1),
                                       RHSC, I.getName());
@@ -2892,6 +3161,12 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     if (BO1 && BO1->getOpcode() == Instruction::Add)
       C = BO1->getOperand(0), D = BO1->getOperand(1);
 
+    // icmp (X+cst) < 0 --> X < -cst
+    if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero()))
+      if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B))
+        if (!RHSC->isMinValue(/*isSigned=*/true))
+          return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC));
+
     // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
     if ((A == Op1 || B == Op1) && NoOp0WrapProblem)
       return new ICmpInst(Pred, A == Op1 ? B : A,
@@ -3126,7 +3401,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     // and       (A & ~B) != 0 --> (A & B) == 0
     // if A is a power of 2.
     if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) &&
-        match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A) && I.isEquality())
+        match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(A, false,
+                                                       0, AT, &I, DT) &&
+                                I.isEquality())
       return new ICmpInst(I.getInversePredicate(),
                           Builder->CreateAnd(A, B),
                           Op1);
@@ -3287,6 +3564,22 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
     }
   }
 
+  // The 'cmpxchg' instruction returns an aggregate containing the old value and
+  // an i1 which indicates whether or not we successfully did the swap.
+  //
+  // Replace comparisons between the old value and the expected value with the
+  // indicator that 'cmpxchg' returns.
+  //
+  // N.B.  This transform is only valid when the 'cmpxchg' is not permitted to
+  // spuriously fail.  In those cases, the old value may equal the expected
+  // value but it is possible for the swap to not occur.
+  if (I.getPredicate() == ICmpInst::ICMP_EQ)
+    if (auto *EVI = dyn_cast<ExtractValueInst>(Op0))
+      if (auto *ACXI = dyn_cast<AtomicCmpXchgInst>(EVI->getAggregateOperand()))
+        if (EVI->getIndices()[0] == 0 && ACXI->getCompareOperand() == Op1 &&
+            !ACXI->isWeak())
+          return ExtractValueInst::Create(ACXI, 1);
+
   {
     Value *X; ConstantInt *Cst;
     // icmp X+Cst, X
@@ -3516,7 +3809,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
-  if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL))
+  if (Value *V = SimplifyFCmpInst(I.getPredicate(), Op0, Op1, DL, TLI, DT, AT))
     return ReplaceInstUsesWith(I, V);
 
   // Simplify 'fcmp pred X, X'