[PM/AA] Hoist ScopedNoAliasAA's interface into a header and move the
[oota-llvm.git] / lib / Transforms / Scalar / MemCpyOptimizer.cpp
index 041312b7ac619a5aed446eea584881cee8188ff3..3c2a498669e15423af38ca2db719c4d53d9719af 100644 (file)
@@ -18,6 +18,7 @@
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/MemoryDependenceAnalysis.h"
+#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Dominators.h"
@@ -28,9 +29,8 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
-#include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Transforms/Utils/Local.h"
-#include <list>
+#include <algorithm>
 using namespace llvm;
 
 #define DEBUG_TYPE "memcpyopt"
@@ -41,7 +41,8 @@ STATISTIC(NumMoveToCpy,   "Number of memmoves converted to memcpy");
 STATISTIC(NumCpyToSet,    "Number of memcpys converted to memset");
 
 static int64_t GetOffsetFromIndex(const GEPOperator *GEP, unsigned Idx,
-                                  bool &VariableIdxFound, const DataLayout &TD){
+                                  bool &VariableIdxFound,
+                                  const DataLayout &DL) {
   // Skip over the first indices.
   gep_type_iterator GTI = gep_type_begin(GEP);
   for (unsigned i = 1; i != Idx; ++i, ++GTI)
@@ -57,24 +58,24 @@ static int64_t GetOffsetFromIndex(const GEPOperator *GEP, unsigned Idx,
 
     // Handle struct indices, which add their field offset to the pointer.
     if (StructType *STy = dyn_cast<StructType>(*GTI)) {
-      Offset += TD.getStructLayout(STy)->getElementOffset(OpC->getZExtValue());
+      Offset += DL.getStructLayout(STy)->getElementOffset(OpC->getZExtValue());
       continue;
     }
 
     // Otherwise, we have a sequential type like an array or vector.  Multiply
     // the index by the ElementSize.
-    uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType());
+    uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType());
     Offset += Size*OpC->getSExtValue();
   }
 
   return Offset;
 }
 
-/// IsPointerOffset - Return true if Ptr1 is provably equal to Ptr2 plus a
-/// constant offset, and return that constant offset.  For example, Ptr1 might
-/// be &A[42], and Ptr2 might be &A[40].  In this case offset would be -8.
+/// Return true if Ptr1 is provably equal to Ptr2 plus a constant offset, and
+/// return that constant offset. For example, Ptr1 might be &A[42], and Ptr2
+/// might be &A[40]. In this case offset would be -8.
 static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset,
-                            const DataLayout &TD) {
+                            const DataLayout &DL) {
   Ptr1 = Ptr1->stripPointerCasts();
   Ptr2 = Ptr2->stripPointerCasts();
 
@@ -92,12 +93,12 @@ static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset,
   // If one pointer is a GEP and the other isn't, then see if the GEP is a
   // constant offset from the base, as in "P" and "gep P, 1".
   if (GEP1 && !GEP2 && GEP1->getOperand(0)->stripPointerCasts() == Ptr2) {
-    Offset = -GetOffsetFromIndex(GEP1, 1, VariableIdxFound, TD);
+    Offset = -GetOffsetFromIndex(GEP1, 1, VariableIdxFound, DL);
     return !VariableIdxFound;
   }
 
   if (GEP2 && !GEP1 && GEP2->getOperand(0)->stripPointerCasts() == Ptr1) {
-    Offset = GetOffsetFromIndex(GEP2, 1, VariableIdxFound, TD);
+    Offset = GetOffsetFromIndex(GEP2, 1, VariableIdxFound, DL);
     return !VariableIdxFound;
   }
 
@@ -115,8 +116,8 @@ static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset,
     if (GEP1->getOperand(Idx) != GEP2->getOperand(Idx))
       break;
 
-  int64_t Offset1 = GetOffsetFromIndex(GEP1, Idx, VariableIdxFound, TD);
-  int64_t Offset2 = GetOffsetFromIndex(GEP2, Idx, VariableIdxFound, TD);
+  int64_t Offset1 = GetOffsetFromIndex(GEP1, Idx, VariableIdxFound, DL);
+  int64_t Offset2 = GetOffsetFromIndex(GEP2, Idx, VariableIdxFound, DL);
   if (VariableIdxFound) return false;
 
   Offset = Offset2-Offset1;
@@ -124,7 +125,7 @@ static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset,
 }
 
 
