Use the new script to sort the includes of every file under lib.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index eb9945b681727c913418d5520b59253e41ca9209..a262d711d3b42fccea2c261fdb9b98353095f3ec 100644 (file)
@@ -12,9 +12,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
-#include "llvm/Support/PatternMatch.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Support/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -287,7 +287,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
 /// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
 /// replaced with RepOp.
 static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
-                                     const TargetData *TD,
+                                     const DataLayout *TD,
                                      const TargetLibraryInfo *TLI) {
   // Trivial replacement.
   if (V == Op)
@@ -333,6 +333,10 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
 
     // All operands were constants, fold it.
     if (ConstOps.size() == I->getNumOperands()) {
+      if (CmpInst *C = dyn_cast<CmpInst>(I))
+        return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0],
+                                               ConstOps[1], TD, TLI);
+
       if (LoadInst *LI = dyn_cast<LoadInst>(I))
         if (!LI->isVolatile())
           return ConstantFoldLoadFromConstPtr(ConstOps[0], TD);
@@ -881,12 +885,16 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
 
   if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
     if (TrueSI->getCondition() == CondVal) {
+      if (SI.getTrueValue() == TrueSI->getTrueValue())
+        return 0;
       SI.setOperand(1, TrueSI->getTrueValue());
       return &SI;
     }
   }
   if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
     if (FalseSI->getCondition() == CondVal) {
+      if (SI.getFalseValue() == FalseSI->getFalseValue())
+        return 0;
       SI.setOperand(2, FalseSI->getFalseValue());
       return &SI;
     }
@@ -899,5 +907,38 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
     return &SI;
   }
 
+  if (VectorType *VecTy = dyn_cast<VectorType>(SI.getType())) {
+    unsigned VWidth = VecTy->getNumElements();
+    APInt UndefElts(VWidth, 0);
+    APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth));
+    if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) {
+      if (V != &SI)
+        return ReplaceInstUsesWith(SI, V);
+      return &SI;
+    }
+
+    if (ConstantVector *CV = dyn_cast<ConstantVector>(CondVal)) {
+      // Form a shufflevector instruction.
+      SmallVector<Constant *, 8> Mask(VWidth);
+      Type *Int32Ty = Type::getInt32Ty(CV->getContext());
+      for (unsigned i = 0; i != VWidth; ++i) {
+        Constant *Elem = cast<Constant>(CV->getOperand(i));
+        if (ConstantInt *E = dyn_cast<ConstantInt>(Elem))
+          Mask[i] = ConstantInt::get(Int32Ty, i + (E->isZero() ? VWidth : 0));
+        else if (isa<UndefValue>(Elem))
+          Mask[i] = UndefValue::get(Int32Ty);
+        else
+          return 0;
+      }
+      Constant *MaskVal = ConstantVector::get(Mask);
+      Value *V = Builder->CreateShuffleVector(TrueVal, FalseVal, MaskVal);
+      return ReplaceInstUsesWith(SI, V);
+    }
+
+    if (isa<ConstantAggregateZero>(CondVal)) {
+      return ReplaceInstUsesWith(SI, FalseVal);
+    }
+  }
+
   return 0;
 }