[Orc] Add overloads of RPC::handle and RPC::expect that take member functions as
[oota-llvm.git] / include / llvm / ExecutionEngine / Orc / OrcRemoteTargetServer.h
index b9db3890cc70d0e78a735ee874c2399cbf0e6e15..4d846986e8dbfc306e426cf7a4031666e516d977 100644 (file)
@@ -43,41 +43,54 @@ public:
   }
 
   std::error_code handleKnownProcedure(JITProcId Id) {
+    typedef OrcRemoteTargetServer ThisT;
+
     DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n");
 
     switch (Id) {
     case CallIntVoidId:
-      return handleCallIntVoid();
+      return handle<CallIntVoid>(Channel, *this, &ThisT::handleCallIntVoid);
     case CallMainId:
-      return handleCallMain();
+      return handle<CallMain>(Channel, *this, &ThisT::handleCallMain);
     case CallVoidVoidId:
-      return handleCallVoidVoid();
+      return handle<CallVoidVoid>(Channel, *this, &ThisT::handleCallVoidVoid);
     case CreateRemoteAllocatorId:
-      return handleCreateRemoteAllocator();
+      return handle<CreateRemoteAllocator>(Channel, *this,
+                                           &ThisT::handleCreateRemoteAllocator);
     case CreateIndirectStubsOwnerId:
-      return handleCreateIndirectStubsOwner();
+      return handle<CreateIndirectStubsOwner>(
+          Channel, *this, &ThisT::handleCreateIndirectStubsOwner);
     case DestroyRemoteAllocatorId:
-      return handleDestroyRemoteAllocator();
+      return handle<DestroyRemoteAllocator>(
+          Channel, *this, &ThisT::handleDestroyRemoteAllocator);
+    case DestroyIndirectStubsOwnerId:
+      return handle<DestroyIndirectStubsOwner>(
+          Channel, *this, &ThisT::handleDestroyIndirectStubsOwner);
     case EmitIndirectStubsId:
-      return handleEmitIndirectStubs();
+      return handle<EmitIndirectStubs>(Channel, *this,
+                                       &ThisT::handleEmitIndirectStubs);
     case EmitResolverBlockId:
-      return handleEmitResolverBlock();
+      return handle<EmitResolverBlock>(Channel, *this,
+                                       &ThisT::handleEmitResolverBlock);
     case EmitTrampolineBlockId:
-      return handleEmitTrampolineBlock();
+      return handle<EmitTrampolineBlock>(Channel, *this,
+                                         &ThisT::handleEmitTrampolineBlock);
     case GetSymbolAddressId:
-      return handleGetSymbolAddress();
+      return handle<GetSymbolAddress>(Channel, *this,
+                                      &ThisT::handleGetSymbolAddress);
     case GetRemoteInfoId:
-      return handleGetRemoteInfo();
+      return handle<GetRemoteInfo>(Channel, *this, &ThisT::handleGetRemoteInfo);
     case ReadMemId:
-      return handleReadMem();
+      return handle<ReadMem>(Channel, *this, &ThisT::handleReadMem);
     case ReserveMemId:
-      return handleReserveMem();
+      return handle<ReserveMem>(Channel, *this, &ThisT::handleReserveMem);
     case SetProtectionsId:
-      return handleSetProtections();
+      return handle<SetProtections>(Channel, *this,
+                                    &ThisT::handleSetProtections);
     case WriteMemId:
-      return handleWriteMem();
+      return handle<WriteMem>(Channel, *this, &ThisT::handleWriteMem);
     case WritePtrId:
-      return handleWritePtr();
+      return handle<WritePtr>(Channel, *this, &ThisT::handleWritePtr);
     default:
       return orcError(OrcErrorCode::UnexpectedRPCCall);
     }
@@ -160,16 +173,10 @@ private:
     return CompiledFnAddr;
   }
 
