Remove the rest of my instcombine changes. Back to the drawing board on this one.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index 73ff00fe25ea6b62efb681952bc4241e4d1befa1..c44fe9db6e3a7689883f6711af0aca3dab80014f 100644 (file)
@@ -13,6 +13,7 @@
 
 #include "InstCombine.h"
 #include "llvm/Support/PatternMatch.h"
+#include "llvm/Analysis/InstructionSimplify.h"
 using namespace llvm;
 using namespace PatternMatch;
 
@@ -326,45 +327,38 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
         break;
       }
       }
+    }
 
-      // (x <s 0) ? -1 : 0 -> ashr x, 31   -> all ones if signed
-      // (x >s -1) ? -1 : 0 -> ashr x, 31  -> all ones if not signed
-      CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
-      if (match(TrueVal, m_ConstantInt<-1>()) &&
-          match(FalseVal, m_ConstantInt<0>()))
-        Pred = ICI->getPredicate();
-      else if (match(TrueVal, m_ConstantInt<0>()) &&
-               match(FalseVal, m_ConstantInt<-1>()))
-        Pred = CmpInst::getInversePredicate(ICI->getPredicate());
-      
-      if (Pred != CmpInst::BAD_ICMP_PREDICATE) {
-        // If we are just checking for a icmp eq of a single bit and zext'ing it
-        // to an integer, then shift the bit to the appropriate place and then
-        // cast to integer to avoid the comparison.
-        const APInt &Op1CV = CI->getValue();
-    
-        // sext (x <s  0) to i32 --> x>>s31      true if signbit set.
-        // sext (x >s -1) to i32 --> (x>>s31)^-1  true if signbit clear.
-        if ((Pred == ICmpInst::ICMP_SLT && Op1CV == 0) ||
-            (Pred == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) {
-          Value *In = ICI->getOperand(0);
-          Value *Sh = ConstantInt::get(In->getType(),
-                                       In->getType()->getScalarSizeInBits()-1);
-          In = InsertNewInstBefore(BinaryOperator::CreateAShr(In, Sh,
-                                                        In->getName()+".lobit"),
-                                   *ICI);
-          if (In->getType() != SI.getType())
-            In = CastInst::CreateIntegerCast(In, SI.getType(),
-                                             true/*SExt*/, "tmp", ICI);
-    
-          if (Pred == ICmpInst::ICMP_SGT)
-            In = InsertNewInstBefore(BinaryOperator::CreateNot(In,
-                                       In->getName()+".not"), *ICI);
-    
-          return ReplaceInstUsesWith(SI, In);
+  // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1
+  // and       (X <s  0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1
+  // FIXME: Type and constness constraints could be lifted, but we have to
+  //        watch code size carefully. We should consider xor instead of
+  //        sub/add when we decide to do that.
+  if (const IntegerType *Ty = dyn_cast<IntegerType>(CmpLHS->getType())) {
+    if (TrueVal->getType() == Ty) {
+      if (ConstantInt *Cmp = dyn_cast<ConstantInt>(CmpRHS)) {
+        ConstantInt *C1 = NULL, *C2 = NULL;
+        if (Pred == ICmpInst::ICMP_SGT && Cmp->isAllOnesValue()) {
+          C1 = dyn_cast<ConstantInt>(TrueVal);
+          C2 = dyn_cast<ConstantInt>(FalseVal);
+        } else if (Pred == ICmpInst::ICMP_SLT && Cmp->isNullValue()) {
+          C1 = dyn_cast<ConstantInt>(FalseVal);
+          C2 = dyn_cast<ConstantInt>(TrueVal);
+        }
+        if (C1 && C2) {
+          // This shift results in either -1 or 0.
+          Value *AShr = Builder->CreateAShr(CmpLHS, Ty->getBitWidth()-1);
+
+          // Check if we can express the operation with a single or.
+          if (C2->isAllOnesValue())
+            return ReplaceInstUsesWith(SI, Builder->CreateOr(AShr, C1));
+
+          Value *And = Builder->CreateAnd(AShr, C2->getValue()-C1->getValue());
+          return ReplaceInstUsesWith(SI, Builder->CreateAdd(And, C1));
         }
       }
     }
+  }
 
   if (CmpLHS == TrueVal && CmpRHS == FalseVal) {
     // Transform (X == Y) ? X : Y  -> Y
@@ -459,49 +453,30 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
   Value *TrueVal = SI.getTrueValue();
   Value *FalseVal = SI.getFalseValue();
 
-  // select true, X, Y  -> X
-  // select false, X, Y -> Y
-  if (ConstantInt *C = dyn_cast<ConstantInt>(CondVal))
-    return ReplaceInstUsesWith(SI, C->getZExtValue() ? TrueVal : FalseVal);
-
-  // select C, X, X -> X
-  if (TrueVal == FalseVal)
-    return ReplaceInstUsesWith(SI, TrueVal);
-
-  if (isa<UndefValue>(TrueVal))   // select C, undef, X -> X
-    return ReplaceInstUsesWith(SI, FalseVal);
-  if (isa<UndefValue>(FalseVal))   // select C, X, undef -> X
-    return ReplaceInstUsesWith(SI, TrueVal);
-  if (isa<UndefValue>(CondVal)) {  // select undef, X, Y -> X or Y
-    if (isa<Constant>(TrueVal))
-      return ReplaceInstUsesWith(SI, TrueVal);
-    else
-      return ReplaceInstUsesWith(SI, FalseVal);
-  }
+  if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, TD))
+    return ReplaceInstUsesWith(SI, V);
 
-  if (SI.getType() == Type::getInt1Ty(SI.getContext())) {
+  if (SI.getType()->isIntegerTy(1)) {
     if (ConstantInt *C = dyn_cast<ConstantInt>(TrueVal)) {
       if (C->getZExtValue()) {
         // Change: A = select B, true, C --> A = or B, C
         return BinaryOperator::CreateOr(CondVal, FalseVal);
-      } else {
-        // Change: A = select B, false, C --> A = and !B, C
-        Value *NotCond =
-          InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
-                                             "not."+CondVal->getName()), SI);
-        return BinaryOperator::CreateAnd(NotCond, FalseVal);
       }
+      // Change: A = select B, false, C --> A = and !B, C
+      Value *NotCond =
+        InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
+                                           "not."+CondVal->getName()), SI);
+      return BinaryOperator::CreateAnd(NotCond, FalseVal);
     } else if (ConstantInt *C = dyn_cast<ConstantInt>(FalseVal)) {
       if (C->getZExtValue() == false) {
         // Change: A = select B, C, false --> A = and B, C
         return BinaryOperator::CreateAnd(CondVal, TrueVal);
-      } else {
-        // Change: A = select B, C, true --> A = or !B, C
-        Value *NotCond =
-          InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
-                                             "not."+CondVal->getName()), SI);
-        return BinaryOperator::CreateOr(NotCond, TrueVal);
       }
