Remove the successor probabilities normalization in tail duplication pass.
[oota-llvm.git] / lib / Target / NVPTX / NVPTXLowerAggrCopies.cpp
index f7fa7aa61df5190cfc8cdced3c21b60a465e8e71..f770c2acaab51d1353945039f4f38bfee43065ac 100644 (file)
@@ -6,12 +6,16 @@
 // License. See LICENSE.TXT for details.
 //
 //===----------------------------------------------------------------------===//
+//
+// \file
 // Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when
 // the size is large or is not a compile-time constant.
 //
 //===----------------------------------------------------------------------===//
 
 #include "NVPTXLowerAggrCopies.h"
+#include "llvm/CodeGen/MachineFunctionAnalysis.h"
+#include "llvm/CodeGen/StackProtector.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
-#include "llvm/Support/InstIterator.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+
+#define DEBUG_TYPE "nvptx"
 
 using namespace llvm;
 
-namespace llvm {
-FunctionPass *createLowerAggrCopies();
-}
+namespace {
 
-char NVPTXLowerAggrCopies::ID = 0;
+// actual analysis class, which is a functionpass
+struct NVPTXLowerAggrCopies : public FunctionPass {
+  static char ID;
 
-// Lower MemTransferInst or load-store pair to loop
-static void convertTransferToLoop(Instruction *splitAt, Value *srcAddr,
-                                  Value *dstAddr, Value *len,
-                                  //unsigned numLoads,
-                                  bool srcVolatile, bool dstVolatile,
-                                  LLVMContext &Context, Function &F) {
-  Type *indType = len->getType();
+  NVPTXLowerAggrCopies() : FunctionPass(ID) {}
 
-  BasicBlock *origBB = splitAt->getParent();
-  BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split");
-  BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB);
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addPreserved<MachineFunctionAnalysis>();
+    AU.addPreserved<StackProtector>();
+  }
 
-  origBB->getTerminator()->setSuccessor(0, loopBB);
-  IRBuilder<> builder(origBB, origBB->getTerminator());
+  bool runOnFunction(Function &F) override;
 
-  // srcAddr and dstAddr are expected to be pointer types,
-  // so no check is made here.
-  unsigned srcAS =
-      dyn_cast<PointerType>(srcAddr->getType())->getAddressSpace();
-  unsigned dstAS =
-      dyn_cast<PointerType>(dstAddr->getType())->getAddressSpace();
+  static const unsigned MaxAggrCopySize = 128;
 
-  // Cast pointers to (char *)
-  srcAddr = builder.CreateBitCast(srcAddr, Type::getInt8PtrTy(Context, srcAS));
-  dstAddr = builder.CreateBitCast(dstAddr, Type::getInt8PtrTy(Context, dstAS));
+  const char *getPassName() const override {
+    return "Lower aggregate copies/intrinsics into loops";
+  }
+};
 
-  IRBuilder<> loop(loopBB);
-  // The loop index (ind) is a phi node.
-  PHINode *ind = loop.CreatePHI(indType, 0);
-  // Incoming value for ind is 0
-  ind->addIncoming(ConstantInt::get(indType, 0), origBB);
+char NVPTXLowerAggrCopies::ID = 0;
+
+// Lower memcpy to loop.
+void convertMemCpyToLoop(Instruction *ConvertedInst, Value *SrcAddr,
+                         Value *DstAddr, Value *CopyLen, bool SrcIsVolatile,
+                         bool DstIsVolatile, LLVMContext &Context,
+                         Function &F) {
+  Type *TypeOfCopyLen = CopyLen->getType();
 
-  // load from srcAddr+ind
-  Value *val = loop.CreateLoad(loop.CreateGEP(srcAddr, ind), srcVolatile);
-  // store at dstAddr+ind
-  loop.CreateStore(val, loop.CreateGEP(dstAddr, ind), dstVolatile);
+  BasicBlock *OrigBB = ConvertedInst->getParent();
+  BasicBlock *NewBB =
+      ConvertedInst->getParent()->splitBasicBlock(ConvertedInst, "split");
+  BasicBlock *LoopBB = BasicBlock::Create(Context, "loadstoreloop", &F, NewBB);
 
-  // The value for ind coming from backedge is (ind + 1)
-  Value *newind = loop.CreateAdd(ind, ConstantInt::get(indType, 1));
-  ind->addIncoming(newind, loopBB);
+  OrigBB->getTerminator()->setSuccessor(0, LoopBB);
+  IRBuilder<> Builder(OrigBB->getTerminator());
 
-  loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB);
+  // SrcAddr and DstAddr are expected to be pointer types,
+  // so no check is made here.
+  unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace();
+  unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
+
+  // Cast pointers to (char *)
+  SrcAddr = Builder.CreateBitCast(SrcAddr, Builder.getInt8PtrTy(SrcAS));
+  DstAddr = Builder.CreateBitCast(DstAddr, Builder.getInt8PtrTy(DstAS));
+
+  IRBuilder<> LoopBuilder(LoopBB);
+  PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
+  LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB);
+
+  // load from SrcAddr+LoopIndex
+  // TODO: we can leverage the align parameter of llvm.memcpy for more efficient
+  // word-sized loads and stores.
+  Value *Element =
+      LoopBuilder.CreateLoad(LoopBuilder.CreateInBoundsGEP(
+                                 LoopBuilder.getInt8Ty(), SrcAddr, LoopIndex),
+                             SrcIsVolatile);
+  // store at DstAddr+LoopIndex
+  LoopBuilder.CreateStore(Element,
+                          LoopBuilder.CreateInBoundsGEP(LoopBuilder.getInt8Ty(),
+                                                        DstAddr, LoopIndex),
+                          DstIsVolatile);
+
+  // The value for LoopIndex coming from backedge is (LoopIndex + 1)
+  Value *NewIndex =
+      LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1));
+  LoopIndex->addIncoming(NewIndex, LoopBB);
+
+  LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB,
+                           NewBB);
 }
 
