From: Lang Hames Date: Mon, 11 Jan 2016 01:40:11 +0000 (+0000) Subject: [Orc] Add support for remote JITing to the ORC API. X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=commitdiff_plain;h=51c60258a4ad2ad2181734448eeb6dea16d53e0b [Orc] Add support for remote JITing to the ORC API. This patch adds utilities to ORC for managing a remote JIT target. It consists of: 1. A very primitive RPC system for making calls over a byte-stream. See RPCChannel.h, RPCUtils.h. 2. An RPC API defined in the above system for managing memory, looking up symbols, creating stubs, etc. on a remote target. See OrcRemoteTargetRPCAPI.h. 3. An interface for creating high-level JIT components (memory managers, callback managers, stub managers, etc.) that operate over the RPC API. See OrcRemoteTargetClient.h. 4. A helper class for building servers that can handle the RPC calls. See OrcRemoteTargetServer.h. The system is designed to work neatly with the existing ORC components and functionality. In particular, the ORC callback API (and consequently the CompileOnDemandLayer) is supported, enabling lazy compilation of remote code. Assuming this doesn't trigger any builder failures, a follow-up patch will be committed which tests these utilities by using them to replace LLI's existing remote-JITing demo code. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@257305 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h new file mode 100644 index 00000000000..8512fd2e125 --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -0,0 +1,743 @@ +//===---- OrcRemoteTargetClient.h - Orc Remote-target Client ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the OrcRemoteTargetClient class and helpers. This class +// can be used to communicate over an RPCChannel with an OrcRemoteTargetServer +// instance to support remote-JITing. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETCLIENT_H +#define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETCLIENT_H + +#include "OrcRemoteTargetRPCAPI.h" + +#define DEBUG_TYPE "orc-remote" + +namespace llvm { +namespace orc { +namespace remote { + +/// This class provides utilities (including memory manager, indirect stubs +/// manager, and compile callback manager types) that support remote JITing +/// in ORC. +/// +/// Each of the utility classes talks to a JIT server (an instance of the +/// OrcRemoteTargetServer class) via an RPC system (see RPCUtils.h) to carry out +/// its actions. +template +class OrcRemoteTargetClient : public OrcRemoteTargetRPCAPI { +public: + /// Remote memory manager. + class RCMemoryManager : public RuntimeDyld::MemoryManager { + public: + RCMemoryManager(OrcRemoteTargetClient &Client, ResourceIdMgr::ResourceId Id) + : Client(Client), Id(Id) { + DEBUG(dbgs() << "Created remote allocator " << Id << "\n"); + } + + ~RCMemoryManager() { + Client.destroyRemoteAllocator(Id); + DEBUG(dbgs() << "Destroyed remote allocator " << Id << "\n"); + } + + uint8_t *allocateCodeSection(uintptr_t Size, unsigned Alignment, + unsigned SectionID, + StringRef SectionName) override { + Unmapped.back().CodeAllocs.emplace_back(Size, Alignment); + uint8_t *Alloc = reinterpret_cast( + Unmapped.back().CodeAllocs.back().getLocalAddress()); + DEBUG(dbgs() << "Allocator " << Id << " allocated code for " + << SectionName << ": " << Alloc << " (" << Size + << " bytes, alignment " << Alignment << ")\n"); + return Alloc; + } + + uint8_t *allocateDataSection(uintptr_t Size, unsigned Alignment, + unsigned SectionID, StringRef SectionName, + bool IsReadOnly) override { + if (IsReadOnly) { + Unmapped.back().RODataAllocs.emplace_back(Size, Alignment); + uint8_t *Alloc = reinterpret_cast( + Unmapped.back().RODataAllocs.back().getLocalAddress()); + DEBUG(dbgs() << "Allocator " << Id << " allocated ro-data for " + << SectionName << ": " << Alloc << " (" << Size + << " bytes, alignment " << Alignment << ")\n"); + return Alloc; + } // else... + + Unmapped.back().RWDataAllocs.emplace_back(Size, Alignment); + uint8_t *Alloc = reinterpret_cast( + Unmapped.back().RWDataAllocs.back().getLocalAddress()); + DEBUG(dbgs() << "Allocator " << Id << " allocated rw-data for " + << SectionName << ": " << Alloc << " (" << Size + << " bytes, alignment " << Alignment << "\n"); + return Alloc; + } + + void reserveAllocationSpace(uintptr_t CodeSize, uint32_t CodeAlign, + uintptr_t RODataSize, uint32_t RODataAlign, + uintptr_t RWDataSize, + uint32_t RWDataAlign) override { + Unmapped.push_back(ObjectAllocs()); + + DEBUG(dbgs() << "Allocator " << Id << " reserved:\n"); + + if (CodeSize != 0) { + if (auto EC = Client.reserveMem(Unmapped.back().RemoteCodeAddr, Id, + CodeSize, CodeAlign)) { + // FIXME; Add error to poll. + llvm_unreachable("Failed reserving remote memory."); + } + DEBUG(dbgs() << " code: " + << format("0x%016x", Unmapped.back().RemoteCodeAddr) + << " (" << CodeSize << " bytes, alignment " << CodeAlign + << ")\n"); + } + + if (RODataSize != 0) { + if (auto EC = Client.reserveMem(Unmapped.back().RemoteRODataAddr, Id, + RODataSize, RODataAlign)) { + // FIXME; Add error to poll. + llvm_unreachable("Failed reserving remote memory."); + } + DEBUG(dbgs() << " ro-data: " + << format("0x%016x", Unmapped.back().RemoteRODataAddr) + << " (" << RODataSize << " bytes, alignment " + << RODataAlign << ")\n"); + } + + if (RWDataSize != 0) { + if (auto EC = Client.reserveMem(Unmapped.back().RemoteRWDataAddr, Id, + RWDataSize, RWDataAlign)) { + // FIXME; Add error to poll. + llvm_unreachable("Failed reserving remote memory."); + } + DEBUG(dbgs() << " rw-data: " + << format("0x%016x", Unmapped.back().RemoteRWDataAddr) + << " (" << RWDataSize << " bytes, alignment " + << RWDataAlign << ")\n"); + } + } + + bool needsToReserveAllocationSpace() override { return true; } + + void registerEHFrames(uint8_t *Addr, uint64_t LoadAddr, + size_t Size) override {} + + void deregisterEHFrames(uint8_t *addr, uint64_t LoadAddr, + size_t Size) override {} + + void notifyObjectLoaded(RuntimeDyld &Dyld, + const object::ObjectFile &Obj) override { + DEBUG(dbgs() << "Allocator " << Id << " applied mappings:\n"); + for (auto &ObjAllocs : Unmapped) { + { + TargetAddress NextCodeAddr = ObjAllocs.RemoteCodeAddr; + for (auto &Alloc : ObjAllocs.CodeAllocs) { + NextCodeAddr = RoundUpToAlignment(NextCodeAddr, Alloc.getAlign()); + Dyld.mapSectionAddress(Alloc.getLocalAddress(), NextCodeAddr); + DEBUG(dbgs() << " code: " + << static_cast(Alloc.getLocalAddress()) + << " -> " << format("0x%016x", NextCodeAddr) << "\n"); + Alloc.setRemoteAddress(NextCodeAddr); + NextCodeAddr += Alloc.getSize(); + } + } + { + TargetAddress NextRODataAddr = ObjAllocs.RemoteRODataAddr; + for (auto &Alloc : ObjAllocs.RODataAllocs) { + NextRODataAddr = + RoundUpToAlignment(NextRODataAddr, Alloc.getAlign()); + Dyld.mapSectionAddress(Alloc.getLocalAddress(), NextRODataAddr); + DEBUG(dbgs() << " ro-data: " + << static_cast(Alloc.getLocalAddress()) + << " -> " << format("0x%016x", NextRODataAddr) + << "\n"); + Alloc.setRemoteAddress(NextRODataAddr); + NextRODataAddr += Alloc.getSize(); + } + } + { + TargetAddress NextRWDataAddr = ObjAllocs.RemoteRWDataAddr; + for (auto &Alloc : ObjAllocs.RWDataAllocs) { + NextRWDataAddr = + RoundUpToAlignment(NextRWDataAddr, Alloc.getAlign()); + Dyld.mapSectionAddress(Alloc.getLocalAddress(), NextRWDataAddr); + DEBUG(dbgs() << " rw-data: " + << static_cast(Alloc.getLocalAddress()) + << " -> " << format("0x%016x", NextRWDataAddr) + << "\n"); + Alloc.setRemoteAddress(NextRWDataAddr); + NextRWDataAddr += Alloc.getSize(); + } + } + Unfinalized.push_back(std::move(ObjAllocs)); + } + Unmapped.clear(); + } + + bool finalizeMemory(std::string *ErrMsg = nullptr) override { + DEBUG(dbgs() << "Allocator " << Id << " finalizing:\n"); + + for (auto &ObjAllocs : Unfinalized) { + + for (auto &Alloc : ObjAllocs.CodeAllocs) { + DEBUG(dbgs() << " copying code: " + << static_cast(Alloc.getLocalAddress()) << " -> " + << format("0x%016x", Alloc.getRemoteAddress()) << " (" + << Alloc.getSize() << " bytes)\n"); + Client.writeMem(Alloc.getRemoteAddress(), Alloc.getLocalAddress(), + Alloc.getSize()); + } + + if (ObjAllocs.RemoteCodeAddr) { + DEBUG(dbgs() << " setting R-X permissions on code block: " + << format("0x%016x", ObjAllocs.RemoteCodeAddr) << "\n"); + Client.setProtections(Id, ObjAllocs.RemoteCodeAddr, + sys::Memory::MF_READ | sys::Memory::MF_EXEC); + } + + for (auto &Alloc : ObjAllocs.RODataAllocs) { + DEBUG(dbgs() << " copying ro-data: " + << static_cast(Alloc.getLocalAddress()) << " -> " + << format("0x%016x", Alloc.getRemoteAddress()) << " (" + << Alloc.getSize() << " bytes)\n"); + Client.writeMem(Alloc.getRemoteAddress(), Alloc.getLocalAddress(), + Alloc.getSize()); + } + + if (ObjAllocs.RemoteRODataAddr) { + DEBUG(dbgs() << " setting R-- permissions on ro-data block: " + << format("0x%016x", ObjAllocs.RemoteRODataAddr) + << "\n"); + Client.setProtections(Id, ObjAllocs.RemoteRODataAddr, + sys::Memory::MF_READ); + } + + for (auto &Alloc : ObjAllocs.RWDataAllocs) { + DEBUG(dbgs() << " copying rw-data: " + << static_cast(Alloc.getLocalAddress()) << " -> " + << format("0x%016x", Alloc.getRemoteAddress()) << " (" + << Alloc.getSize() << " bytes)\n"); + Client.writeMem(Alloc.getRemoteAddress(), Alloc.getLocalAddress(), + Alloc.getSize()); + } + + if (ObjAllocs.RemoteRWDataAddr) { + DEBUG(dbgs() << " setting RW- permissions on rw-data block: " + << format("0x%016x", ObjAllocs.RemoteRWDataAddr) + << "\n"); + Client.setProtections(Id, ObjAllocs.RemoteRWDataAddr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE); + } + } + Unfinalized.clear(); + + return false; + } + + private: + class Alloc { + public: + Alloc(uint64_t Size, unsigned Align) + : Size(Size), Align(Align), Contents(new char[Size + Align - 1]), + RemoteAddr(0) {} + + uint64_t getSize() const { return Size; } + + unsigned getAlign() const { return Align; } + + char *getLocalAddress() const { + uintptr_t LocalAddr = reinterpret_cast(Contents.get()); + LocalAddr = RoundUpToAlignment(LocalAddr, Align); + return reinterpret_cast(LocalAddr); + } + + void setRemoteAddress(TargetAddress RemoteAddr) { + this->RemoteAddr = RemoteAddr; + } + + TargetAddress getRemoteAddress() const { return RemoteAddr; } + + private: + uint64_t Size; + unsigned Align; + std::unique_ptr Contents; + TargetAddress RemoteAddr; + }; + + struct ObjectAllocs { + ObjectAllocs() + : RemoteCodeAddr(0), RemoteRODataAddr(0), RemoteRWDataAddr(0) {} + TargetAddress RemoteCodeAddr; + TargetAddress RemoteRODataAddr; + TargetAddress RemoteRWDataAddr; + std::vector CodeAllocs, RODataAllocs, RWDataAllocs; + }; + + OrcRemoteTargetClient &Client; + ResourceIdMgr::ResourceId Id; + std::vector Unmapped; + std::vector Unfinalized; + }; + + /// Remote indirect stubs manager. + class RCIndirectStubsManager : public IndirectStubsManager { + public: + RCIndirectStubsManager(OrcRemoteTargetClient &Remote, + ResourceIdMgr::ResourceId Id) + : Remote(Remote), Id(Id) {} + + ~RCIndirectStubsManager() { Remote.destroyIndirectStubsManager(Id); } + + std::error_code createStub(StringRef StubName, TargetAddress StubAddr, + JITSymbolFlags StubFlags) override { + if (auto EC = reserveStubs(1)) + return EC; + + return createStubInternal(StubName, StubAddr, StubFlags); + } + + std::error_code createStubs(const StubInitsMap &StubInits) override { + if (auto EC = reserveStubs(StubInits.size())) + return EC; + + for (auto &Entry : StubInits) + if (auto EC = createStubInternal(Entry.first(), Entry.second.first, + Entry.second.second)) + return EC; + + return std::error_code(); + } + + JITSymbol findStub(StringRef Name, bool ExportedStubsOnly) override { + auto I = StubIndexes.find(Name); + if (I == StubIndexes.end()) + return nullptr; + auto Key = I->second.first; + auto Flags = I->second.second; + auto StubSymbol = JITSymbol(getStubAddr(Key), Flags); + if (ExportedStubsOnly && !StubSymbol.isExported()) + return nullptr; + return StubSymbol; + } + + JITSymbol findPointer(StringRef Name) override { + auto I = StubIndexes.find(Name); + if (I == StubIndexes.end()) + return nullptr; + auto Key = I->second.first; + auto Flags = I->second.second; + return JITSymbol(getPtrAddr(Key), Flags); + } + + std::error_code updatePointer(StringRef Name, + TargetAddress NewAddr) override { + auto I = StubIndexes.find(Name); + assert(I != StubIndexes.end() && "No stub pointer for symbol"); + auto Key = I->second.first; + return Remote.writePointer(getPtrAddr(Key), NewAddr); + } + + private: + struct RemoteIndirectStubsInfo { + RemoteIndirectStubsInfo(TargetAddress StubBase, TargetAddress PtrBase, + unsigned NumStubs) + : StubBase(StubBase), PtrBase(PtrBase), NumStubs(NumStubs) {} + TargetAddress StubBase; + TargetAddress PtrBase; + unsigned NumStubs; + }; + + OrcRemoteTargetClient &Remote; + ResourceIdMgr::ResourceId Id; + std::vector RemoteIndirectStubsInfos; + typedef std::pair StubKey; + std::vector FreeStubs; + StringMap> StubIndexes; + + std::error_code reserveStubs(unsigned NumStubs) { + if (NumStubs <= FreeStubs.size()) + return std::error_code(); + + unsigned NewStubsRequired = NumStubs - FreeStubs.size(); + TargetAddress StubBase; + TargetAddress PtrBase; + unsigned NumStubsEmitted; + + Remote.emitIndirectStubs(StubBase, PtrBase, NumStubsEmitted, Id, + NewStubsRequired); + + unsigned NewBlockId = RemoteIndirectStubsInfos.size(); + RemoteIndirectStubsInfos.push_back( + RemoteIndirectStubsInfo(StubBase, PtrBase, NumStubsEmitted)); + + for (unsigned I = 0; I < NumStubsEmitted; ++I) + FreeStubs.push_back(std::make_pair(NewBlockId, I)); + + return std::error_code(); + } + + std::error_code createStubInternal(StringRef StubName, + TargetAddress InitAddr, + JITSymbolFlags StubFlags) { + auto Key = FreeStubs.back(); + FreeStubs.pop_back(); + StubIndexes[StubName] = std::make_pair(Key, StubFlags); + return Remote.writePointer(getPtrAddr(Key), InitAddr); + } + + TargetAddress getStubAddr(StubKey K) { + assert(RemoteIndirectStubsInfos[K.first].StubBase != 0 && + "Missing stub address"); + return RemoteIndirectStubsInfos[K.first].StubBase + + K.second * Remote.getIndirectStubSize(); + } + + TargetAddress getPtrAddr(StubKey K) { + assert(RemoteIndirectStubsInfos[K.first].PtrBase != 0 && + "Missing pointer address"); + return RemoteIndirectStubsInfos[K.first].PtrBase + + K.second * Remote.getPointerSize(); + } + }; + + /// Remote compile callback manager. + class RCCompileCallbackManager : public JITCompileCallbackManager { + public: + RCCompileCallbackManager(TargetAddress ErrorHandlerAddress, + OrcRemoteTargetClient &Remote) + : JITCompileCallbackManager(ErrorHandlerAddress), Remote(Remote) { + assert(!Remote.CompileCallback && "Compile callback already set"); + Remote.CompileCallback = [this](TargetAddress TrampolineAddr) { + return executeCompileCallback(TrampolineAddr); + }; + Remote.emitResolverBlock(); + } + + private: + void grow() { + TargetAddress BlockAddr = 0; + uint32_t NumTrampolines = 0; + auto EC = Remote.emitTrampolineBlock(BlockAddr, NumTrampolines); + assert(!EC && "Failed to create trampolines"); + + uint32_t TrampolineSize = Remote.getTrampolineSize(); + for (unsigned I = 0; I < NumTrampolines; ++I) + this->AvailableTrampolines.push_back(BlockAddr + (I * TrampolineSize)); + } + + OrcRemoteTargetClient &Remote; + }; + + /// Create an OrcRemoteTargetClient. + /// Channel is the ChannelT instance to communicate on. It is assumed that + /// the channel is ready to be read from and written to. + static ErrorOr Create(ChannelT &Channel) { + std::error_code EC; + OrcRemoteTargetClient H(Channel, EC); + if (EC) + return EC; + return H; + } + + /// Call the int(void) function at the given address in the target and return + /// its result. + std::error_code callIntVoid(int &Result, TargetAddress Addr) { + DEBUG(dbgs() << "Calling int(*)(void) " << format("0x%016x", Addr) << "\n"); + + if (auto EC = call(Channel, Addr)) + return EC; + + unsigned NextProcId; + if (auto EC = listenForCompileRequests(NextProcId)) + return EC; + + if (NextProcId != CallIntVoidResponseId) + return orcError(OrcErrorCode::UnexpectedRPCCall); + + return handle(Channel, [&](int R) { + Result = R; + DEBUG(dbgs() << "Result: " << R << "\n"); + return std::error_code(); + }); + } + + /// Call the int(int, char*[]) function at the given address in the target and + /// return its result. + std::error_code callMain(int &Result, TargetAddress Addr, + const std::vector &Args) { + DEBUG(dbgs() << "Calling int(*)(int, char*[]) " << format("0x%016x", Addr) + << "\n"); + + if (auto EC = call(Channel, Addr, Args)) + return EC; + + unsigned NextProcId; + if (auto EC = listenForCompileRequests(NextProcId)) + return EC; + + if (NextProcId != CallMainResponseId) + return orcError(OrcErrorCode::UnexpectedRPCCall); + + return handle(Channel, [&](int R) { + Result = R; + DEBUG(dbgs() << "Result: " << R << "\n"); + return std::error_code(); + }); + } + + /// Call the void() function at the given address in the target and wait for + /// it to finish. + std::error_code callVoidVoid(TargetAddress Addr) { + DEBUG(dbgs() << "Calling void(*)(void) " << format("0x%016x", Addr) + << "\n"); + + if (auto EC = call(Channel, Addr)) + return EC; + + unsigned NextProcId; + if (auto EC = listenForCompileRequests(NextProcId)) + return EC; + + if (NextProcId != CallVoidVoidResponseId) + return orcError(OrcErrorCode::UnexpectedRPCCall); + + return handle(Channel, doNothing); + } + + /// Create an RCMemoryManager which will allocate its memory on the remote + /// target. + std::error_code + createRemoteMemoryManager(std::unique_ptr &MM) { + assert(!MM && "MemoryManager should be null before creation."); + + auto Id = AllocatorIds.getNext(); + if (auto EC = call(Channel, Id)) + return EC; + MM = llvm::make_unique(*this, Id); + return std::error_code(); + } + + /// Create an RCIndirectStubsManager that will allocate stubs on the remote + /// target. + std::error_code + createIndirectStubsManager(std::unique_ptr &I) { + assert(!I && "Indirect stubs manager should be null before creation."); + auto Id = IndirectStubOwnerIds.getNext(); + if (auto EC = call(Channel, Id)) + return EC; + I = llvm::make_unique(*this, Id); + return std::error_code(); + } + + /// Search for symbols in the remote process. Note: This should be used by + /// symbol resolvers *after* they've searched the local symbol table in the + /// JIT stack. + std::error_code getSymbolAddress(TargetAddress &Addr, StringRef Name) { + // Check for an 'out-of-band' error, e.g. from an MM destructor. + if (ExistingError) + return ExistingError; + + // Request remote symbol address. + if (auto EC = call(Channel, Name)) + return EC; + + return expect(Channel, [&](TargetAddress &A) { + Addr = A; + DEBUG(dbgs() << "Remote address lookup " << Name << " = " + << format("0x%016x", Addr) << "\n"); + return std::error_code(); + }); + } + + /// Get the triple for the remote target. + const std::string &getTargetTriple() const { return RemoteTargetTriple; } + + std::error_code terminateSession() { return call(Channel); } + +private: + OrcRemoteTargetClient(ChannelT &Channel, std::error_code &EC) + : Channel(Channel), RemotePointerSize(0), RemotePageSize(0), + RemoteTrampolineSize(0), RemoteIndirectStubSize(0) { + if ((EC = call(Channel))) + return; + + EC = expect( + Channel, readArgs(RemoteTargetTriple, RemotePointerSize, RemotePageSize, + RemoteTrampolineSize, RemoteIndirectStubSize)); + } + + void destroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { + if (auto EC = call(Channel, Id)) { + // FIXME: This will be triggered by a removeModuleSet call: Propagate + // error return up through that. + llvm_unreachable("Failed to destroy remote allocator."); + AllocatorIds.release(Id); + } + } + + std::error_code destroyIndirectStubsManager(ResourceIdMgr::ResourceId Id) { + IndirectStubOwnerIds.release(Id); + return call(Channel, Id); + } + + std::error_code emitIndirectStubs(TargetAddress &StubBase, + TargetAddress &PtrBase, + uint32_t &NumStubsEmitted, + ResourceIdMgr::ResourceId Id, + uint32_t NumStubsRequired) { + if (auto EC = call(Channel, Id, NumStubsRequired)) + return EC; + + return expect( + Channel, readArgs(StubBase, PtrBase, NumStubsEmitted)); + } + + std::error_code emitResolverBlock() { + // Check for an 'out-of-band' error, e.g. from an MM destructor. + if (ExistingError) + return ExistingError; + + return call(Channel); + } + + std::error_code emitTrampolineBlock(TargetAddress &BlockAddr, + uint32_t &NumTrampolines) { + // Check for an 'out-of-band' error, e.g. from an MM destructor. + if (ExistingError) + return ExistingError; + + if (auto EC = call(Channel)) + return EC; + + return expect( + Channel, [&](TargetAddress BAddr, uint32_t NTrampolines) { + BlockAddr = BAddr; + NumTrampolines = NTrampolines; + return std::error_code(); + }); + } + + uint32_t getIndirectStubSize() const { return RemoteIndirectStubSize; } + uint32_t getPageSize() const { return RemotePageSize; } + uint32_t getPointerSize() const { return RemotePointerSize; } + + uint32_t getTrampolineSize() const { return RemoteTrampolineSize; } + + std::error_code listenForCompileRequests(uint32_t &NextId) { + // Check for an 'out-of-band' error, e.g. from an MM destructor. + if (ExistingError) + return ExistingError; + + if (auto EC = getNextProcId(Channel, NextId)) + return EC; + + while (NextId == RequestCompileId) { + TargetAddress TrampolineAddr = 0; + if (auto EC = handle(Channel, readArgs(TrampolineAddr))) + return EC; + + TargetAddress ImplAddr = CompileCallback(TrampolineAddr); + if (auto EC = call(Channel, ImplAddr)) + return EC; + + if (auto EC = getNextProcId(Channel, NextId)) + return EC; + } + + return std::error_code(); + } + + std::error_code readMem(char *Dst, TargetAddress Src, uint64_t Size) { + // Check for an 'out-of-band' error, e.g. from an MM destructor. + if (ExistingError) + return ExistingError; + + if (auto EC = call(Channel, Src, Size)) + return EC; + + if (auto EC = expect( + Channel, [&]() { return Channel.readBytes(Dst, Size); })) + return EC; + + return std::error_code(); + } + + std::error_code reserveMem(TargetAddress &RemoteAddr, + ResourceIdMgr::ResourceId Id, uint64_t Size, + uint32_t Align) { + + // Check for an 'out-of-band' error, e.g. from an MM destructor. + if (ExistingError) + return ExistingError; + + if (auto EC = call(Channel, Id, Size, Align)) + return EC; + + if (auto EC = expect(Channel, [&](TargetAddress Addr) { + RemoteAddr = Addr; + return std::error_code(); + })) + return EC; + + return std::error_code(); + } + + std::error_code setProtections(ResourceIdMgr::ResourceId Id, + TargetAddress RemoteSegAddr, + unsigned ProtFlags) { + return call(Channel, Id, RemoteSegAddr, ProtFlags); + } + + std::error_code writeMem(TargetAddress Addr, const char *Src, uint64_t Size) { + // Check for an 'out-of-band' error, e.g. from an MM destructor. + if (ExistingError) + return ExistingError; + + // Make the send call. + if (auto EC = call(Channel, Addr, Size)) + return EC; + + // Follow this up with the section contents. + if (auto EC = Channel.appendBytes(Src, Size)) + return EC; + + return Channel.send(); + } + + std::error_code writePointer(TargetAddress Addr, TargetAddress PtrVal) { + // Check for an 'out-of-band' error, e.g. from an MM destructor. + if (ExistingError) + return ExistingError; + + return call(Channel, Addr, PtrVal); + } + + static std::error_code doNothing() { return std::error_code(); } + + ChannelT &Channel; + std::error_code ExistingError; + std::string RemoteTargetTriple; + uint32_t RemotePointerSize; + uint32_t RemotePageSize; + uint32_t RemoteTrampolineSize; + uint32_t RemoteIndirectStubSize; + ResourceIdMgr AllocatorIds, IndirectStubOwnerIds; + std::function CompileCallback; +}; + +} // end namespace remote +} // end namespace orc +} // end namespace llvm + +#undef DEBUG_TYPE + +#endif diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h new file mode 100644 index 00000000000..96dc2425102 --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -0,0 +1,185 @@ +//===--- OrcRemoteTargetRPCAPI.h - Orc Remote-target RPC API ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the Orc remote-target RPC API. It should not be used +// directly, but is used by the RemoteTargetClient and RemoteTargetServer +// classes. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H +#define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETRPCAPI_H + +#include "JITSymbol.h" +#include "RPCChannel.h" +#include "RPCUtils.h" + +namespace llvm { +namespace orc { +namespace remote { + +class OrcRemoteTargetRPCAPI : public RPC { +protected: + class ResourceIdMgr { + public: + typedef uint64_t ResourceId; + ResourceIdMgr() : NextId(0) {} + ResourceId getNext() { + if (!FreeIds.empty()) { + ResourceId I = FreeIds.back(); + FreeIds.pop_back(); + return I; + } + return NextId++; + } + void release(ResourceId I) { FreeIds.push_back(I); } + + private: + ResourceId NextId; + std::vector FreeIds; + }; + +public: + enum JITProcId : uint32_t { + InvalidId = 0, + CallIntVoidId, + CallIntVoidResponseId, + CallMainId, + CallMainResponseId, + CallVoidVoidId, + CallVoidVoidResponseId, + CreateRemoteAllocatorId, + CreateIndirectStubsOwnerId, + DestroyRemoteAllocatorId, + DestroyIndirectStubsOwnerId, + EmitIndirectStubsId, + EmitIndirectStubsResponseId, + EmitResolverBlockId, + EmitTrampolineBlockId, + EmitTrampolineBlockResponseId, + GetSymbolAddressId, + GetSymbolAddressResponseId, + GetRemoteInfoId, + GetRemoteInfoResponseId, + ReadMemId, + ReadMemResponseId, + ReserveMemId, + ReserveMemResponseId, + RequestCompileId, + RequestCompileResponseId, + SetProtectionsId, + TerminateSessionId, + WriteMemId, + WritePtrId + }; + + static const char *getJITProcIdName(JITProcId Id); + + typedef Procedure CallIntVoid; + + typedef Procedure + CallIntVoidResponse; + + typedef Procedure /* Args */> + CallMain; + + typedef Procedure CallMainResponse; + + typedef Procedure CallVoidVoid; + + typedef Procedure CallVoidVoidResponse; + + typedef Procedure + CreateRemoteAllocator; + + typedef Procedure + CreateIndirectStubsOwner; + + typedef Procedure + DestroyRemoteAllocator; + + typedef Procedure + DestroyIndirectStubsOwner; + + typedef Procedure + EmitIndirectStubs; + + typedef Procedure< + EmitIndirectStubsResponseId, TargetAddress /* StubsBaseAddr */, + TargetAddress /* PtrsBaseAddr */, uint32_t /* NumStubsEmitted */> + EmitIndirectStubsResponse; + + typedef Procedure EmitResolverBlock; + + typedef Procedure EmitTrampolineBlock; + + typedef Procedure + EmitTrampolineBlockResponse; + + typedef Procedure + GetSymbolAddress; + + typedef Procedure + GetSymbolAddressResponse; + + typedef Procedure GetRemoteInfo; + + typedef Procedure + GetRemoteInfoResponse; + + typedef Procedure + ReadMem; + + typedef Procedure ReadMemResponse; + + typedef Procedure + ReserveMem; + + typedef Procedure + ReserveMemResponse; + + typedef Procedure + RequestCompile; + + typedef Procedure + RequestCompileResponse; + + typedef Procedure + SetProtections; + + typedef Procedure TerminateSession; + + typedef Procedure + WriteMem; + + typedef Procedure + WritePtr; +}; + +} // end namespace remote +} // end namespace orc +} // end namespace llvm + +#endif diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h new file mode 100644 index 00000000000..4b4ecfc1ad2 --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -0,0 +1,479 @@ +//===---- OrcRemoteTargetServer.h - Orc Remote-target Server ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the OrcRemoteTargetServer class. It can be used to build a +// JIT server that can execute code sent from an OrcRemoteTargetClient. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H +#define LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETSERVER_H + +#include "OrcRemoteTargetRPCAPI.h" +#include "llvm/ExecutionEngine/RTDyldMemoryManager.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/Process.h" +#include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "orc-remote" + +namespace llvm { +namespace orc { +namespace remote { + +template +class OrcRemoteTargetServer : public OrcRemoteTargetRPCAPI { +public: + typedef std::function + SymbolLookupFtor; + + OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup) + : Channel(Channel), SymbolLookup(std::move(SymbolLookup)) {} + + std::error_code getNextProcId(JITProcId &Id) { + return deserialize(Channel, Id); + } + + std::error_code handleKnownProcedure(JITProcId Id) { + DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n"); + + switch (Id) { + case CallIntVoidId: + return handleCallIntVoid(); + case CallMainId: + return handleCallMain(); + case CallVoidVoidId: + return handleCallVoidVoid(); + case CreateRemoteAllocatorId: + return handleCreateRemoteAllocator(); + case CreateIndirectStubsOwnerId: + return handleCreateIndirectStubsOwner(); + case DestroyRemoteAllocatorId: + return handleDestroyRemoteAllocator(); + case EmitIndirectStubsId: + return handleEmitIndirectStubs(); + case EmitResolverBlockId: + return handleEmitResolverBlock(); + case EmitTrampolineBlockId: + return handleEmitTrampolineBlock(); + case GetSymbolAddressId: + return handleGetSymbolAddress(); + case GetRemoteInfoId: + return handleGetRemoteInfo(); + case ReadMemId: + return handleReadMem(); + case ReserveMemId: + return handleReserveMem(); + case SetProtectionsId: + return handleSetProtections(); + case WriteMemId: + return handleWriteMem(); + case WritePtrId: + return handleWritePtr(); + default: + return orcError(OrcErrorCode::UnexpectedRPCCall); + } + + llvm_unreachable("Unhandled JIT RPC procedure Id."); + } + + std::error_code requestCompile(TargetAddress &CompiledFnAddr, + TargetAddress TrampolineAddr) { + if (auto EC = call(Channel, TrampolineAddr)) + return EC; + + while (1) { + JITProcId Id = InvalidId; + if (auto EC = getNextProcId(Id)) + return EC; + + switch (Id) { + case RequestCompileResponseId: + return handle(Channel, + readArgs(CompiledFnAddr)); + default: + if (auto EC = handleKnownProcedure(Id)) + return EC; + } + } + + llvm_unreachable("Fell through request-compile command loop."); + } + +private: + struct Allocator { + Allocator() = default; + Allocator(Allocator &&) = default; + Allocator &operator=(Allocator &&) = default; + + ~Allocator() { + for (auto &Alloc : Allocs) + sys::Memory::releaseMappedMemory(Alloc.second); + } + + std::error_code allocate(void *&Addr, size_t Size, uint32_t Align) { + std::error_code EC; + sys::MemoryBlock MB = sys::Memory::allocateMappedMemory( + Size, nullptr, sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC); + if (EC) + return EC; + + Addr = MB.base(); + assert(Allocs.find(MB.base()) == Allocs.end() && "Duplicate alloc"); + Allocs[MB.base()] = std::move(MB); + return std::error_code(); + } + + std::error_code setProtections(void *block, unsigned Flags) { + auto I = Allocs.find(block); + if (I == Allocs.end()) + return orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized); + return sys::Memory::protectMappedMemory(I->second, Flags); + } + + private: + std::map Allocs; + }; + + static std::error_code doNothing() { return std::error_code(); } + + static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) { + TargetAddress CompiledFnAddr = 0; + + auto T = static_cast(JITTargetAddr); + auto EC = T->requestCompile( + CompiledFnAddr, static_cast( + reinterpret_cast(TrampolineAddr))); + assert(!EC && "Compile request failed"); + return CompiledFnAddr; + } + + std::error_code handleCallIntVoid() { + typedef int (*IntVoidFnTy)(); + + IntVoidFnTy Fn = nullptr; + if (auto EC = handle(Channel, [&](TargetAddress Addr) { + Fn = reinterpret_cast(static_cast(Addr)); + return std::error_code(); + })) + return EC; + + DEBUG(dbgs() << " Calling " << reinterpret_cast(Fn) << "\n"); + int Result = Fn(); + DEBUG(dbgs() << " Result = " << Result << "\n"); + + return call(Channel, Result); + } + + std::error_code handleCallMain() { + typedef int (*MainFnTy)(int, const char *[]); + + MainFnTy Fn = nullptr; + std::vector Args; + if (auto EC = handle( + Channel, [&](TargetAddress Addr, std::vector &A) { + Fn = reinterpret_cast(static_cast(Addr)); + Args = std::move(A); + return std::error_code(); + })) + return EC; + + int ArgC = Args.size() + 1; + int Idx = 1; + std::unique_ptr ArgV(new const char *[ArgC + 1]); + ArgV[0] = ""; + for (auto &Arg : Args) + ArgV[Idx++] = Arg.c_str(); + + DEBUG(dbgs() << " Calling " << reinterpret_cast(Fn) << "\n"); + int Result = Fn(ArgC, ArgV.get()); + DEBUG(dbgs() << " Result = " << Result << "\n"); + + return call(Channel, Result); + } + + std::error_code handleCallVoidVoid() { + typedef void (*VoidVoidFnTy)(); + + VoidVoidFnTy Fn = nullptr; + if (auto EC = handle(Channel, [&](TargetAddress Addr) { + Fn = reinterpret_cast(static_cast(Addr)); + return std::error_code(); + })) + return EC; + + DEBUG(dbgs() << " Calling " << reinterpret_cast(Fn) << "\n"); + Fn(); + DEBUG(dbgs() << " Complete.\n"); + + return call(Channel); + } + + std::error_code handleCreateRemoteAllocator() { + return handle( + Channel, [&](ResourceIdMgr::ResourceId Id) { + auto I = Allocators.find(Id); + if (I != Allocators.end()) + return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse); + DEBUG(dbgs() << " Created allocator " << Id << "\n"); + Allocators[Id] = Allocator(); + return std::error_code(); + }); + } + + std::error_code handleCreateIndirectStubsOwner() { + return handle( + Channel, [&](ResourceIdMgr::ResourceId Id) { + auto I = IndirectStubsOwners.find(Id); + if (I != IndirectStubsOwners.end()) + return orcError( + OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse); + DEBUG(dbgs() << " Create indirect stubs owner " << Id << "\n"); + IndirectStubsOwners[Id] = ISBlockOwnerList(); + return std::error_code(); + }); + } + + std::error_code handleDestroyRemoteAllocator() { + return handle( + Channel, [&](ResourceIdMgr::ResourceId Id) { + auto I = Allocators.find(Id); + if (I == Allocators.end()) + return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + Allocators.erase(I); + DEBUG(dbgs() << " Destroyed allocator " << Id << "\n"); + return std::error_code(); + }); + } + + std::error_code handleDestroyIndirectStubsOwner() { + return handle( + Channel, [&](ResourceIdMgr::ResourceId Id) { + auto I = IndirectStubsOwners.find(Id); + if (I == IndirectStubsOwners.end()) + return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); + IndirectStubsOwners.erase(I); + return std::error_code(); + }); + } + + std::error_code handleEmitIndirectStubs() { + ResourceIdMgr::ResourceId ISOwnerId = ~0U; + uint32_t NumStubsRequired = 0; + + if (auto EC = handle( + Channel, readArgs(ISOwnerId, NumStubsRequired))) + return EC; + + DEBUG(dbgs() << " ISMgr " << ISOwnerId << " request " << NumStubsRequired + << " stubs.\n"); + + auto StubOwnerItr = IndirectStubsOwners.find(ISOwnerId); + if (StubOwnerItr == IndirectStubsOwners.end()) + return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); + + typename TargetT::IndirectStubsInfo IS; + if (auto EC = + TargetT::emitIndirectStubsBlock(IS, NumStubsRequired, nullptr)) + return EC; + + TargetAddress StubsBase = + static_cast(reinterpret_cast(IS.getStub(0))); + TargetAddress PtrsBase = + static_cast(reinterpret_cast(IS.getPtr(0))); + uint32_t NumStubsEmitted = IS.getNumStubs(); + + auto &BlockList = StubOwnerItr->second; + BlockList.push_back(std::move(IS)); + + return call(Channel, StubsBase, PtrsBase, + NumStubsEmitted); + } + + std::error_code handleEmitResolverBlock() { + if (auto EC = handle(Channel, doNothing)) + return EC; + + std::error_code EC; + ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( + TargetT::ResolverCodeSize, nullptr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); + if (EC) + return EC; + + TargetT::writeResolverCode(static_cast(ResolverBlock.base()), + &reenter, this); + + return sys::Memory::protectMappedMemory(ResolverBlock.getMemoryBlock(), + sys::Memory::MF_READ | + sys::Memory::MF_EXEC); + } + + std::error_code handleEmitTrampolineBlock() { + if (auto EC = handle(Channel, doNothing)) + return EC; + + std::error_code EC; + + auto TrampolineBlock = + sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( + TargetT::PageSize, nullptr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); + if (EC) + return EC; + + unsigned NumTrampolines = + (TargetT::PageSize - TargetT::PointerSize) / TargetT::TrampolineSize; + + uint8_t *TrampolineMem = static_cast(TrampolineBlock.base()); + TargetT::writeTrampolines(TrampolineMem, ResolverBlock.base(), + NumTrampolines); + + EC = sys::Memory::protectMappedMemory(TrampolineBlock.getMemoryBlock(), + sys::Memory::MF_READ | + sys::Memory::MF_EXEC); + + TrampolineBlocks.push_back(std::move(TrampolineBlock)); + + return call( + Channel, + static_cast(reinterpret_cast(TrampolineMem)), + NumTrampolines); + } + + std::error_code handleGetSymbolAddress() { + std::string SymbolName; + if (auto EC = handle(Channel, readArgs(SymbolName))) + return EC; + + TargetAddress SymbolAddr = SymbolLookup(SymbolName); + DEBUG(dbgs() << " Symbol '" << SymbolName + << "' = " << format("0x%016x", SymbolAddr) << "\n"); + return call(Channel, SymbolAddr); + } + + std::error_code handleGetRemoteInfo() { + if (auto EC = handle(Channel, doNothing)) + return EC; + + std::string ProcessTriple = sys::getProcessTriple(); + uint32_t PointerSize = TargetT::PointerSize; + uint32_t PageSize = sys::Process::getPageSize(); + uint32_t TrampolineSize = TargetT::TrampolineSize; + uint32_t IndirectStubSize = TargetT::IndirectStubsInfo::StubSize; + DEBUG(dbgs() << " Remote info:\n" + << " triple = '" << ProcessTriple << "'\n" + << " pointer size = " << PointerSize << "\n" + << " page size = " << PageSize << "\n" + << " trampoline size = " << TrampolineSize << "\n" + << " indirect stub size = " << IndirectStubSize << "\n"); + return call(Channel, ProcessTriple, PointerSize, + PageSize, TrampolineSize, + IndirectStubSize); + } + + std::error_code handleReadMem() { + char *Src = nullptr; + uint64_t Size = 0; + if (auto EC = + handle(Channel, [&](TargetAddress RSrc, uint64_t RSize) { + Src = reinterpret_cast(static_cast(RSrc)); + Size = RSize; + return std::error_code(); + })) + return EC; + + DEBUG(dbgs() << " Reading " << Size << " bytes from " + << static_cast(Src) << "\n"); + + if (auto EC = call(Channel)) + return EC; + + if (auto EC = Channel.appendBytes(Src, Size)) + return EC; + + return Channel.send(); + } + + std::error_code handleReserveMem() { + void *LocalAllocAddr = nullptr; + + if (auto EC = + handle(Channel, [&](ResourceIdMgr::ResourceId Id, + uint64_t Size, uint32_t Align) { + auto I = Allocators.find(Id); + if (I == Allocators.end()) + return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + auto &Allocator = I->second; + auto EC2 = Allocator.allocate(LocalAllocAddr, Size, Align); + DEBUG(dbgs() << " Allocator " << Id << " reserved " + << LocalAllocAddr << " (" << Size + << " bytes, alignment " << Align << ")\n"); + return EC2; + })) + return EC; + + TargetAddress AllocAddr = + static_cast(reinterpret_cast(LocalAllocAddr)); + + return call(Channel, AllocAddr); + } + + std::error_code handleSetProtections() { + return handle(Channel, [&](ResourceIdMgr::ResourceId Id, + TargetAddress Addr, uint32_t Flags) { + auto I = Allocators.find(Id); + if (I == Allocators.end()) + return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); + auto &Allocator = I->second; + void *LocalAddr = reinterpret_cast(static_cast(Addr)); + DEBUG(dbgs() << " Allocator " << Id << " set permissions on " + << LocalAddr << " to " + << (Flags & sys::Memory::MF_READ ? 'R' : '-') + << (Flags & sys::Memory::MF_WRITE ? 'W' : '-') + << (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n"); + return Allocator.setProtections(LocalAddr, Flags); + }); + } + + std::error_code handleWriteMem() { + return handle(Channel, [&](TargetAddress RDst, uint64_t Size) { + char *Dst = reinterpret_cast(static_cast(RDst)); + return Channel.readBytes(Dst, Size); + }); + } + + std::error_code handleWritePtr() { + return handle( + Channel, [&](TargetAddress Addr, TargetAddress PtrVal) { + uintptr_t *Ptr = + reinterpret_cast(static_cast(Addr)); + *Ptr = static_cast(PtrVal); + return std::error_code(); + }); + } + + ChannelT &Channel; + SymbolLookupFtor SymbolLookup; + std::map Allocators; + typedef std::vector ISBlockOwnerList; + std::map IndirectStubsOwners; + sys::OwningMemoryBlock ResolverBlock; + std::vector TrampolineBlocks; +}; + +} // end namespace remote +} // end namespace orc +} // end namespace llvm + +#undef DEBUG_TYPE + +#endif diff --git a/include/llvm/ExecutionEngine/Orc/RPCChannel.h b/include/llvm/ExecutionEngine/Orc/RPCChannel.h new file mode 100644 index 00000000000..5ebd40d6051 --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/RPCChannel.h @@ -0,0 +1,207 @@ +// -*- c++ -*- + +#ifndef LLVM_EXECUTIONENGINE_ORC_RPCCHANNEL_H +#define LLVM_EXECUTIONENGINE_ORC_RPCCHANNEL_H + +#include "OrcError.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/Endian.h" + +#include +#include + +namespace llvm { +namespace orc { +namespace remote { + +/// Interface for byte-streams to be used with RPC. +class RPCChannel { +public: + virtual ~RPCChannel() {} + + /// Read Size bytes from the stream into *Dst. + virtual std::error_code readBytes(char *Dst, unsigned Size) = 0; + + /// Read size bytes from *Src and append them to the stream. + virtual std::error_code appendBytes(const char *Src, unsigned Size) = 0; + + /// Flush the stream if possible. + virtual std::error_code send() = 0; +}; + +/// RPC channel that reads from and writes from file descriptors. +class FDRPCChannel : public RPCChannel { +public: + FDRPCChannel(int InFD, int OutFD) : InFD(InFD), OutFD(OutFD) {} + + std::error_code readBytes(char *Dst, unsigned Size) override { + assert(Dst && "Attempt to read into null."); + ssize_t ReadResult = ::read(InFD, Dst, Size); + if (ReadResult != Size) + return std::error_code(errno, std::generic_category()); + return std::error_code(); + } + + std::error_code appendBytes(const char *Src, unsigned Size) override { + assert(Src && "Attempt to append from null."); + ssize_t WriteResult = ::write(OutFD, Src, Size); + if (WriteResult != Size) + std::error_code(errno, std::generic_category()); + return std::error_code(); + } + + std::error_code send() override { return std::error_code(); } + +private: + int InFD, OutFD; +}; + +/// RPC channel serialization for a variadic list of arguments. +template +std::error_code serialize_seq(RPCChannel &C, const T &Arg, const Ts &... Args) { + if (auto EC = serialize(C, Arg)) + return EC; + return serialize_seq(C, Args...); +} + +/// RPC channel serialization for an (empty) variadic list of arguments. +inline std::error_code serialize_seq(RPCChannel &C) { + return std::error_code(); +} + +/// RPC channel deserialization for a variadic list of arguments. +template +std::error_code deserialize_seq(RPCChannel &C, T &Arg, Ts &... Args) { + if (auto EC = deserialize(C, Arg)) + return EC; + return deserialize_seq(C, Args...); +} + +/// RPC channel serialization for an (empty) variadic list of arguments. +inline std::error_code deserialize_seq(RPCChannel &C) { + return std::error_code(); +} + +/// RPC channel serialization for integer primitives. +template +typename std::enable_if< + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, + std::error_code>::type +serialize(RPCChannel &C, T V) { + support::endian::byte_swap(V); + return C.appendBytes(reinterpret_cast(&V), sizeof(T)); +} + +/// RPC channel deserialization for integer primitives. +template +typename std::enable_if< + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, + std::error_code>::type +deserialize(RPCChannel &C, T &V) { + if (auto EC = C.readBytes(reinterpret_cast(&V), sizeof(T))) + return EC; + support::endian::byte_swap(V); + return std::error_code(); +} + +/// RPC channel serialization for enums. +template +typename std::enable_if::value, std::error_code>::type +serialize(RPCChannel &C, T V) { + return serialize(C, static_cast::type>(V)); +} + +/// RPC channel deserialization for enums. +template +typename std::enable_if::value, std::error_code>::type +deserialize(RPCChannel &C, T &V) { + typename std::underlying_type::type Tmp; + std::error_code EC = deserialize(C, Tmp); + V = static_cast(Tmp); + return EC; +} + +/// RPC channel serialization for bools. +inline std::error_code serialize(RPCChannel &C, bool V) { + uint8_t VN = V ? 1 : 0; + return C.appendBytes(reinterpret_cast(&VN), 1); +} + +/// RPC channel deserialization for bools. +inline std::error_code deserialize(RPCChannel &C, bool &V) { + uint8_t VN = 0; + if (auto EC = C.readBytes(reinterpret_cast(&VN), 1)) + return EC; + + V = (VN != 0) ? true : false; + return std::error_code(); +} + +/// RPC channel serialization for StringRefs. +/// Note: There is no corresponding deseralization for this, as StringRef +/// doesn't own its memory and so can't hold the deserialized data. +inline std::error_code serialize(RPCChannel &C, StringRef S) { + if (auto EC = serialize(C, static_cast(S.size()))) + return EC; + return C.appendBytes((const char *)S.bytes_begin(), S.size()); +} + +/// RPC channel serialization for std::strings. +inline std::error_code serialize(RPCChannel &C, const std::string &S) { + return serialize(C, StringRef(S)); +} + +/// RPC channel deserialization for std::strings. +inline std::error_code deserialize(RPCChannel &C, std::string &S) { + uint64_t Count; + if (auto EC = deserialize(C, Count)) + return EC; + S.resize(Count); + return C.readBytes(&S[0], Count); +} + +/// RPC channel serialization for ArrayRef. +template +std::error_code serialize(RPCChannel &C, const ArrayRef &A) { + if (auto EC = serialize(C, static_cast(A.size()))) + return EC; + + for (const auto &E : A) + if (auto EC = serialize(C, E)) + return EC; + + return std::error_code(); +} + +/// RPC channel serialization for std::array. +template +std::error_code serialize(RPCChannel &C, const std::vector &V) { + return serialize(C, ArrayRef(V)); +} + +/// RPC channel deserialization for std::array. +template +std::error_code deserialize(RPCChannel &C, std::vector &V) { + uint64_t Count = 0; + if (auto EC = deserialize(C, Count)) + return EC; + + V.resize(Count); + for (auto &E : V) + if (auto EC = deserialize(C, E)) + return EC; + + return std::error_code(); +} + +} // end namespace remote +} // end namespace orc +} // end namespace llvm + +#endif diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h new file mode 100644 index 00000000000..d275f642308 --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -0,0 +1,222 @@ +//===----- RPCUTils.h - Basic tilities for building RPC APIs ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Basic utilities for building RPC APIs. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H +#define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H + +#include "llvm/ADT/STLExtras.h" + +namespace llvm { +namespace orc { +namespace remote { + +/// Contains primitive utilities for defining, calling and handling calls to +/// remote procedures. ChannelT is a bidirectional stream conforming to the +/// RPCChannel interface (see RPCChannel.h), and ProcedureIdT is a procedure +/// identifier type that must be serializable on ChannelT. +/// +/// These utilities support the construction of very primitive RPC utilities. +/// Their intent is to ensure correct serialization and deserialization of +/// procedure arguments, and to keep the client and server's view of the API in +/// sync. +/// +/// These utilities do not support return values. These can be handled by +/// declaring a corresponding '.*Response' procedure and expecting it after a +/// call). They also do not support versioning: the client and server *must* be +/// compiled with the same procedure definitions. +/// +/// +/// +/// Overview (see comments individual types/methods for details): +/// +/// Procedure : +/// +/// associates a unique serializable id with an argument list. +/// +/// +/// call(Channel, Args...) : +/// +/// Calls the remote procedure 'Proc' by serializing Proc's id followed by its +/// arguments and sending the resulting bytes to 'Channel'. +/// +/// +/// handle(Channel, : +/// +/// Handles a call to 'Proc' by deserializing its arguments and calling the +/// given functor. This assumes that the id for 'Proc' has already been +/// deserialized. +/// +/// expect(Channel, : +/// +/// The same as 'handle', except that the procedure id should not have been +/// read yet. Expect will deserialize the id and assert that it matches Proc's +/// id. If it does not, and unexpected RPC call error is returned. + +template class RPC { +public: + /// Utility class for defining/referring to RPC procedures. + /// + /// Typedefs of this utility are used when calling/handling remote procedures. + /// + /// ProcId should be a unique value of ProcedureIdT (i.e. not used with any + /// other Procedure typedef in the RPC API being defined. + /// + /// the template argument Ts... gives the argument list for the remote + /// procedure. + /// + /// E.g. + /// + /// typedef Procedure<0, bool> Proc1; + /// typedef Procedure<1, std::string, std::vector> Proc2; + /// + /// if (auto EC = call(Channel, true)) + /// /* handle EC */; + /// + /// if (auto EC = expect(Channel, + /// [](std::string &S, std::vector &V) { + /// // Stuff. + /// return std::error_code(); + /// }) + /// /* handle EC */; + /// + template class Procedure { + public: + static const ProcedureIdT Id = ProcId; + }; + +private: + template class CallHelper {}; + + template + class CallHelper> { + public: + static std::error_code call(ChannelT &C, const ArgTs &... Args) { + if (auto EC = serialize(C, ProcId)) + return EC; + // If you see a compile-error on this line you're probably calling a + // function with the wrong signature. + return serialize_seq(C, Args...); + } + }; + + template class HandlerHelper {}; + + template + class HandlerHelper> { + public: + template + static std::error_code handle(ChannelT &C, HandlerT Handler) { + return readAndHandle(C, Handler, llvm::index_sequence_for()); + } + + private: + template + static std::error_code readAndHandle(ChannelT &C, HandlerT Handler, + llvm::index_sequence _) { + std::tuple RPCArgs; + if (auto EC = deserialize_seq(C, std::get(RPCArgs)...)) + return EC; + return Handler(std::get(RPCArgs)...); + } + }; + + template class ReadArgs { + public: + std::error_code operator()() { return std::error_code(); } + }; + + template + class ReadArgs : public ReadArgs { + public: + ReadArgs(ArgT &Arg, ArgTs &... Args) + : ReadArgs(Args...), Arg(Arg) {} + + std::error_code operator()(ArgT &ArgVal, ArgTs &... ArgVals) { + this->Arg = std::move(ArgVal); + return ReadArgs::operator()(ArgVals...); + } + + private: + ArgT &Arg; + }; + +public: + /// Serialize Args... to channel C, but do not call C.send(). + /// + /// For buffered channels, this can be used to queue up several calls before + /// flushing the channel. + template + static std::error_code appendCall(ChannelT &C, const ArgTs &... Args) { + return CallHelper::call(C, Args...); + } + + /// Serialize Args... to channel C and call C.send(). + template + static std::error_code call(ChannelT &C, const ArgTs &... Args) { + if (auto EC = appendCall(C, Args...)) + return EC; + return C.send(); + } + + /// Deserialize and return an enum whose underlying type is ProcedureIdT. + static std::error_code getNextProcId(ChannelT &C, ProcedureIdT &Id) { + return deserialize(C, Id); + } + + /// Deserialize args for Proc from C and call Handler. The signature of + /// handler must conform to 'std::error_code(Args...)' where Args... matches + /// the arguments used in the Proc typedef. + template + static std::error_code handle(ChannelT &C, HandlerT Handler) { + return HandlerHelper::handle(C, Handler); + } + + /// Deserialize a ProcedureIdT from C and verify it matches the id for Proc. + /// If the id does match, deserialize the arguments and call the handler + /// (similarly to handle). + /// If the id does not match, return an unexpect RPC call error and do not + /// deserialize any further bytes. + template + static std::error_code expect(ChannelT &C, HandlerT Handler) { + ProcedureIdT ProcId; + if (auto EC = getNextProcId(C, ProcId)) + return EC; + if (ProcId != Proc::Id) + return orcError(OrcErrorCode::UnexpectedRPCCall); + return handle(C, Handler); + } + + /// Helper for handling setter procedures - this method returns a functor that + /// sets the variables referred to by Args... to values deserialized from the + /// channel. + /// E.g. + /// + /// typedef Procedure<0, bool, int> Proc1; + /// + /// ... + /// bool B; + /// int I; + /// if (auto EC = expect(Channel, readArgs(B, I))) + /// /* Handle Args */ ; + /// + template + static ReadArgs readArgs(ArgTs &... Args) { + return ReadArgs(Args...); + } +}; + +} // end namespace remote +} // end namespace orc +} // end namespace llvm + +#endif diff --git a/lib/ExecutionEngine/Orc/CMakeLists.txt b/lib/ExecutionEngine/Orc/CMakeLists.txt index f145be5b688..d26f212e00c 100644 --- a/lib/ExecutionEngine/Orc/CMakeLists.txt +++ b/lib/ExecutionEngine/Orc/CMakeLists.txt @@ -7,6 +7,7 @@ add_llvm_library(LLVMOrcJIT OrcCBindingsStack.cpp OrcError.cpp OrcMCJITReplacement.cpp + OrcRemoteTargetRPCAPI.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/ExecutionEngine/Orc diff --git a/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp b/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp new file mode 100644 index 00000000000..064633b4e49 --- /dev/null +++ b/lib/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.cpp @@ -0,0 +1,83 @@ +//===------- OrcRemoteTargetRPCAPI.cpp - ORC Remote API utilities ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h" + +namespace llvm { +namespace orc { +namespace remote { + +const char *OrcRemoteTargetRPCAPI::getJITProcIdName(JITProcId Id) { + switch (Id) { + case InvalidId: + return "*** Invalid JITProcId ***"; + case CallIntVoidId: + return "CallIntVoid"; + case CallIntVoidResponseId: + return "CallIntVoidResponse"; + case CallMainId: + return "CallMain"; + case CallMainResponseId: + return "CallMainResponse"; + case CallVoidVoidId: + return "CallVoidVoid"; + case CallVoidVoidResponseId: + return "CallVoidVoidResponse"; + case CreateRemoteAllocatorId: + return "CreateRemoteAllocator"; + case CreateIndirectStubsOwnerId: + return "CreateIndirectStubsOwner"; + case DestroyRemoteAllocatorId: + return "DestroyRemoteAllocator"; + case DestroyIndirectStubsOwnerId: + return "DestroyIndirectStubsOwner"; + case EmitIndirectStubsId: + return "EmitIndirectStubs"; + case EmitIndirectStubsResponseId: + return "EmitIndirectStubsResponse"; + case EmitResolverBlockId: + return "EmitResolverBlock"; + case EmitTrampolineBlockId: + return "EmitTrampolineBlock"; + case EmitTrampolineBlockResponseId: + return "EmitTrampolineBlockResponse"; + case GetSymbolAddressId: + return "GetSymbolAddress"; + case GetSymbolAddressResponseId: + return "GetSymbolAddressResponse"; + case GetRemoteInfoId: + return "GetRemoteInfo"; + case GetRemoteInfoResponseId: + return "GetRemoteInfoResponse"; + case ReadMemId: + return "ReadMem"; + case ReadMemResponseId: + return "ReadMemResponse"; + case ReserveMemId: + return "ReserveMem"; + case ReserveMemResponseId: + return "ReserveMemResponse"; + case RequestCompileId: + return "RequestCompile"; + case RequestCompileResponseId: + return "RequestCompileResponse"; + case SetProtectionsId: + return "SetProtections"; + case TerminateSessionId: + return "TerminateSession"; + case WriteMemId: + return "WriteMem"; + case WritePtrId: + return "WritePtr"; + }; + return nullptr; +} +} +} +} diff --git a/unittests/ExecutionEngine/Orc/CMakeLists.txt b/unittests/ExecutionEngine/Orc/CMakeLists.txt index 74cc5b57015..41fef24556b 100644 --- a/unittests/ExecutionEngine/Orc/CMakeLists.txt +++ b/unittests/ExecutionEngine/Orc/CMakeLists.txt @@ -18,4 +18,5 @@ add_llvm_unittest(OrcJITTests ObjectTransformLayerTest.cpp OrcCAPITest.cpp OrcTestCommon.cpp + RPCUtilsTest.cpp ) diff --git a/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp new file mode 100644 index 00000000000..8215144a514 --- /dev/null +++ b/unittests/ExecutionEngine/Orc/RPCUtilsTest.cpp @@ -0,0 +1,147 @@ +//===----------- RPCUtilsTest.cpp - Unit tests the Orc RPC utils ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ExecutionEngine/Orc/RPCChannel.h" +#include "llvm/ExecutionEngine/Orc/RPCUtils.h" +#include "gtest/gtest.h" + +#include + +using namespace llvm; +using namespace llvm::orc; +using namespace llvm::orc::remote; + +class QueueChannel : public RPCChannel { +public: + QueueChannel(std::queue &Queue) : Queue(Queue) {} + + std::error_code readBytes(char *Dst, unsigned Size) override { + while (Size--) { + *Dst++ = Queue.front(); + Queue.pop(); + } + return std::error_code(); + } + + std::error_code appendBytes(const char *Src, unsigned Size) override { + while (Size--) + Queue.push(*Src++); + return std::error_code(); + } + + std::error_code send() override { return std::error_code(); } + +private: + std::queue &Queue; +}; + +class DummyRPC : public testing::Test, + public RPC { +public: + typedef Procedure<1, bool> Proc1; + typedef Procedure<2, int8_t, + uint8_t, + int16_t, + uint16_t, + int32_t, + uint32_t, + int64_t, + uint64_t, + bool, + std::string, + std::vector> AllTheTypes; +}; + + +TEST_F(DummyRPC, TestBasic) { + std::queue Queue; + QueueChannel C(Queue); + + { + // Make a call to Proc1. + auto EC = call(C, true); + EXPECT_FALSE(EC) << "Simple call over queue failed"; + } + + { + // Expect a call to Proc1. + auto EC = expect(C, + [&](bool &B) { + EXPECT_EQ(B, true) + << "Bool serialization broken"; + return std::error_code(); + }); + EXPECT_FALSE(EC) << "Simple expect over queue failed"; + } +} + +TEST_F(DummyRPC, TestSerialization) { + std::queue Queue; + QueueChannel C(Queue); + + { + // Make a call to Proc1. + std::vector v({42, 7}); + auto EC = call(C, + -101, + 250, + -10000, + 10000, + -1000000000, + 1000000000, + -10000000000, + 10000000000, + true, + "foo", + v); + EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed"; + } + + { + // Expect a call to Proc1. + auto EC = expect(C, + [&](int8_t &s8, + uint8_t &u8, + int16_t &s16, + uint16_t &u16, + int32_t &s32, + uint32_t &u32, + int64_t &s64, + uint64_t &u64, + bool &b, + std::string &s, + std::vector &v) { + + EXPECT_EQ(s8, -101) + << "int8_t serialization broken"; + EXPECT_EQ(u8, 250) + << "uint8_t serialization broken"; + EXPECT_EQ(s16, -10000) + << "int16_t serialization broken"; + EXPECT_EQ(u16, 10000) + << "uint16_t serialization broken"; + EXPECT_EQ(s32, -1000000000) + << "int32_t serialization broken"; + EXPECT_EQ(u32, 1000000000ULL) + << "uint32_t serialization broken"; + EXPECT_EQ(s64, -10000000000) + << "int64_t serialization broken"; + EXPECT_EQ(u64, 10000000000ULL) + << "uint64_t serialization broken"; + EXPECT_EQ(b, true) + << "bool serialization broken"; + EXPECT_EQ(s, "foo") + << "std::string serialization broken"; + EXPECT_EQ(v, std::vector({42, 7})) + << "std::vector serialization broken"; + return std::error_code(); + }); + EXPECT_FALSE(EC) << "Big (serialization test) call over queue failed"; + } +}