Taints the non-acquire RMW's store address with the load part
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / OrcRemoteTargetServer.h
1 //===---- OrcRemoteTargetServer.h - Orc Remote-target Server ----*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines the OrcRemoteTargetServer class. It can be used to build a
11 // JIT server that can execute code sent from an OrcRemoteTargetClient.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H
16 #define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H
17
18 #include "OrcRemoteTargetRPCAPI.h"
19 #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
20 #include "llvm/Support/Debug.h"
21 #include "llvm/Support/Format.h"
22 #include "llvm/Support/Process.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include <map>
25
26 #define DEBUG_TYPE "orc-remote"
27
28 namespace llvm {
29 namespace orc {
30 namespace remote {
31
32 template <typename ChannelT, typename TargetT>
33 class OrcRemoteTargetServer : public OrcRemoteTargetRPCAPI {
34 public:
35   typedef std::function<TargetAddress(const std::string &Name)>
36       SymbolLookupFtor;
37
38   OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup)
39       : Channel(Channel), SymbolLookup(std::move(SymbolLookup)) {}
40
41   std::error_code getNextProcId(JITProcId &Id) {
42     return deserialize(Channel, Id);
43   }
44
45   std::error_code handleKnownProcedure(JITProcId Id) {
46     typedef OrcRemoteTargetServer ThisT;
47
48     DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n");
49
50     switch (Id) {
51     case CallIntVoidId:
52       return handle<CallIntVoid>(Channel, *this, &ThisT::handleCallIntVoid);
53     case CallMainId:
54       return handle<CallMain>(Channel, *this, &ThisT::handleCallMain);
55     case CallVoidVoidId:
56       return handle<CallVoidVoid>(Channel, *this, &ThisT::handleCallVoidVoid);
57     case CreateRemoteAllocatorId:
58       return handle<CreateRemoteAllocator>(Channel, *this,
59                                            &ThisT::handleCreateRemoteAllocator);
60     case CreateIndirectStubsOwnerId:
61       return handle<CreateIndirectStubsOwner>(
62           Channel, *this, &ThisT::handleCreateIndirectStubsOwner);
63     case DestroyRemoteAllocatorId:
64       return handle<DestroyRemoteAllocator>(
65           Channel, *this, &ThisT::handleDestroyRemoteAllocator);
66     case DestroyIndirectStubsOwnerId:
67       return handle<DestroyIndirectStubsOwner>(
68           Channel, *this, &ThisT::handleDestroyIndirectStubsOwner);
69     case EmitIndirectStubsId:
70       return handle<EmitIndirectStubs>(Channel, *this,
71                                        &ThisT::handleEmitIndirectStubs);
72     case EmitResolverBlockId:
73       return handle<EmitResolverBlock>(Channel, *this,
74                                        &ThisT::handleEmitResolverBlock);
75     case EmitTrampolineBlockId:
76       return handle<EmitTrampolineBlock>(Channel, *this,
77                                          &ThisT::handleEmitTrampolineBlock);
78     case GetSymbolAddressId:
79       return handle<GetSymbolAddress>(Channel, *this,
80                                       &ThisT::handleGetSymbolAddress);
81     case GetRemoteInfoId:
82       return handle<GetRemoteInfo>(Channel, *this, &ThisT::handleGetRemoteInfo);
83     case ReadMemId:
84       return handle<ReadMem>(Channel, *this, &ThisT::handleReadMem);
85     case ReserveMemId:
86       return handle<ReserveMem>(Channel, *this, &ThisT::handleReserveMem);
87     case SetProtectionsId:
88       return handle<SetProtections>(Channel, *this,
89                                     &ThisT::handleSetProtections);
90     case WriteMemId:
91       return handle<WriteMem>(Channel, *this, &ThisT::handleWriteMem);
92     case WritePtrId:
93       return handle<WritePtr>(Channel, *this, &ThisT::handleWritePtr);
94     default:
95       return orcError(OrcErrorCode::UnexpectedRPCCall);
96     }
97
98     llvm_unreachable("Unhandled JIT RPC procedure Id.");
99   }
100
101   std::error_code requestCompile(TargetAddress &CompiledFnAddr,
102                                  TargetAddress TrampolineAddr) {
103     if (auto EC = call<RequestCompile>(Channel, TrampolineAddr))
104       return EC;
105
106     while (1) {
107       JITProcId Id = InvalidId;
108       if (auto EC = getNextProcId(Id))
109         return EC;
110
111       switch (Id) {
112       case RequestCompileResponseId:
113         return handle<RequestCompileResponse>(Channel,
114                                               readArgs(CompiledFnAddr));
115       default:
116         if (auto EC = handleKnownProcedure(Id))
117           return EC;
118       }
119     }
120
121     llvm_unreachable("Fell through request-compile command loop.");
122   }
123
124 private:
125   struct Allocator {
126     Allocator() = default;
127     Allocator(Allocator &&Other) : Allocs(std::move(Other.Allocs)) {}
128     Allocator &operator=(Allocator &&Other) {
129       Allocs = std::move(Other.Allocs);
130       return *this;
131     }
132
133     ~Allocator() {
134       for (auto &Alloc : Allocs)
135         sys::Memory::releaseMappedMemory(Alloc.second);
136     }
137
138     std::error_code allocate(void *&Addr, size_t Size, uint32_t Align) {
139       std::error_code EC;
140       sys::MemoryBlock MB = sys::Memory::allocateMappedMemory(
141           Size, nullptr, sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC);
142       if (EC)
143         return EC;
144
145       Addr = MB.base();
146       assert(Allocs.find(MB.base()) == Allocs.end() && "Duplicate alloc");
147       Allocs[MB.base()] = std::move(MB);
148       return std::error_code();
149     }
150
151     std::error_code setProtections(void *block, unsigned Flags) {
152       auto I = Allocs.find(block);
153       if (I == Allocs.end())
154         return orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized);
155       return sys::Memory::protectMappedMemory(I->second, Flags);
156     }
157
158   private:
159     std::map<void *, sys::MemoryBlock> Allocs;
160   };
161
162   static std::error_code doNothing() { return std::error_code(); }
163
164   static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) {
165     TargetAddress CompiledFnAddr = 0;
166
167     auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr);
168     auto EC = T->requestCompile(
169         CompiledFnAddr, static_cast<TargetAddress>(
170                             reinterpret_cast<uintptr_t>(TrampolineAddr)));
171     assert(!EC && "Compile request failed");
172     (void)EC;
173     return CompiledFnAddr;
174   }
175
176   std::error_code handleCallIntVoid(TargetAddress Addr) {
177     typedef int (*IntVoidFnTy)();
178     IntVoidFnTy Fn =
179         reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
180
181     DEBUG(dbgs() << "  Calling "
182                  << reinterpret_cast<void *>(reinterpret_cast<intptr_t>(Fn))
183                  << "\n");
184     int Result = Fn();
185     DEBUG(dbgs() << "  Result = " << Result << "\n");
186
187     return call<CallIntVoidResponse>(Channel, Result);
188   }
189
190   std::error_code handleCallMain(TargetAddress Addr,
191                                  std::vector<std::string> Args) {
192     typedef int (*MainFnTy)(int, const char *[]);
193
194     MainFnTy Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
195     int ArgC = Args.size() + 1;
196     int Idx = 1;
197     std::unique_ptr<const char *[]> ArgV(new const char *[ArgC + 1]);
198     ArgV[0] = "<jit process>";
199     for (auto &Arg : Args)
200       ArgV[Idx++] = Arg.c_str();
201
202     DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
203     int Result = Fn(ArgC, ArgV.get());
204     DEBUG(dbgs() << "  Result = " << Result << "\n");
205
206     return call<CallMainResponse>(Channel, Result);
207   }
208
209   std::error_code handleCallVoidVoid(TargetAddress Addr) {
210     typedef void (*VoidVoidFnTy)();
211     VoidVoidFnTy Fn =
212         reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr));
213
214     DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
215     Fn();
216     DEBUG(dbgs() << "  Complete.\n");
217
218     return call<CallVoidVoidResponse>(Channel);
219   }
220
221   std::error_code handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) {
222     auto I = Allocators.find(Id);
223     if (I != Allocators.end())
224       return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse);
225     DEBUG(dbgs() << "  Created allocator " << Id << "\n");
226     Allocators[Id] = Allocator();
227     return std::error_code();
228   }
229
230   std::error_code handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
231     auto I = IndirectStubsOwners.find(Id);
232     if (I != IndirectStubsOwners.end())
233       return orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse);
234     DEBUG(dbgs() << "  Create indirect stubs owner " << Id << "\n");
235     IndirectStubsOwners[Id] = ISBlockOwnerList();
236     return std::error_code();
237   }
238
239   std::error_code handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) {
240     auto I = Allocators.find(Id);
241     if (I == Allocators.end())
242       return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
243     Allocators.erase(I);
244     DEBUG(dbgs() << "  Destroyed allocator " << Id << "\n");
245     return std::error_code();
246   }
247
248   std::error_code
249   handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
250     auto I = IndirectStubsOwners.find(Id);
251     if (I == IndirectStubsOwners.end())
252       return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
253     IndirectStubsOwners.erase(I);
254     return std::error_code();
255   }
256
257   std::error_code handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id,
258                                           uint32_t NumStubsRequired) {
259     DEBUG(dbgs() << "  ISMgr " << Id << " request " << NumStubsRequired
260                  << " stubs.\n");
261
262     auto StubOwnerItr = IndirectStubsOwners.find(Id);
263     if (StubOwnerItr == IndirectStubsOwners.end())
264       return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
265
266     typename TargetT::IndirectStubsInfo IS;
267     if (auto EC =
268             TargetT::emitIndirectStubsBlock(IS, NumStubsRequired, nullptr))
269       return EC;
270
271     TargetAddress StubsBase =
272         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(IS.getStub(0)));
273     TargetAddress PtrsBase =
274         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(IS.getPtr(0)));
275     uint32_t NumStubsEmitted = IS.getNumStubs();
276
277     auto &BlockList = StubOwnerItr->second;
278     BlockList.push_back(std::move(IS));
279
280     return call<EmitIndirectStubsResponse>(Channel, StubsBase, PtrsBase,
281                                            NumStubsEmitted);
282   }
283
284   std::error_code handleEmitResolverBlock() {
285     std::error_code EC;
286     ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
287         TargetT::ResolverCodeSize, nullptr,
288         sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC));
289     if (EC)
290       return EC;
291
292     TargetT::writeResolverCode(static_cast<uint8_t *>(ResolverBlock.base()),
293                                &reenter, this);
294
295     return sys::Memory::protectMappedMemory(ResolverBlock.getMemoryBlock(),
296                                             sys::Memory::MF_READ |
297                                                 sys::Memory::MF_EXEC);
298   }
299
300   std::error_code handleEmitTrampolineBlock() {
301     std::error_code EC;
302     auto TrampolineBlock =
303         sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
304             sys::Process::getPageSize(), nullptr,
305             sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC));
306     if (EC)
307       return EC;
308
309     unsigned NumTrampolines =
310         (sys::Process::getPageSize() - TargetT::PointerSize) /
311         TargetT::TrampolineSize;
312
313     uint8_t *TrampolineMem = static_cast<uint8_t *>(TrampolineBlock.base());
314     TargetT::writeTrampolines(TrampolineMem, ResolverBlock.base(),
315                               NumTrampolines);
316
317     EC = sys::Memory::protectMappedMemory(TrampolineBlock.getMemoryBlock(),
318                                           sys::Memory::MF_READ |
319                                               sys::Memory::MF_EXEC);
320
321     TrampolineBlocks.push_back(std::move(TrampolineBlock));
322
323     return call<EmitTrampolineBlockResponse>(
324         Channel,
325         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)),
326         NumTrampolines);
327   }
328
329   std::error_code handleGetSymbolAddress(const std::string &Name) {
330     TargetAddress Addr = SymbolLookup(Name);
331     DEBUG(dbgs() << "  Symbol '" << Name << "' =  " << format("0x%016x", Addr)
332                  << "\n");
333     return call<GetSymbolAddressResponse>(Channel, Addr);
334   }
335
336   std::error_code handleGetRemoteInfo() {
337     std::string ProcessTriple = sys::getProcessTriple();
338     uint32_t PointerSize = TargetT::PointerSize;
339     uint32_t PageSize = sys::Process::getPageSize();
340     uint32_t TrampolineSize = TargetT::TrampolineSize;
341     uint32_t IndirectStubSize = TargetT::IndirectStubsInfo::StubSize;
342     DEBUG(dbgs() << "  Remote info:\n"
343                  << "    triple             = '" << ProcessTriple << "'\n"
344                  << "    pointer size       = " << PointerSize << "\n"
345                  << "    page size          = " << PageSize << "\n"
346                  << "    trampoline size    = " << TrampolineSize << "\n"
347                  << "    indirect stub size = " << IndirectStubSize << "\n");
348     return call<GetRemoteInfoResponse>(Channel, ProcessTriple, PointerSize,
349                                        PageSize, TrampolineSize,
350                                        IndirectStubSize);
351   }
352
353   std::error_code handleReadMem(TargetAddress RSrc, uint64_t Size) {
354     char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
355
356     DEBUG(dbgs() << "  Reading " << Size << " bytes from "
357                  << static_cast<void *>(Src) << "\n");
358
359     if (auto EC = call<ReadMemResponse>(Channel))
360       return EC;
361
362     if (auto EC = Channel.appendBytes(Src, Size))
363       return EC;
364
365     return Channel.send();
366   }
367
368   std::error_code handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size,
369                                    uint32_t Align) {
370     auto I = Allocators.find(Id);
371     if (I == Allocators.end())
372       return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
373     auto &Allocator = I->second;
374     void *LocalAllocAddr = nullptr;
375     if (auto EC = Allocator.allocate(LocalAllocAddr, Size, Align))
376       return EC;
377
378     DEBUG(dbgs() << "  Allocator " << Id << " reserved " << LocalAllocAddr
379                  << " (" << Size << " bytes, alignment " << Align << ")\n");
380
381     TargetAddress AllocAddr =
382         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr));
383
384     return call<ReserveMemResponse>(Channel, AllocAddr);
385   }
386
387   std::error_code handleSetProtections(ResourceIdMgr::ResourceId Id,
388                                        TargetAddress Addr, uint32_t Flags) {
389     auto I = Allocators.find(Id);
390     if (I == Allocators.end())
391       return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
392     auto &Allocator = I->second;
393     void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr));
394     DEBUG(dbgs() << "  Allocator " << Id << " set permissions on " << LocalAddr
395                  << " to " << (Flags & sys::Memory::MF_READ ? 'R' : '-')
396                  << (Flags & sys::Memory::MF_WRITE ? 'W' : '-')
397                  << (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n");
398     return Allocator.setProtections(LocalAddr, Flags);
399   }
400
401   std::error_code handleWriteMem(TargetAddress RDst, uint64_t Size) {
402     char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst));
403     DEBUG(dbgs() << "  Writing " << Size << " bytes to "
404                  << format("0x%016x", RDst) << "\n");
405     return Channel.readBytes(Dst, Size);
406   }
407
408   std::error_code handleWritePtr(TargetAddress Addr, TargetAddress PtrVal) {
409     DEBUG(dbgs() << "  Writing pointer *" << format("0x%016x", Addr) << " = "
410                  << format("0x%016x", PtrVal) << "\n");
411     uintptr_t *Ptr =
412         reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr));
413     *Ptr = static_cast<uintptr_t>(PtrVal);
414     return std::error_code();
415   }
416
417   ChannelT &Channel;
418   SymbolLookupFtor SymbolLookup;
419   std::map<ResourceIdMgr::ResourceId, Allocator> Allocators;
420   typedef std::vector<typename TargetT::IndirectStubsInfo> ISBlockOwnerList;
421   std::map<ResourceIdMgr::ResourceId, ISBlockOwnerList> IndirectStubsOwners;
422   sys::OwningMemoryBlock ResolverBlock;
423   std::vector<sys::OwningMemoryBlock> TrampolineBlocks;
424 };
425
426 } // end namespace remote
427 } // end namespace orc
428 } // end namespace llvm
429
430 #undef DEBUG_TYPE
431
432 #endif