Massive rewrite of MMX:
[oota-llvm.git] / lib / Transforms / Scalar / ScalarReplAggregates.cpp
index 54e13c44d6904866c57582f4b661ceb99fa71972..33ecb5bf9bdf7d2ad73726edf0d9cf3bbd65ac52 100644 (file)
@@ -28,6 +28,7 @@
 #include "llvm/Instructions.h"
 #include "llvm/IntrinsicInst.h"
 #include "llvm/LLVMContext.h"
+#include "llvm/Module.h"
 #include "llvm/Pass.h"
 #include "llvm/Analysis/Dominators.h"
 #include "llvm/Target/TargetData.h"
@@ -51,7 +52,7 @@ STATISTIC(NumGlobals,   "Number of allocas copied from constant global");
 namespace {
   struct SROA : public FunctionPass {
     static char ID; // Pass identification, replacement for typeid
-    explicit SROA(signed T = -1) : FunctionPass(&ID) {
+    explicit SROA(signed T = -1) : FunctionPass(ID) {
       if (T == -1)
         SRThreshold = 128;
       else
@@ -114,8 +115,7 @@ namespace {
     void DoScalarReplacement(AllocaInst *AI, 
                              std::vector<AllocaInst*> &WorkList);
     void DeleteDeadInstructions();
-    AllocaInst *AddNewAlloca(Function &F, const Type *Ty, AllocaInst *Base);
-    
+   
     void RewriteForScalarRepl(Instruction *I, AllocaInst *AI, uint64_t Offset,
                               SmallVector<AllocaInst*, 32> &NewElts);
     void RewriteBitCast(BitCastInst *BC, AllocaInst *AI, uint64_t Offset,
@@ -135,7 +135,8 @@ namespace {
 }
 
 char SROA::ID = 0;
-static RegisterPass<SROA> X("scalarrepl", "Scalar Replacement of Aggregates");
+INITIALIZE_PASS(SROA, "scalarrepl",
+                "Scalar Replacement of Aggregates", false, false);
 
 // Public interface to the ScalarReplAggregates pass
 FunctionPass *llvm::createScalarReplAggregatesPass(signed int Threshold) { 
@@ -156,7 +157,7 @@ class ConvertToScalarInfo {
   unsigned AllocaSize;
   const TargetData &TD;
  
-  /// IsNotTrivial - This is set to true if there is somee access to the object
+  /// IsNotTrivial - This is set to true if there is some access to the object
   /// which means that mem2reg can't promote it.
   bool IsNotTrivial;
   
@@ -193,6 +194,27 @@ private:
 };
 } // end anonymous namespace.
 
+
+/// IsVerbotenVectorType - Return true if this is a vector type ScalarRepl isn't
+/// allowed to form.  We do this to avoid MMX types, which is a complete hack,
+/// but is required until the backend is fixed.
+static bool IsVerbotenVectorType(const VectorType *VTy, const Instruction *I) {
+  StringRef Triple(I->getParent()->getParent()->getParent()->getTargetTriple());
+  if (!Triple.startswith("i386") &&
+      !Triple.startswith("x86_64"))
+    return false;
+  
+  // Reject all the MMX vector types.
+  switch (VTy->getNumElements()) {
+  default: return false;
+  case 1: return VTy->getElementType()->isIntegerTy(64);
+  case 2: return VTy->getElementType()->isIntegerTy(32);
+  case 4: return VTy->getElementType()->isIntegerTy(16);
+  case 8: return VTy->getElementType()->isIntegerTy(8);
+  }
+}
+
+
 /// TryConvert - Analyze the specified alloca, and if it is safe to do so,
 /// rewrite it to be a new alloca which is mem2reg'able.  This returns the new
 /// alloca if possible or null if not.
@@ -209,7 +231,8 @@ AllocaInst *ConvertToScalarInfo::TryConvert(AllocaInst *AI) {
   // we just get a lot of insert/extracts.  If at least one vector is
   // involved, then we probably really do have a union of vector/array.
   const Type *NewTy;
-  if (VectorTy && VectorTy->isVectorTy() && HadAVector) {
+  if (VectorTy && VectorTy->isVectorTy() && HadAVector &&
+      !IsVerbotenVectorType(cast<VectorType>(VectorTy), AI)) {
     DEBUG(dbgs() << "CONVERT TO VECTOR: " << *AI << "\n  TYPE = "
           << *VectorTy << '\n');
     NewTy = VectorTy;  // Use the vector type.
@@ -298,6 +321,9 @@ bool ConvertToScalarInfo::CanConvertToScalar(Value *V, uint64_t Offset) {
       // Don't break volatile loads.
       if (LI->isVolatile())
         return false;
+      // Don't touch MMX operations.
+      if (LI->getType()->isX86_MMXTy())
+        return false;
       MergeInType(LI->getType(), Offset);
       continue;
     }
@@ -305,6 +331,9 @@ bool ConvertToScalarInfo::CanConvertToScalar(Value *V, uint64_t Offset) {
     if (StoreInst *SI = dyn_cast<StoreInst>(User)) {
       // Storing the pointer, not into the value?
       if (SI->getOperand(0) == V || SI->isVolatile()) return false;
+      // Don't touch MMX operations.
+      if (SI->getOperand(0)->getType()->isX86_MMXTy())
+        return false;
       MergeInType(SI->getOperand(0)->getType(), Offset);
       continue;
     }
@@ -926,7 +955,7 @@ void SROA::DoScalarReplacement(AllocaInst *AI,
   DeleteDeadInstructions();
   AI->eraseFromParent();
 
-  NumReplaced++;
+  ++NumReplaced;
 }
 
 /// DeleteDeadInstructions - Erase instructions on the DeadInstrs list,
@@ -1272,6 +1301,8 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
   // If there is an other pointer, we want to convert it to the same pointer
   // type as AI has, so we can GEP through it safely.
   if (OtherPtr) {
+    unsigned AddrSpace =
+      cast<PointerType>(OtherPtr->getType())->getAddressSpace();
 
     // Remove bitcasts and all-zero GEPs from OtherPtr.  This is an
     // optimization, but it's also required to detect the corner case where
@@ -1279,20 +1310,8 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
     // OtherPtr may be a bitcast or GEP that currently being rewritten.  (This
     // function is only called for mem intrinsics that access the whole
     // aggregate, so non-zero GEPs are not an issue here.)
-    while (1) {
-      if (BitCastInst *BC = dyn_cast<BitCastInst>(OtherPtr)) {
-        OtherPtr = BC->getOperand(0);
-        continue;
-      }
-      if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(OtherPtr)) {
-        // All zero GEPs are effectively bitcasts.
-        if (GEP->hasAllZeroIndices()) {
-          OtherPtr = GEP->getOperand(0);
-          continue;
-        }
-      }
-      break;
-    }
+    OtherPtr = OtherPtr->stripPointerCasts();
+    
     // Copying the alloca to itself is a no-op: just delete it.
     if (OtherPtr == AI || OtherPtr == NewElts[0]) {
       // This code will run twice for a no-op memcpy -- once for each operand.
@@ -1304,15 +1323,13 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
       return;
     }
     
-    if (ConstantExpr *BCE = dyn_cast<ConstantExpr>(OtherPtr))
-      if (BCE->getOpcode() == Instruction::BitCast)
-        OtherPtr = BCE->getOperand(0);
-    
     // If the pointer is not the right type, insert a bitcast to the right
     // type.
-    if (OtherPtr->getType() != AI->getType())
-      OtherPtr = new BitCastInst(OtherPtr, AI->getType(), OtherPtr->getName(),
-                                 MI);
+    const Type *NewTy =
+      PointerType::get(AI->getType()->getElementType(), AddrSpace);
+    
+    if (OtherPtr->getType() != NewTy)
+      OtherPtr = new BitCastInst(OtherPtr, NewTy, OtherPtr->getName(), MI);
   }
   
   // Process each element of the aggregate.
@@ -1373,7 +1390,7 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
       // If the stored element is zero (common case), just store a null
       // constant.
       Constant *StoreVal;
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getOperand(1))) {
+      if (ConstantInt *CI = dyn_cast<ConstantInt>(MI->getArgOperand(1))) {
         if (CI->isZero()) {
           StoreVal = Constant::getNullValue(EltTy);  // 0.0, null, 0, <0,0>
         } else {
@@ -1436,7 +1453,7 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
       Value *Ops[] = {
         SROADest ? EltPtr : OtherElt,  // Dest ptr
         SROADest ? OtherElt : EltPtr,  // Src ptr
-        ConstantInt::get(MI->getOperand(2)->getType(), EltSize), // Size
+        ConstantInt::get(MI->getArgOperand(2)->getType(), EltSize), // Size
         // Align
         ConstantInt::get(Type::getInt32Ty(MI->getContext()), OtherEltAlign),
         MI->getVolatileCst()
@@ -1451,8 +1468,8 @@ void SROA::RewriteMemIntrinUserOfAlloca(MemIntrinsic *MI, Instruction *Inst,
     } else {
       assert(isa<MemSetInst>(MI));
       Value *Ops[] = {
-        EltPtr, MI->getOperand(1),  // Dest, Value,
-        ConstantInt::get(MI->getOperand(2)->getType(), EltSize), // Size
+        EltPtr, MI->getArgOperand(1),  // Dest, Value,
+        ConstantInt::get(MI->getArgOperand(2)->getType(), EltSize), // Size
         Zero,  // Align
         ConstantInt::get(Type::getInt1Ty(MI->getContext()), 0) // isVolatile
       };
@@ -1655,7 +1672,12 @@ void SROA::RewriteLoadUserOfWholeAlloca(LoadInst *LI, AllocaInst *AI,
       SrcField = BinaryOperator::CreateShl(SrcField, ShiftVal, "", LI);
     }
 
-    ResultVal = BinaryOperator::CreateOr(SrcField, ResultVal, "", LI);
+    // Don't create an 'or x, 0' on the first iteration.
+    if (!isa<Constant>(ResultVal) ||
+        !cast<Constant>(ResultVal)->isNullValue())
+      ResultVal = BinaryOperator::CreateOr(SrcField, ResultVal, "", LI);
+    else
+      ResultVal = SrcField;
   }
 
   // Handle tail padding by truncating the result
@@ -1669,6 +1691,12 @@ void SROA::RewriteLoadUserOfWholeAlloca(LoadInst *LI, AllocaInst *AI,
 /// HasPadding - Return true if the specified type has any structure or
 /// alignment padding, false otherwise.
 static bool HasPadding(const Type *Ty, const TargetData &TD) {
+  if (const ArrayType *ATy = dyn_cast<ArrayType>(Ty))
+    return HasPadding(ATy->getElementType(), TD);
+  
+  if (const VectorType *VTy = dyn_cast<VectorType>(Ty))
+    return HasPadding(VTy->getElementType(), TD);
+  
   if (const StructType *STy = dyn_cast<StructType>(Ty)) {
     const StructLayout *SL = TD.getStructLayout(STy);
     unsigned PrevFieldBitOffset = 0;
@@ -1698,12 +1726,8 @@ static bool HasPadding(const Type *Ty, const TargetData &TD) {
       if (PrevFieldEnd < SL->getSizeInBits())
         return true;
     }
-
-  } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Ty)) {
-    return HasPadding(ATy->getElementType(), TD);
-  } else if (const VectorType *VTy = dyn_cast<VectorType>(Ty)) {
-    return HasPadding(VTy->getElementType(), TD);
   }
+  
   return TD.getTypeSizeInBits(Ty) != TD.getTypeAllocSizeInBits(Ty);
 }