reapply r101364, which has been backed out in r101368
[oota-llvm.git] / lib / Transforms / Scalar / ScalarReplAggregates.cpp
index de02cafb37d1b7fcdabaec80e7685ef943d295fd..c58c858b960f5a2612a5ad73fa3b062640ebbdb4 100644 (file)
@@ -202,12 +202,18 @@ bool SROA::performPromotion(Function &F) {
   return Changed;
 }
 
-/// getNumSAElements - Return the number of elements in the specific struct or
-/// array.
-static uint64_t getNumSAElements(const Type *T) {
+/// ShouldAttemptScalarRepl - Decide if an alloca is a good candidate for
+/// SROA.  It must be a struct or array type with a small number of elements.
+static bool ShouldAttemptScalarRepl(AllocaInst *AI) {
+  const Type *T = AI->getAllocatedType();
+  // Do not promote any struct into more than 32 separate vars.
   if (const StructType *ST = dyn_cast<StructType>(T))
-    return ST->getNumElements();
-  return cast<ArrayType>(T)->getNumElements();
+    return ST->getNumElements() <= 32;
+  // Arrays are much less likely to be safe for SROA; only consider
+  // them if they are very small.
+  if (const ArrayType *AT = dyn_cast<ArrayType>(T))
+    return AT->getNumElements() <= 8;
+  return false;
 }
 
 // performScalarRepl - This algorithm is a simple worklist driven algorithm,
@@ -248,7 +254,7 @@ bool SROA::performScalarRepl(Function &F) {
     if (Instruction *TheCopy = isOnlyCopiedFromConstantGlobal(AI)) {
       DEBUG(dbgs() << "Found alloca equal to global: " << *AI << '\n');
       DEBUG(dbgs() << "  memcpy = " << *TheCopy << '\n');
-      Constant *TheSrc = cast<Constant>(TheCopy->getOperand(2));
+      Constant *TheSrc = cast<Constant>(TheCopy->getOperand(1));
       AI->replaceAllUsesWith(ConstantExpr::getBitCast(TheSrc, AI->getType()));
       TheCopy->eraseFromParent();  // Don't mutate the global.
       AI->eraseFromParent();
@@ -266,22 +272,18 @@ bool SROA::performScalarRepl(Function &F) {
     // Do not promote [0 x %struct].
     if (AllocaSize == 0) continue;
 
+    // If the alloca looks like a good candidate for scalar replacement, and if
+    // all its users can be transformed, then split up the aggregate into its
+    // separate elements.
+    if (ShouldAttemptScalarRepl(AI) && isSafeAllocaToScalarRepl(AI)) {
+      DoScalarReplacement(AI, WorkList);
+      Changed = true;
+      continue;
+    }
+
     // Do not promote any struct whose size is too big.
     if (AllocaSize > SRThreshold) continue;
 
-    if ((isa<StructType>(AI->getAllocatedType()) ||
-         isa<ArrayType>(AI->getAllocatedType())) &&
-        // Do not promote any struct into more than "32" separate vars.
-        getNumSAElements(AI->getAllocatedType()) <= SRThreshold/4) {
-      // Check that all of the users of the allocation are capable of being
-      // transformed.
-      if (isSafeAllocaToScalarRepl(AI)) {
-        DoScalarReplacement(AI, WorkList);
-        Changed = true;
-        continue;
-      }
-    }
-
     // If we can turn this aggregate value (potentially with casts) into a
     // simple scalar value that can be mem2reg'd into a register value.
     // IsNotTrivial tracks whether this is something that mem2reg could have
@@ -300,7 +302,7 @@ bool SROA::performScalarRepl(Function &F) {
       // random stuff that doesn't use vectors (e.g. <9 x double>) because then
       // we just get a lot of insert/extracts.  If at least one vector is
       // involved, then we probably really do have a union of vector/array.
-      if (VectorTy && isa<VectorType>(VectorTy) && HadAVector) {
+      if (VectorTy && VectorTy->isVectorTy() && HadAVector) {
         DEBUG(dbgs() << "CONVERT TO VECTOR: " << *AI << "\n  TYPE = "
                      << *VectorTy << '\n');
         
@@ -402,11 +404,11 @@ void SROA::isSafeForScalarRepl(Instruction *I, AllocaInst *AI, uint64_t Offset,
       isSafeGEP(GEPI, AI, GEPOffset, Info);
       if (!Info.isUnsafe)
         isSafeForScalarRepl(GEPI, AI, GEPOffset, Info);
-    } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(UI)) {
+    } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(User)) {
       ConstantInt *Length = dyn_cast<ConstantInt>(MI->getLength());
       if (Length)
         isSafeMemAccess(AI, Offset, Length->getZExtValue(), 0,
-                        UI.getOperandNo() == 1, Info);
+                        UI.getOperandNo() == 0, Info);
       else
         MarkUnsafe(Info);
     } else if (LoadInst *LI = dyn_cast<LoadInst>(User)) {
@@ -447,7 +449,7 @@ void SROA::isSafeGEP(GetElementPtrInst *GEPI, AllocaInst *AI,
   // into.
   for (; GEPIt != E; ++GEPIt) {
     // Ignore struct elements, no extra checking needed for these.
-    if (isa<StructType>(*GEPIt))
+    if ((*GEPIt)->isStructTy())
       continue;
 
     ConstantInt *IdxVal = dyn_cast<ConstantInt>(GEPIt.getOperand());
@@ -478,7 +480,7 @@ void SROA::isSafeMemAccess(AllocaInst *AI, uint64_t Offset, uint64_t MemSize,
     // (which are essentially the same as the MemIntrinsics, especially with
     // regard to copying padding between elements), or references using the
     // aggregate type of the alloca.
-    if (!MemOpType || isa<IntegerType>(MemOpType) || UsesAggregateType) {
+    if (!MemOpType || MemOpType->isIntegerTy() || UsesAggregateType) {
       if (!UsesAggregateType) {
         if (isStore)
           Info.isMemCpyDst = true;
@@ -563,7 +565,7 @@ void SROA::RewriteForScalarRepl(Instruction *I, AllocaInst *AI, uint64_t Offset,
         }
         LI->replaceAllUsesWith(Insert);
         DeadInsts.push_back(LI);
-      } else if (isa<IntegerType>(LIType) &&
+      } else if (LIType->isIntegerTy() &&
                  TD->getTypeAllocSize(LIType) ==
                  TD->getTypeAllocSize(AI->getAllocatedType())) {
         // If this is a load of the entire alloca to an integer, rewrite it.
@@ -586,7 +588,7 @@ void SROA::RewriteForScalarRepl(Instruction *I, AllocaInst *AI, uint64_t Offset,
           new StoreInst(Extract, NewElts[i], SI);
         }
         DeadInsts.push_back(SI);
-      } else if (isa<IntegerType>(SIType) &&
+      } else if (SIType->isIntegerTy() &&
                  TD->getTypeAllocSize(SIType) ==
                  TD->getTypeAllocSize(AI->getAllocatedType())) {
         // If this is a store of the entire alloca from an integer, rewrite it.
@@ -754,7 +756,7 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
   }
   
   // Process each element of the aggregate.
-  Value *TheFn = MI->getOperand(0);
+  Value *TheFn = MI->getCalledValue();
   const Type *BytePtrTy = MI->getRawDest()->getType();
   bool SROADest = MI->getRawDest() == Inst;
   
@@ -812,7 +814,7 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
       // If the stored element is zero (common case), just store a null
       // constant.
       Constant *StoreVal;
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getOperand(2))) {
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getOperand(1))) {
         if (CI->isZero()) {
           StoreVal = Constant::getNullValue(EltTy);  // 0.0, null, 0, <0,0>
         } else {
@@ -831,9 +833,9 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
           
           // Convert the integer value to the appropriate type.
           StoreVal = ConstantInt::get(Context, TotalVal);
-          if (isa<PointerType>(ValTy))
+          if (ValTy->isPointerTy())
             StoreVal = ConstantExpr::getIntToPtr(StoreVal, ValTy);
-          else if (ValTy->isFloatingPoint())
+          else if (ValTy->isFloatingPointTy())
             StoreVal = ConstantExpr::getBitCast(StoreVal, ValTy);
           assert(StoreVal->getType() == ValTy && "Type mismatch!");
           
@@ -856,8 +858,17 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
       EltPtr = new BitCastInst(EltPtr, BytePtrTy, EltPtr->getName(), MI);
     
     // Cast the other pointer (if we have one) to BytePtrTy. 
-    if (OtherElt && OtherElt->getType() != BytePtrTy)
-      OtherElt = new BitCastInst(OtherElt, BytePtrTy, OtherElt->getName(), MI);
+    if (OtherElt && OtherElt->getType() != BytePtrTy) {
+      // Preserve address space of OtherElt
+      const PointerType* OtherPTy = cast<PointerType>(OtherElt->getType());
+      const PointerType* PTy = cast<PointerType>(BytePtrTy);
+      if (OtherPTy->getElementType() != PTy->getElementType()) {
+        Type *NewOtherPTy = PointerType::get(PTy->getElementType(),
+                                             OtherPTy->getAddressSpace());
+        OtherElt = new BitCastInst(OtherElt, NewOtherPTy,
+                                   OtherElt->getNameStr(), MI);
+      }
+    }
     
     unsigned EltSize = TD->getTypeAllocSize(EltTy);
     
@@ -866,19 +877,30 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
       Value *Ops[] = {
         SROADest ? EltPtr : OtherElt,  // Dest ptr
         SROADest ? OtherElt : EltPtr,  // Src ptr
-        ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size
+        ConstantInt::get(MI->getOperand(2)->getType(), EltSize), // Size
         // Align
-        ConstantInt::get(Type::getInt32Ty(MI->getContext()), OtherEltAlign)
+        ConstantInt::get(Type::getInt32Ty(MI->getContext()), OtherEltAlign),
+        MI->getVolatileCst()
       };
-      CallInst::Create(TheFn, Ops, Ops + 4, "", MI);
+      // In case we fold the address space overloaded memcpy of A to B
+      // with memcpy of B to C, change the function to be a memcpy of A to C.
+      const Type *Tys[] = { Ops[0]->getType(), Ops[1]->getType(),
+                            Ops[2]->getType() };
+      Module *M = MI->getParent()->getParent()->getParent();
+      TheFn = Intrinsic::getDeclaration(M, MI->getIntrinsicID(), Tys, 3);
+      CallInst::Create(TheFn, Ops, Ops + 5, "", MI);
     } else {
       assert(isa<MemSetInst>(MI));
       Value *Ops[] = {
-        EltPtr, MI->getOperand(2),  // Dest, Value,
-        ConstantInt::get(MI->getOperand(3)->getType(), EltSize), // Size
-        Zero  // Align
+        EltPtr, MI->getOperand(1),  // Dest, Value,
+        ConstantInt::get(MI->getOperand(2)->getType(), EltSize), // Size
+        Zero,  // Align
+        ConstantInt::get(Type::getInt1Ty(MI->getContext()), 0) // isVolatile
       };
-      CallInst::Create(TheFn, Ops, Ops + 4, "", MI);
+      const Type *Tys[] = { Ops[0]->getType(), Ops[2]->getType() };
+      Module *M = MI->getParent()->getParent()->getParent();
+      TheFn = Intrinsic::getDeclaration(M, Intrinsic::memset, Tys, 2);
+      CallInst::Create(TheFn, Ops, Ops + 5, "", MI);
     }
   }
   DeadInsts.push_back(MI);
@@ -937,7 +959,7 @@ void SROA::RewriteStoreUserOfWholeAlloca(StoreInst *SI, AllocaInst *AI,
       Value *DestField = NewElts[i];
       if (EltVal->getType() == FieldTy) {
         // Storing to an integer field of this size, just do it.
-      } else if (FieldTy->isFloatingPoint() || isa<VectorType>(FieldTy)) {
+      } else if (FieldTy->isFloatingPointTy() || FieldTy->isVectorTy()) {
         // Bitcast to the right element type (for fp/vector values).
         EltVal = new BitCastInst(EltVal, FieldTy, "", SI);
       } else {
@@ -981,7 +1003,8 @@ void SROA::RewriteStoreUserOfWholeAlloca(StoreInst *SI, AllocaInst *AI,
       Value *DestField = NewElts[i];
       if (EltVal->getType() == ArrayEltTy) {
         // Storing to an integer field of this size, just do it.
-      } else if (ArrayEltTy->isFloatingPoint() || isa<VectorType>(ArrayEltTy)) {
+      } else if (ArrayEltTy->isFloatingPointTy() ||
+                 ArrayEltTy->isVectorTy()) {
         // Bitcast to the right element type (for fp/vector values).
         EltVal = new BitCastInst(EltVal, ArrayEltTy, "", SI);
       } else {
@@ -1041,8 +1064,8 @@ void SROA::RewriteLoadUserOfWholeAlloca(LoadInst *LI, AllocaInst *AI,
     
     const IntegerType *FieldIntTy = IntegerType::get(LI->getContext(), 
                                                      FieldSizeBits);
-    if (!isa<IntegerType>(FieldTy) && !FieldTy->isFloatingPoint() &&
-        !isa<VectorType>(FieldTy))
+    if (!FieldTy->isIntegerTy() && !FieldTy->isFloatingPointTy() &&
+        !FieldTy->isVectorTy())
       SrcField = new BitCastInst(SrcField,
                                  PointerType::getUnqual(FieldIntTy),
                                  "", LI);
@@ -1180,7 +1203,7 @@ static void MergeInType(const Type *In, uint64_t Offset, const Type *&VecTy,
         return;
       }
     } else if (In->isFloatTy() || In->isDoubleTy() ||
-               (isa<IntegerType>(In) && In->getPrimitiveSizeInBits() >= 8 &&
+               (In->isIntegerTy() && In->getPrimitiveSizeInBits() >= 8 &&
                 isPowerOf2_32(In->getPrimitiveSizeInBits()))) {
       // If we're accessing something that could be an element of a vector, see
       // if the implied vector agrees with what we already have and if Offset is
@@ -1224,7 +1247,7 @@ bool SROA::CanConvertToScalar(Value *V, bool &IsNotTrivial, const Type *&VecTy,
         return false;
       MergeInType(LI->getType(), Offset, VecTy,
                   AllocaSize, *TD, V->getContext());
-      SawVec |= isa<VectorType>(LI->getType());
+      SawVec |= LI->getType()->isVectorTy();
       continue;
     }
     
@@ -1233,7 +1256,7 @@ bool SROA::CanConvertToScalar(Value *V, bool &IsNotTrivial, const Type *&VecTy,
       if (SI->getOperand(0) == V || SI->isVolatile()) return 0;
       MergeInType(SI->getOperand(0)->getType(), Offset,
                   VecTy, AllocaSize, *TD, V->getContext());
-      SawVec |= isa<VectorType>(SI->getOperand(0)->getType());
+      SawVec |= SI->getOperand(0)->getType()->isVectorTy();
       continue;
     }
     
@@ -1435,7 +1458,7 @@ Value *SROA::ConvertScalar_ExtractValue(Value *FromVal, const Type *ToType,
   // If the result alloca is a vector type, this is either an element
   // access or a bitcast to another vector type of the same size.
   if (const VectorType *VTy = dyn_cast<VectorType>(FromVal->getType())) {
-    if (isa<VectorType>(ToType))
+    if (ToType->isVectorTy())
       return Builder.CreateBitCast(FromVal, ToType, "tmp");
 
     // Otherwise it must be an element access.
@@ -1518,9 +1541,9 @@ Value *SROA::ConvertScalar_ExtractValue(Value *FromVal, const Type *ToType,
                                                     LIBitWidth), "tmp");
 
   // If the result is an integer, this is a trunc or bitcast.
-  if (isa<IntegerType>(ToType)) {
+  if (ToType->isIntegerTy()) {
     // Should be done.
-  } else if (ToType->isFloatingPoint() || isa<VectorType>(ToType)) {
+  } else if (ToType->isFloatingPointTy() || ToType->isVectorTy()) {
     // Just do a bitcast, we know the sizes match up.
     FromVal = Builder.CreateBitCast(FromVal, ToType, "tmp");
   } else {
@@ -1598,10 +1621,10 @@ Value *SROA::ConvertScalar_InsertValue(Value *SV, Value *Old,
   unsigned DestWidth = TD->getTypeSizeInBits(AllocaType);
   unsigned SrcStoreWidth = TD->getTypeStoreSizeInBits(SV->getType());
   unsigned DestStoreWidth = TD->getTypeStoreSizeInBits(AllocaType);
-  if (SV->getType()->isFloatingPoint() || isa<VectorType>(SV->getType()))
+  if (SV->getType()->isFloatingPointTy() || SV->getType()->isVectorTy())
     SV = Builder.CreateBitCast(SV,
                             IntegerType::get(SV->getContext(),SrcWidth), "tmp");
-  else if (isa<PointerType>(SV->getType()))
+  else if (SV->getType()->isPointerTy())
     SV = Builder.CreatePtrToInt(SV, TD->getIntPtrType(SV->getContext()), "tmp");
 
   // Zero extend or truncate the value if needed.
@@ -1679,18 +1702,20 @@ static bool PointsToConstantGlobal(Value *V) {
 static bool isOnlyCopiedFromConstantGlobal(Value *V, Instruction *&TheCopy,
                                            bool isOffset) {
   for (Value::use_iterator UI = V->use_begin(), E = V->use_end(); UI!=E; ++UI) {
-    if (LoadInst *LI = dyn_cast<LoadInst>(*UI))
+    User *U = cast<Instruction>(*UI);
+
+    if (LoadInst *LI = dyn_cast<LoadInst>(U))
       // Ignore non-volatile loads, they are always ok.
       if (!LI->isVolatile())
         continue;
     
-    if (BitCastInst *BCI = dyn_cast<BitCastInst>(*UI)) {
+    if (BitCastInst *BCI = dyn_cast<BitCastInst>(U)) {
       // If uses of the bitcast are ok, we are ok.
       if (!isOnlyCopiedFromConstantGlobal(BCI, TheCopy, isOffset))
         return false;
       continue;
     }
-    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(*UI)) {
+    if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
       // If the GEP has all zero indices, it doesn't offset the pointer.  If it
       // doesn't, it does.
       if (!isOnlyCopiedFromConstantGlobal(GEP, TheCopy,
@@ -1701,7 +1726,7 @@ static bool isOnlyCopiedFromConstantGlobal(Value *V, Instruction *&TheCopy,
     
     // If this is isn't our memcpy/memmove, reject it as something we can't
     // handle.
-    if (!isa<MemTransferInst>(*UI))
+    if (!isa<MemTransferInst>(U))
       return false;
 
     // If we already have seen a copy, reject the second one.
@@ -1712,12 +1737,12 @@ static bool isOnlyCopiedFromConstantGlobal(Value *V, Instruction *&TheCopy,
     if (isOffset) return false;
 
     // If the memintrinsic isn't using the alloca as the dest, reject it.
-    if (UI.getOperandNo() != 1) return false;
+    if (UI.getOperandNo() != 0) return false;
     
-    MemIntrinsic *MI = cast<MemIntrinsic>(*UI);
+    MemIntrinsic *MI = cast<MemIntrinsic>(U);
     
     // If the source of the memcpy/move is not a constant global, reject it.
-    if (!PointsToConstantGlobal(MI->getOperand(2)))
+    if (!PointsToConstantGlobal(MI->getOperand(1)))
       return false;
     
     // Otherwise, the transform is safe.  Remember the copy instruction.