Teach ConvertUsesToScalar to handle memset, allowing it to handle
authorChris Lattner <sabre@nondot.org>
Tue, 3 Feb 2009 02:01:43 +0000 (02:01 +0000)
committerChris Lattner <sabre@nondot.org>
Tue, 3 Feb 2009 02:01:43 +0000 (02:01 +0000)
crazy cases like:

struct f {  int A, B, C, D, E, F; };
short test4() {
  struct f A;
  A.A = 1;
  memset(&A.B, 2, 12);
  return A.C;
}

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@63596 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/ScalarReplAggregates.cpp
test/Transforms/ScalarRepl/memset-aggregate.ll

index 5031619b51cb94a5f613839c56e576e4373f6549..83572d65dae3507f7d6989add2fc21b69fda0cfa 100644 (file)
@@ -130,8 +130,8 @@ namespace {
     void ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, uint64_t Offset);
     Value *ConvertUsesOfLoadToScalar(LoadInst *LI, AllocaInst *NewAI, 
                                      uint64_t Offset);
-    Value *ConvertUsesOfStoreToScalar(StoreInst *SI, AllocaInst *NewAI, 
-                                      uint64_t Offset);
+    Value *ConvertUsesOfStoreToScalar(Value *StoredVal, AllocaInst *NewAI, 
+                                      uint64_t Offset, Instruction *InsertPt);
     static Instruction *isOnlyCopiedFromConstantGlobal(AllocationInst *AI);
   };
 }
@@ -1274,6 +1274,18 @@ bool SROA::CanConvertToScalar(Value *V, bool &IsNotTrivial,
       continue;
     }
     
+    // If this is a constant sized memset of a constant value (e.g. 0) we can
+    // handle it.
+    if (isa<MemSetInst>(User) &&
+        // Store of constant value.
+        isa<ConstantInt>(User->getOperand(2)) &&
+        // Store with constant size.
+        isa<ConstantInt>(User->getOperand(3))) {
+      VecTy = Type::VoidTy;
+      IsNotTrivial = true;
+      continue;
+    }
+    
     // Otherwise, we cannot handle this!
     return false;
   }
@@ -1301,7 +1313,8 @@ void SROA::ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, uint64_t Offset) {
 
     if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
       assert(SI->getOperand(0) != Ptr && "Consistency error!");
-      new StoreInst(ConvertUsesOfStoreToScalar(SI, NewAI, Offset), NewAI, SI);
+      new StoreInst(ConvertUsesOfStoreToScalar(SI->getOperand(0), NewAI,
+                                               Offset, SI), NewAI, SI);
       SI->eraseFromParent();
       continue;
     }
@@ -1321,6 +1334,29 @@ void SROA::ConvertUsesToScalar(Value *Ptr, AllocaInst *NewAI, uint64_t Offset) {
       GEP->eraseFromParent();
       continue;
     }
+    
+    // If this is a constant sized memset of a constant value (e.g. 0) we can
+    // transform it into a store of the expanded constant value.
+    if (MemSetInst *MSI = dyn_cast<MemSetInst>(User)) {
+      assert(MSI->getRawDest() == Ptr && "Consistency error!");
+      unsigned NumBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
+      unsigned Val = cast<ConstantInt>(MSI->getValue())->getZExtValue();
+      
+      // Compute the value replicated the right number of times.
+      APInt APVal(NumBytes*8, Val);
+
+      // Splat the value if non-zero.
+      if (Val)
+        for (unsigned i = 1; i != NumBytes; ++i)
+          APVal |= APVal << 8;
+      
+      new StoreInst(ConvertUsesOfStoreToScalar(ConstantInt::get(APVal), NewAI,
+                                               Offset, MSI), NewAI, MSI);
+      MSI->eraseFromParent();
+      continue;
+    }
+        
+    
     assert(0 && "Unsupported operation!");
     abort();
   }
