4b4ecfc1ad2b2311e8522865f791607455c2834d
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / OrcRemoteTargetServer.h
1 //===---- OrcRemoteTargetServer.h - Orc Remote-target Server ----*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the OrcRemoteTargetServer class. It can be used to build a
11 // JIT server that can execute code sent from an OrcRemoteTargetClient.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H
16 #define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H
17
18 #include "OrcRemoteTargetRPCAPI.h"
19 #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/Format.h"
22 #include "llvm/Support/Process.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include <map>
25
26 #define DEBUG_TYPE "orc-remote"
27
28 namespace llvm {
29 namespace orc {
30 namespace remote {
31
32 template <typename ChannelT, typename TargetT>
33 class OrcRemoteTargetServer : public OrcRemoteTargetRPCAPI {
34 public:
35   typedef std::function<TargetAddress(const std::string &Name)>
36       SymbolLookupFtor;
37
38   OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup)
39       : Channel(Channel), SymbolLookup(std::move(SymbolLookup)) {}
40
41   std::error_code getNextProcId(JITProcId &Id) {
42     return deserialize(Channel, Id);
43   }
44
45   std::error_code handleKnownProcedure(JITProcId Id) {
46     DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n");
47
48     switch (Id) {
49     case CallIntVoidId:
50       return handleCallIntVoid();
51     case CallMainId:
52       return handleCallMain();
53     case CallVoidVoidId:
54       return handleCallVoidVoid();
55     case CreateRemoteAllocatorId:
56       return handleCreateRemoteAllocator();
57     case CreateIndirectStubsOwnerId:
58       return handleCreateIndirectStubsOwner();
59     case DestroyRemoteAllocatorId:
60       return handleDestroyRemoteAllocator();
61     case EmitIndirectStubsId:
62       return handleEmitIndirectStubs();
63     case EmitResolverBlockId:
64       return handleEmitResolverBlock();
65     case EmitTrampolineBlockId:
66       return handleEmitTrampolineBlock();
67     case GetSymbolAddressId:
68       return handleGetSymbolAddress();
69     case GetRemoteInfoId:
70       return handleGetRemoteInfo();
71     case ReadMemId:
72       return handleReadMem();
73     case ReserveMemId:
74       return handleReserveMem();
75     case SetProtectionsId:
76       return handleSetProtections();
77     case WriteMemId:
78       return handleWriteMem();
79     case WritePtrId:
80       return handleWritePtr();
81     default:
82       return orcError(OrcErrorCode::UnexpectedRPCCall);
83     }
84
85     llvm_unreachable("Unhandled JIT RPC procedure Id.");
86   }
87
88   std::error_code requestCompile(TargetAddress &CompiledFnAddr,
89                                  TargetAddress TrampolineAddr) {
90     if (auto EC = call<RequestCompile>(Channel, TrampolineAddr))
91       return EC;
92
93     while (1) {
94       JITProcId Id = InvalidId;
95       if (auto EC = getNextProcId(Id))
96         return EC;
97
98       switch (Id) {
99       case RequestCompileResponseId:
100         return handle<RequestCompileResponse>(Channel,
101                                               readArgs(CompiledFnAddr));
102       default:
103         if (auto EC = handleKnownProcedure(Id))
104           return EC;
105       }
106     }
107
108     llvm_unreachable("Fell through request-compile command loop.");
109   }
110
111 private:
112   struct Allocator {
113     Allocator() = default;
114     Allocator(Allocator &&) = default;
115     Allocator &operator=(Allocator &&) = default;
116
117     ~Allocator() {
118       for (auto &Alloc : Allocs)
119         sys::Memory::releaseMappedMemory(Alloc.second);
120     }
121
122     std::error_code allocate(void *&Addr, size_t Size, uint32_t Align) {
123       std::error_code EC;
124       sys::MemoryBlock MB = sys::Memory::allocateMappedMemory(
125           Size, nullptr, sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC);
126       if (EC)
127         return EC;
128
129       Addr = MB.base();
130       assert(Allocs.find(MB.base()) == Allocs.end() && "Duplicate alloc");
131       Allocs[MB.base()] = std::move(MB);
132       return std::error_code();
133     }
134
135     std::error_code setProtections(void *block, unsigned Flags) {
136       auto I = Allocs.find(block);
137       if (I == Allocs.end())
138         return orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized);
139       return sys::Memory::protectMappedMemory(I->second, Flags);
140     }
141
142   private:
143     std::map<void *, sys::MemoryBlock> Allocs;
144   };
145
146   static std::error_code doNothing() { return std::error_code(); }
147
148   static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) {
149     TargetAddress CompiledFnAddr = 0;
150
151     auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr);
152     auto EC = T->requestCompile(
153         CompiledFnAddr, static_cast<TargetAddress>(
154                             reinterpret_cast<uintptr_t>(TrampolineAddr)));
155     assert(!EC && "Compile request failed");
156     return CompiledFnAddr;
157   }
158
159   std::error_code handleCallIntVoid() {
160     typedef int (*IntVoidFnTy)();
161
162     IntVoidFnTy Fn = nullptr;
163     if (auto EC = handle<CallIntVoid>(Channel, [&](TargetAddress Addr) {
164           Fn = reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
165           return std::error_code();
166         }))
167       return EC;
168
169     DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
170     int Result = Fn();
171     DEBUG(dbgs() << "  Result = " << Result << "\n");
172
173     return call<CallIntVoidResponse>(Channel, Result);
174   }
175
176   std::error_code handleCallMain() {
177     typedef int (*MainFnTy)(int, const char *[]);
178
179     MainFnTy Fn = nullptr;
180     std::vector<std::string> Args;
181     if (auto EC = handle<CallMain>(
182             Channel, [&](TargetAddress Addr, std::vector<std::string> &A) {
183               Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
184               Args = std::move(A);
185               return std::error_code();
186             }))
187       return EC;
188
189     int ArgC = Args.size() + 1;
190     int Idx = 1;
191     std::unique_ptr<const char *[]> ArgV(new const char *[ArgC + 1]);
192     ArgV[0] = "<jit process>";
193     for (auto &Arg : Args)
194       ArgV[Idx++] = Arg.c_str();
195
196     DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
197     int Result = Fn(ArgC, ArgV.get());
198     DEBUG(dbgs() << "  Result = " << Result << "\n");
199
200     return call<CallMainResponse>(Channel, Result);
201   }
202
203   std::error_code handleCallVoidVoid() {
204     typedef void (*VoidVoidFnTy)();
205
206     VoidVoidFnTy Fn = nullptr;
207     if (auto EC = handle<CallIntVoid>(Channel, [&](TargetAddress Addr) {
208           Fn = reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr));
209           return std::error_code();
210         }))
211       return EC;
212
213     DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
214     Fn();
215     DEBUG(dbgs() << "  Complete.\n");
216
217     return call<CallVoidVoidResponse>(Channel);
218   }
219
220   std::error_code handleCreateRemoteAllocator() {
221     return handle<CreateRemoteAllocator>(
222         Channel, [&](ResourceIdMgr::ResourceId Id) {
223           auto I = Allocators.find(Id);
224           if (I != Allocators.end())
225             return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse);
226           DEBUG(dbgs() << "  Created allocator " << Id << "\n");
227           Allocators[Id] = Allocator();
228           return std::error_code();
229         });
230   }
231
232   std::error_code handleCreateIndirectStubsOwner() {
233     return handle<CreateIndirectStubsOwner>(
234         Channel, [&](ResourceIdMgr::ResourceId Id) {
235           auto I = IndirectStubsOwners.find(Id);
236           if (I != IndirectStubsOwners.end())
237             return orcError(
238                 OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse);
239           DEBUG(dbgs() << "  Create indirect stubs owner " << Id << "\n");
240           IndirectStubsOwners[Id] = ISBlockOwnerList();
241           return std::error_code();
242         });
243   }
244
245   std::error_code handleDestroyRemoteAllocator() {
246     return handle<DestroyRemoteAllocator>(
247         Channel, [&](ResourceIdMgr::ResourceId Id) {
248           auto I = Allocators.find(Id);
249           if (I == Allocators.end())
250             return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
251           Allocators.erase(I);
252           DEBUG(dbgs() << "  Destroyed allocator " << Id << "\n");
253           return std::error_code();
254         });
255   }
256
257   std::error_code handleDestroyIndirectStubsOwner() {
258     return handle<DestroyIndirectStubsOwner>(
259         Channel, [&](ResourceIdMgr::ResourceId Id) {
260           auto I = IndirectStubsOwners.find(Id);
261           if (I == IndirectStubsOwners.end())
262             return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
263           IndirectStubsOwners.erase(I);
264           return std::error_code();
265         });
266   }
267
268   std::error_code handleEmitIndirectStubs() {
269     ResourceIdMgr::ResourceId ISOwnerId = ~0U;
270     uint32_t NumStubsRequired = 0;
271
272     if (auto EC = handle<EmitIndirectStubs>(
273             Channel, readArgs(ISOwnerId, NumStubsRequired)))
274       return EC;
275
276     DEBUG(dbgs() << "  ISMgr " << ISOwnerId << " request " << NumStubsRequired
277                  << " stubs.\n");
278
279     auto StubOwnerItr = IndirectStubsOwners.find(ISOwnerId);
280     if (StubOwnerItr == IndirectStubsOwners.end())
281       return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
282
283     typename TargetT::IndirectStubsInfo IS;
284     if (auto EC =
285             TargetT::emitIndirectStubsBlock(IS, NumStubsRequired, nullptr))
286       return EC;
287
288     TargetAddress StubsBase =
289         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(IS.getStub(0)));
290     TargetAddress PtrsBase =
291         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(IS.getPtr(0)));
292     uint32_t NumStubsEmitted = IS.getNumStubs();
293
294     auto &BlockList = StubOwnerItr->second;
295     BlockList.push_back(std::move(IS));
296
297     return call<EmitIndirectStubsResponse>(Channel, StubsBase, PtrsBase,
298                                            NumStubsEmitted);
299   }
300
301   std::error_code handleEmitResolverBlock() {
302     if (auto EC = handle<EmitResolverBlock>(Channel, doNothing))
303       return EC;
304
305     std::error_code EC;
306     ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
307         TargetT::ResolverCodeSize, nullptr,
308         sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC));
309     if (EC)
310       return EC;
311
312     TargetT::writeResolverCode(static_cast<uint8_t *>(ResolverBlock.base()),
313                                &reenter, this);
314
315     return sys::Memory::protectMappedMemory(ResolverBlock.getMemoryBlock(),
316                                             sys::Memory::MF_READ |
317                                                 sys::Memory::MF_EXEC);
318   }
319
320   std::error_code handleEmitTrampolineBlock() {
321     if (auto EC = handle<EmitTrampolineBlock>(Channel, doNothing))
322       return EC;
323
324     std::error_code EC;
325
326     auto TrampolineBlock =
327         sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
328             TargetT::PageSize, nullptr,
329             sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC));
330     if (EC)
331       return EC;
332
333     unsigned NumTrampolines =
334         (TargetT::PageSize - TargetT::PointerSize) / TargetT::TrampolineSize;
335
336     uint8_t *TrampolineMem = static_cast<uint8_t *>(TrampolineBlock.base());
337     TargetT::writeTrampolines(TrampolineMem, ResolverBlock.base(),
338                               NumTrampolines);
339
340     EC = sys::Memory::protectMappedMemory(TrampolineBlock.getMemoryBlock(),
341                                           sys::Memory::MF_READ |
342                                               sys::Memory::MF_EXEC);
343
344     TrampolineBlocks.push_back(std::move(TrampolineBlock));
345
346     return call<EmitTrampolineBlockResponse>(
347         Channel,
348         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)),
349         NumTrampolines);
350   }
351
352   std::error_code handleGetSymbolAddress() {
353     std::string SymbolName;
354     if (auto EC = handle<GetSymbolAddress>(Channel, readArgs(SymbolName)))
355       return EC;
356
357     TargetAddress SymbolAddr = SymbolLookup(SymbolName);
358     DEBUG(dbgs() << "  Symbol '" << SymbolName
359                  << "' =  " << format("0x%016x", SymbolAddr) << "\n");
360     return call<GetSymbolAddressResponse>(Channel, SymbolAddr);
361   }
362
363   std::error_code handleGetRemoteInfo() {
364     if (auto EC = handle<GetRemoteInfo>(Channel, doNothing))
365       return EC;
366
367     std::string ProcessTriple = sys::getProcessTriple();
368     uint32_t PointerSize = TargetT::PointerSize;
369     uint32_t PageSize = sys::Process::getPageSize();
370     uint32_t TrampolineSize = TargetT::TrampolineSize;
371     uint32_t IndirectStubSize = TargetT::IndirectStubsInfo::StubSize;
372     DEBUG(dbgs() << "  Remote info:\n"
373                  << "    triple             = '" << ProcessTriple << "'\n"
374                  << "    pointer size       = " << PointerSize << "\n"
375                  << "    page size          = " << PageSize << "\n"
376                  << "    trampoline size    = " << TrampolineSize << "\n"
377                  << "    indirect stub size = " << IndirectStubSize << "\n");
378     return call<GetRemoteInfoResponse>(Channel, ProcessTriple, PointerSize,
379                                        PageSize, TrampolineSize,
380                                        IndirectStubSize);
381   }
382
383   std::error_code handleReadMem() {
384     char *Src = nullptr;
385     uint64_t Size = 0;
386     if (auto EC =
387             handle<ReadMem>(Channel, [&](TargetAddress RSrc, uint64_t RSize) {
388               Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
389               Size = RSize;
390               return std::error_code();
391             }))
392       return EC;
393
394     DEBUG(dbgs() << "  Reading " << Size << " bytes from "
395                  << static_cast<void *>(Src) << "\n");
396
397     if (auto EC = call<ReadMemResponse>(Channel))
398       return EC;
399
400     if (auto EC = Channel.appendBytes(Src, Size))
401       return EC;
402
403     return Channel.send();
404   }
405
406   std::error_code handleReserveMem() {
407     void *LocalAllocAddr = nullptr;
408
409     if (auto EC =
410             handle<ReserveMem>(Channel, [&](ResourceIdMgr::ResourceId Id,
411                                             uint64_t Size, uint32_t Align) {
412               auto I = Allocators.find(Id);
413               if (I == Allocators.end())
414                 return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
415               auto &Allocator = I->second;
416               auto EC2 = Allocator.allocate(LocalAllocAddr, Size, Align);
417               DEBUG(dbgs() << "  Allocator " << Id << " reserved "
418                            << LocalAllocAddr << " (" << Size
419                            << " bytes, alignment " << Align << ")\n");
420               return EC2;
421             }))
422       return EC;
423
424     TargetAddress AllocAddr =
425         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr));
426
427     return call<ReserveMemResponse>(Channel, AllocAddr);
428   }
429
430   std::error_code handleSetProtections() {
431     return handle<ReserveMem>(Channel, [&](ResourceIdMgr::ResourceId Id,
432                                            TargetAddress Addr, uint32_t Flags) {
433       auto I = Allocators.find(Id);
434       if (I == Allocators.end())
435         return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
436       auto &Allocator = I->second;
437       void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr));
438       DEBUG(dbgs() << "  Allocator " << Id << " set permissions on "
439                    << LocalAddr << " to "
440                    << (Flags & sys::Memory::MF_READ ? 'R' : '-')
441                    << (Flags & sys::Memory::MF_WRITE ? 'W' : '-')
442                    << (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n");
443       return Allocator.setProtections(LocalAddr, Flags);
444     });
445   }
446
447   std::error_code handleWriteMem() {
448     return handle<WriteMem>(Channel, [&](TargetAddress RDst, uint64_t Size) {
449       char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst));
450       return Channel.readBytes(Dst, Size);
451     });
452   }
453
454   std::error_code handleWritePtr() {
455     return handle<WritePtr>(
456         Channel, [&](TargetAddress Addr, TargetAddress PtrVal) {
457           uintptr_t *Ptr =
458               reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr));
459           *Ptr = static_cast<uintptr_t>(PtrVal);
460           return std::error_code();
461         });
462   }
463
464   ChannelT &Channel;
465   SymbolLookupFtor SymbolLookup;
466   std::map<ResourceIdMgr::ResourceId, Allocator> Allocators;
467   typedef std::vector<typename TargetT::IndirectStubsInfo> ISBlockOwnerList;
468   std::map<ResourceIdMgr::ResourceId, ISBlockOwnerList> IndirectStubsOwners;
469   sys::OwningMemoryBlock ResolverBlock;
470   std::vector<sys::OwningMemoryBlock> TrampolineBlocks;
471 };
472
473 } // end namespace remote
474 } // end namespace orc
475 } // end namespace llvm
476
477 #undef DEBUG_TYPE
478
479 #endif