[Orc] Refactor the CompileOnDemandLayer to make its addModuleSet method
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / CompileOnDemandLayer.h
index 3dc3927a15c211cc9fc8ae2333bca4902d0c4ce7..30f7f1cd5f593757fb92a5ef371a4fbe0dbe1a56 100644 (file)
@@ -36,7 +36,7 @@ namespace orc {
 /// compiled only when it is first called.
 template <typename BaseLayerT, typename CompileCallbackMgrT>
 class CompileOnDemandLayer {
-public:
+private:
   /// @brief Lookup helper that provides compatibility with the classic
   ///        static-compilation symbol resolution process.
   ///
@@ -64,6 +64,8 @@ public:
     /// @brief Construct a scoped lookup.
     CODScopedLookup(BaseLayerT &BaseLayer) : BaseLayer(BaseLayer) {}
 
+    virtual ~CODScopedLookup() {}
+
     /// @brief Start a new context for a single logical module.
     LMHandle createLogicalModule() {
       Handles.push_back(SiblingHandlesList());
@@ -92,6 +94,10 @@ public:
       return nullptr;
     }
 
+    /// @brief Find an external symbol (via the user supplied SymbolResolver).
+    virtual RuntimeDyld::SymbolInfo
+    externalLookup(const std::string &Name) const = 0;
+
   private:
 
     JITSymbol findSymbolIn(LMHandle LMH, const std::string &Name) {
@@ -105,7 +111,29 @@ public:
     PseudoDylibModuleSetHandlesList Handles;
   };
 
-private:
+  template <typename ResolverPtrT>
+  class CODScopedLookupImpl : public CODScopedLookup {
+  public:
+    CODScopedLookupImpl(BaseLayerT &BaseLayer, ResolverPtrT Resolver)
+      : CODScopedLookup(BaseLayer), Resolver(std::move(Resolver)) {}
+
+    RuntimeDyld::SymbolInfo
+    externalLookup(const std::string &Name) const override {
+      return Resolver->findSymbol(Name);
+    }
+
+  private:
+    ResolverPtrT Resolver;
+  };
+
+  template <typename ResolverPtrT>
+  static std::shared_ptr<CODScopedLookup>
+  createCODScopedLookup(BaseLayerT &BaseLayer,
+                        ResolverPtrT Resolver) {
+    typedef CODScopedLookupImpl<ResolverPtrT> Impl;
+    return std::make_shared<Impl>(BaseLayer, std::move(Resolver));
+  }
+
   typedef typename BaseLayerT::ModuleSetHandleT BaseLayerModuleSetHandleT;
   typedef std::vector<BaseLayerModuleSetHandleT> BaseLayerModuleSetHandleListT;
 
@@ -138,40 +166,31 @@ public:
   /// @brief Handle to a set of loaded modules.
   typedef typename ModuleSetInfoListT::iterator ModuleSetHandleT;
 
-  // @brief Fallback lookup functor.
-  typedef std::function<RuntimeDyld::SymbolInfo(const std::string &)> LookupFtor;
-
   /// @brief Construct a compile-on-demand layer instance.
   CompileOnDemandLayer(BaseLayerT &BaseLayer, CompileCallbackMgrT &CallbackMgr)
       : BaseLayer(BaseLayer), CompileCallbackMgr(CallbackMgr) {}
 
   /// @brief Add a module to the compile-on-demand layer.
-  template <typename ModuleSetT>
+  template <typename ModuleSetT, typename MemoryManagerPtrT,
+            typename SymbolResolverPtrT>
   ModuleSetHandleT addModuleSet(ModuleSetT Ms,
-                                LookupFtor FallbackLookup = nullptr) {
+                                MemoryManagerPtrT MemMgr,
+                                SymbolResolverPtrT Resolver) {
 
-    // If the user didn't supply a fallback lookup then just use
-    // getSymbolAddress.
-    if (!FallbackLookup)
-      FallbackLookup =
-        [=](const std::string &Name) -> RuntimeDyld::SymbolInfo {
-          if (auto Symbol = findSymbol(Name, true))
-            return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
-                                           Symbol.getFlags());
-          return nullptr;
-        };
+    assert(MemMgr == nullptr &&
+           "User supplied memory managers not supported with COD yet.");
 
     // Create a lookup context and ModuleSetInfo for this module set.
     // For the purposes of symbol resolution the set Ms will be treated as if
     // the modules it contained had been linked together as a dylib.
-    auto DylibLookup = std::make_shared<CODScopedLookup>(BaseLayer);
+    auto DylibLookup = createCODScopedLookup(BaseLayer, std::move(Resolver));
     ModuleSetHandleT H =
         ModuleSetInfos.insert(ModuleSetInfos.end(), ModuleSetInfo(DylibLookup));
     ModuleSetInfo &MSI = ModuleSetInfos.back();
 
     // Process each of the modules in this module set.
     for (auto &M : Ms)
-      partitionAndAdd(*M, MSI, FallbackLookup);
+      partitionAndAdd(*M, MSI);
 
     return H;
   }
@@ -207,8 +226,7 @@ public:
 
 private:
 
-  void partitionAndAdd(Module &M, ModuleSetInfo &MSI,
-                       LookupFtor FallbackLookup) {
+  void partitionAndAdd(Module &M, ModuleSetInfo &MSI) {
     const char *AddrSuffix = "$orc_addr";
     const char *BodySuffix = "$orc_body";
 
@@ -228,8 +246,7 @@ private:
     auto FunctionModules = std::move(PartitionedModule.Functions);
 
     // Emit the commons stright away.
-    auto CommonHandle = addModule(std::move(CommonsModule), MSI, LogicalModule,
-                                  FallbackLookup);
+    auto CommonHandle = addModule(std::move(CommonsModule), MSI, LogicalModule);
     BaseLayer.emitAndFinalize(CommonHandle);
 
     // Map of definition names to callback-info data structures. We'll use
@@ -274,7 +291,7 @@ private:
         NewStubInfos.push_back(StubInfos.insert(StubInfos.begin(), KV));
       }
 
-      auto H = addModule(std::move(SubM), MSI, LogicalModule, FallbackLookup);
+      auto H = addModule(std::move(SubM), MSI, LogicalModule);
 
       // Set the compile actions for this module:
       for (auto &KVPair : NewStubInfos) {
@@ -292,7 +309,7 @@ private:
     // Ok - we've processed all the partitioned modules. Now add the
     // stubs/globals module and set the update actions.
     auto StubsH =
-      addModule(std::move(StubsModule), MSI, LogicalModule, FallbackLookup);
+      addModule(std::move(StubsModule), MSI, LogicalModule);
 
     for (auto &KVPair : StubInfos) {
       std::string AddrName = Mangle(KVPair.first + AddrSuffix,
@@ -310,8 +327,7 @@ private:
   BaseLayerModuleSetHandleT addModule(
                                std::unique_ptr<Module> M,
                                ModuleSetInfo &MSI,
-                               typename CODScopedLookup::LMHandle LogicalModule,
-                               LookupFtor FallbackLookup) {
+                               typename CODScopedLookup::LMHandle LogicalModule) {
 
     // Add this module to the JIT with a memory manager that uses the
     // DylibLookup to resolve symbols.
@@ -325,7 +341,7 @@ private:
           if (auto Symbol = DylibLookup->findSymbol(LogicalModule, Name))
             return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
                                            Symbol.getFlags());
-          return FallbackLookup(Name);
+          return DylibLookup->externalLookup(Name);
         },
         [=](const std::string &Name) -> RuntimeDyld::SymbolInfo {
           if (auto Symbol = DylibLookup->findSymbol(LogicalModule, Name))