[Orc] Refactor the CompileOnDemandLayer to make its addModuleSet method
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / CompileOnDemandLayer.h
index 31eb1430943617f76ea3ac0e5c7cde7e929adce2..30f7f1cd5f593757fb92a5ef371a4fbe0dbe1a56 100644 (file)
 #define LLVM_EXECUTIONENGINE_ORC_COMPILEONDEMANDLAYER_H
 
 #include "IndirectionUtils.h"
-#include "LookasideRTDyldMM.h"
+#include "LambdaResolver.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
 #include <list>
 
 namespace llvm {
+namespace orc {
 
 /// @brief Compile-on-demand layer.
 ///
@@ -33,8 +34,9 @@ namespace llvm {
 /// It is expected that this layer will frequently be used on top of a
 /// LazyEmittingLayer. The combination of the two ensures that each function is
 /// compiled only when it is first called.
-template <typename BaseLayerT> class CompileOnDemandLayer {
-public:
+template <typename BaseLayerT, typename CompileCallbackMgrT>
+class CompileOnDemandLayer {
+private:
   /// @brief Lookup helper that provides compatibility with the classic
   ///        static-compilation symbol resolution process.
   ///
@@ -62,6 +64,8 @@ public:
     /// @brief Construct a scoped lookup.
     CODScopedLookup(BaseLayerT &BaseLayer) : BaseLayer(BaseLayer) {}
 
+    virtual ~CODScopedLookup() {}
+
     /// @brief Start a new context for a single logical module.
     LMHandle createLogicalModule() {
       Handles.push_back(SiblingHandlesList());
@@ -78,31 +82,58 @@ public:
     void removeLogicalModule(LMHandle LMH) { Handles.erase(LMH); }
 
     /// @brief Look up a symbol in this context.
-    uint64_t lookup(LMHandle LMH, const std::string &Name) {
-      if (uint64_t Addr = lookupOnlyIn(LMH, Name))
-        return Addr;
+    JITSymbol findSymbol(LMHandle LMH, const std::string &Name) {
+      if (auto Symbol = findSymbolIn(LMH, Name))
+        return Symbol;
 
       for (auto I = Handles.begin(), E = Handles.end(); I != E; ++I)
         if (I != LMH)
-          if (uint64_t Addr = lookupOnlyIn(I, Name))
-            return Addr;
+          if (auto Symbol = findSymbolIn(I, Name))
+            return Symbol;
 
-      return 0;
+      return nullptr;
     }
 
+    /// @brief Find an external symbol (via the user supplied SymbolResolver).
+    virtual RuntimeDyld::SymbolInfo
+    externalLookup(const std::string &Name) const = 0;
+
   private:
-    uint64_t lookupOnlyIn(LMHandle LMH, const std::string &Name) {
+
+    JITSymbol findSymbolIn(LMHandle LMH, const std::string &Name) {
       for (auto H : *LMH)
-        if (uint64_t Addr = BaseLayer.lookupSymbolAddressIn(H, Name, false))
-          return Addr;
-      return 0;
+        if (auto Symbol = BaseLayer.findSymbolIn(H, Name, false))
+          return Symbol;
+      return nullptr;
     }
 
     BaseLayerT &BaseLayer;
     PseudoDylibModuleSetHandlesList Handles;
   };
 
-private:
+  template <typename ResolverPtrT>
+  class CODScopedLookupImpl : public CODScopedLookup {
+  public:
+    CODScopedLookupImpl(BaseLayerT &BaseLayer, ResolverPtrT Resolver)
+      : CODScopedLookup(BaseLayer), Resolver(std::move(Resolver)) {}
+
+    RuntimeDyld::SymbolInfo
+    externalLookup(const std::string &Name) const override {
+      return Resolver->findSymbol(Name);
+    }
+
+  private:
+    ResolverPtrT Resolver;
+  };
+
+  template <typename ResolverPtrT>
+  static std::shared_ptr<CODScopedLookup>
+  createCODScopedLookup(BaseLayerT &BaseLayer,
+                        ResolverPtrT Resolver) {
+    typedef CODScopedLookupImpl<ResolverPtrT> Impl;
+    return std::make_shared<Impl>(BaseLayer, std::move(Resolver));
+  }
+
   typedef typename BaseLayerT::ModuleSetHandleT BaseLayerModuleSetHandleT;
   typedef std::vector<BaseLayerModuleSetHandleT> BaseLayerModuleSetHandleListT;
 
@@ -113,13 +144,6 @@ private:
     // Logical module handles.
     std::vector<typename CODScopedLookup::LMHandle> LMHandles;
 
-    // Persistent manglers - one per TU.
-    std::vector<PersistentMangler> PersistentManglers;
-
-    // Symbol resolution callback handlers - one per TU.
-    std::vector<std::unique_ptr<JITResolveCallbackHandler>>
-        JITResolveCallbackHandlers;
-
     // List of vectors of module set handles:
     // One vector per logical module - each vector holds the handles for the
     // exploded modules for that logical module in the base layer.
@@ -142,89 +166,31 @@ public:
   /// @brief Handle to a set of loaded modules.
   typedef typename ModuleSetInfoListT::iterator ModuleSetHandleT;
 
-  /// @brief Convenience typedef for callback inserter.
-  typedef std::function<void(Module&, JITResolveCallbackHandler&)>
-    InsertCallbackAsmFtor;
-
   /// @brief Construct a compile-on-demand layer instance.
-  CompileOnDemandLayer(BaseLayerT &BaseLayer,
-                       InsertCallbackAsmFtor InsertCallbackAsm)
-    : BaseLayer(BaseLayer), InsertCallbackAsm(InsertCallbackAsm) {}
+  CompileOnDemandLayer(BaseLayerT &BaseLayer, CompileCallbackMgrT &CallbackMgr)
+      : BaseLayer(BaseLayer), CompileCallbackMgr(CallbackMgr) {}
 
   /// @brief Add a module to the compile-on-demand layer.
-  template <typename ModuleSetT>
+  template <typename ModuleSetT, typename MemoryManagerPtrT,
+            typename SymbolResolverPtrT>
   ModuleSetHandleT addModuleSet(ModuleSetT Ms,
-                                std::unique_ptr<RTDyldMemoryManager> MM) {
+                                MemoryManagerPtrT MemMgr,
+                                SymbolResolverPtrT Resolver) {
 
-    const char *JITAddrSuffix = "$orc_addr";
-    const char *JITImplSuffix = "$orc_impl";
+    assert(MemMgr == nullptr &&
+           "User supplied memory managers not supported with COD yet.");
 
-    // Create a symbol lookup context and ModuleSetInfo for this module set.
-    auto DylibLookup = std::make_shared<CODScopedLookup>(BaseLayer);
+    // Create a lookup context and ModuleSetInfo for this module set.
+    // For the purposes of symbol resolution the set Ms will be treated as if
+    // the modules it contained had been linked together as a dylib.
+    auto DylibLookup = createCODScopedLookup(BaseLayer, std::move(Resolver));
     ModuleSetHandleT H =
         ModuleSetInfos.insert(ModuleSetInfos.end(), ModuleSetInfo(DylibLookup));
     ModuleSetInfo &MSI = ModuleSetInfos.back();
 
-    // Process each of the modules in this module set. All modules share the
-    // same lookup context, but each will get its own TU lookup context.
-    for (auto &M : Ms) {
-
-      // Create a TU lookup context for this module.
-      auto LMH = DylibLookup->createLogicalModule();
-      MSI.LMHandles.push_back(LMH);
-
-      // Create a persistent mangler for this module.
-      MSI.PersistentManglers.emplace_back(*M->getDataLayout());
-
-      // Make all calls to functions defined in this module indirect.
-      JITIndirections Indirections =
-          makeCallsDoubleIndirect(*M, [](const Function &) { return true; },
-                                  JITImplSuffix, JITAddrSuffix);
-
-      // Then carve up the module into a bunch of single-function modules.
-      std::vector<std::unique_ptr<Module>> ExplodedModules =
-          explode(*M, Indirections);
-
-      // Add a resolve-callback handler for this module to look up symbol
-      // addresses when requested via a callback.
-      MSI.JITResolveCallbackHandlers.push_back(
-          createCallbackHandlerFromJITIndirections(
-              Indirections, MSI.PersistentManglers.back(),
-              [=](StringRef S) { return DylibLookup->lookup(LMH, S); }));
-
-      // Insert callback asm code into the first module.
-      InsertCallbackAsm(*ExplodedModules[0],
-                        *MSI.JITResolveCallbackHandlers.back());
-
-      // Now we need to take each of the extracted Modules and add them to
-      // base layer. Each Module will be added individually to make sure they
-      // can be compiled separately, and each will get its own lookaside
-      // memory manager with lookup functors that resolve symbols in sibling
-      // modules first.OA
-      for (auto &M : ExplodedModules) {
-        std::vector<std::unique_ptr<Module>> MSet;
-        MSet.push_back(std::move(M));
-
-        BaseLayerModuleSetHandleT H = BaseLayer.addModuleSet(
-            std::move(MSet),
-            createLookasideRTDyldMM<SectionMemoryManager>(
-                [=](const std::string &Name) {
-                  if (uint64_t Addr = DylibLookup->lookup(LMH, Name))
-                    return Addr;
-                  return getSymbolAddress(Name, true);
-                },
-                [=](const std::string &Name) {
-                  return DylibLookup->lookup(LMH, Name);
-                }));
-        DylibLookup->addToLogicalModule(LMH, H);
-        MSI.BaseLayerModuleSetHandles.push_back(H);
-      }
-
-      initializeFuncAddrs(*MSI.JITResolveCallbackHandlers.back(), Indirections,
-                          MSI.PersistentManglers.back(), [=](StringRef S) {
-                            return DylibLookup->lookup(LMH, S);
-                          });
-    }
+    // Process each of the modules in this module set.
+    for (auto &M : Ms)
+      partitionAndAdd(*M, MSI);
 
     return H;
   }
@@ -238,30 +204,179 @@ public:
     ModuleSetInfos.erase(H);
   }
 
-  /// @brief Get the address of a symbol provided by this layer, or some layer
-  ///        below this one.
-  uint64_t getSymbolAddress(const std::string &Name, bool ExportedSymbolsOnly) {
-    return BaseLayer.getSymbolAddress(Name, ExportedSymbolsOnly);
+  /// @brief Search for the given named symbol.
+  /// @param Name The name of the symbol to search for.
+  /// @param ExportedSymbolsOnly If true, search only for exported symbols.
+  /// @return A handle for the given named symbol, if it exists.
+  JITSymbol findSymbol(StringRef Name, bool ExportedSymbolsOnly) {
+    return BaseLayer.findSymbol(Name, ExportedSymbolsOnly);
   }
 
   /// @brief Get the address of a symbol provided by this layer, or some layer
   ///        below this one.
-  uint64_t lookupSymbolAddressIn(ModuleSetHandleT H, const std::string &Name,
-                                 bool ExportedSymbolsOnly) {
-    BaseLayerModuleSetHandleListT &BaseLayerHandles = H->second;
-    for (auto &BH : BaseLayerHandles) {
-      if (uint64_t Addr =
-            BaseLayer.lookupSymbolAddressIn(BH, Name, ExportedSymbolsOnly))
-        return Addr;
+  JITSymbol findSymbolIn(ModuleSetHandleT H, const std::string &Name,
+                         bool ExportedSymbolsOnly) {
+
+    for (auto &BH : H->BaseLayerModuleSetHandles) {
+      if (auto Symbol = BaseLayer.findSymbolIn(BH, Name, ExportedSymbolsOnly))
+        return Symbol;
     }
-    return 0;
+    return nullptr;
   }
 
 private:
+
+  void partitionAndAdd(Module &M, ModuleSetInfo &MSI) {
+    const char *AddrSuffix = "$orc_addr";
+    const char *BodySuffix = "$orc_body";
+
+    // We're going to break M up into a bunch of sub-modules, but we want
+    // internal linkage symbols to still resolve sensibly. CODScopedLookup
+    // provides the "logical module" concept to make this work, so create a
+    // new logical module for M.
+    auto DylibLookup = MSI.Lookup;
+    auto LogicalModule = DylibLookup->createLogicalModule();
+    MSI.LMHandles.push_back(LogicalModule);
+
+    // Partition M into a "globals and stubs" module, a "common symbols" module,
+    // and a list of single-function modules.
+    auto PartitionedModule = fullyPartition(M);
+    auto StubsModule = std::move(PartitionedModule.GlobalVars);
+    auto CommonsModule = std::move(PartitionedModule.Commons);
+    auto FunctionModules = std::move(PartitionedModule.Functions);
+
+    // Emit the commons stright away.
+    auto CommonHandle = addModule(std::move(CommonsModule), MSI, LogicalModule);
+    BaseLayer.emitAndFinalize(CommonHandle);
+
+    // Map of definition names to callback-info data structures. We'll use
+    // this to build the compile actions for the stubs below.
+    typedef std::map<std::string,
+                     typename CompileCallbackMgrT::CompileCallbackInfo>
+      StubInfoMap;
+    StubInfoMap StubInfos;
+
+    // Now we need to take each of the extracted Modules and add them to
+    // base layer. Each Module will be added individually to make sure they
+    // can be compiled separately, and each will get its own lookaside
+    // memory manager that will resolve within this logical module first.
+    for (auto &SubM : FunctionModules) {
+
+      // Keep track of the stubs we create for this module so that we can set
+      // their compile actions.
+      std::vector<typename StubInfoMap::iterator> NewStubInfos;
+
+      // Search for function definitions and insert stubs into the stubs
+      // module.
+      for (auto &F : *SubM) {
+        if (F.isDeclaration())
+          continue;
+
+        std::string Name = F.getName();
+        Function *Proto = StubsModule->getFunction(Name);
+        assert(Proto && "Failed to clone function decl into stubs module.");
+        auto CallbackInfo =
+          CompileCallbackMgr.getCompileCallback(Proto->getContext());
+        GlobalVariable *FunctionBodyPointer =
+          createImplPointer(*Proto->getType(), *Proto->getParent(),
+                            Name + AddrSuffix,
+                            createIRTypedAddress(*Proto->getFunctionType(),
+                                                 CallbackInfo.getAddress()));
+        makeStub(*Proto, *FunctionBodyPointer);
+
+        F.setName(Name + BodySuffix);
+        F.setVisibility(GlobalValue::HiddenVisibility);
+
+        auto KV = std::make_pair(std::move(Name), std::move(CallbackInfo));
+        NewStubInfos.push_back(StubInfos.insert(StubInfos.begin(), KV));
+      }
+
+      auto H = addModule(std::move(SubM), MSI, LogicalModule);
+
+      // Set the compile actions for this module:
+      for (auto &KVPair : NewStubInfos) {
+        std::string BodyName = Mangle(KVPair->first + BodySuffix,
+                                      M.getDataLayout());
+        auto &CCInfo = KVPair->second;
+        CCInfo.setCompileAction(
+          [=](){
+            return BaseLayer.findSymbolIn(H, BodyName, false).getAddress();
+          });
+      }
+
+    }
+
+    // Ok - we've processed all the partitioned modules. Now add the
+    // stubs/globals module and set the update actions.
+    auto StubsH =
+      addModule(std::move(StubsModule), MSI, LogicalModule);
+
+    for (auto &KVPair : StubInfos) {
+      std::string AddrName = Mangle(KVPair.first + AddrSuffix,
+                                    M.getDataLayout());
+      auto &CCInfo = KVPair.second;
+      CCInfo.setUpdateAction(
+        getLocalFPUpdater(BaseLayer, StubsH, AddrName));
+    }
+  }
+
+  // Add the given Module to the base layer using a memory manager that will
+  // perform the appropriate scoped lookup (i.e. will look first with in the
+  // module from which it was extracted, then into the set to which that module
+  // belonged, and finally externally).
+  BaseLayerModuleSetHandleT addModule(
+                               std::unique_ptr<Module> M,
+                               ModuleSetInfo &MSI,
+                               typename CODScopedLookup::LMHandle LogicalModule) {
+
+    // Add this module to the JIT with a memory manager that uses the
+    // DylibLookup to resolve symbols.
+    std::vector<std::unique_ptr<Module>> MSet;
+    MSet.push_back(std::move(M));
+
+    auto DylibLookup = MSI.Lookup;
+    auto Resolver =
+      createLambdaResolver(
+        [=](const std::string &Name) {
+          if (auto Symbol = DylibLookup->findSymbol(LogicalModule, Name))
+            return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
+                                           Symbol.getFlags());
+          return DylibLookup->externalLookup(Name);
+        },
+        [=](const std::string &Name) -> RuntimeDyld::SymbolInfo {
+          if (auto Symbol = DylibLookup->findSymbol(LogicalModule, Name))
+            return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
+                                           Symbol.getFlags());
+          return nullptr;
+        });
+
+    BaseLayerModuleSetHandleT H =
+      BaseLayer.addModuleSet(std::move(MSet),
+                             make_unique<SectionMemoryManager>(),
+                             std::move(Resolver));
+    // Add this module to the logical module lookup.
+    DylibLookup->addToLogicalModule(LogicalModule, H);
+    MSI.BaseLayerModuleSetHandles.push_back(H);
+
+    return H;
+  }
+
+  static std::string Mangle(StringRef Name, const DataLayout &DL) {
+    Mangler M(&DL);
+    std::string MangledName;
+    {
+      raw_string_ostream MangledNameStream(MangledName);
+      M.getNameWithPrefix(MangledNameStream, Name);
+    }
+    return MangledName;
+  }
+
   BaseLayerT &BaseLayer;
-  InsertCallbackAsmFtor InsertCallbackAsm;
+  CompileCallbackMgrT &CompileCallbackMgr;
   ModuleSetInfoListT ModuleSetInfos;
 };
-}
+
+} // End namespace orc.
+} // End namespace llvm.
 
 #endif // LLVM_EXECUTIONENGINE_ORC_COMPILEONDEMANDLAYER_H