Use the new script to sort the includes of every file under lib.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index e727b2c592db31da7f0f42703d2d990d202ed040..a262d711d3b42fccea2c261fdb9b98353095f3ec 100644 (file)
@@ -12,9 +12,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
-#include "llvm/Support/PatternMatch.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Support/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -129,6 +129,12 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
     if (TI->isCast()) {
       if (TI->getOperand(0)->getType() != FI->getOperand(0)->getType())
         return 0;
+      // The select condition may be a vector. We may only change the operand
+      // type if the vector width remains the same (and matches the condition).
+      Type *CondTy = SI.getCondition()->getType();
+      if (CondTy->isVectorTy() && CondTy->getVectorNumElements() !=
+          FI->getOperand(0)->getType()->getVectorNumElements())
+        return 0;
     } else {
       return 0;  // unknown unary op.
     }
@@ -281,7 +287,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
 /// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
 /// replaced with RepOp.
 static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
-                                     const TargetData *TD,
+                                     const DataLayout *TD,
                                      const TargetLibraryInfo *TLI) {
   // Trivial replacement.
   if (V == Op)
@@ -327,6 +333,10 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
 
     // All operands were constants, fold it.
     if (ConstOps.size() == I->getNumOperands()) {
+      if (CmpInst *C = dyn_cast<CmpInst>(I))
+        return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
+                                               ConstOps[1], TD, TLI);
+
       if (LoadInst *LI = dyn_cast<LoadInst>(I))
         if (!LI->isVolatile())
           return ConstantFoldLoadFromConstPtr(ConstOps[0], TD);
@@ -498,7 +508,7 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
 
   // NOTE: if we wanted to, this is where to detect integer MIN/MAX
 
-  if (isa<Constant>(CmpRHS)) {
+  if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
     if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) {
       // Transform (X == C) ? X : Y -> (X == C) ? C : Y
       SI.setOperand(1, CmpRHS);
@@ -875,12 +885,16 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
 
   if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
     if (TrueSI->getCondition() == CondVal) {
+      if (SI.getTrueValue() == TrueSI->getTrueValue())
+        return 0;
       SI.setOperand(1, TrueSI->getTrueValue());
       return &SI;
     }
   }
   if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
     if (FalseSI->getCondition() == CondVal) {
+      if (SI.getFalseValue() == FalseSI->getFalseValue())
+        return 0;
       SI.setOperand(2, FalseSI->getFalseValue());
       return &SI;
     }
@@ -893,5 +907,38 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
     return &SI;
   }
 
+  if (VectorType *VecTy = dyn_cast<VectorType>(SI.getType())) {
+    unsigned VWidth = VecTy->getNumElements();
+    APInt UndefElts(VWidth, 0);
+    APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth));
+    if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) {
+      if (V != &SI)
+        return ReplaceInstUsesWith(SI, V);
+      return &SI;
+    }
+
+    if (ConstantVector *CV = dyn_cast<ConstantVector>(CondVal)) {
+      // Form a shufflevector instruction.
+      SmallVector<Constant *, 8> Mask(VWidth);
+      Type *Int32Ty = Type::getInt32Ty(CV->getContext());
+      for (unsigned i = 0; i != VWidth; ++i) {
+        Constant *Elem = cast<Constant>(CV->getOperand(i));
+        if (ConstantInt *E = dyn_cast<ConstantInt>(Elem))
+          Mask[i] = ConstantInt::get(Int32Ty, i + (E->isZero() ? VWidth : 0));
+        else if (isa<UndefValue>(Elem))
+          Mask[i] = UndefValue::get(Int32Ty);
+        else
+          return 0;
+      }
+      Constant *MaskVal = ConstantVector::get(Mask);
+      Value *V = Builder->CreateShuffleVector(TrueVal, FalseVal, MaskVal);
+      return ReplaceInstUsesWith(SI, V);
+    }
+
+    if (isa<ConstantAggregateZero>(CondVal)) {
+      return ReplaceInstUsesWith(SI, FalseVal);
+    }
+  }
+
   return 0;
 }