-// Lower MemSetInst to loop
-static void convertMemSetToLoop(Instruction *splitAt, Value *dstAddr,
-                                Value *len, Value *val, LLVMContext &Context,
-                                Function &F) {
-  BasicBlock *origBB = splitAt->getParent();
-  BasicBlock *newBB = splitAt->getParent()->splitBasicBlock(splitAt, "split");
-  BasicBlock *loopBB = BasicBlock::Create(Context, "loadstoreloop", &F, newBB);
+// Lower memmove to IR. memmove is required to correctly copy overlapping memory
+// regions; therefore, it has to check the relative positions of the source and
+// destination pointers and choose the copy direction accordingly.
+//
+// The code below is an IR rendition of this C function:
+//
+// void* memmove(void* dst, const void* src, size_t n) {
+//   unsigned char* d = dst;
+//   const unsigned char* s = src;
+//   if (s < d) {
+//     // copy backwards
+//     while (n--) {
+//       d[n] = s[n];
+//     }
+//   } else {
+//     // copy forward
+//     for (size_t i = 0; i < n; ++i) {
+//       d[i] = s[i];
+//     }
+//   }
+//   return dst;
+// }
+void convertMemMoveToLoop(Instruction *ConvertedInst, Value *SrcAddr,
+                          Value *DstAddr, Value *CopyLen, bool SrcIsVolatile,
+                          bool DstIsVolatile, LLVMContext &Context,
+                          Function &F) {
+  Type *TypeOfCopyLen = CopyLen->getType();
+  BasicBlock *OrigBB = ConvertedInst->getParent();
+
+  // Create the a comparison of src and dst, based on which we jump to either
+  // the forward-copy part of the function (if src >= dst) or the backwards-copy
+  // part (if src < dst).
+  // SplitBlockAndInsertIfThenElse conveniently creates the basic if-then-else
+  // structure. Its block terminators (unconditional branches) are replaced by
+  // the appropriate conditional branches when the loop is built.
+  ICmpInst *PtrCompare = new ICmpInst(ConvertedInst, ICmpInst::ICMP_ULT,
+                                      SrcAddr, DstAddr, "compare_src_dst");
+  TerminatorInst *ThenTerm, *ElseTerm;
+  SplitBlockAndInsertIfThenElse(PtrCompare, ConvertedInst, &ThenTerm,
+                                &ElseTerm);
+
+  // Each part of the function consists of two blocks:
+  //   copy_backwards:        used to skip the loop when n == 0
+  //   copy_backwards_loop:   the actual backwards loop BB
+  //   copy_forward:          used to skip the loop when n == 0
+  //   copy_forward_loop:     the actual forward loop BB
+  BasicBlock *CopyBackwardsBB = ThenTerm->getParent();
+  CopyBackwardsBB->setName("copy_backwards");
+  BasicBlock *CopyForwardBB = ElseTerm->getParent();
+  CopyForwardBB->setName("copy_forward");
+  BasicBlock *ExitBB = ConvertedInst->getParent();
+  ExitBB->setName("memmove_done");
+
+  // Initial comparison of n == 0 that lets us skip the loops altogether. Shared
+  // between both backwards and forward copy clauses.
+  ICmpInst *CompareN =
+      new ICmpInst(OrigBB->getTerminator(), ICmpInst::ICMP_EQ, CopyLen,
+                   ConstantInt::get(TypeOfCopyLen, 0), "compare_n_to_0");
+
+  // Copying backwards.
+  BasicBlock *LoopBB =
+      BasicBlock::Create(Context, "copy_backwards_loop", &F, CopyForwardBB);
+  IRBuilder<> LoopBuilder(LoopBB);
+  PHINode *LoopPhi = LoopBuilder.CreatePHI(TypeOfCopyLen, 0);
+  Value *IndexPtr = LoopBuilder.CreateSub(
+      LoopPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_ptr");
+  Value *Element = LoopBuilder.CreateLoad(
+      LoopBuilder.CreateInBoundsGEP(SrcAddr, IndexPtr), "element");
+  LoopBuilder.CreateStore(Element,
+                          LoopBuilder.CreateInBoundsGEP(DstAddr, IndexPtr));
+  LoopBuilder.CreateCondBr(
+      LoopBuilder.CreateICmpEQ(IndexPtr, ConstantInt::get(TypeOfCopyLen, 0)),
+      ExitBB, LoopBB);
+  LoopPhi->addIncoming(IndexPtr, LoopBB);
+  LoopPhi->addIncoming(CopyLen, CopyBackwardsBB);
+  BranchInst::Create(ExitBB, LoopBB, CompareN, ThenTerm);
+  ThenTerm->eraseFromParent();
+
+  // Copying forward.
+  BasicBlock *FwdLoopBB =
+      BasicBlock::Create(Context, "copy_forward_loop", &F, ExitBB);
+  IRBuilder<> FwdLoopBuilder(FwdLoopBB);
+  PHINode *FwdCopyPhi = FwdLoopBuilder.CreatePHI(TypeOfCopyLen, 0, "index_ptr");
+  Value *FwdElement = FwdLoopBuilder.CreateLoad(
+      FwdLoopBuilder.CreateInBoundsGEP(SrcAddr, FwdCopyPhi), "element");
+  FwdLoopBuilder.CreateStore(
+      FwdElement, FwdLoopBuilder.CreateInBoundsGEP(DstAddr, FwdCopyPhi));
+  Value *FwdIndexPtr = FwdLoopBuilder.CreateAdd(
+      FwdCopyPhi, ConstantInt::get(TypeOfCopyLen, 1), "index_increment");
+  FwdLoopBuilder.CreateCondBr(FwdLoopBuilder.CreateICmpEQ(FwdIndexPtr, CopyLen),
+                              ExitBB, FwdLoopBB);
+  FwdCopyPhi->addIncoming(FwdIndexPtr, FwdLoopBB);
+  FwdCopyPhi->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), CopyForwardBB);
+
+  BranchInst::Create(ExitBB, FwdLoopBB, CompareN, ElseTerm);
+  ElseTerm->eraseFromParent();
+}
 
