Implement select.ll:test11
[oota-llvm.git] / lib / Transforms / Scalar / InstructionCombining.cpp
index 105ef5b81b40128dc64de5199ce624afad9cf17f..2f2d32ceab630d8d18b35f3962c1cbd8357a9162 100644 (file)
 #include "llvm/Target/TargetData.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Support/CallSite.h"
+#include "llvm/Support/GetElementPtrTypeIterator.h"
 #include "llvm/Support/InstIterator.h"
 #include "llvm/Support/InstVisitor.h"
-#include "llvm/Support/CallSite.h"
 #include "Support/Debug.h"
 #include "Support/Statistic.h"
 #include <algorithm>
@@ -92,6 +93,8 @@ namespace {
       AU.setPreservesCFG();
     }
 
+    TargetData &getTargetData() const { return *TD; }
+
     // Visitation implementation - Implement instruction combining for different
     // instruction types.  The semantics are as follows:
     // Return Value:
@@ -127,6 +130,7 @@ namespace {
     Instruction *visitCallSite(CallSite CS);
     bool transformConstExprCastCall(CallSite CS);
 
+  public:
     // InsertNewInstBefore - insert an instruction New before instruction Old
     // in the program.  Add the new instruction to the worklist.
     //
@@ -139,7 +143,6 @@ namespace {
       return New;
     }
 
-  public:
     // ReplaceInstUsesWith - This method is to be used when an instruction is
     // found to be dead, replacable with another preexisting expression.  Here
     // we add all uses of I to the worklist, replace all uses of I with the new
@@ -148,8 +151,15 @@ namespace {
     //
     Instruction *ReplaceInstUsesWith(Instruction &I, Value *V) {
       AddUsersToWorkList(I);         // Add all modified instrs to worklist
-      I.replaceAllUsesWith(V);
-      return &I;
+      if (&I != V) {
+        I.replaceAllUsesWith(V);
+        return &I;
+      } else {
+        // If we are replacing the instruction with itself, this must be in a
+        // segment of unreachable code, so just clobber the instruction.
+        I.replaceAllUsesWith(Constant::getNullValue(I.getType()));
+        return &I;
+      }
     }
 
     // EraseInstFromFunction - When dealing with an instruction that has side
@@ -421,7 +431,12 @@ Instruction *AssociativeOpt(BinaryOperator &Root, const Functor &F) {
 
       // Make what used to be the LHS of the root be the user of the root...
       Value *ExtraOperand = TmpLHSI->getOperand(1);
-      Root.replaceAllUsesWith(TmpLHSI);          // Users now use TmpLHSI
+      if (&Root != TmpLHSI)
+        Root.replaceAllUsesWith(TmpLHSI);        // Users now use TmpLHSI
+      else {
+        Root.replaceAllUsesWith(Constant::getNullValue(TmpLHSI->getType()));
+        return 0;
+      }
       TmpLHSI->setOperand(1, &Root);             // TmpLHSI now uses the root
       BB->getInstList().remove(&Root);           // Remove root from the BB
       BB->getInstList().insert(TmpLHSI, &Root);  // Insert root before TmpLHSI
@@ -808,6 +823,8 @@ Instruction *InstCombiner::visitRem(BinaryOperator &I) {
   if (ConstantInt *RHS = dyn_cast<ConstantInt>(I.getOperand(1))) {
     if (RHS->equalsInt(1))  // X % 1 == 0
       return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+    if (RHS->isAllOnesValue())  // X % -1 == 0
+      return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
 
     // Check to see if this is an unsigned remainder with an exact power of 2,
     // if so, convert to a bitwise and.
@@ -1992,15 +2009,68 @@ Instruction *InstCombiner::visitCastInst(CastInst &CI) {
 }
 
 Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
-  if (ConstantBool *C = dyn_cast<ConstantBool>(SI.getCondition()))
+  Value *CondVal = SI.getCondition();
+  Value *TrueVal = SI.getTrueValue();
+  Value *FalseVal = SI.getFalseValue();
+
+  // select true, X, Y  -> X
+  // select false, X, Y -> Y
+  if (ConstantBool *C = dyn_cast<ConstantBool>(CondVal))
     if (C == ConstantBool::True)
-      return ReplaceInstUsesWith(SI, SI.getTrueValue());
+      return ReplaceInstUsesWith(SI, TrueVal);
     else {
       assert(C == ConstantBool::False);
-      return ReplaceInstUsesWith(SI, SI.getFalseValue());
+      return ReplaceInstUsesWith(SI, FalseVal);
+    }
+
+  // select C, X, X -> X
+  if (TrueVal == FalseVal)
+    return ReplaceInstUsesWith(SI, TrueVal);
+
+  if (SI.getType() == Type::BoolTy)
+    if (ConstantBool *C = dyn_cast<ConstantBool>(TrueVal)) {
+      if (C == ConstantBool::True) {
+        // Change: A = select B, true, C --> A = or B, C
+        return BinaryOperator::create(Instruction::Or, CondVal, FalseVal);
+      } else {
+        // Change: A = select B, false, C --> A = and !B, C
+        Value *NotCond =
+          InsertNewInstBefore(BinaryOperator::createNot(CondVal,
+                                             "not."+CondVal->getName()), SI);
+        return BinaryOperator::create(Instruction::And, NotCond, FalseVal);
+      }
+    } else if (ConstantBool *C = dyn_cast<ConstantBool>(FalseVal)) {
+      if (C == ConstantBool::False) {
+        // Change: A = select B, C, false --> A = and B, C
+        return BinaryOperator::create(Instruction::And, CondVal, TrueVal);
+      } else {
+        // Change: A = select B, C, true --> A = or !B, C
+        Value *NotCond =
+          InsertNewInstBefore(BinaryOperator::createNot(CondVal,
+                                             "not."+CondVal->getName()), SI);
+        return BinaryOperator::create(Instruction::Or, NotCond, TrueVal);
+      }
     }
-  // Other transformations are possible!
 
+  // Selecting between two constants?
+  if (Constant *TrueValC = dyn_cast<Constant>(TrueVal))
+    if (Constant *FalseValC = dyn_cast<Constant>(FalseVal)) {
+      if (SI.getType()->isInteger()) {
+        // select C, 1, 0 -> cast C to int
+        if (FalseValC->isNullValue() && isa<ConstantInt>(TrueValC) &&
+            cast<ConstantInt>(TrueValC)->getRawValue() == 1) {
+          return new CastInst(CondVal, SI.getType());
+        } else if (TrueValC->isNullValue() && isa<ConstantInt>(FalseValC) &&
+                   cast<ConstantInt>(FalseValC)->getRawValue() == 1) {
+          // select C, 0, 1 -> cast !C to int
+          Value *NotCond =
+            InsertNewInstBefore(BinaryOperator::createNot(CondVal,
+                                               "not."+CondVal->getName()), SI);
+          return new CastInst(NotCond, SI.getType());
+        }
+      }
+    }
+  
   return 0;
 }
 
@@ -2132,9 +2202,8 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
     if ((*AI)->getType() == ParamTy) {
       Args.push_back(*AI);
     } else {
-      Instruction *Cast = new CastInst(*AI, ParamTy, "tmp");
-      InsertNewInstBefore(Cast, *Caller);
-      Args.push_back(Cast);
+      Args.push_back(InsertNewInstBefore(new CastInst(*AI, ParamTy, "tmp"),
+                                         *Caller));
     }
   }
 
@@ -2242,6 +2311,20 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) {
   return 0;
 }
 
+static Value *InsertSignExtendToPtrTy(Value *V, const Type *DTy,
+                                      Instruction *InsertPoint,
+                                      InstCombiner *IC) {
+  unsigned PS = IC->getTargetData().getPointerSize();
+  const Type *VTy = V->getType();
+  Instruction *Cast;
+  if (!VTy->isSigned() && VTy->getPrimitiveSize() < PS)
+    // We must insert a cast to ensure we sign-extend.
+    V = IC->InsertNewInstBefore(new CastInst(V, VTy->getSignedVersion(),
+                                             V->getName()), *InsertPoint);
+  return IC->InsertNewInstBefore(new CastInst(V, DTy, V->getName()),
+                                 *InsertPoint);
+}
+
 
 Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   // Is it 'getelementptr %P, long 0'  or 'getelementptr %P'
@@ -2256,27 +2339,85 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   if (GEP.getNumOperands() == 2 && HasZeroPointerIndex)
     return ReplaceInstUsesWith(GEP, GEP.getOperand(0));
 
+  // Eliminate unneeded casts for indices.
+  bool MadeChange = false;
+  gep_type_iterator GTI = gep_type_begin(GEP);
+  for (unsigned i = 1, e = GEP.getNumOperands(); i != e; ++i, ++GTI)
+    if (isa<SequentialType>(*GTI)) {
+      if (CastInst *CI = dyn_cast<CastInst>(GEP.getOperand(i))) {
+        Value *Src = CI->getOperand(0);
+        const Type *SrcTy = Src->getType();
+        const Type *DestTy = CI->getType();
+        if (Src->getType()->isInteger()) {
+          if (SrcTy->getPrimitiveSize() == DestTy->getPrimitiveSize()) {
+            // We can always eliminate a cast from ulong or long to the other.
+            // We can always eliminate a cast from uint to int or the other on
+            // 32-bit pointer platforms.
+            if (DestTy->getPrimitiveSize() >= TD->getPointerSize()) {
+              MadeChange = true;
+              GEP.setOperand(i, Src);
+            }
+          } else if (SrcTy->getPrimitiveSize() < DestTy->getPrimitiveSize() &&
+                     SrcTy->getPrimitiveSize() == 4) {
+            // We can always eliminate a cast from int to [u]long.  We can
+            // eliminate a cast from uint to [u]long iff the target is a 32-bit
+            // pointer target.
+            if (SrcTy->isSigned() || 
+                SrcTy->getPrimitiveSize() >= TD->getPointerSize()) {
+              MadeChange = true;
+              GEP.setOperand(i, Src);
+            }
+          }
+        }
+      }
+      // If we are using a wider index than needed for this platform, shrink it
+      // to what we need.  If the incoming value needs a cast instruction,
+      // insert it.  This explicit cast can make subsequent optimizations more
+      // obvious.
+      Value *Op = GEP.getOperand(i);
+      if (Op->getType()->getPrimitiveSize() > TD->getPointerSize())
+        if (!isa<Constant>(Op)) {
+          Op = InsertNewInstBefore(new CastInst(Op, TD->getIntPtrType(),
+                                                Op->getName()), GEP);
+          GEP.setOperand(i, Op);
+          MadeChange = true;
+        }
+    }
+  if (MadeChange) return &GEP;
+
   // Combine Indices - If the source pointer to this getelementptr instruction
   // is a getelementptr instruction, combine the indices of the two
   // getelementptr instructions into a single instruction.
   //
+  std::vector<Value*> SrcGEPOperands;
   if (GetElementPtrInst *Src = dyn_cast<GetElementPtrInst>(GEP.getOperand(0))) {
+    SrcGEPOperands.assign(Src->op_begin(), Src->op_end());
+  } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(GEP.getOperand(0))) {
+    if (CE->getOpcode() == Instruction::GetElementPtr)
+      SrcGEPOperands.assign(CE->op_begin(), CE->op_end());
+  }
+
+  if (!SrcGEPOperands.empty()) {
     std::vector<Value *> Indices;
   
     // Can we combine the two pointer arithmetics offsets?
-    if (Src->getNumOperands() == 2 && isa<Constant>(Src->getOperand(1)) &&
+    if (SrcGEPOperands.size() == 2 && isa<Constant>(SrcGEPOperands[1]) &&
         isa<Constant>(GEP.getOperand(1))) {
+      Constant *SGC = cast<Constant>(SrcGEPOperands[1]);
+      Constant *GC  = cast<Constant>(GEP.getOperand(1));
+      if (SGC->getType() != GC->getType()) {
+        SGC = ConstantExpr::getSignExtend(SGC, Type::LongTy);
+        GC = ConstantExpr::getSignExtend(GC, Type::LongTy);
+      }
+      
       // Replace: gep (gep %P, long C1), long C2, ...
       // With:    gep %P, long (C1+C2), ...
-      Value *Sum = ConstantExpr::get(Instruction::Add,
-                                     cast<Constant>(Src->getOperand(1)),
-                                     cast<Constant>(GEP.getOperand(1)));
-      assert(Sum && "Constant folding of longs failed!?");
-      GEP.setOperand(0, Src->getOperand(0));
-      GEP.setOperand(1, Sum);
-      AddUsersToWorkList(*Src);   // Reduce use count of Src
+      GEP.setOperand(0, SrcGEPOperands[0]);
+      GEP.setOperand(1, ConstantExpr::getAdd(SGC, GC));
+      if (Instruction *I = dyn_cast<Instruction>(GEP.getOperand(0)))
+        AddUsersToWorkList(*I);   // Reduce use count of Src
       return &GEP;
-    } else if (Src->getNumOperands() == 2) {
+    } else if (SrcGEPOperands.size() == 2) {
       // Replace: gep (gep %P, long B), long A, ...
       // With:    T = long A+B; gep %P, T, ...
       //
@@ -2284,32 +2425,73 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
       // chain to be resolved before we perform this transformation.  This
       // avoids us creating a TON of code in some cases.
       //
-      if (isa<GetElementPtrInst>(Src->getOperand(0)) &&
-          cast<Instruction>(Src->getOperand(0))->getNumOperands() == 2)
+      if (isa<GetElementPtrInst>(SrcGEPOperands[0]) &&
+          cast<Instruction>(SrcGEPOperands[0])->getNumOperands() == 2)
         return 0;   // Wait until our source is folded to completion.
 
-      Value *Sum = BinaryOperator::create(Instruction::Add, Src->getOperand(1),
-                                          GEP.getOperand(1),
-                                          Src->getName()+".sum", &GEP);
-      GEP.setOperand(0, Src->getOperand(0));
+      Value *Sum, *SO1 = SrcGEPOperands[1], *GO1 = GEP.getOperand(1);
+      if (SO1 == Constant::getNullValue(SO1->getType())) {
+        Sum = GO1;
+      } else if (GO1 == Constant::getNullValue(GO1->getType())) {
+        Sum = SO1;
+      } else {
+        // If they aren't the same type, convert both to an integer of the
+        // target's pointer size.
+        if (SO1->getType() != GO1->getType()) {
+          if (Constant *SO1C = dyn_cast<Constant>(SO1)) {
+            SO1 = ConstantExpr::getCast(SO1C, GO1->getType());
+          } else if (Constant *GO1C = dyn_cast<Constant>(GO1)) {
+            GO1 = ConstantExpr::getCast(GO1C, SO1->getType());
+          } else {
+            unsigned PS = TD->getPointerSize();
+            Instruction *Cast;
+            if (SO1->getType()->getPrimitiveSize() == PS) {
+              // Convert GO1 to SO1's type.
+              GO1 = InsertSignExtendToPtrTy(GO1, SO1->getType(), &GEP, this);
+
+            } else if (GO1->getType()->getPrimitiveSize() == PS) {
+              // Convert SO1 to GO1's type.
+              SO1 = InsertSignExtendToPtrTy(SO1, GO1->getType(), &GEP, this);
+            } else {
+              const Type *PT = TD->getIntPtrType();
+              SO1 = InsertSignExtendToPtrTy(SO1, PT, &GEP, this);
+              GO1 = InsertSignExtendToPtrTy(GO1, PT, &GEP, this);
+            }
+          }
+        }
+        Sum = BinaryOperator::create(Instruction::Add, SO1, GO1,
+                                     GEP.getOperand(0)->getName()+".sum", &GEP);
+        WorkList.push_back(cast<Instruction>(Sum));
+      }
+      GEP.setOperand(0, SrcGEPOperands[0]);
       GEP.setOperand(1, Sum);
-      WorkList.push_back(cast<Instruction>(Sum));
       return &GEP;
-    } else if (*GEP.idx_begin() == Constant::getNullValue(Type::LongTy) &&
-               Src->getNumOperands() != 1) { 
+    } else if (isa<Constant>(*GEP.idx_begin()) && 
+               cast<Constant>(*GEP.idx_begin())->isNullValue() &&
+               SrcGEPOperands.size() != 1) { 
       // Otherwise we can do the fold if the first index of the GEP is a zero
-      Indices.insert(Indices.end(), Src->idx_begin(), Src->idx_end());
+      Indices.insert(Indices.end(), SrcGEPOperands.begin()+1,
+                     SrcGEPOperands.end());
       Indices.insert(Indices.end(), GEP.idx_begin()+1, GEP.idx_end());
-    } else if (Src->getOperand(Src->getNumOperands()-1) == 
-               Constant::getNullValue(Type::LongTy)) {
-      // If the src gep ends with a constant array index, merge this get into
-      // it, even if we have a non-zero array index.
-      Indices.insert(Indices.end(), Src->idx_begin(), Src->idx_end()-1);
-      Indices.insert(Indices.end(), GEP.idx_begin(), GEP.idx_end());
+    } else if (SrcGEPOperands.back() ==
+               Constant::getNullValue(SrcGEPOperands.back()->getType())) {
+      // We have to check to make sure this really is an ARRAY index we are
+      // ending up with, not a struct index.
+      generic_gep_type_iterator<std::vector<Value*>::iterator>
+        GTI = gep_type_begin(SrcGEPOperands[0]->getType(),
+                             SrcGEPOperands.begin()+1, SrcGEPOperands.end());
+      std::advance(GTI, SrcGEPOperands.size()-2);
+      if (isa<SequentialType>(*GTI)) {
+        // If the src gep ends with a constant array index, merge this get into
+        // it, even if we have a non-zero array index.
+        Indices.insert(Indices.end(), SrcGEPOperands.begin()+1,
+                       SrcGEPOperands.end()-1);
+        Indices.insert(Indices.end(), GEP.idx_begin(), GEP.idx_end());
+      }
     }
 
     if (!Indices.empty())
-      return new GetElementPtrInst(Src->getOperand(0), Indices, GEP.getName());
+      return new GetElementPtrInst(SrcGEPOperands[0], Indices, GEP.getName());
 
   } else if (GlobalValue *GV = dyn_cast<GlobalValue>(GEP.getOperand(0))) {
     // GEP of global variable.  If all of the indices for this GEP are
@@ -2384,7 +2566,7 @@ Instruction *InstCombiner::visitAllocationInst(AllocationInst &AI) {
       // Now that I is pointing to the first non-allocation-inst in the block,
       // insert our getelementptr instruction...
       //
-      std::vector<Value*> Idx(2, Constant::getNullValue(Type::LongTy));
+      std::vector<Value*> Idx(2, Constant::getNullValue(Type::IntTy));
       Value *V = new GetElementPtrInst(New, Idx, New->getName()+".sub", It);
 
       // Now make everything use the getelementptr instead of the original
@@ -2425,7 +2607,7 @@ Instruction *InstCombiner::visitFreeInst(FreeInst &FI) {
 /// expression, or null if something is funny.
 ///
 static Constant *GetGEPGlobalInitializer(Constant *C, ConstantExpr *CE) {
-  if (CE->getOperand(1) != Constant::getNullValue(Type::LongTy))
+  if (CE->getOperand(1) != Constant::getNullValue(CE->getOperand(1)->getType()))
     return 0;  // Do not allow stepping over the value!
 
   // Loop over all of the operands, tracking down which value we are
@@ -2466,6 +2648,27 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) {
           if (GV->isConstant() && !GV->isExternal())
             if (Constant *V = GetGEPGlobalInitializer(GV->getInitializer(), CE))
               return ReplaceInstUsesWith(LI, V);
+
+  // load (cast X) --> cast (load X) iff safe
+  if (CastInst *CI = dyn_cast<CastInst>(Op)) {
+    const Type *DestPTy = cast<PointerType>(CI->getType())->getElementType();
+    if (const PointerType *SrcTy =
+        dyn_cast<PointerType>(CI->getOperand(0)->getType())) {
+      const Type *SrcPTy = SrcTy->getElementType();
+      if (TD->getTypeSize(SrcPTy) == TD->getTypeSize(DestPTy) &&
+          (SrcPTy->isInteger() || isa<PointerType>(SrcPTy)) &&
+          (DestPTy->isInteger() || isa<PointerType>(DestPTy))) {
+        // Okay, we are casting from one integer or pointer type to another of
+        // the same size.  Instead of casting the pointer before the load, cast
+        // the result of the loaded value.
+        Value *NewLoad = InsertNewInstBefore(new LoadInst(CI->getOperand(0),
+                                                          CI->getName()), LI);
+        // Now cast the result of the load.
+        return new CastInst(NewLoad, LI.getType());
+      }
+    }
+  }
+
   return 0;
 }
 
@@ -2548,6 +2751,16 @@ bool InstCombiner::runOnFunction(Function &F) {
       continue;
     }
 
+    // Check to see if any of the operands of this instruction are a
+    // ConstantPointerRef.  Since they sneak in all over the place and inhibit
+    // optimization, we want to strip them out unconditionally!
+    for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
+      if (ConstantPointerRef *CPR =
+          dyn_cast<ConstantPointerRef>(I->getOperand(i))) {
+        I->setOperand(i, CPR->getValue());
+        Changed = true;
+      }
+
     // Now that we have an instruction, try combining it to simplify it...
     if (Instruction *Result = visit(*I)) {
       ++NumCombined;