+      // Change: A = select B, C, true --> A = or !B, C
+      Value *NotCond =
+        InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
+                                           "not."+CondVal->getName()), SI);
+      return BinaryOperator::CreateOr(NotCond, TrueVal);
     }
     
     // select a, b, a  -> a&b
@@ -516,16 +491,25 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
   if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal))
     if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) {
       // select C, 1, 0 -> zext C to int
-      if (FalseValC->isZero() && TrueValC->getValue() == 1) {
-        return CastInst::Create(Instruction::ZExt, CondVal, SI.getType());
-      } else if (TrueValC->isZero() && FalseValC->getValue() == 1) {
-        // select C, 0, 1 -> zext !C to int
-        Value *NotCond =
-          InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
-                                               "not."+CondVal->getName()), SI);
-        return CastInst::Create(Instruction::ZExt, NotCond, SI.getType());
+      if (FalseValC->isZero() && TrueValC->getValue() == 1)
+        return new ZExtInst(CondVal, SI.getType());
+
+      // select C, -1, 0 -> sext C to int
+      if (FalseValC->isZero() && TrueValC->isAllOnesValue())
+        return new SExtInst(CondVal, SI.getType());
+      
+      // select C, 0, 1 -> zext !C to int
+      if (TrueValC->isZero() && FalseValC->getValue() == 1) {
+        Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
+        return new ZExtInst(NotCond, SI.getType());
       }
 
+      // select C, 0, -1 -> sext !C to int
+      if (TrueValC->isZero() && FalseValC->isAllOnesValue()) {
+        Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
+        return new SExtInst(NotCond, SI.getType());
+      }
+      
       if (ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition())) {
         // If one of the constants is zero (we know they can't both be) and we
         // have an icmp instruction with zero, and we have an 'and' with the
@@ -547,8 +531,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
                 ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE;
                 Value *V = ICA;
                 if (ShouldNotVal)
-                  V = InsertNewInstBefore(BinaryOperator::Create(
-                                  Instruction::Xor, V, ICA->getOperand(1)), SI);
+                  V = Builder->CreateXor(V, ICA->getOperand(1));
                 return ReplaceInstUsesWith(SI, V);
               }
       }
@@ -569,9 +552,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
              !CFPf->getValueAPF().isZero()))
         return ReplaceInstUsesWith(SI, FalseVal);
       }
-      // Transform (X != Y) ? X : Y  -> X
-      if (FCI->getPredicate() == FCmpInst::FCMP_ONE)
+      // Transform (X une Y) ? X : Y  -> X
+      if (FCI->getPredicate() == FCmpInst::FCMP_UNE) {
+        // This is not safe in general for floating point:  
+        // consider X== -0, Y== +0.
+        // It becomes safe if either operand is a nonzero constant.
+        ConstantFP *CFPt, *CFPf;
+        if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) &&
+              !CFPt->getValueAPF().isZero()) ||
+            ((CFPf = dyn_cast<ConstantFP>(FalseVal)) &&
+             !CFPf->getValueAPF().isZero()))
         return ReplaceInstUsesWith(SI, TrueVal);
+      }
       // NOTE: if we wanted to, this is where to detect MIN/MAX
 
     } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){
@@ -587,9 +579,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
              !CFPf->getValueAPF().isZero()))
           return ReplaceInstUsesWith(SI, FalseVal);
       }
-      // Transform (X != Y) ? Y : X  -> Y
-      if (FCI->getPredicate() == FCmpInst::FCMP_ONE)
-        return ReplaceInstUsesWith(SI, TrueVal);
+      // Transform (X une Y) ? Y : X  -> Y
+      if (FCI->getPredicate() == FCmpInst::FCMP_UNE) {
+        // This is not safe in general for floating point:  
+        // consider X== -0, Y== +0.
+        // It becomes safe if either operand is a nonzero constant.
+        ConstantFP *CFPt, *CFPf;
+        if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) &&
+              !CFPt->getValueAPF().isZero()) ||
+            ((CFPf = dyn_cast<ConstantFP>(FalseVal)) &&
+             !CFPf->getValueAPF().isZero()))
+          return ReplaceInstUsesWith(SI, TrueVal);
+      }
       // NOTE: if we wanted to, this is where to detect MIN/MAX
     }
     // NOTE: if we wanted to, this is where to detect ABS
@@ -659,7 +660,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       }
 
   // See if we can fold the select into one of our operands.
-  if (SI.getType()->isInteger()) {
+  if (SI.getType()->isIntegerTy()) {
     if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal))
       return FoldI;