Implement more of the sockets API
[folly.git] / folly / portability / Sockets.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/portability/Sockets.h>
18
19 #ifdef _MSC_VER
20
21 #include <errno.h>
22 #include <fcntl.h>
23
24 #include <event2/util.h>
25
26 #include <MSWSock.h>
27
28 #include <folly/ScopeGuard.h>
29
30 namespace folly {
31 namespace portability {
32 namespace sockets {
33
34 // We have to startup WSA.
35 static struct FSPInit {
36   FSPInit() {
37     WSADATA dat;
38     WSAStartup(MAKEWORD(2, 2), &dat);
39   }
40   ~FSPInit() { WSACleanup(); }
41 } fspInit;
42
43 bool is_fh_socket(int fh) {
44   SOCKET h = fd_to_socket(fh);
45   constexpr long kDummyEvents = 0xABCDEF12;
46   WSANETWORKEVENTS e;
47   e.lNetworkEvents = kDummyEvents;
48   WSAEnumNetworkEvents(h, nullptr, &e);
49   return e.lNetworkEvents != kDummyEvents;
50 }
51
52 SOCKET fd_to_socket(int fd) {
53   if (fd == -1) {
54     return INVALID_SOCKET;
55   }
56   // We do this in a roundabout way to allow us to compile even if
57   // we're doing a bit of trickery to ensure that things aren't
58   // being implicitly converted to a SOCKET by temporarily
59   // adjusting the windows headers to define SOCKET as a
60   // structure.
61   static_assert(sizeof(HANDLE) == sizeof(SOCKET), "Handle size mismatch.");
62   HANDLE tmp = (HANDLE)_get_osfhandle(fd);
63   return *(SOCKET*)&tmp;
64 }
65
66 int socket_to_fd(SOCKET s) {
67   if (s == INVALID_SOCKET) {
68     return -1;
69   }
70   return _open_osfhandle((intptr_t)s, O_RDWR | O_BINARY);
71 }
72
73 int translate_wsa_error(int wsaErr) {
74   switch (wsaErr) {
75     case WSAEWOULDBLOCK:
76       return EAGAIN;
77     default:
78       return wsaErr;
79   }
80 }
81
82 template <class R, class F, class... Args>
83 static R wrapSocketFunction(F f, int s, Args... args) {
84   SOCKET h = fd_to_socket(s);
85   R ret = f(h, args...);
86   errno = translate_wsa_error(WSAGetLastError());
87   return ret;
88 }
89
90 int accept(int s, struct sockaddr* addr, socklen_t* addrlen) {
91   return socket_to_fd(wrapSocketFunction<SOCKET>(::accept, s, addr, addrlen));
92 }
93
94 int bind(int s, const struct sockaddr* name, socklen_t namelen) {
95   return wrapSocketFunction<int>(::bind, s, name, namelen);
96 }
97
98 int connect(int s, const struct sockaddr* name, socklen_t namelen) {
99   auto r = wrapSocketFunction<int>(::connect, s, name, namelen);
100   if (r == -1 && WSAGetLastError() == WSAEWOULDBLOCK) {
101     errno = EINPROGRESS;
102   }
103   return r;
104 }
105
106 int getpeername(int s, struct sockaddr* name, socklen_t* namelen) {
107   return wrapSocketFunction<int>(::getpeername, s, name, namelen);
108 }
109
110 int getsockname(int s, struct sockaddr* name, socklen_t* namelen) {
111   return wrapSocketFunction<int>(::getsockname, s, name, namelen);
112 }
113
114 int getsockopt(int s, int level, int optname, char* optval, socklen_t* optlen) {
115   return getsockopt(s, level, optname, (void*)optval, optlen);
116 }
117
118 int getsockopt(int s, int level, int optname, void* optval, socklen_t* optlen) {
119   auto ret = wrapSocketFunction<int>(
120       ::getsockopt, s, level, optname, (char*)optval, (int*)optlen);
121   if (optname == TCP_NODELAY && *optlen == 1) {
122     // Windows is weird about this value, and documents it as a
123     // BOOL (ie. int) but expects the variable to be bool (1-byte),
124     // so we get to adapt the interface to work that way.
125     *(int*)optval = *(uint8_t*)optval;
126     *optlen = sizeof(int);
127   }
128   return ret;
129 }
130
131 int inet_aton(const char* cp, struct in_addr* inp) {
132   inp->s_addr = inet_addr(cp);
133   return inp->s_addr == INADDR_NONE ? 0 : 1;
134 }
135
136 const char* inet_ntop(int af, const void* src, char* dst, socklen_t size) {
137   return ::inet_ntop(af, (char*)src, dst, size);
138 }
139
140 int listen(int s, int backlog) {
141   return wrapSocketFunction<int>(::listen, s, backlog);
142 }
143
144 int poll(struct pollfd fds[], nfds_t nfds, int timeout) {
145   // TODO: Allow both file descriptors and SOCKETs in this.
146   for (int i = 0; i < nfds; i++) {
147     fds[i].fd = fd_to_socket(fds[i].fd);
148   }
149   return ::WSAPoll(fds, (ULONG)nfds, timeout);
150 }
151
152 ssize_t recv(int s, void* buf, size_t len, int flags) {
153   if ((flags & MSG_DONTWAIT) == MSG_DONTWAIT) {
154     flags &= ~MSG_DONTWAIT;
155
156     u_long pendingRead = 0;
157     if (ioctlsocket(fd_to_socket(s), FIONREAD, &pendingRead)) {
158       errno = translate_wsa_error(WSAGetLastError());
159       return -1;
160     }
161
162     fd_set readSet;
163     FD_ZERO(&readSet);
164     FD_SET(fd_to_socket(s), &readSet);
165     timeval timeout{0, 0};
166     auto ret = select(1, &readSet, nullptr, nullptr, &timeout);
167     if (ret == 0) {
168       errno = EWOULDBLOCK;
169       return -1;
170     }
171   }
172   return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, (int)len, flags);
173 }
174
175 ssize_t recv(int s, char* buf, int len, int flags) {
176   return recv(s, (void*)buf, (size_t)len, flags);
177 }
178
179 ssize_t recv(int s, void* buf, int len, int flags) {
180   return recv(s, (void*)buf, (size_t)len, flags);
181 }
182
183 ssize_t recvfrom(
184     int s,
185     void* buf,
186     size_t len,
187     int flags,
188     struct sockaddr* from,
189     socklen_t* fromlen) {
190   if ((flags & MSG_TRUNC) == MSG_TRUNC) {
191     SOCKET h = fd_to_socket(s);
192
193     WSABUF wBuf{};
194     wBuf.buf = (CHAR*)buf;
195     wBuf.len = len;
196     WSAMSG wMsg{};
197     wMsg.dwBufferCount = 1;
198     wMsg.lpBuffers = &wBuf;
199     wMsg.name = from;
200     if (fromlen != nullptr) {
201       wMsg.namelen = *fromlen;
202     }
203
204     // WSARecvMsg is an extension, so we don't get
205     // the convenience of being able to call it directly, even though
206     // WSASendMsg is part of the normal API -_-...
207     LPFN_WSARECVMSG WSARecvMsg;
208     GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
209     DWORD recMsgBytes;
210     WSAIoctl(
211         h,
212         SIO_GET_EXTENSION_FUNCTION_POINTER,
213         &WSARecgMsg_GUID,
214         sizeof(WSARecgMsg_GUID),
215         &WSARecvMsg,
216         sizeof(WSARecvMsg),
217         &recMsgBytes,
218         nullptr,
219         nullptr);
220
221     DWORD bytesReceived;
222     int res = WSARecvMsg(h, &wMsg, &bytesReceived, nullptr, nullptr);
223     errno = translate_wsa_error(WSAGetLastError());
224     if (res == 0) {
225       return bytesReceived;
226     }
227     if (fromlen != nullptr) {
228       *fromlen = wMsg.namelen;
229     }
230     if ((wMsg.dwFlags & MSG_TRUNC) == MSG_TRUNC) {
231       return wBuf.len + 1;
232     }
233     return -1;
234   }
235   return wrapSocketFunction<ssize_t>(
236       ::recvfrom, s, (char*)buf, (int)len, flags, from, (int*)fromlen);
237 }
238
239 ssize_t recvfrom(
240     int s,
241     char* buf,
242     int len,
243     int flags,
244     struct sockaddr* from,
245     socklen_t* fromlen) {
246   return recvfrom(s, (void*)buf, (size_t)len, flags, from, fromlen);
247 }
248
249 ssize_t recvfrom(
250     int s,
251     void* buf,
252     int len,
253     int flags,
254     struct sockaddr* from,
255     socklen_t* fromlen) {
256   return recvfrom(s, (void*)buf, (size_t)len, flags, from, fromlen);
257 }
258
259 ssize_t recvmsg(int s, struct msghdr* message, int fl) {
260   SOCKET h = fd_to_socket(s);
261
262   // Don't currently support the name translation.
263   if (message->msg_name != nullptr || message->msg_namelen != 0) {
264     return (ssize_t)-1;
265   }
266   WSAMSG msg;
267   msg.name = nullptr;
268   msg.namelen = 0;
269   msg.Control.buf = (CHAR*)message->msg_control;
270   msg.Control.len = (ULONG)message->msg_controllen;
271   msg.dwFlags = 0;
272   msg.dwBufferCount = (DWORD)message->msg_iovlen;
273   msg.lpBuffers = new WSABUF[message->msg_iovlen];
274   SCOPE_EXIT { delete[] msg.lpBuffers; };
275   for (size_t i = 0; i < message->msg_iovlen; i++) {
276     msg.lpBuffers[i].buf = (CHAR*)message->msg_iov[i].iov_base;
277     msg.lpBuffers[i].len = (ULONG)message->msg_iov[i].iov_len;
278   }
279
280   // WSARecvMsg is an extension, so we don't get
281   // the convenience of being able to call it directly, even though
282   // WSASendMsg is part of the normal API -_-...
283   LPFN_WSARECVMSG WSARecvMsg;
284   GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
285   DWORD recMsgBytes;
286   WSAIoctl(
287       h,
288       SIO_GET_EXTENSION_FUNCTION_POINTER,
289       &WSARecgMsg_GUID,
290       sizeof(WSARecgMsg_GUID),
291       &WSARecvMsg,
292       sizeof(WSARecvMsg),
293       &recMsgBytes,
294       nullptr,
295       nullptr);
296
297   DWORD bytesReceived;
298   int res = WSARecvMsg(h, &msg, &bytesReceived, nullptr, nullptr);
299   errno = translate_wsa_error(WSAGetLastError());
300   return res == 0 ? (ssize_t)bytesReceived : -1;
301 }
302
303 ssize_t send(int s, const void* buf, size_t len, int flags) {
304   return wrapSocketFunction<ssize_t>(
305       ::send, s, (const char*)buf, (int)len, flags);
306 }
307
308 ssize_t send(int s, const char* buf, int len, int flags) {
309   return send(s, (const void*)buf, (size_t)len, flags);
310 }
311
312 ssize_t send(int s, const void* buf, int len, int flags) {
313   return send(s, (const void*)buf, (size_t)len, flags);
314 }
315
316 ssize_t sendmsg(int s, const struct msghdr* message, int fl) {
317   SOCKET h = fd_to_socket(s);
318
319   // Unfortunately, WSASendMsg requires the socket to have been opened
320   // as either SOCK_DGRAM or SOCK_RAW, but sendmsg has no such requirement,
321   // so we have to implement it based on send instead :(
322   ssize_t bytesSent = 0;
323   for (size_t i = 0; i < message->msg_iovlen; i++) {
324     int r = -1;
325     if (message->msg_name != nullptr) {
326       r = ::sendto(
327           h,
328           (const char*)message->msg_iov[i].iov_base,
329           (int)message->msg_iov[i].iov_len,
330           message->msg_flags,
331           (const sockaddr*)message->msg_name,
332           (int)message->msg_namelen);
333     } else {
334       r = ::send(
335           h,
336           (const char*)message->msg_iov[i].iov_base,
337           (int)message->msg_iov[i].iov_len,
338           message->msg_flags);
339     }
340     if (r == -1 || r != message->msg_iov[i].iov_len) {
341       return -1;
342     }
343     bytesSent += r;
344   }
345   return bytesSent;
346 }
347
348 ssize_t sendto(
349     int s,
350     const void* buf,
351     size_t len,
352     int flags,
353     const sockaddr* to,
354     socklen_t tolen) {
355   return wrapSocketFunction<ssize_t>(
356       ::sendto, s, (const char*)buf, (int)len, flags, to, (int)tolen);
357 }
358
359 ssize_t sendto(
360     int s,
361     const char* buf,
362     int len,
363     int flags,
364     const sockaddr* to,
365     socklen_t tolen) {
366   return sendto(s, (const void*)buf, (size_t)len, flags, to, tolen);
367 }
368
369 ssize_t sendto(
370     int s,
371     const void* buf,
372     int len,
373     int flags,
374     const sockaddr* to,
375     socklen_t tolen) {
376   return sendto(s, buf, (size_t)len, flags, to, tolen);
377 }
378
379 int setsockopt(
380     int s,
381     int level,
382     int optname,
383     const void* optval,
384     socklen_t optlen) {
385   if (optname == SO_REUSEADDR) {
386     // We don't have an equivelent to the Linux & OSX meaning of this
387     // on Windows, so ignore it.
388     return 0;
389   } else if (optname == SO_REUSEPORT) {
390     // Windows's SO_REUSEADDR option is closer to SO_REUSEPORT than
391     // it is to the Linux & OSX meaning of SO_REUSEADDR.
392     return -1;
393   }
394   return wrapSocketFunction<int>(
395       ::setsockopt, s, level, optname, (char*)optval, optlen);
396 }
397
398 int setsockopt(
399     int s,
400     int level,
401     int optname,
402     const char* optval,
403     socklen_t optlen) {
404   return setsockopt(s, level, optname, (const void*)optval, optlen);
405 }
406
407 int shutdown(int s, int how) {
408   return wrapSocketFunction<int>(::shutdown, s, how);
409 }
410
411 int socket(int af, int type, int protocol) {
412   return socket_to_fd(::socket(af, type, protocol));
413 }
414
415 int socketpair(int domain, int type, int protocol, int sv[2]) {
416   if (domain != PF_UNIX || type != SOCK_STREAM || protocol != 0) {
417     return -1;
418   }
419   intptr_t pair[2];
420   auto r = evutil_socketpair(AF_INET, type, protocol, pair);
421   if (r == -1) {
422     return r;
423   }
424   sv[0] = _open_osfhandle(pair[0], O_RDWR | O_BINARY);
425   sv[1] = _open_osfhandle(pair[1], O_RDWR | O_BINARY);
426   return 0;
427 }
428 }
429 }
430 }
431 #endif