2017
[folly.git] / folly / portability / Sockets.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/portability/Sockets.h>
18
19 #ifdef _MSC_VER
20
21 #include <errno.h>
22 #include <fcntl.h>
23
24 #include <event2/util.h>
25
26 #include <MSWSock.h>
27
28 #include <folly/ScopeGuard.h>
29
30 namespace folly {
31 namespace portability {
32 namespace sockets {
33
34 // We have to startup WSA.
35 static struct FSPInit {
36   FSPInit() {
37     WSADATA dat;
38     WSAStartup(MAKEWORD(2, 2), &dat);
39   }
40   ~FSPInit() { WSACleanup(); }
41 } fspInit;
42
43 bool is_fh_socket(int fh) {
44   SOCKET h = fd_to_socket(fh);
45   constexpr long kDummyEvents = 0xABCDEF12;
46   WSANETWORKEVENTS e;
47   e.lNetworkEvents = kDummyEvents;
48   WSAEnumNetworkEvents(h, nullptr, &e);
49   return e.lNetworkEvents != kDummyEvents;
50 }
51
52 SOCKET fd_to_socket(int fd) {
53   if (fd == -1) {
54     return INVALID_SOCKET;
55   }
56   // We do this in a roundabout way to allow us to compile even if
57   // we're doing a bit of trickery to ensure that things aren't
58   // being implicitly converted to a SOCKET by temporarily
59   // adjusting the windows headers to define SOCKET as a
60   // structure.
61   static_assert(sizeof(HANDLE) == sizeof(SOCKET), "Handle size mismatch.");
62   HANDLE tmp = (HANDLE)_get_osfhandle(fd);
63   return *(SOCKET*)&tmp;
64 }
65
66 int socket_to_fd(SOCKET s) {
67   if (s == INVALID_SOCKET) {
68     return -1;
69   }
70   return _open_osfhandle((intptr_t)s, O_RDWR | O_BINARY);
71 }
72
73 int translate_wsa_error(int wsaErr) {
74   switch (wsaErr) {
75     case WSAEWOULDBLOCK:
76       return EAGAIN;
77     default:
78       return wsaErr;
79   }
80 }
81
82 template <class R, class F, class... Args>
83 static R wrapSocketFunction(F f, int s, Args... args) {
84   SOCKET h = fd_to_socket(s);
85   R ret = f(h, args...);
86   errno = translate_wsa_error(WSAGetLastError());
87   return ret;
88 }
89
90 int accept(int s, struct sockaddr* addr, socklen_t* addrlen) {
91   return socket_to_fd(wrapSocketFunction<SOCKET>(::accept, s, addr, addrlen));
92 }
93
94 int bind(int s, const struct sockaddr* name, socklen_t namelen) {
95   return wrapSocketFunction<int>(::bind, s, name, namelen);
96 }
97
98 int connect(int s, const struct sockaddr* name, socklen_t namelen) {
99   auto r = wrapSocketFunction<int>(::connect, s, name, namelen);
100   if (r == -1 && WSAGetLastError() == WSAEWOULDBLOCK) {
101     errno = EINPROGRESS;
102   }
103   return r;
104 }
105
106 int getpeername(int s, struct sockaddr* name, socklen_t* namelen) {
107   return wrapSocketFunction<int>(::getpeername, s, name, namelen);
108 }
109
110 int getsockname(int s, struct sockaddr* name, socklen_t* namelen) {
111   return wrapSocketFunction<int>(::getsockname, s, name, namelen);
112 }
113
114 int getsockopt(int s, int level, int optname, char* optval, socklen_t* optlen) {
115   return getsockopt(s, level, optname, (void*)optval, optlen);
116 }
117
118 int getsockopt(int s, int level, int optname, void* optval, socklen_t* optlen) {
119   auto ret = wrapSocketFunction<int>(
120       ::getsockopt, s, level, optname, (char*)optval, (int*)optlen);
121   if (optname == TCP_NODELAY && *optlen == 1) {
122     // Windows is weird about this value, and documents it as a
123     // BOOL (ie. int) but expects the variable to be bool (1-byte),
124     // so we get to adapt the interface to work that way.
125     *(int*)optval = *(uint8_t*)optval;
126     *optlen = sizeof(int);
127   }
128   return ret;
129 }
130
131 int inet_aton(const char* cp, struct in_addr* inp) {
132   inp->s_addr = inet_addr(cp);
133   return inp->s_addr == INADDR_NONE ? 0 : 1;
134 }
135
136 const char* inet_ntop(int af, const void* src, char* dst, socklen_t size) {
137   return ::inet_ntop(af, (char*)src, dst, size);
138 }
139
140 int listen(int s, int backlog) {
141   return wrapSocketFunction<int>(::listen, s, backlog);
142 }
143
144 int poll(struct pollfd fds[], nfds_t nfds, int timeout) {
145   // TODO: Allow both file descriptors and SOCKETs in this.
146   for (int i = 0; i < nfds; i++) {
147     fds[i].fd = fd_to_socket((int)fds[i].fd);
148   }
149   return ::WSAPoll(fds, (ULONG)nfds, timeout);
150 }
151
152 ssize_t recv(int s, void* buf, size_t len, int flags) {
153   if ((flags & MSG_DONTWAIT) == MSG_DONTWAIT) {
154     flags &= ~MSG_DONTWAIT;
155
156     u_long pendingRead = 0;
157     if (ioctlsocket(fd_to_socket(s), FIONREAD, &pendingRead)) {
158       errno = translate_wsa_error(WSAGetLastError());
159       return -1;
160     }
161
162     fd_set readSet;
163     FD_ZERO(&readSet);
164     FD_SET(fd_to_socket(s), &readSet);
165     timeval timeout{0, 0};
166     auto ret = select(1, &readSet, nullptr, nullptr, &timeout);
167     if (ret == 0) {
168       errno = EWOULDBLOCK;
169       return -1;
170     }
171   }
172   return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, (int)len, flags);
173 }
174
175 ssize_t recv(int s, char* buf, int len, int flags) {
176   return recv(s, (void*)buf, (size_t)len, flags);
177 }
178
179 ssize_t recv(int s, void* buf, int len, int flags) {
180   return recv(s, (void*)buf, (size_t)len, flags);
181 }
182
183 ssize_t recvfrom(
184     int s,
185     void* buf,
186     size_t len,
187     int flags,
188     struct sockaddr* from,
189     socklen_t* fromlen) {
190   if ((flags & MSG_TRUNC) == MSG_TRUNC) {
191     SOCKET h = fd_to_socket(s);
192
193     WSABUF wBuf{};
194     wBuf.buf = (CHAR*)buf;
195     wBuf.len = (ULONG)len;
196     WSAMSG wMsg{};
197     wMsg.dwBufferCount = 1;
198     wMsg.lpBuffers = &wBuf;
199     wMsg.name = from;
200     if (fromlen != nullptr) {
201       wMsg.namelen = *fromlen;
202     }
203
204     // WSARecvMsg is an extension, so we don't get
205     // the convenience of being able to call it directly, even though
206     // WSASendMsg is part of the normal API -_-...
207     LPFN_WSARECVMSG WSARecvMsg;
208     GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
209     DWORD recMsgBytes;
210     WSAIoctl(
211         h,
212         SIO_GET_EXTENSION_FUNCTION_POINTER,
213         &WSARecgMsg_GUID,
214         sizeof(WSARecgMsg_GUID),
215         &WSARecvMsg,
216         sizeof(WSARecvMsg),
217         &recMsgBytes,
218         nullptr,
219         nullptr);
220
221     DWORD bytesReceived;
222     int res = WSARecvMsg(h, &wMsg, &bytesReceived, nullptr, nullptr);
223     errno = translate_wsa_error(WSAGetLastError());
224     if (res == 0) {
225       return bytesReceived;
226     }
227     if (fromlen != nullptr) {
228       *fromlen = wMsg.namelen;
229     }
230     if ((wMsg.dwFlags & MSG_TRUNC) == MSG_TRUNC) {
231       return wBuf.len + 1;
232     }
233     return -1;
234   }
235   return wrapSocketFunction<ssize_t>(
236       ::recvfrom, s, (char*)buf, (int)len, flags, from, (int*)fromlen);
237 }
238
239 ssize_t recvfrom(
240     int s,
241     char* buf,
242     int len,
243     int flags,
244     struct sockaddr* from,
245     socklen_t* fromlen) {
246   return recvfrom(s, (void*)buf, (size_t)len, flags, from, fromlen);
247 }
248
249 ssize_t recvfrom(
250     int s,
251     void* buf,
252     int len,
253     int flags,
254     struct sockaddr* from,
255     socklen_t* fromlen) {
256   return recvfrom(s, (void*)buf, (size_t)len, flags, from, fromlen);
257 }
258
259 ssize_t recvmsg(int s, struct msghdr* message, int fl) {
260   SOCKET h = fd_to_socket(s);
261
262   // Don't currently support the name translation.
263   if (message->msg_name != nullptr || message->msg_namelen != 0) {
264     return (ssize_t)-1;
265   }
266   WSAMSG msg;
267   msg.name = nullptr;
268   msg.namelen = 0;
269   msg.Control.buf = (CHAR*)message->msg_control;
270   msg.Control.len = (ULONG)message->msg_controllen;
271   msg.dwFlags = 0;
272   msg.dwBufferCount = (DWORD)message->msg_iovlen;
273   msg.lpBuffers = new WSABUF[message->msg_iovlen];
274   SCOPE_EXIT { delete[] msg.lpBuffers; };
275   for (size_t i = 0; i < message->msg_iovlen; i++) {
276     msg.lpBuffers[i].buf = (CHAR*)message->msg_iov[i].iov_base;
277     msg.lpBuffers[i].len = (ULONG)message->msg_iov[i].iov_len;
278   }
279
280   // WSARecvMsg is an extension, so we don't get
281   // the convenience of being able to call it directly, even though
282   // WSASendMsg is part of the normal API -_-...
283   LPFN_WSARECVMSG WSARecvMsg;
284   GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
285   DWORD recMsgBytes;
286   WSAIoctl(
287       h,
288       SIO_GET_EXTENSION_FUNCTION_POINTER,
289       &WSARecgMsg_GUID,
290       sizeof(WSARecgMsg_GUID),
291       &WSARecvMsg,
292       sizeof(WSARecvMsg),
293       &recMsgBytes,
294       nullptr,
295       nullptr);
296
297   DWORD bytesReceived;
298   int res = WSARecvMsg(h, &msg, &bytesReceived, nullptr, nullptr);
299   errno = translate_wsa_error(WSAGetLastError());
300   return res == 0 ? (ssize_t)bytesReceived : -1;
301 }
302
303 ssize_t send(int s, const void* buf, size_t len, int flags) {
304   return wrapSocketFunction<ssize_t>(
305       ::send, s, (const char*)buf, (int)len, flags);
306 }
307
308 ssize_t send(int s, const char* buf, int len, int flags) {
309   return send(s, (const void*)buf, (size_t)len, flags);
310 }
311
312 ssize_t send(int s, const void* buf, int len, int flags) {
313   return send(s, (const void*)buf, (size_t)len, flags);
314 }
315
316 ssize_t sendmsg(int s, const struct msghdr* message, int fl) {
317   SOCKET h = fd_to_socket(s);
318
319   // Unfortunately, WSASendMsg requires the socket to have been opened
320   // as either SOCK_DGRAM or SOCK_RAW, but sendmsg has no such requirement,
321   // so we have to implement it based on send instead :(
322   ssize_t bytesSent = 0;
323   for (size_t i = 0; i < message->msg_iovlen; i++) {
324     int r = -1;
325     if (message->msg_name != nullptr) {
326       r = ::sendto(
327           h,
328           (const char*)message->msg_iov[i].iov_base,
329           (int)message->msg_iov[i].iov_len,
330           message->msg_flags,
331           (const sockaddr*)message->msg_name,
332           (int)message->msg_namelen);
333     } else {
334       r = ::send(
335           h,
336           (const char*)message->msg_iov[i].iov_base,
337           (int)message->msg_iov[i].iov_len,
338           message->msg_flags);
339     }
340     if (r == -1 || size_t(r) != message->msg_iov[i].iov_len) {
341       errno = translate_wsa_error(WSAGetLastError());
342       if (WSAGetLastError() == WSAEWOULDBLOCK && bytesSent > 0) {
343         return bytesSent;
344       }
345       return -1;
346     }
347     bytesSent += r;
348   }
349   return bytesSent;
350 }
351
352 ssize_t sendto(
353     int s,
354     const void* buf,
355     size_t len,
356     int flags,
357     const sockaddr* to,
358     socklen_t tolen) {
359   return wrapSocketFunction<ssize_t>(
360       ::sendto, s, (const char*)buf, (int)len, flags, to, (int)tolen);
361 }
362
363 ssize_t sendto(
364     int s,
365     const char* buf,
366     int len,
367     int flags,
368     const sockaddr* to,
369     socklen_t tolen) {
370   return sendto(s, (const void*)buf, (size_t)len, flags, to, tolen);
371 }
372
373 ssize_t sendto(
374     int s,
375     const void* buf,
376     int len,
377     int flags,
378     const sockaddr* to,
379     socklen_t tolen) {
380   return sendto(s, buf, (size_t)len, flags, to, tolen);
381 }
382
383 int setsockopt(
384     int s,
385     int level,
386     int optname,
387     const void* optval,
388     socklen_t optlen) {
389   if (optname == SO_REUSEADDR) {
390     // We don't have an equivelent to the Linux & OSX meaning of this
391     // on Windows, so ignore it.
392     return 0;
393   } else if (optname == SO_REUSEPORT) {
394     // Windows's SO_REUSEADDR option is closer to SO_REUSEPORT than
395     // it is to the Linux & OSX meaning of SO_REUSEADDR.
396     return -1;
397   }
398   return wrapSocketFunction<int>(
399       ::setsockopt, s, level, optname, (char*)optval, optlen);
400 }
401
402 int setsockopt(
403     int s,
404     int level,
405     int optname,
406     const char* optval,
407     socklen_t optlen) {
408   return setsockopt(s, level, optname, (const void*)optval, optlen);
409 }
410
411 int shutdown(int s, int how) {
412   return wrapSocketFunction<int>(::shutdown, s, how);
413 }
414
415 int socket(int af, int type, int protocol) {
416   return socket_to_fd(::socket(af, type, protocol));
417 }
418
419 int socketpair(int domain, int type, int protocol, int sv[2]) {
420   if (domain != PF_UNIX || type != SOCK_STREAM || protocol != 0) {
421     return -1;
422   }
423   intptr_t pair[2];
424   auto r = evutil_socketpair(AF_INET, type, protocol, pair);
425   if (r == -1) {
426     return r;
427   }
428   sv[0] = _open_osfhandle(pair[0], O_RDWR | O_BINARY);
429   sv[1] = _open_osfhandle(pair[1], O_RDWR | O_BINARY);
430   return 0;
431 }
432 }
433 }
434 }
435 #endif