[Orc] Fix a bug in the compile callback manager: trampoline ids need to be fixed
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / IndirectionUtils.h
1 //===-- IndirectionUtils.h - Utilities for adding indirections --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Contains utilities for adding indirections and breaking up modules.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H
15 #define LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H
16
17 #include "JITSymbol.h"
18 #include "llvm/ADT/DenseSet.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/Mangler.h"
21 #include "llvm/IR/Module.h"
22 #include <sstream>
23
24 namespace llvm {
25
26 /// @brief Base class for JITLayer independent aspects of
27 ///        JITCompileCallbackManager.
28 template <typename TargetT>
29 class JITCompileCallbackManagerBase {
30 public:
31
32   /// @brief Construct a JITCompileCallbackManagerBase.
33   /// @param ErrorHandlerAddress The address of an error handler in the target
34   ///                            process to be used if a compile callback fails.
35   /// @param NumTrampolinesPerBlock Number of trampolines to emit if there is no
36   ///                             available trampoline when getCompileCallback is
37   ///                             called.
38   JITCompileCallbackManagerBase(TargetAddress ErrorHandlerAddress,
39                                 unsigned NumTrampolinesPerBlock)
40     : ErrorHandlerAddress(ErrorHandlerAddress),
41       NumTrampolinesPerBlock(NumTrampolinesPerBlock) {}
42
43   /// @brief Execute the callback for the given trampoline id. Called by the JIT
44   ///        to compile functions on demand.
45   TargetAddress executeCompileCallback(TargetAddress TrampolineID) {
46     typename TrampolineMapT::iterator I = ActiveTrampolines.find(TrampolineID);
47     // FIXME: Also raise an error in the Orc error-handler when we finally have
48     //        one.
49     if (I == ActiveTrampolines.end())
50       return ErrorHandlerAddress;
51
52     // Found a callback handler. Yank this trampoline out of the active list and
53     // put it back in the available trampolines list, then try to run the
54     // handler's compile and update actions.
55     // Moving the trampoline ID back to the available list first means there's at
56     // least one available trampoline if the compile action triggers a request for
57     // a new one.
58     AvailableTrampolines.push_back(I->first - TargetT::CallSize);
59     auto CallbackHandler = std::move(I->second);
60     ActiveTrampolines.erase(I);
61
62     if (auto Addr = CallbackHandler.Compile()) {
63       CallbackHandler.Update(Addr);
64       return Addr;
65     }
66     return ErrorHandlerAddress;
67   }
68
69 protected:
70
71   typedef std::function<TargetAddress()> CompileFtorT;
72   typedef std::function<void(TargetAddress)> UpdateFtorT;
73
74   struct CallbackHandler {
75     CompileFtorT Compile;
76     UpdateFtorT Update;
77   };
78
79   TargetAddress ErrorHandlerAddress;
80   unsigned NumTrampolinesPerBlock;
81
82   typedef std::map<TargetAddress, CallbackHandler> TrampolineMapT;
83   TrampolineMapT ActiveTrampolines;
84   std::vector<TargetAddress> AvailableTrampolines;
85 };
86
87 /// @brief Manage compile callbacks.
88 template <typename JITLayerT, typename TargetT>
89 class JITCompileCallbackManager :
90     public JITCompileCallbackManagerBase<TargetT> {
91 public:
92
93   typedef typename JITCompileCallbackManagerBase<TargetT>::CompileFtorT
94     CompileFtorT;
95   typedef typename JITCompileCallbackManagerBase<TargetT>::UpdateFtorT
96     UpdateFtorT;
97
98   /// @brief Construct a JITCompileCallbackManager.
99   /// @param JIT JIT layer to emit callback trampolines, etc. into.
100   /// @param Context LLVMContext to use for trampoline & resolve block modules.
101   /// @param ErrorHandlerAddress The address of an error handler in the target
102   ///                            process to be used if a compile callback fails.
103   /// @param NumTrampolinesPerBlock Number of trampolines to allocate whenever
104   ///                               there is no existing callback trampoline.
105   ///                               (Trampolines are allocated in blocks for
106   ///                               efficiency.)
107   JITCompileCallbackManager(JITLayerT &JIT, LLVMContext &Context,
108                             TargetAddress ErrorHandlerAddress,
109                             unsigned NumTrampolinesPerBlock)
110     : JITCompileCallbackManagerBase<TargetT>(ErrorHandlerAddress,
111                                              NumTrampolinesPerBlock),
112       JIT(JIT) {
113     emitResolverBlock(Context);
114   }
115
116   /// @brief Handle to a newly created compile callback. Can be used to get an
117   ///        IR constant representing the address of the trampoline, and to set
118   ///        the compile and update actions for the callback.
119   class CompileCallbackInfo {
120   public:
121     CompileCallbackInfo(Constant *Addr, CompileFtorT &Compile,
122                         UpdateFtorT &Update)
123       : Addr(Addr), Compile(Compile), Update(Update) {}
124
125     Constant* getAddress() const { return Addr; }
126     void setCompileAction(CompileFtorT Compile) {
127       this->Compile = std::move(Compile);
128     }
129     void setUpdateAction(UpdateFtorT Update) {
130       this->Update = std::move(Update);
131     }
132   private:
133     Constant *Addr;
134     CompileFtorT &Compile;
135     UpdateFtorT &Update;
136   };
137
138   /// @brief Get/create a compile callback with the given signature.
139   CompileCallbackInfo getCompileCallback(FunctionType &FT) {
140     TargetAddress TrampolineAddr = getAvailableTrampolineAddr(FT.getContext());
141     auto &CallbackHandler =
142       this->ActiveTrampolines[TrampolineAddr + TargetT::CallSize];
143     Constant *AddrIntVal =
144       ConstantInt::get(Type::getInt64Ty(FT.getContext()), TrampolineAddr);
145     Constant *AddrPtrVal =
146       ConstantExpr::getCast(Instruction::IntToPtr, AddrIntVal,
147                             PointerType::get(&FT, 0));
148
149     return CompileCallbackInfo(AddrPtrVal, CallbackHandler.Compile,
150                                CallbackHandler.Update);
151   }
152
153   /// @brief Get a functor for updating the value of a named function pointer.
154   UpdateFtorT getLocalFPUpdater(typename JITLayerT::ModuleSetHandleT H,
155                                 std::string Name) {
156     // FIXME: Move-capture Name once we can use C++14.
157     return [=](TargetAddress Addr) {
158       auto FPSym = JIT.findSymbolIn(H, Name, true);
159       assert(FPSym && "Cannot find function pointer to update.");
160       void *FPAddr = reinterpret_cast<void*>(
161                        static_cast<uintptr_t>(FPSym.getAddress()));
162       memcpy(FPAddr, &Addr, sizeof(uintptr_t));
163     };
164   }
165
166 private:
167
168   std::vector<std::unique_ptr<Module>>
169   SingletonSet(std::unique_ptr<Module> M) {
170     std::vector<std::unique_ptr<Module>> Ms;
171     Ms.push_back(std::move(M));
172     return Ms;
173   }
174
175   void emitResolverBlock(LLVMContext &Context) {
176     std::unique_ptr<Module> M(new Module("resolver_block_module",
177                                          Context));
178     TargetT::insertResolverBlock(*M, *this);
179     auto H = JIT.addModuleSet(SingletonSet(std::move(M)), nullptr);
180     JIT.emitAndFinalize(H);
181     auto ResolverBlockSymbol =
182       JIT.findSymbolIn(H, TargetT::ResolverBlockName, false);
183     assert(ResolverBlockSymbol && "Failed to insert resolver block");
184     ResolverBlockAddr = ResolverBlockSymbol.getAddress();
185   }
186
187   TargetAddress getAvailableTrampolineAddr(LLVMContext &Context) {
188     if (this->AvailableTrampolines.empty())
189       grow(Context);
190     assert(!this->AvailableTrampolines.empty() &&
191            "Failed to grow available trampolines.");
192     TargetAddress TrampolineAddr = this->AvailableTrampolines.back();
193     this->AvailableTrampolines.pop_back();
194     return TrampolineAddr;
195   }
196
197   void grow(LLVMContext &Context) {
198     assert(this->AvailableTrampolines.empty() && "Growing prematurely?");
199     std::unique_ptr<Module> M(new Module("trampoline_block", Context));
200     auto GetLabelName =
201       TargetT::insertCompileCallbackTrampolines(*M, ResolverBlockAddr,
202                                                 this->NumTrampolinesPerBlock,
203                                                 this->ActiveTrampolines.size());
204     auto H = JIT.addModuleSet(SingletonSet(std::move(M)), nullptr);
205     JIT.emitAndFinalize(H);
206     for (unsigned I = 0; I < this->NumTrampolinesPerBlock; ++I) {
207       std::string Name = GetLabelName(I);
208       auto TrampolineSymbol = JIT.findSymbolIn(H, Name, false);
209       assert(TrampolineSymbol && "Failed to emit trampoline.");
210       this->AvailableTrampolines.push_back(TrampolineSymbol.getAddress());
211     }
212   }
213
214   JITLayerT &JIT;
215   TargetAddress ResolverBlockAddr;
216 };
217
218 GlobalVariable* createImplPointer(Function &F, const Twine &Name,
219                                   Constant *Initializer);
220
221 void makeStub(Function &F, GlobalVariable &ImplPointer);
222
223 typedef std::map<Module*, DenseSet<const GlobalValue*>> ModulePartitionMap;
224
225 void partition(Module &M, const ModulePartitionMap &PMap);
226
227 /// @brief Struct for trivial "complete" partitioning of a module.
228 class FullyPartitionedModule {
229 public:
230   std::unique_ptr<Module> GlobalVars;
231   std::unique_ptr<Module> Commons;
232   std::vector<std::unique_ptr<Module>> Functions;
233
234   FullyPartitionedModule() = default;
235   FullyPartitionedModule(FullyPartitionedModule &&S)
236       : GlobalVars(std::move(S.GlobalVars)), Commons(std::move(S.Commons)),
237         Functions(std::move(S.Functions)) {}
238 };
239
240 FullyPartitionedModule fullyPartition(Module &M);
241
242 }
243
244 #endif // LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H