Fix a missing newline in debug output.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCasts.cpp
index f25dd3582b2e7f6761a1c6bc1b1104fee4dc7e13..b0137c4a32ca128b0a65ed8a0dc3ea63cdc437ba 100644 (file)
@@ -23,7 +23,7 @@ using namespace PatternMatch;
 ///
 static Value *DecomposeSimpleLinearExpr(Value *Val, unsigned &Scale,
                                         int &Offset) {
-  assert(Val->getType()->isInteger(32) && "Unexpected allocation size type!");
+  assert(Val->getType()->isIntegerTy(32) && "Unexpected allocation size type!");
   if (ConstantInt *CI = dyn_cast<ConstantInt>(Val)) {
     Offset = CI->getZExtValue();
     Scale  = 0;
@@ -255,17 +255,26 @@ isEliminableCastPair(
   return Instruction::CastOps(Res);
 }
 
-/// ValueRequiresCast - Return true if the cast from "V to Ty" actually results
-/// in any code being generated.  It does not require codegen if V is simple
-/// enough or if the cast can be folded into other casts.
-bool InstCombiner::ValueRequiresCast(Instruction::CastOps opcode,const Value *V,
-                                     const Type *Ty) {
+/// ShouldOptimizeCast - Return true if the cast from "V to Ty" actually
+/// results in any code being generated and is interesting to optimize out. If
+/// the cast can be eliminated by some other simple transformation, we prefer
+/// to do the simplification first.
+bool InstCombiner::ShouldOptimizeCast(Instruction::CastOps opc, const Value *V,
+                                      const Type *Ty) {
+  // Noop casts and casts of constants should be eliminated trivially.
   if (V->getType() == Ty || isa<Constant>(V)) return false;
   
-  // If this is another cast that can be eliminated, it isn't codegen either.
+  // If this is another cast that can be eliminated, we prefer to have it
+  // eliminated.
   if (const CastInst *CI = dyn_cast<CastInst>(V))
-    if (isEliminableCastPair(CI, opcode, Ty, TD))
+    if (isEliminableCastPair(CI, opc, Ty, TD))
       return false;
+  
+  // If this is a vector sext from a compare, then we don't want to break the
+  // idiom where each element of the extended vector is either zero or all ones.
+  if (opc == Instruction::SExt && isa<CmpInst>(V) && Ty->isVectorTy())
+    return false;
+  
   return true;
 }
 
@@ -294,8 +303,8 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) {
   if (isa<PHINode>(Src)) {
     // We don't do this if this would create a PHI node with an illegal type if
     // it is currently legal.
-    if (!isa<IntegerType>(Src->getType()) ||
-        !isa<IntegerType>(CI.getType()) ||
+    if (!Src->getType()->isIntegerTy() ||
+        !CI.getType()->isIntegerTy() ||
         ShouldChangeType(CI.getType(), Src->getType()))
       if (Instruction *NV = FoldOpIntoPhi(CI))
         return NV;
@@ -427,13 +436,13 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
   // type.   Only do this if the dest type is a simple type, don't convert the
   // expression tree to something weird like i93 unless the source is also
   // strange.
-  if ((isa<VectorType>(DestTy) || ShouldChangeType(SrcTy, DestTy)) &&
+  if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
       CanEvaluateTruncated(Src, DestTy)) {
       
     // If this cast is a truncate, evaluting in a different type always
     // eliminates the cast, so it is always a win.
     DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type"
-          " to avoid cast: " << CI);
+          " to avoid cast: " << CI << '\n');
     Value *Res = EvaluateInDifferentType(Src, DestTy, false);
     assert(Res->getType() == DestTy);
     return ReplaceInstUsesWith(CI, Res);
@@ -719,7 +728,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
   // expression tree to something weird like i93 unless the source is also
   // strange.
   unsigned BitsToClear;
-  if ((isa<VectorType>(DestTy) || ShouldChangeType(SrcTy, DestTy)) &&
+  if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
       CanEvaluateZExtd(Src, DestTy, BitsToClear)) { 
     assert(BitsToClear < SrcTy->getScalarSizeInBits() &&
            "Unreasonable BitsToClear");
@@ -828,7 +837,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
 
   // zext (xor i1 X, true) to i32  --> xor (zext i1 X to i32), 1
   Value *X;
-  if (SrcI && SrcI->hasOneUse() && SrcI->getType()->isInteger(1) &&
+  if (SrcI && SrcI->hasOneUse() && SrcI->getType()->isIntegerTy(1) &&
       match(SrcI, m_Not(m_Value(X))) &&
       (!X->hasOneUse() || !isa<CmpInst>(X))) {
     Value *New = Builder->CreateZExt(X, CI.getType());
@@ -923,17 +932,11 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
   Value *Src = CI.getOperand(0);
   const Type *SrcTy = Src->getType(), *DestTy = CI.getType();
 
-  // Canonicalize sign-extend from i1 to a select.
-  if (Src->getType()->isInteger(1))
-    return SelectInst::Create(Src,
-                              Constant::getAllOnesValue(CI.getType()),
-                              Constant::getNullValue(CI.getType()));
-  
   // Attempt to extend the entire input expression tree to the destination
   // type.   Only do this if the dest type is a simple type, don't convert the
   // expression tree to something weird like i93 unless the source is also
   // strange.
-  if ((isa<VectorType>(DestTy) || ShouldChangeType(SrcTy, DestTy)) &&
+  if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
       CanEvaluateSExtd(Src, DestTy)) {
     // Okay, we can transform this!  Insert the new expression now.
     DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type"
@@ -968,6 +971,30 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
       return BinaryOperator::CreateAShr(Res, ShAmt);
     }
   
+  
+  // (x <s 0) ? -1 : 0 -> ashr x, 31   -> all ones if signed
+  // (x >s -1) ? -1 : 0 -> ashr x, 31  -> all ones if not signed
+  {
+  ICmpInst::Predicate Pred; Value *CmpLHS; ConstantInt *CmpRHS;
+  if (match(Src, m_ICmp(Pred, m_Value(CmpLHS), m_ConstantInt(CmpRHS)))) {
+    // sext (x <s  0) to i32 --> x>>s31       true if signbit set.
+    // sext (x >s -1) to i32 --> (x>>s31)^-1  true if signbit clear.
+    if ((Pred == ICmpInst::ICMP_SLT && CmpRHS->isZero()) ||
+        (Pred == ICmpInst::ICMP_SGT && CmpRHS->isAllOnesValue())) {
+      Value *Sh = ConstantInt::get(CmpLHS->getType(),
+                                   CmpLHS->getType()->getScalarSizeInBits()-1);
+      Value *In = Builder->CreateAShr(CmpLHS, Sh, CmpLHS->getName()+".lobit");
+      if (In->getType() != CI.getType())
+        In = Builder->CreateIntCast(In, CI.getType(), true/*SExt*/, "tmp");
+      
+      if (Pred == ICmpInst::ICMP_SGT)
+        In = Builder->CreateNot(In, In->getName()+".not");
+      return ReplaceInstUsesWith(CI, In);
+    }
+  }
+  }
+  
+  
   // If the input is a shl/ashr pair of a same constant, then this is a sign
   // extension from a smaller value.  If we could trust arbitrary bitwidth
   // integers, we could turn this into a truncate to the smaller bit and then
@@ -1127,16 +1154,22 @@ Instruction *InstCombiner::visitSIToFP(CastInst &CI) {
 }
 
 Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) {
-  // If the source integer type is larger than the intptr_t type for
-  // this target, do a trunc to the intptr_t type, then inttoptr of it.  This
-  // allows the trunc to be exposed to other transforms.  Don't do this for
-  // extending inttoptr's, because we don't know if the target sign or zero
-  // extends to pointers.
-  if (TD && CI.getOperand(0)->getType()->getScalarSizeInBits() >
-      TD->getPointerSizeInBits()) {
-    Value *P = Builder->CreateTrunc(CI.getOperand(0),
-                                    TD->getIntPtrType(CI.getContext()), "tmp");
-    return new IntToPtrInst(P, CI.getType());
+  // If the source integer type is not the intptr_t type for this target, do a
+  // trunc or zext to the intptr_t type, then inttoptr of it.  This allows the
+  // cast to be exposed to other transforms.
+  if (TD) {
+    if (CI.getOperand(0)->getType()->getScalarSizeInBits() >
+        TD->getPointerSizeInBits()) {
+      Value *P = Builder->CreateTrunc(CI.getOperand(0),
+                                      TD->getIntPtrType(CI.getContext()), "tmp");
+      return new IntToPtrInst(P, CI.getType());
+    }
+    if (CI.getOperand(0)->getType()->getScalarSizeInBits() <
+        TD->getPointerSizeInBits()) {
+      Value *P = Builder->CreateZExt(CI.getOperand(0),
+                                     TD->getIntPtrType(CI.getContext()), "tmp");
+      return new IntToPtrInst(P, CI.getType());
+    }
   }
   
   if (Instruction *I = commonCastTransforms(CI))
@@ -1198,22 +1231,85 @@ Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) {
 }
 
 Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) {
-  // If the destination integer type is smaller than the intptr_t type for
-  // this target, do a ptrtoint to intptr_t then do a trunc.  This allows the
-  // trunc to be exposed to other transforms.  Don't do this for extending
-  // ptrtoint's, because we don't know if the target sign or zero extends its
-  // pointers.
-  if (TD &&
-      CI.getType()->getScalarSizeInBits() < TD->getPointerSizeInBits()) {
-    Value *P = Builder->CreatePtrToInt(CI.getOperand(0),
-                                       TD->getIntPtrType(CI.getContext()),
-                                       "tmp");
-    return new TruncInst(P, CI.getType());
+  // If the destination integer type is not the intptr_t type for this target,
+  // do a ptrtoint to intptr_t then do a trunc or zext.  This allows the cast
+  // to be exposed to other transforms.
+  if (TD) {
+    if (CI.getType()->getScalarSizeInBits() < TD->getPointerSizeInBits()) {
+      Value *P = Builder->CreatePtrToInt(CI.getOperand(0),
+                                         TD->getIntPtrType(CI.getContext()),
+                                         "tmp");
+      return new TruncInst(P, CI.getType());
+    }
+    if (CI.getType()->getScalarSizeInBits() > TD->getPointerSizeInBits()) {
+      Value *P = Builder->CreatePtrToInt(CI.getOperand(0),
+                                         TD->getIntPtrType(CI.getContext()),
+                                         "tmp");
+      return new ZExtInst(P, CI.getType());
+    }
   }
   
   return commonPointerCastTransforms(CI);
 }
 
+/// OptimizeVectorResize - This input value (which is known to have vector type)
+/// is being zero extended or truncated to the specified vector type.  Try to
+/// replace it with a shuffle (and vector/vector bitcast) if possible.
+///
+/// The source and destination vector types may have different element types.
+static Instruction *OptimizeVectorResize(Value *InVal, const VectorType *DestTy,
+                                         InstCombiner &IC) {
+  // We can only do this optimization if the output is a multiple of the input
+  // element size, or the input is a multiple of the output element size.
+  // Convert the input type to have the same element type as the output.
+  const VectorType *SrcTy = cast<VectorType>(InVal->getType());
+  
+  if (SrcTy->getElementType() != DestTy->getElementType()) {
+    // The input types don't need to be identical, but for now they must be the
+    // same size.  There is no specific reason we couldn't handle things like
+    // <4 x i16> -> <4 x i32> by bitcasting to <2 x i32> but haven't gotten
+    // there yet. 
+    if (SrcTy->getElementType()->getPrimitiveSizeInBits() !=
+        DestTy->getElementType()->getPrimitiveSizeInBits())
+      return 0;
+    
+    SrcTy = VectorType::get(DestTy->getElementType(), SrcTy->getNumElements());
+    InVal = IC.Builder->CreateBitCast(InVal, SrcTy);
+  }
+  
+  // Now that the element types match, get the shuffle mask and RHS of the
+  // shuffle to use, which depends on whether we're increasing or decreasing the
+  // size of the input.
+  SmallVector<Constant*, 16> ShuffleMask;
+  Value *V2;
+  const IntegerType *Int32Ty = Type::getInt32Ty(SrcTy->getContext());
+  
+  if (SrcTy->getNumElements() > DestTy->getNumElements()) {
+    // If we're shrinking the number of elements, just shuffle in the low
+    // elements from the input and use undef as the second shuffle input.
+    V2 = UndefValue::get(SrcTy);
+    for (unsigned i = 0, e = DestTy->getNumElements(); i != e; ++i)
+      ShuffleMask.push_back(ConstantInt::get(Int32Ty, i));
+    
+  } else {
+    // If we're increasing the number of elements, shuffle in all of the
+    // elements from InVal and fill the rest of the result elements with zeros
+    // from a constant zero.
+    V2 = Constant::getNullValue(SrcTy);
+    unsigned SrcElts = SrcTy->getNumElements();
+    for (unsigned i = 0, e = SrcElts; i != e; ++i)
+      ShuffleMask.push_back(ConstantInt::get(Int32Ty, i));
+
+    // The excess elements reference the first element of the zero input.
+    ShuffleMask.append(DestTy->getNumElements()-SrcElts,
+                       ConstantInt::get(Int32Ty, SrcElts));
+  }
+  
+  Constant *Mask = ConstantVector::get(ShuffleMask.data(), ShuffleMask.size());
+  return new ShuffleVectorInst(InVal, V2, Mask);
+}
+
+
 Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
   // If the operands are integer typed then apply the integer transforms,
   // otherwise just apply the common ones.
@@ -1251,7 +1347,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
       Constant::getNullValue(Type::getInt32Ty(CI.getContext()));
     unsigned NumZeros = 0;
     while (SrcElTy != DstElTy && 
-           isa<CompositeType>(SrcElTy) && !isa<PointerType>(SrcElTy) &&
+           isa<CompositeType>(SrcElTy) && !SrcElTy->isPointerTy() &&
            SrcElTy->getNumContainedTypes() /* not "{}" */) {
       SrcElTy = cast<CompositeType>(SrcElTy)->getTypeAtIndex(ZeroUInt);
       ++NumZeros;
@@ -1266,16 +1362,28 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
   }
 
   if (const VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) {
-    if (DestVTy->getNumElements() == 1 && !isa<VectorType>(SrcTy)) {
+    if (DestVTy->getNumElements() == 1 && !SrcTy->isVectorTy()) {
       Value *Elem = Builder->CreateBitCast(Src, DestVTy->getElementType());
       return InsertElementInst::Create(UndefValue::get(DestTy), Elem,
                      Constant::getNullValue(Type::getInt32Ty(CI.getContext())));
       // FIXME: Canonicalize bitcast(insertelement) -> insertelement(bitcast)
     }
+    
+    // If this is a cast from an integer to vector, check to see if the input
+    // is a trunc or zext of a bitcast from vector.  If so, we can replace all
+    // the casts with a shuffle and (potentially) a bitcast.
+    if (isa<IntegerType>(SrcTy) && (isa<TruncInst>(Src) || isa<ZExtInst>(Src))){
+      CastInst *SrcCast = cast<CastInst>(Src);
+      if (BitCastInst *BCIn = dyn_cast<BitCastInst>(SrcCast->getOperand(0)))
+        if (isa<VectorType>(BCIn->getOperand(0)->getType()))
+          if (Instruction *I = OptimizeVectorResize(BCIn->getOperand(0),
+                                               cast<VectorType>(DestTy), *this))
+            return I;
+    }
   }
 
   if (const VectorType *SrcVTy = dyn_cast<VectorType>(SrcTy)) {
-    if (SrcVTy->getNumElements() == 1 && !isa<VectorType>(DestTy)) {
+    if (SrcVTy->getNumElements() == 1 && !DestTy->isVectorTy()) {
       Value *Elem = 
         Builder->CreateExtractElement(Src,
                    Constant::getNullValue(Type::getInt32Ty(CI.getContext())));
@@ -1285,8 +1393,8 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
 
   if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(Src)) {
     // Okay, we have (bitcast (shuffle ..)).  Check to see if this is
-    // a bitconvert to a vector with the same # elts.
-    if (SVI->hasOneUse() && isa<VectorType>(DestTy) && 
+    // a bitcast to a vector with the same # elts.
+    if (SVI->hasOneUse() && DestTy->isVectorTy() && 
         cast<VectorType>(DestTy)->getNumElements() ==
               SVI->getType()->getNumElements() &&
         SVI->getType()->getNumElements() ==
@@ -1308,7 +1416,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
     }
   }
   
-  if (isa<PointerType>(SrcTy))
+  if (SrcTy->isPointerTy())
     return commonPointerCastTransforms(CI);
   return commonCastTransforms(CI);
 }