simplify some code
[oota-llvm.git] / lib / Transforms / Scalar / InstructionCombining.cpp
index 6eefac2033dab070526e3016768674e03823c440..2a08ced68c85e307e4d0cb695a5dddbb5fd8fa25 100644 (file)
@@ -74,6 +74,7 @@ namespace {
     std::vector<Instruction*> Worklist;
     DenseMap<Instruction*, unsigned> WorklistMap;
     TargetData *TD;
+    bool MustPreserveLCSSA;
   public:
     /// AddToWorkList - Add the specified instruction to the worklist if it
     /// isn't already in it.
@@ -141,6 +142,8 @@ namespace {
 
   public:
     virtual bool runOnFunction(Function &F);
+    
+    bool DoOneIteration(Function &F, unsigned ItNum);
 
     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
       AU.addRequired<TargetData>();
@@ -847,6 +850,7 @@ static void ComputeUnsignedMinMaxValuesFromKnownBits(const Type *Ty,
 bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
                                         uint64_t &KnownZero, uint64_t &KnownOne,
                                         unsigned Depth) {
+  const IntegerType *VTy = cast<IntegerType>(V->getType());
   if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
     // We know all of the bits for a constant!
     KnownOne = CI->getZExtValue() & DemandedMask;
@@ -863,10 +867,10 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
     }
     // If this is the root being simplified, allow it to have multiple uses,
     // just set the DemandedMask to all bits.
-    DemandedMask = cast<IntegerType>(V->getType())->getBitMask();
+    DemandedMask = VTy->getBitMask();
   } else if (DemandedMask == 0) {   // Not demanding any bits from V.
-    if (V != UndefValue::get(V->getType()))
-      return UpdateValueUsesWith(V, UndefValue::get(V->getType()));
+    if (V != UndefValue::get(VTy))
+      return UpdateValueUsesWith(V, UndefValue::get(VTy));
     return false;
   } else if (Depth == 6) {        // Limit search depth.
     return false;
@@ -875,7 +879,7 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
   Instruction *I = dyn_cast<Instruction>(V);
   if (!I) return false;        // Only analyze instructions.
 
-  DemandedMask &= cast<IntegerType>(V->getType())->getBitMask();
+  DemandedMask &= VTy->getBitMask();
   
   uint64_t KnownZero2 = 0, KnownOne2 = 0;
   switch (I->getOpcode()) {
@@ -903,7 +907,7 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
     
     // If all of the demanded bits in the inputs are known zeros, return zero.
     if ((DemandedMask & (KnownZero|KnownZero2)) == DemandedMask)
-      return UpdateValueUsesWith(I, Constant::getNullValue(I->getType()));
+      return UpdateValueUsesWith(I, Constant::getNullValue(VTy));
       
     // If the RHS is a constant, see if we can simplify it.
     if (ShrinkDemandedConstant(I, 1, DemandedMask & ~KnownZero2))
@@ -988,8 +992,7 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
     //    e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
     if ((DemandedMask & (KnownZero|KnownOne)) == DemandedMask) { // all known
       if ((KnownOne & KnownOne2) == KnownOne) {
-        Constant *AndC = ConstantInt::get(I->getType(), 
-                                          ~KnownOne & DemandedMask);
+        Constant *AndC = ConstantInt::get(VTy, ~KnownOne & DemandedMask);
         Instruction *And = 
           BinaryOperator::createAnd(I->getOperand(0), AndC, "tmp");
         InsertNewInstBefore(And, *I);
@@ -1045,7 +1048,7 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
     // Compute the bits in the result that are not present in the input.
     const IntegerType *SrcTy = cast<IntegerType>(I->getOperand(0)->getType());
     uint64_t NotIn = ~SrcTy->getBitMask();
-    uint64_t NewBits = cast<IntegerType>(I->getType())->getBitMask() & NotIn;
+    uint64_t NewBits = VTy->getBitMask() & NotIn;
     
     DemandedMask &= SrcTy->getBitMask();
     if (SimplifyDemandedBits(I->getOperand(0), DemandedMask,
@@ -1060,7 +1063,7 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
     // Compute the bits in the result that are not present in the input.
     const IntegerType *SrcTy = cast<IntegerType>(I->getOperand(0)->getType());
     uint64_t NotIn = ~SrcTy->getBitMask();
-    uint64_t NewBits = cast<IntegerType>(I->getType())->getBitMask() & NotIn;
+    uint64_t NewBits = VTy->getBitMask() & NotIn;
     
     // Get the sign bit for the source type
     uint64_t InSignBit = 1ULL << (SrcTy->getPrimitiveSizeInBits()-1);
@@ -1083,8 +1086,7 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
     // convert this into a zero extension.
     if ((KnownZero & InSignBit) || (NewBits & ~DemandedMask) == NewBits) {
       // Convert to ZExt cast
-      CastInst *NewCast = CastInst::create(
-        Instruction::ZExt, I->getOperand(0), I->getType(), I->getName(), I);
+      CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName(), I);
       return UpdateValueUsesWith(I, NewCast);
     } else if (KnownOne & InSignBit) {    // Input sign bit known set
       KnownOne |= NewBits;
@@ -1109,7 +1111,7 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
       // either.
       
       // Shift the demanded mask up so that it's at the top of the uint64_t.
-      unsigned BitWidth = I->getType()->getPrimitiveSizeInBits();
+      unsigned BitWidth = VTy->getPrimitiveSizeInBits();
       unsigned NLZ = CountLeadingZeros_64(DemandedMask << (64-BitWidth));
       
       // If the top bit of the output is demanded, demand everything from the
@@ -1205,8 +1207,8 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
       
       // Compute the new bits that are at the top now.
       uint64_t HighBits = (1ULL << ShiftAmt)-1;
-      HighBits <<= I->getType()->getPrimitiveSizeInBits() - ShiftAmt;
-      uint64_t TypeMask = cast<IntegerType>(I->getType())->getBitMask();
+      HighBits <<= VTy->getBitWidth() - ShiftAmt;
+      uint64_t TypeMask = VTy->getBitMask();
       // Unsigned shift right.
       if (SimplifyDemandedBits(I->getOperand(0),
                               (DemandedMask << ShiftAmt) & TypeMask,
@@ -1238,8 +1240,8 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
       
       // Compute the new bits that are at the top now.
       uint64_t HighBits = (1ULL << ShiftAmt)-1;
-      HighBits <<= I->getType()->getPrimitiveSizeInBits() - ShiftAmt;
-      uint64_t TypeMask = cast<IntegerType>(I->getType())->getBitMask();
+      HighBits <<= VTy->getBitWidth() - ShiftAmt;
+      uint64_t TypeMask = VTy->getBitMask();
       // Signed shift right.
       if (SimplifyDemandedBits(I->getOperand(0),
                                (DemandedMask << ShiftAmt) & TypeMask,
@@ -1252,7 +1254,7 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
       KnownOne  >>= ShiftAmt;
         
       // Handle the sign bits.
-      uint64_t SignBit = 1ULL << (I->getType()->getPrimitiveSizeInBits()-1);
+      uint64_t SignBit = 1ULL << (VTy->getBitWidth()-1);
       SignBit >>= ShiftAmt;  // Adjust to where it is now in the mask.
         
       // If the input sign bit is known to be zero, or if none of the top bits
@@ -1273,7 +1275,7 @@ bool InstCombiner::SimplifyDemandedBits(Value *V, uint64_t DemandedMask,
   // If the client is only demanding bits that we know, return the known
   // constant.
   if ((DemandedMask & (KnownZero|KnownOne)) == DemandedMask)
-    return UpdateValueUsesWith(I, ConstantInt::get(I->getType(), KnownOne));
+    return UpdateValueUsesWith(I, ConstantInt::get(VTy, KnownOne));
   return false;
 }  
 
@@ -5906,35 +5908,62 @@ Instruction *InstCombiner::PromoteCastOfAllocation(CastInst &CI,
 }
 
 /// CanEvaluateInDifferentType - Return true if we can take the specified value
-/// and return it without inserting any new casts.  This is used by code that
-/// tries to decide whether promoting or shrinking integer operations to wider
-/// or smaller types will allow us to eliminate a truncate or extend.
-static bool CanEvaluateInDifferentType(Value *V, const Type *Ty,
+/// and return it as type Ty without inserting any new casts and without
+/// changing the computed value.  This is used by code that tries to decide
+/// whether promoting or shrinking integer operations to wider or smaller types
+/// will allow us to eliminate a truncate or extend.
+///
+/// This is a truncation operation if Ty is smaller than V->getType(), or an
+/// extension operation if Ty is larger.
+static bool CanEvaluateInDifferentType(Value *V, const IntegerType *Ty,
                                        int &NumCastsRemoved) {
-  if (isa<Constant>(V)) return true;
+  // We can always evaluate constants in another type.
+  if (isa<ConstantInt>(V))
+    return true;
   
   Instruction *I = dyn_cast<Instruction>(V);
-  if (!I || !I->hasOneUse()) return false;
+  if (!I) return false;
+  
+  const IntegerType *OrigTy = cast<IntegerType>(V->getType());
   
   switch (I->getOpcode()) {
+  case Instruction::Add:
+  case Instruction::Sub:
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor:
+    if (!I->hasOneUse()) return false;
     // These operators can all arbitrarily be extended or truncated.
     return CanEvaluateInDifferentType(I->getOperand(0), Ty, NumCastsRemoved) &&
            CanEvaluateInDifferentType(I->getOperand(1), Ty, NumCastsRemoved);
-  case Instruction::AShr:
-  case Instruction::LShr:
+
   case Instruction::Shl:
-    // If this is just a bitcast changing the sign of the operation, we can
-    // convert if the operand can be converted.
-    if (V->getType()->getPrimitiveSizeInBits() == Ty->getPrimitiveSizeInBits())
-      return CanEvaluateInDifferentType(I->getOperand(0), Ty, NumCastsRemoved);
+    if (!I->hasOneUse()) return false;
+    // If we are truncating the result of this SHL, and if it's a shift of a
+    // constant amount, we can always perform a SHL in a smaller type.
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      if (Ty->getBitWidth() < OrigTy->getBitWidth() &&
+          CI->getZExtValue() < Ty->getBitWidth())
+        return CanEvaluateInDifferentType(I->getOperand(0), Ty,NumCastsRemoved);
+    }
+    break;
+  case Instruction::LShr:
+    if (!I->hasOneUse()) return false;
+    // If this is a truncate of a logical shr, we can truncate it to a smaller
+    // lshr iff we know that the bits we would otherwise be shifting in are
+    // already zeros.
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      if (Ty->getBitWidth() < OrigTy->getBitWidth() &&
+          MaskedValueIsZero(I->getOperand(0),
+                            OrigTy->getBitMask() & ~Ty->getBitMask()) &&
+          CI->getZExtValue() < Ty->getBitWidth()) {
+        return CanEvaluateInDifferentType(I->getOperand(0), Ty, NumCastsRemoved);
+      }
+    }
     break;
   case Instruction::Trunc:
   case Instruction::ZExt:
   case Instruction::SExt:
-  case Instruction::BitCast:
     // If this is a cast from the destination type, we can trivially eliminate
     // it, and this will remove a cast overall.
     if (I->getOperand(0)->getType() == Ty) {
@@ -5960,7 +5989,7 @@ static bool CanEvaluateInDifferentType(Value *V, const Type *Ty,
 /// CanEvaluateInDifferentType returns true for, actually insert the code to
 /// evaluate the expression.
 Value *InstCombiner::EvaluateInDifferentType(Value *V, const Type *Ty, 
-                                             bool isSigned ) {
+                                             bool isSigned) {
   if (Constant *C = dyn_cast<Constant>(V))
     return ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/);
 
@@ -5968,21 +5997,18 @@ Value *InstCombiner::EvaluateInDifferentType(Value *V, const Type *Ty,
   Instruction *I = cast<Instruction>(V);
   Instruction *Res = 0;
   switch (I->getOpcode()) {
+  case Instruction::Add:
+  case Instruction::Sub:
   case Instruction::And:
   case Instruction::Or:
-  case Instruction::Xor: {
-    Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
-    Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
-    Res = BinaryOperator::create((Instruction::BinaryOps)I->getOpcode(),
-                                 LHS, RHS, I->getName());
-    break;
-  }
+  case Instruction::Xor:
   case Instruction::AShr:
   case Instruction::LShr:
   case Instruction::Shl: {
     Value *LHS = EvaluateInDifferentType(I->getOperand(0), Ty, isSigned);
-    Res = BinaryOperator::create(Instruction::BinaryOps(I->getOpcode()), LHS, 
-                                 I->getOperand(1), I->getName());
+    Value *RHS = EvaluateInDifferentType(I->getOperand(1), Ty, isSigned);
+    Res = BinaryOperator::create((Instruction::BinaryOps)I->getOpcode(),
+                                 LHS, RHS, I->getName());
     break;
   }    
   case Instruction::Trunc:
@@ -6064,8 +6090,8 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) {
   return 0;
 }
 
-/// Only the TRUNC, ZEXT, SEXT, and BITCONVERT can have both operands as
-/// integers. This function implements the common transforms for all those
+/// Only the TRUNC, ZEXT, SEXT, and BITCAST can both operand and result as
+/// integer types. This function implements the common transforms for all those
 /// cases.
 /// @brief Implement the transforms common to CastInst with integer operands
 Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) {
@@ -6091,9 +6117,11 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) {
   if (!SrcI || !Src->hasOneUse())
     return 0;
 
-  // Attempt to propagate the cast into the instruction.
+  // Attempt to propagate the cast into the instruction for int->int casts.
   int NumCastsRemoved = 0;
-  if (CanEvaluateInDifferentType(SrcI, DestTy, NumCastsRemoved)) {
+  if (!isa<BitCastInst>(CI) &&
+      CanEvaluateInDifferentType(SrcI, cast<IntegerType>(DestTy),
+                                 NumCastsRemoved)) {
     // If this cast is a truncate, evaluting in a different type always
     // eliminates the cast, so it is always a win.  If this is a noop-cast
     // this just removes a noop cast which isn't pointful, but simplifies
@@ -6102,27 +6130,24 @@ Instruction *InstCombiner::commonIntCastTransforms(CastInst &CI) {
     // the input have eliminated at least one cast.  If this is a sign
     // extension, we insert two new casts (to do the extension) so we
     // require that two casts have been eliminated.
-    bool DoXForm = CI.isNoopCast(TD->getIntPtrType());
-    if (!DoXForm) {
-      switch (CI.getOpcode()) {
-        case Instruction::Trunc:
-          DoXForm = true;
-          break;
-        case Instruction::ZExt:
-          DoXForm = NumCastsRemoved >= 1;
-          break;
-        case Instruction::SExt:
-          DoXForm = NumCastsRemoved >= 2;
-          break;
-        case Instruction::BitCast:
-          DoXForm = false;
-          break;
-        default:
-          // All the others use floating point so we shouldn't actually 
-          // get here because of the check above.
-          assert(!"Unknown cast type .. unreachable");
-          break;
-      }
+    bool DoXForm;
+    switch (CI.getOpcode()) {
+    default:
+      // All the others use floating point so we shouldn't actually 
+      // get here because of the check above.
+      assert(0 && "Unknown cast type");
+    case Instruction::Trunc:
+      DoXForm = true;
+      break;
+    case Instruction::ZExt:
+      DoXForm = NumCastsRemoved >= 1;
+      break;
+    case Instruction::SExt:
+      DoXForm = NumCastsRemoved >= 2;
+      break;
+    case Instruction::BitCast:
+      DoXForm = false;
+      break;
     }
     
     if (DoXForm) {
@@ -7660,7 +7685,7 @@ static bool DeadPHICycle(PHINode *PN, std::set<PHINode*> &PotentiallyDeadPHIs) {
 //
 Instruction *InstCombiner::visitPHINode(PHINode &PN) {
   // If LCSSA is around, don't mess with Phi nodes
-  if (mustPreserveAnalysisID(LCSSAID)) return 0;
+  if (MustPreserveLCSSA) return 0;
   
   if (Value *V = PN.hasConstantValue())
     return ReplaceInstUsesWith(PN, V);
@@ -9164,9 +9189,12 @@ static void AddReachableCodeToWorklist(BasicBlock *BB,
     AddReachableCodeToWorklist(TI->getSuccessor(i), Visited, IC, TD);
 }
 
-bool InstCombiner::runOnFunction(Function &F) {
+bool InstCombiner::DoOneIteration(Function &F, unsigned Iteration) {
   bool Changed = false;
   TD = &getAnalysis<TargetData>();
+  
+  DEBUG(DOUT << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on "
+             << F.getNameStr() << "\n");
 
   {
     // Do a depth-first traversal of the function, populate the worklist with
@@ -9295,24 +9323,38 @@ bool InstCombiner::runOnFunction(Function &F) {
         if (isInstructionTriviallyDead(I)) {
           // Make sure we process all operands now that we are reducing their
           // use counts.
-          AddUsesToWorkList(*I);;
+          AddUsesToWorkList(*I);
 
           // Instructions may end up in the worklist more than once.  Erase all
           // occurrences of this instruction.
           RemoveFromWorkList(I);
           I->eraseFromParent();
         } else {
-          AddToWorkList(Result);
-          AddUsersToWorkList(*Result);
+          AddToWorkList(I);
+          AddUsersToWorkList(*I);
         }
       }
       Changed = true;
     }
   }
 
+  assert(WorklistMap.empty() && "Worklist empty, but map not?");
   return Changed;
 }
 
+
+bool InstCombiner::runOnFunction(Function &F) {
+  MustPreserveLCSSA = mustPreserveAnalysisID(LCSSAID);
+  
+  bool EverMadeChange = false;
+
+  // Iterate while there is work to do.
+  unsigned Iteration = 0;
+  while (DoOneIteration(F, Iteration++)) 
+    EverMadeChange = true;
+  return EverMadeChange;
+}
+
 FunctionPass *llvm::createInstructionCombiningPass() {
   return new InstCombiner();
 }