[Orc] It's not valid to pass a null resolver to addModuleSet. Use a no-op
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / IndirectionUtils.h
index 0bc71bfdf33015cf8e9129c908a134fe9b0d2bfd..13ba125e80b4b627b84350451e89b673238b148d 100644 (file)
 #define LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H
 
 #include "JITSymbol.h"
+#include "LambdaResolver.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ExecutionEngine/RuntimeDyld.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Mangler.h"
 #include "llvm/IR/Module.h"
 #include <sstream>
 
 namespace llvm {
+namespace orc {
 
-/// @brief Persistent name mangling.
-///
-///   This class provides name mangling that can outlive a Module (and its
-/// DataLayout).
-class PersistentMangler {
+/// @brief Base class for JITLayer independent aspects of
+///        JITCompileCallbackManager.
+class JITCompileCallbackManagerBase {
 public:
-  PersistentMangler(DataLayout DL) : DL(std::move(DL)), M(&this->DL) {}
 
-  std::string getMangledName(StringRef Name) const {
-    std::string MangledName;
-    {
-      raw_string_ostream MangledNameStream(MangledName);
-      M.getNameWithPrefix(MangledNameStream, Name);
+  typedef std::function<TargetAddress()> CompileFtor;
+  typedef std::function<void(TargetAddress)> UpdateFtor;
+
+  /// @brief Handle to a newly created compile callback. Can be used to get an
+  ///        IR constant representing the address of the trampoline, and to set
+  ///        the compile and update actions for the callback.
+  class CompileCallbackInfo {
+  public:
+    CompileCallbackInfo(TargetAddress Addr, CompileFtor &Compile,
+                        UpdateFtor &Update)
+      : Addr(Addr), Compile(Compile), Update(Update) {}
+
+    TargetAddress getAddress() const { return Addr; }
+    void setCompileAction(CompileFtor Compile) {
+      this->Compile = std::move(Compile);
+    }
+    void setUpdateAction(UpdateFtor Update) {
+      this->Update = std::move(Update);
     }
-    return MangledName;
+  private:
+    TargetAddress Addr;
+    CompileFtor &Compile;
+    UpdateFtor &Update;
+  };
+
+  /// @brief Construct a JITCompileCallbackManagerBase.
+  /// @param ErrorHandlerAddress The address of an error handler in the target
+  ///                            process to be used if a compile callback fails.
+  /// @param NumTrampolinesPerBlock Number of trampolines to emit if there is no
+  ///                             available trampoline when getCompileCallback is
+  ///                             called.
+  JITCompileCallbackManagerBase(TargetAddress ErrorHandlerAddress,
+                                unsigned NumTrampolinesPerBlock)
+    : ErrorHandlerAddress(ErrorHandlerAddress),
+      NumTrampolinesPerBlock(NumTrampolinesPerBlock) {}
+
+  virtual ~JITCompileCallbackManagerBase() {}
+
+  /// @brief Execute the callback for the given trampoline id. Called by the JIT
+  ///        to compile functions on demand.
+  TargetAddress executeCompileCallback(TargetAddress TrampolineID) {
+    TrampolineMapT::iterator I = ActiveTrampolines.find(TrampolineID);
+    // FIXME: Also raise an error in the Orc error-handler when we finally have
+    //        one.
+    if (I == ActiveTrampolines.end())
+      return ErrorHandlerAddress;
+
+    // Found a callback handler. Yank this trampoline out of the active list and
+    // put it back in the available trampolines list, then try to run the
+    // handler's compile and update actions.
+    // Moving the trampoline ID back to the available list first means there's at
+    // least one available trampoline if the compile action triggers a request for
+    // a new one.
+    AvailableTrampolines.push_back(I->first);
+    auto CallbackHandler = std::move(I->second);
+    ActiveTrampolines.erase(I);
+
+    if (auto Addr = CallbackHandler.Compile()) {
+      CallbackHandler.Update(Addr);
+      return Addr;
+    }
+    return ErrorHandlerAddress;
   }
 
-private:
-  DataLayout DL;
-  Mangler M;
-};
+  /// @brief Get/create a compile callback with the given signature.
+  virtual CompileCallbackInfo getCompileCallback(LLVMContext &Context) = 0;
 
-/// @brief Handle callbacks from the JIT process requesting the definitions of
-///        symbols.
-///
-///   This utility is intended to be used to support compile-on-demand for
-/// functions.
-class JITResolveCallbackHandler {
-private:
-  typedef std::vector<std::string> FuncNameList;
+protected:
 
-public:
-  typedef FuncNameList::size_type StubIndex;
+  struct CallbackHandler {
+    CompileFtor Compile;
+    UpdateFtor Update;
+  };
 
-public:
-  /// @brief Create a JITResolveCallbackHandler with the given functors for
-  ///        looking up symbols and updating their use-sites.
-  ///
-  /// @return A JITResolveCallbackHandler instance that will invoke the
-  ///         Lookup and Update functors as needed to resolve missing symbol
-  ///         definitions.
-  template <typename LookupFtor, typename UpdateFtor>
-  static std::unique_ptr<JITResolveCallbackHandler> create(LookupFtor Lookup,
-                                                           UpdateFtor Update);
-
-  /// @brief Destroy instance. Does not modify existing emitted symbols.
-  ///
-  ///   Not-yet-emitted symbols will need to be resolved some other way after
-  /// this class is destroyed.
-  virtual ~JITResolveCallbackHandler() {}
-
-  /// @brief Add a function to be resolved on demand.
-  void addFuncName(std::string Name) { FuncNames.push_back(std::move(Name)); }
-
-  /// @brief Get the name associated with the given index.
-  const std::string &getFuncName(StubIndex Idx) const { return FuncNames[Idx]; }
-
-  /// @brief Returns the number of symbols being managed by this instance.
-  StubIndex getNumFuncs() const { return FuncNames.size(); }
-
-  /// @brief Get the address for the symbol associated with the given index.
-  ///
-  ///   This is expected to be called by code in the JIT process itself, in
-  /// order to resolve a function.
-  virtual TargetAddress resolve(StubIndex StubIdx) = 0;
+  TargetAddress ErrorHandlerAddress;
+  unsigned NumTrampolinesPerBlock;
 
-private:
-  FuncNameList FuncNames;
+  typedef std::map<TargetAddress, CallbackHandler> TrampolineMapT;
+  TrampolineMapT ActiveTrampolines;
+  std::vector<TargetAddress> AvailableTrampolines;
 };
 
-// Implementation class for JITResolveCallbackHandler.
-template <typename LookupFtor, typename UpdateFtor>
-class JITResolveCallbackHandlerImpl : public JITResolveCallbackHandler {
+/// @brief Manage compile callbacks.
+template <typename JITLayerT, typename TargetT>
+class JITCompileCallbackManager : public JITCompileCallbackManagerBase {
 public:
-  JITResolveCallbackHandlerImpl(LookupFtor Lookup, UpdateFtor Update)
-      : Lookup(std::move(Lookup)), Update(std::move(Update)) {}
-
-  TargetAddress resolve(StubIndex StubIdx) override {
-    const std::string &FuncName = getFuncName(StubIdx);
-    TargetAddress Addr = Lookup(FuncName);
-    Update(FuncName, Addr);
-    return Addr;
+
+  /// @brief Construct a JITCompileCallbackManager.
+  /// @param JIT JIT layer to emit callback trampolines, etc. into.
+  /// @param Context LLVMContext to use for trampoline & resolve block modules.
+  /// @param ErrorHandlerAddress The address of an error handler in the target
+  ///                            process to be used if a compile callback fails.
+  /// @param NumTrampolinesPerBlock Number of trampolines to allocate whenever
+  ///                               there is no existing callback trampoline.
+  ///                               (Trampolines are allocated in blocks for
+  ///                               efficiency.)
+  JITCompileCallbackManager(JITLayerT &JIT, RuntimeDyld::MemoryManager &MemMgr,
+                            LLVMContext &Context,
+                            TargetAddress ErrorHandlerAddress,
+                            unsigned NumTrampolinesPerBlock)
+    : JITCompileCallbackManagerBase(ErrorHandlerAddress,
+                                    NumTrampolinesPerBlock),
+      JIT(JIT), MemMgr(MemMgr) {
+    emitResolverBlock(Context);
+  }
+
+  /// @brief Get/create a compile callback with the given signature.
+  CompileCallbackInfo getCompileCallback(LLVMContext &Context) final {
+    TargetAddress TrampolineAddr = getAvailableTrampolineAddr(Context);
+    auto &CallbackHandler =
+      this->ActiveTrampolines[TrampolineAddr];
+
+    return CompileCallbackInfo(TrampolineAddr, CallbackHandler.Compile,
+                               CallbackHandler.Update);
   }
 
 private:
-  LookupFtor Lookup;
-  UpdateFtor Update;
-};
 
-template <typename LookupFtor, typename UpdateFtor>
-std::unique_ptr<JITResolveCallbackHandler>
-JITResolveCallbackHandler::create(LookupFtor Lookup, UpdateFtor Update) {
-  typedef JITResolveCallbackHandlerImpl<LookupFtor, UpdateFtor> Impl;
-  return make_unique<Impl>(std::move(Lookup), std::move(Update));
-}
-
-/// @brief Holds a list of the function names that were indirected, plus
-///        mappings from each of these names to (a) the name of function
-///        providing the implementation for that name (GetImplNames), and
-///        (b) the name of the global variable holding the address of the
-///        implementation.
-///
-///   This data structure can be used with a JITCallbackHandler to look up and
-/// update function implementations when lazily compiling.
-class JITIndirections {
-public:
-  JITIndirections(std::vector<std::string> IndirectedNames,
-                  std::function<std::string(StringRef)> GetImplName,
-                  std::function<std::string(StringRef)> GetAddrName)
-      : IndirectedNames(std::move(IndirectedNames)),
-        GetImplName(std::move(GetImplName)),
-        GetAddrName(std::move(GetAddrName)) {}
-
-  std::vector<std::string> IndirectedNames;
-  std::function<std::string(StringRef Name)> GetImplName;
-  std::function<std::string(StringRef Name)> GetAddrName;
-};
+  std::vector<std::unique_ptr<Module>>
+  SingletonSet(std::unique_ptr<Module> M) {
+    std::vector<std::unique_ptr<Module>> Ms;
+    Ms.push_back(std::move(M));
+    return Ms;
+  }
 
-/// @brief Indirect all calls to functions matching the predicate
-///        ShouldIndirect through a global variable containing the address
-///        of the implementation.
-///
-/// @return An indirection structure containing the functions that had their
-///         call-sites re-written.
-///
-///   For each function 'F' that meets the ShouldIndirect predicate, and that
-/// is called in this Module, add a common-linkage global variable to the
-/// module that will hold the address of the implementation of that function.
-/// Rewrite all call-sites of 'F' to be indirect calls (via the global).
-/// This allows clients, either directly or via a JITCallbackHandler, to
-/// change the address of the implementation of 'F' at runtime.
-///
-/// Important notes:
-///
-///   Single indirection does not preserve pointer equality for 'F'. If the
-/// program was already calling 'F' indirectly through function pointers, or
-/// if it was taking the address of 'F' for the purpose of pointer comparisons
-/// or arithmetic double indirection should be used instead.
-///
-///   This method does *not* initialize the function implementation addresses.
-/// The client must do this prior to running any call-sites that have been
-/// indirected.
-JITIndirections makeCallsSingleIndirect(
-    llvm::Module &M,
-    const std::function<bool(const Function &)> &ShouldIndirect,
-    const char *JITImplSuffix, const char *JITAddrSuffix);
-
-/// @brief Replace the body of functions matching the predicate ShouldIndirect
-///        with indirect calls to the implementation.
-///
-/// @return An indirections structure containing the functions that had their
-///         implementations re-written.
-///
-///   For each function 'F' that meets the ShouldIndirect predicate, add a
-/// common-linkage global variable to the module that will hold the address of
-/// the implementation of that function and rewrite the implementation of 'F'
-/// to call through to the implementation indirectly (via the global).
-/// This allows clients, either directly or via a JITCallbackHandler, to
-/// change the address of the implementation of 'F' at runtime.
-///
-/// Important notes:
-///
-///   Double indirection is slower than single indirection, but preserves
-/// function pointer relation tests and correct behavior for function pointers
-/// (all calls to 'F', direct or indirect) go the address stored in the global
-/// variable at the time of the call.
-///
-///   This method does *not* initialize the function implementation addresses.
-/// The client must do this prior to running any call-sites that have been
-/// indirected.
-JITIndirections makeCallsDoubleIndirect(
-    llvm::Module &M,
-    const std::function<bool(const Function &)> &ShouldIndirect,
-    const char *JITImplSuffix, const char *JITAddrSuffix);
-
-/// @brief Given a set of indirections and a symbol lookup functor, create a
-///        JITResolveCallbackHandler instance that will resolve the
-///        implementations for the indirected symbols on demand.
-template <typename SymbolLookupFtor>
-std::unique_ptr<JITResolveCallbackHandler>
-createCallbackHandlerFromJITIndirections(const JITIndirections &Indirs,
-                                         const PersistentMangler &NM,
-                                         SymbolLookupFtor Lookup) {
-  auto GetImplName = Indirs.GetImplName;
-  auto GetAddrName = Indirs.GetAddrName;
-
-  std::unique_ptr<JITResolveCallbackHandler> J =
-      JITResolveCallbackHandler::create(
-          [=](const std::string &S) {
-            return Lookup(NM.getMangledName(GetImplName(S)));
+  void emitResolverBlock(LLVMContext &Context) {
+    std::unique_ptr<Module> M(new Module("resolver_block_module",
+                                         Context));
+    TargetT::insertResolverBlock(*M, *this);
+    auto NonResolver =
+      createLambdaResolver(
+          [](const std::string &Name) -> RuntimeDyld::SymbolInfo {
+            llvm_unreachable("External symbols in resolver block?");
           },
-          [=](const std::string &S, TargetAddress Addr) {
-            void *ImplPtr = reinterpret_cast<void *>(
-                Lookup(NM.getMangledName(GetAddrName(S))));
-            memcpy(ImplPtr, &Addr, sizeof(TargetAddress));
+          [](const std::string &Name) -> RuntimeDyld::SymbolInfo {
+            llvm_unreachable("Dylib symbols in resolver block?");
           });
+    auto H = JIT.addModuleSet(SingletonSet(std::move(M)), &MemMgr,
+                              std::move(NonResolver));
+    JIT.emitAndFinalize(H);
+    auto ResolverBlockSymbol =
+      JIT.findSymbolIn(H, TargetT::ResolverBlockName, false);
+    assert(ResolverBlockSymbol && "Failed to insert resolver block");
+    ResolverBlockAddr = ResolverBlockSymbol.getAddress();
+  }
 
-  for (const auto &FuncName : Indirs.IndirectedNames)
-    J->addFuncName(FuncName);
-
-  return J;
-}
-
-/// @brief Insert callback asm into module M for the symbols managed by
-///        JITResolveCallbackHandler J.
-void insertX86CallbackAsm(Module &M, JITResolveCallbackHandler &J);
-
-/// @brief Initialize global indirects to point into the callback asm.
-template <typename LookupFtor>
-void initializeFuncAddrs(JITResolveCallbackHandler &J,
-                         const JITIndirections &Indirs,
-                         const PersistentMangler &NM, LookupFtor Lookup) {
-  // Forward declare so that we can access this, even though it's an
-  // implementation detail.
-  std::string getJITResolveCallbackIndexLabel(unsigned I);
-
-  if (J.getNumFuncs() == 0)
-    return;
-
-  //   Force a look up one of the global addresses for a function that has been
-  // indirected. We need to do this to trigger the emission of the module
-  // holding the callback asm. We can't rely on that emission happening
-  // automatically when we look up the callback asm symbols, since lazy-emitting
-  // layers can't see those.
-  Lookup(NM.getMangledName(Indirs.GetAddrName(J.getFuncName(0))));
-
-  // Now update indirects to point to the JIT resolve callback asm.
-  for (JITResolveCallbackHandler::StubIndex I = 0; I < J.getNumFuncs(); ++I) {
-    TargetAddress ResolveCallbackIdxAddr =
-        Lookup(getJITResolveCallbackIndexLabel(I));
-    void *AddrPtr = reinterpret_cast<void *>(
-        Lookup(NM.getMangledName(Indirs.GetAddrName(J.getFuncName(I)))));
-    assert(AddrPtr && "Can't find stub addr global to initialize.");
-    memcpy(AddrPtr, &ResolveCallbackIdxAddr, sizeof(TargetAddress));
+  TargetAddress getAvailableTrampolineAddr(LLVMContext &Context) {
+    if (this->AvailableTrampolines.empty())
+      grow(Context);
+    assert(!this->AvailableTrampolines.empty() &&
+           "Failed to grow available trampolines.");
+    TargetAddress TrampolineAddr = this->AvailableTrampolines.back();
+    this->AvailableTrampolines.pop_back();
+    return TrampolineAddr;
   }
-}
 
-/// @brief Extract all functions matching the predicate ShouldExtract in to
-///        their own modules. (Does not modify the original module.)
-///
-/// @return A set of modules, the first containing all symbols (including
-///         globals and aliases) that did not pass ShouldExtract, and each
-///         subsequent module containing one of the functions that did meet
-///         ShouldExtract.
-///
-///   By adding the resulting modules separately (not as a set) to a
-/// LazyEmittingLayer instance, compilation can be deferred until symbols are
-/// actually needed.
-std::vector<std::unique_ptr<llvm::Module>>
-explode(const llvm::Module &OrigMod,
-        const std::function<bool(const Function &)> &ShouldExtract);
-
-/// @brief Given a module that has been indirectified, break each function
-///        that has been indirected out into its own module. (Does not modify
-///        the original module).
+  void grow(LLVMContext &Context) {
+    assert(this->AvailableTrampolines.empty() && "Growing prematurely?");
+    std::unique_ptr<Module> M(new Module("trampoline_block", Context));
+    auto GetLabelName =
+      TargetT::insertCompileCallbackTrampolines(*M, ResolverBlockAddr,
+                                                this->NumTrampolinesPerBlock,
+                                                this->ActiveTrampolines.size());
+    auto NonResolver =
+      createLambdaResolver(
+          [](const std::string &Name) -> RuntimeDyld::SymbolInfo {
+            llvm_unreachable("External symbols in trampoline block?");
+          },
+          [](const std::string &Name) -> RuntimeDyld::SymbolInfo {
+            llvm_unreachable("Dylib symbols in trampoline block?");
+          });
+    auto H = JIT.addModuleSet(SingletonSet(std::move(M)), &MemMgr,
+                              std::move(NonResolver));
+    JIT.emitAndFinalize(H);
+    for (unsigned I = 0; I < this->NumTrampolinesPerBlock; ++I) {
+      std::string Name = GetLabelName(I);
+      auto TrampolineSymbol = JIT.findSymbolIn(H, Name, false);
+      assert(TrampolineSymbol && "Failed to emit trampoline.");
+      this->AvailableTrampolines.push_back(TrampolineSymbol.getAddress());
+    }
+  }
+
+  JITLayerT &JIT;
+  RuntimeDyld::MemoryManager &MemMgr;
+  TargetAddress ResolverBlockAddr;
+};
+
+/// @brief Get an update functor that updates the value of a named function
+///        pointer.
+template <typename JITLayerT>
+JITCompileCallbackManagerBase::UpdateFtor
+getLocalFPUpdater(JITLayerT &JIT, typename JITLayerT::ModuleSetHandleT H,
+                  std::string Name) {
+    // FIXME: Move-capture Name once we can use C++14.
+    return [=,&JIT](TargetAddress Addr) {
+      auto FPSym = JIT.findSymbolIn(H, Name, true);
+      assert(FPSym && "Cannot find function pointer to update.");
+      void *FPAddr = reinterpret_cast<void*>(
+                       static_cast<uintptr_t>(FPSym.getAddress()));
+      memcpy(FPAddr, &Addr, sizeof(uintptr_t));
+    };
+  }
+
+/// @brief Build a function pointer of FunctionType with the given constant
+///        address.
 ///
-/// @returns A set of modules covering the symbols provided by OrigMod.
-std::vector<std::unique_ptr<llvm::Module>>
-explode(const llvm::Module &OrigMod, const JITIndirections &Indirections);
-}
+///   Usage example: Turn a trampoline address into a function pointer constant
+/// for use in a stub.
+Constant* createIRTypedAddress(FunctionType &FT, TargetAddress Addr);
+
+/// @brief Create a function pointer with the given type, name, and initializer
+///        in the given Module.
+GlobalVariable* createImplPointer(PointerType &PT, Module &M,
+                                  const Twine &Name, Constant *Initializer);
+
+/// @brief Turn a function declaration into a stub function that makes an
+///        indirect call using the given function pointer.
+void makeStub(Function &F, GlobalVariable &ImplPointer);
+
+typedef std::map<Module*, DenseSet<const GlobalValue*>> ModulePartitionMap;
+
+/// @brief Extract subsections of a Module into the given Module according to
+///        the given ModulePartitionMap.
+void partition(Module &M, const ModulePartitionMap &PMap);
+
+/// @brief Struct for trivial "complete" partitioning of a module.
+class FullyPartitionedModule {
+public:
+  std::unique_ptr<Module> GlobalVars;
+  std::unique_ptr<Module> Commons;
+  std::vector<std::unique_ptr<Module>> Functions;
+
+  FullyPartitionedModule() = default;
+  FullyPartitionedModule(FullyPartitionedModule &&S)
+      : GlobalVars(std::move(S.GlobalVars)), Commons(std::move(S.Commons)),
+        Functions(std::move(S.Functions)) {}
+};
+
+/// @brief Extract every function in M into a separate module.
+FullyPartitionedModule fullyPartition(Module &M);
+
+} // End namespace orc.
+} // End namespace llvm.
 
 #endif // LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H