[Orc] Refactor the CompileOnDemandLayer to make its addModuleSet method
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / CompileOnDemandLayer.h
index 77b0c48d0a7a3344f7995172795a60ffac4fb531..30f7f1cd5f593757fb92a5ef371a4fbe0dbe1a56 100644 (file)
@@ -16,7 +16,7 @@
 #define LLVM_EXECUTIONENGINE_ORC_COMPILEONDEMANDLAYER_H
 
 #include "IndirectionUtils.h"
-#include "LookasideRTDyldMM.h"
+#include "LambdaResolver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
 #include <list>
@@ -36,7 +36,7 @@ namespace orc {
 /// compiled only when it is first called.
 template <typename BaseLayerT, typename CompileCallbackMgrT>
 class CompileOnDemandLayer {
-public:
+private:
   /// @brief Lookup helper that provides compatibility with the classic
   ///        static-compilation symbol resolution process.
   ///
@@ -64,6 +64,8 @@ public:
     /// @brief Construct a scoped lookup.
     CODScopedLookup(BaseLayerT &BaseLayer) : BaseLayer(BaseLayer) {}
 
+    virtual ~CODScopedLookup() {}
+
     /// @brief Start a new context for a single logical module.
     LMHandle createLogicalModule() {
       Handles.push_back(SiblingHandlesList());
@@ -92,6 +94,10 @@ public:
       return nullptr;
     }
 
+    /// @brief Find an external symbol (via the user supplied SymbolResolver).
+    virtual RuntimeDyld::SymbolInfo
+    externalLookup(const std::string &Name) const = 0;
+
   private:
 
     JITSymbol findSymbolIn(LMHandle LMH, const std::string &Name) {
@@ -105,7 +111,29 @@ public:
     PseudoDylibModuleSetHandlesList Handles;
   };
 
-private:
+  template <typename ResolverPtrT>
+  class CODScopedLookupImpl : public CODScopedLookup {
+  public:
+    CODScopedLookupImpl(BaseLayerT &BaseLayer, ResolverPtrT Resolver)
+      : CODScopedLookup(BaseLayer), Resolver(std::move(Resolver)) {}
+
+    RuntimeDyld::SymbolInfo
+    externalLookup(const std::string &Name) const override {
+      return Resolver->findSymbol(Name);
+    }
+
+  private:
+    ResolverPtrT Resolver;
+  };
+
+  template <typename ResolverPtrT>
+  static std::shared_ptr<CODScopedLookup>
+  createCODScopedLookup(BaseLayerT &BaseLayer,
+                        ResolverPtrT Resolver) {
+    typedef CODScopedLookupImpl<ResolverPtrT> Impl;
+    return std::make_shared<Impl>(BaseLayer, std::move(Resolver));
+  }
+
   typedef typename BaseLayerT::ModuleSetHandleT BaseLayerModuleSetHandleT;
   typedef std::vector<BaseLayerModuleSetHandleT> BaseLayerModuleSetHandleListT;
 
@@ -138,36 +166,31 @@ public:
   /// @brief Handle to a set of loaded modules.
   typedef typename ModuleSetInfoListT::iterator ModuleSetHandleT;
 
-  // @brief Fallback lookup functor.
-  typedef std::function<uint64_t(const std::string &)> LookupFtor;
-
   /// @brief Construct a compile-on-demand layer instance.
   CompileOnDemandLayer(BaseLayerT &BaseLayer, CompileCallbackMgrT &CallbackMgr)
       : BaseLayer(BaseLayer), CompileCallbackMgr(CallbackMgr) {}
 
   /// @brief Add a module to the compile-on-demand layer.
-  template <typename ModuleSetT>
+  template <typename ModuleSetT, typename MemoryManagerPtrT,
+            typename SymbolResolverPtrT>
   ModuleSetHandleT addModuleSet(ModuleSetT Ms,
-                                LookupFtor FallbackLookup = nullptr) {
+                                MemoryManagerPtrT MemMgr,
+                                SymbolResolverPtrT Resolver) {
 
-    // If the user didn't supply a fallback lookup then just use
-    // getSymbolAddress.
-    if (!FallbackLookup)
-      FallbackLookup = [=](const std::string &Name) {
-                         return findSymbol(Name, true).getAddress();
-                       };
+    assert(MemMgr == nullptr &&
+           "User supplied memory managers not supported with COD yet.");
 
     // Create a lookup context and ModuleSetInfo for this module set.
     // For the purposes of symbol resolution the set Ms will be treated as if
     // the modules it contained had been linked together as a dylib.
-    auto DylibLookup = std::make_shared<CODScopedLookup>(BaseLayer);
+    auto DylibLookup = createCODScopedLookup(BaseLayer, std::move(Resolver));
     ModuleSetHandleT H =
         ModuleSetInfos.insert(ModuleSetInfos.end(), ModuleSetInfo(DylibLookup));
     ModuleSetInfo &MSI = ModuleSetInfos.back();
 
     // Process each of the modules in this module set.
     for (auto &M : Ms)
-      partitionAndAdd(*M, MSI, FallbackLookup);
+      partitionAndAdd(*M, MSI);
 
     return H;
   }
@@ -203,8 +226,7 @@ public:
 
 private:
 
-  void partitionAndAdd(Module &M, ModuleSetInfo &MSI,
-                       LookupFtor FallbackLookup) {
+  void partitionAndAdd(Module &M, ModuleSetInfo &MSI) {
     const char *AddrSuffix = "$orc_addr";
     const char *BodySuffix = "$orc_body";
 
@@ -224,8 +246,7 @@ private:
     auto FunctionModules = std::move(PartitionedModule.Functions);
 
     // Emit the commons stright away.
-    auto CommonHandle = addModule(std::move(CommonsModule), MSI, LogicalModule,
-                                  FallbackLookup);
+    auto CommonHandle = addModule(std::move(CommonsModule), MSI, LogicalModule);
     BaseLayer.emitAndFinalize(CommonHandle);
 
     // Map of definition names to callback-info data structures. We'll use
@@ -255,10 +276,12 @@ private:
         Function *Proto = StubsModule->getFunction(Name);
         assert(Proto && "Failed to clone function decl into stubs module.");
         auto CallbackInfo =
-          CompileCallbackMgr.getCompileCallback(*Proto->getFunctionType());
+          CompileCallbackMgr.getCompileCallback(Proto->getContext());
         GlobalVariable *FunctionBodyPointer =
-          createImplPointer(*Proto, Name + AddrSuffix,
-                            CallbackInfo.getAddress());
+          createImplPointer(*Proto->getType(), *Proto->getParent(),
+                            Name + AddrSuffix,
+                            createIRTypedAddress(*Proto->getFunctionType(),
+                                                 CallbackInfo.getAddress()));
         makeStub(*Proto, *FunctionBodyPointer);
 
         F.setName(Name + BodySuffix);
@@ -268,7 +291,7 @@ private:
         NewStubInfos.push_back(StubInfos.insert(StubInfos.begin(), KV));
       }
 
-      auto H = addModule(std::move(SubM), MSI, LogicalModule, FallbackLookup);
+      auto H = addModule(std::move(SubM), MSI, LogicalModule);
 
       // Set the compile actions for this module:
       for (auto &KVPair : NewStubInfos) {
@@ -286,7 +309,7 @@ private:
     // Ok - we've processed all the partitioned modules. Now add the
     // stubs/globals module and set the update actions.
     auto StubsH =
-      addModule(std::move(StubsModule), MSI, LogicalModule, FallbackLookup);
+      addModule(std::move(StubsModule), MSI, LogicalModule);
 
     for (auto &KVPair : StubInfos) {
       std::string AddrName = Mangle(KVPair.first + AddrSuffix,
@@ -304,8 +327,7 @@ private:
   BaseLayerModuleSetHandleT addModule(
                                std::unique_ptr<Module> M,
                                ModuleSetInfo &MSI,
-                               typename CODScopedLookup::LMHandle LogicalModule,
-                               LookupFtor FallbackLookup) {
+                               typename CODScopedLookup::LMHandle LogicalModule) {
 
     // Add this module to the JIT with a memory manager that uses the
     // DylibLookup to resolve symbols.
@@ -313,19 +335,25 @@ private:
     MSet.push_back(std::move(M));
 
     auto DylibLookup = MSI.Lookup;
-    auto MM =
-      createLookasideRTDyldMM<SectionMemoryManager>(
+    auto Resolver =
+      createLambdaResolver(
         [=](const std::string &Name) {
           if (auto Symbol = DylibLookup->findSymbol(LogicalModule, Name))
-            return Symbol.getAddress();
-          return FallbackLookup(Name);
+            return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
+                                           Symbol.getFlags());
+          return DylibLookup->externalLookup(Name);
         },
-        [=](const std::string &Name) {
-          return DylibLookup->findSymbol(LogicalModule, Name).getAddress();
+        [=](const std::string &Name) -> RuntimeDyld::SymbolInfo {
+          if (auto Symbol = DylibLookup->findSymbol(LogicalModule, Name))
+            return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
+                                           Symbol.getFlags());
+          return nullptr;
         });
 
     BaseLayerModuleSetHandleT H =
-      BaseLayer.addModuleSet(std::move(MSet), std::move(MM));
+      BaseLayer.addModuleSet(std::move(MSet),
+                             make_unique<SectionMemoryManager>(),
+                             std::move(Resolver));
     // Add this module to the logical module lookup.
     DylibLookup->addToLogicalModule(LogicalModule, H);
     MSI.BaseLayerModuleSetHandles.push_back(H);