[MCJIT][Orc] Refactor RTDyldMemoryManager, weave RuntimeDyld::SymbolInfo through
[oota-llvm.git] / examples / Kaleidoscope / Orc / lazy_irgen / toy.cpp
index 2963f30e2edbea490c103977345468edf94d2e40..c489a450d79da8a3fb253314218a8c855e04d3d1 100644 (file)
@@ -1,15 +1,16 @@
 #include "llvm/Analysis/Passes.h"
 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
+#include "llvm/ExecutionEngine/Orc/LambdaResolver.h"
 #include "llvm/ExecutionEngine/Orc/LazyEmittingLayer.h"
 #include "llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LegacyPassManager.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
-#include "llvm/PassManager.h"
 #include "llvm/Support/TargetSelect.h"
 #include "llvm/Transforms/Scalar.h"
 #include <cctype>
@@ -19,7 +20,9 @@
 #include <sstream>
 #include <string>
 #include <vector>
+
 using namespace llvm;
+using namespace llvm::orc;
 
 //===----------------------------------------------------------------------===//
 // Lexer
@@ -679,14 +682,18 @@ std::string MakeLegalFunctionName(std::string Name)
 
 class SessionContext {
 public:
-  SessionContext(LLVMContext &C) : Context(C) {}
+  SessionContext(LLVMContext &C)
+    : Context(C), TM(EngineBuilder().selectTarget()) {}
   LLVMContext& getLLVMContext() const { return Context; }
+  TargetMachine& getTarget() { return *TM; }
   void addPrototypeAST(std::unique_ptr<PrototypeAST> P);
   PrototypeAST* getPrototypeAST(const std::string &Name);
-  std::map<std::string, std::unique_ptr<FunctionAST>> FunctionDefs; 
 private:
   typedef std::map<std::string, std::unique_ptr<PrototypeAST>> PrototypeMap;
+  
   LLVMContext &Context;
+  std::unique_ptr<TargetMachine> TM;
+      
   PrototypeMap Prototypes;
 };
 
@@ -708,7 +715,9 @@ public:
     : Session(S),
       M(new Module(GenerateUniqueName("jit_module_"),
                    Session.getLLVMContext())),
-      Builder(Session.getLLVMContext()) {}
+      Builder(Session.getLLVMContext()) {
+    M->setDataLayout(*Session.getTarget().getDataLayout());
+  }
 
   SessionContext& getSession() { return Session; }
   Module& getM() const { return *M; }
@@ -1137,15 +1146,27 @@ static std::unique_ptr<llvm::Module> IRGen(SessionContext &S,
   return C.takeM();
 }
 
