Uniformize the names of type predicates: rather than having isFloatTy and
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCasts.cpp
index e018b351082a23b7abc229ba7f8d415d01ed852d..bb4a0e94968e0dd104b14217b00dd44131688b0f 100644 (file)
@@ -23,7 +23,7 @@ using namespace PatternMatch;
 ///
 static Value *DecomposeSimpleLinearExpr(Value *Val, unsigned &Scale,
                                         int &Offset) {
-  assert(Val->getType()->isInteger(32) && "Unexpected allocation size type!");
+  assert(Val->getType()->isIntegerTy(32) && "Unexpected allocation size type!");
   if (ConstantInt *CI = dyn_cast<ConstantInt>(Val)) {
     Offset = CI->getZExtValue();
     Scale  = 0;
@@ -255,17 +255,26 @@ isEliminableCastPair(
   return Instruction::CastOps(Res);
 }
 
-/// ValueRequiresCast - Return true if the cast from "V to Ty" actually results
-/// in any code being generated.  It does not require codegen if V is simple
-/// enough or if the cast can be folded into other casts.
-bool InstCombiner::ValueRequiresCast(Instruction::CastOps opcode,const Value *V,
-                                     const Type *Ty) {
+/// ShouldOptimizeCast - Return true if the cast from "V to Ty" actually
+/// results in any code being generated and is interesting to optimize out. If
+/// the cast can be eliminated by some other simple transformation, we prefer
+/// to do the simplification first.
+bool InstCombiner::ShouldOptimizeCast(Instruction::CastOps opc, const Value *V,
+                                      const Type *Ty) {
+  // Noop casts and casts of constants should be eliminated trivially.
   if (V->getType() == Ty || isa<Constant>(V)) return false;
   
-  // If this is another cast that can be eliminated, it isn't codegen either.
+  // If this is another cast that can be eliminated, we prefer to have it
+  // eliminated.
   if (const CastInst *CI = dyn_cast<CastInst>(V))
-    if (isEliminableCastPair(CI, opcode, Ty, TD))
+    if (isEliminableCastPair(CI, opc, Ty, TD))
       return false;
+  
+  // If this is a vector sext from a compare, then we don't want to break the
+  // idiom where each element of the extended vector is either zero or all ones.
+  if (opc == Instruction::SExt && isa<CmpInst>(V) && isa<VectorType>(Ty))
+    return false;
+  
   return true;
 }
 
@@ -828,7 +837,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
 
   // zext (xor i1 X, true) to i32  --> xor (zext i1 X to i32), 1
   Value *X;
-  if (SrcI && SrcI->hasOneUse() && SrcI->getType()->isInteger(1) &&
+  if (SrcI && SrcI->hasOneUse() && SrcI->getType()->isIntegerTy(1) &&
       match(SrcI, m_Not(m_Value(X))) &&
       (!X->hasOneUse() || !isa<CmpInst>(X))) {
     Value *New = Builder->CreateZExt(X, CI.getType());
@@ -923,12 +932,6 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
   Value *Src = CI.getOperand(0);
   const Type *SrcTy = Src->getType(), *DestTy = CI.getType();
 
-  // Canonicalize sign-extend from i1 to a select.
-  if (Src->getType()->isInteger(1))
-    return SelectInst::Create(Src,
-                              Constant::getAllOnesValue(CI.getType()),
-                              Constant::getNullValue(CI.getType()));
-  
   // Attempt to extend the entire input expression tree to the destination
   // type.   Only do this if the dest type is a simple type, don't convert the
   // expression tree to something weird like i93 unless the source is also
@@ -955,6 +958,43 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
                                       ShAmt);
   }
 
+  // If this input is a trunc from our destination, then turn sext(trunc(x))
+  // into shifts.
+  if (TruncInst *TI = dyn_cast<TruncInst>(Src))
+    if (TI->hasOneUse() && TI->getOperand(0)->getType() == DestTy) {
+      uint32_t SrcBitSize = SrcTy->getScalarSizeInBits();
+      uint32_t DestBitSize = DestTy->getScalarSizeInBits();
+      
+      // We need to emit a shl + ashr to do the sign extend.
+      Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize);
+      Value *Res = Builder->CreateShl(TI->getOperand(0), ShAmt, "sext");
+      return BinaryOperator::CreateAShr(Res, ShAmt);
+    }
+  
+  
+  // (x <s 0) ? -1 : 0 -> ashr x, 31   -> all ones if signed
+  // (x >s -1) ? -1 : 0 -> ashr x, 31  -> all ones if not signed
+  {
+  ICmpInst::Predicate Pred; Value *CmpLHS; ConstantInt *CmpRHS;
+  if (match(Src, m_ICmp(Pred, m_Value(CmpLHS), m_ConstantInt(CmpRHS)))) {
+    // sext (x <s  0) to i32 --> x>>s31       true if signbit set.
+    // sext (x >s -1) to i32 --> (x>>s31)^-1  true if signbit clear.
+    if ((Pred == ICmpInst::ICMP_SLT && CmpRHS->isZero()) ||
+        (Pred == ICmpInst::ICMP_SGT && CmpRHS->isAllOnesValue())) {
+      Value *Sh = ConstantInt::get(CmpLHS->getType(),
+                                   CmpLHS->getType()->getScalarSizeInBits()-1);
+      Value *In = Builder->CreateAShr(CmpLHS, Sh, CmpLHS->getName()+".lobit");
+      if (In->getType() != CI.getType())
+        In = Builder->CreateIntCast(In, CI.getType(), true/*SExt*/, "tmp");
+      
+      if (Pred == ICmpInst::ICMP_SGT)
+        In = Builder->CreateNot(In, In->getName()+".not");
+      return ReplaceInstUsesWith(CI, In);
+    }
+  }
+  }
+  
+  
   // If the input is a shl/ashr pair of a same constant, then this is a sign
   // extension from a smaller value.  If we could trust arbitrary bitwidth
   // integers, we could turn this into a truncate to the smaller bit and then
@@ -1114,16 +1154,22 @@ Instruction *InstCombiner::visitSIToFP(CastInst &CI) {
 }
 
 Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) {
-  // If the source integer type is larger than the intptr_t type for
-  // this target, do a trunc to the intptr_t type, then inttoptr of it.  This
-  // allows the trunc to be exposed to other transforms.  Don't do this for
-  // extending inttoptr's, because we don't know if the target sign or zero
-  // extends to pointers.
-  if (TD && CI.getOperand(0)->getType()->getScalarSizeInBits() >
-      TD->getPointerSizeInBits()) {
-    Value *P = Builder->CreateTrunc(CI.getOperand(0),
-                                    TD->getIntPtrType(CI.getContext()), "tmp");
-    return new IntToPtrInst(P, CI.getType());
+  // If the source integer type is not the intptr_t type for this target, do a
+  // trunc or zext to the intptr_t type, then inttoptr of it.  This allows the
+  // cast to be exposed to other transforms.
+  if (TD) {
+    if (CI.getOperand(0)->getType()->getScalarSizeInBits() >
+        TD->getPointerSizeInBits()) {
+      Value *P = Builder->CreateTrunc(CI.getOperand(0),
+                                      TD->getIntPtrType(CI.getContext()), "tmp");
+      return new IntToPtrInst(P, CI.getType());
+    }
+    if (CI.getOperand(0)->getType()->getScalarSizeInBits() <
+        TD->getPointerSizeInBits()) {
+      Value *P = Builder->CreateZExt(CI.getOperand(0),
+                                     TD->getIntPtrType(CI.getContext()), "tmp");
+      return new IntToPtrInst(P, CI.getType());
+    }
   }
   
   if (Instruction *I = commonCastTransforms(CI))
@@ -1185,17 +1231,22 @@ Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) {
 }
 
 Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) {
-  // If the destination integer type is smaller than the intptr_t type for
-  // this target, do a ptrtoint to intptr_t then do a trunc.  This allows the
-  // trunc to be exposed to other transforms.  Don't do this for extending
-  // ptrtoint's, because we don't know if the target sign or zero extends its
-  // pointers.
-  if (TD &&
-      CI.getType()->getScalarSizeInBits() < TD->getPointerSizeInBits()) {
-    Value *P = Builder->CreatePtrToInt(CI.getOperand(0),
-                                       TD->getIntPtrType(CI.getContext()),
-                                       "tmp");
-    return new TruncInst(P, CI.getType());
+  // If the destination integer type is not the intptr_t type for this target,
+  // do a ptrtoint to intptr_t then do a trunc or zext.  This allows the cast
+  // to be exposed to other transforms.
+  if (TD) {
+    if (CI.getType()->getScalarSizeInBits() < TD->getPointerSizeInBits()) {
+      Value *P = Builder->CreatePtrToInt(CI.getOperand(0),
+                                         TD->getIntPtrType(CI.getContext()),
+                                         "tmp");
+      return new TruncInst(P, CI.getType());
+    }
+    if (CI.getType()->getScalarSizeInBits() > TD->getPointerSizeInBits()) {
+      Value *P = Builder->CreatePtrToInt(CI.getOperand(0),
+                                         TD->getIntPtrType(CI.getContext()),
+                                         "tmp");
+      return new ZExtInst(P, CI.getType());
+    }
   }
   
   return commonPointerCastTransforms(CI);