Reapply r237520 with another fix for infinite looping
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index b92d90d07895378a0e4e3a7cedb3303bd2f68571..d2fbcdd39915c8d0cba67be02aa9e66945becba7 100644 (file)
 #include "InstCombineInternal.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
 
 #define DEBUG_TYPE "instcombine"
 
-/// MatchSelectPattern - Pattern match integer [SU]MIN, [SU]MAX, and ABS idioms,
-/// returning the kind and providing the out parameter results if we
-/// successfully match.
 static SelectPatternFlavor
-MatchSelectPattern(Value *V, Value *&LHS, Value *&RHS) {
-  SelectInst *SI = dyn_cast<SelectInst>(V);
-  if (!SI) return SPF_UNKNOWN;
-
-  ICmpInst *ICI = dyn_cast<ICmpInst>(SI->getCondition());
-  if (!ICI) return SPF_UNKNOWN;
-
-  ICmpInst::Predicate Pred = ICI->getPredicate();
-  Value *CmpLHS = ICI->getOperand(0);
-  Value *CmpRHS = ICI->getOperand(1);
-  Value *TrueVal = SI->getTrueValue();
-  Value *FalseVal = SI->getFalseValue();
-
-  LHS = CmpLHS;
-  RHS = CmpRHS;
-
-  // (icmp X, Y) ? X : Y
-  if (TrueVal == CmpLHS && FalseVal == CmpRHS) {
-    switch (Pred) {
-    default: return SPF_UNKNOWN; // Equality.
-    case ICmpInst::ICMP_UGT:
-    case ICmpInst::ICMP_UGE: return SPF_UMAX;
-    case ICmpInst::ICMP_SGT:
-    case ICmpInst::ICMP_SGE: return SPF_SMAX;
-    case ICmpInst::ICMP_ULT:
-    case ICmpInst::ICMP_ULE: return SPF_UMIN;
-    case ICmpInst::ICMP_SLT:
-    case ICmpInst::ICMP_SLE: return SPF_SMIN;
-    }
-  }
-
-  // (icmp X, Y) ? Y : X
-  if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
-    switch (Pred) {
-    default: return SPF_UNKNOWN; // Equality.
-    case ICmpInst::ICMP_UGT:
-    case ICmpInst::ICMP_UGE: return SPF_UMIN;
-    case ICmpInst::ICMP_SGT:
-    case ICmpInst::ICMP_SGE: return SPF_SMIN;
-    case ICmpInst::ICMP_ULT:
-    case ICmpInst::ICMP_ULE: return SPF_UMAX;
-    case ICmpInst::ICMP_SLT:
-    case ICmpInst::ICMP_SLE: return SPF_SMAX;
-    }
+getInverseMinMaxSelectPattern(SelectPatternFlavor SPF) {
+  switch (SPF) {
+  default:
+    llvm_unreachable("unhandled!");
+
+  case SPF_SMIN:
+    return SPF_SMAX;
+  case SPF_UMIN:
+    return SPF_UMAX;
+  case SPF_SMAX:
+    return SPF_SMIN;
+  case SPF_UMAX:
+    return SPF_UMIN;
   }
+}
 
-  if (ConstantInt *C1 = dyn_cast<ConstantInt>(CmpRHS)) {
-    if ((CmpLHS == TrueVal && match(FalseVal, m_Neg(m_Specific(CmpLHS)))) ||
-        (CmpLHS == FalseVal && match(TrueVal, m_Neg(m_Specific(CmpLHS))))) {
-
-      // ABS(X) ==> (X >s 0) ? X : -X and (X >s -1) ? X : -X
-      // NABS(X) ==> (X >s 0) ? -X : X and (X >s -1) ? -X : X
-      if (Pred == ICmpInst::ICMP_SGT && (C1->isZero() || C1->isMinusOne())) {
-        return (CmpLHS == TrueVal) ? SPF_ABS : SPF_NABS;
-      }
-
-      // ABS(X) ==> (X <s 0) ? -X : X and (X <s 1) ? -X : X
-      // NABS(X) ==> (X <s 0) ? X : -X and (X <s 1) ? X : -X
-      if (Pred == ICmpInst::ICMP_SLT && (C1->isZero() || C1->isOne())) {
-        return (CmpLHS == FalseVal) ? SPF_ABS : SPF_NABS;
-      }
-    }
+static CmpInst::Predicate getICmpPredicateForMinMax(SelectPatternFlavor SPF) {
+  switch (SPF) {
+  default:
+    llvm_unreachable("unhandled!");
+
+  case SPF_SMIN:
+    return ICmpInst::ICMP_SLT;
+  case SPF_UMIN:
+    return ICmpInst::ICMP_ULT;
+  case SPF_SMAX:
+    return ICmpInst::ICMP_SGT;
+  case SPF_UMAX:
+    return ICmpInst::ICMP_UGT;
   }
-
-  // TODO: (X > 4) ? X : 5   -->  (X >= 5) ? X : 5  -->  MAX(X, 5)
-
-  return SPF_UNKNOWN;
 }
 
+static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy *Builder,
+                                          SelectPatternFlavor SPF, Value *A,
+                                          Value *B) {
+  CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF);
+  return Builder->CreateSelect(Builder->CreateICmp(Pred, A, B), A, B);
+}
 
 /// GetSelectFoldableOperands - We want to turn code that looks like this:
 ///   %C = or %A, %B