+template <typename T>
+static std::vector<T> singletonSet(T t) {
+  std::vector<T> Vec;
+  Vec.push_back(std::move(t));
+  return Vec;
+}
+
 class KaleidoscopeJIT {
 public:
   typedef ObjectLinkingLayer<> ObjLayerT;
   typedef IRCompileLayer<ObjLayerT> CompileLayerT;
   typedef LazyEmittingLayer<CompileLayerT> LazyEmitLayerT;
-
   typedef LazyEmitLayerT::ModuleSetHandleT ModuleHandleT;
 
-  std::string Mangle(const std::string &Name) {
+  KaleidoscopeJIT(SessionContext &Session)
+    : Session(Session),
+      Mang(Session.getTarget().getDataLayout()),
+      CompileLayer(ObjectLayer, SimpleCompiler(Session.getTarget())),
+      LazyEmitLayer(CompileLayer) {}
+
+  std::string mangle(const std::string &Name) {
     std::string MangledName;
     {
       raw_string_ostream MangledNameStream(MangledName);
@@ -1154,76 +1175,79 @@ public:
     return MangledName;
   }
 
-  KaleidoscopeJIT(SessionContext &Session)
-    : TM(EngineBuilder().selectTarget()),
-      Mang(TM->getDataLayout()), Session(Session),
-      CompileLayer(ObjectLayer, SimpleCompiler(*TM)),
-      LazyEmitLayer(CompileLayer) {}
+  void addFunctionAST(std::unique_ptr<FunctionAST> FnAST) {
+    std::cerr << "Adding AST: " << FnAST->Proto->Name << "\n";
+    FunctionDefs[mangle(FnAST->Proto->Name)] = std::move(FnAST);
+  }
 
   ModuleHandleT addModule(std::unique_ptr<Module> M) {
-    if (!M->getDataLayout())
-      M->setDataLayout(TM->getDataLayout());
-
-    // The LazyEmitLayer takes lists of modules, rather than single modules, so
-    // we'll just build a single-element list.
-    std::vector<std::unique_ptr<Module>> S;
-    S.push_back(std::move(M));
-
     // We need a memory manager to allocate memory and resolve symbols for this
-    // new module. Create one that resolves symbols by looking back into the JIT.
-    auto MM = createLookasideRTDyldMM<SectionMemoryManager>(
-                [&](const std::string &Name) -> uint64_t {
-                  // First try to find 'Name' within the JIT.
-                  if (uint64_t Addr = getMangledSymbolAddress(Name))
-                    return Addr;
-
-                  // If we don't find 'Name' in the JIT, see if we have some AST
-                  // for it.
-                  auto DefI = Session.FunctionDefs.find(Name);
-                  if (DefI == Session.FunctionDefs.end())
-                    return 0;
-
-                  // We have AST for 'Name'. IRGen it, add it to the JIT, and
-                  // return the address for it.
-                  // FIXME: What happens if IRGen fails?
-                  addModule(IRGen(Session, *DefI->second));
-
-                  // Remove the function definition's AST now that we've
-                  // finished with it.
-                  Session.FunctionDefs.erase(DefI);
-
-                  return getMangledSymbolAddress(Name);
-                }, 
-                [](const std::string &S) { return 0; } );
-
-    return LazyEmitLayer.addModuleSet(std::move(S), std::move(MM));
+    // new module. Create one that resolves symbols by looking back into the
+    // JIT.
+    auto Resolver = createLambdaResolver(
+                      [&](const std::string &Name) {
+                        // First try to find 'Name' within the JIT.
+                        if (auto Symbol = findSymbol(Name))
+                          return RuntimeDyld::SymbolInfo(Symbol.getAddress(),
+                                                         Symbol.getFlags());
+
+                        // If we don't already have a definition of 'Name' then search
+                        // the ASTs.
+                        return searchFunctionASTs(Name);
+                      },
+                      [](const std::string &S) { return nullptr; } );
+
+    return LazyEmitLayer.addModuleSet(singletonSet(std::move(M)),
+                                      make_unique<SectionMemoryManager>(),
+                                      std::move(Resolver));
   }
 
   void removeModule(ModuleHandleT H) { LazyEmitLayer.removeModuleSet(H); }
 
-  uint64_t getMangledSymbolAddress(const std::string &Name) {
-    return LazyEmitLayer.getSymbolAddress(Name, false);
+  JITSymbol findSymbol(const std::string &Name) {
+    return LazyEmitLayer.findSymbol(Name, true);
+  }
+
+  JITSymbol findSymbolIn(ModuleHandleT H, const std::string &Name) {
+    return LazyEmitLayer.findSymbolIn(H, Name, true);
   }
 
-  uint64_t getSymbolAddress(const std::string &Name) {
-    return getMangledSymbolAddress(Mangle(Name));
+  JITSymbol findUnmangledSymbol(const std::string &Name) {
+    return findSymbol(mangle(Name));
   }
 
 private:
 
-  std::unique_ptr<TargetMachine> TM;
-  Mangler Mang;
-  SessionContext &Session;
+  // This method searches the FunctionDefs map for a definition of 'Name'. If it
+  // finds one it generates a stub for it and returns the address of the stub.
+  RuntimeDyld::SymbolInfo searchFunctionASTs(const std::string &Name) {
+    auto DefI = FunctionDefs.find(Name);
+    if (DefI == FunctionDefs.end())
+      return 0;
+
+    // Take the FunctionAST out of the map.
+    auto FnAST = std::move(DefI->second);
+    FunctionDefs.erase(DefI);
+
+    // IRGen the AST, add it to the JIT, and return the address for it.
+    auto H = addModule(IRGen(Session, *FnAST));
+    auto Sym = findSymbolIn(H, Name);
+    return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags());
+  }
 
+  SessionContext &Session;
+  Mangler Mang;
   ObjLayerT ObjectLayer;
   CompileLayerT CompileLayer;
   LazyEmitLayerT LazyEmitLayer;
+
+  std::map<std::string, std::unique_ptr<FunctionAST>> FunctionDefs;
 };
 
 static void HandleDefinition(SessionContext &S, KaleidoscopeJIT &J) {
   if (auto F = ParseDefinition()) {
     S.addPrototypeAST(llvm::make_unique<PrototypeAST>(*F->Proto));
-    S.FunctionDefs[J.Mangle(F->Proto->Name)] = std::move(F);
+    J.addFunctionAST(std::move(F));
   } else {
     // Skip token for error recovery.
     getNextToken();
@@ -1253,11 +1277,11 @@ static void HandleTopLevelExpression(SessionContext &S, KaleidoscopeJIT &J) {
       auto H = J.addModule(C.takeM());
 
       // Get the address of the JIT'd function in memory.
-      uint64_t ExprFuncAddr = J.getSymbolAddress("__anon_expr");
+      auto ExprSymbol = J.findUnmangledSymbol("__anon_expr");
       
       // Cast it to the right type (takes no arguments, returns a double) so we
       // can call it as a native function.
-      double (*FP)() = (double (*)())(intptr_t)ExprFuncAddr;
+      double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
 #ifdef MINIMAL_STDERR_OUTPUT
       FP();
 #else
@@ -1279,16 +1303,16 @@ static void MainLoop() {
   KaleidoscopeJIT J(S);
 
   while (1) {
-#ifndef MINIMAL_STDERR_OUTPUT
-    std::cerr << "ready> ";
-#endif
     switch (CurTok) {
     case tok_eof:    return;
-    case ';':        getNextToken(); break;  // ignore top-level semicolons.
+    case ';':        getNextToken(); continue;  // ignore top-level semicolons.
     case tok_def:    HandleDefinition(S, J); break;
     case tok_extern: HandleExtern(S); break;
     default:         HandleTopLevelExpression(S, J); break;
     }
+#ifndef MINIMAL_STDERR_OUTPUT
+    std::cerr << "ready> ";
+#endif
   }
 }