-  std::error_code handleCallIntVoid() {
+  std::error_code handleCallIntVoid(TargetAddress Addr) {
     typedef int (*IntVoidFnTy)();
-
-    IntVoidFnTy Fn = nullptr;
-    if (std::error_code EC =
-            handle<CallIntVoid>(Channel, [&](TargetAddress Addr) {
-              Fn = reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
-              return std::error_code();
-            }))
-      return EC;
+    IntVoidFnTy Fn =
+        reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr));
 
     DEBUG(dbgs() << "  Calling "
                  << reinterpret_cast<void *>(reinterpret_cast<intptr_t>(Fn))
@@ -180,19 +187,11 @@ private:
     return call<CallIntVoidResponse>(Channel, Result);
   }
 
-  std::error_code handleCallMain() {
+  std::error_code handleCallMain(TargetAddress Addr,
+                                 std::vector<std::string> Args) {
     typedef int (*MainFnTy)(int, const char *[]);
 
-    MainFnTy Fn = nullptr;
-    std::vector<std::string> Args;
-    if (std::error_code EC = handle<CallMain>(
-            Channel, [&](TargetAddress Addr, std::vector<std::string> &A) {
-              Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
-              Args = std::move(A);
-              return std::error_code();
-            }))
-      return EC;
-
+    MainFnTy Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr));
     int ArgC = Args.size() + 1;
     int Idx = 1;
     std::unique_ptr<const char *[]> ArgV(new const char *[ArgC + 1]);
@@ -207,16 +206,10 @@ private:
     return call<CallMainResponse>(Channel, Result);
   }
 
-  std::error_code handleCallVoidVoid() {
+  std::error_code handleCallVoidVoid(TargetAddress Addr) {
     typedef void (*VoidVoidFnTy)();
-
-    VoidVoidFnTy Fn = nullptr;
-    if (std::error_code EC =
-            handle<CallIntVoid>(Channel, [&](TargetAddress Addr) {
-              Fn = reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr));
-              return std::error_code();
-            }))
-      return EC;
+    VoidVoidFnTy Fn =
+        reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr));
 
     DEBUG(dbgs() << "  Calling " << reinterpret_cast<void *>(Fn) << "\n");
     Fn();
@@ -225,66 +218,48 @@ private:
     return call<CallVoidVoidResponse>(Channel);
   }
 
-  std::error_code handleCreateRemoteAllocator() {
-    return handle<CreateRemoteAllocator>(
-        Channel, [&](ResourceIdMgr::ResourceId Id) {
-          auto I = Allocators.find(Id);
-          if (I != Allocators.end())
-            return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse);
-          DEBUG(dbgs() << "  Created allocator " << Id << "\n");
-          Allocators[Id] = Allocator();
-          return std::error_code();
-        });
+  std::error_code handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) {
+    auto I = Allocators.find(Id);
+    if (I != Allocators.end())
+      return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse);
+    DEBUG(dbgs() << "  Created allocator " << Id << "\n");
+    Allocators[Id] = Allocator();
+    return std::error_code();
   }
 
-  std::error_code handleCreateIndirectStubsOwner() {
-    return handle<CreateIndirectStubsOwner>(
-        Channel, [&](ResourceIdMgr::ResourceId Id) {
-          auto I = IndirectStubsOwners.find(Id);
-          if (I != IndirectStubsOwners.end())
-            return orcError(
-                OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse);
-          DEBUG(dbgs() << "  Create indirect stubs owner " << Id << "\n");
-          IndirectStubsOwners[Id] = ISBlockOwnerList();
-          return std::error_code();
-        });
+  std::error_code handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
+    auto I = IndirectStubsOwners.find(Id);
+    if (I != IndirectStubsOwners.end())
+      return orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse);
+    DEBUG(dbgs() << "  Create indirect stubs owner " << Id << "\n");
+    IndirectStubsOwners[Id] = ISBlockOwnerList();
+    return std::error_code();
   }
 
-  std::error_code handleDestroyRemoteAllocator() {
-    return handle<DestroyRemoteAllocator>(
-        Channel, [&](ResourceIdMgr::ResourceId Id) {
-          auto I = Allocators.find(Id);
-          if (I == Allocators.end())
-            return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
-          Allocators.erase(I);
-          DEBUG(dbgs() << "  Destroyed allocator " << Id << "\n");
-          return std::error_code();
-        });
+  std::error_code handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) {
+    auto I = Allocators.find(Id);
+    if (I == Allocators.end())
+      return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
+    Allocators.erase(I);
+    DEBUG(dbgs() << "  Destroyed allocator " << Id << "\n");
+    return std::error_code();
   }
 
