Fix indentation.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineSelect.cpp
index 6c36ebc2790d4fbb611b7537270b19f7e0ed9edf..2fc932594f7b0931cf4dcad099dc5cfa5154bf17 100644 (file)
@@ -1,4 +1,4 @@
-//===- InstCombineLoadStoreAlloca.cpp -------------------------------------===//
+//===- InstCombineSelect.cpp ----------------------------------------------===//
 //
 //                     The LLVM Compiler Infrastructure
 //
@@ -7,16 +7,11 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// This file implements the visit functions for load, store and alloca.
+// This file implements the visitSelect function.
 //
 //===----------------------------------------------------------------------===//
 
 #include "InstCombine.h"
-//#include "llvm/IntrinsicInst.h"
-//#include "llvm/Target/TargetData.h"
-//#include "llvm/Transforms/Utils/BasicBlockUtils.h"
-//#include "llvm/Transforms/Utils/Local.h"
-//#include "llvm/ADT/Statistic.h"
 #include "llvm/Support/PatternMatch.h"
 using namespace llvm;
 using namespace PatternMatch;
@@ -331,44 +326,6 @@ Instruction *InstCombiner::visitSelectInstWithICmp(SelectInst &SI,
         break;
       }
       }
-
-      // (x <s 0) ? -1 : 0 -> ashr x, 31   -> all ones if signed
-      // (x >s -1) ? -1 : 0 -> ashr x, 31  -> all ones if not signed
-      CmpInst::Predicate Pred = CmpInst::BAD_ICMP_PREDICATE;
-      if (match(TrueVal, m_ConstantInt<-1>()) &&
-          match(FalseVal, m_ConstantInt<0>()))
-        Pred = ICI->getPredicate();
-      else if (match(TrueVal, m_ConstantInt<0>()) &&
-               match(FalseVal, m_ConstantInt<-1>()))
-        Pred = CmpInst::getInversePredicate(ICI->getPredicate());
-      
-      if (Pred != CmpInst::BAD_ICMP_PREDICATE) {
-        // If we are just checking for a icmp eq of a single bit and zext'ing it
-        // to an integer, then shift the bit to the appropriate place and then
-        // cast to integer to avoid the comparison.
-        const APInt &Op1CV = CI->getValue();
-    
-        // sext (x <s  0) to i32 --> x>>s31      true if signbit set.
-        // sext (x >s -1) to i32 --> (x>>s31)^-1  true if signbit clear.
-        if ((Pred == ICmpInst::ICMP_SLT && Op1CV == 0) ||
-            (Pred == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) {
-          Value *In = ICI->getOperand(0);
-          Value *Sh = ConstantInt::get(In->getType(),
-                                       In->getType()->getScalarSizeInBits()-1);
-          In = InsertNewInstBefore(BinaryOperator::CreateAShr(In, Sh,
-                                                        In->getName()+".lobit"),
-                                   *ICI);
-          if (In->getType() != SI.getType())
-            In = CastInst::CreateIntegerCast(In, SI.getType(),
-                                             true/*SExt*/, "tmp", ICI);
-    
-          if (Pred == ICmpInst::ICMP_SGT)
-            In = InsertNewInstBefore(BinaryOperator::CreateNot(In,
-                                       In->getName()+".not"), *ICI);
-    
-          return ReplaceInstUsesWith(SI, In);
-        }
-      }
     }
 
   if (CmpLHS == TrueVal && CmpRHS == FalseVal) {
@@ -484,7 +441,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       return ReplaceInstUsesWith(SI, FalseVal);
   }
 
