Clean up previous cast optimization a bit. Also make zext elimination a bit more...
authorEvan Cheng <evan.cheng@apple.com>
Fri, 16 Jan 2009 02:11:43 +0000 (02:11 +0000)
committerEvan Cheng <evan.cheng@apple.com>
Fri, 16 Jan 2009 02:11:43 +0000 (02:11 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@62297 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/InstructionCombining.cpp
test/Transforms/InstCombine/cast.ll

index 31f82ad196461fea3c67f0a299ea86d1872f8a04..36a2ad5e455d406042b0bb391a6b688043d75db2 100644 (file)
@@ -394,8 +394,7 @@ namespace {
     Value *EvaluateInDifferentType(Value *V, const Type *Ty, bool isSigned);
 
     bool CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
-                                    unsigned CastOpc,
-                                    int &NumCastsRemoved, bool &SeenTrunc);
+                                    unsigned CastOpc, int &NumCastsRemoved);
     unsigned GetOrEnforceKnownAlignment(Value *V,
                                         unsigned PrefAlign = 0);
 
@@ -7497,10 +7496,9 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI,
 /// If CastOpc is a sext or zext, we are asking if the low bits of the value can
 /// bit computed in a larger type, which is then and'd or sext_in_reg'd to get
 /// the final result.
-bool
-InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
-                                         unsigned CastOpc,
-                                         int &NumCastsRemoved, bool &SeenTrunc){
+bool InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
+                                              unsigned CastOpc,
+                                              int &NumCastsRemoved){
   // We can always evaluate constants in another type.
   if (isa<ConstantInt>(V))
     return true;
@@ -7520,8 +7518,6 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
       // casts first.
       if (!isa<CastInst>(I->getOperand(0)) && I->hasOneUse())
         ++NumCastsRemoved;
-      if (isa<TruncInst>(I))
-        SeenTrunc = true;
       return true;
     }
   }
@@ -7540,9 +7536,9 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
   case Instruction::Xor:
     // These operators can all arbitrarily be extended or truncated.
     return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc,
-                                      NumCastsRemoved, SeenTrunc) &&
+                                      NumCastsRemoved) &&
            CanEvaluateInDifferentType(I->getOperand(1), Ty, CastOpc,
-                                      NumCastsRemoved, SeenTrunc);
+                                      NumCastsRemoved);
 
   case Instruction::Shl:
     // If we are truncating the result of this SHL, and if it's a shift of a
@@ -7552,7 +7548,7 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
       if (BitWidth < OrigTy->getBitWidth() && 
           CI->getLimitedValue(BitWidth) < BitWidth)
         return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc,
-                                          NumCastsRemoved, SeenTrunc);
+                                          NumCastsRemoved);
     }
     break;
   case Instruction::LShr:
@@ -7567,7 +7563,7 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
             APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth)) &&
           CI->getLimitedValue(BitWidth) < BitWidth) {
         return CanEvaluateInDifferentType(I->getOperand(0), Ty, CastOpc,
-                                          NumCastsRemoved, SeenTrunc);
+                                          NumCastsRemoved);
       }
     }
     break;
@@ -7587,16 +7583,16 @@ InstCombiner::CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
   case Instruction::Select: {
     SelectInst *SI = cast<SelectInst>(I);
     return CanEvaluateInDifferentType(SI->getTrueValue(), Ty, CastOpc,
-                                      NumCastsRemoved, SeenTrunc) &&
+                                      NumCastsRemoved) &&
            CanEvaluateInDifferentType(SI->getFalseValue(), Ty, CastOpc,
-                                      NumCastsRemoved, SeenTrunc);
+                                      NumCastsRemoved);
   }
   case Instruction::PHI: {
     // We can change a phi if we can change all operands.
     PHINode *PN = cast<PHINode>(I);
     for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
       if (!CanEvaluateInDifferentType(PN->getIncomingValue(i), Ty, CastOpc,
-                                      NumCastsRemoved, SeenTrunc))
+                                      NumCastsRemoved))
         return false;
     return true;
   }
