[Orc][Kaleidoscope] Remove dead AST map in SessionContext.
[oota-llvm.git] / examples / Kaleidoscope / Orc / fully_lazy / toy.cpp
index 9210dd1f3bb85e76c4960d20cd7346fd1c62355e..56123bb41e3612a3c956dc277229200e4c36650c 100644 (file)
@@ -20,7 +20,9 @@
 #include <sstream>
 #include <string>
 #include <vector>
+
 using namespace llvm;
+using namespace llvm::orc;
 
 //===----------------------------------------------------------------------===//
 // Lexer
@@ -684,7 +686,6 @@ public:
   LLVMContext& getLLVMContext() const { return Context; }
   void addPrototypeAST(std::unique_ptr<PrototypeAST> P);
   PrototypeAST* getPrototypeAST(const std::string &Name);
-  std::map<std::string, std::unique_ptr<FunctionAST>> FunctionDefs; 
 private:
   typedef std::map<std::string, std::unique_ptr<PrototypeAST>> PrototypeMap;
   LLVMContext &Context;
@@ -1184,27 +1185,14 @@ public:
     // We need a memory manager to allocate memory and resolve symbols for this
     // new module. Create one that resolves symbols by looking back into the JIT.
     auto MM = createLookasideRTDyldMM<SectionMemoryManager>(
-                [&](const std::string &Name) -> uint64_t {
+                [&](const std::string &Name) {
                   // First try to find 'Name' within the JIT.
                   if (auto Symbol = findMangledSymbol(Name))
                     return Symbol.getAddress();
 
-                  // If we don't find 'Name' in the JIT, see if we have some AST
-                  // for it.
-                  auto DefI = Session.FunctionDefs.find(Name);
-                  if (DefI == Session.FunctionDefs.end())
-                    return 0;
-
-                  // We have AST for 'Name'. IRGen it, add it to the JIT, and
-                  // return the address for it.
-                  // FIXME: What happens if IRGen fails?
-                  addModule(IRGen(Session, *DefI->second));
-
-                  // Remove the function definition's AST now that we've
-                  // finished with it.
-                  Session.FunctionDefs.erase(DefI);
-
-                  return findMangledSymbol(Name).getAddress();
+                  // If we don't already have a definition of 'Name' then search
+                  // the ASTs.
+                  return searchUncompiledASTs(Name);
                 },
                 [](const std::string &S) { return 0; } );
 
@@ -1223,7 +1211,7 @@ public:
 
   JITSymbol findMangledSymbolIn(LazyEmitLayerT::ModuleSetHandleT H,
                                 const std::string &Name) {
-    return LazyEmitLayer.findSymbolIn(H, Name, true); 
+    return LazyEmitLayer.findSymbolIn(H, Name, true);
   }
 
   JITSymbol findSymbolIn(LazyEmitLayerT::ModuleSetHandleT H,
@@ -1232,19 +1220,49 @@ public:
   }
 
   void addFunctionDefinition(std::unique_ptr<FunctionAST> FnAST) {
-    // Step 1) IRGen a prototype for this function:
+    FunctionDefs[Mangle(FnAST->Proto->Name)] = std::move(FnAST);
+  }
+
+private:
+
+  // This method searches the FunctionDefs map for a definition of 'Name'. If it
+  // finds one it generates a stub for it and returns the address of the stub.
+  TargetAddress searchUncompiledASTs(const std::string &Name) {
+    auto DefI = FunctionDefs.find(Name);
+    if (DefI == FunctionDefs.end())
+      return 0;
+
+    // We have AST for 'Name'. IRGen a stub for it and add it to the JIT.
+    // FIXME: What happens if IRGen fails?
+    auto H = irGenStub(std::move(DefI->second));
+
+    // Remove the map entry now that we're done with it.
+    FunctionDefs.erase(DefI);
+
+    // Return the address of the stub.
+    return findMangledSymbolIn(H, Name).getAddress();
+  }
+
+  // This method will take the AST for a function definition and IR-gen a stub
+  // for that function that will, on first call, IR-gen the actual body of the
+  // function.
+  ModuleHandleT irGenStub(std::unique_ptr<FunctionAST> FnAST) {
+    // Step 1) IRGen a prototype for the stub. This will have the same type as
+    //         the function.
     IRGenContext C(Session);
     Function *F = FnAST->Proto->IRGen(C);
     C.getM().setDataLayout(TM->getDataLayout());
 
-    // Step 2) Create a compile callback that will be used to compile this
-    //         function when it is first called.
+    // Step 2) Get a compile callback that can be used to compile the body of
+    //         the function. The resulting CallbackInfo type will let us set the
+    //         compile and update actions for the callback, and get a pointer to
+    //         the jit trampoline that we need to call to trigger those actions.
     auto CallbackInfo =
       CompileCallbacks.getCompileCallback(*F->getFunctionType());
 
     // Step 3) Create a stub that will indirectly call the body of this
     //         function once it is compiled. Initially, set the function
-    //         pointer for the indirection to point at the compile callback.
+    //         pointer for the indirection to point at the trampoline.
     std::string BodyPtrName = (F->getName() + "$address").str();
     GlobalVariable *FunctionBodyPointer =
       createImplPointer(*F, BodyPtrName, CallbackInfo.getAddress());
@@ -1253,7 +1271,7 @@ public:
     // Step 4) Add the module containing the stub to the JIT.
     auto H = addModule(C.takeM());
 
-    // Step 5) Set the compile and update actions for the callback.
+    // Step 5) Set the compile and update actions.
     //
     //   The compile action will IRGen the function and add it to the JIT, then
     // request its address, which will trigger codegen. Since we don't need the
@@ -1263,16 +1281,16 @@ public:
     //
     //   The update action will update FunctionBodyPointer to point at the newly
     // compiled function.
-    CallbackInfo.setCompileAction(
-      [this,Fn = std::shared_ptr<FunctionAST>(std::move(FnAST))](){
-        auto H = addModule(IRGen(Session, *Fn));
-        return findSymbolIn(H, Fn->Proto->Name).getAddress();
-      });
+    std::shared_ptr<FunctionAST> Fn = std::move(FnAST);
+    CallbackInfo.setCompileAction([this, Fn]() {
+      auto H = addModule(IRGen(Session, *Fn));
+      return findSymbolIn(H, Fn->Proto->Name).getAddress();
+    });
     CallbackInfo.setUpdateAction(
       CompileCallbacks.getLocalFPUpdater(H, Mangle(BodyPtrName)));
-  }
 
-private:
+    return H;
+  }
 
   std::unique_ptr<TargetMachine> TM;
   Mangler Mang;
@@ -1283,6 +1301,8 @@ private:
   LazyEmitLayerT LazyEmitLayer;
 
   JITCompileCallbackManager<LazyEmitLayerT, OrcX86_64> CompileCallbacks;
+
+  std::map<std::string, std::unique_ptr<FunctionAST>> FunctionDefs;
 };
 
 static void HandleDefinition(SessionContext &S, KaleidoscopeJIT &J) {