Enable auto-deps in all of Folly
[folly.git] / folly / portability / Sockets.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/portability/Sockets.h>
18
19 #ifdef _MSC_VER
20
21 #include <errno.h>
22 #include <fcntl.h>
23
24 #include <event2/util.h> // @manual
25
26 #include <MSWSock.h> // @manual
27
28 #include <folly/ScopeGuard.h>
29
30 namespace folly {
31 namespace portability {
32 namespace sockets {
33
34 // We have to startup WSA.
35 static struct FSPInit {
36   FSPInit() {
37     WSADATA dat;
38     WSAStartup(MAKEWORD(2, 2), &dat);
39   }
40   ~FSPInit() {
41     WSACleanup();
42   }
43 } fspInit;
44
45 bool is_fh_socket(int fh) {
46   SOCKET h = fd_to_socket(fh);
47   constexpr long kDummyEvents = 0xABCDEF12;
48   WSANETWORKEVENTS e;
49   e.lNetworkEvents = kDummyEvents;
50   WSAEnumNetworkEvents(h, nullptr, &e);
51   return e.lNetworkEvents != kDummyEvents;
52 }
53
54 SOCKET fd_to_socket(int fd) {
55   if (fd == -1) {
56     return INVALID_SOCKET;
57   }
58   // We do this in a roundabout way to allow us to compile even if
59   // we're doing a bit of trickery to ensure that things aren't
60   // being implicitly converted to a SOCKET by temporarily
61   // adjusting the windows headers to define SOCKET as a
62   // structure.
63   static_assert(sizeof(HANDLE) == sizeof(SOCKET), "Handle size mismatch.");
64   HANDLE tmp = (HANDLE)_get_osfhandle(fd);
65   return *(SOCKET*)&tmp;
66 }
67
68 int socket_to_fd(SOCKET s) {
69   if (s == INVALID_SOCKET) {
70     return -1;
71   }
72   return _open_osfhandle((intptr_t)s, O_RDWR | O_BINARY);
73 }
74
75 int translate_wsa_error(int wsaErr) {
76   switch (wsaErr) {
77     case WSAEWOULDBLOCK:
78       return EAGAIN;
79     default:
80       return wsaErr;
81   }
82 }
83
84 template <class R, class F, class... Args>
85 static R wrapSocketFunction(F f, int s, Args... args) {
86   SOCKET h = fd_to_socket(s);
87   R ret = f(h, args...);
88   errno = translate_wsa_error(WSAGetLastError());
89   return ret;
90 }
91
92 int accept(int s, struct sockaddr* addr, socklen_t* addrlen) {
93   return socket_to_fd(wrapSocketFunction<SOCKET>(::accept, s, addr, addrlen));
94 }
95
96 int bind(int s, const struct sockaddr* name, socklen_t namelen) {
97   return wrapSocketFunction<int>(::bind, s, name, namelen);
98 }
99
100 int connect(int s, const struct sockaddr* name, socklen_t namelen) {
101   auto r = wrapSocketFunction<int>(::connect, s, name, namelen);
102   if (r == -1 && WSAGetLastError() == WSAEWOULDBLOCK) {
103     errno = EINPROGRESS;
104   }
105   return r;
106 }
107
108 int getpeername(int s, struct sockaddr* name, socklen_t* namelen) {
109   return wrapSocketFunction<int>(::getpeername, s, name, namelen);
110 }
111
112 int getsockname(int s, struct sockaddr* name, socklen_t* namelen) {
113   return wrapSocketFunction<int>(::getsockname, s, name, namelen);
114 }
115
116 int getsockopt(int s, int level, int optname, char* optval, socklen_t* optlen) {
117   return getsockopt(s, level, optname, (void*)optval, optlen);
118 }
119
120 int getsockopt(int s, int level, int optname, void* optval, socklen_t* optlen) {
121   auto ret = wrapSocketFunction<int>(
122       ::getsockopt, s, level, optname, (char*)optval, (int*)optlen);
123   if (optname == TCP_NODELAY && *optlen == 1) {
124     // Windows is weird about this value, and documents it as a
125     // BOOL (ie. int) but expects the variable to be bool (1-byte),
126     // so we get to adapt the interface to work that way.
127     *(int*)optval = *(uint8_t*)optval;
128     *optlen = sizeof(int);
129   }
130   return ret;
131 }
132
133 int inet_aton(const char* cp, struct in_addr* inp) {
134   inp->s_addr = inet_addr(cp);
135   return inp->s_addr == INADDR_NONE ? 0 : 1;
136 }
137
138 const char* inet_ntop(int af, const void* src, char* dst, socklen_t size) {
139   return ::inet_ntop(af, (char*)src, dst, size_t(size));
140 }
141
142 int listen(int s, int backlog) {
143   return wrapSocketFunction<int>(::listen, s, backlog);
144 }
145
146 int poll(struct pollfd fds[], nfds_t nfds, int timeout) {
147   // TODO: Allow both file descriptors and SOCKETs in this.
148   for (int i = 0; i < nfds; i++) {
149     fds[i].fd = fd_to_socket((int)fds[i].fd);
150   }
151   return ::WSAPoll(fds, (ULONG)nfds, timeout);
152 }
153
154 ssize_t recv(int s, void* buf, size_t len, int flags) {
155   if ((flags & MSG_DONTWAIT) == MSG_DONTWAIT) {
156     flags &= ~MSG_DONTWAIT;
157
158     u_long pendingRead = 0;
159     if (ioctlsocket(fd_to_socket(s), FIONREAD, &pendingRead)) {
160       errno = translate_wsa_error(WSAGetLastError());
161       return -1;
162     }
163
164     fd_set readSet;
165     FD_ZERO(&readSet);
166     FD_SET(fd_to_socket(s), &readSet);
167     timeval timeout{0, 0};
168     auto ret = select(1, &readSet, nullptr, nullptr, &timeout);
169     if (ret == 0) {
170       errno = EWOULDBLOCK;
171       return -1;
172     }
173   }
174   return wrapSocketFunction<ssize_t>(::recv, s, (char*)buf, (int)len, flags);
175 }
176
177 ssize_t recv(int s, char* buf, int len, int flags) {
178   return recv(s, (void*)buf, (size_t)len, flags);
179 }
180
181 ssize_t recv(int s, void* buf, int len, int flags) {
182   return recv(s, (void*)buf, (size_t)len, flags);
183 }
184
185 ssize_t recvfrom(
186     int s,
187     void* buf,
188     size_t len,
189     int flags,
190     struct sockaddr* from,
191     socklen_t* fromlen) {
192   if ((flags & MSG_TRUNC) == MSG_TRUNC) {
193     SOCKET h = fd_to_socket(s);
194
195     WSABUF wBuf{};
196     wBuf.buf = (CHAR*)buf;
197     wBuf.len = (ULONG)len;
198     WSAMSG wMsg{};
199     wMsg.dwBufferCount = 1;
200     wMsg.lpBuffers = &wBuf;
201     wMsg.name = from;
202     if (fromlen != nullptr) {
203       wMsg.namelen = *fromlen;
204     }
205
206     // WSARecvMsg is an extension, so we don't get
207     // the convenience of being able to call it directly, even though
208     // WSASendMsg is part of the normal API -_-...
209     LPFN_WSARECVMSG WSARecvMsg;
210     GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
211     DWORD recMsgBytes;
212     WSAIoctl(
213         h,
214         SIO_GET_EXTENSION_FUNCTION_POINTER,
215         &WSARecgMsg_GUID,
216         sizeof(WSARecgMsg_GUID),
217         &WSARecvMsg,
218         sizeof(WSARecvMsg),
219         &recMsgBytes,
220         nullptr,
221         nullptr);
222
223     DWORD bytesReceived;
224     int res = WSARecvMsg(h, &wMsg, &bytesReceived, nullptr, nullptr);
225     errno = translate_wsa_error(WSAGetLastError());
226     if (res == 0) {
227       return bytesReceived;
228     }
229     if (fromlen != nullptr) {
230       *fromlen = wMsg.namelen;
231     }
232     if ((wMsg.dwFlags & MSG_TRUNC) == MSG_TRUNC) {
233       return wBuf.len + 1;
234     }
235     return -1;
236   }
237   return wrapSocketFunction<ssize_t>(
238       ::recvfrom, s, (char*)buf, (int)len, flags, from, (int*)fromlen);
239 }
240
241 ssize_t recvfrom(
242     int s,
243     char* buf,
244     int len,
245     int flags,
246     struct sockaddr* from,
247     socklen_t* fromlen) {
248   return recvfrom(s, (void*)buf, (size_t)len, flags, from, fromlen);
249 }
250
251 ssize_t recvfrom(
252     int s,
253     void* buf,
254     int len,
255     int flags,
256     struct sockaddr* from,
257     socklen_t* fromlen) {
258   return recvfrom(s, (void*)buf, (size_t)len, flags, from, fromlen);
259 }
260
261 ssize_t recvmsg(int s, struct msghdr* message, int /* flags */) {
262   SOCKET h = fd_to_socket(s);
263
264   // Don't currently support the name translation.
265   if (message->msg_name != nullptr || message->msg_namelen != 0) {
266     return (ssize_t)-1;
267   }
268   WSAMSG msg;
269   msg.name = nullptr;
270   msg.namelen = 0;
271   msg.Control.buf = (CHAR*)message->msg_control;
272   msg.Control.len = (ULONG)message->msg_controllen;
273   msg.dwFlags = 0;
274   msg.dwBufferCount = (DWORD)message->msg_iovlen;
275   msg.lpBuffers = new WSABUF[message->msg_iovlen];
276   SCOPE_EXIT {
277     delete[] msg.lpBuffers;
278   };
279   for (size_t i = 0; i < message->msg_iovlen; i++) {
280     msg.lpBuffers[i].buf = (CHAR*)message->msg_iov[i].iov_base;
281     msg.lpBuffers[i].len = (ULONG)message->msg_iov[i].iov_len;
282   }
283
284   // WSARecvMsg is an extension, so we don't get
285   // the convenience of being able to call it directly, even though
286   // WSASendMsg is part of the normal API -_-...
287   LPFN_WSARECVMSG WSARecvMsg;
288   GUID WSARecgMsg_GUID = WSAID_WSARECVMSG;
289   DWORD recMsgBytes;
290   WSAIoctl(
291       h,
292       SIO_GET_EXTENSION_FUNCTION_POINTER,
293       &WSARecgMsg_GUID,
294       sizeof(WSARecgMsg_GUID),
295       &WSARecvMsg,
296       sizeof(WSARecvMsg),
297       &recMsgBytes,
298       nullptr,
299       nullptr);
300
301   DWORD bytesReceived;
302   int res = WSARecvMsg(h, &msg, &bytesReceived, nullptr, nullptr);
303   errno = translate_wsa_error(WSAGetLastError());
304   return res == 0 ? (ssize_t)bytesReceived : -1;
305 }
306
307 ssize_t send(int s, const void* buf, size_t len, int flags) {
308   return wrapSocketFunction<ssize_t>(
309       ::send, s, (const char*)buf, (int)len, flags);
310 }
311
312 ssize_t send(int s, const char* buf, int len, int flags) {
313   return send(s, (const void*)buf, (size_t)len, flags);
314 }
315
316 ssize_t send(int s, const void* buf, int len, int flags) {
317   return send(s, (const void*)buf, (size_t)len, flags);
318 }
319
320 ssize_t sendmsg(int s, const struct msghdr* message, int /* flags */) {
321   SOCKET h = fd_to_socket(s);
322
323   // Unfortunately, WSASendMsg requires the socket to have been opened
324   // as either SOCK_DGRAM or SOCK_RAW, but sendmsg has no such requirement,
325   // so we have to implement it based on send instead :(
326   ssize_t bytesSent = 0;
327   for (size_t i = 0; i < message->msg_iovlen; i++) {
328     int r = -1;
329     if (message->msg_name != nullptr) {
330       r = ::sendto(
331           h,
332           (const char*)message->msg_iov[i].iov_base,
333           (int)message->msg_iov[i].iov_len,
334           message->msg_flags,
335           (const sockaddr*)message->msg_name,
336           (int)message->msg_namelen);
337     } else {
338       r = ::send(
339           h,
340           (const char*)message->msg_iov[i].iov_base,
341           (int)message->msg_iov[i].iov_len,
342           message->msg_flags);
343     }
344     if (r == -1 || size_t(r) != message->msg_iov[i].iov_len) {
345       errno = translate_wsa_error(WSAGetLastError());
346       if (WSAGetLastError() == WSAEWOULDBLOCK && bytesSent > 0) {
347         return bytesSent;
348       }
349       return -1;
350     }
351     bytesSent += r;
352   }
353   return bytesSent;
354 }
355
356 ssize_t sendto(
357     int s,
358     const void* buf,
359     size_t len,
360     int flags,
361     const sockaddr* to,
362     socklen_t tolen) {
363   return wrapSocketFunction<ssize_t>(
364       ::sendto, s, (const char*)buf, (int)len, flags, to, (int)tolen);
365 }
366
367 ssize_t sendto(
368     int s,
369     const char* buf,
370     int len,
371     int flags,
372     const sockaddr* to,
373     socklen_t tolen) {
374   return sendto(s, (const void*)buf, (size_t)len, flags, to, tolen);
375 }
376
377 ssize_t sendto(
378     int s,
379     const void* buf,
380     int len,
381     int flags,
382     const sockaddr* to,
383     socklen_t tolen) {
384   return sendto(s, buf, (size_t)len, flags, to, tolen);
385 }
386
387 int setsockopt(
388     int s,
389     int level,
390     int optname,
391     const void* optval,
392     socklen_t optlen) {
393   if (optname == SO_REUSEADDR) {
394     // We don't have an equivelent to the Linux & OSX meaning of this
395     // on Windows, so ignore it.
396     return 0;
397   } else if (optname == SO_REUSEPORT) {
398     // Windows's SO_REUSEADDR option is closer to SO_REUSEPORT than
399     // it is to the Linux & OSX meaning of SO_REUSEADDR.
400     return -1;
401   }
402   return wrapSocketFunction<int>(
403       ::setsockopt, s, level, optname, (char*)optval, optlen);
404 }
405
406 int setsockopt(
407     int s,
408     int level,
409     int optname,
410     const char* optval,
411     socklen_t optlen) {
412   return setsockopt(s, level, optname, (const void*)optval, optlen);
413 }
414
415 int shutdown(int s, int how) {
416   return wrapSocketFunction<int>(::shutdown, s, how);
417 }
418
419 int socket(int af, int type, int protocol) {
420   return socket_to_fd(::socket(af, type, protocol));
421 }
422
423 int socketpair(int domain, int type, int protocol, int sv[2]) {
424   if (domain != PF_UNIX || type != SOCK_STREAM || protocol != 0) {
425     return -1;
426   }
427   intptr_t pair[2];
428   auto r = evutil_socketpair(AF_INET, type, protocol, pair);
429   if (r == -1) {
430     return r;
431   }
432   sv[0] = _open_osfhandle(pair[0], O_RDWR | O_BINARY);
433   sv[1] = _open_osfhandle(pair[1], O_RDWR | O_BINARY);
434   return 0;
435 }
436 }
437 }
438 }
439 #endif