Fix build with Windows SDK v10.0.16232
[folly.git] / folly / io / async / test / AsyncSocketTest2.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/test/AsyncSocketTest2.h>
18
19 #include <folly/ExceptionWrapper.h>
20 #include <folly/Random.h>
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSocket.h>
23 #include <folly/io/async/AsyncTimeout.h>
24 #include <folly/io/async/EventBase.h>
25
26 #include <folly/experimental/TestUtil.h>
27 #include <folly/io/IOBuf.h>
28 #include <folly/io/async/test/AsyncSocketTest.h>
29 #include <folly/io/async/test/Util.h>
30 #include <folly/portability/GMock.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/Sockets.h>
33 #include <folly/portability/Unistd.h>
34 #include <folly/test/SocketAddressTestHelper.h>
35
36 #include <boost/scoped_array.hpp>
37 #include <fcntl.h>
38 #include <sys/types.h>
39 #include <iostream>
40 #include <thread>
41
42 using namespace boost;
43
44 using std::string;
45 using std::vector;
46 using std::min;
47 using std::cerr;
48 using std::endl;
49 using std::unique_ptr;
50 using std::chrono::milliseconds;
51 using boost::scoped_array;
52
53 using namespace folly;
54 using namespace folly::test;
55 using namespace testing;
56
57 namespace fsp = folly::portability::sockets;
58
59 class DelayedWrite: public AsyncTimeout {
60  public:
61   DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
62       unique_ptr<IOBuf>&& bufs, AsyncTransportWrapper::WriteCallback* wcb,
63       bool cork, bool lastWrite = false):
64     AsyncTimeout(socket->getEventBase()),
65     socket_(socket),
66     bufs_(std::move(bufs)),
67     wcb_(wcb),
68     cork_(cork),
69     lastWrite_(lastWrite) {}
70
71  private:
72   void timeoutExpired() noexcept override {
73     WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
74     socket_->writeChain(wcb_, std::move(bufs_), flags);
75     if (lastWrite_) {
76       socket_->shutdownWrite();
77     }
78   }
79
80   std::shared_ptr<AsyncSocket> socket_;
81   unique_ptr<IOBuf> bufs_;
82   AsyncTransportWrapper::WriteCallback* wcb_;
83   bool cork_;
84   bool lastWrite_;
85 };
86
87 ///////////////////////////////////////////////////////////////////////////
88 // connect() tests
89 ///////////////////////////////////////////////////////////////////////////
90
91 /**
92  * Test connecting to a server
93  */
94 TEST(AsyncSocketTest, Connect) {
95   // Start listening on a local port
96   TestServer server;
97
98   // Connect using a AsyncSocket
99   EventBase evb;
100   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
101   ConnCallback cb;
102   socket->connect(&cb, server.getAddress(), 30);
103
104   evb.loop();
105
106   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
107   EXPECT_LE(0, socket->getConnectTime().count());
108   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
109 }
110
111 enum class TFOState {
112   DISABLED,
113   ENABLED,
114 };
115
116 class AsyncSocketConnectTest : public ::testing::TestWithParam<TFOState> {};
117
118 std::vector<TFOState> getTestingValues() {
119   std::vector<TFOState> vals;
120   vals.emplace_back(TFOState::DISABLED);
121
122 #if FOLLY_ALLOW_TFO
123   vals.emplace_back(TFOState::ENABLED);
124 #endif
125   return vals;
126 }
127
128 INSTANTIATE_TEST_CASE_P(
129     ConnectTests,
130     AsyncSocketConnectTest,
131     ::testing::ValuesIn(getTestingValues()));
132
133 /**
134  * Test connecting to a server that isn't listening
135  */
136 TEST(AsyncSocketTest, ConnectRefused) {
137   EventBase evb;
138
139   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
140
141   // Hopefully nothing is actually listening on this address
142   folly::SocketAddress addr("127.0.0.1", 65535);
143   ConnCallback cb;
144   socket->connect(&cb, addr, 30);
145
146   evb.loop();
147
148   EXPECT_EQ(STATE_FAILED, cb.state);
149   EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType());
150   EXPECT_LE(0, socket->getConnectTime().count());
151   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
152 }
153
154 /**
155  * Test connection timeout
156  */
157 TEST(AsyncSocketTest, ConnectTimeout) {
158   EventBase evb;
159
160   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
161
162   // Try connecting to server that won't respond.
163   //
164   // This depends somewhat on the network where this test is run.
165   // Hopefully this IP will be routable but unresponsive.
166   // (Alternatively, we could try listening on a local raw socket, but that
167   // normally requires root privileges.)
168   auto host =
169       SocketAddressTestHelper::isIPv6Enabled() ?
170       SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6 :
171       SocketAddressTestHelper::isIPv4Enabled() ?
172       SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4 :
173       nullptr;
174   SocketAddress addr(host, 65535);
175   ConnCallback cb;
176   socket->connect(&cb, addr, 1); // also set a ridiculously small timeout
177
178   evb.loop();
179
180   ASSERT_EQ(cb.state, STATE_FAILED);
181   ASSERT_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);
182
183   // Verify that we can still get the peer address after a timeout.
184   // Use case is if the client was created from a client pool, and we want
185   // to log which peer failed.
186   folly::SocketAddress peer;
187   socket->getPeerAddress(&peer);
188   ASSERT_EQ(peer, addr);
189   EXPECT_LE(0, socket->getConnectTime().count());
190   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(1));
191 }
192
193 /**
194  * Test writing immediately after connecting, without waiting for connect
195  * to finish.
196  */
197 TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
198   TestServer server;
199
200   // connect()
201   EventBase evb;
202   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
203
204   if (GetParam() == TFOState::ENABLED) {
205     socket->enableTFO();
206   }
207
208   ConnCallback ccb;
209   socket->connect(&ccb, server.getAddress(), 30);
210
211   // write()
212   char buf[128];
213   memset(buf, 'a', sizeof(buf));
214   WriteCallback wcb;
215   socket->write(&wcb, buf, sizeof(buf));
216
217   // Loop.  We don't bother accepting on the server socket yet.
218   // The kernel should be able to buffer the write request so it can succeed.
219   evb.loop();
220
221   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
222   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
223
224   // Make sure the server got a connection and received the data
225   socket->close();
226   server.verifyConnection(buf, sizeof(buf));
227
228   ASSERT_TRUE(socket->isClosedBySelf());
229   ASSERT_FALSE(socket->isClosedByPeer());
230   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
231 }
232
233 /**
234  * Test connecting using a nullptr connect callback.
235  */
236 TEST_P(AsyncSocketConnectTest, ConnectNullCallback) {
237   TestServer server;
238
239   // connect()
240   EventBase evb;
241   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
242   if (GetParam() == TFOState::ENABLED) {
243     socket->enableTFO();
244   }
245
246   socket->connect(nullptr, server.getAddress(), 30);
247
248   // write some data, just so we have some way of verifing
249   // that the socket works correctly after connecting
250   char buf[128];
251   memset(buf, 'a', sizeof(buf));
252   WriteCallback wcb;
253   socket->write(&wcb, buf, sizeof(buf));
254
255   evb.loop();
256
257   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
258
259   // Make sure the server got a connection and received the data
260   socket->close();
261   server.verifyConnection(buf, sizeof(buf));
262
263   ASSERT_TRUE(socket->isClosedBySelf());
264   ASSERT_FALSE(socket->isClosedByPeer());
265 }
266
267 /**
268  * Test calling both write() and close() immediately after connecting, without
269  * waiting for connect to finish.
270  *
271  * This exercises the STATE_CONNECTING_CLOSING code.
272  */
273 TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) {
274   TestServer server;
275
276   // connect()
277   EventBase evb;
278   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
279   if (GetParam() == TFOState::ENABLED) {
280     socket->enableTFO();
281   }
282   ConnCallback ccb;
283   socket->connect(&ccb, server.getAddress(), 30);
284
285   // write()
286   char buf[128];
287   memset(buf, 'a', sizeof(buf));
288   WriteCallback wcb;
289   socket->write(&wcb, buf, sizeof(buf));
290
291   // close()
292   socket->close();
293
294   // Loop.  We don't bother accepting on the server socket yet.
295   // The kernel should be able to buffer the write request so it can succeed.
296   evb.loop();
297
298   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
299   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
300
301   // Make sure the server got a connection and received the data
302   server.verifyConnection(buf, sizeof(buf));
303
304   ASSERT_TRUE(socket->isClosedBySelf());
305   ASSERT_FALSE(socket->isClosedByPeer());
306 }
307
308 /**
309  * Test calling close() immediately after connect()
310  */
311 TEST(AsyncSocketTest, ConnectAndClose) {
312   TestServer server;
313
314   // Connect using a AsyncSocket
315   EventBase evb;
316   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
317   ConnCallback ccb;
318   socket->connect(&ccb, server.getAddress(), 30);
319
320   // Hopefully the connect didn't succeed immediately.
321   // If it did, we can't exercise the close-while-connecting code path.
322   if (ccb.state == STATE_SUCCEEDED) {
323     LOG(INFO) << "connect() succeeded immediately; aborting test "
324                        "of close-during-connect behavior";
325     return;
326   }
327
328   socket->close();
329
330   // Loop, although there shouldn't be anything to do.
331   evb.loop();
332
333   // Make sure the connection was aborted
334   ASSERT_EQ(ccb.state, STATE_FAILED);
335
336   ASSERT_TRUE(socket->isClosedBySelf());
337   ASSERT_FALSE(socket->isClosedByPeer());
338 }
339
340 /**
341  * Test calling closeNow() immediately after connect()
342  *
343  * This should be identical to the normal close behavior.
344  */
345 TEST(AsyncSocketTest, ConnectAndCloseNow) {
346   TestServer server;
347
348   // Connect using a AsyncSocket
349   EventBase evb;
350   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
351   ConnCallback ccb;
352   socket->connect(&ccb, server.getAddress(), 30);
353
354   // Hopefully the connect didn't succeed immediately.
355   // If it did, we can't exercise the close-while-connecting code path.
356   if (ccb.state == STATE_SUCCEEDED) {
357     LOG(INFO) << "connect() succeeded immediately; aborting test "
358                        "of closeNow()-during-connect behavior";
359     return;
360   }
361
362   socket->closeNow();
363
364   // Loop, although there shouldn't be anything to do.
365   evb.loop();
366
367   // Make sure the connection was aborted
368   ASSERT_EQ(ccb.state, STATE_FAILED);
369
370   ASSERT_TRUE(socket->isClosedBySelf());
371   ASSERT_FALSE(socket->isClosedByPeer());
372 }
373
374 /**
375  * Test calling both write() and closeNow() immediately after connecting,
376  * without waiting for connect to finish.
377  *
378  * This should abort the pending write.
379  */
380 TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
381   TestServer server;
382
383   // connect()
384   EventBase evb;
385   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
386   ConnCallback ccb;
387   socket->connect(&ccb, server.getAddress(), 30);
388
389   // Hopefully the connect didn't succeed immediately.
390   // If it did, we can't exercise the close-while-connecting code path.
391   if (ccb.state == STATE_SUCCEEDED) {
392     LOG(INFO) << "connect() succeeded immediately; aborting test "
393                        "of write-during-connect behavior";
394     return;
395   }
396
397   // write()
398   char buf[128];
399   memset(buf, 'a', sizeof(buf));
400   WriteCallback wcb;
401   socket->write(&wcb, buf, sizeof(buf));
402
403   // close()
404   socket->closeNow();
405
406   // Loop, although there shouldn't be anything to do.
407   evb.loop();
408
409   ASSERT_EQ(ccb.state, STATE_FAILED);
410   ASSERT_EQ(wcb.state, STATE_FAILED);
411
412   ASSERT_TRUE(socket->isClosedBySelf());
413   ASSERT_FALSE(socket->isClosedByPeer());
414 }
415
416 /**
417  * Test installing a read callback immediately, before connect() finishes.
418  */
419 TEST_P(AsyncSocketConnectTest, ConnectAndRead) {
420   TestServer server;
421
422   // connect()
423   EventBase evb;
424   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
425   if (GetParam() == TFOState::ENABLED) {
426     socket->enableTFO();
427   }
428
429   ConnCallback ccb;
430   socket->connect(&ccb, server.getAddress(), 30);
431
432   ReadCallback rcb;
433   socket->setReadCB(&rcb);
434
435   if (GetParam() == TFOState::ENABLED) {
436     // Trigger a connection
437     socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
438   }
439
440   // Even though we haven't looped yet, we should be able to accept
441   // the connection and send data to it.
442   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
443   uint8_t buf[128];
444   memset(buf, 'a', sizeof(buf));
445   acceptedSocket->write(buf, sizeof(buf));
446   acceptedSocket->flush();
447   acceptedSocket->close();
448
449   // Loop, although there shouldn't be anything to do.
450   evb.loop();
451
452   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
453   ASSERT_EQ(rcb.buffers.size(), 1);
454   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf));
455   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
456
457   ASSERT_FALSE(socket->isClosedBySelf());
458   ASSERT_FALSE(socket->isClosedByPeer());
459 }
460
461 /**
462  * Test installing a read callback and then closing immediately before the
463  * connect attempt finishes.
464  */
465 TEST(AsyncSocketTest, ConnectReadAndClose) {
466   TestServer server;
467
468   // connect()
469   EventBase evb;
470   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
471   ConnCallback ccb;
472   socket->connect(&ccb, server.getAddress(), 30);
473
474   // Hopefully the connect didn't succeed immediately.
475   // If it did, we can't exercise the close-while-connecting code path.
476   if (ccb.state == STATE_SUCCEEDED) {
477     LOG(INFO) << "connect() succeeded immediately; aborting test "
478                        "of read-during-connect behavior";
479     return;
480   }
481
482   ReadCallback rcb;
483   socket->setReadCB(&rcb);
484
485   // close()
486   socket->close();
487
488   // Loop, although there shouldn't be anything to do.
489   evb.loop();
490
491   ASSERT_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
492   ASSERT_EQ(rcb.buffers.size(), 0);
493   ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
494
495   ASSERT_TRUE(socket->isClosedBySelf());
496   ASSERT_FALSE(socket->isClosedByPeer());
497 }
498
499 /**
500  * Test both writing and installing a read callback immediately,
501  * before connect() finishes.
502  */
503 TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) {
504   TestServer server;
505
506   // connect()
507   EventBase evb;
508   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
509   if (GetParam() == TFOState::ENABLED) {
510     socket->enableTFO();
511   }
512   ConnCallback ccb;
513   socket->connect(&ccb, server.getAddress(), 30);
514
515   // write()
516   char buf1[128];
517   memset(buf1, 'a', sizeof(buf1));
518   WriteCallback wcb;
519   socket->write(&wcb, buf1, sizeof(buf1));
520
521   // set a read callback
522   ReadCallback rcb;
523   socket->setReadCB(&rcb);
524
525   // Even though we haven't looped yet, we should be able to accept
526   // the connection and send data to it.
527   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
528   uint8_t buf2[128];
529   memset(buf2, 'b', sizeof(buf2));
530   acceptedSocket->write(buf2, sizeof(buf2));
531   acceptedSocket->flush();
532
533   // shut down the write half of acceptedSocket, so that the AsyncSocket
534   // will stop reading and we can break out of the event loop.
535   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
536
537   // Loop
538   evb.loop();
539
540   // Make sure the connect succeeded
541   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
542
543   // Make sure the AsyncSocket read the data written by the accepted socket
544   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
545   ASSERT_EQ(rcb.buffers.size(), 1);
546   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf2));
547   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);
548
549   // Close the AsyncSocket so we'll see EOF on acceptedSocket
550   socket->close();
551
552   // Make sure the accepted socket saw the data written by the AsyncSocket
553   uint8_t readbuf[sizeof(buf1)];
554   acceptedSocket->readAll(readbuf, sizeof(readbuf));
555   ASSERT_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
556   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
557   ASSERT_EQ(bytesRead, 0);
558
559   ASSERT_FALSE(socket->isClosedBySelf());
560   ASSERT_TRUE(socket->isClosedByPeer());
561 }
562
563 /**
564  * Test writing to the socket then shutting down writes before the connect
565  * attempt finishes.
566  */
567 TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
568   TestServer server;
569
570   // connect()
571   EventBase evb;
572   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
573   ConnCallback ccb;
574   socket->connect(&ccb, server.getAddress(), 30);
575
576   // Hopefully the connect didn't succeed immediately.
577   // If it did, we can't exercise the write-while-connecting code path.
578   if (ccb.state == STATE_SUCCEEDED) {
579     LOG(INFO) << "connect() succeeded immediately; skipping test";
580     return;
581   }
582
583   // Ask to write some data
584   char wbuf[128];
585   memset(wbuf, 'a', sizeof(wbuf));
586   WriteCallback wcb;
587   socket->write(&wcb, wbuf, sizeof(wbuf));
588   socket->shutdownWrite();
589
590   // Shutdown writes
591   socket->shutdownWrite();
592
593   // Even though we haven't looped yet, we should be able to accept
594   // the connection.
595   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
596
597   // Since the connection is still in progress, there should be no data to
598   // read yet.  Verify that the accepted socket is not readable.
599   struct pollfd fds[1];
600   fds[0].fd = acceptedSocket->getSocketFD();
601   fds[0].events = POLLIN;
602   fds[0].revents = 0;
603   int rc = poll(fds, 1, 0);
604   ASSERT_EQ(rc, 0);
605
606   // Write data to the accepted socket
607   uint8_t acceptedWbuf[192];
608   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
609   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
610   acceptedSocket->flush();
611
612   // Loop
613   evb.loop();
614
615   // The loop should have completed the connection, written the queued data,
616   // and shutdown writes on the socket.
617   //
618   // Check that the connection was completed successfully and that the write
619   // callback succeeded.
620   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
621   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
622
623   // Check that we can read the data that was written to the socket, and that
624   // we see an EOF, since its socket was half-shutdown.
625   uint8_t readbuf[sizeof(wbuf)];
626   acceptedSocket->readAll(readbuf, sizeof(readbuf));
627   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
628   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
629   ASSERT_EQ(bytesRead, 0);
630
631   // Close the accepted socket.  This will cause it to see EOF
632   // and uninstall the read callback when we loop next.
633   acceptedSocket->close();
634
635   // Install a read callback, then loop again.
636   ReadCallback rcb;
637   socket->setReadCB(&rcb);
638   evb.loop();
639
640   // This loop should have read the data and seen the EOF
641   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
642   ASSERT_EQ(rcb.buffers.size(), 1);
643   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
644   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
645                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
646
647   ASSERT_FALSE(socket->isClosedBySelf());
648   ASSERT_FALSE(socket->isClosedByPeer());
649 }
650
651 /**
652  * Test reading, writing, and shutting down writes before the connect attempt
653  * finishes.
654  */
655 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
656   TestServer server;
657
658   // connect()
659   EventBase evb;
660   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
661   ConnCallback ccb;
662   socket->connect(&ccb, server.getAddress(), 30);
663
664   // Hopefully the connect didn't succeed immediately.
665   // If it did, we can't exercise the write-while-connecting code path.
666   if (ccb.state == STATE_SUCCEEDED) {
667     LOG(INFO) << "connect() succeeded immediately; skipping test";
668     return;
669   }
670
671   // Install a read callback
672   ReadCallback rcb;
673   socket->setReadCB(&rcb);
674
675   // Ask to write some data
676   char wbuf[128];
677   memset(wbuf, 'a', sizeof(wbuf));
678   WriteCallback wcb;
679   socket->write(&wcb, wbuf, sizeof(wbuf));
680
681   // Shutdown writes
682   socket->shutdownWrite();
683
684   // Even though we haven't looped yet, we should be able to accept
685   // the connection.
686   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
687
688   // Since the connection is still in progress, there should be no data to
689   // read yet.  Verify that the accepted socket is not readable.
690   struct pollfd fds[1];
691   fds[0].fd = acceptedSocket->getSocketFD();
692   fds[0].events = POLLIN;
693   fds[0].revents = 0;
694   int rc = poll(fds, 1, 0);
695   ASSERT_EQ(rc, 0);
696
697   // Write data to the accepted socket
698   uint8_t acceptedWbuf[192];
699   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
700   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
701   acceptedSocket->flush();
702   // Shutdown writes to the accepted socket.  This will cause it to see EOF
703   // and uninstall the read callback.
704   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
705
706   // Loop
707   evb.loop();
708
709   // The loop should have completed the connection, written the queued data,
710   // shutdown writes on the socket, read the data we wrote to it, and see the
711   // EOF.
712   //
713   // Check that the connection was completed successfully and that the read
714   // and write callbacks were invoked as expected.
715   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
716   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
717   ASSERT_EQ(rcb.buffers.size(), 1);
718   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
719   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
720                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
721   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
722
723   // Check that we can read the data that was written to the socket, and that
724   // we see an EOF, since its socket was half-shutdown.
725   uint8_t readbuf[sizeof(wbuf)];
726   acceptedSocket->readAll(readbuf, sizeof(readbuf));
727   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
728   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
729   ASSERT_EQ(bytesRead, 0);
730
731   // Fully close both sockets
732   acceptedSocket->close();
733   socket->close();
734
735   ASSERT_FALSE(socket->isClosedBySelf());
736   ASSERT_TRUE(socket->isClosedByPeer());
737 }
738
739 /**
740  * Test reading, writing, and calling shutdownWriteNow() before the
741  * connect attempt finishes.
742  */
743 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
744   TestServer server;
745
746   // connect()
747   EventBase evb;
748   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
749   ConnCallback ccb;
750   socket->connect(&ccb, server.getAddress(), 30);
751
752   // Hopefully the connect didn't succeed immediately.
753   // If it did, we can't exercise the write-while-connecting code path.
754   if (ccb.state == STATE_SUCCEEDED) {
755     LOG(INFO) << "connect() succeeded immediately; skipping test";
756     return;
757   }
758
759   // Install a read callback
760   ReadCallback rcb;
761   socket->setReadCB(&rcb);
762
763   // Ask to write some data
764   char wbuf[128];
765   memset(wbuf, 'a', sizeof(wbuf));
766   WriteCallback wcb;
767   socket->write(&wcb, wbuf, sizeof(wbuf));
768
769   // Shutdown writes immediately.
770   // This should immediately discard the data that we just tried to write.
771   socket->shutdownWriteNow();
772
773   // Verify that writeError() was invoked on the write callback.
774   ASSERT_EQ(wcb.state, STATE_FAILED);
775   ASSERT_EQ(wcb.bytesWritten, 0);
776
777   // Even though we haven't looped yet, we should be able to accept
778   // the connection.
779   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
780
781   // Since the connection is still in progress, there should be no data to
782   // read yet.  Verify that the accepted socket is not readable.
783   struct pollfd fds[1];
784   fds[0].fd = acceptedSocket->getSocketFD();
785   fds[0].events = POLLIN;
786   fds[0].revents = 0;
787   int rc = poll(fds, 1, 0);
788   ASSERT_EQ(rc, 0);
789
790   // Write data to the accepted socket
791   uint8_t acceptedWbuf[192];
792   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
793   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
794   acceptedSocket->flush();
795   // Shutdown writes to the accepted socket.  This will cause it to see EOF
796   // and uninstall the read callback.
797   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
798
799   // Loop
800   evb.loop();
801
802   // The loop should have completed the connection, written the queued data,
803   // shutdown writes on the socket, read the data we wrote to it, and see the
804   // EOF.
805   //
806   // Check that the connection was completed successfully and that the read
807   // callback was invoked as expected.
808   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
809   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
810   ASSERT_EQ(rcb.buffers.size(), 1);
811   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
812   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
813                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
814
815   // Since we used shutdownWriteNow(), it should have discarded all pending
816   // write data.  Verify we see an immediate EOF when reading from the accepted
817   // socket.
818   uint8_t readbuf[sizeof(wbuf)];
819   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
820   ASSERT_EQ(bytesRead, 0);
821
822   // Fully close both sockets
823   acceptedSocket->close();
824   socket->close();
825
826   ASSERT_FALSE(socket->isClosedBySelf());
827   ASSERT_TRUE(socket->isClosedByPeer());
828 }
829
830 // Helper function for use in testConnectOptWrite()
831 // Temporarily disable the read callback
832 void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
833   // Uninstall the read callback
834   socket->setReadCB(nullptr);
835   // Schedule the read callback to be reinstalled after 1ms
836   socket->getEventBase()->runInLoop(
837       std::bind(&AsyncSocket::setReadCB, socket, rcb));
838 }
839
840 /**
841  * Test connect+write, then have the connect callback perform another write.
842  *
843  * This tests interaction of the optimistic writing after connect with
844  * additional write attempts that occur in the connect callback.
845  */
846 void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
847   TestServer server;
848   EventBase evb;
849   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
850
851   // connect()
852   ConnCallback ccb;
853   socket->connect(&ccb, server.getAddress(), 30);
854
855   // Hopefully the connect didn't succeed immediately.
856   // If it did, we can't exercise the optimistic write code path.
857   if (ccb.state == STATE_SUCCEEDED) {
858     LOG(INFO) << "connect() succeeded immediately; aborting test "
859                        "of optimistic write behavior";
860     return;
861   }
862
863   // Tell the connect callback to perform a write when the connect succeeds
864   WriteCallback wcb2;
865   scoped_array<char> buf2(new char[size2]);
866   memset(buf2.get(), 'b', size2);
867   if (size2 > 0) {
868     ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
869     // Tell the second write callback to close the connection when it is done
870     wcb2.successCallback = [&] { socket->closeNow(); };
871   }
872
873   // Schedule one write() immediately, before the connect finishes
874   scoped_array<char> buf1(new char[size1]);
875   memset(buf1.get(), 'a', size1);
876   WriteCallback wcb1;
877   if (size1 > 0) {
878     socket->write(&wcb1, buf1.get(), size1);
879   }
880
881   if (close) {
882     // immediately perform a close, before connect() completes
883     socket->close();
884   }
885
886   // Start reading from the other endpoint after 10ms.
887   // If we're using large buffers, we have to read so that the writes don't
888   // block forever.
889   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
890   ReadCallback rcb;
891   rcb.dataAvailableCallback = std::bind(tmpDisableReads,
892                                         acceptedSocket.get(), &rcb);
893   socket->getEventBase()->tryRunAfterDelay(
894       std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb),
895       10);
896
897   // Loop.  We don't bother accepting on the server socket yet.
898   // The kernel should be able to buffer the write request so it can succeed.
899   evb.loop();
900
901   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
902   if (size1 > 0) {
903     ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
904   }
905   if (size2 > 0) {
906     ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
907   }
908
909   socket->close();
910
911   // Make sure the read callback received all of the data
912   size_t bytesRead = 0;
913   for (vector<ReadCallback::Buffer>::const_iterator it = rcb.buffers.begin();
914        it != rcb.buffers.end();
915        ++it) {
916     size_t start = bytesRead;
917     bytesRead += it->length;
918     size_t end = bytesRead;
919     if (start < size1) {
920       size_t cmpLen = min(size1, end) - start;
921       ASSERT_EQ(memcmp(it->buffer, buf1.get() + start, cmpLen), 0);
922     }
923     if (end > size1 && end <= size1 + size2) {
924       size_t itOffset;
925       size_t buf2Offset;
926       size_t cmpLen;
927       if (start >= size1) {
928         itOffset = 0;
929         buf2Offset = start - size1;
930         cmpLen = end - start;
931       } else {
932         itOffset = size1 - start;
933         buf2Offset = 0;
934         cmpLen = end - size1;
935       }
936       ASSERT_EQ(memcmp(it->buffer + itOffset, buf2.get() + buf2Offset,
937                                cmpLen),
938                         0);
939     }
940   }
941   ASSERT_EQ(bytesRead, size1 + size2);
942 }
943
944 TEST(AsyncSocketTest, ConnectCallbackWrite) {
945   // Test using small writes that should both succeed immediately
946   testConnectOptWrite(100, 200);
947
948   // Test using a large buffer in the connect callback, that should block
949   const size_t largeSize = 32 * 1024 * 1024;
950   testConnectOptWrite(100, largeSize);
951
952   // Test using a large initial write
953   testConnectOptWrite(largeSize, 100);
954
955   // Test using two large buffers
956   testConnectOptWrite(largeSize, largeSize);
957
958   // Test a small write in the connect callback,
959   // but no immediate write before connect completes
960   testConnectOptWrite(0, 64);
961
962   // Test a large write in the connect callback,
963   // but no immediate write before connect completes
964   testConnectOptWrite(0, largeSize);
965
966   // Test connect, a small write, then immediately call close() before connect
967   // completes
968   testConnectOptWrite(211, 0, true);
969
970   // Test connect, a large immediate write (that will block), then immediately
971   // call close() before connect completes
972   testConnectOptWrite(largeSize, 0, true);
973 }
974
975 ///////////////////////////////////////////////////////////////////////////
976 // write() related tests
977 ///////////////////////////////////////////////////////////////////////////
978
979 /**
980  * Test writing using a nullptr callback
981  */
982 TEST(AsyncSocketTest, WriteNullCallback) {
983   TestServer server;
984
985   // connect()
986   EventBase evb;
987   std::shared_ptr<AsyncSocket> socket =
988     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
989   evb.loop(); // loop until the socket is connected
990
991   // write() with a nullptr callback
992   char buf[128];
993   memset(buf, 'a', sizeof(buf));
994   socket->write(nullptr, buf, sizeof(buf));
995
996   evb.loop(); // loop until the data is sent
997
998   // Make sure the server got a connection and received the data
999   socket->close();
1000   server.verifyConnection(buf, sizeof(buf));
1001
1002   ASSERT_TRUE(socket->isClosedBySelf());
1003   ASSERT_FALSE(socket->isClosedByPeer());
1004 }
1005
1006 /**
1007  * Test writing with a send timeout
1008  */
1009 TEST(AsyncSocketTest, WriteTimeout) {
1010   TestServer server;
1011
1012   // connect()
1013   EventBase evb;
1014   std::shared_ptr<AsyncSocket> socket =
1015     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1016   evb.loop(); // loop until the socket is connected
1017
1018   // write() a large chunk of data, with no-one on the other end reading.
1019   // Tricky: the kernel caches the connection metrics for recently-used
1020   // routes (see tcp_no_metrics_save) so a freshly opened connection can
1021   // have a send buffer size bigger than wmem_default.  This makes the test
1022   // flaky on contbuild if writeLength is < wmem_max (20M on our systems).
1023   size_t writeLength = 32 * 1024 * 1024;
1024   uint32_t timeout = 200;
1025   socket->setSendTimeout(timeout);
1026   scoped_array<char> buf(new char[writeLength]);
1027   memset(buf.get(), 'a', writeLength);
1028   WriteCallback wcb;
1029   socket->write(&wcb, buf.get(), writeLength);
1030
1031   TimePoint start;
1032   evb.loop();
1033   TimePoint end;
1034
1035   // Make sure the write attempt timed out as requested
1036   ASSERT_EQ(wcb.state, STATE_FAILED);
1037   ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);
1038
1039   // Check that the write timed out within a reasonable period of time.
1040   // We don't check for exactly the specified timeout, since AsyncSocket only
1041   // times out when it hasn't made progress for that period of time.
1042   //
1043   // On linux, the first write sends a few hundred kb of data, then blocks for
1044   // writability, and then unblocks again after 40ms and is able to write
1045   // another smaller of data before blocking permanently.  Therefore it doesn't
1046   // time out until 40ms + timeout.
1047   //
1048   // I haven't fully verified the cause of this, but I believe it probably
1049   // occurs because the receiving end delays sending an ack for up to 40ms.
1050   // (This is the default value for TCP_DELACK_MIN.)  Once the sender receives
1051   // the ack, it can send some more data.  However, after that point the
1052   // receiver's kernel buffer is full.  This 40ms delay happens even with
1053   // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints.  However, the
1054   // kernel may be automatically disabling TCP_QUICKACK after receiving some
1055   // data.
1056   //
1057   // For now, we simply check that the timeout occurred within 160ms of
1058   // the requested value.
1059   T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
1060 }
1061
1062 /**
1063  * Test writing to a socket that the remote endpoint has closed
1064  */
1065 TEST(AsyncSocketTest, WritePipeError) {
1066   TestServer server;
1067
1068   // connect()
1069   EventBase evb;
1070   std::shared_ptr<AsyncSocket> socket =
1071     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1072   socket->setSendTimeout(1000);
1073   evb.loop(); // loop until the socket is connected
1074
1075   // accept and immediately close the socket
1076   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1077   acceptedSocket->close();
1078
1079   // write() a large chunk of data
1080   size_t writeLength = 32 * 1024 * 1024;
1081   scoped_array<char> buf(new char[writeLength]);
1082   memset(buf.get(), 'a', writeLength);
1083   WriteCallback wcb;
1084   socket->write(&wcb, buf.get(), writeLength);
1085
1086   evb.loop();
1087
1088   // Make sure the write failed.
1089   // It would be nice if AsyncSocketException could convey the errno value,
1090   // so that we could check for EPIPE
1091   ASSERT_EQ(wcb.state, STATE_FAILED);
1092   ASSERT_EQ(wcb.exception.getType(),
1093                     AsyncSocketException::INTERNAL_ERROR);
1094
1095   ASSERT_FALSE(socket->isClosedBySelf());
1096   ASSERT_FALSE(socket->isClosedByPeer());
1097 }
1098
1099 /**
1100  * Test writing to a socket that has its read side closed
1101  */
1102 TEST(AsyncSocketTest, WriteAfterReadEOF) {
1103   TestServer server;
1104
1105   // connect()
1106   EventBase evb;
1107   std::shared_ptr<AsyncSocket> socket =
1108       AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1109   evb.loop(); // loop until the socket is connected
1110
1111   // Accept the connection
1112   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1113   ReadCallback rcb;
1114   acceptedSocket->setReadCB(&rcb);
1115
1116   // Shutdown the write side of client socket (read side of server socket)
1117   socket->shutdownWrite();
1118   evb.loop();
1119
1120   // Check that accepted socket is still writable
1121   ASSERT_FALSE(acceptedSocket->good());
1122   ASSERT_TRUE(acceptedSocket->writable());
1123
1124   // Write data to accepted socket
1125   constexpr size_t simpleBufLength = 5;
1126   char simpleBuf[simpleBufLength];
1127   memset(simpleBuf, 'a', simpleBufLength);
1128   WriteCallback wcb;
1129   acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
1130   evb.loop();
1131
1132   // Make sure we were able to write even after getting a read EOF
1133   ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
1134   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1135 }
1136
1137 /**
1138  * Test that bytes written is correctly computed in case of write failure
1139  */
1140 TEST(AsyncSocketTest, WriteErrorCallbackBytesWritten) {
1141   // Send and receive buffer sizes for the sockets.
1142   const int sockBufSize = 8 * 1024;
1143
1144   TestServer server(false, sockBufSize);
1145
1146   AsyncSocket::OptionMap options{
1147       {{SOL_SOCKET, SO_SNDBUF}, sockBufSize},
1148       {{SOL_SOCKET, SO_RCVBUF}, sockBufSize},
1149       {{IPPROTO_TCP, TCP_NODELAY}, 1},
1150   };
1151
1152   // The current thread will be used by the receiver - use a separate thread
1153   // for the sender.
1154   EventBase senderEvb;
1155   std::thread senderThread([&]() { senderEvb.loopForever(); });
1156
1157   ConnCallback ccb;
1158   std::shared_ptr<AsyncSocket> socket;
1159
1160   senderEvb.runInEventBaseThreadAndWait([&]() {
1161     socket = AsyncSocket::newSocket(&senderEvb);
1162     socket->connect(&ccb, server.getAddress(), 30, options);
1163   });
1164
1165   // accept the socket on the server side
1166   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1167
1168   // Send a big (45KB) write so that it is partially written. The first write
1169   // is 16KB (8KB on both sides) and subsequent writes are 8KB each. Reading
1170   // just under 24KB would cause 3-4 writes for the total of 32-40KB in the
1171   // following sequence: 16KB + 8KB + 8KB (+ 8KB). This ensures that not all
1172   // bytes are written when the socket is reset. Having at least 3 writes
1173   // ensures that the total size (45KB) would be exceeed in case of overcounting
1174   // based on the initial write size of 16KB.
1175   constexpr size_t sendSize = 45 * 1024;
1176   auto const sendBuf = std::vector<char>(sendSize, 'a');
1177
1178   WriteCallback wcb;
1179
1180   senderEvb.runInEventBaseThreadAndWait(
1181       [&]() { socket->write(&wcb, sendBuf.data(), sendSize); });
1182
1183   // Reading 20KB would cause three additional writes of 8KB, but less
1184   // than 45KB total, so the socket is reset before all bytes are written.
1185   constexpr size_t recvSize = 20 * 1024;
1186   uint8_t recvBuf[recvSize];
1187   int bytesRead = acceptedSocket->readAll(recvBuf, sizeof(recvBuf));
1188
1189   acceptedSocket->closeWithReset();
1190
1191   senderEvb.terminateLoopSoon();
1192   senderThread.join();
1193
1194   LOG(INFO) << "Bytes written: " << wcb.bytesWritten;
1195
1196   ASSERT_EQ(STATE_FAILED, wcb.state);
1197   ASSERT_GE(wcb.bytesWritten, bytesRead);
1198   ASSERT_LE(wcb.bytesWritten, sendSize);
1199   ASSERT_EQ(recvSize, bytesRead);
1200   ASSERT(32 * 1024 == wcb.bytesWritten || 40 * 1024 == wcb.bytesWritten);
1201 }
1202
1203 /**
1204  * Test writing a mix of simple buffers and IOBufs
1205  */
1206 TEST(AsyncSocketTest, WriteIOBuf) {
1207   TestServer server;
1208
1209   // connect()
1210   EventBase evb;
1211   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1212   ConnCallback ccb;
1213   socket->connect(&ccb, server.getAddress(), 30);
1214
1215   // Accept the connection
1216   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1217   ReadCallback rcb;
1218   acceptedSocket->setReadCB(&rcb);
1219
1220   // Check if EOR tracking flag can be set and reset.
1221   EXPECT_FALSE(socket->isEorTrackingEnabled());
1222   socket->setEorTracking(true);
1223   EXPECT_TRUE(socket->isEorTrackingEnabled());
1224   socket->setEorTracking(false);
1225   EXPECT_FALSE(socket->isEorTrackingEnabled());
1226
1227   // Write a simple buffer to the socket
1228   constexpr size_t simpleBufLength = 5;
1229   char simpleBuf[simpleBufLength];
1230   memset(simpleBuf, 'a', simpleBufLength);
1231   WriteCallback wcb;
1232   socket->write(&wcb, simpleBuf, simpleBufLength);
1233
1234   // Write a single-element IOBuf chain
1235   size_t buf1Length = 7;
1236   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1237   memset(buf1->writableData(), 'b', buf1Length);
1238   buf1->append(buf1Length);
1239   unique_ptr<IOBuf> buf1Copy(buf1->clone());
1240   WriteCallback wcb2;
1241   socket->writeChain(&wcb2, std::move(buf1));
1242
1243   // Write a multiple-element IOBuf chain
1244   size_t buf2Length = 11;
1245   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1246   memset(buf2->writableData(), 'c', buf2Length);
1247   buf2->append(buf2Length);
1248   size_t buf3Length = 13;
1249   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1250   memset(buf3->writableData(), 'd', buf3Length);
1251   buf3->append(buf3Length);
1252   buf2->appendChain(std::move(buf3));
1253   unique_ptr<IOBuf> buf2Copy(buf2->clone());
1254   buf2Copy->coalesce();
1255   WriteCallback wcb3;
1256   socket->writeChain(&wcb3, std::move(buf2));
1257   socket->shutdownWrite();
1258
1259   // Let the reads and writes run to completion
1260   evb.loop();
1261
1262   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1263   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1264   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1265
1266   // Make sure the reader got the right data in the right order
1267   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1268   ASSERT_EQ(rcb.buffers.size(), 1);
1269   ASSERT_EQ(rcb.buffers[0].length,
1270       simpleBufLength + buf1Length + buf2Length + buf3Length);
1271   ASSERT_EQ(
1272       memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
1273   ASSERT_EQ(
1274       memcmp(rcb.buffers[0].buffer + simpleBufLength,
1275           buf1Copy->data(), buf1Copy->length()), 0);
1276   ASSERT_EQ(
1277       memcmp(rcb.buffers[0].buffer + simpleBufLength + buf1Length,
1278           buf2Copy->data(), buf2Copy->length()), 0);
1279
1280   acceptedSocket->close();
1281   socket->close();
1282
1283   ASSERT_TRUE(socket->isClosedBySelf());
1284   ASSERT_FALSE(socket->isClosedByPeer());
1285 }
1286
1287 TEST(AsyncSocketTest, WriteIOBufCorked) {
1288   TestServer server;
1289
1290   // connect()
1291   EventBase evb;
1292   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1293   ConnCallback ccb;
1294   socket->connect(&ccb, server.getAddress(), 30);
1295
1296   // Accept the connection
1297   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1298   ReadCallback rcb;
1299   acceptedSocket->setReadCB(&rcb);
1300
1301   // Do three writes, 100ms apart, with the "cork" flag set
1302   // on the second write.  The reader should see the first write
1303   // arrive by itself, followed by the second and third writes
1304   // arriving together.
1305   size_t buf1Length = 5;
1306   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1307   memset(buf1->writableData(), 'a', buf1Length);
1308   buf1->append(buf1Length);
1309   size_t buf2Length = 7;
1310   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1311   memset(buf2->writableData(), 'b', buf2Length);
1312   buf2->append(buf2Length);
1313   size_t buf3Length = 11;
1314   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1315   memset(buf3->writableData(), 'c', buf3Length);
1316   buf3->append(buf3Length);
1317   WriteCallback wcb1;
1318   socket->writeChain(&wcb1, std::move(buf1));
1319   WriteCallback wcb2;
1320   DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
1321   write2.scheduleTimeout(100);
1322   WriteCallback wcb3;
1323   DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
1324   write3.scheduleTimeout(140);
1325
1326   evb.loop();
1327   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1328   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1329   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1330   if (wcb3.state != STATE_SUCCEEDED) {
1331     throw(wcb3.exception);
1332   }
1333   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1334
1335   // Make sure the reader got the data with the right grouping
1336   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1337   ASSERT_EQ(rcb.buffers.size(), 2);
1338   ASSERT_EQ(rcb.buffers[0].length, buf1Length);
1339   ASSERT_EQ(rcb.buffers[1].length, buf2Length + buf3Length);
1340
1341   acceptedSocket->close();
1342   socket->close();
1343
1344   ASSERT_TRUE(socket->isClosedBySelf());
1345   ASSERT_FALSE(socket->isClosedByPeer());
1346 }
1347
1348 /**
1349  * Test performing a zero-length write
1350  */
1351 TEST(AsyncSocketTest, ZeroLengthWrite) {
1352   TestServer server;
1353
1354   // connect()
1355   EventBase evb;
1356   std::shared_ptr<AsyncSocket> socket =
1357     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1358   evb.loop(); // loop until the socket is connected
1359
1360   auto acceptedSocket = server.acceptAsync(&evb);
1361   ReadCallback rcb;
1362   acceptedSocket->setReadCB(&rcb);
1363
1364   size_t len1 = 1024*1024;
1365   size_t len2 = 1024*1024;
1366   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1367   memset(buf.get(), 'a', len1);
1368   memset(buf.get(), 'b', len2);
1369
1370   WriteCallback wcb1;
1371   WriteCallback wcb2;
1372   WriteCallback wcb3;
1373   WriteCallback wcb4;
1374   socket->write(&wcb1, buf.get(), 0);
1375   socket->write(&wcb2, buf.get(), len1);
1376   socket->write(&wcb3, buf.get() + len1, 0);
1377   socket->write(&wcb4, buf.get() + len1, len2);
1378   socket->close();
1379
1380   evb.loop(); // loop until the data is sent
1381
1382   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1383   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1384   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1385   ASSERT_EQ(wcb4.state, STATE_SUCCEEDED);
1386   rcb.verifyData(buf.get(), len1 + len2);
1387
1388   ASSERT_TRUE(socket->isClosedBySelf());
1389   ASSERT_FALSE(socket->isClosedByPeer());
1390 }
1391
1392 TEST(AsyncSocketTest, ZeroLengthWritev) {
1393   TestServer server;
1394
1395   // connect()
1396   EventBase evb;
1397   std::shared_ptr<AsyncSocket> socket =
1398     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1399   evb.loop(); // loop until the socket is connected
1400
1401   auto acceptedSocket = server.acceptAsync(&evb);
1402   ReadCallback rcb;
1403   acceptedSocket->setReadCB(&rcb);
1404
1405   size_t len1 = 1024*1024;
1406   size_t len2 = 1024*1024;
1407   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1408   memset(buf.get(), 'a', len1);
1409   memset(buf.get(), 'b', len2);
1410
1411   WriteCallback wcb;
1412   constexpr size_t iovCount = 4;
1413   struct iovec iov[iovCount];
1414   iov[0].iov_base = buf.get();
1415   iov[0].iov_len = len1;
1416   iov[1].iov_base = buf.get() + len1;
1417   iov[1].iov_len = 0;
1418   iov[2].iov_base = buf.get() + len1;
1419   iov[2].iov_len = len2;
1420   iov[3].iov_base = buf.get() + len1 + len2;
1421   iov[3].iov_len = 0;
1422
1423   socket->writev(&wcb, iov, iovCount);
1424   socket->close();
1425   evb.loop(); // loop until the data is sent
1426
1427   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1428   rcb.verifyData(buf.get(), len1 + len2);
1429
1430   ASSERT_TRUE(socket->isClosedBySelf());
1431   ASSERT_FALSE(socket->isClosedByPeer());
1432 }
1433
1434 ///////////////////////////////////////////////////////////////////////////
1435 // close() related tests
1436 ///////////////////////////////////////////////////////////////////////////
1437
1438 /**
1439  * Test calling close() with pending writes when the socket is already closing.
1440  */
1441 TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
1442   TestServer server;
1443
1444   // connect()
1445   EventBase evb;
1446   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1447   ConnCallback ccb;
1448   socket->connect(&ccb, server.getAddress(), 30);
1449
1450   // accept the socket on the server side
1451   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1452
1453   // Loop to ensure the connect has completed
1454   evb.loop();
1455
1456   // Make sure we are connected
1457   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1458
1459   // Schedule pending writes, until several write attempts have blocked
1460   char buf[128];
1461   memset(buf, 'a', sizeof(buf));
1462   typedef vector< std::shared_ptr<WriteCallback> > WriteCallbackVector;
1463   WriteCallbackVector writeCallbacks;
1464
1465   writeCallbacks.reserve(5);
1466   while (writeCallbacks.size() < 5) {
1467     std::shared_ptr<WriteCallback> wcb(new WriteCallback);
1468
1469     socket->write(wcb.get(), buf, sizeof(buf));
1470     if (wcb->state == STATE_SUCCEEDED) {
1471       // Succeeded immediately.  Keep performing more writes
1472       continue;
1473     }
1474
1475     // This write is blocked.
1476     // Have the write callback call close() when writeError() is invoked
1477     wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
1478     writeCallbacks.push_back(wcb);
1479   }
1480
1481   // Call closeNow() to immediately fail the pending writes
1482   socket->closeNow();
1483
1484   // Make sure writeError() was invoked on all of the pending write callbacks
1485   for (WriteCallbackVector::const_iterator it = writeCallbacks.begin();
1486        it != writeCallbacks.end();
1487        ++it) {
1488     ASSERT_EQ((*it)->state, STATE_FAILED);
1489   }
1490
1491   ASSERT_TRUE(socket->isClosedBySelf());
1492   ASSERT_FALSE(socket->isClosedByPeer());
1493 }
1494
1495 ///////////////////////////////////////////////////////////////////////////
1496 // ImmediateRead related tests
1497 ///////////////////////////////////////////////////////////////////////////
1498
1499 /* AsyncSocket use to verify immediate read works */
1500 class AsyncSocketImmediateRead : public folly::AsyncSocket {
1501  public:
1502   bool immediateReadCalled = false;
1503   explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}
1504  protected:
1505   void checkForImmediateRead() noexcept override {
1506     immediateReadCalled = true;
1507     AsyncSocket::handleRead();
1508   }
1509 };
1510
1511 TEST(AsyncSocket, ConnectReadImmediateRead) {
1512   TestServer server;
1513
1514   const size_t maxBufferSz = 100;
1515   const size_t maxReadsPerEvent = 1;
1516   const size_t expectedDataSz = maxBufferSz * 3;
1517   char expectedData[expectedDataSz];
1518   memset(expectedData, 'j', expectedDataSz);
1519
1520   EventBase evb;
1521   ReadCallback rcb(maxBufferSz);
1522   AsyncSocketImmediateRead socket(&evb);
1523   socket.connect(nullptr, server.getAddress(), 30);
1524
1525   evb.loop(); // loop until the socket is connected
1526
1527   socket.setReadCB(&rcb);
1528   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1529   socket.immediateReadCalled = false;
1530
1531   auto acceptedSocket = server.acceptAsync(&evb);
1532
1533   ReadCallback rcbServer;
1534   WriteCallback wcbServer;
1535   rcbServer.dataAvailableCallback = [&]() {
1536     if (rcbServer.dataRead() == expectedDataSz) {
1537       // write back all data read
1538       rcbServer.verifyData(expectedData, expectedDataSz);
1539       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1540       acceptedSocket->close();
1541     }
1542   };
1543   acceptedSocket->setReadCB(&rcbServer);
1544
1545   // write data
1546   WriteCallback wcb1;
1547   socket.write(&wcb1, expectedData, expectedDataSz);
1548   evb.loop();
1549   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1550   rcb.verifyData(expectedData, expectedDataSz);
1551   ASSERT_EQ(socket.immediateReadCalled, true);
1552
1553   ASSERT_FALSE(socket.isClosedBySelf());
1554   ASSERT_FALSE(socket.isClosedByPeer());
1555 }
1556
1557 TEST(AsyncSocket, ConnectReadUninstallRead) {
1558   TestServer server;
1559
1560   const size_t maxBufferSz = 100;
1561   const size_t maxReadsPerEvent = 1;
1562   const size_t expectedDataSz = maxBufferSz * 3;
1563   char expectedData[expectedDataSz];
1564   memset(expectedData, 'k', expectedDataSz);
1565
1566   EventBase evb;
1567   ReadCallback rcb(maxBufferSz);
1568   AsyncSocketImmediateRead socket(&evb);
1569   socket.connect(nullptr, server.getAddress(), 30);
1570
1571   evb.loop(); // loop until the socket is connected
1572
1573   socket.setReadCB(&rcb);
1574   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1575   socket.immediateReadCalled = false;
1576
1577   auto acceptedSocket = server.acceptAsync(&evb);
1578
1579   ReadCallback rcbServer;
1580   WriteCallback wcbServer;
1581   rcbServer.dataAvailableCallback = [&]() {
1582     if (rcbServer.dataRead() == expectedDataSz) {
1583       // write back all data read
1584       rcbServer.verifyData(expectedData, expectedDataSz);
1585       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1586       acceptedSocket->close();
1587     }
1588   };
1589   acceptedSocket->setReadCB(&rcbServer);
1590
1591   rcb.dataAvailableCallback = [&]() {
1592     // we read data and reset readCB
1593     socket.setReadCB(nullptr);
1594   };
1595
1596   // write data
1597   WriteCallback wcb;
1598   socket.write(&wcb, expectedData, expectedDataSz);
1599   evb.loop();
1600   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1601
1602   /* we shoud've only read maxBufferSz data since readCallback_
1603    * was reset in dataAvailableCallback */
1604   ASSERT_EQ(rcb.dataRead(), maxBufferSz);
1605   ASSERT_EQ(socket.immediateReadCalled, false);
1606
1607   ASSERT_FALSE(socket.isClosedBySelf());
1608   ASSERT_FALSE(socket.isClosedByPeer());
1609 }
1610
1611 // TODO:
1612 // - Test connect() and have the connect callback set the read callback
1613 // - Test connect() and have the connect callback unset the read callback
1614 // - Test reading/writing/closing/destroying the socket in the connect callback
1615 // - Test reading/writing/closing/destroying the socket in the read callback
1616 // - Test reading/writing/closing/destroying the socket in the write callback
1617 // - Test one-way shutdown behavior
1618 // - Test changing the EventBase
1619 //
1620 // - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
1621 //   in connectSuccess(), readDataAvailable(), writeSuccess()
1622
1623
1624 ///////////////////////////////////////////////////////////////////////////
1625 // AsyncServerSocket tests
1626 ///////////////////////////////////////////////////////////////////////////
1627
1628 /**
1629  * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
1630  */
1631 TEST(AsyncSocketTest, ServerAcceptOptions) {
1632   EventBase eventBase;
1633
1634   // Create a server socket
1635   std::shared_ptr<AsyncServerSocket> serverSocket(
1636       AsyncServerSocket::newSocket(&eventBase));
1637   serverSocket->bind(0);
1638   serverSocket->listen(16);
1639   folly::SocketAddress serverAddress;
1640   serverSocket->getAddress(&serverAddress);
1641
1642   // Add a callback to accept one connection then stop the loop
1643   TestAcceptCallback acceptCallback;
1644   acceptCallback.setConnectionAcceptedFn(
1645       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1646         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1647       });
1648   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
1649     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1650   });
1651   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
1652   serverSocket->startAccepting();
1653
1654   // Connect to the server socket
1655   std::shared_ptr<AsyncSocket> socket(
1656       AsyncSocket::newSocket(&eventBase, serverAddress));
1657
1658   eventBase.loop();
1659
1660   // Verify that the server accepted a connection
1661   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
1662   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
1663                     TestAcceptCallback::TYPE_START);
1664   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
1665                     TestAcceptCallback::TYPE_ACCEPT);
1666   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
1667                     TestAcceptCallback::TYPE_STOP);
1668   int fd = acceptCallback.getEvents()->at(1).fd;
1669
1670   // The accepted connection should already be in non-blocking mode
1671   int flags = fcntl(fd, F_GETFL, 0);
1672   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
1673
1674 #ifndef TCP_NOPUSH
1675   // The accepted connection should already have TCP_NODELAY set
1676   int value;
1677   socklen_t valueLength = sizeof(value);
1678   int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
1679   ASSERT_EQ(rc, 0);
1680   ASSERT_EQ(value, 1);
1681 #endif
1682 }
1683
1684 /**
1685  * Test AsyncServerSocket::removeAcceptCallback()
1686  */
1687 TEST(AsyncSocketTest, RemoveAcceptCallback) {
1688   // Create a new AsyncServerSocket
1689   EventBase eventBase;
1690   std::shared_ptr<AsyncServerSocket> serverSocket(
1691       AsyncServerSocket::newSocket(&eventBase));
1692   serverSocket->bind(0);
1693   serverSocket->listen(16);
1694   folly::SocketAddress serverAddress;
1695   serverSocket->getAddress(&serverAddress);
1696
1697   // Add several accept callbacks
1698   TestAcceptCallback cb1;
1699   TestAcceptCallback cb2;
1700   TestAcceptCallback cb3;
1701   TestAcceptCallback cb4;
1702   TestAcceptCallback cb5;
1703   TestAcceptCallback cb6;
1704   TestAcceptCallback cb7;
1705
1706   // Test having callbacks remove other callbacks before them on the list,
1707   // after them on the list, or removing themselves.
1708   //
1709   // Have callback 2 remove callback 3 and callback 5 the first time it is
1710   // called.
1711   int cb2Count = 0;
1712   cb1.setConnectionAcceptedFn([&](int /* fd */,
1713                                   const folly::SocketAddress& /* addr */) {
1714     std::shared_ptr<AsyncSocket> sock2(
1715         AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2: -cb3 -cb5
1716   });
1717   cb3.setConnectionAcceptedFn(
1718       [&](int /* fd */, const folly::SocketAddress& /* addr */) {});
1719   cb4.setConnectionAcceptedFn(
1720       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1721         std::shared_ptr<AsyncSocket> sock3(
1722             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
1723       });
1724   cb5.setConnectionAcceptedFn(
1725       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1726         std::shared_ptr<AsyncSocket> sock5(
1727             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
1728
1729       });
1730   cb2.setConnectionAcceptedFn(
1731       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1732         if (cb2Count == 0) {
1733           serverSocket->removeAcceptCallback(&cb3, nullptr);
1734           serverSocket->removeAcceptCallback(&cb5, nullptr);
1735         }
1736         ++cb2Count;
1737       });
1738   // Have callback 6 remove callback 4 the first time it is called,
1739   // and destroy the server socket the second time it is called
1740   int cb6Count = 0;
1741   cb6.setConnectionAcceptedFn(
1742       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1743         if (cb6Count == 0) {
1744           serverSocket->removeAcceptCallback(&cb4, nullptr);
1745           std::shared_ptr<AsyncSocket> sock6(
1746               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1747           std::shared_ptr<AsyncSocket> sock7(
1748               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
1749           std::shared_ptr<AsyncSocket> sock8(
1750               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop
1751
1752         } else {
1753           serverSocket.reset();
1754         }
1755         ++cb6Count;
1756       });
1757   // Have callback 7 remove itself
1758   cb7.setConnectionAcceptedFn(
1759       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1760         serverSocket->removeAcceptCallback(&cb7, nullptr);
1761       });
1762
1763   serverSocket->addAcceptCallback(&cb1, &eventBase);
1764   serverSocket->addAcceptCallback(&cb2, &eventBase);
1765   serverSocket->addAcceptCallback(&cb3, &eventBase);
1766   serverSocket->addAcceptCallback(&cb4, &eventBase);
1767   serverSocket->addAcceptCallback(&cb5, &eventBase);
1768   serverSocket->addAcceptCallback(&cb6, &eventBase);
1769   serverSocket->addAcceptCallback(&cb7, &eventBase);
1770   serverSocket->startAccepting();
1771
1772   // Make several connections to the socket
1773   std::shared_ptr<AsyncSocket> sock1(
1774       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1775   std::shared_ptr<AsyncSocket> sock4(
1776       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4
1777
1778   // Loop until we are stopped
1779   eventBase.loop();
1780
1781   // Check to make sure that the expected callbacks were invoked.
1782   //
1783   // NOTE: This code depends on the AsyncServerSocket operating calling all of
1784   // the AcceptCallbacks in round-robin fashion, in the order that they were
1785   // added.  The code is implemented this way right now, but the API doesn't
1786   // explicitly require it be done this way.  If we change the code not to be
1787   // exactly round robin in the future, we can simplify the test checks here.
1788   // (We'll also need to update the termination code, since we expect cb6 to
1789   // get called twice to terminate the loop.)
1790   ASSERT_EQ(cb1.getEvents()->size(), 4);
1791   ASSERT_EQ(cb1.getEvents()->at(0).type,
1792                     TestAcceptCallback::TYPE_START);
1793   ASSERT_EQ(cb1.getEvents()->at(1).type,
1794                     TestAcceptCallback::TYPE_ACCEPT);
1795   ASSERT_EQ(cb1.getEvents()->at(2).type,
1796                     TestAcceptCallback::TYPE_ACCEPT);
1797   ASSERT_EQ(cb1.getEvents()->at(3).type,
1798                     TestAcceptCallback::TYPE_STOP);
1799
1800   ASSERT_EQ(cb2.getEvents()->size(), 4);
1801   ASSERT_EQ(cb2.getEvents()->at(0).type,
1802                     TestAcceptCallback::TYPE_START);
1803   ASSERT_EQ(cb2.getEvents()->at(1).type,
1804                     TestAcceptCallback::TYPE_ACCEPT);
1805   ASSERT_EQ(cb2.getEvents()->at(2).type,
1806                     TestAcceptCallback::TYPE_ACCEPT);
1807   ASSERT_EQ(cb2.getEvents()->at(3).type,
1808                     TestAcceptCallback::TYPE_STOP);
1809
1810   ASSERT_EQ(cb3.getEvents()->size(), 2);
1811   ASSERT_EQ(cb3.getEvents()->at(0).type,
1812                     TestAcceptCallback::TYPE_START);
1813   ASSERT_EQ(cb3.getEvents()->at(1).type,
1814                     TestAcceptCallback::TYPE_STOP);
1815
1816   ASSERT_EQ(cb4.getEvents()->size(), 3);
1817   ASSERT_EQ(cb4.getEvents()->at(0).type,
1818                     TestAcceptCallback::TYPE_START);
1819   ASSERT_EQ(cb4.getEvents()->at(1).type,
1820                     TestAcceptCallback::TYPE_ACCEPT);
1821   ASSERT_EQ(cb4.getEvents()->at(2).type,
1822                     TestAcceptCallback::TYPE_STOP);
1823
1824   ASSERT_EQ(cb5.getEvents()->size(), 2);
1825   ASSERT_EQ(cb5.getEvents()->at(0).type,
1826                     TestAcceptCallback::TYPE_START);
1827   ASSERT_EQ(cb5.getEvents()->at(1).type,
1828                     TestAcceptCallback::TYPE_STOP);
1829
1830   ASSERT_EQ(cb6.getEvents()->size(), 4);
1831   ASSERT_EQ(cb6.getEvents()->at(0).type,
1832                     TestAcceptCallback::TYPE_START);
1833   ASSERT_EQ(cb6.getEvents()->at(1).type,
1834                     TestAcceptCallback::TYPE_ACCEPT);
1835   ASSERT_EQ(cb6.getEvents()->at(2).type,
1836                     TestAcceptCallback::TYPE_ACCEPT);
1837   ASSERT_EQ(cb6.getEvents()->at(3).type,
1838                     TestAcceptCallback::TYPE_STOP);
1839
1840   ASSERT_EQ(cb7.getEvents()->size(), 3);
1841   ASSERT_EQ(cb7.getEvents()->at(0).type,
1842                     TestAcceptCallback::TYPE_START);
1843   ASSERT_EQ(cb7.getEvents()->at(1).type,
1844                     TestAcceptCallback::TYPE_ACCEPT);
1845   ASSERT_EQ(cb7.getEvents()->at(2).type,
1846                     TestAcceptCallback::TYPE_STOP);
1847 }
1848
1849 /**
1850  * Test AsyncServerSocket::removeAcceptCallback()
1851  */
1852 TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
1853   // Create a new AsyncServerSocket
1854   EventBase eventBase;
1855   std::shared_ptr<AsyncServerSocket> serverSocket(
1856       AsyncServerSocket::newSocket(&eventBase));
1857   serverSocket->bind(0);
1858   serverSocket->listen(16);
1859   folly::SocketAddress serverAddress;
1860   serverSocket->getAddress(&serverAddress);
1861
1862   // Add several accept callbacks
1863   TestAcceptCallback cb1;
1864   auto thread_id = std::this_thread::get_id();
1865   cb1.setAcceptStartedFn([&](){
1866     CHECK_NE(thread_id, std::this_thread::get_id());
1867     thread_id = std::this_thread::get_id();
1868   });
1869   cb1.setConnectionAcceptedFn(
1870       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1871         ASSERT_EQ(thread_id, std::this_thread::get_id());
1872         serverSocket->removeAcceptCallback(&cb1, &eventBase);
1873       });
1874   cb1.setAcceptStoppedFn([&](){
1875     ASSERT_EQ(thread_id, std::this_thread::get_id());
1876   });
1877
1878   // Test having callbacks remove other callbacks before them on the list,
1879   serverSocket->addAcceptCallback(&cb1, &eventBase);
1880   serverSocket->startAccepting();
1881
1882   // Make several connections to the socket
1883   std::shared_ptr<AsyncSocket> sock1(
1884       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1885
1886   // Loop in another thread
1887   auto other = std::thread([&](){
1888     eventBase.loop();
1889   });
1890   other.join();
1891
1892   // Check to make sure that the expected callbacks were invoked.
1893   //
1894   // NOTE: This code depends on the AsyncServerSocket operating calling all of
1895   // the AcceptCallbacks in round-robin fashion, in the order that they were
1896   // added.  The code is implemented this way right now, but the API doesn't
1897   // explicitly require it be done this way.  If we change the code not to be
1898   // exactly round robin in the future, we can simplify the test checks here.
1899   // (We'll also need to update the termination code, since we expect cb6 to
1900   // get called twice to terminate the loop.)
1901   ASSERT_EQ(cb1.getEvents()->size(), 3);
1902   ASSERT_EQ(cb1.getEvents()->at(0).type,
1903                     TestAcceptCallback::TYPE_START);
1904   ASSERT_EQ(cb1.getEvents()->at(1).type,
1905                     TestAcceptCallback::TYPE_ACCEPT);
1906   ASSERT_EQ(cb1.getEvents()->at(2).type,
1907                     TestAcceptCallback::TYPE_STOP);
1908
1909 }
1910
1911 void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
1912   EventBase* eventBase = serverSocket->getEventBase();
1913   CHECK(eventBase);
1914
1915   // Add a callback to accept one connection then stop accepting
1916   TestAcceptCallback acceptCallback;
1917   acceptCallback.setConnectionAcceptedFn(
1918       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1919         serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
1920       });
1921   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
1922     serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
1923   });
1924   serverSocket->addAcceptCallback(&acceptCallback, eventBase);
1925   serverSocket->startAccepting();
1926
1927   // Connect to the server socket
1928   folly::SocketAddress serverAddress;
1929   serverSocket->getAddress(&serverAddress);
1930   AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
1931
1932   // Loop to process all events
1933   eventBase->loop();
1934
1935   // Verify that the server accepted a connection
1936   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
1937   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
1938                     TestAcceptCallback::TYPE_START);
1939   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
1940                     TestAcceptCallback::TYPE_ACCEPT);
1941   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
1942                     TestAcceptCallback::TYPE_STOP);
1943 }
1944
1945 /* Verify that we don't leak sockets if we are destroyed()
1946  * and there are still writes pending
1947  *
1948  * If destroy() only calls close() instead of closeNow(),
1949  * it would shutdown(writes) on the socket, but it would
1950  * never be close()'d, and the socket would leak
1951  */
1952 TEST(AsyncSocketTest, DestroyCloseTest) {
1953   TestServer server;
1954
1955   // connect()
1956   EventBase clientEB;
1957   EventBase serverEB;
1958   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
1959   ConnCallback ccb;
1960   socket->connect(&ccb, server.getAddress(), 30);
1961
1962   // Accept the connection
1963   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
1964   ReadCallback rcb;
1965   acceptedSocket->setReadCB(&rcb);
1966
1967   // Write a large buffer to the socket that is larger than kernel buffer
1968   size_t simpleBufLength = 5000000;
1969   char* simpleBuf = new char[simpleBufLength];
1970   memset(simpleBuf, 'a', simpleBufLength);
1971   WriteCallback wcb;
1972
1973   // Let the reads and writes run to completion
1974   int fd = acceptedSocket->getFd();
1975
1976   acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
1977   socket.reset();
1978   acceptedSocket.reset();
1979
1980   // Test that server socket was closed
1981   folly::test::msvcSuppressAbortOnInvalidParams([&] {
1982     ssize_t sz = read(fd, simpleBuf, simpleBufLength);
1983     ASSERT_EQ(sz, -1);
1984     ASSERT_EQ(errno, EBADF);
1985   });
1986   delete[] simpleBuf;
1987 }
1988
1989 /**
1990  * Test AsyncServerSocket::useExistingSocket()
1991  */
1992 TEST(AsyncSocketTest, ServerExistingSocket) {
1993   EventBase eventBase;
1994
1995   // Test creating a socket, and letting AsyncServerSocket bind and listen
1996   {
1997     // Manually create a socket
1998     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
1999     ASSERT_GE(fd, 0);
2000
2001     // Create a server socket
2002     AsyncServerSocket::UniquePtr serverSocket(
2003         new AsyncServerSocket(&eventBase));
2004     serverSocket->useExistingSocket(fd);
2005     folly::SocketAddress address;
2006     serverSocket->getAddress(&address);
2007     address.setPort(0);
2008     serverSocket->bind(address);
2009     serverSocket->listen(16);
2010
2011     // Make sure the socket works
2012     serverSocketSanityTest(serverSocket.get());
2013   }
2014
2015   // Test creating a socket and binding manually,
2016   // then letting AsyncServerSocket listen
2017   {
2018     // Manually create a socket
2019     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2020     ASSERT_GE(fd, 0);
2021     // bind
2022     struct sockaddr_in addr;
2023     addr.sin_family = AF_INET;
2024     addr.sin_port = 0;
2025     addr.sin_addr.s_addr = INADDR_ANY;
2026     ASSERT_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
2027                              sizeof(addr)), 0);
2028     // Look up the address that we bound to
2029     folly::SocketAddress boundAddress;
2030     boundAddress.setFromLocalAddress(fd);
2031
2032     // Create a server socket
2033     AsyncServerSocket::UniquePtr serverSocket(
2034         new AsyncServerSocket(&eventBase));
2035     serverSocket->useExistingSocket(fd);
2036     serverSocket->listen(16);
2037
2038     // Make sure AsyncServerSocket reports the same address that we bound to
2039     folly::SocketAddress serverSocketAddress;
2040     serverSocket->getAddress(&serverSocketAddress);
2041     ASSERT_EQ(boundAddress, serverSocketAddress);
2042
2043     // Make sure the socket works
2044     serverSocketSanityTest(serverSocket.get());
2045   }
2046
2047   // Test creating a socket, binding and listening manually,
2048   // then giving it to AsyncServerSocket
2049   {
2050     // Manually create a socket
2051     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2052     ASSERT_GE(fd, 0);
2053     // bind
2054     struct sockaddr_in addr;
2055     addr.sin_family = AF_INET;
2056     addr.sin_port = 0;
2057     addr.sin_addr.s_addr = INADDR_ANY;
2058     ASSERT_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
2059                              sizeof(addr)), 0);
2060     // Look up the address that we bound to
2061     folly::SocketAddress boundAddress;
2062     boundAddress.setFromLocalAddress(fd);
2063     // listen
2064     ASSERT_EQ(listen(fd, 16), 0);
2065
2066     // Create a server socket
2067     AsyncServerSocket::UniquePtr serverSocket(
2068         new AsyncServerSocket(&eventBase));
2069     serverSocket->useExistingSocket(fd);
2070
2071     // Make sure AsyncServerSocket reports the same address that we bound to
2072     folly::SocketAddress serverSocketAddress;
2073     serverSocket->getAddress(&serverSocketAddress);
2074     ASSERT_EQ(boundAddress, serverSocketAddress);
2075
2076     // Make sure the socket works
2077     serverSocketSanityTest(serverSocket.get());
2078   }
2079 }
2080
2081 TEST(AsyncSocketTest, UnixDomainSocketTest) {
2082   EventBase eventBase;
2083
2084   // Create a server socket
2085   std::shared_ptr<AsyncServerSocket> serverSocket(
2086       AsyncServerSocket::newSocket(&eventBase));
2087   string path(1, 0);
2088   path.append(folly::to<string>("/anonymous", folly::Random::rand64()));
2089   folly::SocketAddress serverAddress;
2090   serverAddress.setFromPath(path);
2091   serverSocket->bind(serverAddress);
2092   serverSocket->listen(16);
2093
2094   // Add a callback to accept one connection then stop the loop
2095   TestAcceptCallback acceptCallback;
2096   acceptCallback.setConnectionAcceptedFn(
2097       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2098         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2099       });
2100   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2101     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2102   });
2103   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2104   serverSocket->startAccepting();
2105
2106   // Connect to the server socket
2107   std::shared_ptr<AsyncSocket> socket(
2108       AsyncSocket::newSocket(&eventBase, serverAddress));
2109
2110   eventBase.loop();
2111
2112   // Verify that the server accepted a connection
2113   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
2114   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
2115                     TestAcceptCallback::TYPE_START);
2116   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
2117                     TestAcceptCallback::TYPE_ACCEPT);
2118   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
2119                     TestAcceptCallback::TYPE_STOP);
2120   int fd = acceptCallback.getEvents()->at(1).fd;
2121
2122   // The accepted connection should already be in non-blocking mode
2123   int flags = fcntl(fd, F_GETFL, 0);
2124   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
2125 }
2126
2127 TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
2128   EventBase eventBase;
2129   TestConnectionEventCallback connectionEventCallback;
2130
2131   // Create a server socket
2132   std::shared_ptr<AsyncServerSocket> serverSocket(
2133       AsyncServerSocket::newSocket(&eventBase));
2134   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2135   serverSocket->bind(0);
2136   serverSocket->listen(16);
2137   folly::SocketAddress serverAddress;
2138   serverSocket->getAddress(&serverAddress);
2139
2140   // Add a callback to accept one connection then stop the loop
2141   TestAcceptCallback acceptCallback;
2142   acceptCallback.setConnectionAcceptedFn(
2143       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2144         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2145       });
2146   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2147     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2148   });
2149   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2150   serverSocket->startAccepting();
2151
2152   // Connect to the server socket
2153   std::shared_ptr<AsyncSocket> socket(
2154       AsyncSocket::newSocket(&eventBase, serverAddress));
2155
2156   eventBase.loop();
2157
2158   // Validate the connection event counters
2159   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2160   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2161   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2162   ASSERT_EQ(
2163       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
2164   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
2165   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2166   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2167   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2168 }
2169
2170 TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
2171   EventBase eventBase;
2172   TestConnectionEventCallback connectionEventCallback;
2173
2174   // Create a server socket
2175   std::shared_ptr<AsyncServerSocket> serverSocket(
2176       AsyncServerSocket::newSocket(&eventBase));
2177   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2178   serverSocket->bind(0);
2179   serverSocket->listen(16);
2180   folly::SocketAddress serverAddress;
2181   serverSocket->getAddress(&serverAddress);
2182
2183   // Add a callback to accept one connection then stop the loop
2184   TestAcceptCallback acceptCallback;
2185   acceptCallback.setConnectionAcceptedFn(
2186       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2187         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2188       });
2189   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2190     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2191   });
2192   bool acceptStartedFlag{false};
2193   acceptCallback.setAcceptStartedFn([&acceptStartedFlag](){
2194     acceptStartedFlag = true;
2195   });
2196   bool acceptStoppedFlag{false};
2197   acceptCallback.setAcceptStoppedFn([&acceptStoppedFlag](){
2198     acceptStoppedFlag = true;
2199   });
2200   serverSocket->addAcceptCallback(&acceptCallback, nullptr);
2201   serverSocket->startAccepting();
2202
2203   // Connect to the server socket
2204   std::shared_ptr<AsyncSocket> socket(
2205       AsyncSocket::newSocket(&eventBase, serverAddress));
2206
2207   eventBase.loop();
2208
2209   ASSERT_TRUE(acceptStartedFlag);
2210   ASSERT_TRUE(acceptStoppedFlag);
2211   // Validate the connection event counters
2212   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2213   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2214   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2215   ASSERT_EQ(
2216       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
2217   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
2218   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2219   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2220   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2221 }
2222
2223
2224
2225 /**
2226  * Test AsyncServerSocket::getNumPendingMessagesInQueue()
2227  */
2228 TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
2229   EventBase eventBase;
2230
2231   // Counter of how many connections have been accepted
2232   int count = 0;
2233
2234   // Create a server socket
2235   auto serverSocket(AsyncServerSocket::newSocket(&eventBase));
2236   serverSocket->bind(0);
2237   serverSocket->listen(16);
2238   folly::SocketAddress serverAddress;
2239   serverSocket->getAddress(&serverAddress);
2240
2241   // Add a callback to accept connections
2242   TestAcceptCallback acceptCallback;
2243   acceptCallback.setConnectionAcceptedFn(
2244       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2245         count++;
2246         ASSERT_EQ(4 - count, serverSocket->getNumPendingMessagesInQueue());
2247
2248         if (count == 4) {
2249           // all messages are processed, remove accept callback
2250           serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2251         }
2252       });
2253   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2254     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2255   });
2256   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2257   serverSocket->startAccepting();
2258
2259   // Connect to the server socket, 4 clients, there are 4 connections
2260   auto socket1(AsyncSocket::newSocket(&eventBase, serverAddress));
2261   auto socket2(AsyncSocket::newSocket(&eventBase, serverAddress));
2262   auto socket3(AsyncSocket::newSocket(&eventBase, serverAddress));
2263   auto socket4(AsyncSocket::newSocket(&eventBase, serverAddress));
2264
2265   eventBase.loop();
2266 }
2267
2268 /**
2269  * Test AsyncTransport::BufferCallback
2270  */
2271 TEST(AsyncSocketTest, BufferTest) {
2272   TestServer server;
2273
2274   EventBase evb;
2275   AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2276   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2277   ConnCallback ccb;
2278   socket->connect(&ccb, server.getAddress(), 30, option);
2279
2280   char buf[100 * 1024];
2281   memset(buf, 'c', sizeof(buf));
2282   WriteCallback wcb;
2283   BufferCallback bcb;
2284   socket->setBufferCallback(&bcb);
2285   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2286
2287   evb.loop();
2288   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2289   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2290
2291   ASSERT_TRUE(bcb.hasBuffered());
2292   ASSERT_TRUE(bcb.hasBufferCleared());
2293
2294   socket->close();
2295   server.verifyConnection(buf, sizeof(buf));
2296
2297   ASSERT_TRUE(socket->isClosedBySelf());
2298   ASSERT_FALSE(socket->isClosedByPeer());
2299 }
2300
2301 TEST(AsyncSocketTest, BufferCallbackKill) {
2302   TestServer server;
2303   EventBase evb;
2304   AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2305   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2306   ConnCallback ccb;
2307   socket->connect(&ccb, server.getAddress(), 30, option);
2308   evb.loopOnce();
2309
2310   char buf[100 * 1024];
2311   memset(buf, 'c', sizeof(buf));
2312   BufferCallback bcb;
2313   socket->setBufferCallback(&bcb);
2314   WriteCallback wcb;
2315   wcb.successCallback = [&] {
2316     ASSERT_TRUE(socket.unique());
2317     socket.reset();
2318   };
2319
2320   // This will trigger AsyncSocket::handleWrite,
2321   // which calls WriteCallback::writeSuccess,
2322   // which calls wcb.successCallback above,
2323   // which tries to delete socket
2324   // Then, the socket will also try to use this BufferCallback
2325   // And that should crash us, if there is no DestructorGuard on the stack
2326   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2327
2328   evb.loop();
2329   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2330 }
2331
2332 #if FOLLY_ALLOW_TFO
2333 TEST(AsyncSocketTest, ConnectTFO) {
2334   // Start listening on a local port
2335   TestServer server(true);
2336
2337   // Connect using a AsyncSocket
2338   EventBase evb;
2339   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2340   socket->enableTFO();
2341   ConnCallback cb;
2342   socket->connect(&cb, server.getAddress(), 30);
2343
2344   std::array<uint8_t, 128> buf;
2345   memset(buf.data(), 'a', buf.size());
2346
2347   std::array<uint8_t, 3> readBuf;
2348   auto sendBuf = IOBuf::copyBuffer("hey");
2349
2350   std::thread t([&] {
2351     auto acceptedSocket = server.accept();
2352     acceptedSocket->write(buf.data(), buf.size());
2353     acceptedSocket->flush();
2354     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2355     acceptedSocket->close();
2356   });
2357
2358   evb.loop();
2359
2360   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2361   EXPECT_LE(0, socket->getConnectTime().count());
2362   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2363   EXPECT_TRUE(socket->getTFOAttempted());
2364
2365   // Should trigger the connect
2366   WriteCallback write;
2367   ReadCallback rcb;
2368   socket->writeChain(&write, sendBuf->clone());
2369   socket->setReadCB(&rcb);
2370   evb.loop();
2371
2372   t.join();
2373
2374   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2375   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2376   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2377   ASSERT_EQ(1, rcb.buffers.size());
2378   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2379   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2380   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2381 }
2382
2383 TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
2384   // Start listening on a local port
2385   TestServer server(true);
2386
2387   // Connect using a AsyncSocket
2388   EventBase evb;
2389   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2390   socket->enableTFO();
2391   ConnCallback cb;
2392   socket->connect(&cb, server.getAddress(), 30);
2393   ReadCallback rcb;
2394   socket->setReadCB(&rcb);
2395
2396   std::array<uint8_t, 128> buf;
2397   memset(buf.data(), 'a', buf.size());
2398
2399   std::array<uint8_t, 3> readBuf;
2400   auto sendBuf = IOBuf::copyBuffer("hey");
2401
2402   std::thread t([&] {
2403     auto acceptedSocket = server.accept();
2404     acceptedSocket->write(buf.data(), buf.size());
2405     acceptedSocket->flush();
2406     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2407     acceptedSocket->close();
2408   });
2409
2410   evb.loop();
2411
2412   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2413   EXPECT_LE(0, socket->getConnectTime().count());
2414   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2415   EXPECT_TRUE(socket->getTFOAttempted());
2416
2417   // Should trigger the connect
2418   WriteCallback write;
2419   socket->writeChain(&write, sendBuf->clone());
2420   evb.loop();
2421
2422   t.join();
2423
2424   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2425   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2426   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2427   ASSERT_EQ(1, rcb.buffers.size());
2428   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2429   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2430   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2431 }
2432
2433 /**
2434  * Test connecting to a server that isn't listening
2435  */
2436 TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
2437   EventBase evb;
2438
2439   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2440
2441   socket->enableTFO();
2442
2443   // Hopefully nothing is actually listening on this address
2444   folly::SocketAddress addr("::1", 65535);
2445   ConnCallback cb;
2446   socket->connect(&cb, addr, 30);
2447
2448   evb.loop();
2449
2450   WriteCallback write1;
2451   // Trigger the connect if TFO attempt is supported.
2452   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2453   WriteCallback write2;
2454   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2455   evb.loop();
2456
2457   if (!socket->getTFOFinished()) {
2458     EXPECT_EQ(STATE_FAILED, write1.state);
2459   } else {
2460     EXPECT_EQ(STATE_SUCCEEDED, write1.state);
2461     EXPECT_FALSE(socket->getTFOSucceded());
2462   }
2463
2464   EXPECT_EQ(STATE_FAILED, write2.state);
2465
2466   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2467   EXPECT_LE(0, socket->getConnectTime().count());
2468   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2469   EXPECT_TRUE(socket->getTFOAttempted());
2470 }
2471
2472 /**
2473  * Test calling closeNow() immediately after connecting.
2474  */
2475 TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) {
2476   TestServer server(true);
2477
2478   // connect()
2479   EventBase evb;
2480   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2481   socket->enableTFO();
2482
2483   ConnCallback ccb;
2484   socket->connect(&ccb, server.getAddress(), 30);
2485
2486   // write()
2487   std::array<char, 128> buf;
2488   memset(buf.data(), 'a', buf.size());
2489
2490   // close()
2491   socket->closeNow();
2492
2493   // Loop, although there shouldn't be anything to do.
2494   evb.loop();
2495
2496   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2497
2498   ASSERT_TRUE(socket->isClosedBySelf());
2499   ASSERT_FALSE(socket->isClosedByPeer());
2500 }
2501
2502 /**
2503  * Test calling close() immediately after connect()
2504  */
2505 TEST(AsyncSocketTest, ConnectAndCloseTFO) {
2506   TestServer server(true);
2507
2508   // Connect using a AsyncSocket
2509   EventBase evb;
2510   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2511   socket->enableTFO();
2512
2513   ConnCallback ccb;
2514   socket->connect(&ccb, server.getAddress(), 30);
2515
2516   socket->close();
2517
2518   // Loop, although there shouldn't be anything to do.
2519   evb.loop();
2520
2521   // Make sure the connection was aborted
2522   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2523
2524   ASSERT_TRUE(socket->isClosedBySelf());
2525   ASSERT_FALSE(socket->isClosedByPeer());
2526 }
2527
2528 class MockAsyncTFOSocket : public AsyncSocket {
2529  public:
2530   using UniquePtr = std::unique_ptr<MockAsyncTFOSocket, Destructor>;
2531
2532   explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
2533
2534   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
2535 };
2536
2537 TEST(AsyncSocketTest, TestTFOUnsupported) {
2538   TestServer server(true);
2539
2540   // Connect using a AsyncSocket
2541   EventBase evb;
2542   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2543   socket->enableTFO();
2544
2545   ConnCallback ccb;
2546   socket->connect(&ccb, server.getAddress(), 30);
2547   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2548
2549   ReadCallback rcb;
2550   socket->setReadCB(&rcb);
2551
2552   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2553       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2554   WriteCallback write;
2555   auto sendBuf = IOBuf::copyBuffer("hey");
2556   socket->writeChain(&write, sendBuf->clone());
2557   EXPECT_EQ(STATE_WAITING, write.state);
2558
2559   std::array<uint8_t, 128> buf;
2560   memset(buf.data(), 'a', buf.size());
2561
2562   std::array<uint8_t, 3> readBuf;
2563
2564   std::thread t([&] {
2565     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2566     acceptedSocket->write(buf.data(), buf.size());
2567     acceptedSocket->flush();
2568     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2569     acceptedSocket->close();
2570   });
2571
2572   evb.loop();
2573
2574   t.join();
2575   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2576   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2577
2578   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2579   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2580   ASSERT_EQ(1, rcb.buffers.size());
2581   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2582   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2583   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2584 }
2585
2586 TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
2587   EventBase evb;
2588
2589   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2590   socket->enableTFO();
2591
2592   // Hopefully this fails
2593   folly::SocketAddress fakeAddr("127.0.0.1", 65535);
2594   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2595       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2596         sockaddr_storage addr;
2597         auto len = fakeAddr.getAddress(&addr);
2598         int ret = connect(fd, (const struct sockaddr*)&addr, len);
2599         LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
2600                   << errno;
2601         return ret;
2602       }));
2603
2604   // Hopefully nothing is actually listening on this address
2605   ConnCallback cb;
2606   socket->connect(&cb, fakeAddr, 30);
2607
2608   WriteCallback write1;
2609   // Trigger the connect if TFO attempt is supported.
2610   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2611
2612   if (socket->getTFOFinished()) {
2613     // This test is useless now.
2614     return;
2615   }
2616   WriteCallback write2;
2617   // Trigger the connect if TFO attempt is supported.
2618   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2619   evb.loop();
2620
2621   EXPECT_EQ(STATE_FAILED, write1.state);
2622   EXPECT_EQ(STATE_FAILED, write2.state);
2623   EXPECT_FALSE(socket->getTFOSucceded());
2624
2625   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2626   EXPECT_LE(0, socket->getConnectTime().count());
2627   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2628   EXPECT_TRUE(socket->getTFOAttempted());
2629 }
2630
2631 TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
2632   // Try connecting to server that won't respond.
2633   //
2634   // This depends somewhat on the network where this test is run.
2635   // Hopefully this IP will be routable but unresponsive.
2636   // (Alternatively, we could try listening on a local raw socket, but that
2637   // normally requires root privileges.)
2638   auto host = SocketAddressTestHelper::isIPv6Enabled()
2639       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2640       : SocketAddressTestHelper::isIPv4Enabled()
2641           ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2642           : nullptr;
2643   SocketAddress addr(host, 65535);
2644
2645   // Connect using a AsyncSocket
2646   EventBase evb;
2647   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2648   socket->enableTFO();
2649
2650   ConnCallback ccb;
2651   // Set a very small timeout
2652   socket->connect(&ccb, addr, 1);
2653   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2654
2655   ReadCallback rcb;
2656   socket->setReadCB(&rcb);
2657
2658   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2659       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2660   WriteCallback write;
2661   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2662
2663   evb.loop();
2664
2665   EXPECT_EQ(STATE_FAILED, write.state);
2666 }
2667
2668 TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
2669   TestServer server(true);
2670
2671   // Connect using a AsyncSocket
2672   EventBase evb;
2673   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2674   socket->enableTFO();
2675
2676   ConnCallback ccb;
2677   socket->connect(&ccb, server.getAddress(), 30);
2678   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2679
2680   ReadCallback rcb;
2681   socket->setReadCB(&rcb);
2682
2683   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2684       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2685         sockaddr_storage addr;
2686         auto len = server.getAddress().getAddress(&addr);
2687         return connect(fd, (const struct sockaddr*)&addr, len);
2688       }));
2689   WriteCallback write;
2690   auto sendBuf = IOBuf::copyBuffer("hey");
2691   socket->writeChain(&write, sendBuf->clone());
2692   EXPECT_EQ(STATE_WAITING, write.state);
2693
2694   std::array<uint8_t, 128> buf;
2695   memset(buf.data(), 'a', buf.size());
2696
2697   std::array<uint8_t, 3> readBuf;
2698
2699   std::thread t([&] {
2700     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2701     acceptedSocket->write(buf.data(), buf.size());
2702     acceptedSocket->flush();
2703     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2704     acceptedSocket->close();
2705   });
2706
2707   evb.loop();
2708
2709   t.join();
2710   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2711
2712   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2713   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2714
2715   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2716   ASSERT_EQ(1, rcb.buffers.size());
2717   ASSERT_EQ(buf.size(), rcb.buffers[0].length);
2718   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2719 }
2720
2721 TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
2722   // Try connecting to server that won't respond.
2723   //
2724   // This depends somewhat on the network where this test is run.
2725   // Hopefully this IP will be routable but unresponsive.
2726   // (Alternatively, we could try listening on a local raw socket, but that
2727   // normally requires root privileges.)
2728   auto host = SocketAddressTestHelper::isIPv6Enabled()
2729       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2730       : SocketAddressTestHelper::isIPv4Enabled()
2731           ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2732           : nullptr;
2733   SocketAddress addr(host, 65535);
2734
2735   // Connect using a AsyncSocket
2736   EventBase evb;
2737   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2738   socket->enableTFO();
2739
2740   ConnCallback ccb;
2741   // Set a very small timeout
2742   socket->connect(&ccb, addr, 1);
2743   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2744
2745   ReadCallback rcb;
2746   socket->setReadCB(&rcb);
2747
2748   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2749       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2750         sockaddr_storage addr2;
2751         auto len = addr.getAddress(&addr2);
2752         return connect(fd, (const struct sockaddr*)&addr2, len);
2753       }));
2754   WriteCallback write;
2755   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2756
2757   evb.loop();
2758
2759   EXPECT_EQ(STATE_FAILED, write.state);
2760 }
2761
2762 TEST(AsyncSocketTest, TestTFOEagain) {
2763   TestServer server(true);
2764
2765   // Connect using a AsyncSocket
2766   EventBase evb;
2767   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2768   socket->enableTFO();
2769
2770   ConnCallback ccb;
2771   socket->connect(&ccb, server.getAddress(), 30);
2772
2773   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2774       .WillOnce(SetErrnoAndReturn(EAGAIN, -1));
2775   WriteCallback write;
2776   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2777
2778   evb.loop();
2779
2780   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2781   EXPECT_EQ(STATE_FAILED, write.state);
2782 }
2783
2784 // Sending a large amount of data in the first write which will
2785 // definitely not fit into MSS.
2786 TEST(AsyncSocketTest, ConnectTFOWithBigData) {
2787   // Start listening on a local port
2788   TestServer server(true);
2789
2790   // Connect using a AsyncSocket
2791   EventBase evb;
2792   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2793   socket->enableTFO();
2794   ConnCallback cb;
2795   socket->connect(&cb, server.getAddress(), 30);
2796
2797   std::array<uint8_t, 128> buf;
2798   memset(buf.data(), 'a', buf.size());
2799
2800   constexpr size_t len = 10 * 1024;
2801   auto sendBuf = IOBuf::create(len);
2802   sendBuf->append(len);
2803   std::array<uint8_t, len> readBuf;
2804
2805   std::thread t([&] {
2806     auto acceptedSocket = server.accept();
2807     acceptedSocket->write(buf.data(), buf.size());
2808     acceptedSocket->flush();
2809     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2810     acceptedSocket->close();
2811   });
2812
2813   evb.loop();
2814
2815   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2816   EXPECT_LE(0, socket->getConnectTime().count());
2817   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2818   EXPECT_TRUE(socket->getTFOAttempted());
2819
2820   // Should trigger the connect
2821   WriteCallback write;
2822   ReadCallback rcb;
2823   socket->writeChain(&write, sendBuf->clone());
2824   socket->setReadCB(&rcb);
2825   evb.loop();
2826
2827   t.join();
2828
2829   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2830   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2831   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2832   ASSERT_EQ(1, rcb.buffers.size());
2833   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2834   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2835   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2836 }
2837
2838 #endif // FOLLY_ALLOW_TFO
2839
2840 class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
2841  public:
2842   MOCK_METHOD1(evbAttached, void(AsyncSocket*));
2843   MOCK_METHOD1(evbDetached, void(AsyncSocket*));
2844 };
2845
2846 TEST(AsyncSocketTest, EvbCallbacks) {
2847   auto cb = std::make_unique<MockEvbChangeCallback>();
2848   EventBase evb;
2849   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2850
2851   InSequence seq;
2852   EXPECT_CALL(*cb, evbDetached(socket.get())).Times(1);
2853   EXPECT_CALL(*cb, evbAttached(socket.get())).Times(1);
2854
2855   socket->setEvbChangedCallback(std::move(cb));
2856   socket->detachEventBase();
2857   socket->attachEventBase(&evb);
2858 }
2859
2860 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
2861 /* copied from include/uapi/linux/net_tstamp.h */
2862 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
2863 enum SOF_TIMESTAMPING {
2864   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
2865   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
2866   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
2867   SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
2868   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
2869 };
2870
2871 class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
2872  public:
2873   TestErrMessageCallback()
2874       : exception_(folly::AsyncSocketException::UNKNOWN, "none") {}
2875
2876   void errMessage(const cmsghdr& cmsg) noexcept override {
2877     if (cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_TIMESTAMPING) {
2878       gotTimestamp_++;
2879       checkResetCallback();
2880     } else if (
2881         (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
2882         (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
2883       gotByteSeq_++;
2884       checkResetCallback();
2885     }
2886   }
2887
2888   void errMessageError(
2889       const folly::AsyncSocketException& ex) noexcept override {
2890     exception_ = ex;
2891   }
2892
2893   void checkResetCallback() noexcept {
2894     if (socket_ != nullptr && resetAfter_ != -1 &&
2895         gotTimestamp_ + gotByteSeq_ == resetAfter_) {
2896       socket_->setErrMessageCB(nullptr);
2897     }
2898   }
2899
2900   folly::AsyncSocket* socket_{nullptr};
2901   folly::AsyncSocketException exception_;
2902   int gotTimestamp_{0};
2903   int gotByteSeq_{0};
2904   int resetAfter_{-1};
2905 };
2906
2907 TEST(AsyncSocketTest, ErrMessageCallback) {
2908   TestServer server;
2909
2910   // connect()
2911   EventBase evb;
2912   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2913
2914   ConnCallback ccb;
2915   socket->connect(&ccb, server.getAddress(), 30);
2916   LOG(INFO) << "Client socket fd=" << socket->getFd();
2917
2918   // Let the socket
2919   evb.loop();
2920
2921   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2922
2923   // Set read callback to keep the socket subscribed for event
2924   // notifications. Though we're no planning to read anything from
2925   // this side of the connection.
2926   ReadCallback rcb(1);
2927   socket->setReadCB(&rcb);
2928
2929   // Set up timestamp callbacks
2930   TestErrMessageCallback errMsgCB;
2931   socket->setErrMessageCB(&errMsgCB);
2932   ASSERT_EQ(socket->getErrMessageCallback(),
2933             static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB));
2934
2935   errMsgCB.socket_ = socket.get();
2936   errMsgCB.resetAfter_ = 3;
2937
2938   // Enable timestamp notifications
2939   ASSERT_GT(socket->getFd(), 0);
2940   int flags = SOF_TIMESTAMPING_OPT_ID
2941               | SOF_TIMESTAMPING_OPT_TSONLY
2942               | SOF_TIMESTAMPING_SOFTWARE
2943               | SOF_TIMESTAMPING_OPT_CMSG
2944               | SOF_TIMESTAMPING_TX_SCHED;
2945   AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
2946   EXPECT_EQ(tstampingOpt.apply(socket->getFd(), flags), 0);
2947
2948   // write()
2949   std::vector<uint8_t> wbuf(128, 'a');
2950   WriteCallback wcb;
2951   // Send two packets to get two EOM notifications
2952   socket->write(&wcb, wbuf.data(), wbuf.size() / 2);
2953   socket->write(&wcb, wbuf.data() + wbuf.size() / 2, wbuf.size() / 2);
2954
2955   // Accept the connection.
2956   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2957   LOG(INFO) << "Server socket fd=" << acceptedSocket->getSocketFD();
2958
2959   // Loop
2960   evb.loopOnce();
2961   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2962
2963   // Check that we can read the data that was written to the socket
2964   std::vector<uint8_t> rbuf(1 + wbuf.size(), 0);
2965   uint32_t bytesRead = acceptedSocket->read(rbuf.data(), rbuf.size());
2966   ASSERT_TRUE(std::equal(wbuf.begin(), wbuf.end(), rbuf.begin()));
2967   ASSERT_EQ(bytesRead, wbuf.size());
2968
2969   // Close both sockets
2970   acceptedSocket->close();
2971   socket->close();
2972
2973   ASSERT_TRUE(socket->isClosedBySelf());
2974   ASSERT_FALSE(socket->isClosedByPeer());
2975
2976   // Check for the timestamp notifications.
2977   ASSERT_EQ(errMsgCB.exception_.type_, folly::AsyncSocketException::UNKNOWN);
2978   ASSERT_GT(errMsgCB.gotByteSeq_, 0);
2979   ASSERT_GT(errMsgCB.gotTimestamp_, 0);
2980   ASSERT_EQ(
2981       errMsgCB.gotByteSeq_ + errMsgCB.gotTimestamp_, errMsgCB.resetAfter_);
2982 }
2983 #endif // FOLLY_HAVE_MSG_ERRQUEUE
2984
2985 TEST(AsyncSocket, PreReceivedData) {
2986   TestServer server;
2987
2988   EventBase evb;
2989   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2990   socket->connect(nullptr, server.getAddress(), 30);
2991   evb.loop();
2992
2993   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
2994
2995   auto acceptedSocket = server.acceptAsync(&evb);
2996
2997   ReadCallback peekCallback(2);
2998   ReadCallback readCallback;
2999   peekCallback.dataAvailableCallback = [&]() {
3000     peekCallback.verifyData("he", 2);
3001     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
3002     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
3003     acceptedSocket->setReadCB(nullptr);
3004     acceptedSocket->setReadCB(&readCallback);
3005   };
3006   readCallback.dataAvailableCallback = [&]() {
3007     if (readCallback.dataRead() == 5) {
3008       readCallback.verifyData("hello", 5);
3009       acceptedSocket->setReadCB(nullptr);
3010     }
3011   };
3012
3013   acceptedSocket->setReadCB(&peekCallback);
3014
3015   evb.loop();
3016 }
3017
3018 TEST(AsyncSocket, PreReceivedDataOnly) {
3019   TestServer server;
3020
3021   EventBase evb;
3022   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3023   socket->connect(nullptr, server.getAddress(), 30);
3024   evb.loop();
3025
3026   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
3027
3028   auto acceptedSocket = server.acceptAsync(&evb);
3029
3030   ReadCallback peekCallback;
3031   ReadCallback readCallback;
3032   peekCallback.dataAvailableCallback = [&]() {
3033     peekCallback.verifyData("hello", 5);
3034     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
3035     acceptedSocket->setReadCB(&readCallback);
3036   };
3037   readCallback.dataAvailableCallback = [&]() {
3038     readCallback.verifyData("hello", 5);
3039     acceptedSocket->setReadCB(nullptr);
3040   };
3041
3042   acceptedSocket->setReadCB(&peekCallback);
3043
3044   evb.loop();
3045 }
3046
3047 TEST(AsyncSocket, PreReceivedDataPartial) {
3048   TestServer server;
3049
3050   EventBase evb;
3051   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3052   socket->connect(nullptr, server.getAddress(), 30);
3053   evb.loop();
3054
3055   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
3056
3057   auto acceptedSocket = server.acceptAsync(&evb);
3058
3059   ReadCallback peekCallback;
3060   ReadCallback smallReadCallback(3);
3061   ReadCallback normalReadCallback;
3062   peekCallback.dataAvailableCallback = [&]() {
3063     peekCallback.verifyData("hello", 5);
3064     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
3065     acceptedSocket->setReadCB(&smallReadCallback);
3066   };
3067   smallReadCallback.dataAvailableCallback = [&]() {
3068     smallReadCallback.verifyData("hel", 3);
3069     acceptedSocket->setReadCB(&normalReadCallback);
3070   };
3071   normalReadCallback.dataAvailableCallback = [&]() {
3072     normalReadCallback.verifyData("lo", 2);
3073     acceptedSocket->setReadCB(nullptr);
3074   };
3075
3076   acceptedSocket->setReadCB(&peekCallback);
3077
3078   evb.loop();
3079 }
3080
3081 TEST(AsyncSocket, PreReceivedDataTakeover) {
3082   TestServer server;
3083
3084   EventBase evb;
3085   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3086   socket->connect(nullptr, server.getAddress(), 30);
3087   evb.loop();
3088
3089   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
3090
3091   auto acceptedSocket =
3092       AsyncSocket::UniquePtr(new AsyncSocket(&evb, server.acceptFD()));
3093   AsyncSocket::UniquePtr takeoverSocket;
3094
3095   ReadCallback peekCallback(3);
3096   ReadCallback readCallback;
3097   peekCallback.dataAvailableCallback = [&]() {
3098     peekCallback.verifyData("hel", 3);
3099     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
3100     acceptedSocket->setReadCB(nullptr);
3101     takeoverSocket =
3102         AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
3103     takeoverSocket->setReadCB(&readCallback);
3104   };
3105   readCallback.dataAvailableCallback = [&]() {
3106     readCallback.verifyData("hello", 5);
3107     takeoverSocket->setReadCB(nullptr);
3108   };
3109
3110   acceptedSocket->setReadCB(&peekCallback);
3111
3112   evb.loop();
3113 }
3114
3115 #ifdef MSG_NOSIGNAL
3116 TEST(AsyncSocketTest, SendMessageFlags) {
3117   TestServer server;
3118   TestSendMsgParamsCallback sendMsgCB(
3119       MSG_DONTWAIT|MSG_NOSIGNAL|MSG_MORE, 0, nullptr);
3120
3121   // connect()
3122   EventBase evb;
3123   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3124
3125   ConnCallback ccb;
3126   socket->connect(&ccb, server.getAddress(), 30);
3127   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
3128
3129   evb.loop();
3130   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
3131
3132   // Set SendMsgParamsCallback
3133   socket->setSendMsgParamCB(&sendMsgCB);
3134   ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
3135
3136   // Write the first portion of data. This data is expected to be
3137   // sent out immediately.
3138   std::vector<uint8_t> buf(128, 'a');
3139   WriteCallback wcb;
3140   sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
3141   socket->write(&wcb, buf.data(), buf.size());
3142   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3143   ASSERT_TRUE(sendMsgCB.queriedFlags_);
3144   ASSERT_FALSE(sendMsgCB.queriedData_);
3145
3146   // Using different flags for the second write operation.
3147   // MSG_MORE flag is expected to delay sending this
3148   // data to the wire.
3149   sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
3150   socket->write(&wcb, buf.data(), buf.size());
3151   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3152   ASSERT_TRUE(sendMsgCB.queriedFlags_);
3153   ASSERT_FALSE(sendMsgCB.queriedData_);
3154
3155   // Make sure the accepted socket saw only the data from
3156   // the first write request.
3157   std::vector<uint8_t> readbuf(2 * buf.size());
3158   uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
3159   ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
3160   ASSERT_EQ(bytesRead, buf.size());
3161
3162   // Make sure the server got a connection and received the data
3163   acceptedSocket->close();
3164   socket->close();
3165
3166   ASSERT_TRUE(socket->isClosedBySelf());
3167   ASSERT_FALSE(socket->isClosedByPeer());
3168 }
3169
3170 TEST(AsyncSocketTest, SendMessageAncillaryData) {
3171   int fds[2];
3172   EXPECT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0);
3173
3174   // "Client" socket
3175   int cfd = fds[0];
3176   ASSERT_NE(cfd, -1);
3177
3178   // "Server" socket
3179   int sfd = fds[1];
3180   ASSERT_NE(sfd, -1);
3181   SCOPE_EXIT { close(sfd); };
3182
3183   // Instantiate AsyncSocket object for the connected socket
3184   EventBase evb;
3185   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, cfd);
3186
3187   // Open a temporary file and write a magic string to it
3188   // We'll transfer the file handle to test the message parameters
3189   // callback logic.
3190   TemporaryFile file(StringPiece(),
3191                      fs::path(),
3192                      TemporaryFile::Scope::UNLINK_IMMEDIATELY);
3193   int tmpfd = file.fd();
3194   ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
3195   std::string magicString("Magic string");
3196   ASSERT_EQ(write(tmpfd, magicString.c_str(), magicString.length()),
3197             magicString.length());
3198
3199   // Send message
3200   union {
3201     // Space large enough to hold an 'int'
3202     char control[CMSG_SPACE(sizeof(int))];
3203     struct cmsghdr cmh;
3204   } s_u;
3205   s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
3206   s_u.cmh.cmsg_level = SOL_SOCKET;
3207   s_u.cmh.cmsg_type = SCM_RIGHTS;
3208   memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
3209
3210   // Set up the callback providing message parameters
3211   TestSendMsgParamsCallback sendMsgCB(
3212       MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
3213   socket->setSendMsgParamCB(&sendMsgCB);
3214
3215   // We must transmit at least 1 byte of real data in order
3216   // to send ancillary data
3217   int s_data = 12345;
3218   WriteCallback wcb;
3219   socket->write(&wcb, &s_data, sizeof(s_data));
3220   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3221
3222   // Receive the message
3223   union {
3224     // Space large enough to hold an 'int'
3225     char control[CMSG_SPACE(sizeof(int))];
3226     struct cmsghdr cmh;
3227   } r_u;
3228   struct msghdr msgh;
3229   struct iovec iov;
3230   int r_data = 0;
3231
3232   msgh.msg_control = r_u.control;
3233   msgh.msg_controllen = sizeof(r_u.control);
3234   msgh.msg_name = nullptr;
3235   msgh.msg_namelen = 0;
3236   msgh.msg_iov = &iov;
3237   msgh.msg_iovlen = 1;
3238   iov.iov_base = &r_data;
3239   iov.iov_len = sizeof(r_data);
3240
3241   // Receive data
3242   ASSERT_NE(recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
3243
3244   // Validate the received message
3245   ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
3246   ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
3247   ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
3248   ASSERT_EQ(r_data, s_data);
3249   int fd = 0;
3250   memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
3251   ASSERT_NE(fd, 0);
3252   SCOPE_EXIT { close(fd); };
3253
3254   std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
3255
3256   // Reposition to the beginning of the file
3257   ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
3258
3259   // Read the magic string back, and compare it with the original
3260   ASSERT_EQ(
3261       magicString.length(),
3262       read(fd, transferredMagicString.data(), transferredMagicString.size()));
3263   ASSERT_TRUE(std::equal(
3264       magicString.begin(),
3265       magicString.end(),
3266       transferredMagicString.begin()));
3267 }
3268 #endif