Move TargetData to DataLayout.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index 71a286ef69c47de6fbb230065d12b2c7998be8c0..ecce242a81e32c1c8948b0f8d8911166ec9aae80 100644 (file)
@@ -13,6 +13,7 @@
 
 #include "InstCombine.h"
 #include "llvm/Support/PatternMatch.h"
+#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 using namespace llvm;
 using namespace PatternMatch;
@@ -128,14 +129,19 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
     if (TI->isCast()) {
       if (TI->getOperand(0)->getType() != FI->getOperand(0)->getType())
         return 0;
+      // The select condition may be a vector. We may only change the operand
+      // type if the vector width remains the same (and matches the condition).
+      Type *CondTy = SI.getCondition()->getType();
+      if (CondTy->isVectorTy() && CondTy->getVectorNumElements() !=
+          FI->getOperand(0)->getType()->getVectorNumElements())
+        return 0;
     } else {
       return 0;  // unknown unary op.
     }
 
     // Fold this by inserting a select from the input values.
-    SelectInst *NewSI = SelectInst::Create(SI.getCondition(), TI->getOperand(0),
-                                          FI->getOperand(0), SI.getName()+".v");
-    InsertNewInstBefore(NewSI, SI);
+    Value *NewSI = Builder->CreateSelect(SI.getCondition(), TI->getOperand(0),
+                                         FI->getOperand(0), SI.getName()+".v");
     return CastInst::Create(Instruction::CastOps(TI->getOpcode()), NewSI,
                             TI->getType());
   }
@@ -174,9 +180,8 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
   }
 
   // If we reach here, they do have operations in common.
-  SelectInst *NewSI = SelectInst::Create(SI.getCondition(), OtherOpT,
-                                         OtherOpF, SI.getName()+".v");
-  InsertNewInstBefore(NewSI, SI);
+  Value *NewSI = Builder->CreateSelect(SI.getCondition(), OtherOpT,
+                                       OtherOpF, SI.getName()+".v");
 
   if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TI)) {
     if (MatchIsOpZero)
@@ -185,7 +190,6 @@ Instruction *InstCombiner::FoldSelectOpOp(SelectInst &SI, Instruction *TI,
       return BinaryOperator::Create(BO->getOpcode(), NewSI, MatchOp);
   }
   llvm_unreachable("Shouldn't get here");
-  return 0;
 }
 
 static bool isSelect01(Constant *C1, Constant *C2) {
@@ -214,7 +218,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
         unsigned OpToFold = 0;
         if ((SFO & 1) && FalseVal == TVI->getOperand(0)) {
           OpToFold = 1;
-        } else  if ((SFO & 2) && FalseVal == TVI->getOperand(1)) {
+        } else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) {
           OpToFold = 2;
         }
 
@@ -224,12 +228,18 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
           // Avoid creating select between 2 constants unless it's selecting
           // between 0, 1 and -1.
           if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) {
-            Instruction *NewSel = SelectInst::Create(SI.getCondition(), OOp, C);
-            InsertNewInstBefore(NewSel, SI);
+            Value *NewSel = Builder->CreateSelect(SI.getCondition(), OOp, C);
             NewSel->takeName(TVI);
-            if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TVI))
-              return BinaryOperator::Create(BO->getOpcode(), FalseVal, NewSel);
-            llvm_unreachable("Unknown instruction!!");
+            BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI);
+            BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(),
+                                                        FalseVal, NewSel);
+            if (isa<PossiblyExactOperator>(BO))
+              BO->setIsExact(TVI_BO->isExact());
+            if (isa<OverflowingBinaryOperator>(BO)) {
+              BO->setHasNoUnsignedWrap(TVI_BO->hasNoUnsignedWrap());
+              BO->setHasNoSignedWrap(TVI_BO->hasNoSignedWrap());
+            }
+            return BO;
           }
         }
       }
@@ -243,7 +253,7 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
         unsigned OpToFold = 0;
         if ((SFO & 1) && TrueVal == FVI->getOperand(0)) {
           OpToFold = 1;
-        } else  if ((SFO & 2) && TrueVal == FVI->getOperand(1)) {
+        } else if ((SFO & 2) && TrueVal == FVI->getOperand(1)) {
           OpToFold = 2;
         }
 