-  origBB->getTerminator()->setSuccessor(0, loopBB);
-  IRBuilder<> builder(origBB, origBB->getTerminator());
+// Lower memset to loop.
+void convertMemSetToLoop(Instruction *ConvertedInst, Value *DstAddr,
+                         Value *CopyLen, Value *SetValue, LLVMContext &Context,
+                         Function &F) {
+  BasicBlock *OrigBB = ConvertedInst->getParent();
+  BasicBlock *NewBB =
+      ConvertedInst->getParent()->splitBasicBlock(ConvertedInst, "split");
+  BasicBlock *LoopBB = BasicBlock::Create(Context, "loadstoreloop", &F, NewBB);
 
-  unsigned dstAS =
-      dyn_cast<PointerType>(dstAddr->getType())->getAddressSpace();
+  OrigBB->getTerminator()->setSuccessor(0, LoopBB);
+  IRBuilder<> Builder(OrigBB->getTerminator());
 
   // Cast pointer to the type of value getting stored
-  dstAddr = builder.CreateBitCast(dstAddr,
-                                  PointerType::get(val->getType(), dstAS));
+  unsigned dstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace();
+  DstAddr = Builder.CreateBitCast(DstAddr,
+                                  PointerType::get(SetValue->getType(), dstAS));
 
-  IRBuilder<> loop(loopBB);
-  PHINode *ind = loop.CreatePHI(len->getType(), 0);
-  ind->addIncoming(ConstantInt::get(len->getType(), 0), origBB);
+  IRBuilder<> LoopBuilder(LoopBB);
+  PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLen->getType(), 0);
+  LoopIndex->addIncoming(ConstantInt::get(CopyLen->getType(), 0), OrigBB);
 
