b9db3890cc70d0e78a735ee874c2399cbf0e6e15
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / OrcRemoteTargetServer.h
1 //===---- OrcRemoteTargetServer.h - Orc Remote-target Server ----*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the OrcRemoteTargetServer class. It can be used to build a
11 // JIT server that can execute code sent from an OrcRemoteTargetClient.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H
16 #define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H
17
18 #include "OrcRemoteTargetRPCAPI.h"
19 #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/Format.h"
22 #include "llvm/Support/Process.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include <map>
25
26 #define DEBUG_TYPE "orc-remote"
27
28 namespace llvm {
29 namespace orc {
30 namespace remote {
31
32 template <typename ChannelT, typename TargetT>
33 class OrcRemoteTargetServer : public OrcRemoteTargetRPCAPI {
34 public:
35   typedef std::function<TargetAddress(const std::string &Name)>
36       SymbolLookupFtor;
37
38   OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup)
39       : Channel(Channel), SymbolLookup(std::move(SymbolLookup)) {}
40
41   std::error_code getNextProcId(JITProcId &Id) {
42     return deserialize(Channel, Id);
43   }
44
45   std::error_code handleKnownProcedure(JITProcId Id) {
46     DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n");
47
48     switch (Id) {
49     case CallIntVoidId:
50       return handleCallIntVoid();
51     case CallMainId:
52       return handleCallMain();
53     case CallVoidVoidId:
54       return handleCallVoidVoid();
55     case CreateRemoteAllocatorId:
56       return handleCreateRemoteAllocator();
57     case CreateIndirectStubsOwnerId:
58       return handleCreateIndirectStubsOwner();
59     case DestroyRemoteAllocatorId:
60       return handleDestroyRemoteAllocator();
61     case EmitIndirectStubsId:
62       return handleEmitIndirectStubs();
63     case EmitResolverBlockId:
64       return handleEmitResolverBlock();
65     case EmitTrampolineBlockId:
66       return handleEmitTrampolineBlock();
67     case GetSymbolAddressId:
68       return handleGetSymbolAddress();
69     case GetRemoteInfoId:
70       return handleGetRemoteInfo();
71     case ReadMemId:
72       return handleReadMem();
73     case ReserveMemId:
74       return handleReserveMem();
75     case SetProtectionsId:
76       return handleSetProtections();
77     case WriteMemId:
78       return handleWriteMem();
79     case WritePtrId:
80       return handleWritePtr();
81     default:
82       return orcError(OrcErrorCode::UnexpectedRPCCall);
83     }
84
85     llvm_unreachable("Unhandled JIT RPC procedure Id.");
86   }
87
88   std::error_code requestCompile(TargetAddress &CompiledFnAddr,
89                                  TargetAddress TrampolineAddr) {
90     if (auto EC = call<RequestCompile>(Channel, TrampolineAddr))
91       return EC;
92
93     while (1) {
94       JITProcId Id = InvalidId;
95       if (auto EC = getNextProcId(Id))
96         return EC;
97
98       switch (Id) {
99       case RequestCompileResponseId:
100         return handle<RequestCompileResponse>(Channel,
101                                               readArgs(CompiledFnAddr));
102       default:
103         if (auto EC = handleKnownProcedure(Id))
104           return EC;
105       }
106     }
107
108     llvm_unreachable("Fell through request-compile command loop.");
109   }
110
111 private:
112   struct Allocator {
113     Allocator() = default;
114     Allocator(Allocator &&Other) : Allocs(std::move(Other.Allocs)) {}
115     Allocator &operator=(Allocator &&Other) {
116       Allocs = std::move(Other.Allocs);
117       return *this;
118     }
119
120     ~Allocator() {
121       for (auto &Alloc : Allocs)
122         sys::Memory::releaseMappedMemory(Alloc.second);
123     }
124
125     std::error_code allocate(void *&Addr, size_t Size, uint32_t Align) {
126       std::error_code EC;
127       sys::MemoryBlock MB = sys::Memory::allocateMappedMemory(
128           Size, nullptr, sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC);
129       if (EC)
130         return EC;
131
132       Addr = MB.base();
133       assert(Allocs.find(MB.base()) == Allocs.end() && "Duplicate alloc");
134       Allocs[MB.base()] = std::move(MB);
135       return std::error_code();
136     }
137
138     std::error_code setProtections(void *block, unsigned Flags) {
139       auto I = Allocs.find(block);
140       if (I == Allocs.end())
141         return orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized);
142       return sys::Memory::protectMappedMemory(I->second, Flags);
143     }
144
145   private:
146     std::map<void *, sys::MemoryBlock> Allocs;
147   };
148
149   static std::error_code doNothing() { return std::error_code(); }
150
151   static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) {
152     TargetAddress CompiledFnAddr = 0;
153
154     auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr);
155     auto EC = T->requestCompile(
156         CompiledFnAddr, static_cast<TargetAddress>(
157                             reinterpret_cast<uintptr_t>(TrampolineAddr)));
158     assert(!EC && "Compile request failed");
159     (void)&EC;
160     return CompiledFnAddr;
161   }
162
163   std::error_code handleCallIntVoid() {
164     typedef int (*IntVoidFnTy)();
165
166     IntVoidFnTy Fn = nullptr;
167     if (std::error_code EC =
168             handle<CallIntVoid>(Channel, [&](TargetAddress Addr) {
169               Fn = reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
170               return std::error_code();
171             }))
172       return EC;
173
174     DEBUG(dbgs() << "  Calling "
175                  << reinterpret_cast<void *>(reinterpret_cast<intptr_t>(Fn))
176                  << "\n");
177     int Result = Fn();
178     DEBUG(dbgs() << "  Result = " << Result << "\n");
179
180     return call<CallIntVoidResponse>(Channel, Result);
181   }
182
183   std::error_code handleCallMain() {
184     typedef int (*MainFnTy)(int, const char *[]);
185
186     MainFnTy Fn = nullptr;
187     std::vector<std::string> Args;
188     if (std::error_code EC = handle<CallMain>(
189             Channel, [&](TargetAddress Addr, std::vector<std::string> &A) {
190               Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
191               Args = std::move(A);
192               return std::error_code();
193             }))
194       return EC;
195
196     int ArgC = Args.size() + 1;
197     int Idx = 1;
198     std::unique_ptr<const char *[]> ArgV(new const char *[ArgC + 1]);
199     ArgV[0] = "<jit process>";
200     for (auto &Arg : Args)
201       ArgV[Idx++] = Arg.c_str();
202
203     DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
204     int Result = Fn(ArgC, ArgV.get());
205     DEBUG(dbgs() << "  Result = " << Result << "\n");
206
207     return call<CallMainResponse>(Channel, Result);
208   }
209
210   std::error_code handleCallVoidVoid() {
211     typedef void (*VoidVoidFnTy)();
212
213     VoidVoidFnTy Fn = nullptr;
214     if (std::error_code EC =
215             handle<CallIntVoid>(Channel, [&](TargetAddress Addr) {
216               Fn = reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr));
217               return std::error_code();
218             }))
219       return EC;
220
221     DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
222     Fn();
223     DEBUG(dbgs() << "  Complete.\n");
224
225     return call<CallVoidVoidResponse>(Channel);
226   }
227
228   std::error_code handleCreateRemoteAllocator() {
229     return handle<CreateRemoteAllocator>(
230         Channel, [&](ResourceIdMgr::ResourceId Id) {
231           auto I = Allocators.find(Id);
232           if (I != Allocators.end())
233             return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse);
234           DEBUG(dbgs() << "  Created allocator " << Id << "\n");
235           Allocators[Id] = Allocator();
236           return std::error_code();
237         });
238   }
239
240   std::error_code handleCreateIndirectStubsOwner() {
241     return handle<CreateIndirectStubsOwner>(
242         Channel, [&](ResourceIdMgr::ResourceId Id) {
243           auto I = IndirectStubsOwners.find(Id);
244           if (I != IndirectStubsOwners.end())
245             return orcError(
246                 OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse);
247           DEBUG(dbgs() << "  Create indirect stubs owner " << Id << "\n");
248           IndirectStubsOwners[Id] = ISBlockOwnerList();
249           return std::error_code();
250         });
251   }
252
253   std::error_code handleDestroyRemoteAllocator() {
254     return handle<DestroyRemoteAllocator>(
255         Channel, [&](ResourceIdMgr::ResourceId Id) {
256           auto I = Allocators.find(Id);
257           if (I == Allocators.end())
258             return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
259           Allocators.erase(I);
260           DEBUG(dbgs() << "  Destroyed allocator " << Id << "\n");
261           return std::error_code();
262         });
263   }
264
265   std::error_code handleDestroyIndirectStubsOwner() {
266     return handle<DestroyIndirectStubsOwner>(
267         Channel, [&](ResourceIdMgr::ResourceId Id) {
268           auto I = IndirectStubsOwners.find(Id);
269           if (I == IndirectStubsOwners.end())
270             return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
271           IndirectStubsOwners.erase(I);
272           return std::error_code();
273         });
274   }
275
276   std::error_code handleEmitIndirectStubs() {
277     ResourceIdMgr::ResourceId ISOwnerId = ~0U;
278     uint32_t NumStubsRequired = 0;
279
280     if (auto EC = handle<EmitIndirectStubs>(
281             Channel, readArgs(ISOwnerId, NumStubsRequired)))
282       return EC;
283
284     DEBUG(dbgs() << "  ISMgr " << ISOwnerId << " request " << NumStubsRequired
285                  << " stubs.\n");
286
287     auto StubOwnerItr = IndirectStubsOwners.find(ISOwnerId);
288     if (StubOwnerItr == IndirectStubsOwners.end())
289       return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
290
291     typename TargetT::IndirectStubsInfo IS;
292     if (auto EC =
293             TargetT::emitIndirectStubsBlock(IS, NumStubsRequired, nullptr))
294       return EC;
295
296     TargetAddress StubsBase =
297         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(IS.getStub(0)));
298     TargetAddress PtrsBase =
299         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(IS.getPtr(0)));
300     uint32_t NumStubsEmitted = IS.getNumStubs();
301
302     auto &BlockList = StubOwnerItr->second;
303     BlockList.push_back(std::move(IS));
304
305     return call<EmitIndirectStubsResponse>(Channel, StubsBase, PtrsBase,
306                                            NumStubsEmitted);
307   }
308
309   std::error_code handleEmitResolverBlock() {
310     if (auto EC = handle<EmitResolverBlock>(Channel, doNothing))
311       return EC;
312
313     std::error_code EC;
314     ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
315         TargetT::ResolverCodeSize, nullptr,
316         sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC));
317     if (EC)
318       return EC;
319
320     TargetT::writeResolverCode(static_cast<uint8_t *>(ResolverBlock.base()),
321                                &reenter, this);
322
323     return sys::Memory::protectMappedMemory(ResolverBlock.getMemoryBlock(),
324                                             sys::Memory::MF_READ |
325                                                 sys::Memory::MF_EXEC);
326   }
327
328   std::error_code handleEmitTrampolineBlock() {
329     if (auto EC = handle<EmitTrampolineBlock>(Channel, doNothing))
330       return EC;
331
332     std::error_code EC;
333
334     auto TrampolineBlock =
335         sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
336             sys::Process::getPageSize(), nullptr,
337             sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC));
338     if (EC)
339       return EC;
340
341     unsigned NumTrampolines =
342         (sys::Process::getPageSize() - TargetT::PointerSize) /
343         TargetT::TrampolineSize;
344
345     uint8_t *TrampolineMem = static_cast<uint8_t *>(TrampolineBlock.base());
346     TargetT::writeTrampolines(TrampolineMem, ResolverBlock.base(),
347                               NumTrampolines);
348
349     EC = sys::Memory::protectMappedMemory(TrampolineBlock.getMemoryBlock(),
350                                           sys::Memory::MF_READ |
351                                               sys::Memory::MF_EXEC);
352
353     TrampolineBlocks.push_back(std::move(TrampolineBlock));
354
355     return call<EmitTrampolineBlockResponse>(
356         Channel,
357         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)),
358         NumTrampolines);
359   }
360
361   std::error_code handleGetSymbolAddress() {
362     std::string SymbolName;
363     if (auto EC = handle<GetSymbolAddress>(Channel, readArgs(SymbolName)))
364       return EC;
365
366     TargetAddress SymbolAddr = SymbolLookup(SymbolName);
367     DEBUG(dbgs() << "  Symbol '" << SymbolName
368                  << "' =  " << format("0x%016x", SymbolAddr) << "\n");
369     return call<GetSymbolAddressResponse>(Channel, SymbolAddr);
370   }
371
372   std::error_code handleGetRemoteInfo() {
373     if (auto EC = handle<GetRemoteInfo>(Channel, doNothing))
374       return EC;
375
376     std::string ProcessTriple = sys::getProcessTriple();
377     uint32_t PointerSize = TargetT::PointerSize;
378     uint32_t PageSize = sys::Process::getPageSize();
379     uint32_t TrampolineSize = TargetT::TrampolineSize;
380     uint32_t IndirectStubSize = TargetT::IndirectStubsInfo::StubSize;
381     DEBUG(dbgs() << "  Remote info:\n"
382                  << "    triple             = '" << ProcessTriple << "'\n"
383                  << "    pointer size       = " << PointerSize << "\n"
384                  << "    page size          = " << PageSize << "\n"
385                  << "    trampoline size    = " << TrampolineSize << "\n"
386                  << "    indirect stub size = " << IndirectStubSize << "\n");
387     return call<GetRemoteInfoResponse>(Channel, ProcessTriple, PointerSize,
388                                        PageSize, TrampolineSize,
389                                        IndirectStubSize);
390   }
391
392   std::error_code handleReadMem() {
393     char *Src = nullptr;
394     uint64_t Size = 0;
395     if (std::error_code EC =
396             handle<ReadMem>(Channel, [&](TargetAddress RSrc, uint64_t RSize) {
397               Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
398               Size = RSize;
399               return std::error_code();
400             }))
401       return EC;
402
403     DEBUG(dbgs() << "  Reading " << Size << " bytes from "
404                  << static_cast<void *>(Src) << "\n");
405
406     if (auto EC = call<ReadMemResponse>(Channel))
407       return EC;
408
409     if (auto EC = Channel.appendBytes(Src, Size))
410       return EC;
411
412     return Channel.send();
413   }
414
415   std::error_code handleReserveMem() {
416     void *LocalAllocAddr = nullptr;
417
418     if (std::error_code EC =
419             handle<ReserveMem>(Channel, [&](ResourceIdMgr::ResourceId Id,
420                                             uint64_t Size, uint32_t Align) {
421               auto I = Allocators.find(Id);
422               if (I == Allocators.end())
423                 return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
424               auto &Allocator = I->second;
425               auto EC2 = Allocator.allocate(LocalAllocAddr, Size, Align);
426               DEBUG(dbgs() << "  Allocator " << Id << " reserved "
427                            << LocalAllocAddr << " (" << Size
428                            << " bytes, alignment " << Align << ")\n");
429               return EC2;
430             }))
431       return EC;
432
433     TargetAddress AllocAddr =
434         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr));
435
436     return call<ReserveMemResponse>(Channel, AllocAddr);
437   }
438
439   std::error_code handleSetProtections() {
440     return handle<ReserveMem>(Channel, [&](ResourceIdMgr::ResourceId Id,
441                                            TargetAddress Addr, uint32_t Flags) {
442       auto I = Allocators.find(Id);
443       if (I == Allocators.end())
444         return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
445       auto &Allocator = I->second;
446       void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr));
447       DEBUG(dbgs() << "  Allocator " << Id << " set permissions on "
448                    << LocalAddr << " to "
449                    << (Flags & sys::Memory::MF_READ ? 'R' : '-')
450                    << (Flags & sys::Memory::MF_WRITE ? 'W' : '-')
451                    << (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n");
452       return Allocator.setProtections(LocalAddr, Flags);
453     });
454   }
455
456   std::error_code handleWriteMem() {
457     return handle<WriteMem>(Channel, [&](TargetAddress RDst, uint64_t Size) {
458       char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst));
459       return Channel.readBytes(Dst, Size);
460     });
461   }
462
463   std::error_code handleWritePtr() {
464     return handle<WritePtr>(
465         Channel, [&](TargetAddress Addr, TargetAddress PtrVal) {
466           uintptr_t *Ptr =
467               reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr));
468           *Ptr = static_cast<uintptr_t>(PtrVal);
469           return std::error_code();
470         });
471   }
472
473   ChannelT &Channel;
474   SymbolLookupFtor SymbolLookup;
475   std::map<ResourceIdMgr::ResourceId, Allocator> Allocators;
476   typedef std::vector<typename TargetT::IndirectStubsInfo> ISBlockOwnerList;
477   std::map<ResourceIdMgr::ResourceId, ISBlockOwnerList> IndirectStubsOwners;
478   sys::OwningMemoryBlock ResolverBlock;
479   std::vector<sys::OwningMemoryBlock> TrampolineBlocks;
480 };
481
482 } // end namespace remote
483 } // end namespace orc
484 } // end namespace llvm
485
486 #undef DEBUG_TYPE
487
488 #endif