@@ -1422,40 +1458,38 @@ Value *SROA::ConvertUsesOfLoadToScalar(LoadInst *LI, AllocaInst *NewAI,
 ///
 /// Offset is an offset from the original alloca, in bits that need to be
 /// shifted to the right.  By the end of this, there should be no uses of Ptr.
-Value *SROA::ConvertUsesOfStoreToScalar(StoreInst *SI, AllocaInst *NewAI,
-                                        uint64_t Offset) {
+Value *SROA::ConvertUsesOfStoreToScalar(Value *SV, AllocaInst *NewAI,
+                                        uint64_t Offset, Instruction *IP) {
 
   // Convert the stored type to the actual type, shift it left to insert
   // then 'or' into place.
-  Value *SV = SI->getOperand(0);
   const Type *AllocaType = NewAI->getType()->getElementType();
-  if (SV->getType() == AllocaType && Offset == 0) {
+  if (SV->getType() == AllocaType && Offset == 0)
     return SV;
-  }
 
   if (const VectorType *VTy = dyn_cast<VectorType>(AllocaType)) {
-    Value *Old = new LoadInst(NewAI, NewAI->getName()+".in", SI);
+    Value *Old = new LoadInst(NewAI, NewAI->getName()+".in", IP);
 
     // If the result alloca is a vector type, this is either an element
     // access or a bitcast to another vector type.
     if (isa<VectorType>(SV->getType())) {
-      SV = new BitCastInst(SV, AllocaType, SV->getName(), SI);
+      SV = new BitCastInst(SV, AllocaType, SV->getName(), IP);
     } else {
       // Must be an element insertion.
       unsigned Elt = Offset/TD->getTypePaddedSizeInBits(VTy->getElementType());
       
       if (SV->getType() != VTy->getElementType())
-        SV = new BitCastInst(SV, VTy->getElementType(), "tmp", SI);
+        SV = new BitCastInst(SV, VTy->getElementType(), "tmp", IP);
       
       SV = InsertElementInst::Create(Old, SV,
                                      ConstantInt::get(Type::Int32Ty, Elt),
-                                     "tmp", SI);
+                                     "tmp", IP);
     }
     return SV;
   }
 
 
-  Value *Old = new LoadInst(NewAI, NewAI->getName()+".in", SI);
+  Value *Old = new LoadInst(NewAI, NewAI->getName()+".in", IP);
 
   // If SV is a float, convert it to the appropriate integer type.
   // If it is a pointer, do the same, and also handle ptr->ptr casts
@@ -1465,19 +1499,19 @@ Value *SROA::ConvertUsesOfStoreToScalar(StoreInst *SI, AllocaInst *NewAI,
   unsigned SrcStoreWidth = TD->getTypeStoreSizeInBits(SV->getType());
   unsigned DestStoreWidth = TD->getTypeStoreSizeInBits(AllocaType);
   if (SV->getType()->isFloatingPoint() || isa<VectorType>(SV->getType()))
-    SV = new BitCastInst(SV, IntegerType::get(SrcWidth), SV->getName(), SI);
+    SV = new BitCastInst(SV, IntegerType::get(SrcWidth), SV->getName(), IP);
   else if (isa<PointerType>(SV->getType()))
-    SV = new PtrToIntInst(SV, TD->getIntPtrType(), SV->getName(), SI);
+    SV = new PtrToIntInst(SV, TD->getIntPtrType(), SV->getName(), IP);
 
   // Zero extend or truncate the value if needed.
   if (SV->getType() != AllocaType) {
     if (SV->getType()->getPrimitiveSizeInBits() <
              AllocaType->getPrimitiveSizeInBits())
-      SV = new ZExtInst(SV, AllocaType, SV->getName(), SI);
+      SV = new ZExtInst(SV, AllocaType, SV->getName(), IP);
     else {
       // Truncation may be needed if storing more than the alloca can hold
       // (undefined behavior).
-      SV = new TruncInst(SV, AllocaType, SV->getName(), SI);
+      SV = new TruncInst(SV, AllocaType, SV->getName(), IP);
       SrcWidth = DestWidth;
       SrcStoreWidth = DestStoreWidth;
     }
@@ -1502,12 +1536,12 @@ Value *SROA::ConvertUsesOfStoreToScalar(StoreInst *SI, AllocaInst *NewAI,
   if (ShAmt > 0 && (unsigned)ShAmt < DestWidth) {
     SV = BinaryOperator::CreateShl(SV,
                                    ConstantInt::get(SV->getType(), ShAmt),
-                                   SV->getName(), SI);
+                                   SV->getName(), IP);
     Mask <<= ShAmt;
   } else if (ShAmt < 0 && (unsigned)-ShAmt < DestWidth) {
     SV = BinaryOperator::CreateLShr(SV,
                                     ConstantInt::get(SV->getType(), -ShAmt),
-                                    SV->getName(), SI);
+                                    SV->getName(), IP);
     Mask = Mask.lshr(-ShAmt);
   }
 
@@ -1516,8 +1550,8 @@ Value *SROA::ConvertUsesOfStoreToScalar(StoreInst *SI, AllocaInst *NewAI,
   if (SrcWidth != DestWidth) {
     assert(DestWidth > SrcWidth);
     Old = BinaryOperator::CreateAnd(Old, ConstantInt::get(~Mask),
-                                    Old->getName()+".mask", SI);
-    SV = BinaryOperator::CreateOr(Old, SV, SV->getName()+".ins", SI);
+                                    Old->getName()+".mask", IP);
+    SV = BinaryOperator::CreateOr(Old, SV, SV->getName()+".ins", IP);
   }
   return SV;
 }
index 4febda5b9a793f2761f9d9d3eaf94d01590bb218..b7b33521bbcefcd40ceace01245a9d07260c2723 100644 (file)
@@ -1,6 +1,7 @@
 ; PR1226
 ; RUN: llvm-as < %s | opt -scalarrepl | llvm-dis | grep {ret i32 16843009}
 ; RUN: llvm-as < %s | opt -scalarrepl | llvm-dis | not grep alloca
+; RUN: llvm-as < %s | opt -scalarrepl -instcombine | llvm-dis | grep {ret i16 514}
 
 target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64"
 target triple = "i686-apple-darwin8"
@@ -46,3 +47,20 @@ entry:
        %tmp7 = load i32* %tmp6         ; <i32> [#uses=1]
        ret i32 %tmp7
 }
+
+
+       %struct.f = type { i32, i32, i32, i32, i32, i32 }
+
+define i16 @test4() nounwind {
+entry:
+       %A = alloca %struct.f, align 8          ; <%struct.f*> [#uses=3]
+       %0 = getelementptr %struct.f* %A, i32 0, i32 0          ; <i32*> [#uses=1]
+       store i32 1, i32* %0, align 8
+       %1 = getelementptr %struct.f* %A, i32 0, i32 1          ; <i32*> [#uses=1]
+       %2 = bitcast i32* %1 to i8*             ; <i8*> [#uses=1]
+       call void @llvm.memset.i32(i8* %2, i8 2, i32 12, i32 4)
+       %3 = getelementptr %struct.f* %A, i32 0, i32 2          ; <i32*> [#uses=1]
+       %4 = load i32* %3, align 8              ; <i32> [#uses=1]
+       %retval12 = trunc i32 %4 to i16         ; <i16> [#uses=1]
+       ret i16 %retval12
+}