-  if (SI.getType() == Type::getInt1Ty(SI.getContext())) {
+  if (SI.getType()->isIntegerTy(1)) {
     if (ConstantInt *C = dyn_cast<ConstantInt>(TrueVal)) {
       if (C->getZExtValue()) {
         // Change: A = select B, true, C --> A = or B, C
@@ -521,16 +478,25 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
   if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal))
     if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) {
       // select C, 1, 0 -> zext C to int
-      if (FalseValC->isZero() && TrueValC->getValue() == 1) {
-        return CastInst::Create(Instruction::ZExt, CondVal, SI.getType());
-      } else if (TrueValC->isZero() && FalseValC->getValue() == 1) {
-        // select C, 0, 1 -> zext !C to int
-        Value *NotCond =
-          InsertNewInstBefore(BinaryOperator::CreateNot(CondVal,
-                                               "not."+CondVal->getName()), SI);
-        return CastInst::Create(Instruction::ZExt, NotCond, SI.getType());
+      if (FalseValC->isZero() && TrueValC->getValue() == 1)
+        return new ZExtInst(CondVal, SI.getType());
+
+      // select C, -1, 0 -> sext C to int
+      if (FalseValC->isZero() && TrueValC->isAllOnesValue())
+        return new SExtInst(CondVal, SI.getType());
+      
+      // select C, 0, 1 -> zext !C to int
+      if (TrueValC->isZero() && FalseValC->getValue() == 1) {
+        Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
+        return new ZExtInst(NotCond, SI.getType());
       }
 
+      // select C, 0, -1 -> sext !C to int
+      if (TrueValC->isZero() && FalseValC->isAllOnesValue()) {
+        Value *NotCond = Builder->CreateNot(CondVal, "not."+CondVal->getName());
+        return new SExtInst(NotCond, SI.getType());
+      }
+      
       if (ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition())) {
         // If one of the constants is zero (we know they can't both be) and we
         // have an icmp instruction with zero, and we have an 'and' with the
@@ -552,8 +518,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
                 ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE;
                 Value *V = ICA;
                 if (ShouldNotVal)
-                  V = InsertNewInstBefore(BinaryOperator::Create(
-                                  Instruction::Xor, V, ICA->getOperand(1)), SI);
+                  V = Builder->CreateXor(V, ICA->getOperand(1));
                 return ReplaceInstUsesWith(SI, V);
               }
       }
@@ -574,9 +539,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
              !CFPf->getValueAPF().isZero()))
         return ReplaceInstUsesWith(SI, FalseVal);
       }
-      // Transform (X != Y) ? X : Y  -> X
-      if (FCI->getPredicate() == FCmpInst::FCMP_ONE)
+      // Transform (X une Y) ? X : Y  -> X
+      if (FCI->getPredicate() == FCmpInst::FCMP_UNE) {
+        // This is not safe in general for floating point:  
+        // consider X== -0, Y== +0.
+        // It becomes safe if either operand is a nonzero constant.
+        ConstantFP *CFPt, *CFPf;
+        if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) &&
+              !CFPt->getValueAPF().isZero()) ||
+            ((CFPf = dyn_cast<ConstantFP>(FalseVal)) &&
+             !CFPf->getValueAPF().isZero()))
         return ReplaceInstUsesWith(SI, TrueVal);
+      }
       // NOTE: if we wanted to, this is where to detect MIN/MAX
 
     } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){
@@ -592,9 +566,18 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
              !CFPf->getValueAPF().isZero()))
           return ReplaceInstUsesWith(SI, FalseVal);
       }
-      // Transform (X != Y) ? Y : X  -> Y
-      if (FCI->getPredicate() == FCmpInst::FCMP_ONE)
-        return ReplaceInstUsesWith(SI, TrueVal);
+      // Transform (X une Y) ? Y : X  -> Y
+      if (FCI->getPredicate() == FCmpInst::FCMP_UNE) {
+        // This is not safe in general for floating point:  
+        // consider X== -0, Y== +0.
+        // It becomes safe if either operand is a nonzero constant.
+        ConstantFP *CFPt, *CFPf;
+        if (((CFPt = dyn_cast<ConstantFP>(TrueVal)) &&
+              !CFPt->getValueAPF().isZero()) ||
+            ((CFPf = dyn_cast<ConstantFP>(FalseVal)) &&
+             !CFPf->getValueAPF().isZero()))
+          return ReplaceInstUsesWith(SI, TrueVal);
+      }
       // NOTE: if we wanted to, this is where to detect MIN/MAX
     }
     // NOTE: if we wanted to, this is where to detect ABS
@@ -664,7 +647,7 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
       }
 
   // See if we can fold the select into one of our operands.
-  if (SI.getType()->isInteger()) {
+  if (SI.getType()->isIntegerTy()) {
     if (Instruction *FoldI = FoldSelectIntoOp(SI, TrueVal, FalseVal))
       return FoldI;