@@ -253,12 +263,18 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
           // Avoid creating select between 2 constants unless it's selecting
           // between 0, 1 and -1.
           if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) {
-            Instruction *NewSel = SelectInst::Create(SI.getCondition(), C, OOp);
-            InsertNewInstBefore(NewSel, SI);
+            Value *NewSel = Builder->CreateSelect(SI.getCondition(), C, OOp);
             NewSel->takeName(FVI);
-            if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FVI))
-              return BinaryOperator::Create(BO->getOpcode(), TrueVal, NewSel);
-            llvm_unreachable("Unknown instruction!!");
+            BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI);
+            BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(),
+                                                        TrueVal, NewSel);
+            if (isa<PossiblyExactOperator>(BO))
+              BO->setIsExact(FVI_BO->isExact());
+            if (isa<OverflowingBinaryOperator>(BO)) {
+              BO->setHasNoUnsignedWrap(FVI_BO->hasNoUnsignedWrap());
+              BO->setHasNoSignedWrap(FVI_BO->hasNoSignedWrap());
+            }
+            return BO;
           }
         }
       }
@@ -268,6 +284,67 @@ Instruction *InstCombiner::FoldSelectIntoOp(SelectInst &SI, Value *TrueVal,
   return 0;
 }
 
+/// SimplifyWithOpReplaced - See if V simplifies when its operand Op is
+/// replaced with RepOp.
+static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+                                     const DataLayout *TD,
+                                     const TargetLibraryInfo *TLI) {
+  // Trivial replacement.
+  if (V == Op)
+    return RepOp;
+
+  Instruction *I = dyn_cast<Instruction>(V);
+  if (!I)
+    return 0;
+
+  // If this is a binary operator, try to simplify it with the replaced op.
+  if (BinaryOperator *B = dyn_cast<BinaryOperator>(I)) {
+    if (B->getOperand(0) == Op)
+      return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), TD, TLI);
+    if (B->getOperand(1) == Op)
+      return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, TD, TLI);
+  }
+
+  // Same for CmpInsts.
+  if (CmpInst *C = dyn_cast<CmpInst>(I)) {
+    if (C->getOperand(0) == Op)
+      return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), TD,
+                             TLI);
+    if (C->getOperand(1) == Op)
+      return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, TD,
+                             TLI);
+  }
+
+  // TODO: We could hand off more cases to instsimplify here.
+
+  // If all operands are constant after substituting Op for RepOp then we can
+  // constant fold the instruction.
+  if (Constant *CRepOp = dyn_cast<Constant>(RepOp)) {
+    // Build a list of all constant operands.
+    SmallVector<Constant*, 8> ConstOps;
+    for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
+      if (I->getOperand(i) == Op)
+        ConstOps.push_back(CRepOp);
+      else if (Constant *COp = dyn_cast<Constant>(I->getOperand(i)))
+        ConstOps.push_back(COp);
+      else
+        break;
+    }
+
+    // All operands were constants, fold it.
+    if (ConstOps.size() == I->getNumOperands()) {
+      if (LoadInst *LI = dyn_cast<LoadInst>(I))
+        if (!LI->isVolatile())
+          return ConstantFoldLoadFromConstPtr(ConstOps[0], TD);
+
+      return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
+                                      ConstOps, TD, TLI);
+    }
+  }
+
+  return 0;
+}
+
 /// visitSelectInstWithICmp - Visit a SelectInst that has an
 /// ICmpInst as its first operand.
 ///
@@ -281,8 +358,8 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
   Value *FalseVal = SI.getFalseValue();
 
   // Check cases where the comparison is with a constant that