-  std::error_code handleDestroyIndirectStubsOwner() {
-    return handle<DestroyIndirectStubsOwner>(
-        Channel, [&](ResourceIdMgr::ResourceId Id) {
-          auto I = IndirectStubsOwners.find(Id);
-          if (I == IndirectStubsOwners.end())
-            return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
-          IndirectStubsOwners.erase(I);
-          return std::error_code();
-        });
+  std::error_code
+  handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) {
+    auto I = IndirectStubsOwners.find(Id);
+    if (I == IndirectStubsOwners.end())
+      return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
+    IndirectStubsOwners.erase(I);
+    return std::error_code();
   }
 
-  std::error_code handleEmitIndirectStubs() {
-    ResourceIdMgr::ResourceId ISOwnerId = ~0U;
-    uint32_t NumStubsRequired = 0;
-
-    if (auto EC = handle<EmitIndirectStubs>(
-            Channel, readArgs(ISOwnerId, NumStubsRequired)))
-      return EC;
-
-    DEBUG(dbgs() << "  ISMgr " << ISOwnerId << " request " << NumStubsRequired
+  std::error_code handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id,
+                                          uint32_t NumStubsRequired) {
+    DEBUG(dbgs() << "  ISMgr " << Id << " request " << NumStubsRequired
                  << " stubs.\n");
 
-    auto StubOwnerItr = IndirectStubsOwners.find(ISOwnerId);
+    auto StubOwnerItr = IndirectStubsOwners.find(Id);
     if (StubOwnerItr == IndirectStubsOwners.end())
       return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist);
 
@@ -307,9 +282,6 @@ private:
   }
 
   std::error_code handleEmitResolverBlock() {
-    if (auto EC = handle<EmitResolverBlock>(Channel, doNothing))
-      return EC;
-
     std::error_code EC;
     ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
         TargetT::ResolverCodeSize, nullptr,
@@ -326,11 +298,7 @@ private:
   }
 
   std::error_code handleEmitTrampolineBlock() {
-    if (auto EC = handle<EmitTrampolineBlock>(Channel, doNothing))
-      return EC;
-
     std::error_code EC;
-
     auto TrampolineBlock =
         sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory(
             sys::Process::getPageSize(), nullptr,
@@ -358,21 +326,14 @@ private:
         NumTrampolines);
   }
 
-  std::error_code handleGetSymbolAddress() {
-    std::string SymbolName;
-    if (auto EC = handle<GetSymbolAddress>(Channel, readArgs(SymbolName)))
-      return EC;
-
-    TargetAddress SymbolAddr = SymbolLookup(SymbolName);
-    DEBUG(dbgs() << "  Symbol '" << SymbolName
-                 << "' =  " << format("0x%016x", SymbolAddr) << "\n");
-    return call<GetSymbolAddressResponse>(Channel, SymbolAddr);
+  std::error_code handleGetSymbolAddress(const std::string &Name) {
+    TargetAddress Addr = SymbolLookup(Name);
+    DEBUG(dbgs() << "  Symbol '" << Name << "' =  " << format("0x%016x", Addr)
+                 << "\n");
+    return call<GetSymbolAddressResponse>(Channel, Addr);
   }
 
   std::error_code handleGetRemoteInfo() {
-    if (auto EC = handle<GetRemoteInfo>(Channel, doNothing))
-      return EC;
-
     std::string ProcessTriple = sys::getProcessTriple();
     uint32_t PointerSize = TargetT::PointerSize;
     uint32_t PageSize = sys::Process::getPageSize();
@@ -389,16 +350,8 @@ private:
                                        IndirectStubSize);
   }
 
-  std::error_code handleReadMem() {
-    char *Src = nullptr;
-    uint64_t Size = 0;
-    if (std::error_code EC =
-            handle<ReadMem>(Channel, [&](TargetAddress RSrc, uint64_t RSize) {
-              Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
-              Size = RSize;
-              return std::error_code();
-            }))
-      return EC;
+  std::error_code handleReadMem(TargetAddress RSrc, uint64_t Size) {
+    char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc));
 
     DEBUG(dbgs() << "  Reading " << Size << " bytes from "
                  << static_cast<void *>(Src) << "\n");
