[Orc] New JIT APIs.
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / IndirectionUtils.h
1 //===-- IndirectionUtils.h - Utilities for adding indirections --*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Contains utilities for adding indirections and breaking up modules.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #ifndef LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H
15 #define LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H
16
17 #include "llvm/IR/Mangler.h"
18 #include "llvm/IR/Module.h"
19 #include <sstream>
20
21 namespace llvm {
22
23 /// @brief Persistent name mangling.
24 ///
25 ///   This class provides name mangling that can outlive a Module (and its
26 /// DataLayout).
27 class PersistentMangler {
28 public:
29   PersistentMangler(DataLayout DL) : DL(std::move(DL)), M(&this->DL) {}
30
31   std::string getMangledName(StringRef Name) const {
32     std::string MangledName;
33     {
34       raw_string_ostream MangledNameStream(MangledName);
35       M.getNameWithPrefix(MangledNameStream, Name);
36     }
37     return MangledName;
38   }
39
40 private:
41   DataLayout DL;
42   Mangler M;
43 };
44
45 /// @brief Handle callbacks from the JIT process requesting the definitions of
46 ///        symbols.
47 ///
48 ///   This utility is intended to be used to support compile-on-demand for
49 /// functions.
50 class JITResolveCallbackHandler {
51 private:
52   typedef std::vector<std::string> FuncNameList;
53
54 public:
55   typedef FuncNameList::size_type StubIndex;
56
57 public:
58   /// @brief Create a JITResolveCallbackHandler with the given functors for
59   ///        looking up symbols and updating their use-sites.
60   ///
61   /// @return A JITResolveCallbackHandler instance that will invoke the
62   ///         Lookup and Update functors as needed to resolve missing symbol
63   ///         definitions.
64   template <typename LookupFtor, typename UpdateFtor>
65   static std::unique_ptr<JITResolveCallbackHandler> create(LookupFtor Lookup,
66                                                            UpdateFtor Update);
67
68   /// @brief Destroy instance. Does not modify existing emitted symbols.
69   ///
70   ///   Not-yet-emitted symbols will need to be resolved some other way after
71   /// this class is destroyed.
72   virtual ~JITResolveCallbackHandler() {}
73
74   /// @brief Add a function to be resolved on demand.
75   void addFuncName(std::string Name) { FuncNames.push_back(std::move(Name)); }
76
77   /// @brief Get the name associated with the given index.
78   const std::string &getFuncName(StubIndex Idx) const { return FuncNames[Idx]; }
79
80   /// @brief Returns the number of symbols being managed by this instance.
81   StubIndex getNumFuncs() const { return FuncNames.size(); }
82
83   /// @brief Get the address for the symbol associated with the given index.
84   ///
85   ///   This is expected to be called by code in the JIT process itself, in
86   /// order to resolve a function.
87   virtual uint64_t resolve(StubIndex StubIdx) = 0;
88
89 private:
90   FuncNameList FuncNames;
91 };
92
93 // Implementation class for JITResolveCallbackHandler.
94 template <typename LookupFtor, typename UpdateFtor>
95 class JITResolveCallbackHandlerImpl : public JITResolveCallbackHandler {
96 public:
97   JITResolveCallbackHandlerImpl(LookupFtor Lookup, UpdateFtor Update)
98       : Lookup(std::move(Lookup)), Update(std::move(Update)) {}
99
100   uint64_t resolve(StubIndex StubIdx) override {
101     const std::string &FuncName = getFuncName(StubIdx);
102     uint64_t Addr = Lookup(FuncName);
103     Update(FuncName, Addr);
104     return Addr;
105   }
106
107 private:
108   LookupFtor Lookup;
109   UpdateFtor Update;
110 };
111
112 template <typename LookupFtor, typename UpdateFtor>
113 std::unique_ptr<JITResolveCallbackHandler>
114 JITResolveCallbackHandler::create(LookupFtor Lookup, UpdateFtor Update) {
115   typedef JITResolveCallbackHandlerImpl<LookupFtor, UpdateFtor> Impl;
116   return make_unique<Impl>(std::move(Lookup), std::move(Update));
117 }
118
119 /// @brief Holds a list of the function names that were indirected, plus
120 ///        mappings from each of these names to (a) the name of function
121 ///        providing the implementation for that name (GetImplNames), and
122 ///        (b) the name of the global variable holding the address of the
123 ///        implementation.
124 ///
125 ///   This data structure can be used with a JITCallbackHandler to look up and
126 /// update function implementations when lazily compiling.
127 class JITIndirections {
128 public:
129   JITIndirections(std::vector<std::string> IndirectedNames,
130                   std::function<std::string(StringRef)> GetImplName,
131                   std::function<std::string(StringRef)> GetAddrName)
132       : IndirectedNames(std::move(IndirectedNames)),
133         GetImplName(std::move(GetImplName)),
134         GetAddrName(std::move(GetAddrName)) {}
135
136   std::vector<std::string> IndirectedNames;
137   std::function<std::string(StringRef Name)> GetImplName;
138   std::function<std::string(StringRef Name)> GetAddrName;
139 };
140
141 /// @brief Indirect all calls to functions matching the predicate
142 ///        ShouldIndirect through a global variable containing the address
143 ///        of the implementation.
144 ///
145 /// @return An indirection structure containing the functions that had their
146 ///         call-sites re-written.
147 ///
148 ///   For each function 'F' that meets the ShouldIndirect predicate, and that
149 /// is called in this Module, add a common-linkage global variable to the
150 /// module that will hold the address of the implementation of that function.
151 /// Rewrite all call-sites of 'F' to be indirect calls (via the global).
152 /// This allows clients, either directly or via a JITCallbackHandler, to
153 /// change the address of the implementation of 'F' at runtime.
154 ///
155 /// Important notes:
156 ///
157 ///   Single indirection does not preserve pointer equality for 'F'. If the
158 /// program was already calling 'F' indirectly through function pointers, or
159 /// if it was taking the address of 'F' for the purpose of pointer comparisons
160 /// or arithmetic double indirection should be used instead.
161 ///
162 ///   This method does *not* initialize the function implementation addresses.
163 /// The client must do this prior to running any call-sites that have been
164 /// indirected.
165 JITIndirections makeCallsSingleIndirect(
166     llvm::Module &M,
167     const std::function<bool(const Function &)> &ShouldIndirect,
168     const char *JITImplSuffix, const char *JITAddrSuffix);
169
170 /// @brief Replace the body of functions matching the predicate ShouldIndirect
171 ///        with indirect calls to the implementation.
172 ///
173 /// @return An indirections structure containing the functions that had their
174 ///         implementations re-written.
175 ///
176 ///   For each function 'F' that meets the ShouldIndirect predicate, add a
177 /// common-linkage global variable to the module that will hold the address of
178 /// the implementation of that function and rewrite the implementation of 'F'
179 /// to call through to the implementation indirectly (via the global).
180 /// This allows clients, either directly or via a JITCallbackHandler, to
181 /// change the address of the implementation of 'F' at runtime.
182 ///
183 /// Important notes:
184 ///
185 ///   Double indirection is slower than single indirection, but preserves
186 /// function pointer relation tests and correct behavior for function pointers
187 /// (all calls to 'F', direct or indirect) go the address stored in the global
188 /// variable at the time of the call.
189 ///
190 ///   This method does *not* initialize the function implementation addresses.
191 /// The client must do this prior to running any call-sites that have been
192 /// indirected.
193 JITIndirections makeCallsDoubleIndirect(
194     llvm::Module &M,
195     const std::function<bool(const Function &)> &ShouldIndirect,
196     const char *JITImplSuffix, const char *JITAddrSuffix);
197
198 /// @brief Given a set of indirections and a symbol lookup functor, create a
199 ///        JITResolveCallbackHandler instance that will resolve the
200 ///        implementations for the indirected symbols on demand.
201 template <typename SymbolLookupFtor>
202 std::unique_ptr<JITResolveCallbackHandler>
203 createCallbackHandlerFromJITIndirections(const JITIndirections &Indirs,
204                                          const PersistentMangler &NM,
205                                          SymbolLookupFtor Lookup) {
206   auto GetImplName = Indirs.GetImplName;
207   auto GetAddrName = Indirs.GetAddrName;
208
209   std::unique_ptr<JITResolveCallbackHandler> J =
210       JITResolveCallbackHandler::create(
211           [=](const std::string &S) {
212             return Lookup(NM.getMangledName(GetImplName(S)));
213           },
214           [=](const std::string &S, uint64_t Addr) {
215             void *ImplPtr = reinterpret_cast<void *>(
216                 Lookup(NM.getMangledName(GetAddrName(S))));
217             memcpy(ImplPtr, &Addr, sizeof(uint64_t));
218           });
219
220   for (const auto &FuncName : Indirs.IndirectedNames)
221     J->addFuncName(FuncName);
222
223   return J;
224 }
225
226 /// @brief Insert callback asm into module M for the symbols managed by
227 ///        JITResolveCallbackHandler J.
228 void insertX86CallbackAsm(Module &M, JITResolveCallbackHandler &J);
229
230 /// @brief Initialize global indirects to point into the callback asm.
231 template <typename LookupFtor>
232 void initializeFuncAddrs(JITResolveCallbackHandler &J,
233                          const JITIndirections &Indirs,
234                          const PersistentMangler &NM, LookupFtor Lookup) {
235   // Forward declare so that we can access this, even though it's an
236   // implementation detail.
237   std::string getJITResolveCallbackIndexLabel(unsigned I);
238
239   if (J.getNumFuncs() == 0)
240     return;
241
242   //   Force a look up one of the global addresses for a function that has been
243   // indirected. We need to do this to trigger the emission of the module
244   // holding the callback asm. We can't rely on that emission happening
245   // automatically when we look up the callback asm symbols, since lazy-emitting
246   // layers can't see those.
247   Lookup(NM.getMangledName(Indirs.GetAddrName(J.getFuncName(0))));
248
249   // Now update indirects to point to the JIT resolve callback asm.
250   for (JITResolveCallbackHandler::StubIndex I = 0; I < J.getNumFuncs(); ++I) {
251     uint64_t ResolveCallbackIdxAddr =
252         Lookup(getJITResolveCallbackIndexLabel(I));
253     void *AddrPtr = reinterpret_cast<void *>(
254         Lookup(NM.getMangledName(Indirs.GetAddrName(J.getFuncName(I)))));
255     assert(AddrPtr && "Can't find stub addr global to initialize.");
256     memcpy(AddrPtr, &ResolveCallbackIdxAddr, sizeof(uint64_t));
257   }
258 }
259
260 /// @brief Extract all functions matching the predicate ShouldExtract in to
261 ///        their own modules. (Does not modify the original module.)
262 ///
263 /// @return A set of modules, the first containing all symbols (including
264 ///         globals and aliases) that did not pass ShouldExtract, and each
265 ///         subsequent module containing one of the functions that did meet
266 ///         ShouldExtract.
267 ///
268 ///   By adding the resulting modules separately (not as a set) to a
269 /// LazyEmittingLayer instance, compilation can be deferred until symbols are
270 /// actually needed.
271 std::vector<std::unique_ptr<llvm::Module>>
272 explode(const llvm::Module &OrigMod,
273         const std::function<bool(const Function &)> &ShouldExtract);
274
275 /// @brief Given a module that has been indirectified, break each function
276 ///        that has been indirected out into its own module. (Does not modify
277 ///        the original module).
278 ///
279 /// @returns A set of modules covering the symbols provided by OrigMod.
280 std::vector<std::unique_ptr<llvm::Module>>
281 explode(const llvm::Module &OrigMod, const JITIndirections &Indirections);
282 }
283
284 #endif // LLVM_EXECUTIONENGINE_ORC_INDIRECTIONUTILS_H