e90bca5c728c5788c1038b547118b7abbcf35459
[folly.git] / folly / io / async / test / AsyncSocketTest2.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/test/AsyncSocketTest2.h>
18
19 #include <folly/ExceptionWrapper.h>
20 #include <folly/Random.h>
21 #include <folly/SocketAddress.h>
22 #include <folly/io/async/AsyncSocket.h>
23 #include <folly/io/async/AsyncTimeout.h>
24 #include <folly/io/async/EventBase.h>
25
26 #include <folly/experimental/TestUtil.h>
27 #include <folly/io/IOBuf.h>
28 #include <folly/io/async/test/AsyncSocketTest.h>
29 #include <folly/io/async/test/Util.h>
30 #include <folly/portability/GMock.h>
31 #include <folly/portability/GTest.h>
32 #include <folly/portability/Sockets.h>
33 #include <folly/portability/Unistd.h>
34 #include <folly/test/SocketAddressTestHelper.h>
35
36 #include <boost/scoped_array.hpp>
37 #include <fcntl.h>
38 #include <sys/types.h>
39 #include <iostream>
40 #include <thread>
41
42 using namespace boost;
43
44 using std::string;
45 using std::vector;
46 using std::min;
47 using std::cerr;
48 using std::endl;
49 using std::unique_ptr;
50 using std::chrono::milliseconds;
51 using boost::scoped_array;
52
53 using namespace folly;
54 using namespace folly::test;
55 using namespace testing;
56
57 namespace fsp = folly::portability::sockets;
58
59 class DelayedWrite: public AsyncTimeout {
60  public:
61   DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
62       unique_ptr<IOBuf>&& bufs, AsyncTransportWrapper::WriteCallback* wcb,
63       bool cork, bool lastWrite = false):
64     AsyncTimeout(socket->getEventBase()),
65     socket_(socket),
66     bufs_(std::move(bufs)),
67     wcb_(wcb),
68     cork_(cork),
69     lastWrite_(lastWrite) {}
70
71  private:
72   void timeoutExpired() noexcept override {
73     WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
74     socket_->writeChain(wcb_, std::move(bufs_), flags);
75     if (lastWrite_) {
76       socket_->shutdownWrite();
77     }
78   }
79
80   std::shared_ptr<AsyncSocket> socket_;
81   unique_ptr<IOBuf> bufs_;
82   AsyncTransportWrapper::WriteCallback* wcb_;
83   bool cork_;
84   bool lastWrite_;
85 };
86
87 ///////////////////////////////////////////////////////////////////////////
88 // connect() tests
89 ///////////////////////////////////////////////////////////////////////////
90
91 /**
92  * Test connecting to a server
93  */
94 TEST(AsyncSocketTest, Connect) {
95   // Start listening on a local port
96   TestServer server;
97
98   // Connect using a AsyncSocket
99   EventBase evb;
100   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
101   ConnCallback cb;
102   socket->connect(&cb, server.getAddress(), 30);
103
104   evb.loop();
105
106   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
107   EXPECT_LE(0, socket->getConnectTime().count());
108   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
109 }
110
111 enum class TFOState {
112   DISABLED,
113   ENABLED,
114 };
115
116 class AsyncSocketConnectTest : public ::testing::TestWithParam<TFOState> {};
117
118 std::vector<TFOState> getTestingValues() {
119   std::vector<TFOState> vals;
120   vals.emplace_back(TFOState::DISABLED);
121
122 #if FOLLY_ALLOW_TFO
123   vals.emplace_back(TFOState::ENABLED);
124 #endif
125   return vals;
126 }
127
128 INSTANTIATE_TEST_CASE_P(
129     ConnectTests,
130     AsyncSocketConnectTest,
131     ::testing::ValuesIn(getTestingValues()));
132
133 /**
134  * Test connecting to a server that isn't listening
135  */
136 TEST(AsyncSocketTest, ConnectRefused) {
137   EventBase evb;
138
139   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
140
141   // Hopefully nothing is actually listening on this address
142   folly::SocketAddress addr("127.0.0.1", 65535);
143   ConnCallback cb;
144   socket->connect(&cb, addr, 30);
145
146   evb.loop();
147
148   EXPECT_EQ(STATE_FAILED, cb.state);
149   EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType());
150   EXPECT_LE(0, socket->getConnectTime().count());
151   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
152 }
153
154 /**
155  * Test connection timeout
156  */
157 TEST(AsyncSocketTest, ConnectTimeout) {
158   EventBase evb;
159
160   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
161
162   // Try connecting to server that won't respond.
163   //
164   // This depends somewhat on the network where this test is run.
165   // Hopefully this IP will be routable but unresponsive.
166   // (Alternatively, we could try listening on a local raw socket, but that
167   // normally requires root privileges.)
168   auto host =
169       SocketAddressTestHelper::isIPv6Enabled() ?
170       SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6 :
171       SocketAddressTestHelper::isIPv4Enabled() ?
172       SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4 :
173       nullptr;
174   SocketAddress addr(host, 65535);
175   ConnCallback cb;
176   socket->connect(&cb, addr, 1); // also set a ridiculously small timeout
177
178   evb.loop();
179
180   ASSERT_EQ(cb.state, STATE_FAILED);
181   ASSERT_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);
182
183   // Verify that we can still get the peer address after a timeout.
184   // Use case is if the client was created from a client pool, and we want
185   // to log which peer failed.
186   folly::SocketAddress peer;
187   socket->getPeerAddress(&peer);
188   ASSERT_EQ(peer, addr);
189   EXPECT_LE(0, socket->getConnectTime().count());
190   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(1));
191 }
192
193 /**
194  * Test writing immediately after connecting, without waiting for connect
195  * to finish.
196  */
197 TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
198   TestServer server;
199
200   // connect()
201   EventBase evb;
202   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
203
204   if (GetParam() == TFOState::ENABLED) {
205     socket->enableTFO();
206   }
207
208   ConnCallback ccb;
209   socket->connect(&ccb, server.getAddress(), 30);
210
211   // write()
212   char buf[128];
213   memset(buf, 'a', sizeof(buf));
214   WriteCallback wcb;
215   socket->write(&wcb, buf, sizeof(buf));
216
217   // Loop.  We don't bother accepting on the server socket yet.
218   // The kernel should be able to buffer the write request so it can succeed.
219   evb.loop();
220
221   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
222   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
223
224   // Make sure the server got a connection and received the data
225   socket->close();
226   server.verifyConnection(buf, sizeof(buf));
227
228   ASSERT_TRUE(socket->isClosedBySelf());
229   ASSERT_FALSE(socket->isClosedByPeer());
230   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
231 }
232
233 /**
234  * Test connecting using a nullptr connect callback.
235  */
236 TEST_P(AsyncSocketConnectTest, ConnectNullCallback) {
237   TestServer server;
238
239   // connect()
240   EventBase evb;
241   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
242   if (GetParam() == TFOState::ENABLED) {
243     socket->enableTFO();
244   }
245
246   socket->connect(nullptr, server.getAddress(), 30);
247
248   // write some data, just so we have some way of verifing
249   // that the socket works correctly after connecting
250   char buf[128];
251   memset(buf, 'a', sizeof(buf));
252   WriteCallback wcb;
253   socket->write(&wcb, buf, sizeof(buf));
254
255   evb.loop();
256
257   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
258
259   // Make sure the server got a connection and received the data
260   socket->close();
261   server.verifyConnection(buf, sizeof(buf));
262
263   ASSERT_TRUE(socket->isClosedBySelf());
264   ASSERT_FALSE(socket->isClosedByPeer());
265 }
266
267 /**
268  * Test calling both write() and close() immediately after connecting, without
269  * waiting for connect to finish.
270  *
271  * This exercises the STATE_CONNECTING_CLOSING code.
272  */
273 TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) {
274   TestServer server;
275
276   // connect()
277   EventBase evb;
278   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
279   if (GetParam() == TFOState::ENABLED) {
280     socket->enableTFO();
281   }
282   ConnCallback ccb;
283   socket->connect(&ccb, server.getAddress(), 30);
284
285   // write()
286   char buf[128];
287   memset(buf, 'a', sizeof(buf));
288   WriteCallback wcb;
289   socket->write(&wcb, buf, sizeof(buf));
290
291   // close()
292   socket->close();
293
294   // Loop.  We don't bother accepting on the server socket yet.
295   // The kernel should be able to buffer the write request so it can succeed.
296   evb.loop();
297
298   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
299   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
300
301   // Make sure the server got a connection and received the data
302   server.verifyConnection(buf, sizeof(buf));
303
304   ASSERT_TRUE(socket->isClosedBySelf());
305   ASSERT_FALSE(socket->isClosedByPeer());
306 }
307
308 /**
309  * Test calling close() immediately after connect()
310  */
311 TEST(AsyncSocketTest, ConnectAndClose) {
312   TestServer server;
313
314   // Connect using a AsyncSocket
315   EventBase evb;
316   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
317   ConnCallback ccb;
318   socket->connect(&ccb, server.getAddress(), 30);
319
320   // Hopefully the connect didn't succeed immediately.
321   // If it did, we can't exercise the close-while-connecting code path.
322   if (ccb.state == STATE_SUCCEEDED) {
323     LOG(INFO) << "connect() succeeded immediately; aborting test "
324                        "of close-during-connect behavior";
325     return;
326   }
327
328   socket->close();
329
330   // Loop, although there shouldn't be anything to do.
331   evb.loop();
332
333   // Make sure the connection was aborted
334   ASSERT_EQ(ccb.state, STATE_FAILED);
335
336   ASSERT_TRUE(socket->isClosedBySelf());
337   ASSERT_FALSE(socket->isClosedByPeer());
338 }
339
340 /**
341  * Test calling closeNow() immediately after connect()
342  *
343  * This should be identical to the normal close behavior.
344  */
345 TEST(AsyncSocketTest, ConnectAndCloseNow) {
346   TestServer server;
347
348   // Connect using a AsyncSocket
349   EventBase evb;
350   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
351   ConnCallback ccb;
352   socket->connect(&ccb, server.getAddress(), 30);
353
354   // Hopefully the connect didn't succeed immediately.
355   // If it did, we can't exercise the close-while-connecting code path.
356   if (ccb.state == STATE_SUCCEEDED) {
357     LOG(INFO) << "connect() succeeded immediately; aborting test "
358                        "of closeNow()-during-connect behavior";
359     return;
360   }
361
362   socket->closeNow();
363
364   // Loop, although there shouldn't be anything to do.
365   evb.loop();
366
367   // Make sure the connection was aborted
368   ASSERT_EQ(ccb.state, STATE_FAILED);
369
370   ASSERT_TRUE(socket->isClosedBySelf());
371   ASSERT_FALSE(socket->isClosedByPeer());
372 }
373
374 /**
375  * Test calling both write() and closeNow() immediately after connecting,
376  * without waiting for connect to finish.
377  *
378  * This should abort the pending write.
379  */
380 TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
381   TestServer server;
382
383   // connect()
384   EventBase evb;
385   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
386   ConnCallback ccb;
387   socket->connect(&ccb, server.getAddress(), 30);
388
389   // Hopefully the connect didn't succeed immediately.
390   // If it did, we can't exercise the close-while-connecting code path.
391   if (ccb.state == STATE_SUCCEEDED) {
392     LOG(INFO) << "connect() succeeded immediately; aborting test "
393                        "of write-during-connect behavior";
394     return;
395   }
396
397   // write()
398   char buf[128];
399   memset(buf, 'a', sizeof(buf));
400   WriteCallback wcb;
401   socket->write(&wcb, buf, sizeof(buf));
402
403   // close()
404   socket->closeNow();
405
406   // Loop, although there shouldn't be anything to do.
407   evb.loop();
408
409   ASSERT_EQ(ccb.state, STATE_FAILED);
410   ASSERT_EQ(wcb.state, STATE_FAILED);
411
412   ASSERT_TRUE(socket->isClosedBySelf());
413   ASSERT_FALSE(socket->isClosedByPeer());
414 }
415
416 /**
417  * Test installing a read callback immediately, before connect() finishes.
418  */
419 TEST_P(AsyncSocketConnectTest, ConnectAndRead) {
420   TestServer server;
421
422   // connect()
423   EventBase evb;
424   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
425   if (GetParam() == TFOState::ENABLED) {
426     socket->enableTFO();
427   }
428
429   ConnCallback ccb;
430   socket->connect(&ccb, server.getAddress(), 30);
431
432   ReadCallback rcb;
433   socket->setReadCB(&rcb);
434
435   if (GetParam() == TFOState::ENABLED) {
436     // Trigger a connection
437     socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
438   }
439
440   // Even though we haven't looped yet, we should be able to accept
441   // the connection and send data to it.
442   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
443   uint8_t buf[128];
444   memset(buf, 'a', sizeof(buf));
445   acceptedSocket->write(buf, sizeof(buf));
446   acceptedSocket->flush();
447   acceptedSocket->close();
448
449   // Loop, although there shouldn't be anything to do.
450   evb.loop();
451
452   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
453   ASSERT_EQ(rcb.buffers.size(), 1);
454   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf));
455   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
456
457   ASSERT_FALSE(socket->isClosedBySelf());
458   ASSERT_FALSE(socket->isClosedByPeer());
459 }
460
461 /**
462  * Test installing a read callback and then closing immediately before the
463  * connect attempt finishes.
464  */
465 TEST(AsyncSocketTest, ConnectReadAndClose) {
466   TestServer server;
467
468   // connect()
469   EventBase evb;
470   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
471   ConnCallback ccb;
472   socket->connect(&ccb, server.getAddress(), 30);
473
474   // Hopefully the connect didn't succeed immediately.
475   // If it did, we can't exercise the close-while-connecting code path.
476   if (ccb.state == STATE_SUCCEEDED) {
477     LOG(INFO) << "connect() succeeded immediately; aborting test "
478                        "of read-during-connect behavior";
479     return;
480   }
481
482   ReadCallback rcb;
483   socket->setReadCB(&rcb);
484
485   // close()
486   socket->close();
487
488   // Loop, although there shouldn't be anything to do.
489   evb.loop();
490
491   ASSERT_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
492   ASSERT_EQ(rcb.buffers.size(), 0);
493   ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
494
495   ASSERT_TRUE(socket->isClosedBySelf());
496   ASSERT_FALSE(socket->isClosedByPeer());
497 }
498
499 /**
500  * Test both writing and installing a read callback immediately,
501  * before connect() finishes.
502  */
503 TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) {
504   TestServer server;
505
506   // connect()
507   EventBase evb;
508   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
509   if (GetParam() == TFOState::ENABLED) {
510     socket->enableTFO();
511   }
512   ConnCallback ccb;
513   socket->connect(&ccb, server.getAddress(), 30);
514
515   // write()
516   char buf1[128];
517   memset(buf1, 'a', sizeof(buf1));
518   WriteCallback wcb;
519   socket->write(&wcb, buf1, sizeof(buf1));
520
521   // set a read callback
522   ReadCallback rcb;
523   socket->setReadCB(&rcb);
524
525   // Even though we haven't looped yet, we should be able to accept
526   // the connection and send data to it.
527   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
528   uint8_t buf2[128];
529   memset(buf2, 'b', sizeof(buf2));
530   acceptedSocket->write(buf2, sizeof(buf2));
531   acceptedSocket->flush();
532
533   // shut down the write half of acceptedSocket, so that the AsyncSocket
534   // will stop reading and we can break out of the event loop.
535   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
536
537   // Loop
538   evb.loop();
539
540   // Make sure the connect succeeded
541   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
542
543   // Make sure the AsyncSocket read the data written by the accepted socket
544   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
545   ASSERT_EQ(rcb.buffers.size(), 1);
546   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf2));
547   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);
548
549   // Close the AsyncSocket so we'll see EOF on acceptedSocket
550   socket->close();
551
552   // Make sure the accepted socket saw the data written by the AsyncSocket
553   uint8_t readbuf[sizeof(buf1)];
554   acceptedSocket->readAll(readbuf, sizeof(readbuf));
555   ASSERT_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
556   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
557   ASSERT_EQ(bytesRead, 0);
558
559   ASSERT_FALSE(socket->isClosedBySelf());
560   ASSERT_TRUE(socket->isClosedByPeer());
561 }
562
563 /**
564  * Test writing to the socket then shutting down writes before the connect
565  * attempt finishes.
566  */
567 TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
568   TestServer server;
569
570   // connect()
571   EventBase evb;
572   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
573   ConnCallback ccb;
574   socket->connect(&ccb, server.getAddress(), 30);
575
576   // Hopefully the connect didn't succeed immediately.
577   // If it did, we can't exercise the write-while-connecting code path.
578   if (ccb.state == STATE_SUCCEEDED) {
579     LOG(INFO) << "connect() succeeded immediately; skipping test";
580     return;
581   }
582
583   // Ask to write some data
584   char wbuf[128];
585   memset(wbuf, 'a', sizeof(wbuf));
586   WriteCallback wcb;
587   socket->write(&wcb, wbuf, sizeof(wbuf));
588   socket->shutdownWrite();
589
590   // Shutdown writes
591   socket->shutdownWrite();
592
593   // Even though we haven't looped yet, we should be able to accept
594   // the connection.
595   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
596
597   // Since the connection is still in progress, there should be no data to
598   // read yet.  Verify that the accepted socket is not readable.
599   struct pollfd fds[1];
600   fds[0].fd = acceptedSocket->getSocketFD();
601   fds[0].events = POLLIN;
602   fds[0].revents = 0;
603   int rc = poll(fds, 1, 0);
604   ASSERT_EQ(rc, 0);
605
606   // Write data to the accepted socket
607   uint8_t acceptedWbuf[192];
608   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
609   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
610   acceptedSocket->flush();
611
612   // Loop
613   evb.loop();
614
615   // The loop should have completed the connection, written the queued data,
616   // and shutdown writes on the socket.
617   //
618   // Check that the connection was completed successfully and that the write
619   // callback succeeded.
620   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
621   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
622
623   // Check that we can read the data that was written to the socket, and that
624   // we see an EOF, since its socket was half-shutdown.
625   uint8_t readbuf[sizeof(wbuf)];
626   acceptedSocket->readAll(readbuf, sizeof(readbuf));
627   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
628   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
629   ASSERT_EQ(bytesRead, 0);
630
631   // Close the accepted socket.  This will cause it to see EOF
632   // and uninstall the read callback when we loop next.
633   acceptedSocket->close();
634
635   // Install a read callback, then loop again.
636   ReadCallback rcb;
637   socket->setReadCB(&rcb);
638   evb.loop();
639
640   // This loop should have read the data and seen the EOF
641   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
642   ASSERT_EQ(rcb.buffers.size(), 1);
643   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
644   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
645                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
646
647   ASSERT_FALSE(socket->isClosedBySelf());
648   ASSERT_FALSE(socket->isClosedByPeer());
649 }
650
651 /**
652  * Test reading, writing, and shutting down writes before the connect attempt
653  * finishes.
654  */
655 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
656   TestServer server;
657
658   // connect()
659   EventBase evb;
660   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
661   ConnCallback ccb;
662   socket->connect(&ccb, server.getAddress(), 30);
663
664   // Hopefully the connect didn't succeed immediately.
665   // If it did, we can't exercise the write-while-connecting code path.
666   if (ccb.state == STATE_SUCCEEDED) {
667     LOG(INFO) << "connect() succeeded immediately; skipping test";
668     return;
669   }
670
671   // Install a read callback
672   ReadCallback rcb;
673   socket->setReadCB(&rcb);
674
675   // Ask to write some data
676   char wbuf[128];
677   memset(wbuf, 'a', sizeof(wbuf));
678   WriteCallback wcb;
679   socket->write(&wcb, wbuf, sizeof(wbuf));
680
681   // Shutdown writes
682   socket->shutdownWrite();
683
684   // Even though we haven't looped yet, we should be able to accept
685   // the connection.
686   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
687
688   // Since the connection is still in progress, there should be no data to
689   // read yet.  Verify that the accepted socket is not readable.
690   struct pollfd fds[1];
691   fds[0].fd = acceptedSocket->getSocketFD();
692   fds[0].events = POLLIN;
693   fds[0].revents = 0;
694   int rc = poll(fds, 1, 0);
695   ASSERT_EQ(rc, 0);
696
697   // Write data to the accepted socket
698   uint8_t acceptedWbuf[192];
699   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
700   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
701   acceptedSocket->flush();
702   // Shutdown writes to the accepted socket.  This will cause it to see EOF
703   // and uninstall the read callback.
704   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
705
706   // Loop
707   evb.loop();
708
709   // The loop should have completed the connection, written the queued data,
710   // shutdown writes on the socket, read the data we wrote to it, and see the
711   // EOF.
712   //
713   // Check that the connection was completed successfully and that the read
714   // and write callbacks were invoked as expected.
715   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
716   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
717   ASSERT_EQ(rcb.buffers.size(), 1);
718   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
719   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
720                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
721   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
722
723   // Check that we can read the data that was written to the socket, and that
724   // we see an EOF, since its socket was half-shutdown.
725   uint8_t readbuf[sizeof(wbuf)];
726   acceptedSocket->readAll(readbuf, sizeof(readbuf));
727   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
728   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
729   ASSERT_EQ(bytesRead, 0);
730
731   // Fully close both sockets
732   acceptedSocket->close();
733   socket->close();
734
735   ASSERT_FALSE(socket->isClosedBySelf());
736   ASSERT_TRUE(socket->isClosedByPeer());
737 }
738
739 /**
740  * Test reading, writing, and calling shutdownWriteNow() before the
741  * connect attempt finishes.
742  */
743 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
744   TestServer server;
745
746   // connect()
747   EventBase evb;
748   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
749   ConnCallback ccb;
750   socket->connect(&ccb, server.getAddress(), 30);
751
752   // Hopefully the connect didn't succeed immediately.
753   // If it did, we can't exercise the write-while-connecting code path.
754   if (ccb.state == STATE_SUCCEEDED) {
755     LOG(INFO) << "connect() succeeded immediately; skipping test";
756     return;
757   }
758
759   // Install a read callback
760   ReadCallback rcb;
761   socket->setReadCB(&rcb);
762
763   // Ask to write some data
764   char wbuf[128];
765   memset(wbuf, 'a', sizeof(wbuf));
766   WriteCallback wcb;
767   socket->write(&wcb, wbuf, sizeof(wbuf));
768
769   // Shutdown writes immediately.
770   // This should immediately discard the data that we just tried to write.
771   socket->shutdownWriteNow();
772
773   // Verify that writeError() was invoked on the write callback.
774   ASSERT_EQ(wcb.state, STATE_FAILED);
775   ASSERT_EQ(wcb.bytesWritten, 0);
776
777   // Even though we haven't looped yet, we should be able to accept
778   // the connection.
779   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
780
781   // Since the connection is still in progress, there should be no data to
782   // read yet.  Verify that the accepted socket is not readable.
783   struct pollfd fds[1];
784   fds[0].fd = acceptedSocket->getSocketFD();
785   fds[0].events = POLLIN;
786   fds[0].revents = 0;
787   int rc = poll(fds, 1, 0);
788   ASSERT_EQ(rc, 0);
789
790   // Write data to the accepted socket
791   uint8_t acceptedWbuf[192];
792   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
793   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
794   acceptedSocket->flush();
795   // Shutdown writes to the accepted socket.  This will cause it to see EOF
796   // and uninstall the read callback.
797   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
798
799   // Loop
800   evb.loop();
801
802   // The loop should have completed the connection, written the queued data,
803   // shutdown writes on the socket, read the data we wrote to it, and see the
804   // EOF.
805   //
806   // Check that the connection was completed successfully and that the read
807   // callback was invoked as expected.
808   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
809   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
810   ASSERT_EQ(rcb.buffers.size(), 1);
811   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
812   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
813                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
814
815   // Since we used shutdownWriteNow(), it should have discarded all pending
816   // write data.  Verify we see an immediate EOF when reading from the accepted
817   // socket.
818   uint8_t readbuf[sizeof(wbuf)];
819   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
820   ASSERT_EQ(bytesRead, 0);
821
822   // Fully close both sockets
823   acceptedSocket->close();
824   socket->close();
825
826   ASSERT_FALSE(socket->isClosedBySelf());
827   ASSERT_TRUE(socket->isClosedByPeer());
828 }
829
830 // Helper function for use in testConnectOptWrite()
831 // Temporarily disable the read callback
832 void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
833   // Uninstall the read callback
834   socket->setReadCB(nullptr);
835   // Schedule the read callback to be reinstalled after 1ms
836   socket->getEventBase()->runInLoop(
837       std::bind(&AsyncSocket::setReadCB, socket, rcb));
838 }
839
840 /**
841  * Test connect+write, then have the connect callback perform another write.
842  *
843  * This tests interaction of the optimistic writing after connect with
844  * additional write attempts that occur in the connect callback.
845  */
846 void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
847   TestServer server;
848   EventBase evb;
849   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
850
851   // connect()
852   ConnCallback ccb;
853   socket->connect(&ccb, server.getAddress(), 30);
854
855   // Hopefully the connect didn't succeed immediately.
856   // If it did, we can't exercise the optimistic write code path.
857   if (ccb.state == STATE_SUCCEEDED) {
858     LOG(INFO) << "connect() succeeded immediately; aborting test "
859                        "of optimistic write behavior";
860     return;
861   }
862
863   // Tell the connect callback to perform a write when the connect succeeds
864   WriteCallback wcb2;
865   scoped_array<char> buf2(new char[size2]);
866   memset(buf2.get(), 'b', size2);
867   if (size2 > 0) {
868     ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
869     // Tell the second write callback to close the connection when it is done
870     wcb2.successCallback = [&] { socket->closeNow(); };
871   }
872
873   // Schedule one write() immediately, before the connect finishes
874   scoped_array<char> buf1(new char[size1]);
875   memset(buf1.get(), 'a', size1);
876   WriteCallback wcb1;
877   if (size1 > 0) {
878     socket->write(&wcb1, buf1.get(), size1);
879   }
880
881   if (close) {
882     // immediately perform a close, before connect() completes
883     socket->close();
884   }
885
886   // Start reading from the other endpoint after 10ms.
887   // If we're using large buffers, we have to read so that the writes don't
888   // block forever.
889   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
890   ReadCallback rcb;
891   rcb.dataAvailableCallback = std::bind(tmpDisableReads,
892                                         acceptedSocket.get(), &rcb);
893   socket->getEventBase()->tryRunAfterDelay(
894       std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb),
895       10);
896
897   // Loop.  We don't bother accepting on the server socket yet.
898   // The kernel should be able to buffer the write request so it can succeed.
899   evb.loop();
900
901   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
902   if (size1 > 0) {
903     ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
904   }
905   if (size2 > 0) {
906     ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
907   }
908
909   socket->close();
910
911   // Make sure the read callback received all of the data
912   size_t bytesRead = 0;
913   for (vector<ReadCallback::Buffer>::const_iterator it = rcb.buffers.begin();
914        it != rcb.buffers.end();
915        ++it) {
916     size_t start = bytesRead;
917     bytesRead += it->length;
918     size_t end = bytesRead;
919     if (start < size1) {
920       size_t cmpLen = min(size1, end) - start;
921       ASSERT_EQ(memcmp(it->buffer, buf1.get() + start, cmpLen), 0);
922     }
923     if (end > size1 && end <= size1 + size2) {
924       size_t itOffset;
925       size_t buf2Offset;
926       size_t cmpLen;
927       if (start >= size1) {
928         itOffset = 0;
929         buf2Offset = start - size1;
930         cmpLen = end - start;
931       } else {
932         itOffset = size1 - start;
933         buf2Offset = 0;
934         cmpLen = end - size1;
935       }
936       ASSERT_EQ(memcmp(it->buffer + itOffset, buf2.get() + buf2Offset,
937                                cmpLen),
938                         0);
939     }
940   }
941   ASSERT_EQ(bytesRead, size1 + size2);
942 }
943
944 TEST(AsyncSocketTest, ConnectCallbackWrite) {
945   // Test using small writes that should both succeed immediately
946   testConnectOptWrite(100, 200);
947
948   // Test using a large buffer in the connect callback, that should block
949   const size_t largeSize = 32 * 1024 * 1024;
950   testConnectOptWrite(100, largeSize);
951
952   // Test using a large initial write
953   testConnectOptWrite(largeSize, 100);
954
955   // Test using two large buffers
956   testConnectOptWrite(largeSize, largeSize);
957
958   // Test a small write in the connect callback,
959   // but no immediate write before connect completes
960   testConnectOptWrite(0, 64);
961
962   // Test a large write in the connect callback,
963   // but no immediate write before connect completes
964   testConnectOptWrite(0, largeSize);
965
966   // Test connect, a small write, then immediately call close() before connect
967   // completes
968   testConnectOptWrite(211, 0, true);
969
970   // Test connect, a large immediate write (that will block), then immediately
971   // call close() before connect completes
972   testConnectOptWrite(largeSize, 0, true);
973 }
974
975 ///////////////////////////////////////////////////////////////////////////
976 // write() related tests
977 ///////////////////////////////////////////////////////////////////////////
978
979 /**
980  * Test writing using a nullptr callback
981  */
982 TEST(AsyncSocketTest, WriteNullCallback) {
983   TestServer server;
984
985   // connect()
986   EventBase evb;
987   std::shared_ptr<AsyncSocket> socket =
988     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
989   evb.loop(); // loop until the socket is connected
990
991   // write() with a nullptr callback
992   char buf[128];
993   memset(buf, 'a', sizeof(buf));
994   socket->write(nullptr, buf, sizeof(buf));
995
996   evb.loop(); // loop until the data is sent
997
998   // Make sure the server got a connection and received the data
999   socket->close();
1000   server.verifyConnection(buf, sizeof(buf));
1001
1002   ASSERT_TRUE(socket->isClosedBySelf());
1003   ASSERT_FALSE(socket->isClosedByPeer());
1004 }
1005
1006 /**
1007  * Test writing with a send timeout
1008  */
1009 TEST(AsyncSocketTest, WriteTimeout) {
1010   TestServer server;
1011
1012   // connect()
1013   EventBase evb;
1014   std::shared_ptr<AsyncSocket> socket =
1015     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1016   evb.loop(); // loop until the socket is connected
1017
1018   // write() a large chunk of data, with no-one on the other end reading.
1019   // Tricky: the kernel caches the connection metrics for recently-used
1020   // routes (see tcp_no_metrics_save) so a freshly opened connection can
1021   // have a send buffer size bigger than wmem_default.  This makes the test
1022   // flaky on contbuild if writeLength is < wmem_max (20M on our systems).
1023   size_t writeLength = 32 * 1024 * 1024;
1024   uint32_t timeout = 200;
1025   socket->setSendTimeout(timeout);
1026   scoped_array<char> buf(new char[writeLength]);
1027   memset(buf.get(), 'a', writeLength);
1028   WriteCallback wcb;
1029   socket->write(&wcb, buf.get(), writeLength);
1030
1031   TimePoint start;
1032   evb.loop();
1033   TimePoint end;
1034
1035   // Make sure the write attempt timed out as requested
1036   ASSERT_EQ(wcb.state, STATE_FAILED);
1037   ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);
1038
1039   // Check that the write timed out within a reasonable period of time.
1040   // We don't check for exactly the specified timeout, since AsyncSocket only
1041   // times out when it hasn't made progress for that period of time.
1042   //
1043   // On linux, the first write sends a few hundred kb of data, then blocks for
1044   // writability, and then unblocks again after 40ms and is able to write
1045   // another smaller of data before blocking permanently.  Therefore it doesn't
1046   // time out until 40ms + timeout.
1047   //
1048   // I haven't fully verified the cause of this, but I believe it probably
1049   // occurs because the receiving end delays sending an ack for up to 40ms.
1050   // (This is the default value for TCP_DELACK_MIN.)  Once the sender receives
1051   // the ack, it can send some more data.  However, after that point the
1052   // receiver's kernel buffer is full.  This 40ms delay happens even with
1053   // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints.  However, the
1054   // kernel may be automatically disabling TCP_QUICKACK after receiving some
1055   // data.
1056   //
1057   // For now, we simply check that the timeout occurred within 160ms of
1058   // the requested value.
1059   T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
1060 }
1061
1062 /**
1063  * Test writing to a socket that the remote endpoint has closed
1064  */
1065 TEST(AsyncSocketTest, WritePipeError) {
1066   TestServer server;
1067
1068   // connect()
1069   EventBase evb;
1070   std::shared_ptr<AsyncSocket> socket =
1071     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1072   socket->setSendTimeout(1000);
1073   evb.loop(); // loop until the socket is connected
1074
1075   // accept and immediately close the socket
1076   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1077   acceptedSocket->close();
1078
1079   // write() a large chunk of data
1080   size_t writeLength = 32 * 1024 * 1024;
1081   scoped_array<char> buf(new char[writeLength]);
1082   memset(buf.get(), 'a', writeLength);
1083   WriteCallback wcb;
1084   socket->write(&wcb, buf.get(), writeLength);
1085
1086   evb.loop();
1087
1088   // Make sure the write failed.
1089   // It would be nice if AsyncSocketException could convey the errno value,
1090   // so that we could check for EPIPE
1091   ASSERT_EQ(wcb.state, STATE_FAILED);
1092   ASSERT_EQ(wcb.exception.getType(),
1093                     AsyncSocketException::INTERNAL_ERROR);
1094
1095   ASSERT_FALSE(socket->isClosedBySelf());
1096   ASSERT_FALSE(socket->isClosedByPeer());
1097 }
1098
1099 /**
1100  * Test that bytes written is correctly computed in case of write failure
1101  */
1102 TEST(AsyncSocketTest, WriteErrorCallbackBytesWritten) {
1103   // Send and receive buffer sizes for the sockets.
1104   const int sockBufSize = 8 * 1024;
1105
1106   TestServer server(false, sockBufSize);
1107
1108   AsyncSocket::OptionMap options{
1109       {{SOL_SOCKET, SO_SNDBUF}, sockBufSize},
1110       {{SOL_SOCKET, SO_RCVBUF}, sockBufSize},
1111       {{IPPROTO_TCP, TCP_NODELAY}, 1},
1112   };
1113
1114   // The current thread will be used by the receiver - use a separate thread
1115   // for the sender.
1116   EventBase senderEvb;
1117   std::thread senderThread([&]() { senderEvb.loopForever(); });
1118
1119   ConnCallback ccb;
1120   std::shared_ptr<AsyncSocket> socket;
1121
1122   senderEvb.runInEventBaseThreadAndWait([&]() {
1123     socket = AsyncSocket::newSocket(&senderEvb);
1124     socket->connect(&ccb, server.getAddress(), 30, options);
1125   });
1126
1127   // accept the socket on the server side
1128   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1129
1130   // Send a big (45KB) write so that it is partially written. The first write
1131   // is 16KB (8KB on both sides) and subsequent writes are 8KB each. Reading
1132   // just under 24KB would cause 3-4 writes for the total of 32-40KB in the
1133   // following sequence: 16KB + 8KB + 8KB (+ 8KB). This ensures that not all
1134   // bytes are written when the socket is reset. Having at least 3 writes
1135   // ensures that the total size (45KB) would be exceeed in case of overcounting
1136   // based on the initial write size of 16KB.
1137   constexpr size_t sendSize = 45 * 1024;
1138   auto const sendBuf = std::vector<char>(sendSize, 'a');
1139
1140   WriteCallback wcb;
1141
1142   senderEvb.runInEventBaseThreadAndWait(
1143       [&]() { socket->write(&wcb, sendBuf.data(), sendSize); });
1144
1145   // Reading 20KB would cause three additional writes of 8KB, but less
1146   // than 45KB total, so the socket is reset before all bytes are written.
1147   constexpr size_t recvSize = 20 * 1024;
1148   uint8_t recvBuf[recvSize];
1149   int bytesRead = acceptedSocket->readAll(recvBuf, sizeof(recvBuf));
1150
1151   acceptedSocket->closeWithReset();
1152
1153   senderEvb.terminateLoopSoon();
1154   senderThread.join();
1155
1156   LOG(INFO) << "Bytes written: " << wcb.bytesWritten;
1157
1158   ASSERT_EQ(STATE_FAILED, wcb.state);
1159   ASSERT_GE(wcb.bytesWritten, bytesRead);
1160   ASSERT_LE(wcb.bytesWritten, sendSize);
1161   ASSERT_EQ(recvSize, bytesRead);
1162   ASSERT(32 * 1024 == wcb.bytesWritten || 40 * 1024 == wcb.bytesWritten);
1163 }
1164
1165 /**
1166  * Test writing a mix of simple buffers and IOBufs
1167  */
1168 TEST(AsyncSocketTest, WriteIOBuf) {
1169   TestServer server;
1170
1171   // connect()
1172   EventBase evb;
1173   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1174   ConnCallback ccb;
1175   socket->connect(&ccb, server.getAddress(), 30);
1176
1177   // Accept the connection
1178   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1179   ReadCallback rcb;
1180   acceptedSocket->setReadCB(&rcb);
1181
1182   // Check if EOR tracking flag can be set and reset.
1183   EXPECT_FALSE(socket->isEorTrackingEnabled());
1184   socket->setEorTracking(true);
1185   EXPECT_TRUE(socket->isEorTrackingEnabled());
1186   socket->setEorTracking(false);
1187   EXPECT_FALSE(socket->isEorTrackingEnabled());
1188
1189   // Write a simple buffer to the socket
1190   constexpr size_t simpleBufLength = 5;
1191   char simpleBuf[simpleBufLength];
1192   memset(simpleBuf, 'a', simpleBufLength);
1193   WriteCallback wcb;
1194   socket->write(&wcb, simpleBuf, simpleBufLength);
1195
1196   // Write a single-element IOBuf chain
1197   size_t buf1Length = 7;
1198   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1199   memset(buf1->writableData(), 'b', buf1Length);
1200   buf1->append(buf1Length);
1201   unique_ptr<IOBuf> buf1Copy(buf1->clone());
1202   WriteCallback wcb2;
1203   socket->writeChain(&wcb2, std::move(buf1));
1204
1205   // Write a multiple-element IOBuf chain
1206   size_t buf2Length = 11;
1207   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1208   memset(buf2->writableData(), 'c', buf2Length);
1209   buf2->append(buf2Length);
1210   size_t buf3Length = 13;
1211   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1212   memset(buf3->writableData(), 'd', buf3Length);
1213   buf3->append(buf3Length);
1214   buf2->appendChain(std::move(buf3));
1215   unique_ptr<IOBuf> buf2Copy(buf2->clone());
1216   buf2Copy->coalesce();
1217   WriteCallback wcb3;
1218   socket->writeChain(&wcb3, std::move(buf2));
1219   socket->shutdownWrite();
1220
1221   // Let the reads and writes run to completion
1222   evb.loop();
1223
1224   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1225   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1226   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1227
1228   // Make sure the reader got the right data in the right order
1229   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1230   ASSERT_EQ(rcb.buffers.size(), 1);
1231   ASSERT_EQ(rcb.buffers[0].length,
1232       simpleBufLength + buf1Length + buf2Length + buf3Length);
1233   ASSERT_EQ(
1234       memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
1235   ASSERT_EQ(
1236       memcmp(rcb.buffers[0].buffer + simpleBufLength,
1237           buf1Copy->data(), buf1Copy->length()), 0);
1238   ASSERT_EQ(
1239       memcmp(rcb.buffers[0].buffer + simpleBufLength + buf1Length,
1240           buf2Copy->data(), buf2Copy->length()), 0);
1241
1242   acceptedSocket->close();
1243   socket->close();
1244
1245   ASSERT_TRUE(socket->isClosedBySelf());
1246   ASSERT_FALSE(socket->isClosedByPeer());
1247 }
1248
1249 TEST(AsyncSocketTest, WriteIOBufCorked) {
1250   TestServer server;
1251
1252   // connect()
1253   EventBase evb;
1254   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1255   ConnCallback ccb;
1256   socket->connect(&ccb, server.getAddress(), 30);
1257
1258   // Accept the connection
1259   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1260   ReadCallback rcb;
1261   acceptedSocket->setReadCB(&rcb);
1262
1263   // Do three writes, 100ms apart, with the "cork" flag set
1264   // on the second write.  The reader should see the first write
1265   // arrive by itself, followed by the second and third writes
1266   // arriving together.
1267   size_t buf1Length = 5;
1268   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1269   memset(buf1->writableData(), 'a', buf1Length);
1270   buf1->append(buf1Length);
1271   size_t buf2Length = 7;
1272   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1273   memset(buf2->writableData(), 'b', buf2Length);
1274   buf2->append(buf2Length);
1275   size_t buf3Length = 11;
1276   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1277   memset(buf3->writableData(), 'c', buf3Length);
1278   buf3->append(buf3Length);
1279   WriteCallback wcb1;
1280   socket->writeChain(&wcb1, std::move(buf1));
1281   WriteCallback wcb2;
1282   DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
1283   write2.scheduleTimeout(100);
1284   WriteCallback wcb3;
1285   DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
1286   write3.scheduleTimeout(140);
1287
1288   evb.loop();
1289   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1290   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1291   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1292   if (wcb3.state != STATE_SUCCEEDED) {
1293     throw(wcb3.exception);
1294   }
1295   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1296
1297   // Make sure the reader got the data with the right grouping
1298   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1299   ASSERT_EQ(rcb.buffers.size(), 2);
1300   ASSERT_EQ(rcb.buffers[0].length, buf1Length);
1301   ASSERT_EQ(rcb.buffers[1].length, buf2Length + buf3Length);
1302
1303   acceptedSocket->close();
1304   socket->close();
1305
1306   ASSERT_TRUE(socket->isClosedBySelf());
1307   ASSERT_FALSE(socket->isClosedByPeer());
1308 }
1309
1310 /**
1311  * Test performing a zero-length write
1312  */
1313 TEST(AsyncSocketTest, ZeroLengthWrite) {
1314   TestServer server;
1315
1316   // connect()
1317   EventBase evb;
1318   std::shared_ptr<AsyncSocket> socket =
1319     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1320   evb.loop(); // loop until the socket is connected
1321
1322   auto acceptedSocket = server.acceptAsync(&evb);
1323   ReadCallback rcb;
1324   acceptedSocket->setReadCB(&rcb);
1325
1326   size_t len1 = 1024*1024;
1327   size_t len2 = 1024*1024;
1328   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1329   memset(buf.get(), 'a', len1);
1330   memset(buf.get(), 'b', len2);
1331
1332   WriteCallback wcb1;
1333   WriteCallback wcb2;
1334   WriteCallback wcb3;
1335   WriteCallback wcb4;
1336   socket->write(&wcb1, buf.get(), 0);
1337   socket->write(&wcb2, buf.get(), len1);
1338   socket->write(&wcb3, buf.get() + len1, 0);
1339   socket->write(&wcb4, buf.get() + len1, len2);
1340   socket->close();
1341
1342   evb.loop(); // loop until the data is sent
1343
1344   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1345   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1346   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1347   ASSERT_EQ(wcb4.state, STATE_SUCCEEDED);
1348   rcb.verifyData(buf.get(), len1 + len2);
1349
1350   ASSERT_TRUE(socket->isClosedBySelf());
1351   ASSERT_FALSE(socket->isClosedByPeer());
1352 }
1353
1354 TEST(AsyncSocketTest, ZeroLengthWritev) {
1355   TestServer server;
1356
1357   // connect()
1358   EventBase evb;
1359   std::shared_ptr<AsyncSocket> socket =
1360     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1361   evb.loop(); // loop until the socket is connected
1362
1363   auto acceptedSocket = server.acceptAsync(&evb);
1364   ReadCallback rcb;
1365   acceptedSocket->setReadCB(&rcb);
1366
1367   size_t len1 = 1024*1024;
1368   size_t len2 = 1024*1024;
1369   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1370   memset(buf.get(), 'a', len1);
1371   memset(buf.get(), 'b', len2);
1372
1373   WriteCallback wcb;
1374   constexpr size_t iovCount = 4;
1375   struct iovec iov[iovCount];
1376   iov[0].iov_base = buf.get();
1377   iov[0].iov_len = len1;
1378   iov[1].iov_base = buf.get() + len1;
1379   iov[1].iov_len = 0;
1380   iov[2].iov_base = buf.get() + len1;
1381   iov[2].iov_len = len2;
1382   iov[3].iov_base = buf.get() + len1 + len2;
1383   iov[3].iov_len = 0;
1384
1385   socket->writev(&wcb, iov, iovCount);
1386   socket->close();
1387   evb.loop(); // loop until the data is sent
1388
1389   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1390   rcb.verifyData(buf.get(), len1 + len2);
1391
1392   ASSERT_TRUE(socket->isClosedBySelf());
1393   ASSERT_FALSE(socket->isClosedByPeer());
1394 }
1395
1396 ///////////////////////////////////////////////////////////////////////////
1397 // close() related tests
1398 ///////////////////////////////////////////////////////////////////////////
1399
1400 /**
1401  * Test calling close() with pending writes when the socket is already closing.
1402  */
1403 TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
1404   TestServer server;
1405
1406   // connect()
1407   EventBase evb;
1408   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1409   ConnCallback ccb;
1410   socket->connect(&ccb, server.getAddress(), 30);
1411
1412   // accept the socket on the server side
1413   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1414
1415   // Loop to ensure the connect has completed
1416   evb.loop();
1417
1418   // Make sure we are connected
1419   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1420
1421   // Schedule pending writes, until several write attempts have blocked
1422   char buf[128];
1423   memset(buf, 'a', sizeof(buf));
1424   typedef vector< std::shared_ptr<WriteCallback> > WriteCallbackVector;
1425   WriteCallbackVector writeCallbacks;
1426
1427   writeCallbacks.reserve(5);
1428   while (writeCallbacks.size() < 5) {
1429     std::shared_ptr<WriteCallback> wcb(new WriteCallback);
1430
1431     socket->write(wcb.get(), buf, sizeof(buf));
1432     if (wcb->state == STATE_SUCCEEDED) {
1433       // Succeeded immediately.  Keep performing more writes
1434       continue;
1435     }
1436
1437     // This write is blocked.
1438     // Have the write callback call close() when writeError() is invoked
1439     wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
1440     writeCallbacks.push_back(wcb);
1441   }
1442
1443   // Call closeNow() to immediately fail the pending writes
1444   socket->closeNow();
1445
1446   // Make sure writeError() was invoked on all of the pending write callbacks
1447   for (WriteCallbackVector::const_iterator it = writeCallbacks.begin();
1448        it != writeCallbacks.end();
1449        ++it) {
1450     ASSERT_EQ((*it)->state, STATE_FAILED);
1451   }
1452
1453   ASSERT_TRUE(socket->isClosedBySelf());
1454   ASSERT_FALSE(socket->isClosedByPeer());
1455 }
1456
1457 ///////////////////////////////////////////////////////////////////////////
1458 // ImmediateRead related tests
1459 ///////////////////////////////////////////////////////////////////////////
1460
1461 /* AsyncSocket use to verify immediate read works */
1462 class AsyncSocketImmediateRead : public folly::AsyncSocket {
1463  public:
1464   bool immediateReadCalled = false;
1465   explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}
1466  protected:
1467   void checkForImmediateRead() noexcept override {
1468     immediateReadCalled = true;
1469     AsyncSocket::handleRead();
1470   }
1471 };
1472
1473 TEST(AsyncSocket, ConnectReadImmediateRead) {
1474   TestServer server;
1475
1476   const size_t maxBufferSz = 100;
1477   const size_t maxReadsPerEvent = 1;
1478   const size_t expectedDataSz = maxBufferSz * 3;
1479   char expectedData[expectedDataSz];
1480   memset(expectedData, 'j', expectedDataSz);
1481
1482   EventBase evb;
1483   ReadCallback rcb(maxBufferSz);
1484   AsyncSocketImmediateRead socket(&evb);
1485   socket.connect(nullptr, server.getAddress(), 30);
1486
1487   evb.loop(); // loop until the socket is connected
1488
1489   socket.setReadCB(&rcb);
1490   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1491   socket.immediateReadCalled = false;
1492
1493   auto acceptedSocket = server.acceptAsync(&evb);
1494
1495   ReadCallback rcbServer;
1496   WriteCallback wcbServer;
1497   rcbServer.dataAvailableCallback = [&]() {
1498     if (rcbServer.dataRead() == expectedDataSz) {
1499       // write back all data read
1500       rcbServer.verifyData(expectedData, expectedDataSz);
1501       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1502       acceptedSocket->close();
1503     }
1504   };
1505   acceptedSocket->setReadCB(&rcbServer);
1506
1507   // write data
1508   WriteCallback wcb1;
1509   socket.write(&wcb1, expectedData, expectedDataSz);
1510   evb.loop();
1511   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1512   rcb.verifyData(expectedData, expectedDataSz);
1513   ASSERT_EQ(socket.immediateReadCalled, true);
1514
1515   ASSERT_FALSE(socket.isClosedBySelf());
1516   ASSERT_FALSE(socket.isClosedByPeer());
1517 }
1518
1519 TEST(AsyncSocket, ConnectReadUninstallRead) {
1520   TestServer server;
1521
1522   const size_t maxBufferSz = 100;
1523   const size_t maxReadsPerEvent = 1;
1524   const size_t expectedDataSz = maxBufferSz * 3;
1525   char expectedData[expectedDataSz];
1526   memset(expectedData, 'k', expectedDataSz);
1527
1528   EventBase evb;
1529   ReadCallback rcb(maxBufferSz);
1530   AsyncSocketImmediateRead socket(&evb);
1531   socket.connect(nullptr, server.getAddress(), 30);
1532
1533   evb.loop(); // loop until the socket is connected
1534
1535   socket.setReadCB(&rcb);
1536   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1537   socket.immediateReadCalled = false;
1538
1539   auto acceptedSocket = server.acceptAsync(&evb);
1540
1541   ReadCallback rcbServer;
1542   WriteCallback wcbServer;
1543   rcbServer.dataAvailableCallback = [&]() {
1544     if (rcbServer.dataRead() == expectedDataSz) {
1545       // write back all data read
1546       rcbServer.verifyData(expectedData, expectedDataSz);
1547       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1548       acceptedSocket->close();
1549     }
1550   };
1551   acceptedSocket->setReadCB(&rcbServer);
1552
1553   rcb.dataAvailableCallback = [&]() {
1554     // we read data and reset readCB
1555     socket.setReadCB(nullptr);
1556   };
1557
1558   // write data
1559   WriteCallback wcb;
1560   socket.write(&wcb, expectedData, expectedDataSz);
1561   evb.loop();
1562   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1563
1564   /* we shoud've only read maxBufferSz data since readCallback_
1565    * was reset in dataAvailableCallback */
1566   ASSERT_EQ(rcb.dataRead(), maxBufferSz);
1567   ASSERT_EQ(socket.immediateReadCalled, false);
1568
1569   ASSERT_FALSE(socket.isClosedBySelf());
1570   ASSERT_FALSE(socket.isClosedByPeer());
1571 }
1572
1573 // TODO:
1574 // - Test connect() and have the connect callback set the read callback
1575 // - Test connect() and have the connect callback unset the read callback
1576 // - Test reading/writing/closing/destroying the socket in the connect callback
1577 // - Test reading/writing/closing/destroying the socket in the read callback
1578 // - Test reading/writing/closing/destroying the socket in the write callback
1579 // - Test one-way shutdown behavior
1580 // - Test changing the EventBase
1581 //
1582 // - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
1583 //   in connectSuccess(), readDataAvailable(), writeSuccess()
1584
1585
1586 ///////////////////////////////////////////////////////////////////////////
1587 // AsyncServerSocket tests
1588 ///////////////////////////////////////////////////////////////////////////
1589
1590 /**
1591  * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
1592  */
1593 TEST(AsyncSocketTest, ServerAcceptOptions) {
1594   EventBase eventBase;
1595
1596   // Create a server socket
1597   std::shared_ptr<AsyncServerSocket> serverSocket(
1598       AsyncServerSocket::newSocket(&eventBase));
1599   serverSocket->bind(0);
1600   serverSocket->listen(16);
1601   folly::SocketAddress serverAddress;
1602   serverSocket->getAddress(&serverAddress);
1603
1604   // Add a callback to accept one connection then stop the loop
1605   TestAcceptCallback acceptCallback;
1606   acceptCallback.setConnectionAcceptedFn(
1607       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1608         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1609       });
1610   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
1611     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1612   });
1613   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
1614   serverSocket->startAccepting();
1615
1616   // Connect to the server socket
1617   std::shared_ptr<AsyncSocket> socket(
1618       AsyncSocket::newSocket(&eventBase, serverAddress));
1619
1620   eventBase.loop();
1621
1622   // Verify that the server accepted a connection
1623   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
1624   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
1625                     TestAcceptCallback::TYPE_START);
1626   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
1627                     TestAcceptCallback::TYPE_ACCEPT);
1628   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
1629                     TestAcceptCallback::TYPE_STOP);
1630   int fd = acceptCallback.getEvents()->at(1).fd;
1631
1632   // The accepted connection should already be in non-blocking mode
1633   int flags = fcntl(fd, F_GETFL, 0);
1634   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
1635
1636 #ifndef TCP_NOPUSH
1637   // The accepted connection should already have TCP_NODELAY set
1638   int value;
1639   socklen_t valueLength = sizeof(value);
1640   int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
1641   ASSERT_EQ(rc, 0);
1642   ASSERT_EQ(value, 1);
1643 #endif
1644 }
1645
1646 /**
1647  * Test AsyncServerSocket::removeAcceptCallback()
1648  */
1649 TEST(AsyncSocketTest, RemoveAcceptCallback) {
1650   // Create a new AsyncServerSocket
1651   EventBase eventBase;
1652   std::shared_ptr<AsyncServerSocket> serverSocket(
1653       AsyncServerSocket::newSocket(&eventBase));
1654   serverSocket->bind(0);
1655   serverSocket->listen(16);
1656   folly::SocketAddress serverAddress;
1657   serverSocket->getAddress(&serverAddress);
1658
1659   // Add several accept callbacks
1660   TestAcceptCallback cb1;
1661   TestAcceptCallback cb2;
1662   TestAcceptCallback cb3;
1663   TestAcceptCallback cb4;
1664   TestAcceptCallback cb5;
1665   TestAcceptCallback cb6;
1666   TestAcceptCallback cb7;
1667
1668   // Test having callbacks remove other callbacks before them on the list,
1669   // after them on the list, or removing themselves.
1670   //
1671   // Have callback 2 remove callback 3 and callback 5 the first time it is
1672   // called.
1673   int cb2Count = 0;
1674   cb1.setConnectionAcceptedFn([&](int /* fd */,
1675                                   const folly::SocketAddress& /* addr */) {
1676     std::shared_ptr<AsyncSocket> sock2(
1677         AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2: -cb3 -cb5
1678   });
1679   cb3.setConnectionAcceptedFn(
1680       [&](int /* fd */, const folly::SocketAddress& /* addr */) {});
1681   cb4.setConnectionAcceptedFn(
1682       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1683         std::shared_ptr<AsyncSocket> sock3(
1684             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
1685       });
1686   cb5.setConnectionAcceptedFn(
1687       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1688         std::shared_ptr<AsyncSocket> sock5(
1689             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
1690
1691       });
1692   cb2.setConnectionAcceptedFn(
1693       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1694         if (cb2Count == 0) {
1695           serverSocket->removeAcceptCallback(&cb3, nullptr);
1696           serverSocket->removeAcceptCallback(&cb5, nullptr);
1697         }
1698         ++cb2Count;
1699       });
1700   // Have callback 6 remove callback 4 the first time it is called,
1701   // and destroy the server socket the second time it is called
1702   int cb6Count = 0;
1703   cb6.setConnectionAcceptedFn(
1704       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1705         if (cb6Count == 0) {
1706           serverSocket->removeAcceptCallback(&cb4, nullptr);
1707           std::shared_ptr<AsyncSocket> sock6(
1708               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1709           std::shared_ptr<AsyncSocket> sock7(
1710               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
1711           std::shared_ptr<AsyncSocket> sock8(
1712               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop
1713
1714         } else {
1715           serverSocket.reset();
1716         }
1717         ++cb6Count;
1718       });
1719   // Have callback 7 remove itself
1720   cb7.setConnectionAcceptedFn(
1721       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1722         serverSocket->removeAcceptCallback(&cb7, nullptr);
1723       });
1724
1725   serverSocket->addAcceptCallback(&cb1, &eventBase);
1726   serverSocket->addAcceptCallback(&cb2, &eventBase);
1727   serverSocket->addAcceptCallback(&cb3, &eventBase);
1728   serverSocket->addAcceptCallback(&cb4, &eventBase);
1729   serverSocket->addAcceptCallback(&cb5, &eventBase);
1730   serverSocket->addAcceptCallback(&cb6, &eventBase);
1731   serverSocket->addAcceptCallback(&cb7, &eventBase);
1732   serverSocket->startAccepting();
1733
1734   // Make several connections to the socket
1735   std::shared_ptr<AsyncSocket> sock1(
1736       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1737   std::shared_ptr<AsyncSocket> sock4(
1738       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4
1739
1740   // Loop until we are stopped
1741   eventBase.loop();
1742
1743   // Check to make sure that the expected callbacks were invoked.
1744   //
1745   // NOTE: This code depends on the AsyncServerSocket operating calling all of
1746   // the AcceptCallbacks in round-robin fashion, in the order that they were
1747   // added.  The code is implemented this way right now, but the API doesn't
1748   // explicitly require it be done this way.  If we change the code not to be
1749   // exactly round robin in the future, we can simplify the test checks here.
1750   // (We'll also need to update the termination code, since we expect cb6 to
1751   // get called twice to terminate the loop.)
1752   ASSERT_EQ(cb1.getEvents()->size(), 4);
1753   ASSERT_EQ(cb1.getEvents()->at(0).type,
1754                     TestAcceptCallback::TYPE_START);
1755   ASSERT_EQ(cb1.getEvents()->at(1).type,
1756                     TestAcceptCallback::TYPE_ACCEPT);
1757   ASSERT_EQ(cb1.getEvents()->at(2).type,
1758                     TestAcceptCallback::TYPE_ACCEPT);
1759   ASSERT_EQ(cb1.getEvents()->at(3).type,
1760                     TestAcceptCallback::TYPE_STOP);
1761
1762   ASSERT_EQ(cb2.getEvents()->size(), 4);
1763   ASSERT_EQ(cb2.getEvents()->at(0).type,
1764                     TestAcceptCallback::TYPE_START);
1765   ASSERT_EQ(cb2.getEvents()->at(1).type,
1766                     TestAcceptCallback::TYPE_ACCEPT);
1767   ASSERT_EQ(cb2.getEvents()->at(2).type,
1768                     TestAcceptCallback::TYPE_ACCEPT);
1769   ASSERT_EQ(cb2.getEvents()->at(3).type,
1770                     TestAcceptCallback::TYPE_STOP);
1771
1772   ASSERT_EQ(cb3.getEvents()->size(), 2);
1773   ASSERT_EQ(cb3.getEvents()->at(0).type,
1774                     TestAcceptCallback::TYPE_START);
1775   ASSERT_EQ(cb3.getEvents()->at(1).type,
1776                     TestAcceptCallback::TYPE_STOP);
1777
1778   ASSERT_EQ(cb4.getEvents()->size(), 3);
1779   ASSERT_EQ(cb4.getEvents()->at(0).type,
1780                     TestAcceptCallback::TYPE_START);
1781   ASSERT_EQ(cb4.getEvents()->at(1).type,
1782                     TestAcceptCallback::TYPE_ACCEPT);
1783   ASSERT_EQ(cb4.getEvents()->at(2).type,
1784                     TestAcceptCallback::TYPE_STOP);
1785
1786   ASSERT_EQ(cb5.getEvents()->size(), 2);
1787   ASSERT_EQ(cb5.getEvents()->at(0).type,
1788                     TestAcceptCallback::TYPE_START);
1789   ASSERT_EQ(cb5.getEvents()->at(1).type,
1790                     TestAcceptCallback::TYPE_STOP);
1791
1792   ASSERT_EQ(cb6.getEvents()->size(), 4);
1793   ASSERT_EQ(cb6.getEvents()->at(0).type,
1794                     TestAcceptCallback::TYPE_START);
1795   ASSERT_EQ(cb6.getEvents()->at(1).type,
1796                     TestAcceptCallback::TYPE_ACCEPT);
1797   ASSERT_EQ(cb6.getEvents()->at(2).type,
1798                     TestAcceptCallback::TYPE_ACCEPT);
1799   ASSERT_EQ(cb6.getEvents()->at(3).type,
1800                     TestAcceptCallback::TYPE_STOP);
1801
1802   ASSERT_EQ(cb7.getEvents()->size(), 3);
1803   ASSERT_EQ(cb7.getEvents()->at(0).type,
1804                     TestAcceptCallback::TYPE_START);
1805   ASSERT_EQ(cb7.getEvents()->at(1).type,
1806                     TestAcceptCallback::TYPE_ACCEPT);
1807   ASSERT_EQ(cb7.getEvents()->at(2).type,
1808                     TestAcceptCallback::TYPE_STOP);
1809 }
1810
1811 /**
1812  * Test AsyncServerSocket::removeAcceptCallback()
1813  */
1814 TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
1815   // Create a new AsyncServerSocket
1816   EventBase eventBase;
1817   std::shared_ptr<AsyncServerSocket> serverSocket(
1818       AsyncServerSocket::newSocket(&eventBase));
1819   serverSocket->bind(0);
1820   serverSocket->listen(16);
1821   folly::SocketAddress serverAddress;
1822   serverSocket->getAddress(&serverAddress);
1823
1824   // Add several accept callbacks
1825   TestAcceptCallback cb1;
1826   auto thread_id = std::this_thread::get_id();
1827   cb1.setAcceptStartedFn([&](){
1828     CHECK_NE(thread_id, std::this_thread::get_id());
1829     thread_id = std::this_thread::get_id();
1830   });
1831   cb1.setConnectionAcceptedFn(
1832       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1833         ASSERT_EQ(thread_id, std::this_thread::get_id());
1834         serverSocket->removeAcceptCallback(&cb1, &eventBase);
1835       });
1836   cb1.setAcceptStoppedFn([&](){
1837     ASSERT_EQ(thread_id, std::this_thread::get_id());
1838   });
1839
1840   // Test having callbacks remove other callbacks before them on the list,
1841   serverSocket->addAcceptCallback(&cb1, &eventBase);
1842   serverSocket->startAccepting();
1843
1844   // Make several connections to the socket
1845   std::shared_ptr<AsyncSocket> sock1(
1846       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1847
1848   // Loop in another thread
1849   auto other = std::thread([&](){
1850     eventBase.loop();
1851   });
1852   other.join();
1853
1854   // Check to make sure that the expected callbacks were invoked.
1855   //
1856   // NOTE: This code depends on the AsyncServerSocket operating calling all of
1857   // the AcceptCallbacks in round-robin fashion, in the order that they were
1858   // added.  The code is implemented this way right now, but the API doesn't
1859   // explicitly require it be done this way.  If we change the code not to be
1860   // exactly round robin in the future, we can simplify the test checks here.
1861   // (We'll also need to update the termination code, since we expect cb6 to
1862   // get called twice to terminate the loop.)
1863   ASSERT_EQ(cb1.getEvents()->size(), 3);
1864   ASSERT_EQ(cb1.getEvents()->at(0).type,
1865                     TestAcceptCallback::TYPE_START);
1866   ASSERT_EQ(cb1.getEvents()->at(1).type,
1867                     TestAcceptCallback::TYPE_ACCEPT);
1868   ASSERT_EQ(cb1.getEvents()->at(2).type,
1869                     TestAcceptCallback::TYPE_STOP);
1870
1871 }
1872
1873 void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
1874   EventBase* eventBase = serverSocket->getEventBase();
1875   CHECK(eventBase);
1876
1877   // Add a callback to accept one connection then stop accepting
1878   TestAcceptCallback acceptCallback;
1879   acceptCallback.setConnectionAcceptedFn(
1880       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1881         serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
1882       });
1883   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
1884     serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
1885   });
1886   serverSocket->addAcceptCallback(&acceptCallback, eventBase);
1887   serverSocket->startAccepting();
1888
1889   // Connect to the server socket
1890   folly::SocketAddress serverAddress;
1891   serverSocket->getAddress(&serverAddress);
1892   AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
1893
1894   // Loop to process all events
1895   eventBase->loop();
1896
1897   // Verify that the server accepted a connection
1898   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
1899   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
1900                     TestAcceptCallback::TYPE_START);
1901   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
1902                     TestAcceptCallback::TYPE_ACCEPT);
1903   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
1904                     TestAcceptCallback::TYPE_STOP);
1905 }
1906
1907 /* Verify that we don't leak sockets if we are destroyed()
1908  * and there are still writes pending
1909  *
1910  * If destroy() only calls close() instead of closeNow(),
1911  * it would shutdown(writes) on the socket, but it would
1912  * never be close()'d, and the socket would leak
1913  */
1914 TEST(AsyncSocketTest, DestroyCloseTest) {
1915   TestServer server;
1916
1917   // connect()
1918   EventBase clientEB;
1919   EventBase serverEB;
1920   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
1921   ConnCallback ccb;
1922   socket->connect(&ccb, server.getAddress(), 30);
1923
1924   // Accept the connection
1925   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
1926   ReadCallback rcb;
1927   acceptedSocket->setReadCB(&rcb);
1928
1929   // Write a large buffer to the socket that is larger than kernel buffer
1930   size_t simpleBufLength = 5000000;
1931   char* simpleBuf = new char[simpleBufLength];
1932   memset(simpleBuf, 'a', simpleBufLength);
1933   WriteCallback wcb;
1934
1935   // Let the reads and writes run to completion
1936   int fd = acceptedSocket->getFd();
1937
1938   acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
1939   socket.reset();
1940   acceptedSocket.reset();
1941
1942   // Test that server socket was closed
1943   folly::test::msvcSuppressAbortOnInvalidParams([&] {
1944     ssize_t sz = read(fd, simpleBuf, simpleBufLength);
1945     ASSERT_EQ(sz, -1);
1946     ASSERT_EQ(errno, EBADF);
1947   });
1948   delete[] simpleBuf;
1949 }
1950
1951 /**
1952  * Test AsyncServerSocket::useExistingSocket()
1953  */
1954 TEST(AsyncSocketTest, ServerExistingSocket) {
1955   EventBase eventBase;
1956
1957   // Test creating a socket, and letting AsyncServerSocket bind and listen
1958   {
1959     // Manually create a socket
1960     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
1961     ASSERT_GE(fd, 0);
1962
1963     // Create a server socket
1964     AsyncServerSocket::UniquePtr serverSocket(
1965         new AsyncServerSocket(&eventBase));
1966     serverSocket->useExistingSocket(fd);
1967     folly::SocketAddress address;
1968     serverSocket->getAddress(&address);
1969     address.setPort(0);
1970     serverSocket->bind(address);
1971     serverSocket->listen(16);
1972
1973     // Make sure the socket works
1974     serverSocketSanityTest(serverSocket.get());
1975   }
1976
1977   // Test creating a socket and binding manually,
1978   // then letting AsyncServerSocket listen
1979   {
1980     // Manually create a socket
1981     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
1982     ASSERT_GE(fd, 0);
1983     // bind
1984     struct sockaddr_in addr;
1985     addr.sin_family = AF_INET;
1986     addr.sin_port = 0;
1987     addr.sin_addr.s_addr = INADDR_ANY;
1988     ASSERT_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
1989                              sizeof(addr)), 0);
1990     // Look up the address that we bound to
1991     folly::SocketAddress boundAddress;
1992     boundAddress.setFromLocalAddress(fd);
1993
1994     // Create a server socket
1995     AsyncServerSocket::UniquePtr serverSocket(
1996         new AsyncServerSocket(&eventBase));
1997     serverSocket->useExistingSocket(fd);
1998     serverSocket->listen(16);
1999
2000     // Make sure AsyncServerSocket reports the same address that we bound to
2001     folly::SocketAddress serverSocketAddress;
2002     serverSocket->getAddress(&serverSocketAddress);
2003     ASSERT_EQ(boundAddress, serverSocketAddress);
2004
2005     // Make sure the socket works
2006     serverSocketSanityTest(serverSocket.get());
2007   }
2008
2009   // Test creating a socket, binding and listening manually,
2010   // then giving it to AsyncServerSocket
2011   {
2012     // Manually create a socket
2013     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2014     ASSERT_GE(fd, 0);
2015     // bind
2016     struct sockaddr_in addr;
2017     addr.sin_family = AF_INET;
2018     addr.sin_port = 0;
2019     addr.sin_addr.s_addr = INADDR_ANY;
2020     ASSERT_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
2021                              sizeof(addr)), 0);
2022     // Look up the address that we bound to
2023     folly::SocketAddress boundAddress;
2024     boundAddress.setFromLocalAddress(fd);
2025     // listen
2026     ASSERT_EQ(listen(fd, 16), 0);
2027
2028     // Create a server socket
2029     AsyncServerSocket::UniquePtr serverSocket(
2030         new AsyncServerSocket(&eventBase));
2031     serverSocket->useExistingSocket(fd);
2032
2033     // Make sure AsyncServerSocket reports the same address that we bound to
2034     folly::SocketAddress serverSocketAddress;
2035     serverSocket->getAddress(&serverSocketAddress);
2036     ASSERT_EQ(boundAddress, serverSocketAddress);
2037
2038     // Make sure the socket works
2039     serverSocketSanityTest(serverSocket.get());
2040   }
2041 }
2042
2043 TEST(AsyncSocketTest, UnixDomainSocketTest) {
2044   EventBase eventBase;
2045
2046   // Create a server socket
2047   std::shared_ptr<AsyncServerSocket> serverSocket(
2048       AsyncServerSocket::newSocket(&eventBase));
2049   string path(1, 0);
2050   path.append(folly::to<string>("/anonymous", folly::Random::rand64()));
2051   folly::SocketAddress serverAddress;
2052   serverAddress.setFromPath(path);
2053   serverSocket->bind(serverAddress);
2054   serverSocket->listen(16);
2055
2056   // Add a callback to accept one connection then stop the loop
2057   TestAcceptCallback acceptCallback;
2058   acceptCallback.setConnectionAcceptedFn(
2059       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2060         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2061       });
2062   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2063     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2064   });
2065   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2066   serverSocket->startAccepting();
2067
2068   // Connect to the server socket
2069   std::shared_ptr<AsyncSocket> socket(
2070       AsyncSocket::newSocket(&eventBase, serverAddress));
2071
2072   eventBase.loop();
2073
2074   // Verify that the server accepted a connection
2075   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
2076   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
2077                     TestAcceptCallback::TYPE_START);
2078   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
2079                     TestAcceptCallback::TYPE_ACCEPT);
2080   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
2081                     TestAcceptCallback::TYPE_STOP);
2082   int fd = acceptCallback.getEvents()->at(1).fd;
2083
2084   // The accepted connection should already be in non-blocking mode
2085   int flags = fcntl(fd, F_GETFL, 0);
2086   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
2087 }
2088
2089 TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
2090   EventBase eventBase;
2091   TestConnectionEventCallback connectionEventCallback;
2092
2093   // Create a server socket
2094   std::shared_ptr<AsyncServerSocket> serverSocket(
2095       AsyncServerSocket::newSocket(&eventBase));
2096   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2097   serverSocket->bind(0);
2098   serverSocket->listen(16);
2099   folly::SocketAddress serverAddress;
2100   serverSocket->getAddress(&serverAddress);
2101
2102   // Add a callback to accept one connection then stop the loop
2103   TestAcceptCallback acceptCallback;
2104   acceptCallback.setConnectionAcceptedFn(
2105       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2106         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2107       });
2108   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2109     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2110   });
2111   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2112   serverSocket->startAccepting();
2113
2114   // Connect to the server socket
2115   std::shared_ptr<AsyncSocket> socket(
2116       AsyncSocket::newSocket(&eventBase, serverAddress));
2117
2118   eventBase.loop();
2119
2120   // Validate the connection event counters
2121   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2122   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2123   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2124   ASSERT_EQ(
2125       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
2126   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
2127   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2128   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2129   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2130 }
2131
2132 TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
2133   EventBase eventBase;
2134   TestConnectionEventCallback connectionEventCallback;
2135
2136   // Create a server socket
2137   std::shared_ptr<AsyncServerSocket> serverSocket(
2138       AsyncServerSocket::newSocket(&eventBase));
2139   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2140   serverSocket->bind(0);
2141   serverSocket->listen(16);
2142   folly::SocketAddress serverAddress;
2143   serverSocket->getAddress(&serverAddress);
2144
2145   // Add a callback to accept one connection then stop the loop
2146   TestAcceptCallback acceptCallback;
2147   acceptCallback.setConnectionAcceptedFn(
2148       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2149         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2150       });
2151   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2152     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2153   });
2154   bool acceptStartedFlag{false};
2155   acceptCallback.setAcceptStartedFn([&acceptStartedFlag](){
2156     acceptStartedFlag = true;
2157   });
2158   bool acceptStoppedFlag{false};
2159   acceptCallback.setAcceptStoppedFn([&acceptStoppedFlag](){
2160     acceptStoppedFlag = true;
2161   });
2162   serverSocket->addAcceptCallback(&acceptCallback, nullptr);
2163   serverSocket->startAccepting();
2164
2165   // Connect to the server socket
2166   std::shared_ptr<AsyncSocket> socket(
2167       AsyncSocket::newSocket(&eventBase, serverAddress));
2168
2169   eventBase.loop();
2170
2171   ASSERT_TRUE(acceptStartedFlag);
2172   ASSERT_TRUE(acceptStoppedFlag);
2173   // Validate the connection event counters
2174   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2175   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2176   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2177   ASSERT_EQ(
2178       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
2179   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
2180   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2181   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2182   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2183 }
2184
2185
2186
2187 /**
2188  * Test AsyncServerSocket::getNumPendingMessagesInQueue()
2189  */
2190 TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
2191   EventBase eventBase;
2192
2193   // Counter of how many connections have been accepted
2194   int count = 0;
2195
2196   // Create a server socket
2197   auto serverSocket(AsyncServerSocket::newSocket(&eventBase));
2198   serverSocket->bind(0);
2199   serverSocket->listen(16);
2200   folly::SocketAddress serverAddress;
2201   serverSocket->getAddress(&serverAddress);
2202
2203   // Add a callback to accept connections
2204   TestAcceptCallback acceptCallback;
2205   acceptCallback.setConnectionAcceptedFn(
2206       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2207         count++;
2208         ASSERT_EQ(4 - count, serverSocket->getNumPendingMessagesInQueue());
2209
2210         if (count == 4) {
2211           // all messages are processed, remove accept callback
2212           serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2213         }
2214       });
2215   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2216     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2217   });
2218   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2219   serverSocket->startAccepting();
2220
2221   // Connect to the server socket, 4 clients, there are 4 connections
2222   auto socket1(AsyncSocket::newSocket(&eventBase, serverAddress));
2223   auto socket2(AsyncSocket::newSocket(&eventBase, serverAddress));
2224   auto socket3(AsyncSocket::newSocket(&eventBase, serverAddress));
2225   auto socket4(AsyncSocket::newSocket(&eventBase, serverAddress));
2226
2227   eventBase.loop();
2228 }
2229
2230 /**
2231  * Test AsyncTransport::BufferCallback
2232  */
2233 TEST(AsyncSocketTest, BufferTest) {
2234   TestServer server;
2235
2236   EventBase evb;
2237   AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2238   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2239   ConnCallback ccb;
2240   socket->connect(&ccb, server.getAddress(), 30, option);
2241
2242   char buf[100 * 1024];
2243   memset(buf, 'c', sizeof(buf));
2244   WriteCallback wcb;
2245   BufferCallback bcb;
2246   socket->setBufferCallback(&bcb);
2247   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2248
2249   evb.loop();
2250   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2251   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2252
2253   ASSERT_TRUE(bcb.hasBuffered());
2254   ASSERT_TRUE(bcb.hasBufferCleared());
2255
2256   socket->close();
2257   server.verifyConnection(buf, sizeof(buf));
2258
2259   ASSERT_TRUE(socket->isClosedBySelf());
2260   ASSERT_FALSE(socket->isClosedByPeer());
2261 }
2262
2263 TEST(AsyncSocketTest, BufferCallbackKill) {
2264   TestServer server;
2265   EventBase evb;
2266   AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2267   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2268   ConnCallback ccb;
2269   socket->connect(&ccb, server.getAddress(), 30, option);
2270   evb.loopOnce();
2271
2272   char buf[100 * 1024];
2273   memset(buf, 'c', sizeof(buf));
2274   BufferCallback bcb;
2275   socket->setBufferCallback(&bcb);
2276   WriteCallback wcb;
2277   wcb.successCallback = [&] {
2278     ASSERT_TRUE(socket.unique());
2279     socket.reset();
2280   };
2281
2282   // This will trigger AsyncSocket::handleWrite,
2283   // which calls WriteCallback::writeSuccess,
2284   // which calls wcb.successCallback above,
2285   // which tries to delete socket
2286   // Then, the socket will also try to use this BufferCallback
2287   // And that should crash us, if there is no DestructorGuard on the stack
2288   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2289
2290   evb.loop();
2291   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2292 }
2293
2294 #if FOLLY_ALLOW_TFO
2295 TEST(AsyncSocketTest, ConnectTFO) {
2296   // Start listening on a local port
2297   TestServer server(true);
2298
2299   // Connect using a AsyncSocket
2300   EventBase evb;
2301   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2302   socket->enableTFO();
2303   ConnCallback cb;
2304   socket->connect(&cb, server.getAddress(), 30);
2305
2306   std::array<uint8_t, 128> buf;
2307   memset(buf.data(), 'a', buf.size());
2308
2309   std::array<uint8_t, 3> readBuf;
2310   auto sendBuf = IOBuf::copyBuffer("hey");
2311
2312   std::thread t([&] {
2313     auto acceptedSocket = server.accept();
2314     acceptedSocket->write(buf.data(), buf.size());
2315     acceptedSocket->flush();
2316     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2317     acceptedSocket->close();
2318   });
2319
2320   evb.loop();
2321
2322   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2323   EXPECT_LE(0, socket->getConnectTime().count());
2324   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2325   EXPECT_TRUE(socket->getTFOAttempted());
2326
2327   // Should trigger the connect
2328   WriteCallback write;
2329   ReadCallback rcb;
2330   socket->writeChain(&write, sendBuf->clone());
2331   socket->setReadCB(&rcb);
2332   evb.loop();
2333
2334   t.join();
2335
2336   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2337   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2338   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2339   ASSERT_EQ(1, rcb.buffers.size());
2340   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2341   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2342   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2343 }
2344
2345 TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
2346   // Start listening on a local port
2347   TestServer server(true);
2348
2349   // Connect using a AsyncSocket
2350   EventBase evb;
2351   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2352   socket->enableTFO();
2353   ConnCallback cb;
2354   socket->connect(&cb, server.getAddress(), 30);
2355   ReadCallback rcb;
2356   socket->setReadCB(&rcb);
2357
2358   std::array<uint8_t, 128> buf;
2359   memset(buf.data(), 'a', buf.size());
2360
2361   std::array<uint8_t, 3> readBuf;
2362   auto sendBuf = IOBuf::copyBuffer("hey");
2363
2364   std::thread t([&] {
2365     auto acceptedSocket = server.accept();
2366     acceptedSocket->write(buf.data(), buf.size());
2367     acceptedSocket->flush();
2368     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2369     acceptedSocket->close();
2370   });
2371
2372   evb.loop();
2373
2374   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2375   EXPECT_LE(0, socket->getConnectTime().count());
2376   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2377   EXPECT_TRUE(socket->getTFOAttempted());
2378
2379   // Should trigger the connect
2380   WriteCallback write;
2381   socket->writeChain(&write, sendBuf->clone());
2382   evb.loop();
2383
2384   t.join();
2385
2386   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2387   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2388   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2389   ASSERT_EQ(1, rcb.buffers.size());
2390   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2391   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2392   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2393 }
2394
2395 /**
2396  * Test connecting to a server that isn't listening
2397  */
2398 TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
2399   EventBase evb;
2400
2401   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2402
2403   socket->enableTFO();
2404
2405   // Hopefully nothing is actually listening on this address
2406   folly::SocketAddress addr("::1", 65535);
2407   ConnCallback cb;
2408   socket->connect(&cb, addr, 30);
2409
2410   evb.loop();
2411
2412   WriteCallback write1;
2413   // Trigger the connect if TFO attempt is supported.
2414   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2415   WriteCallback write2;
2416   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2417   evb.loop();
2418
2419   if (!socket->getTFOFinished()) {
2420     EXPECT_EQ(STATE_FAILED, write1.state);
2421   } else {
2422     EXPECT_EQ(STATE_SUCCEEDED, write1.state);
2423     EXPECT_FALSE(socket->getTFOSucceded());
2424   }
2425
2426   EXPECT_EQ(STATE_FAILED, write2.state);
2427
2428   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2429   EXPECT_LE(0, socket->getConnectTime().count());
2430   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2431   EXPECT_TRUE(socket->getTFOAttempted());
2432 }
2433
2434 /**
2435  * Test calling closeNow() immediately after connecting.
2436  */
2437 TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) {
2438   TestServer server(true);
2439
2440   // connect()
2441   EventBase evb;
2442   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2443   socket->enableTFO();
2444
2445   ConnCallback ccb;
2446   socket->connect(&ccb, server.getAddress(), 30);
2447
2448   // write()
2449   std::array<char, 128> buf;
2450   memset(buf.data(), 'a', buf.size());
2451
2452   // close()
2453   socket->closeNow();
2454
2455   // Loop, although there shouldn't be anything to do.
2456   evb.loop();
2457
2458   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2459
2460   ASSERT_TRUE(socket->isClosedBySelf());
2461   ASSERT_FALSE(socket->isClosedByPeer());
2462 }
2463
2464 /**
2465  * Test calling close() immediately after connect()
2466  */
2467 TEST(AsyncSocketTest, ConnectAndCloseTFO) {
2468   TestServer server(true);
2469
2470   // Connect using a AsyncSocket
2471   EventBase evb;
2472   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2473   socket->enableTFO();
2474
2475   ConnCallback ccb;
2476   socket->connect(&ccb, server.getAddress(), 30);
2477
2478   socket->close();
2479
2480   // Loop, although there shouldn't be anything to do.
2481   evb.loop();
2482
2483   // Make sure the connection was aborted
2484   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2485
2486   ASSERT_TRUE(socket->isClosedBySelf());
2487   ASSERT_FALSE(socket->isClosedByPeer());
2488 }
2489
2490 class MockAsyncTFOSocket : public AsyncSocket {
2491  public:
2492   using UniquePtr = std::unique_ptr<MockAsyncTFOSocket, Destructor>;
2493
2494   explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
2495
2496   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
2497 };
2498
2499 TEST(AsyncSocketTest, TestTFOUnsupported) {
2500   TestServer server(true);
2501
2502   // Connect using a AsyncSocket
2503   EventBase evb;
2504   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2505   socket->enableTFO();
2506
2507   ConnCallback ccb;
2508   socket->connect(&ccb, server.getAddress(), 30);
2509   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2510
2511   ReadCallback rcb;
2512   socket->setReadCB(&rcb);
2513
2514   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2515       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2516   WriteCallback write;
2517   auto sendBuf = IOBuf::copyBuffer("hey");
2518   socket->writeChain(&write, sendBuf->clone());
2519   EXPECT_EQ(STATE_WAITING, write.state);
2520
2521   std::array<uint8_t, 128> buf;
2522   memset(buf.data(), 'a', buf.size());
2523
2524   std::array<uint8_t, 3> readBuf;
2525
2526   std::thread t([&] {
2527     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2528     acceptedSocket->write(buf.data(), buf.size());
2529     acceptedSocket->flush();
2530     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2531     acceptedSocket->close();
2532   });
2533
2534   evb.loop();
2535
2536   t.join();
2537   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2538   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2539
2540   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2541   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2542   ASSERT_EQ(1, rcb.buffers.size());
2543   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2544   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2545   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2546 }
2547
2548 TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
2549   EventBase evb;
2550
2551   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2552   socket->enableTFO();
2553
2554   // Hopefully this fails
2555   folly::SocketAddress fakeAddr("127.0.0.1", 65535);
2556   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2557       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2558         sockaddr_storage addr;
2559         auto len = fakeAddr.getAddress(&addr);
2560         int ret = connect(fd, (const struct sockaddr*)&addr, len);
2561         LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
2562                   << errno;
2563         return ret;
2564       }));
2565
2566   // Hopefully nothing is actually listening on this address
2567   ConnCallback cb;
2568   socket->connect(&cb, fakeAddr, 30);
2569
2570   WriteCallback write1;
2571   // Trigger the connect if TFO attempt is supported.
2572   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2573
2574   if (socket->getTFOFinished()) {
2575     // This test is useless now.
2576     return;
2577   }
2578   WriteCallback write2;
2579   // Trigger the connect if TFO attempt is supported.
2580   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2581   evb.loop();
2582
2583   EXPECT_EQ(STATE_FAILED, write1.state);
2584   EXPECT_EQ(STATE_FAILED, write2.state);
2585   EXPECT_FALSE(socket->getTFOSucceded());
2586
2587   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2588   EXPECT_LE(0, socket->getConnectTime().count());
2589   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2590   EXPECT_TRUE(socket->getTFOAttempted());
2591 }
2592
2593 TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
2594   // Try connecting to server that won't respond.
2595   //
2596   // This depends somewhat on the network where this test is run.
2597   // Hopefully this IP will be routable but unresponsive.
2598   // (Alternatively, we could try listening on a local raw socket, but that
2599   // normally requires root privileges.)
2600   auto host = SocketAddressTestHelper::isIPv6Enabled()
2601       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2602       : SocketAddressTestHelper::isIPv4Enabled()
2603           ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2604           : nullptr;
2605   SocketAddress addr(host, 65535);
2606
2607   // Connect using a AsyncSocket
2608   EventBase evb;
2609   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2610   socket->enableTFO();
2611
2612   ConnCallback ccb;
2613   // Set a very small timeout
2614   socket->connect(&ccb, addr, 1);
2615   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2616
2617   ReadCallback rcb;
2618   socket->setReadCB(&rcb);
2619
2620   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2621       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2622   WriteCallback write;
2623   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2624
2625   evb.loop();
2626
2627   EXPECT_EQ(STATE_FAILED, write.state);
2628 }
2629
2630 TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
2631   TestServer server(true);
2632
2633   // Connect using a AsyncSocket
2634   EventBase evb;
2635   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2636   socket->enableTFO();
2637
2638   ConnCallback ccb;
2639   socket->connect(&ccb, server.getAddress(), 30);
2640   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2641
2642   ReadCallback rcb;
2643   socket->setReadCB(&rcb);
2644
2645   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2646       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2647         sockaddr_storage addr;
2648         auto len = server.getAddress().getAddress(&addr);
2649         return connect(fd, (const struct sockaddr*)&addr, len);
2650       }));
2651   WriteCallback write;
2652   auto sendBuf = IOBuf::copyBuffer("hey");
2653   socket->writeChain(&write, sendBuf->clone());
2654   EXPECT_EQ(STATE_WAITING, write.state);
2655
2656   std::array<uint8_t, 128> buf;
2657   memset(buf.data(), 'a', buf.size());
2658
2659   std::array<uint8_t, 3> readBuf;
2660
2661   std::thread t([&] {
2662     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2663     acceptedSocket->write(buf.data(), buf.size());
2664     acceptedSocket->flush();
2665     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2666     acceptedSocket->close();
2667   });
2668
2669   evb.loop();
2670
2671   t.join();
2672   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2673
2674   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2675   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2676
2677   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2678   ASSERT_EQ(1, rcb.buffers.size());
2679   ASSERT_EQ(buf.size(), rcb.buffers[0].length);
2680   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2681 }
2682
2683 TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
2684   // Try connecting to server that won't respond.
2685   //
2686   // This depends somewhat on the network where this test is run.
2687   // Hopefully this IP will be routable but unresponsive.
2688   // (Alternatively, we could try listening on a local raw socket, but that
2689   // normally requires root privileges.)
2690   auto host = SocketAddressTestHelper::isIPv6Enabled()
2691       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2692       : SocketAddressTestHelper::isIPv4Enabled()
2693           ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2694           : nullptr;
2695   SocketAddress addr(host, 65535);
2696
2697   // Connect using a AsyncSocket
2698   EventBase evb;
2699   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2700   socket->enableTFO();
2701
2702   ConnCallback ccb;
2703   // Set a very small timeout
2704   socket->connect(&ccb, addr, 1);
2705   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2706
2707   ReadCallback rcb;
2708   socket->setReadCB(&rcb);
2709
2710   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2711       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2712         sockaddr_storage addr2;
2713         auto len = addr.getAddress(&addr2);
2714         return connect(fd, (const struct sockaddr*)&addr2, len);
2715       }));
2716   WriteCallback write;
2717   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2718
2719   evb.loop();
2720
2721   EXPECT_EQ(STATE_FAILED, write.state);
2722 }
2723
2724 TEST(AsyncSocketTest, TestTFOEagain) {
2725   TestServer server(true);
2726
2727   // Connect using a AsyncSocket
2728   EventBase evb;
2729   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2730   socket->enableTFO();
2731
2732   ConnCallback ccb;
2733   socket->connect(&ccb, server.getAddress(), 30);
2734
2735   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2736       .WillOnce(SetErrnoAndReturn(EAGAIN, -1));
2737   WriteCallback write;
2738   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2739
2740   evb.loop();
2741
2742   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2743   EXPECT_EQ(STATE_FAILED, write.state);
2744 }
2745
2746 // Sending a large amount of data in the first write which will
2747 // definitely not fit into MSS.
2748 TEST(AsyncSocketTest, ConnectTFOWithBigData) {
2749   // Start listening on a local port
2750   TestServer server(true);
2751
2752   // Connect using a AsyncSocket
2753   EventBase evb;
2754   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2755   socket->enableTFO();
2756   ConnCallback cb;
2757   socket->connect(&cb, server.getAddress(), 30);
2758
2759   std::array<uint8_t, 128> buf;
2760   memset(buf.data(), 'a', buf.size());
2761
2762   constexpr size_t len = 10 * 1024;
2763   auto sendBuf = IOBuf::create(len);
2764   sendBuf->append(len);
2765   std::array<uint8_t, len> readBuf;
2766
2767   std::thread t([&] {
2768     auto acceptedSocket = server.accept();
2769     acceptedSocket->write(buf.data(), buf.size());
2770     acceptedSocket->flush();
2771     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2772     acceptedSocket->close();
2773   });
2774
2775   evb.loop();
2776
2777   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2778   EXPECT_LE(0, socket->getConnectTime().count());
2779   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2780   EXPECT_TRUE(socket->getTFOAttempted());
2781
2782   // Should trigger the connect
2783   WriteCallback write;
2784   ReadCallback rcb;
2785   socket->writeChain(&write, sendBuf->clone());
2786   socket->setReadCB(&rcb);
2787   evb.loop();
2788
2789   t.join();
2790
2791   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2792   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2793   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2794   ASSERT_EQ(1, rcb.buffers.size());
2795   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2796   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2797   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2798 }
2799
2800 #endif // FOLLY_ALLOW_TFO
2801
2802 class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
2803  public:
2804   MOCK_METHOD1(evbAttached, void(AsyncSocket*));
2805   MOCK_METHOD1(evbDetached, void(AsyncSocket*));
2806 };
2807
2808 TEST(AsyncSocketTest, EvbCallbacks) {
2809   auto cb = folly::make_unique<MockEvbChangeCallback>();
2810   EventBase evb;
2811   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2812
2813   InSequence seq;
2814   EXPECT_CALL(*cb, evbDetached(socket.get())).Times(1);
2815   EXPECT_CALL(*cb, evbAttached(socket.get())).Times(1);
2816
2817   socket->setEvbChangedCallback(std::move(cb));
2818   socket->detachEventBase();
2819   socket->attachEventBase(&evb);
2820 }
2821
2822 #ifdef MSG_ERRQUEUE
2823 /* copied from include/uapi/linux/net_tstamp.h */
2824 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
2825 enum SOF_TIMESTAMPING {
2826   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
2827   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
2828   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
2829   SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
2830   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
2831 };
2832 TEST(AsyncSocketTest, ErrMessageCallback) {
2833   TestServer server;
2834
2835   // connect()
2836   EventBase evb;
2837   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2838
2839   ConnCallback ccb;
2840   socket->connect(&ccb, server.getAddress(), 30);
2841   LOG(INFO) << "Client socket fd=" << socket->getFd();
2842
2843   // Let the socket
2844   evb.loop();
2845
2846   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2847
2848   // Set read callback to keep the socket subscribed for event
2849   // notifications. Though we're no planning to read anything from
2850   // this side of the connection.
2851   ReadCallback rcb(1);
2852   socket->setReadCB(&rcb);
2853
2854   // Set up timestamp callbacks
2855   TestErrMessageCallback errMsgCB;
2856   socket->setErrMessageCB(&errMsgCB);
2857   ASSERT_EQ(socket->getErrMessageCallback(),
2858             static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB));
2859
2860   // Enable timestamp notifications
2861   ASSERT_GT(socket->getFd(), 0);
2862   int flags = SOF_TIMESTAMPING_OPT_ID
2863               | SOF_TIMESTAMPING_OPT_TSONLY
2864               | SOF_TIMESTAMPING_SOFTWARE
2865               | SOF_TIMESTAMPING_OPT_CMSG
2866               | SOF_TIMESTAMPING_TX_SCHED;
2867   AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
2868   EXPECT_EQ(tstampingOpt.apply(socket->getFd(), flags), 0);
2869
2870   // write()
2871   std::vector<uint8_t> wbuf(128, 'a');
2872   WriteCallback wcb;
2873   socket->write(&wcb, wbuf.data(), wbuf.size());
2874
2875   // Accept the connection.
2876   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2877   LOG(INFO) << "Server socket fd=" << acceptedSocket->getSocketFD();
2878
2879   // Loop
2880   evb.loopOnce();
2881   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2882
2883   // Check that we can read the data that was written to the socket
2884   std::vector<uint8_t> rbuf(1 + wbuf.size(), 0);
2885   uint32_t bytesRead = acceptedSocket->read(rbuf.data(), rbuf.size());
2886   ASSERT_TRUE(std::equal(wbuf.begin(), wbuf.end(), rbuf.begin()));
2887   ASSERT_EQ(bytesRead, wbuf.size());
2888
2889   // Close both sockets
2890   acceptedSocket->close();
2891   socket->close();
2892
2893   ASSERT_TRUE(socket->isClosedBySelf());
2894   ASSERT_FALSE(socket->isClosedByPeer());
2895
2896   // Check for the timestamp notifications.
2897   ASSERT_EQ(errMsgCB.exception_.type_, folly::AsyncSocketException::UNKNOWN);
2898   ASSERT_TRUE(errMsgCB.gotByteSeq_);
2899   ASSERT_TRUE(errMsgCB.gotTimestamp_);
2900 }
2901 #endif // MSG_ERRQUEUE
2902
2903 TEST(AsyncSocket, PreReceivedData) {
2904   TestServer server;
2905
2906   EventBase evb;
2907   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2908   socket->connect(nullptr, server.getAddress(), 30);
2909   evb.loop();
2910
2911   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
2912
2913   auto acceptedSocket = server.acceptAsync(&evb);
2914
2915   ReadCallback peekCallback(2);
2916   ReadCallback readCallback;
2917   peekCallback.dataAvailableCallback = [&]() {
2918     peekCallback.verifyData("he", 2);
2919     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
2920     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
2921     acceptedSocket->setReadCB(nullptr);
2922     acceptedSocket->setReadCB(&readCallback);
2923   };
2924   readCallback.dataAvailableCallback = [&]() {
2925     if (readCallback.dataRead() == 5) {
2926       readCallback.verifyData("hello", 5);
2927       acceptedSocket->setReadCB(nullptr);
2928     }
2929   };
2930
2931   acceptedSocket->setReadCB(&peekCallback);
2932
2933   evb.loop();
2934 }
2935
2936 TEST(AsyncSocket, PreReceivedDataOnly) {
2937   TestServer server;
2938
2939   EventBase evb;
2940   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2941   socket->connect(nullptr, server.getAddress(), 30);
2942   evb.loop();
2943
2944   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
2945
2946   auto acceptedSocket = server.acceptAsync(&evb);
2947
2948   ReadCallback peekCallback;
2949   ReadCallback readCallback;
2950   peekCallback.dataAvailableCallback = [&]() {
2951     peekCallback.verifyData("hello", 5);
2952     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
2953     acceptedSocket->setReadCB(&readCallback);
2954   };
2955   readCallback.dataAvailableCallback = [&]() {
2956     readCallback.verifyData("hello", 5);
2957     acceptedSocket->setReadCB(nullptr);
2958   };
2959
2960   acceptedSocket->setReadCB(&peekCallback);
2961
2962   evb.loop();
2963 }
2964
2965 TEST(AsyncSocket, PreReceivedDataPartial) {
2966   TestServer server;
2967
2968   EventBase evb;
2969   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2970   socket->connect(nullptr, server.getAddress(), 30);
2971   evb.loop();
2972
2973   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
2974
2975   auto acceptedSocket = server.acceptAsync(&evb);
2976
2977   ReadCallback peekCallback;
2978   ReadCallback smallReadCallback(3);
2979   ReadCallback normalReadCallback;
2980   peekCallback.dataAvailableCallback = [&]() {
2981     peekCallback.verifyData("hello", 5);
2982     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
2983     acceptedSocket->setReadCB(&smallReadCallback);
2984   };
2985   smallReadCallback.dataAvailableCallback = [&]() {
2986     smallReadCallback.verifyData("hel", 3);
2987     acceptedSocket->setReadCB(&normalReadCallback);
2988   };
2989   normalReadCallback.dataAvailableCallback = [&]() {
2990     normalReadCallback.verifyData("lo", 2);
2991     acceptedSocket->setReadCB(nullptr);
2992   };
2993
2994   acceptedSocket->setReadCB(&peekCallback);
2995
2996   evb.loop();
2997 }
2998
2999 TEST(AsyncSocket, PreReceivedDataTakeover) {
3000   TestServer server;
3001
3002   EventBase evb;
3003   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3004   socket->connect(nullptr, server.getAddress(), 30);
3005   evb.loop();
3006
3007   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
3008
3009   auto acceptedSocket =
3010       AsyncSocket::UniquePtr(new AsyncSocket(&evb, server.acceptFD()));
3011   AsyncSocket::UniquePtr takeoverSocket;
3012
3013   ReadCallback peekCallback(3);
3014   ReadCallback readCallback;
3015   peekCallback.dataAvailableCallback = [&]() {
3016     peekCallback.verifyData("hel", 3);
3017     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
3018     acceptedSocket->setReadCB(nullptr);
3019     takeoverSocket =
3020         AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
3021     takeoverSocket->setReadCB(&readCallback);
3022   };
3023   readCallback.dataAvailableCallback = [&]() {
3024     readCallback.verifyData("hello", 5);
3025     takeoverSocket->setReadCB(nullptr);
3026   };
3027
3028   acceptedSocket->setReadCB(&peekCallback);
3029
3030   evb.loop();
3031 }
3032
3033 TEST(AsyncSocketTest, SendMessageFlags) {
3034   TestServer server;
3035   TestSendMsgParamsCallback sendMsgCB(
3036       MSG_DONTWAIT|MSG_NOSIGNAL|MSG_MORE, 0, nullptr);
3037
3038   // connect()
3039   EventBase evb;
3040   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3041
3042   ConnCallback ccb;
3043   socket->connect(&ccb, server.getAddress(), 30);
3044   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
3045
3046   evb.loop();
3047   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
3048
3049   // Set SendMsgParamsCallback
3050   socket->setSendMsgParamCB(&sendMsgCB);
3051   ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
3052
3053   // Write the first portion of data. This data is expected to be
3054   // sent out immediately.
3055   std::vector<uint8_t> buf(128, 'a');
3056   WriteCallback wcb;
3057   sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
3058   socket->write(&wcb, buf.data(), buf.size());
3059   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3060   ASSERT_TRUE(sendMsgCB.queriedFlags_);
3061   ASSERT_FALSE(sendMsgCB.queriedData_);
3062
3063   // Using different flags for the second write operation.
3064   // MSG_MORE flag is expected to delay sending this
3065   // data to the wire.
3066   sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
3067   socket->write(&wcb, buf.data(), buf.size());
3068   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3069   ASSERT_TRUE(sendMsgCB.queriedFlags_);
3070   ASSERT_FALSE(sendMsgCB.queriedData_);
3071
3072   // Make sure the accepted socket saw only the data from
3073   // the first write request.
3074   std::vector<uint8_t> readbuf(2 * buf.size());
3075   uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
3076   ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
3077   ASSERT_EQ(bytesRead, buf.size());
3078
3079   // Make sure the server got a connection and received the data
3080   acceptedSocket->close();
3081   socket->close();
3082
3083   ASSERT_TRUE(socket->isClosedBySelf());
3084   ASSERT_FALSE(socket->isClosedByPeer());
3085 }
3086
3087 TEST(AsyncSocketTest, SendMessageAncillaryData) {
3088   int fds[2];
3089   EXPECT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0);
3090
3091   // "Client" socket
3092   int cfd = fds[0];
3093   ASSERT_NE(cfd, -1);
3094
3095   // "Server" socket
3096   int sfd = fds[1];
3097   ASSERT_NE(sfd, -1);
3098   SCOPE_EXIT { close(sfd); };
3099
3100   // Instantiate AsyncSocket object for the connected socket
3101   EventBase evb;
3102   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, cfd);
3103
3104   // Open a temporary file and write a magic string to it
3105   // We'll transfer the file handle to test the message parameters
3106   // callback logic.
3107   TemporaryFile file(StringPiece(),
3108                      fs::path(),
3109                      TemporaryFile::Scope::UNLINK_IMMEDIATELY);
3110   int tmpfd = file.fd();
3111   ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
3112   std::string magicString("Magic string");
3113   ASSERT_EQ(write(tmpfd, magicString.c_str(), magicString.length()),
3114             magicString.length());
3115
3116   // Send message
3117   union {
3118     // Space large enough to hold an 'int'
3119     char control[CMSG_SPACE(sizeof(int))];
3120     struct cmsghdr cmh;
3121   } s_u;
3122   s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
3123   s_u.cmh.cmsg_level = SOL_SOCKET;
3124   s_u.cmh.cmsg_type = SCM_RIGHTS;
3125   memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
3126
3127   // Set up the callback providing message parameters
3128   TestSendMsgParamsCallback sendMsgCB(
3129       MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
3130   socket->setSendMsgParamCB(&sendMsgCB);
3131
3132   // We must transmit at least 1 byte of real data in order
3133   // to send ancillary data
3134   int s_data = 12345;
3135   WriteCallback wcb;
3136   socket->write(&wcb, &s_data, sizeof(s_data));
3137   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3138
3139   // Receive the message
3140   union {
3141     // Space large enough to hold an 'int'
3142     char control[CMSG_SPACE(sizeof(int))];
3143     struct cmsghdr cmh;
3144   } r_u;
3145   struct msghdr msgh;
3146   struct iovec iov;
3147   int r_data = 0;
3148
3149   msgh.msg_control = r_u.control;
3150   msgh.msg_controllen = sizeof(r_u.control);
3151   msgh.msg_name = nullptr;
3152   msgh.msg_namelen = 0;
3153   msgh.msg_iov = &iov;
3154   msgh.msg_iovlen = 1;
3155   iov.iov_base = &r_data;
3156   iov.iov_len = sizeof(r_data);
3157
3158   // Receive data
3159   ASSERT_NE(recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
3160
3161   // Validate the received message
3162   ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
3163   ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
3164   ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
3165   ASSERT_EQ(r_data, s_data);
3166   int fd = 0;
3167   memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
3168   ASSERT_NE(fd, 0);
3169   SCOPE_EXIT { close(fd); };
3170
3171   std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
3172
3173   // Reposition to the beginning of the file
3174   ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
3175
3176   // Read the magic string back, and compare it with the original
3177   ASSERT_EQ(
3178       magicString.length(),
3179       read(fd, transferredMagicString.data(), transferredMagicString.size()));
3180   ASSERT_TRUE(std::equal(
3181       magicString.begin(),
3182       magicString.end(),
3183       transferredMagicString.begin()));
3184 }