@@ -412,62 +365,49 @@ private:
     return Channel.send();
   }
 
-  std::error_code handleReserveMem() {
+  std::error_code handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size,
+                                   uint32_t Align) {
+    auto I = Allocators.find(Id);
+    if (I == Allocators.end())
+      return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
+    auto &Allocator = I->second;
     void *LocalAllocAddr = nullptr;
-
-    if (std::error_code EC =
-            handle<ReserveMem>(Channel, [&](ResourceIdMgr::ResourceId Id,
-                                            uint64_t Size, uint32_t Align) {
-              auto I = Allocators.find(Id);
-              if (I == Allocators.end())
-                return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
-              auto &Allocator = I->second;
-              auto EC2 = Allocator.allocate(LocalAllocAddr, Size, Align);
-              DEBUG(dbgs() << "  Allocator " << Id << " reserved "
-                           << LocalAllocAddr << " (" << Size
-                           << " bytes, alignment " << Align << ")\n");
-              return EC2;
-            }))
+    if (auto EC = Allocator.allocate(LocalAllocAddr, Size, Align))
       return EC;
 
+    DEBUG(dbgs() << "  Allocator " << Id << " reserved " << LocalAllocAddr
+                 << " (" << Size << " bytes, alignment " << Align << ")\n");
+
     TargetAddress AllocAddr =
         static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr));
 
     return call<ReserveMemResponse>(Channel, AllocAddr);
   }
 
-  std::error_code handleSetProtections() {
-    return handle<ReserveMem>(Channel, [&](ResourceIdMgr::ResourceId Id,
-                                           TargetAddress Addr, uint32_t Flags) {
-      auto I = Allocators.find(Id);
-      if (I == Allocators.end())
-        return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
-      auto &Allocator = I->second;
-      void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr));
-      DEBUG(dbgs() << "  Allocator " << Id << " set permissions on "
-                   << LocalAddr << " to "
-                   << (Flags & sys::Memory::MF_READ ? 'R' : '-')
-                   << (Flags & sys::Memory::MF_WRITE ? 'W' : '-')
-                   << (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n");
-      return Allocator.setProtections(LocalAddr, Flags);
-    });
+  std::error_code handleSetProtections(ResourceIdMgr::ResourceId Id,
+                                       TargetAddress Addr, uint32_t Flags) {
+    auto I = Allocators.find(Id);
+    if (I == Allocators.end())
+      return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist);
+    auto &Allocator = I->second;
+    void *LocalAddr = reinterpret_cast<void *>(static_cast<uintptr_t>(Addr));
+    DEBUG(dbgs() << "  Allocator " << Id << " set permissions on " << LocalAddr
+                 << " to " << (Flags & sys::Memory::MF_READ ? 'R' : '-')
+                 << (Flags & sys::Memory::MF_WRITE ? 'W' : '-')
+                 << (Flags & sys::Memory::MF_EXEC ? 'X' : '-') << "\n");
+    return Allocator.setProtections(LocalAddr, Flags);
   }
 
-  std::error_code handleWriteMem() {
-    return handle<WriteMem>(Channel, [&](TargetAddress RDst, uint64_t Size) {
-      char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst));
-      return Channel.readBytes(Dst, Size);
-    });
+  std::error_code handleWriteMem(TargetAddress RDst, uint64_t Size) {
+    char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst));
+    return Channel.readBytes(Dst, Size);
   }
 
-  std::error_code handleWritePtr() {
-    return handle<WritePtr>(
-        Channel, [&](TargetAddress Addr, TargetAddress PtrVal) {
-          uintptr_t *Ptr =
-              reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr));
-          *Ptr = static_cast<uintptr_t>(PtrVal);
-          return std::error_code();
-        });
+  std::error_code handleWritePtr(TargetAddress Addr, TargetAddress PtrVal) {
+    uintptr_t *Ptr =
+        reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr));
+    *Ptr = static_cast<uintptr_t>(PtrVal);
+    return std::error_code();
   }
 
   ChannelT &Channel;