-  // can be adjusted to fit the min/max idiom. We may edit ICI in
-  // place here, so make sure the select is the only user.
+  // can be adjusted to fit the min/max idiom. We may move or edit ICI
+  // here, so make sure the select is the only user.
   if (ICI->hasOneUse())
     if (ConstantInt *CI = dyn_cast<ConstantInt>(CmpRHS)) {
       // X < MIN ? T : F  -->  F
@@ -300,7 +377,7 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
       case ICmpInst::ICMP_UGT:
       case ICmpInst::ICMP_SGT: {
         // These transformations only work for selects over integers.
-        const IntegerType *SelectTy = dyn_cast<IntegerType>(SI.getType());
+        IntegerType *SelectTy = dyn_cast<IntegerType>(SI.getType());
         if (!SelectTy)
           break;
 
@@ -364,6 +441,11 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
         ICI->setOperand(1, CmpRHS);
         SI.setOperand(1, TrueVal);
         SI.setOperand(2, FalseVal);
+
+        // Move ICI instruction right before the select instruction. Otherwise
+        // the sext/zext value may be defined after the ICI instruction uses it.
+        ICI->moveBefore(&SI);
+
         Changed = true;
         break;
       }
@@ -375,7 +457,7 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
   // FIXME: Type and constness constraints could be lifted, but we have to
   //        watch code size carefully. We should consider xor instead of
   //        sub/add when we decide to do that.
-  if (const IntegerType *Ty = dyn_cast<IntegerType>(CmpLHS->getType())) {
+  if (IntegerType *Ty = dyn_cast<IntegerType>(CmpLHS->getType())) {
     if (TrueVal->getType() == Ty) {
       if (ConstantInt *Cmp = dyn_cast<ConstantInt>(CmpRHS)) {
         ConstantInt *C1 = NULL, *C2 = NULL;
@@ -401,24 +483,39 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
     }
   }
 
-  if (CmpLHS == TrueVal && CmpRHS == FalseVal) {
-    // Transform (X == Y) ? X : Y  -> Y
-    if (Pred == ICmpInst::ICMP_EQ)
+  // If we have an equality comparison then we know the value in one of the
+  // arms of the select. See if substituting this value into the arm and
+  // simplifying the result yields the same value as the other arm.
+  if (Pred == ICmpInst::ICMP_EQ) {
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD, TLI) == TrueVal ||
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD, TLI) == TrueVal)
       return ReplaceInstUsesWith(SI, FalseVal);
-    // Transform (X != Y) ? X : Y  -> X
-    if (Pred == ICmpInst::ICMP_NE)
-      return ReplaceInstUsesWith(SI, TrueVal);
-    /// NOTE: if we wanted to, this is where to detect integer MIN/MAX
-
-  } else if (CmpLHS == FalseVal && CmpRHS == TrueVal) {
-    // Transform (X == Y) ? Y : X  -> X
-    if (Pred == ICmpInst::ICMP_EQ)
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD, TLI) == FalseVal ||
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD, TLI) == FalseVal)
       return ReplaceInstUsesWith(SI, FalseVal);
-    // Transform (X != Y) ? Y : X  -> Y
-    if (Pred == ICmpInst::ICMP_NE)
+  } else if (Pred == ICmpInst::ICMP_NE) {
+    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, TD, TLI) == FalseVal ||
+        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, TD, TLI) == FalseVal)
+      return ReplaceInstUsesWith(SI, TrueVal);
+    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, TD, TLI) == TrueVal ||
+        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, TD, TLI) == TrueVal)
       return ReplaceInstUsesWith(SI, TrueVal);
-    /// NOTE: if we wanted to, this is where to detect integer MIN/MAX
   }
+
+  // NOTE: if we wanted to, this is where to detect integer MIN/MAX
+
+  if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) {
+    if (CmpLHS == TrueVal && Pred == ICmpInst::ICMP_EQ) {
+      // Transform (X == C) ? X : Y -> (X == C) ? C : Y
+      SI.setOperand(1, CmpRHS);
+      Changed = true;
+    } else if (CmpLHS == FalseVal && Pred == ICmpInst::ICMP_NE) {
+      // Transform (X != C) ? Y : X -> (X != C) ? Y : C
+      SI.setOperand(2, CmpRHS);
+      Changed = true;
+    }
+  }
+
   return Changed ? &SI : 0;
 }
 
@@ -498,9 +595,8 @@ static Value *foldSelectICmpAnd(const SelectInst &SI, ConstantInt *TrueVal,
   if (!IC || !IC->isEquality())
     return 0;
 
-  if (ConstantInt *C = dyn_cast<ConstantInt>(IC->getOperand(1)))
-    if (!C->isZero())
-      return 0;
+  if (!match(IC->getOperand(1), m_Zero()))
+    return 0;
 
   ConstantInt *AndRHS;
   Value *LHS = IC->getOperand(0);
@@ -573,9 +669,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
         return BinaryOperator::CreateOr(CondVal, FalseVal);
       }
       // Change: A = select B, false, C --> A = and !B, C
-      Value *NotCond =
-        InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
-                                           "not."+CondVal->getName()), SI);
+      Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
       return BinaryOperator::CreateAnd(NotCond, FalseVal);
     } else if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) {
       if (C->getZExtValue() == false) {
@@ -583,9 +677,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
         return BinaryOperator::CreateAnd(CondVal, TrueVal);
       }
       // Change: A = select B, C, true --> A = or !B, C
-      Value *NotCond =
-        InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
-                                           "not."+CondVal->getName()), SI);
+      Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
       return BinaryOperator::CreateOr(NotCond, TrueVal);
     }
 
@@ -595,6 +687,13 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       return BinaryOperator::CreateOr(CondVal, FalseVal);
     else if (CondVal == FalseVal)
       return BinaryOperator::CreateAnd(CondVal, TrueVal);