@@ -7845,10 +7841,9 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) {
 
   // Attempt to propagate the cast into the instruction for int->int casts.
   int NumCastsRemoved = 0;
-  bool SeenTrunc = false;
   if (!isa<BitCastInst>(CI) &&
       CanEvaluateInDifferentType(SrcI, cast<IntegerType>(DestTy),
-                                 CI.getOpcode(), NumCastsRemoved, SeenTrunc)) {
+                                 CI.getOpcode(), NumCastsRemoved)) {
     // If this cast is a truncate, evaluting in a different type always
     // eliminates the cast, so it is always a win.  If this is a zero-extension,
     // we need to do an AND to maintain the clear top-part of the computation,
@@ -7865,14 +7860,27 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) {
     case Instruction::Trunc:
       DoXForm = true;
       break;
-    case Instruction::ZExt:
+    case Instruction::ZExt: {
       DoXForm = NumCastsRemoved >= 1;
-      // TODO: Check if we need to insert an AND.
+      if (!DoXForm) {
+        // If it's unnecessary to issue an AND to clear the high bits, it's
+        // always profitable to do this xform.
+        Value *TryRes = EvaluateInDifferentType(SrcI, DestTy, 
+                                           CI.getOpcode() == Instruction::SExt);
+        APInt Mask(APInt::getBitsSet(DestBitSize, SrcBitSize, DestBitSize));
+        if (MaskedValueIsZero(TryRes, Mask))
+          return ReplaceInstUsesWith(CI, TryRes);
+        else if (Instruction *TryI = dyn_cast<Instruction>(TryRes))
+          if (TryI->use_empty())
+            EraseInstFromFunction(*TryI);
+      }
       break;
+    }
     case Instruction::SExt: {
       DoXForm = NumCastsRemoved >= 2;
-      if (!SeenTrunc) {
-        // Do we have to emit a truncate to SrcBitSize followed by a sext?
+      if (!DoXForm && !isa<TruncInst>(SrcI)) {
+        // If we do not have to emit the truncate + sext pair, then it's always
+        // profitable to do this xform.
         //
         // It's not safe to eliminate the trunc + sext pair if one of the
         // eliminated cast is a truncate. e.g.
@@ -7880,11 +7888,14 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) {
         // t3 = sext i16 t2 to i32
         // !=
         // i32 t1
-        unsigned NumSignBits = ComputeNumSignBits(&CI);
-        if (NumSignBits > (DestBitSize - SrcBitSize)) {
-          DoXForm = true;
-          JustReplace = true;
-        }
+        Value *TryRes = EvaluateInDifferentType(SrcI, DestTy, 
+                                           CI.getOpcode() == Instruction::SExt);
+        unsigned NumSignBits = ComputeNumSignBits(TryRes);
+        if (NumSignBits > (DestBitSize - SrcBitSize))
+          return ReplaceInstUsesWith(CI, TryRes);
+        else if (Instruction *TryI = dyn_cast<Instruction>(TryRes))
+          if (TryI->use_empty())
+            EraseInstFromFunction(*TryI);
       }
       break;
     }
@@ -7893,6 +7904,10 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) {
     if (DoXForm) {
       Value *Res = EvaluateInDifferentType(SrcI, DestTy, 
                                            CI.getOpcode() == Instruction::SExt);
+      if (JustReplace)
+          // Just replace this cast with the result.
+          return ReplaceInstUsesWith(CI, Res);
+
       assert(Res->getType() == DestTy);
       switch (CI.getOpcode()) {
       default: assert(0 && "Unknown cast type!");
@@ -7901,15 +7916,24 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) {
         // Just replace this cast with the result.
         return ReplaceInstUsesWith(CI, Res);
       case Instruction::ZExt: {
-        // We need to emit an AND to clear the high bits.
         assert(SrcBitSize < DestBitSize && "Not a zext?");
+
+        // If the high bits are already zero, just replace this cast with the
+        // result.
+        APInt Mask(APInt::getBitsSet(DestBitSize, SrcBitSize, DestBitSize));
+        if (MaskedValueIsZero(Res, Mask))
+          return ReplaceInstUsesWith(CI, Res);
+
+        // We need to emit an AND to clear the high bits.
         Constant *C = ConstantInt::get(APInt::getLowBitsSet(DestBitSize,
                                                             SrcBitSize));
         return BinaryOperator::CreateAnd(Res, C);
       }
-      case Instruction::SExt:
-        if (JustReplace)
-          // Just replace this cast with the result.
+      case Instruction::SExt: {
+        // If the high bits are already filled with sign bit, just replace this
+        // cast with the result.
+        unsigned NumSignBits = ComputeNumSignBits(Res);
+        if (NumSignBits > (DestBitSize - SrcBitSize))
           return ReplaceInstUsesWith(CI, Res);
 
         // We need to emit a cast to truncate, then a cast to sext.
@@ -7917,6 +7941,7 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) {
             InsertCastBefore(Instruction::Trunc, Res, Src->getType(), 
                              CI), DestTy);
       }
+      }
     }
   }
   
index 9361ff24e9b8ae8e77e7e16dfeebfc3ffab38775..7a1e7a802dd36f32b9e6b092100b6ddb6c7a657e 100644 (file)
@@ -254,3 +254,10 @@ define i1 @test37(i32 %a) {
         ret i1 %e
 }
 
+define i64 @test38(i32 %a) {
+       %1 = icmp eq i32 %a, -2
+       %2 = zext i1 %1 to i8
+       %3 = xor i8 %2, 1
+       %4 = zext i8 %3 to i64
+        ret i64 %4
+}