@@ -312,9 +279,9 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
 /// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
 /// replaced with RepOp.
 static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
-                                     const DataLayout *TD,
                                      const TargetLibraryInfo *TLI,
-                                     DominatorTree *DT, AssumptionCache *AC) {
+                                     const DataLayout &DL, DominatorTree *DT,
+                                     AssumptionCache *AC) {
   // Trivial replacement.
   if (V == Op)
     return RepOp;
@@ -326,18 +293,18 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   // If this is a binary operator, try to simplify it with the replaced op.
   if (BinaryOperator *B = dyn_cast<BinaryOperator>(I)) {
     if (B->getOperand(0) == Op)
-      return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD, TLI);
+      return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), DL, TLI);
     if (B->getOperand(1) == Op)
-      return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD, TLI);
+      return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, DL, TLI);
   }
 
   // Same for CmpInsts.
   if (CmpInst *C = dyn_cast<CmpInst>(I)) {
     if (C->getOperand(0) == Op)
-      return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD,
+      return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), DL,
                              TLI, DT, AC);
     if (C->getOperand(1) == Op)
-      return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD,
+      return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, DL,
                              TLI, DT, AC);
   }
 
@@ -361,14 +328,14 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
     if (ConstOps.size() == I->getNumOperands()) {
       if (CmpInst *C = dyn_cast<CmpInst>(I))
         return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
-                                               ConstOps[1], TD, TLI);
+                                               ConstOps[1], DL, TLI);
 
       if (LoadInst *LI = dyn_cast<LoadInst>(I))
         if (!LI->isVolatile())
-          return ConstantFoldLoadFromConstPtr(ConstOps[0], TD);
+          return ConstantFoldLoadFromConstPtr(ConstOps[0], DL);
 
-      return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
-                                      ConstOps, TD, TLI);
+      return ConstantFoldInstOperands(I->getOpcode(), I->getType(), ConstOps,
+                                      DL, TLI);
     }
   }
 
@@ -635,25 +602,25 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
   // arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
   if (Pred == ICmpInst::ICMP_EQ) {
-    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) ==
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
             TrueVal ||
-        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) ==
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
             TrueVal)
       return ReplaceInstUsesWith(SI, FalseVal);
-    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) ==
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
             FalseVal ||
-        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) ==
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
             FalseVal)
       return ReplaceInstUsesWith(SI, FalseVal);
   } else if (Pred == ICmpInst::ICMP_NE) {
-    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) ==
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
             FalseVal ||
-        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) ==
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
             FalseVal)
       return ReplaceInstUsesWith(SI, TrueVal);
-    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, DL, TLI, DT, AC) ==
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TLI, DL, DT, AC) ==
             TrueVal ||
-        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, DL, TLI, DT, AC) ==
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TLI, DL, DT, AC) ==
             TrueVal)
       return ReplaceInstUsesWith(SI, TrueVal);
   }
@@ -829,6 +796,52 @@ Instruction *InstCombiner::FoldSPFofSPF(Instruction *Inner,
         SI->getCondition(), SI->getFalseValue(), SI->getTrueValue());
     return ReplaceInstUsesWith(Outer, NewSI);
   }
