[Orc] Rename JITCompileCallbackManagerBase to JITCompileCallbackManager.
[oota-llvm.git] / lib / ExecutionEngine / Orc / OrcCBindingsStack.h
index 5b58c0098781a6c77cf8aa7adb4fc10bd31fd4b3..d2f7fe4ac0ef5b848a0cf6b1cdc3df075c54161e 100644 (file)
 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
 #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
 #include "llvm/IR/LLVMContext.h"
+#include "llvm-c/OrcBindings.h"
 
 namespace llvm {
 
-class OrcCBindingsStack {
-private:
+class OrcCBindingsStack;
 
-public:
+DEFINE_SIMPLE_CONVERSION_FUNCTIONS(OrcCBindingsStack, LLVMOrcJITStackRef)
+DEFINE_SIMPLE_CONVERSION_FUNCTIONS(TargetMachine, LLVMTargetMachineRef)
 
-  typedef orc::TargetAddress (*CExternalSymbolResolverFn)(const char *Name,
-                                                          void *Ctx);
+class OrcCBindingsStack {
+public:
 
-  typedef orc::JITCompileCallbackManagerBase CompileCallbackMgr;
+  typedef orc::JITCompileCallbackManager CompileCallbackMgr;
   typedef orc::ObjectLinkingLayer<> ObjLayerT;
   typedef orc::IRCompileLayer<ObjLayerT> CompileLayerT;
   typedef orc::CompileOnDemandLayer<CompileLayerT, CompileCallbackMgr> CODLayerT;
 
-  typedef std::function<
-            std::unique_ptr<CompileCallbackMgr>(CompileLayerT&,
-                                                RuntimeDyld::MemoryManager&,
-                                                LLVMContext&)>
+  typedef std::function<std::unique_ptr<CompileCallbackMgr>()>
     CallbackManagerBuilder;
 
   typedef CODLayerT::IndirectStubsManagerBuilderT IndirectStubsManagerBuilder;
 
 private:
 
-  typedef enum { Invalid, CODLayerHandle, ObjectLayerHandle } HandleType;
-  union RawHandleUnion {
-    RawHandleUnion() { memset(this, 0, sizeof(RawHandleUnion)); }
-    ObjLayerT::ObjSetHandleT Obj;
-    CODLayerT::ModuleSetHandleT COD;
+  class GenericHandle {
+  public:
+    virtual ~GenericHandle() {}
+    virtual orc::JITSymbol findSymbolIn(const std::string &Name,
+                                        bool ExportedSymbolsOnly) = 0;
+    virtual void removeModule() = 0;
   };
 
-  struct ModuleHandleData {
-
-    ModuleHandleData() : Type(Invalid) {}
-
-    ModuleHandleData(ObjLayerT::ObjSetHandleT H)
-        : Type(ObjectLayerHandle) {
-      RawHandle.Obj = std::move(H);
+  template <typename LayerT>
+  class GenericHandleImpl : public GenericHandle {
+  public:
+    GenericHandleImpl(LayerT &Layer, typename LayerT::ModuleSetHandleT Handle)
+      : Layer(Layer), Handle(std::move(Handle)) {}
+
+    orc::JITSymbol findSymbolIn(const std::string &Name,
+                                bool ExportedSymbolsOnly) override {
+      return Layer.findSymbolIn(Handle, Name, ExportedSymbolsOnly);
     }
 
-    ModuleHandleData(CODLayerT::ModuleSetHandleT H)
-      : Type(CODLayerHandle) {
-      RawHandle.COD = std::move(H);
+    void removeModule() override {
+      return Layer.removeModuleSet(Handle);
     }
 
-    HandleType Type;
-    RawHandleUnion RawHandle;
+  private:
+    LayerT &Layer;
+    typename LayerT::ModuleSetHandleT Handle;
   };
 
+  template <typename LayerT>
+  std::unique_ptr<GenericHandleImpl<LayerT>>
+  createGenericHandle(LayerT &Layer, typename LayerT::ModuleSetHandleT Handle) {
+    return llvm::make_unique<GenericHandleImpl<LayerT>>(Layer,
+                                                        std::move(Handle));
+  }
+
 public:
 
   // We need a 'ModuleSetHandleT' to conform to the layer concept.
@@ -75,19 +83,19 @@ public:
 
   typedef unsigned ModuleHandleT;
 
-  static CallbackManagerBuilder createCallbackManagerBuilder(Triple T);
+  static std::unique_ptr<CompileCallbackMgr> createCompileCallbackMgr(Triple T);
   static IndirectStubsManagerBuilder createIndirectStubsMgrBuilder(Triple T);
 
-  OrcCBindingsStack(TargetMachine &TM, LLVMContext &Context,
-                    CallbackManagerBuilder &BuildCallbackMgr,
+  OrcCBindingsStack(TargetMachine &TM,
+                   std::unique_ptr<CompileCallbackMgr> CCMgr, 
                     IndirectStubsManagerBuilder IndirectStubsMgrBuilder)
-    : DL(TM.createDataLayout()),
+    : DL(TM.createDataLayout()), CCMgr(std::move(CCMgr)),
       ObjectLayer(),
       CompileLayer(ObjectLayer, orc::SimpleCompiler(TM)),
-      CCMgr(BuildCallbackMgr(CompileLayer, CCMgrMemMgr, Context)),
       CODLayer(CompileLayer,
                [](Function &F) { std::set<Function*> S; S.insert(&F); return S; },
-               *CCMgr, std::move(IndirectStubsMgrBuilder), false),
+               *this->CCMgr, std::move(IndirectStubsMgrBuilder), false),
+      IndirectStubsMgr(IndirectStubsMgrBuilder()),
       CXXRuntimeOverrides([this](const std::string &S) { return mangle(S); }) {}
 
   ~OrcCBindingsStack() {
@@ -112,8 +120,27 @@ public:
     return reinterpret_cast<PtrTy>(static_cast<uintptr_t>(Addr));
   }
 
+  orc::TargetAddress
+  createLazyCompileCallback(LLVMOrcLazyCompileCallbackFn Callback,
+                            void *CallbackCtx) {
+    auto CCInfo = CCMgr->getCompileCallback();
+    CCInfo.setCompileAction(
+      [=]() -> orc::TargetAddress {
+        return Callback(wrap(this), CallbackCtx);
+      });
+    return CCInfo.getAddress();
+  }
+
+  void createIndirectStub(StringRef StubName, orc::TargetAddress Addr) {
+    IndirectStubsMgr->createStub(StubName, Addr, JITSymbolFlags::Exported);
+  }
+
+  void setIndirectStubPointer(StringRef Name, orc::TargetAddress Addr) {
+    IndirectStubsMgr->updatePointer(Name, Addr);
+  }
+
   std::shared_ptr<RuntimeDyld::SymbolResolver>
-  createResolver(CExternalSymbolResolverFn ExternalResolver,
+  createResolver(LLVMOrcSymbolResolverFn ExternalResolver,
                  void *ExternalResolverCtx) {
     auto Resolver = orc::createLambdaResolver(
       [this, ExternalResolver, ExternalResolverCtx](const std::string &Name) {
@@ -147,7 +174,7 @@ public:
   ModuleHandleT addIRModule(LayerT &Layer,
                             Module *M,
                             std::unique_ptr<RuntimeDyld::MemoryManager> MemMgr,
-                            CExternalSymbolResolverFn ExternalResolver,
+                            LLVMOrcSymbolResolverFn ExternalResolver,
                             void *ExternalResolverCtx) {
 
     // Attach a data-layout if one isn't already present.
@@ -171,7 +198,7 @@ public:
 
     auto LH = Layer.addModuleSet(std::move(S), std::move(MemMgr),
                                  std::move(Resolver));
-    ModuleHandleT H = createHandle(LH);
+    ModuleHandleT H = createHandle(Layer, LH);
 
     // Run the static constructors, and save the static destructor runner for
     // execution when the JIT is torn down.
@@ -184,7 +211,7 @@ public:
   }
 
   ModuleHandleT addIRModuleEager(Module* M,
-                                 CExternalSymbolResolverFn ExternalResolver,
+                                 LLVMOrcSymbolResolverFn ExternalResolver,
                                  void *ExternalResolverCtx) {
     return addIRModule(CompileLayer, std::move(M),
                        llvm::make_unique<SectionMemoryManager>(),
@@ -192,58 +219,43 @@ public:
   }
 
   ModuleHandleT addIRModuleLazy(Module* M,
-                                CExternalSymbolResolverFn ExternalResolver,
+                                LLVMOrcSymbolResolverFn ExternalResolver,
                                 void *ExternalResolverCtx) {
     return addIRModule(CODLayer, std::move(M), nullptr,
                        std::move(ExternalResolver), ExternalResolverCtx);
   }
 
   void removeModule(ModuleHandleT H) {
-    auto &HD = HandleData[H];
-    switch (HD.Type) {
-    case ObjectLayerHandle:
-      ObjectLayer.removeObjectSet(HD.RawHandle.Obj);
-      break;
-    case CODLayerHandle:
-      CODLayer.removeModuleSet(HD.RawHandle.COD);
-      break;
-    default:
-      llvm_unreachable("removeModule called on invalid handle type");
-    }
+    GenericHandles[H]->removeModule();
+    GenericHandles[H] = nullptr;
+    FreeHandleIndexes.push_back(H);
   }
 
   orc::JITSymbol findSymbol(const std::string &Name, bool ExportedSymbolsOnly) {
+    if (auto Sym = IndirectStubsMgr->findStub(Name, ExportedSymbolsOnly))
+      return Sym;
     return CODLayer.findSymbol(mangle(Name), ExportedSymbolsOnly);
   }
 
   orc::JITSymbol findSymbolIn(ModuleHandleT H, const std::string &Name,
                               bool ExportedSymbolsOnly) {
-    auto &HD = HandleData[H];
-    switch (HD.Type) {
-    case ObjectLayerHandle:
-      return ObjectLayer.findSymbolIn(HD.RawHandle.Obj, mangle(Name),
-                                      ExportedSymbolsOnly);
-    case CODLayerHandle:
-      return CODLayer.findSymbolIn(HD.RawHandle.COD, mangle(Name),
-                                   ExportedSymbolsOnly);
-    default:
-      llvm_unreachable("removeModule called on invalid handle type");
-    }
+    return GenericHandles[H]->findSymbolIn(Name, ExportedSymbolsOnly);
   }
 
 private:
 
-  template <typename LayerHandleT>
-  unsigned createHandle(LayerHandleT LH) {
+  template <typename LayerT>
+  unsigned createHandle(LayerT &Layer,
+                        typename LayerT::ModuleSetHandleT Handle) {
     unsigned NewHandle;
-    if (!FreeHandles.empty()) {
-      NewHandle = FreeHandles.back();
-      FreeHandles.pop_back();
-      HandleData[NewHandle] = ModuleHandleData(std::move(LH));
+    if (!FreeHandleIndexes.empty()) {
+      NewHandle = FreeHandleIndexes.back();
+      FreeHandleIndexes.pop_back();
+      GenericHandles[NewHandle] = createGenericHandle(Layer, std::move(Handle));
       return NewHandle;
     } else {
-      NewHandle = HandleData.size();
-      HandleData.push_back(ModuleHandleData(std::move(LH)));
+      NewHandle = GenericHandles.size();
+      GenericHandles.push_back(createGenericHandle(Layer, std::move(Handle)));
     }
     return NewHandle;
   }
@@ -251,13 +263,15 @@ private:
   DataLayout DL;
   SectionMemoryManager CCMgrMemMgr;
 
+  std::unique_ptr<CompileCallbackMgr> CCMgr;
   ObjLayerT ObjectLayer;
   CompileLayerT CompileLayer;
-  std::unique_ptr<CompileCallbackMgr> CCMgr;
   CODLayerT CODLayer;
 
-  std::vector<ModuleHandleData> HandleData;
-  std::vector<unsigned> FreeHandles;
+  std::unique_ptr<orc::IndirectStubsManagerBase> IndirectStubsMgr;
+
+  std::vector<std::unique_ptr<GenericHandle>> GenericHandles;
+  std::vector<unsigned> FreeHandleIndexes;
 
   orc::LocalCXXRuntimeOverrides CXXRuntimeOverrides;
   std::vector<orc::CtorDtorRunner<OrcCBindingsStack>> IRStaticDestructorRunners;