Use the new script to sort the includes of every file under lib.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index eb463902d6688e8094b9609782dc00b5ab7d7230..a262d711d3b42fccea2c261fdb9b98353095f3ec 100644 (file)
@@ -12,8 +12,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
-#include "llvm/Support/PatternMatch.h"
+#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Support/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -128,6 +129,12 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
     if (TI->isCast()) {
       if (TI->getOperand(0)->getType() != FI->getOperand(0)->getType())
         return 0;
+      // The select condition may be a vector. We may only change the operand
+      // type if the vector width remains the same (and matches the condition).
+      Type *CondTy = SI.getCondition()->getType();
+      if (CondTy->isVectorTy() && CondTy->getVectorNumElements() !=
+          FI->getOperand(0)->getType()->getVectorNumElements())
+        return 0;
     } else {
       return 0;  // unknown unary op.
     }
@@ -183,7 +190,6 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
       return BinaryOperator::Create(BO->getOpcode(), NewSI, MatchOp);
   }
   llvm_unreachable("Shouldn't get here");
-  return 0;
 }
 
 static bool isSelect01(Constant *C1, Constant *C2) {
@@ -281,7 +287,8 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
 /// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
 /// replaced with RepOp.
 static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
-                                     const TargetData *TD) {
+                                     const DataLayout *TD,
+                                     const TargetLibraryInfo *TLI) {
   // Trivial replacement.
   if (V == Op)
     return RepOp;
@@ -293,17 +300,19 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   // If this is a binary operator, try to simplify it with the replaced op.
   if (BinaryOperator *B = dyn_cast<BinaryOperator>(I)) {
     if (B->getOperand(0) == Op)
-      return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD);
+      return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD, TLI);
     if (B->getOperand(1) == Op)
-      return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD);
+      return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD, TLI);
   }
 
   // Same for CmpInsts.
   if (CmpInst *C = dyn_cast<CmpInst>(I)) {
     if (C->getOperand(0) == Op)
-      return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD);
+      return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD,
+                             TLI);
     if (C->getOperand(1) == Op)
-      return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD);
+      return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD,
+                             TLI);
   }
 
   // TODO: We could hand off more cases to instsimplify here.
@@ -323,9 +332,18 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
     }
 
     // All operands were constants, fold it.
-    if (ConstOps.size() == I->getNumOperands())
+    if (ConstOps.size() == I->getNumOperands()) {
+      if (CmpInst *C = dyn_cast<CmpInst>(I))
+        return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
+                                               ConstOps[1], TD, TLI);
+
+      if (LoadInst *LI = dyn_cast<LoadInst>(I))
+        if (!LI->isVolatile())
+          return ConstantFoldLoadFromConstPtr(ConstOps[0], TD);
+
       return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
-                                      ConstOps.data(), ConstOps.size(), TD);
+                                      ConstOps, TD, TLI);
+    }
   }
 
   return 0;
@@ -473,18 +491,24 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
   // arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
   if (Pred == ICmpInst::ICMP_EQ) {
-    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD) == TrueVal ||
-        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD) == TrueVal)
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD, TLI) == TrueVal ||
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD, TLI) == TrueVal)
+      return ReplaceInstUsesWith(SI, FalseVal);
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD, TLI) == FalseVal ||
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD, TLI) == FalseVal)
       return ReplaceInstUsesWith(SI, FalseVal);
   } else if (Pred == ICmpInst::ICMP_NE) {
-    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD) == FalseVal ||
-        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD) == FalseVal)
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD, TLI) == FalseVal ||
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD, TLI) == FalseVal)
+      return ReplaceInstUsesWith(SI, TrueVal);
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD, TLI) == TrueVal ||
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD, TLI) == TrueVal)
       return ReplaceInstUsesWith(SI, TrueVal);
   }
 
   // NOTE: if we wanted to, this is where to detect integer MIN/MAX
 
-  if (isa<Constant>(CmpRHS)) {
+  if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
     if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) {
       // Transform (X == C) ? X : Y -> (X == C) ? C : Y
       SI.setOperand(1, CmpRHS);
@@ -667,6 +691,13 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       return BinaryOperator::CreateOr(CondVal, FalseVal);
     else if (CondVal == FalseVal)
       return BinaryOperator::CreateAnd(CondVal, TrueVal);
+
+    // select a, ~a, b -> (~a)&b
+    // select a, b, ~a -> (~a)|b
+    if (match(TrueVal, m_Not(m_Specific(CondVal))))
+      return BinaryOperator::CreateAnd(TrueVal, FalseVal);
+    else if (match(FalseVal, m_Not(m_Specific(CondVal))))
+      return BinaryOperator::CreateOr(TrueVal, FalseVal);
   }
 
   // Selecting between two integer constants?
@@ -854,12 +885,16 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
 
   if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
     if (TrueSI->getCondition() == CondVal) {
+      if (SI.getTrueValue() == TrueSI->getTrueValue())
+        return 0;
       SI.setOperand(1, TrueSI->getTrueValue());
       return &SI;
     }
   }
   if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
     if (FalseSI->getCondition() == CondVal) {
+      if (SI.getFalseValue() == FalseSI->getFalseValue())
+        return 0;
       SI.setOperand(2, FalseSI->getFalseValue());
       return &SI;
     }
@@ -872,5 +907,38 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
     return &SI;
   }
 
+  if (VectorType *VecTy = dyn_cast<VectorType>(SI.getType())) {
+    unsigned VWidth = VecTy->getNumElements();
+    APInt UndefElts(VWidth, 0);
+    APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth));
+    if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) {
+      if (V != &SI)
+        return ReplaceInstUsesWith(SI, V);
+      return &SI;
+    }
+
+    if (ConstantVector *CV = dyn_cast<ConstantVector>(CondVal)) {
+      // Form a shufflevector instruction.
+      SmallVector<Constant *, 8> Mask(VWidth);
+      Type *Int32Ty = Type::getInt32Ty(CV->getContext());
+      for (unsigned i = 0; i != VWidth; ++i) {
+        Constant *Elem = cast<Constant>(CV->getOperand(i));
+        if (ConstantInt *E = dyn_cast<ConstantInt>(Elem))
+          Mask[i] = ConstantInt::get(Int32Ty, i + (E->isZero() ? VWidth : 0));
+        else if (isa<UndefValue>(Elem))
+          Mask[i] = UndefValue::get(Int32Ty);
+        else
+          return 0;
+      }
+      Constant *MaskVal = ConstantVector::get(Mask);
+      Value *V = Builder->CreateShuffleVector(TrueVal, FalseVal, MaskVal);
+      return ReplaceInstUsesWith(SI, V);
+    }
+
+    if (isa<ConstantAggregateZero>(CondVal)) {
+      return ReplaceInstUsesWith(SI, FalseVal);
+    }
+  }
+
   return 0;
 }