+
+  auto IsFreeOrProfitableToInvert =
+      [&](Value *V, Value *&NotV, bool &ElidesXor) {
+    if (match(V, m_Not(m_Value(NotV)))) {
+      // If V has at most 2 uses then we can get rid of the xor operation
+      // entirely.
+      ElidesXor |= !V->hasNUsesOrMore(3);
+      return true;
+    }
+
+    if (IsFreeToInvert(V, !V->hasNUsesOrMore(3))) {
+      NotV = nullptr;
+      return true;
+    }
+
+    return false;
+  };
+
+  Value *NotA, *NotB, *NotC;
+  bool ElidesXor = false;
+
+  // MIN(MIN(~A, ~B), ~C) == ~MAX(MAX(A, B), C)
+  // MIN(MAX(~A, ~B), ~C) == ~MAX(MIN(A, B), C)
+  // MAX(MIN(~A, ~B), ~C) == ~MIN(MAX(A, B), C)
+  // MAX(MAX(~A, ~B), ~C) == ~MIN(MIN(A, B), C)
+  //
+  // This transform is performance neutral if we can elide at least one xor from
+  // the set of three operands, since we'll be tacking on an xor at the very
+  // end.
+  if (IsFreeOrProfitableToInvert(A, NotA, ElidesXor) &&
+      IsFreeOrProfitableToInvert(B, NotB, ElidesXor) &&
+      IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) {
+    if (!NotA)
+      NotA = Builder->CreateNot(A);
+    if (!NotB)
+      NotB = Builder->CreateNot(B);
+    if (!NotC)
+      NotC = Builder->CreateNot(C);
+
+    Value *NewInner = generateMinMaxSelectPattern(
+        Builder, getInverseMinMaxSelectPattern(SPF1), NotA, NotB);
+    Value *NewOuter = Builder->CreateNot(generateMinMaxSelectPattern(
+        Builder, getInverseMinMaxSelectPattern(SPF2), NewInner, NotC));
+    return ReplaceInstUsesWith(Outer, NewOuter);
+  }
+
   return nullptr;
 }
 
@@ -927,7 +940,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       return BinaryOperator::CreateAnd(NotCond, FalseVal);
     }
     if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) {
-      if (C->getZExtValue() == false) {
+      if (!C->getZExtValue()) {
         // Change: A = select B, C, false --> A = and B, C
         return BinaryOperator::CreateAnd(CondVal, TrueVal);
       }
@@ -1141,26 +1154,67 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       }
 
   // See if we can fold the select into one of our operands.
-  if (SI.getType()->isIntegerTy()) {
+  if (SI.getType()->isIntOrIntVectorTy()) {
     if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal))
       return FoldI;
 
-    // MAX(MAX(a, b), a) -> MAX(a, b)
-    // MIN(MIN(a, b), a) -> MIN(a, b)
-    // MAX(MIN(a, b), a) -> a
-    // MIN(MAX(a, b), a) -> a
     Value *LHS, *RHS, *LHS2, *RHS2;
-    if (SelectPatternFlavor SPF = MatchSelectPattern(&SI, LHS, RHS)) {
-      if (SelectPatternFlavor SPF2 = MatchSelectPattern(LHS, LHS2, RHS2))
+    Instruction::CastOps CastOp;
+    SelectPatternFlavor SPF = matchSelectPattern(&SI, LHS, RHS, &CastOp);
+
+    if (SPF) {
+      // Canonicalize so that type casts are outside select patterns.
+      if (LHS->getType()->getPrimitiveSizeInBits() !=
+          SI.getType()->getPrimitiveSizeInBits()) {
+        CmpInst::Predicate Pred = getICmpPredicateForMinMax(SPF);
+        Value *Cmp = Builder->CreateICmp(Pred, LHS, RHS);
+        Value *NewSI = Builder->CreateCast(CastOp,
+                                           Builder->CreateSelect(Cmp, LHS, RHS),
+                                           SI.getType());
+        return ReplaceInstUsesWith(SI, NewSI);
+      }
+
+      // MAX(MAX(a, b), a) -> MAX(a, b)
+      // MIN(MIN(a, b), a) -> MIN(a, b)
+      // MAX(MIN(a, b), a) -> a
+      // MIN(MAX(a, b), a) -> a
+      if (SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2))
         if (Instruction *R = FoldSPFofSPF(cast<Instruction>(LHS),SPF2,LHS2,RHS2,
                                           SI, SPF, RHS))
           return R;
-      if (SelectPatternFlavor SPF2 = MatchSelectPattern(RHS, LHS2, RHS2))
+      if (SelectPatternFlavor SPF2 = matchSelectPattern(RHS, LHS2, RHS2))
         if (Instruction *R = FoldSPFofSPF(cast<Instruction>(RHS),SPF2,LHS2,RHS2,
                                           SI, SPF, LHS))
           return R;
     }
 
+    // MAX(~a, ~b) -> ~MIN(a, b)
+    if (SPF == SPF_SMAX || SPF == SPF_UMAX) {
+      if (IsFreeToInvert(LHS, LHS->hasNUses(2)) &&
+          IsFreeToInvert(RHS, RHS->hasNUses(2))) {
+
+        // This transform adds a xor operation and that extra cost needs to be
+        // justified.  We look for simplifications that will result from
+        // applying this rule:
+
+        bool Profitable =
+            (LHS->hasNUses(2) && match(LHS, m_Not(m_Value()))) ||
+            (RHS->hasNUses(2) && match(RHS, m_Not(m_Value()))) ||
+            (SI.hasOneUse() && match(*SI.user_begin(), m_Not(m_Value())));
+
+        if (Profitable) {
+          Value *NewLHS = Builder->CreateNot(LHS);
+          Value *NewRHS = Builder->CreateNot(RHS);
+          Value *NewCmp = SPF == SPF_SMAX
+                              ? Builder->CreateICmpSLT(NewLHS, NewRHS)
+                              : Builder->CreateICmpULT(NewLHS, NewRHS);
+          Value *NewSI =
+              Builder->CreateNot(Builder->CreateSelect(NewCmp, NewLHS, NewRHS));
+          return ReplaceInstUsesWith(SI, NewSI);
+        }
+      }
+    }
+
     // TODO.
     // ABS(-X) -> ABS(X)
   }