-/// MemsetRange - Represents a range of memset'd bytes with the ByteVal value.
+/// Represents a range of memset'd bytes with the ByteVal value.
 /// This allows us to analyze stores like:
 ///   store 0 -> P+1
 ///   store 0 -> P+0
@@ -150,12 +151,11 @@ struct MemsetRange {
   /// TheStores - The actual stores that make up this range.
   SmallVector<Instruction*, 16> TheStores;
 
-  bool isProfitableToUseMemset(const DataLayout &TD) const;
-
+  bool isProfitableToUseMemset(const DataLayout &DL) const;
 };
 } // end anon namespace
 
-bool MemsetRange::isProfitableToUseMemset(const DataLayout &TD) const {
+bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const {
   // If we found more than 4 stores to merge or 16 bytes, use memset.
   if (TheStores.size() >= 4 || End-Start >= 16) return true;
 
@@ -183,7 +183,7 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &TD) const {
   // size. If so, check to see whether we will end up actually reducing the
   // number of stores used.
   unsigned Bytes = unsigned(End-Start);
-  unsigned MaxIntSize = TD.getLargestLegalIntTypeSize();
+  unsigned MaxIntSize = DL.getLargestLegalIntTypeSize();
   if (MaxIntSize == 0)
     MaxIntSize = 1;
   unsigned NumPointerStores = Bytes / MaxIntSize;
@@ -200,15 +200,14 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &TD) const {
 
 namespace {
 class MemsetRanges {
-  /// Ranges - A sorted list of the memset ranges.  We use std::list here
-  /// because each element is relatively large and expensive to copy.
-  std::list<MemsetRange> Ranges;
-  typedef std::list<MemsetRange>::iterator range_iterator;
+  /// A sorted list of the memset ranges.
+  SmallVector<MemsetRange, 8> Ranges;
+  typedef SmallVectorImpl<MemsetRange>::iterator range_iterator;
   const DataLayout &DL;
 public:
   MemsetRanges(const DataLayout &DL) : DL(DL) {}
 
-  typedef std::list<MemsetRange>::const_iterator const_iterator;
+  typedef SmallVectorImpl<MemsetRange>::const_iterator const_iterator;
   const_iterator begin() const { return Ranges.begin(); }
   const_iterator end() const { return Ranges.end(); }
   bool empty() const { return Ranges.empty(); }
@@ -240,26 +239,20 @@ public:
 } // end anon namespace
 
 
-/// addRange - Add a new store to the MemsetRanges data structure.  This adds a
+/// Add a new store to the MemsetRanges data structure.  This adds a
 /// new range for the specified store at the specified offset, merging into
 /// existing ranges as appropriate.
-///
-/// Do a linear search of the ranges to see if this can be joined and/or to
-/// find the insertion point in the list.  We keep the ranges sorted for
-/// simplicity here.  This is a linear search of a linked list, which is ugly,
-/// however the number of ranges is limited, so this won't get crazy slow.
 void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr,
                             unsigned Alignment, Instruction *Inst) {
   int64_t End = Start+Size;
-  range_iterator I = Ranges.begin(), E = Ranges.end();
 
-  while (I != E && Start > I->End)
-    ++I;
+  range_iterator I = std::lower_bound(Ranges.begin(), Ranges.end(), Start,
+    [](const MemsetRange &LHS, int64_t RHS) { return LHS.End < RHS; });
 
   // We now know that I == E, in which case we didn't find anything to merge
   // with, or that Start <= I->End.  If End < I->Start or I == E, then we need
   // to insert a new range.  Handle this now.
-  if (I == E || End < I->Start) {
+  if (I == Ranges.end() || End < I->Start) {
     MemsetRange &R = *Ranges.insert(I, MemsetRange());
     R.Start        = Start;
     R.End          = End;
@@ -295,7 +288,7 @@ void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr,
   if (End > I->End) {
     I->End = End;
     range_iterator NextI = I;
-    while (++NextI != E && End >= NextI->Start) {
+    while (++NextI != Ranges.end() && End >= NextI->Start) {
       // Merge the range in.
       I->TheStores.append(NextI->TheStores.begin(), NextI->TheStores.end());
       if (NextI->End > I->End)
@@ -314,14 +307,12 @@ namespace {
   class MemCpyOpt : public FunctionPass {
     MemoryDependenceAnalysis *MD;
     TargetLibraryInfo *TLI;
-    const DataLayout *DL;
   public:
     static char ID; // Pass identification, replacement for typeid
     MemCpyOpt() : FunctionPass(ID) {
       initializeMemCpyOptPass(*PassRegistry::getPassRegistry());
       MD = nullptr;
       TLI = nullptr;
-      DL = nullptr;
     }
 
     bool runOnFunction(Function &F) override;
@@ -339,15 +330,16 @@ namespace {
       AU.addPreserved<MemoryDependenceAnalysis>();
     }
 
-    // Helper fuctions
+    // Helper functions
     bool processStore(StoreInst *SI, BasicBlock::iterator &BBI);
     bool processMemSet(MemSetInst *SI, BasicBlock::iterator &BBI);
     bool processMemCpy(MemCpyInst *M);
     bool processMemMove(MemMoveInst *M);
     bool performCallSlotOptzn(Instruction *cpy, Value *cpyDst, Value *cpySrc,
                               uint64_t cpyLen, unsigned cpyAlign, CallInst *C);
-    bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep,
-                                       uint64_t MSize);
+    bool processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep);
+    bool processMemSetMemCpyDependence(MemCpyInst *M, MemSetInst *MDep);
+    bool performMemCpyToMemSetOptzn(MemCpyInst *M, MemSetInst *MDep);
     bool processByValArgument(CallSite CS, unsigned ArgNo);
     Instruction *tryMergingIntoMemset(Instruction *I, Value *StartPtr,
                                       Value *ByteVal);
@@ -358,7 +350,7 @@ namespace {
   char MemCpyOpt::ID = 0;
 }
 
-// createMemCpyOptPass - The public interface to this file...
+/// The public interface to this file...
 FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOpt(); }
 
 INITIALIZE_PASS_BEGIN(MemCpyOpt, "memcpyopt", "MemCpy Optimization",
@@ -371,19 +363,19 @@ INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
 INITIALIZE_PASS_END(MemCpyOpt, "memcpyopt", "MemCpy Optimization",
                     false, false)
 
-/// tryMergingIntoMemset - When scanning forward over instructions, we look for
-/// some other patterns to fold away.  In particular, this looks for stores to
-/// neighboring locations of memory.  If it sees enough consecutive ones, it
-/// attempts to merge them together into a memcpy/memset.
+/// When scanning forward over instructions, we look for some other patterns to
+/// fold away. In particular, this looks for stores to neighboring locations of
+/// memory. If it sees enough consecutive ones, it attempts to merge them
+/// together into a memcpy/memset.
 Instruction *MemCpyOpt::tryMergingIntoMemset(Instruction *StartInst,
                                              Value *StartPtr, Value *ByteVal) {
-  if (!DL) return nullptr;
+  const DataLayout &DL = StartInst->getModule()->getDataLayout();
 
   // Okay, so we now have a single store that can be splatable.  Scan to find
   // all subsequent stores of the same value to offset from the same pointer.
   // Join these together into ranges, so we can decide whether contiguous blocks
   // are stored.
-  MemsetRanges Ranges(*DL);
+  MemsetRanges Ranges(DL);
 
   BasicBlock::iterator BI = StartInst;
   for (++BI; !isa<TerminatorInst>(BI); ++BI) {
@@ -406,8 +398,8 @@ Instruction *MemCpyOpt::tryMergingIntoMemset(Instruction *StartInst,
 
       // Check to see if this store is to a constant offset from the start ptr.
       int64_t Offset;
-      if (!IsPointerOffset(StartPtr, NextStore->getPointerOperand(),
-                           Offset, *DL))
+      if (!IsPointerOffset(StartPtr, NextStore->getPointerOperand(), Offset,
+                           DL))
         break;
 
       Ranges.addStore(Offset, NextStore);
@@ -420,7 +412,7 @@ Instruction *MemCpyOpt::tryMergingIntoMemset(Instruction *StartInst,
 
       // Check to see if this store is to a constant offset from the start ptr.
       int64_t Offset;
-      if (!IsPointerOffset(StartPtr, MSI->getDest(), Offset, *DL))
+      if (!IsPointerOffset(StartPtr, MSI->getDest(), Offset, DL))
         break;
 
       Ranges.addMemSet(Offset, MSI);
@@ -452,7 +444,7 @@ Instruction *MemCpyOpt::tryMergingIntoMemset(Instruction *StartInst,
     if (Range.TheStores.size() == 1) continue;
 
     // If it is profitable to lower this range to memset, do so now.
-    if (!Range.isProfitableToUseMemset(*DL))
+    if (!Range.isProfitableToUseMemset(DL))
       continue;
 
     // Otherwise, we do want to transform this!  Create a new memset.
@@ -464,7 +456,7 @@ Instruction *MemCpyOpt::tryMergingIntoMemset(Instruction *StartInst,
     if (Alignment == 0) {
       Type *EltType =
         cast<PointerType>(StartPtr->getType())->getElementType();
-      Alignment = DL->getABITypeAlignment(EltType);
+      Alignment = DL.getABITypeAlignment(EltType);
     }
 
     AMemSet =
@@ -494,8 +486,7 @@ Instruction *MemCpyOpt::tryMergingIntoMemset(Instruction *StartInst,
 
 bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
   if (!SI->isSimple()) return false;
-
-  if (!DL) return false;
+  const DataLayout &DL = SI->getModule()->getDataLayout();
 
   // Detect cases where we're performing call slot forwarding, but
   // happen to be using a load-store pair to implement it, rather than
@@ -512,10 +503,10 @@ bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
         // Check that nothing touches the dest of the "copy" between
         // the call and the store.
         AliasAnalysis &AA = getAnalysis<AliasAnalysis>();
-        AliasAnalysis::Location StoreLoc = AA.getLocation(SI);
+        MemoryLocation StoreLoc = MemoryLocation::get(SI);
         for (BasicBlock::iterator I = --BasicBlock::iterator(SI),
                                   E = C; I != E; --I) {
-          if (AA.getModRefInfo(&*I, StoreLoc) != AliasAnalysis::NoModRef) {
+          if (AA.getModRefInfo(&*I, StoreLoc) != MRI_NoModRef) {
             C = nullptr;
             break;
           }
@@ -525,16 +516,16 @@ bool MemCpyOpt::processStore(StoreInst *SI, BasicBlock::iterator &BBI) {
       if (C) {
         unsigned storeAlign = SI->getAlignment();
         if (!storeAlign)
-          storeAlign = DL->getABITypeAlignment(SI->getOperand(0)->getType());
+          storeAlign = DL.getABITypeAlignment(SI->getOperand(0)->getType());
         unsigned loadAlign = LI->getAlignment();
         if (!loadAlign)
-          loadAlign = DL->getABITypeAlignment(LI->getType());
+          loadAlign = DL.getABITypeAlignment(LI->getType());
 
-        bool changed = performCallSlotOptzn(LI,
-                        SI->getPointerOperand()->stripPointerCasts(),
-                        LI->getPointerOperand()->stripPointerCasts(),
-                        DL->getTypeStoreSize(SI->getOperand(0)->getType()),
-                        std::min(storeAlign, loadAlign), C);
+        bool changed = performCallSlotOptzn(
+            LI, SI->getPointerOperand()->stripPointerCasts(),
+            LI->getPointerOperand()->stripPointerCasts(),
+            DL.getTypeStoreSize(SI->getOperand(0)->getType()),
+            std::min(storeAlign, loadAlign), C);
         if (changed) {
           MD->removeInstruction(SI);
           SI->eraseFromParent();
@@ -576,7 +567,7 @@ bool MemCpyOpt::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) {
 }
 
 
-/// performCallSlotOptzn - takes a memcpy and a call that it depends on,
+/// Takes a memcpy and a call that it depends on,
 /// and checks for the possibility of a call slot optimization by having
 /// the call write its result directly into the destination of the memcpy.
 bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy,
@@ -606,15 +597,13 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy,
   if (!srcAlloca)
     return false;
 
-  // Check that all of src is copied to dest.
-  if (!DL) return false;
-
   ConstantInt *srcArraySize = dyn_cast<ConstantInt>(srcAlloca->getArraySize());
   if (!srcArraySize)
     return false;
 
-  uint64_t srcSize = DL->getTypeAllocSize(srcAlloca->getAllocatedType()) *
-    srcArraySize->getZExtValue();
+  const DataLayout &DL = cpy->getModule()->getDataLayout();
+  uint64_t srcSize = DL.getTypeAllocSize(srcAlloca->getAllocatedType()) *
+                     srcArraySize->getZExtValue();
 
   if (cpyLen < srcSize)
     return false;
@@ -628,8 +617,8 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy,
     if (!destArraySize)
       return false;
 
-    uint64_t destSize = DL->getTypeAllocSize(A->getAllocatedType()) *
-      destArraySize->getZExtValue();
+    uint64_t destSize = DL.getTypeAllocSize(A->getAllocatedType()) *
+                        destArraySize->getZExtValue();
 
     if (destSize < srcSize)
       return false;
@@ -648,7 +637,7 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy,
         return false;
       }
 
-      uint64_t destSize = DL->getTypeAllocSize(StructTy);
+      uint64_t destSize = DL.getTypeAllocSize(StructTy);
       if (destSize < srcSize)
         return false;
     }
@@ -659,7 +648,7 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy,
   // Check that dest points to memory that is at least as aligned as src.
   unsigned srcAlign = srcAlloca->getAlignment();
   if (!srcAlign)
-    srcAlign = DL->getABITypeAlignment(srcAlloca->getAllocatedType());
+    srcAlign = DL.getABITypeAlignment(srcAlloca->getAllocatedType());
   bool isDestSufficientlyAligned = srcAlign <= cpyAlign;
   // If dest is not aligned enough and we can't increase its alignment then
   // bail out.
@@ -715,11 +704,11 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy,
   // the use analysis, we also need to know that it does not sneakily
   // access dest.  We rely on AA to figure this out for us.
   AliasAnalysis &AA = getAnalysis<AliasAnalysis>();
-  AliasAnalysis::ModRefResult MR = AA.getModRefInfo(C, cpyDest, srcSize);
+  ModRefInfo MR = AA.getModRefInfo(C, cpyDest, srcSize);
   // If necessary, perform additional analysis.
-  if (MR != AliasAnalysis::NoModRef)
+  if (MR != MRI_NoModRef)
     MR = AA.callCapturesBefore(C, cpyDest, srcSize, &DT);
-  if (MR != AliasAnalysis::NoModRef)
+  if (MR != MRI_NoModRef)
     return false;
 
   // All the checks have passed, so do the transformation.
@@ -750,6 +739,16 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy,
   // its dependence information by changing its parameter.
   MD->removeInstruction(C);
 
+  // Update AA metadata
+  // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be
+  // handled here, but combineMetadata doesn't support them yet
+  unsigned KnownIDs[] = {
+    LLVMContext::MD_tbaa,
+    LLVMContext::MD_alias_scope,
+    LLVMContext::MD_noalias,
+  };
+  combineMetadata(C, cpy, KnownIDs);
+
   // Remove the memcpy.
   MD->removeInstruction(cpy);
   ++NumMemCpyInstr;
@@ -757,12 +756,9 @@ bool MemCpyOpt::performCallSlotOptzn(Instruction *cpy,
   return true;
 }
 
-/// processMemCpyMemCpyDependence - We've found that the (upward scanning)
-/// memory dependence of memcpy 'M' is the memcpy 'MDep'.  Try to simplify M to
-/// copy from MDep's input if we can.  MSize is the size of M's copy.
-///
-bool MemCpyOpt::processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep,
-                                              uint64_t MSize) {
+/// We've found that the (upward scanning) memory dependence of memcpy 'M' is
+/// the memcpy 'MDep'. Try to simplify M to copy from MDep's input if we can.
+bool MemCpyOpt::processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep) {
   // We can only transforms memcpy's where the dest of one is the source of the
   // other.
   if (M->getSource() != MDep->getDest() || MDep->isVolatile())
@@ -797,9 +793,8 @@ bool MemCpyOpt::processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep,
   //
   // NOTE: This is conservative, it will stop on any read from the source loc,
   // not just the defining memcpy.
-  MemDepResult SourceDep =
-    MD->getPointerDependencyFrom(AA.getLocationForSource(MDep),
-                                 false, M, M->getParent());
+  MemDepResult SourceDep = MD->getPointerDependencyFrom(
+      MemoryLocation::getForSource(MDep), false, M, M->getParent());
   if (!SourceDep.isClobber() || SourceDep.getInst() != MDep)
     return false;
 
@@ -807,7 +802,8 @@ bool MemCpyOpt::processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep,
   // source and dest might overlap.  We still want to eliminate the intermediate
   // value, but we have to generate a memmove instead of memcpy.
   bool UseMemMove = false;
-  if (!AA.isNoAlias(AA.getLocationForDest(M), AA.getLocationForSource(MDep)))
+  if (!AA.isNoAlias(MemoryLocation::getForDest(M),
+                    MemoryLocation::getForSource(MDep)))
     UseMemMove = true;
 
   // If all checks passed, then we can transform M.
@@ -834,8 +830,104 @@ bool MemCpyOpt::processMemCpyMemCpyDependence(MemCpyInst *M, MemCpyInst *MDep,
   return true;
 }
 
+/// We've found that the (upward scanning) memory dependence of \p MemCpy is
+/// \p MemSet.  Try to simplify \p MemSet to only set the trailing bytes that
+/// weren't copied over by \p MemCpy.
+///
+/// In other words, transform:
+/// \code
+///   memset(dst, c, dst_size);
+///   memcpy(dst, src, src_size);
+/// \endcode
+/// into:
+/// \code
+///   memcpy(dst, src, src_size);
+///   memset(dst + src_size, c, dst_size <= src_size ? 0 : dst_size - src_size);
+/// \endcode
+bool MemCpyOpt::processMemSetMemCpyDependence(MemCpyInst *MemCpy,
+                                              MemSetInst *MemSet) {
+  // We can only transform memset/memcpy with the same destination.
+  if (MemSet->getDest() != MemCpy->getDest())
+    return false;
+
+  // Check that there are no other dependencies on the memset destination.
+  MemDepResult DstDepInfo = MD->getPointerDependencyFrom(
+      MemoryLocation::getForDest(MemSet), false, MemCpy, MemCpy->getParent());
+  if (DstDepInfo.getInst() != MemSet)
+    return false;
+
+  // Use the same i8* dest as the memcpy, killing the memset dest if different.
+  Value *Dest = MemCpy->getRawDest();
+  Value *DestSize = MemSet->getLength();
+  Value *SrcSize = MemCpy->getLength();
+
+  // By default, create an unaligned memset.
+  unsigned Align = 1;
+  // If Dest is aligned, and SrcSize is constant, use the minimum alignment
+  // of the sum.
+  const unsigned DestAlign =
+      std::max(MemSet->getAlignment(), MemCpy->getAlignment());
+  if (DestAlign > 1)
+    if (ConstantInt *SrcSizeC = dyn_cast<ConstantInt>(SrcSize))
+      Align = MinAlign(SrcSizeC->getZExtValue(), DestAlign);
+
+  IRBuilder<> Builder(MemCpy);
+
+  // If the sizes have different types, zext the smaller one.
+  if (DestSize->getType() != SrcSize->getType()) {
+    if (DestSize->getType()->getIntegerBitWidth() >
+        SrcSize->getType()->getIntegerBitWidth())
+      SrcSize = Builder.CreateZExt(SrcSize, DestSize->getType());
+    else
+      DestSize = Builder.CreateZExt(DestSize, SrcSize->getType());
+  }
 
-/// processMemCpy - perform simplification of memcpy's.  If we have memcpy A
+  Value *MemsetLen =
+      Builder.CreateSelect(Builder.CreateICmpULE(DestSize, SrcSize),
+                           ConstantInt::getNullValue(DestSize->getType()),
+                           Builder.CreateSub(DestSize, SrcSize));
+  Builder.CreateMemSet(Builder.CreateGEP(Dest, SrcSize), MemSet->getOperand(1),
+                       MemsetLen, Align);
+
+  MD->removeInstruction(MemSet);
+  MemSet->eraseFromParent();
+  return true;
+}
+
+/// Transform memcpy to memset when its source was just memset.
+/// In other words, turn:
+/// \code
+///   memset(dst1, c, dst1_size);
+///   memcpy(dst2, dst1, dst2_size);
+/// \endcode
+/// into:
+/// \code
+///   memset(dst1, c, dst1_size);
+///   memset(dst2, c, dst2_size);
+/// \endcode
+/// When dst2_size <= dst1_size.
+///
+/// The \p MemCpy must have a Constant length.
+bool MemCpyOpt::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy,
+                                           MemSetInst *MemSet) {
+  // This only makes sense on memcpy(..., memset(...), ...).
+  if (MemSet->getRawDest() != MemCpy->getRawSource())
+    return false;
+
+  ConstantInt *CopySize = cast<ConstantInt>(MemCpy->getLength());
+  ConstantInt *MemSetSize = dyn_cast<ConstantInt>(MemSet->getLength());
+  // Make sure the memcpy doesn't read any more than what the memset wrote.
+  // Don't worry about sizes larger than i64.
+  if (!MemSetSize || CopySize->getZExtValue() > MemSetSize->getZExtValue())
+    return false;
+
+  IRBuilder<> Builder(MemCpy);
+  Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1),
+                       CopySize, MemCpy->getAlignment());
+  return true;
+}
+
+/// Perform simplification of memcpy's.  If we have memcpy A
 /// which copies X to Y, and memcpy B which copies Y to Z, then we can rewrite
 /// B to be a memcpy from X to Z (or potentially a memmove, depending on
 /// circumstances). This allows later passes to remove the first memcpy
@@ -864,17 +956,26 @@ bool MemCpyOpt::processMemCpy(MemCpyInst *M) {
         return true;
       }
 
+  MemDepResult DepInfo = MD->getDependency(M);
+
+  // Try to turn a partially redundant memset + memcpy into
+  // memcpy + smaller memset.  We don't need the memcpy size for this.
+  if (DepInfo.isClobber())
+    if (MemSetInst *MDep = dyn_cast<MemSetInst>(DepInfo.getInst()))
+      if (processMemSetMemCpyDependence(M, MDep))
+        return true;
+
   // The optimizations after this point require the memcpy size.
   ConstantInt *CopySize = dyn_cast<ConstantInt>(M->getLength());
   if (!CopySize) return false;
 
-  // The are three possible optimizations we can do for memcpy:
+  // There are four possible optimizations we can do for memcpy:
   //   a) memcpy-memcpy xform which exposes redundance for DSE.
   //   b) call-memcpy xform for return slot optimization.
   //   c) memcpy from freshly alloca'd space or space that has just started its
   //      lifetime copies undefined data, and we can therefore eliminate the
   //      memcpy in favor of the data that was already at the destination.
-  MemDepResult DepInfo = MD->getDependency(M);
+  //   d) memcpy from a just-memset'd source can be turned into memset.
   if (DepInfo.isClobber()) {
     if (CallInst *C = dyn_cast<CallInst>(DepInfo.getInst())) {
       if (performCallSlotOptzn(M, M->getDest(), M->getSource(),
@@ -887,12 +988,13 @@ bool MemCpyOpt::processMemCpy(MemCpyInst *M) {
     }
   }
 
-  AliasAnalysis::Location SrcLoc = AliasAnalysis::getLocationForSource(M);
+  MemoryLocation SrcLoc = MemoryLocation::getForSource(M);
   MemDepResult SrcDepInfo = MD->getPointerDependencyFrom(SrcLoc, true,
                                                          M, M->getParent());
+
   if (SrcDepInfo.isClobber()) {
     if (MemCpyInst *MDep = dyn_cast<MemCpyInst>(SrcDepInfo.getInst()))
-      return processMemCpyMemCpyDependence(M, MDep, CopySize->getZExtValue());
+      return processMemCpyMemCpyDependence(M, MDep);
   } else if (SrcDepInfo.isDef()) {
     Instruction *I = SrcDepInfo.getInst();
     bool hasUndefContents = false;
@@ -914,11 +1016,20 @@ bool MemCpyOpt::processMemCpy(MemCpyInst *M) {
     }
   }
 
+  if (SrcDepInfo.isClobber())
+    if (MemSetInst *MDep = dyn_cast<MemSetInst>(SrcDepInfo.getInst()))
+      if (performMemCpyToMemSetOptzn(M, MDep)) {
+        MD->removeInstruction(M);
+        M->eraseFromParent();
+        ++NumCpyToSet;
+        return true;
+      }
+
   return false;
 }
 
-/// processMemMove - Transforms memmove calls to memcpy calls when the src/dst
-/// are guaranteed not to alias.
+/// Transforms memmove calls to memcpy calls when the src/dst are guaranteed
+/// not to alias.
 bool MemCpyOpt::processMemMove(MemMoveInst *M) {
   AliasAnalysis &AA = getAnalysis<AliasAnalysis>();
 
@@ -926,7 +1037,8 @@ bool MemCpyOpt::processMemMove(MemMoveInst *M) {
     return false;
 
   // See if the pointers alias.
-  if (!AA.isNoAlias(AA.getLocationForDest(M), AA.getLocationForSource(M)))
+  if (!AA.isNoAlias(MemoryLocation::getForDest(M),
+                    MemoryLocation::getForSource(M)))
     return false;
 
   DEBUG(dbgs() << "MemCpyOpt: Optimizing memmove -> memcpy: " << *M << "\n");
@@ -947,18 +1059,16 @@ bool MemCpyOpt::processMemMove(MemMoveInst *M) {
   return true;
 }
 
-/// processByValArgument - This is called on every byval argument in call sites.
+/// This is called on every byval argument in call sites.
 bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) {
-  if (!DL) return false;
-
+  const DataLayout &DL = CS.getCaller()->getParent()->getDataLayout();
   // Find out what feeds this byval argument.
   Value *ByValArg = CS.getArgument(ArgNo);
   Type *ByValTy = cast<PointerType>(ByValArg->getType())->getElementType();
-  uint64_t ByValSize = DL->getTypeAllocSize(ByValTy);
-  MemDepResult DepInfo =
-    MD->getPointerDependencyFrom(AliasAnalysis::Location(ByValArg, ByValSize),
-                                 true, CS.getInstruction(),
-                                 CS.getInstruction()->getParent());
+  uint64_t ByValSize = DL.getTypeAllocSize(ByValTy);
+  MemDepResult DepInfo = MD->getPointerDependencyFrom(
+      MemoryLocation(ByValArg, ByValSize), true, CS.getInstruction(),
+      CS.getInstruction()->getParent());
   if (!DepInfo.isClobber())
     return false;
 
@@ -987,8 +1097,8 @@ bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) {
           *CS->getParent()->getParent());
   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
   if (MDep->getAlignment() < ByValAlign &&
-      getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, &AC,
-                                 CS.getInstruction(), &DT) < ByValAlign)
+      getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL,
+                                 CS.getInstruction(), &AC, &DT) < ByValAlign)
     return false;
 
   // Verify that the copied-from memory doesn't change in between the memcpy and
@@ -1001,8 +1111,8 @@ bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) {
   // NOTE: This is conservative, it will stop on any read from the source loc,
   // not just the defining memcpy.
   MemDepResult SourceDep =
-    MD->getPointerDependencyFrom(AliasAnalysis::getLocationForSource(MDep),
-                                 false, CS.getInstruction(), MDep->getParent());
+      MD->getPointerDependencyFrom(MemoryLocation::getForSource(MDep), false,
+                                   CS.getInstruction(), MDep->getParent());
   if (!SourceDep.isClobber() || SourceDep.getInst() != MDep)
     return false;
 
@@ -1021,7 +1131,7 @@ bool MemCpyOpt::processByValArgument(CallSite CS, unsigned ArgNo) {
   return true;
 }
 
-/// iterateOnFunction - Executes one iteration of MemCpyOpt.
+/// Executes one iteration of MemCpyOpt.
 bool MemCpyOpt::iterateOnFunction(Function &F) {
   bool MadeChange = false;
 
@@ -1041,7 +1151,7 @@ bool MemCpyOpt::iterateOnFunction(Function &F) {
         RepeatInstruction = processMemCpy(M);
       else if (MemMoveInst *M = dyn_cast<MemMoveInst>(I))
         RepeatInstruction = processMemMove(M);
-      else if (CallSite CS = (Value*)I) {
+      else if (auto CS = CallSite(I)) {
         for (unsigned i = 0, e = CS.arg_size(); i != e; ++i)
           if (CS.isByValArgument(i))
             MadeChange |= processByValArgument(CS, i);
@@ -1058,17 +1168,13 @@ bool MemCpyOpt::iterateOnFunction(Function &F) {
   return MadeChange;
 }
 
-// MemCpyOpt::runOnFunction - This is the main transformation entry point for a
-// function.
-//
+/// This is the main transformation entry point for a function.
 bool MemCpyOpt::runOnFunction(Function &F) {
   if (skipOptnoneFunction(F))
     return false;
 
   bool MadeChange = false;
   MD = &getAnalysis<MemoryDependenceAnalysis>();
-  DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
-  DL = DLP ? &DLP->getDataLayout() : nullptr;
   TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
 
   // If we don't have at least memset and memcpy, there is little point of doing