-  loop.CreateStore(val, loop.CreateGEP(dstAddr, ind), false);
+  LoopBuilder.CreateStore(
+      SetValue,
+      LoopBuilder.CreateInBoundsGEP(SetValue->getType(), DstAddr, LoopIndex),
+      false);
 
-  Value *newind = loop.CreateAdd(ind, ConstantInt::get(len->getType(), 1));
-  ind->addIncoming(newind, loopBB);
+  Value *NewIndex =
+      LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLen->getType(), 1));
+  LoopIndex->addIncoming(NewIndex, LoopBB);
 
-  loop.CreateCondBr(loop.CreateICmpULT(newind, len), loopBB, newBB);
+  LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB,
+                           NewBB);
 }
 
 bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {
-  SmallVector<LoadInst *, 4> aggrLoads;
-  SmallVector<MemTransferInst *, 4> aggrMemcpys;
-  SmallVector<MemSetInst *, 4> aggrMemsets;
+  SmallVector<LoadInst *, 4> AggrLoads;
+  SmallVector<MemIntrinsic *, 4> MemCalls;
 
-  DataLayout *TD = &getAnalysis<DataLayout>();
+  const DataLayout &DL = F.getParent()->getDataLayout();
   LLVMContext &Context = F.getParent()->getContext();
 
-  //
-  // Collect all the aggrLoads, aggrMemcpys and addrMemsets.
-  //
-  //const BasicBlock *firstBB = &F.front();  // first BB in F
+  // Collect all aggregate loads and mem* calls.
   for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
-    //BasicBlock *bb = BI;
     for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE;
-        ++II) {
-      if (LoadInst * load = dyn_cast<LoadInst>(II)) {
-
-        if (load->hasOneUse() == false) continue;
-
-        if (TD->getTypeStoreSize(load->getType()) < MaxAggrCopySize) continue;
+         ++II) {
+      if (LoadInst *LI = dyn_cast<LoadInst>(II)) {
+        if (!LI->hasOneUse())
+          continue;
 
-        User *use = *(load->use_begin());
-        if (StoreInst * store = dyn_cast<StoreInst>(use)) {
-          if (store->getOperand(0) != load) //getValueOperand
+        if (DL.getTypeStoreSize(LI->getType()) < MaxAggrCopySize)
           continue;
-          aggrLoads.push_back(load);
-        }
-      } else if (MemTransferInst * intr = dyn_cast<MemTransferInst>(II)) {
-        Value *len = intr->getLength();
-        // If the number of elements being copied is greater
-        // than MaxAggrCopySize, lower it to a loop
-        if (ConstantInt * len_int = dyn_cast < ConstantInt > (len)) {
-          if (len_int->getZExtValue() >= MaxAggrCopySize) {
-            aggrMemcpys.push_back(intr);
-          }
-        } else {
-          // turn variable length memcpy/memmov into loop
-          aggrMemcpys.push_back(intr);
+
+        if (StoreInst *SI = dyn_cast<StoreInst>(LI->user_back())) {
+          if (SI->getOperand(0) != LI)
+            continue;
+          AggrLoads.push_back(LI);
         }
-      } else if (MemSetInst * memsetintr = dyn_cast<MemSetInst>(II)) {
-        Value *len = memsetintr->getLength();
-        if (ConstantInt * len_int = dyn_cast<ConstantInt>(len)) {
-          if (len_int->getZExtValue() >= MaxAggrCopySize) {
-            aggrMemsets.push_back(memsetintr);
+      } else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(II)) {
+        // Convert intrinsic calls with variable size or with constant size
+        // larger than the MaxAggrCopySize threshold.
+        if (ConstantInt *LenCI = dyn_cast<ConstantInt>(IntrCall->getLength())) {
+          if (LenCI->getZExtValue() >= MaxAggrCopySize) {
+            MemCalls.push_back(IntrCall);
           }
         } else {
-          // turn variable length memset into loop
-          aggrMemsets.push_back(memsetintr);
+          MemCalls.push_back(IntrCall);
         }
       }
     }
   }
-  if ((aggrLoads.size() == 0) && (aggrMemcpys.size() == 0)
-      && (aggrMemsets.size() == 0)) return false;
+
+  if (AggrLoads.size() == 0 && MemCalls.size() == 0) {
+    return false;
+  }
 
   //
   // Do the transformation of an aggr load/copy/set to a loop
   //
-  for (unsigned i = 0, e = aggrLoads.size(); i != e; ++i) {
-    LoadInst *load = aggrLoads[i];
-    StoreInst *store = dyn_cast<StoreInst>(*load->use_begin());
-    Value *srcAddr = load->getOperand(0);
-    Value *dstAddr = store->getOperand(1);
-    unsigned numLoads = TD->getTypeStoreSize(load->getType());
-    Value *len = ConstantInt::get(Type::getInt32Ty(Context), numLoads);
-
-    convertTransferToLoop(store, srcAddr, dstAddr, len, load->isVolatile(),
-                          store->isVolatile(), Context, F);
-
-    store->eraseFromParent();
-    load->eraseFromParent();
-  }
-
-  for (unsigned i = 0, e = aggrMemcpys.size(); i != e; ++i) {
-    MemTransferInst *cpy = aggrMemcpys[i];
-    Value *len = cpy->getLength();
-    // llvm 2.7 version of memcpy does not have volatile
-    // operand yet. So always making it non-volatile
-    // optimistically, so that we don't see unnecessary
-    // st.volatile in ptx
-    convertTransferToLoop(cpy, cpy->getSource(), cpy->getDest(), len, false,
-                          false, Context, F);
-    cpy->eraseFromParent();
+  for (LoadInst *LI : AggrLoads) {
+    StoreInst *SI = dyn_cast<StoreInst>(*LI->user_begin());
+    Value *SrcAddr = LI->getOperand(0);
+    Value *DstAddr = SI->getOperand(1);
+    unsigned NumLoads = DL.getTypeStoreSize(LI->getType());
+    Value *CopyLen = ConstantInt::get(Type::getInt32Ty(Context), NumLoads);
+
+    convertMemCpyToLoop(/* ConvertedInst */ SI,
+                        /* SrcAddr */ SrcAddr, /* DstAddr */ DstAddr,
+                        /* CopyLen */ CopyLen,
+                        /* SrcIsVolatile */ LI->isVolatile(),
+                        /* DstIsVolatile */ SI->isVolatile(),
+                        /* Context */ Context,
+                        /* Function F */ F);
+
+    SI->eraseFromParent();
+    LI->eraseFromParent();
   }
 
-  for (unsigned i = 0, e = aggrMemsets.size(); i != e; ++i) {
-    MemSetInst *memsetinst = aggrMemsets[i];
-    Value *len = memsetinst->getLength();
-    Value *val = memsetinst->getValue();
-    convertMemSetToLoop(memsetinst, memsetinst->getDest(), len, val, Context,
-                        F);
-    memsetinst->eraseFromParent();
+  // Transform mem* intrinsic calls.
+  for (MemIntrinsic *MemCall : MemCalls) {
+    if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(MemCall)) {
+      convertMemCpyToLoop(/* ConvertedInst */ Memcpy,
+                          /* SrcAddr */ Memcpy->getRawSource(),
+                          /* DstAddr */ Memcpy->getRawDest(),
+                          /* CopyLen */ Memcpy->getLength(),
+                          /* SrcIsVolatile */ Memcpy->isVolatile(),
+                          /* DstIsVolatile */ Memcpy->isVolatile(),
+                          /* Context */ Context,
+                          /* Function F */ F);
+    } else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(MemCall)) {
+      convertMemMoveToLoop(/* ConvertedInst */ Memmove,
+                           /* SrcAddr */ Memmove->getRawSource(),
+                           /* DstAddr */ Memmove->getRawDest(),
+                           /* CopyLen */ Memmove->getLength(),
+                           /* SrcIsVolatile */ Memmove->isVolatile(),
+                           /* DstIsVolatile */ Memmove->isVolatile(),
+                           /* Context */ Context,
+                           /* Function F */ F);
+
+    } else if (MemSetInst *Memset = dyn_cast<MemSetInst>(MemCall)) {
+      convertMemSetToLoop(/* ConvertedInst */ Memset,
+                          /* DstAddr */ Memset->getRawDest(),
+                          /* CopyLen */ Memset->getLength(),
+                          /* SetValue */ Memset->getValue(),
+                          /* Context */ Context,
+                          /* Function F */ F);
+    }
+    MemCall->eraseFromParent();
   }
 
   return true;
 }
 
+} // namespace
+
+namespace llvm {
+void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);
+}
+
+INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies",
+                "Lower aggregate copies, and llvm.mem* intrinsics into loops",
+                false, false)
+
 FunctionPass *llvm::createLowerAggrCopies() {
   return new NVPTXLowerAggrCopies();
 }