@@ -1174,37 +1228,41 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
         return NV;
 
   if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
-    // select(C, select(C, a, b), c) -> select(C, a, c)
-    if (TrueSI->getCondition() == CondVal) {
-      if (SI.getTrueValue() == TrueSI->getTrueValue())
-        return nullptr;
-      SI.setOperand(1, TrueSI->getTrueValue());
-      return &SI;
-    }
-    // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
-    // We choose this as normal form to enable folding on the And and shortening
-    // paths for the values (this helps GetUnderlyingObjects() for example).
-    if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) {
-      Value *And = Builder->CreateAnd(CondVal, TrueSI->getCondition());
-      SI.setOperand(0, And);
-      SI.setOperand(1, TrueSI->getTrueValue());
-      return &SI;
+    if (TrueSI->getCondition()->getType() == CondVal->getType()) {
+      // select(C, select(C, a, b), c) -> select(C, a, c)
+      if (TrueSI->getCondition() == CondVal) {
+        if (SI.getTrueValue() == TrueSI->getTrueValue())
+          return nullptr;
+        SI.setOperand(1, TrueSI->getTrueValue());
+        return &SI;
+      }
+      // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b)
+      // We choose this as normal form to enable folding on the And and shortening
+      // paths for the values (this helps GetUnderlyingObjects() for example).
+      if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) {
+        Value *And = Builder->CreateAnd(CondVal, TrueSI->getCondition());
+        SI.setOperand(0, And);
+        SI.setOperand(1, TrueSI->getTrueValue());
+        return &SI;
+      }
     }
   }
   if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
-    // select(C, a, select(C, b, c)) -> select(C, a, c)
-    if (FalseSI->getCondition() == CondVal) {
-      if (SI.getFalseValue() == FalseSI->getFalseValue())
-        return nullptr;
-      SI.setOperand(2, FalseSI->getFalseValue());
-      return &SI;
-    }
-    // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b)
-    if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) {
-      Value *Or = Builder->CreateOr(CondVal, FalseSI->getCondition());
-      SI.setOperand(0, Or);
-      SI.setOperand(2, FalseSI->getFalseValue());
-      return &SI;
+    if (FalseSI->getCondition()->getType() == CondVal->getType()) {
+      // select(C, a, select(C, b, c)) -> select(C, a, c)
+      if (FalseSI->getCondition() == CondVal) {
+        if (SI.getFalseValue() == FalseSI->getFalseValue())
+          return nullptr;
+        SI.setOperand(2, FalseSI->getFalseValue());
+        return &SI;
+      }
+      // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b)
+      if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) {
+        Value *Or = Builder->CreateOr(CondVal, FalseSI->getCondition());
+        SI.setOperand(0, Or);
+        SI.setOperand(2, FalseSI->getFalseValue());
+        return &SI;
+      }
     }
   }