+
+    // select a, ~a, b -> (~a)&b
+    // select a, b, ~a -> (~a)|b
+    if (match(TrueVal, m_Not(m_Specific(CondVal))))
+      return BinaryOperator::CreateAnd(TrueVal, FalseVal);
+    else if (match(FalseVal, m_Not(m_Specific(CondVal))))
+      return BinaryOperator::CreateOr(TrueVal, FalseVal);
   }
 
   // Selecting between two integer constants?
@@ -724,28 +823,21 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
             // So at this point we know we have (Y -> OtherAddOp):
             //        select C, (add X, Y), (sub X, Z)
             Value *NegVal;  // Compute -Z
-            if (Constant *C = dyn_cast<Constant>(SubOp->getOperand(1))) {
-              NegVal = ConstantExpr::getNeg(C);
-            } else if (SI.getType()->isFloatingPointTy()) {
-              NegVal = InsertNewInstBefore(
-                    BinaryOperator::CreateFNeg(SubOp->getOperand(1),
-                                              "tmp"), SI);
+            if (SI.getType()->isFPOrFPVectorTy()) {
+              NegVal = Builder->CreateFNeg(SubOp->getOperand(1));
             } else {
-              NegVal = InsertNewInstBefore(
-                    BinaryOperator::CreateNeg(SubOp->getOperand(1),
-                                              "tmp"), SI);
+              NegVal = Builder->CreateNeg(SubOp->getOperand(1));
             }
 
             Value *NewTrueOp = OtherAddOp;
             Value *NewFalseOp = NegVal;
             if (AddOp != TI)
               std::swap(NewTrueOp, NewFalseOp);
-            Instruction *NewSel =
-              SelectInst::Create(CondVal, NewTrueOp,
-                                 NewFalseOp, SI.getName() + ".p");
+            Value *NewSel = 
+              Builder->CreateSelect(CondVal, NewTrueOp,
+                                    NewFalseOp, SI.getName() + ".p");
 
-            NewSel = InsertNewInstBefore(NewSel, SI);
-            if (SI.getType()->isFloatingPointTy())
+            if (SI.getType()->isFPOrFPVectorTy())
               return BinaryOperator::CreateFAdd(SubOp->getOperand(0), NewSel);
             else
               return BinaryOperator::CreateAdd(SubOp->getOperand(0), NewSel);
@@ -787,6 +879,23 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       if (Instruction *NV = FoldOpIntoPhi(SI))
         return NV;
 
+  if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
+    if (TrueSI->getCondition() == CondVal) {
+      if (SI.getTrueValue() == TrueSI->getTrueValue())
+        return 0;
+      SI.setOperand(1, TrueSI->getTrueValue());
+      return &SI;
+    }
+  }
+  if (SelectInst *FalseSI = dyn_cast<SelectInst>(FalseVal)) {
+    if (FalseSI->getCondition() == CondVal) {
+      if (SI.getFalseValue() == FalseSI->getFalseValue())
+        return 0;
+      SI.setOperand(2, FalseSI->getFalseValue());
+      return &SI;
+    }
+  }
+
   if (BinaryOperator::isNot(CondVal)) {
     SI.setOperand(0, BinaryOperator::getNotArgument(CondVal));
     SI.setOperand(1, FalseVal);
@@ -794,5 +903,38 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
     return &SI;
   }
 
+  if (VectorType *VecTy = dyn_cast<VectorType>(SI.getType())) {
+    unsigned VWidth = VecTy->getNumElements();
+    APInt UndefElts(VWidth, 0);
+    APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth));
+    if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) {
+      if (V != &SI)
+        return ReplaceInstUsesWith(SI, V);
+      return &SI;
+    }
+
+    if (ConstantVector *CV = dyn_cast<ConstantVector>(CondVal)) {
+      // Form a shufflevector instruction.
+      SmallVector<Constant *, 8> Mask(VWidth);
+      Type *Int32Ty = Type::getInt32Ty(CV->getContext());
+      for (unsigned i = 0; i != VWidth; ++i) {
+        Constant *Elem = cast<Constant>(CV->getOperand(i));
+        if (ConstantInt *E = dyn_cast<ConstantInt>(Elem))
+          Mask[i] = ConstantInt::get(Int32Ty, i + (E->isZero() ? VWidth : 0));
+        else if (isa<UndefValue>(Elem))
+          Mask[i] = UndefValue::get(Int32Ty);
+        else
+          return 0;
+      }
+      Constant *MaskVal = ConstantVector::get(Mask);
+      Value *V = Builder->CreateShuffleVector(TrueVal, FalseVal, MaskVal);
+      return ReplaceInstUsesWith(SI, V);
+    }
+
+    if (isa<ConstantAggregateZero>(CondVal)) {
+      return ReplaceInstUsesWith(SI, FalseVal);
+    }
+  }
+
   return 0;
 }