[Orc] Fix a bug in the CompileOnDemand layer where stub decls were not cloned
[oota-llvm.git] / lib / ExecutionEngine / Orc / OrcMCJITReplacement.h
index 28fc04b283d315d981d37c5f2862b395c6bc8701..4023344d2f3d5f6f8196b27cd2ddbe75e2ba420d 100644 (file)
 #include "llvm/ExecutionEngine/Orc/LazyEmittingLayer.h"
 #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
 #include "llvm/Object/Archive.h"
-#include "llvm/Target/TargetSubtargetInfo.h"
 
 namespace llvm {
+namespace orc {
 
 class OrcMCJITReplacement : public ExecutionEngine {
 
-  class ForwardingRTDyldMM : public RTDyldMemoryManager {
+  // OrcMCJITReplacement needs to do a little extra book-keeping to ensure that
+  // Orc's automatic finalization doesn't kick in earlier than MCJIT clients are
+  // expecting - see finalizeMemory.
+  class MCJITReplacementMemMgr : public MCJITMemoryManager {
   public:
-    ForwardingRTDyldMM(OrcMCJITReplacement &M) : M(M) {}
+    MCJITReplacementMemMgr(OrcMCJITReplacement &M,
+                           std::shared_ptr<MCJITMemoryManager> ClientMM)
+      : M(M), ClientMM(std::move(ClientMM)) {}
 
     uint8_t *allocateCodeSection(uintptr_t Size, unsigned Alignment,
                                  unsigned SectionID,
                                  StringRef SectionName) override {
       uint8_t *Addr =
-          M.MM->allocateCodeSection(Size, Alignment, SectionID, SectionName);
+          ClientMM->allocateCodeSection(Size, Alignment, SectionID,
+                                        SectionName);
       M.SectionsAllocatedSinceLastLoad.insert(Addr);
       return Addr;
     }
@@ -42,43 +48,35 @@ class OrcMCJITReplacement : public ExecutionEngine {
     uint8_t *allocateDataSection(uintptr_t Size, unsigned Alignment,
                                  unsigned SectionID, StringRef SectionName,
                                  bool IsReadOnly) override {
-      uint8_t *Addr = M.MM->allocateDataSection(Size, Alignment, SectionID,
-                                                SectionName, IsReadOnly);
+      uint8_t *Addr = ClientMM->allocateDataSection(Size, Alignment, SectionID,
+                                                    SectionName, IsReadOnly);
       M.SectionsAllocatedSinceLastLoad.insert(Addr);
       return Addr;
     }
 
     void reserveAllocationSpace(uintptr_t CodeSize, uintptr_t DataSizeRO,
                                 uintptr_t DataSizeRW) override {
-      return M.MM->reserveAllocationSpace(CodeSize, DataSizeRO, DataSizeRW);
+      return ClientMM->reserveAllocationSpace(CodeSize, DataSizeRO,
+                                                DataSizeRW);
     }
 
     bool needsToReserveAllocationSpace() override {
-      return M.MM->needsToReserveAllocationSpace();
+      return ClientMM->needsToReserveAllocationSpace();
     }
 
     void registerEHFrames(uint8_t *Addr, uint64_t LoadAddr,
                           size_t Size) override {
-      return M.MM->registerEHFrames(Addr, LoadAddr, Size);
+      return ClientMM->registerEHFrames(Addr, LoadAddr, Size);
     }
 
     void deregisterEHFrames(uint8_t *Addr, uint64_t LoadAddr,
                             size_t Size) override {
-      return M.MM->deregisterEHFrames(Addr, LoadAddr, Size);
-    }
-
-    uint64_t getSymbolAddress(const std::string &Name) override {
-      return M.getSymbolAddressWithoutMangling(Name);
-    }
-
-    void *getPointerToNamedFunction(const std::string &Name,
-                                    bool AbortOnFailure = true) override {
-      return M.MM->getPointerToNamedFunction(Name, AbortOnFailure);
+      return ClientMM->deregisterEHFrames(Addr, LoadAddr, Size);
     }
 
     void notifyObjectLoaded(ExecutionEngine *EE,
                             const object::ObjectFile &O) override {
-      return M.MM->notifyObjectLoaded(EE, O);
+      return ClientMM->notifyObjectLoaded(EE, O);
     }
 
     bool finalizeMemory(std::string *ErrMsg = nullptr) override {
@@ -96,20 +94,41 @@ class OrcMCJITReplacement : public ExecutionEngine {
       // get more than one set of objects loaded but not yet finalized is if
       // they were loaded during relocation of another set.
       if (M.UnfinalizedSections.size() == 1)
-        return M.MM->finalizeMemory(ErrMsg);
+        return ClientMM->finalizeMemory(ErrMsg);
       return false;
     }
 
+  private:
+    OrcMCJITReplacement &M;
+    std::shared_ptr<MCJITMemoryManager> ClientMM;
+  };
+
+  class LinkingResolver : public RuntimeDyld::SymbolResolver {
+  public:
+    LinkingResolver(OrcMCJITReplacement &M) : M(M) {}
+
+    RuntimeDyld::SymbolInfo findSymbol(const std::string &Name) override {
+      return M.findMangledSymbol(Name);
+    }
+
+    RuntimeDyld::SymbolInfo
+    findSymbolInLogicalDylib(const std::string &Name) override {
+      return M.ClientResolver->findSymbolInLogicalDylib(Name);
+    }
+
   private:
     OrcMCJITReplacement &M;
   };
 
 private:
+
   static ExecutionEngine *
   createOrcMCJITReplacement(std::string *ErrorMsg,
-                            std::unique_ptr<RTDyldMemoryManager> OrcJMM,
-                            std::unique_ptr<llvm::TargetMachine> TM) {
-    return new llvm::OrcMCJITReplacement(std::move(OrcJMM), std::move(TM));
+                            std::shared_ptr<MCJITMemoryManager> MemMgr,
+                            std::shared_ptr<RuntimeDyld::SymbolResolver> Resolver,
+                            std::unique_ptr<TargetMachine> TM) {
+    return new OrcMCJITReplacement(std::move(MemMgr), std::move(Resolver),
+                                   std::move(TM));
   }
 
 public:
@@ -117,12 +136,15 @@ public:
     OrcMCJITReplacementCtor = createOrcMCJITReplacement;
   }
 
-  OrcMCJITReplacement(std::unique_ptr<RTDyldMemoryManager> MM,
-                      std::unique_ptr<TargetMachine> TM)
-      : TM(std::move(TM)), MM(std::move(MM)), Mang(this->TM->getDataLayout()),
+  OrcMCJITReplacement(
+                    std::shared_ptr<MCJITMemoryManager> MemMgr,
+                    std::shared_ptr<RuntimeDyld::SymbolResolver> ClientResolver,
+                    std::unique_ptr<TargetMachine> TM)
+      : TM(std::move(TM)), MemMgr(*this, std::move(MemMgr)),
+        Resolver(*this), ClientResolver(std::move(ClientResolver)),
+        Mang(this->TM->getDataLayout()),
         NotifyObjectLoaded(*this), NotifyFinalized(*this),
-        ObjectLayer(ObjectLayerT::CreateRTDyldMMFtor(), NotifyObjectLoaded,
-                    NotifyFinalized),
+        ObjectLayer(NotifyObjectLoaded, NotifyFinalized),
         CompileLayer(ObjectLayer, SimpleCompiler(*this->TM)),
         LazyEmitLayer(CompileLayer) {
     setDataLayout(this->TM->getDataLayout());
@@ -132,21 +154,19 @@ public:
 
     // If this module doesn't have a DataLayout attached then attach the
     // default.
-    if (!M->getDataLayout())
-      M->setDataLayout(getDataLayout());
+    if (M->getDataLayout().isDefault())
+      M->setDataLayout(*getDataLayout());
 
-    OwnedModules.push_back(std::move(M));
+    Modules.push_back(std::move(M));
     std::vector<Module *> Ms;
-    Ms.push_back(&*OwnedModules.back());
-    LazyEmitLayer.addModuleSet(std::move(Ms),
-                               llvm::make_unique<ForwardingRTDyldMM>(*this));
+    Ms.push_back(&*Modules.back());
+    LazyEmitLayer.addModuleSet(std::move(Ms), &MemMgr, &Resolver);
   }
 
   void addObjectFile(std::unique_ptr<object::ObjectFile> O) override {
     std::vector<std::unique_ptr<object::ObjectFile>> Objs;
     Objs.push_back(std::move(O));
-    ObjectLayer.addObjectSet(std::move(Objs),
-                             llvm::make_unique<ForwardingRTDyldMM>(*this));
+    ObjectLayer.addObjectSet(std::move(Objs), &MemMgr, &Resolver);
   }
 
   void addObjectFile(object::OwningBinary<object::ObjectFile> O) override {
@@ -155,8 +175,12 @@ public:
     std::tie(Obj, Buf) = O.takeBinary();
     std::vector<std::unique_ptr<object::ObjectFile>> Objs;
     Objs.push_back(std::move(Obj));
-    ObjectLayer.addObjectSet(std::move(Objs),
-                             llvm::make_unique<ForwardingRTDyldMM>(*this));
+    auto H =
+      ObjectLayer.addObjectSet(std::move(Objs), &MemMgr, &Resolver);
+
+    std::vector<std::unique_ptr<MemoryBuffer>> Bufs;
+    Bufs.push_back(std::move(Buf));
+    ObjectLayer.takeOwnershipOfBuffers(H, std::move(Bufs));
   }
 
   void addArchive(object::OwningBinary<object::Archive> A) override {
@@ -164,7 +188,11 @@ public:
   }
 
   uint64_t getSymbolAddress(StringRef Name) {
-    return getSymbolAddressWithoutMangling(Mangle(Name));
+    return findSymbol(Name).getAddress();
+  }
+
+  RuntimeDyld::SymbolInfo findSymbol(StringRef Name) {
+    return findMangledSymbol(Mangle(Name));
   }
 
   void finalizeObject() override {
@@ -208,18 +236,19 @@ public:
   }
 
 private:
-  uint64_t getSymbolAddressWithoutMangling(StringRef Name) {
-    if (uint64_t Addr = LazyEmitLayer.getSymbolAddress(Name, false))
-      return Addr;
-    if (uint64_t Addr = MM->getSymbolAddress(Name))
-      return Addr;
-    if (uint64_t Addr = scanArchives(Name))
-      return Addr;
 
-    return 0;
+  RuntimeDyld::SymbolInfo findMangledSymbol(StringRef Name) {
+    if (auto Sym = LazyEmitLayer.findSymbol(Name, false))
+      return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags());
+    if (auto Sym = ClientResolver->findSymbol(Name))
+      return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags());
+    if (auto Sym = scanArchives(Name))
+      return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags());
+
+    return nullptr;
   }
 
-  uint64_t scanArchives(StringRef Name) {
+  JITSymbol scanArchives(StringRef Name) {
     for (object::OwningBinary<object::Archive> &OB : Archives) {
       object::Archive *A = OB.getBinary();
       // Look for our symbols in each Archive
@@ -235,14 +264,13 @@ private:
           std::vector<std::unique_ptr<object::ObjectFile>> ObjSet;
           ObjSet.push_back(std::unique_ptr<object::ObjectFile>(
               static_cast<object::ObjectFile *>(ChildBin.release())));
-          ObjectLayer.addObjectSet(
-              std::move(ObjSet), llvm::make_unique<ForwardingRTDyldMM>(*this));
-          if (uint64_t Addr = ObjectLayer.getSymbolAddress(Name, true))
-            return Addr;
+          ObjectLayer.addObjectSet(std::move(ObjSet), &MemMgr, &Resolver);
+          if (auto Sym = ObjectLayer.findSymbol(Name, true))
+            return Sym;
         }
       }
     }
-    return 0;
+    return nullptr;
   }
 
   class NotifyObjectLoadedT {
@@ -261,7 +289,7 @@ private:
       assert(Objects.size() == Infos.size() &&
              "Incorrect number of Infos for Objects.");
       for (unsigned I = 0; I < Objects.size(); ++I)
-        M.MM->notifyObjectLoaded(&M, *Objects[I]);
+        M.MemMgr.notifyObjectLoaded(&M, *Objects[I]);
     };
 
   private:
@@ -293,7 +321,9 @@ private:
   typedef LazyEmittingLayer<CompileLayerT> LazyEmitLayerT;
 
   std::unique_ptr<TargetMachine> TM;
-  std::unique_ptr<RTDyldMemoryManager> MM;
+  MCJITReplacementMemMgr MemMgr;
+  LinkingResolver Resolver;
+  std::shared_ptr<RuntimeDyld::SymbolResolver> ClientResolver;
   Mangler Mang;
 
   NotifyObjectLoadedT NotifyObjectLoaded;
@@ -303,10 +333,6 @@ private:
   CompileLayerT CompileLayer;
   LazyEmitLayerT LazyEmitLayer;
 
-  // MCJIT keeps modules alive - we need to do the same for backwards
-  // compatibility.
-  std::vector<std::unique_ptr<Module>> OwnedModules;
-
   // We need to store ObjLayerT::ObjSetHandles for each of the object sets
   // that have been emitted but not yet finalized so that we can forward the
   // mapSectionAddress calls appropriately.
@@ -323,6 +349,8 @@ private:
 
   std::vector<object::OwningBinary<object::Archive>> Archives;
 };
-}
+
+} // End namespace orc.
+} // End namespace llvm.
 
 #endif // LLVM_LIB_EXECUTIONENGINE_ORC_MCJITREPLACEMENT_H