Fix TFO refused case
[folly.git] / folly / io / async / test / AsyncSocketTest2.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <folly/ExceptionWrapper.h>
17 #include <folly/RWSpinLock.h>
18 #include <folly/Random.h>
19 #include <folly/SocketAddress.h>
20 #include <folly/io/async/AsyncServerSocket.h>
21 #include <folly/io/async/AsyncSocket.h>
22 #include <folly/io/async/AsyncTimeout.h>
23 #include <folly/io/async/EventBase.h>
24
25 #include <folly/experimental/TestUtil.h>
26 #include <folly/io/IOBuf.h>
27 #include <folly/io/async/test/AsyncSocketTest.h>
28 #include <folly/io/async/test/Util.h>
29 #include <folly/portability/GMock.h>
30 #include <folly/portability/GTest.h>
31 #include <folly/portability/Sockets.h>
32 #include <folly/portability/Unistd.h>
33 #include <folly/test/SocketAddressTestHelper.h>
34
35 #include <boost/scoped_array.hpp>
36 #include <fcntl.h>
37 #include <sys/types.h>
38 #include <iostream>
39 #include <thread>
40
41 using namespace boost;
42
43 using std::string;
44 using std::vector;
45 using std::min;
46 using std::cerr;
47 using std::endl;
48 using std::unique_ptr;
49 using std::chrono::milliseconds;
50 using boost::scoped_array;
51
52 using namespace folly;
53 using namespace testing;
54
55 namespace fsp = folly::portability::sockets;
56
57 class DelayedWrite: public AsyncTimeout {
58  public:
59   DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
60       unique_ptr<IOBuf>&& bufs, AsyncTransportWrapper::WriteCallback* wcb,
61       bool cork, bool lastWrite = false):
62     AsyncTimeout(socket->getEventBase()),
63     socket_(socket),
64     bufs_(std::move(bufs)),
65     wcb_(wcb),
66     cork_(cork),
67     lastWrite_(lastWrite) {}
68
69  private:
70   void timeoutExpired() noexcept override {
71     WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
72     socket_->writeChain(wcb_, std::move(bufs_), flags);
73     if (lastWrite_) {
74       socket_->shutdownWrite();
75     }
76   }
77
78   std::shared_ptr<AsyncSocket> socket_;
79   unique_ptr<IOBuf> bufs_;
80   AsyncTransportWrapper::WriteCallback* wcb_;
81   bool cork_;
82   bool lastWrite_;
83 };
84
85 ///////////////////////////////////////////////////////////////////////////
86 // connect() tests
87 ///////////////////////////////////////////////////////////////////////////
88
89 /**
90  * Test connecting to a server
91  */
92 TEST(AsyncSocketTest, Connect) {
93   // Start listening on a local port
94   TestServer server;
95
96   // Connect using a AsyncSocket
97   EventBase evb;
98   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
99   ConnCallback cb;
100   socket->connect(&cb, server.getAddress(), 30);
101
102   evb.loop();
103
104   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
105   EXPECT_LE(0, socket->getConnectTime().count());
106   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
107 }
108
109 enum class TFOState {
110   DISABLED,
111   ENABLED,
112 };
113
114 class AsyncSocketConnectTest : public ::testing::TestWithParam<TFOState> {};
115
116 std::vector<TFOState> getTestingValues() {
117   std::vector<TFOState> vals;
118   vals.emplace_back(TFOState::DISABLED);
119
120 #if FOLLY_ALLOW_TFO
121   vals.emplace_back(TFOState::ENABLED);
122 #endif
123   return vals;
124 }
125
126 INSTANTIATE_TEST_CASE_P(
127     ConnectTests,
128     AsyncSocketConnectTest,
129     ::testing::ValuesIn(getTestingValues()));
130
131 /**
132  * Test connecting to a server that isn't listening
133  */
134 TEST(AsyncSocketTest, ConnectRefused) {
135   EventBase evb;
136
137   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
138
139   // Hopefully nothing is actually listening on this address
140   folly::SocketAddress addr("127.0.0.1", 65535);
141   ConnCallback cb;
142   socket->connect(&cb, addr, 30);
143
144   evb.loop();
145
146   EXPECT_EQ(STATE_FAILED, cb.state);
147   EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType());
148   EXPECT_LE(0, socket->getConnectTime().count());
149   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
150 }
151
152 /**
153  * Test connection timeout
154  */
155 TEST(AsyncSocketTest, ConnectTimeout) {
156   EventBase evb;
157
158   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
159
160   // Try connecting to server that won't respond.
161   //
162   // This depends somewhat on the network where this test is run.
163   // Hopefully this IP will be routable but unresponsive.
164   // (Alternatively, we could try listening on a local raw socket, but that
165   // normally requires root privileges.)
166   auto host =
167       SocketAddressTestHelper::isIPv6Enabled() ?
168       SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6 :
169       SocketAddressTestHelper::isIPv4Enabled() ?
170       SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4 :
171       nullptr;
172   SocketAddress addr(host, 65535);
173   ConnCallback cb;
174   socket->connect(&cb, addr, 1); // also set a ridiculously small timeout
175
176   evb.loop();
177
178   ASSERT_EQ(cb.state, STATE_FAILED);
179   ASSERT_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);
180
181   // Verify that we can still get the peer address after a timeout.
182   // Use case is if the client was created from a client pool, and we want
183   // to log which peer failed.
184   folly::SocketAddress peer;
185   socket->getPeerAddress(&peer);
186   ASSERT_EQ(peer, addr);
187   EXPECT_LE(0, socket->getConnectTime().count());
188   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(1));
189 }
190
191 /**
192  * Test writing immediately after connecting, without waiting for connect
193  * to finish.
194  */
195 TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
196   TestServer server;
197
198   // connect()
199   EventBase evb;
200   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
201
202   if (GetParam() == TFOState::ENABLED) {
203     socket->enableTFO();
204   }
205
206   ConnCallback ccb;
207   socket->connect(&ccb, server.getAddress(), 30);
208
209   // write()
210   char buf[128];
211   memset(buf, 'a', sizeof(buf));
212   WriteCallback wcb;
213   socket->write(&wcb, buf, sizeof(buf));
214
215   // Loop.  We don't bother accepting on the server socket yet.
216   // The kernel should be able to buffer the write request so it can succeed.
217   evb.loop();
218
219   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
220   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
221
222   // Make sure the server got a connection and received the data
223   socket->close();
224   server.verifyConnection(buf, sizeof(buf));
225
226   ASSERT_TRUE(socket->isClosedBySelf());
227   ASSERT_FALSE(socket->isClosedByPeer());
228   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
229 }
230
231 /**
232  * Test connecting using a nullptr connect callback.
233  */
234 TEST_P(AsyncSocketConnectTest, ConnectNullCallback) {
235   TestServer server;
236
237   // connect()
238   EventBase evb;
239   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
240   if (GetParam() == TFOState::ENABLED) {
241     socket->enableTFO();
242   }
243
244   socket->connect(nullptr, server.getAddress(), 30);
245
246   // write some data, just so we have some way of verifing
247   // that the socket works correctly after connecting
248   char buf[128];
249   memset(buf, 'a', sizeof(buf));
250   WriteCallback wcb;
251   socket->write(&wcb, buf, sizeof(buf));
252
253   evb.loop();
254
255   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
256
257   // Make sure the server got a connection and received the data
258   socket->close();
259   server.verifyConnection(buf, sizeof(buf));
260
261   ASSERT_TRUE(socket->isClosedBySelf());
262   ASSERT_FALSE(socket->isClosedByPeer());
263 }
264
265 /**
266  * Test calling both write() and close() immediately after connecting, without
267  * waiting for connect to finish.
268  *
269  * This exercises the STATE_CONNECTING_CLOSING code.
270  */
271 TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) {
272   TestServer server;
273
274   // connect()
275   EventBase evb;
276   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
277   if (GetParam() == TFOState::ENABLED) {
278     socket->enableTFO();
279   }
280   ConnCallback ccb;
281   socket->connect(&ccb, server.getAddress(), 30);
282
283   // write()
284   char buf[128];
285   memset(buf, 'a', sizeof(buf));
286   WriteCallback wcb;
287   socket->write(&wcb, buf, sizeof(buf));
288
289   // close()
290   socket->close();
291
292   // Loop.  We don't bother accepting on the server socket yet.
293   // The kernel should be able to buffer the write request so it can succeed.
294   evb.loop();
295
296   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
297   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
298
299   // Make sure the server got a connection and received the data
300   server.verifyConnection(buf, sizeof(buf));
301
302   ASSERT_TRUE(socket->isClosedBySelf());
303   ASSERT_FALSE(socket->isClosedByPeer());
304 }
305
306 /**
307  * Test calling close() immediately after connect()
308  */
309 TEST(AsyncSocketTest, ConnectAndClose) {
310   TestServer server;
311
312   // Connect using a AsyncSocket
313   EventBase evb;
314   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
315   ConnCallback ccb;
316   socket->connect(&ccb, server.getAddress(), 30);
317
318   // Hopefully the connect didn't succeed immediately.
319   // If it did, we can't exercise the close-while-connecting code path.
320   if (ccb.state == STATE_SUCCEEDED) {
321     LOG(INFO) << "connect() succeeded immediately; aborting test "
322                        "of close-during-connect behavior";
323     return;
324   }
325
326   socket->close();
327
328   // Loop, although there shouldn't be anything to do.
329   evb.loop();
330
331   // Make sure the connection was aborted
332   ASSERT_EQ(ccb.state, STATE_FAILED);
333
334   ASSERT_TRUE(socket->isClosedBySelf());
335   ASSERT_FALSE(socket->isClosedByPeer());
336 }
337
338 /**
339  * Test calling closeNow() immediately after connect()
340  *
341  * This should be identical to the normal close behavior.
342  */
343 TEST(AsyncSocketTest, ConnectAndCloseNow) {
344   TestServer server;
345
346   // Connect using a AsyncSocket
347   EventBase evb;
348   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
349   ConnCallback ccb;
350   socket->connect(&ccb, server.getAddress(), 30);
351
352   // Hopefully the connect didn't succeed immediately.
353   // If it did, we can't exercise the close-while-connecting code path.
354   if (ccb.state == STATE_SUCCEEDED) {
355     LOG(INFO) << "connect() succeeded immediately; aborting test "
356                        "of closeNow()-during-connect behavior";
357     return;
358   }
359
360   socket->closeNow();
361
362   // Loop, although there shouldn't be anything to do.
363   evb.loop();
364
365   // Make sure the connection was aborted
366   ASSERT_EQ(ccb.state, STATE_FAILED);
367
368   ASSERT_TRUE(socket->isClosedBySelf());
369   ASSERT_FALSE(socket->isClosedByPeer());
370 }
371
372 /**
373  * Test calling both write() and closeNow() immediately after connecting,
374  * without waiting for connect to finish.
375  *
376  * This should abort the pending write.
377  */
378 TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
379   TestServer server;
380
381   // connect()
382   EventBase evb;
383   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
384   ConnCallback ccb;
385   socket->connect(&ccb, server.getAddress(), 30);
386
387   // Hopefully the connect didn't succeed immediately.
388   // If it did, we can't exercise the close-while-connecting code path.
389   if (ccb.state == STATE_SUCCEEDED) {
390     LOG(INFO) << "connect() succeeded immediately; aborting test "
391                        "of write-during-connect behavior";
392     return;
393   }
394
395   // write()
396   char buf[128];
397   memset(buf, 'a', sizeof(buf));
398   WriteCallback wcb;
399   socket->write(&wcb, buf, sizeof(buf));
400
401   // close()
402   socket->closeNow();
403
404   // Loop, although there shouldn't be anything to do.
405   evb.loop();
406
407   ASSERT_EQ(ccb.state, STATE_FAILED);
408   ASSERT_EQ(wcb.state, STATE_FAILED);
409
410   ASSERT_TRUE(socket->isClosedBySelf());
411   ASSERT_FALSE(socket->isClosedByPeer());
412 }
413
414 /**
415  * Test installing a read callback immediately, before connect() finishes.
416  */
417 TEST_P(AsyncSocketConnectTest, ConnectAndRead) {
418   TestServer server;
419
420   // connect()
421   EventBase evb;
422   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
423   if (GetParam() == TFOState::ENABLED) {
424     socket->enableTFO();
425   }
426
427   ConnCallback ccb;
428   socket->connect(&ccb, server.getAddress(), 30);
429
430   ReadCallback rcb;
431   socket->setReadCB(&rcb);
432
433   if (GetParam() == TFOState::ENABLED) {
434     // Trigger a connection
435     socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
436   }
437
438   // Even though we haven't looped yet, we should be able to accept
439   // the connection and send data to it.
440   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
441   uint8_t buf[128];
442   memset(buf, 'a', sizeof(buf));
443   acceptedSocket->write(buf, sizeof(buf));
444   acceptedSocket->flush();
445   acceptedSocket->close();
446
447   // Loop, although there shouldn't be anything to do.
448   evb.loop();
449
450   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
451   ASSERT_EQ(rcb.buffers.size(), 1);
452   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf));
453   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
454
455   ASSERT_FALSE(socket->isClosedBySelf());
456   ASSERT_FALSE(socket->isClosedByPeer());
457 }
458
459 /**
460  * Test installing a read callback and then closing immediately before the
461  * connect attempt finishes.
462  */
463 TEST(AsyncSocketTest, ConnectReadAndClose) {
464   TestServer server;
465
466   // connect()
467   EventBase evb;
468   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
469   ConnCallback ccb;
470   socket->connect(&ccb, server.getAddress(), 30);
471
472   // Hopefully the connect didn't succeed immediately.
473   // If it did, we can't exercise the close-while-connecting code path.
474   if (ccb.state == STATE_SUCCEEDED) {
475     LOG(INFO) << "connect() succeeded immediately; aborting test "
476                        "of read-during-connect behavior";
477     return;
478   }
479
480   ReadCallback rcb;
481   socket->setReadCB(&rcb);
482
483   // close()
484   socket->close();
485
486   // Loop, although there shouldn't be anything to do.
487   evb.loop();
488
489   ASSERT_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
490   ASSERT_EQ(rcb.buffers.size(), 0);
491   ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
492
493   ASSERT_TRUE(socket->isClosedBySelf());
494   ASSERT_FALSE(socket->isClosedByPeer());
495 }
496
497 /**
498  * Test both writing and installing a read callback immediately,
499  * before connect() finishes.
500  */
501 TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) {
502   TestServer server;
503
504   // connect()
505   EventBase evb;
506   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
507   if (GetParam() == TFOState::ENABLED) {
508     socket->enableTFO();
509   }
510   ConnCallback ccb;
511   socket->connect(&ccb, server.getAddress(), 30);
512
513   // write()
514   char buf1[128];
515   memset(buf1, 'a', sizeof(buf1));
516   WriteCallback wcb;
517   socket->write(&wcb, buf1, sizeof(buf1));
518
519   // set a read callback
520   ReadCallback rcb;
521   socket->setReadCB(&rcb);
522
523   // Even though we haven't looped yet, we should be able to accept
524   // the connection and send data to it.
525   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
526   uint8_t buf2[128];
527   memset(buf2, 'b', sizeof(buf2));
528   acceptedSocket->write(buf2, sizeof(buf2));
529   acceptedSocket->flush();
530
531   // shut down the write half of acceptedSocket, so that the AsyncSocket
532   // will stop reading and we can break out of the event loop.
533   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
534
535   // Loop
536   evb.loop();
537
538   // Make sure the connect succeeded
539   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
540
541   // Make sure the AsyncSocket read the data written by the accepted socket
542   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
543   ASSERT_EQ(rcb.buffers.size(), 1);
544   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf2));
545   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);
546
547   // Close the AsyncSocket so we'll see EOF on acceptedSocket
548   socket->close();
549
550   // Make sure the accepted socket saw the data written by the AsyncSocket
551   uint8_t readbuf[sizeof(buf1)];
552   acceptedSocket->readAll(readbuf, sizeof(readbuf));
553   ASSERT_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
554   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
555   ASSERT_EQ(bytesRead, 0);
556
557   ASSERT_FALSE(socket->isClosedBySelf());
558   ASSERT_TRUE(socket->isClosedByPeer());
559 }
560
561 /**
562  * Test writing to the socket then shutting down writes before the connect
563  * attempt finishes.
564  */
565 TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
566   TestServer server;
567
568   // connect()
569   EventBase evb;
570   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
571   ConnCallback ccb;
572   socket->connect(&ccb, server.getAddress(), 30);
573
574   // Hopefully the connect didn't succeed immediately.
575   // If it did, we can't exercise the write-while-connecting code path.
576   if (ccb.state == STATE_SUCCEEDED) {
577     LOG(INFO) << "connect() succeeded immediately; skipping test";
578     return;
579   }
580
581   // Ask to write some data
582   char wbuf[128];
583   memset(wbuf, 'a', sizeof(wbuf));
584   WriteCallback wcb;
585   socket->write(&wcb, wbuf, sizeof(wbuf));
586   socket->shutdownWrite();
587
588   // Shutdown writes
589   socket->shutdownWrite();
590
591   // Even though we haven't looped yet, we should be able to accept
592   // the connection.
593   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
594
595   // Since the connection is still in progress, there should be no data to
596   // read yet.  Verify that the accepted socket is not readable.
597   struct pollfd fds[1];
598   fds[0].fd = acceptedSocket->getSocketFD();
599   fds[0].events = POLLIN;
600   fds[0].revents = 0;
601   int rc = poll(fds, 1, 0);
602   ASSERT_EQ(rc, 0);
603
604   // Write data to the accepted socket
605   uint8_t acceptedWbuf[192];
606   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
607   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
608   acceptedSocket->flush();
609
610   // Loop
611   evb.loop();
612
613   // The loop should have completed the connection, written the queued data,
614   // and shutdown writes on the socket.
615   //
616   // Check that the connection was completed successfully and that the write
617   // callback succeeded.
618   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
619   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
620
621   // Check that we can read the data that was written to the socket, and that
622   // we see an EOF, since its socket was half-shutdown.
623   uint8_t readbuf[sizeof(wbuf)];
624   acceptedSocket->readAll(readbuf, sizeof(readbuf));
625   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
626   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
627   ASSERT_EQ(bytesRead, 0);
628
629   // Close the accepted socket.  This will cause it to see EOF
630   // and uninstall the read callback when we loop next.
631   acceptedSocket->close();
632
633   // Install a read callback, then loop again.
634   ReadCallback rcb;
635   socket->setReadCB(&rcb);
636   evb.loop();
637
638   // This loop should have read the data and seen the EOF
639   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
640   ASSERT_EQ(rcb.buffers.size(), 1);
641   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
642   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
643                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
644
645   ASSERT_FALSE(socket->isClosedBySelf());
646   ASSERT_FALSE(socket->isClosedByPeer());
647 }
648
649 /**
650  * Test reading, writing, and shutting down writes before the connect attempt
651  * finishes.
652  */
653 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
654   TestServer server;
655
656   // connect()
657   EventBase evb;
658   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
659   ConnCallback ccb;
660   socket->connect(&ccb, server.getAddress(), 30);
661
662   // Hopefully the connect didn't succeed immediately.
663   // If it did, we can't exercise the write-while-connecting code path.
664   if (ccb.state == STATE_SUCCEEDED) {
665     LOG(INFO) << "connect() succeeded immediately; skipping test";
666     return;
667   }
668
669   // Install a read callback
670   ReadCallback rcb;
671   socket->setReadCB(&rcb);
672
673   // Ask to write some data
674   char wbuf[128];
675   memset(wbuf, 'a', sizeof(wbuf));
676   WriteCallback wcb;
677   socket->write(&wcb, wbuf, sizeof(wbuf));
678
679   // Shutdown writes
680   socket->shutdownWrite();
681
682   // Even though we haven't looped yet, we should be able to accept
683   // the connection.
684   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
685
686   // Since the connection is still in progress, there should be no data to
687   // read yet.  Verify that the accepted socket is not readable.
688   struct pollfd fds[1];
689   fds[0].fd = acceptedSocket->getSocketFD();
690   fds[0].events = POLLIN;
691   fds[0].revents = 0;
692   int rc = poll(fds, 1, 0);
693   ASSERT_EQ(rc, 0);
694
695   // Write data to the accepted socket
696   uint8_t acceptedWbuf[192];
697   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
698   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
699   acceptedSocket->flush();
700   // Shutdown writes to the accepted socket.  This will cause it to see EOF
701   // and uninstall the read callback.
702   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
703
704   // Loop
705   evb.loop();
706
707   // The loop should have completed the connection, written the queued data,
708   // shutdown writes on the socket, read the data we wrote to it, and see the
709   // EOF.
710   //
711   // Check that the connection was completed successfully and that the read
712   // and write callbacks were invoked as expected.
713   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
714   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
715   ASSERT_EQ(rcb.buffers.size(), 1);
716   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
717   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
718                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
719   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
720
721   // Check that we can read the data that was written to the socket, and that
722   // we see an EOF, since its socket was half-shutdown.
723   uint8_t readbuf[sizeof(wbuf)];
724   acceptedSocket->readAll(readbuf, sizeof(readbuf));
725   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
726   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
727   ASSERT_EQ(bytesRead, 0);
728
729   // Fully close both sockets
730   acceptedSocket->close();
731   socket->close();
732
733   ASSERT_FALSE(socket->isClosedBySelf());
734   ASSERT_TRUE(socket->isClosedByPeer());
735 }
736
737 /**
738  * Test reading, writing, and calling shutdownWriteNow() before the
739  * connect attempt finishes.
740  */
741 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
742   TestServer server;
743
744   // connect()
745   EventBase evb;
746   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
747   ConnCallback ccb;
748   socket->connect(&ccb, server.getAddress(), 30);
749
750   // Hopefully the connect didn't succeed immediately.
751   // If it did, we can't exercise the write-while-connecting code path.
752   if (ccb.state == STATE_SUCCEEDED) {
753     LOG(INFO) << "connect() succeeded immediately; skipping test";
754     return;
755   }
756
757   // Install a read callback
758   ReadCallback rcb;
759   socket->setReadCB(&rcb);
760
761   // Ask to write some data
762   char wbuf[128];
763   memset(wbuf, 'a', sizeof(wbuf));
764   WriteCallback wcb;
765   socket->write(&wcb, wbuf, sizeof(wbuf));
766
767   // Shutdown writes immediately.
768   // This should immediately discard the data that we just tried to write.
769   socket->shutdownWriteNow();
770
771   // Verify that writeError() was invoked on the write callback.
772   ASSERT_EQ(wcb.state, STATE_FAILED);
773   ASSERT_EQ(wcb.bytesWritten, 0);
774
775   // Even though we haven't looped yet, we should be able to accept
776   // the connection.
777   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
778
779   // Since the connection is still in progress, there should be no data to
780   // read yet.  Verify that the accepted socket is not readable.
781   struct pollfd fds[1];
782   fds[0].fd = acceptedSocket->getSocketFD();
783   fds[0].events = POLLIN;
784   fds[0].revents = 0;
785   int rc = poll(fds, 1, 0);
786   ASSERT_EQ(rc, 0);
787
788   // Write data to the accepted socket
789   uint8_t acceptedWbuf[192];
790   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
791   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
792   acceptedSocket->flush();
793   // Shutdown writes to the accepted socket.  This will cause it to see EOF
794   // and uninstall the read callback.
795   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
796
797   // Loop
798   evb.loop();
799
800   // The loop should have completed the connection, written the queued data,
801   // shutdown writes on the socket, read the data we wrote to it, and see the
802   // EOF.
803   //
804   // Check that the connection was completed successfully and that the read
805   // callback was invoked as expected.
806   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
807   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
808   ASSERT_EQ(rcb.buffers.size(), 1);
809   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
810   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
811                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
812
813   // Since we used shutdownWriteNow(), it should have discarded all pending
814   // write data.  Verify we see an immediate EOF when reading from the accepted
815   // socket.
816   uint8_t readbuf[sizeof(wbuf)];
817   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
818   ASSERT_EQ(bytesRead, 0);
819
820   // Fully close both sockets
821   acceptedSocket->close();
822   socket->close();
823
824   ASSERT_FALSE(socket->isClosedBySelf());
825   ASSERT_TRUE(socket->isClosedByPeer());
826 }
827
828 // Helper function for use in testConnectOptWrite()
829 // Temporarily disable the read callback
830 void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
831   // Uninstall the read callback
832   socket->setReadCB(nullptr);
833   // Schedule the read callback to be reinstalled after 1ms
834   socket->getEventBase()->runInLoop(
835       std::bind(&AsyncSocket::setReadCB, socket, rcb));
836 }
837
838 /**
839  * Test connect+write, then have the connect callback perform another write.
840  *
841  * This tests interaction of the optimistic writing after connect with
842  * additional write attempts that occur in the connect callback.
843  */
844 void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
845   TestServer server;
846   EventBase evb;
847   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
848
849   // connect()
850   ConnCallback ccb;
851   socket->connect(&ccb, server.getAddress(), 30);
852
853   // Hopefully the connect didn't succeed immediately.
854   // If it did, we can't exercise the optimistic write code path.
855   if (ccb.state == STATE_SUCCEEDED) {
856     LOG(INFO) << "connect() succeeded immediately; aborting test "
857                        "of optimistic write behavior";
858     return;
859   }
860
861   // Tell the connect callback to perform a write when the connect succeeds
862   WriteCallback wcb2;
863   scoped_array<char> buf2(new char[size2]);
864   memset(buf2.get(), 'b', size2);
865   if (size2 > 0) {
866     ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
867     // Tell the second write callback to close the connection when it is done
868     wcb2.successCallback = [&] { socket->closeNow(); };
869   }
870
871   // Schedule one write() immediately, before the connect finishes
872   scoped_array<char> buf1(new char[size1]);
873   memset(buf1.get(), 'a', size1);
874   WriteCallback wcb1;
875   if (size1 > 0) {
876     socket->write(&wcb1, buf1.get(), size1);
877   }
878
879   if (close) {
880     // immediately perform a close, before connect() completes
881     socket->close();
882   }
883
884   // Start reading from the other endpoint after 10ms.
885   // If we're using large buffers, we have to read so that the writes don't
886   // block forever.
887   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
888   ReadCallback rcb;
889   rcb.dataAvailableCallback = std::bind(tmpDisableReads,
890                                         acceptedSocket.get(), &rcb);
891   socket->getEventBase()->tryRunAfterDelay(
892       std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb),
893       10);
894
895   // Loop.  We don't bother accepting on the server socket yet.
896   // The kernel should be able to buffer the write request so it can succeed.
897   evb.loop();
898
899   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
900   if (size1 > 0) {
901     ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
902   }
903   if (size2 > 0) {
904     ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
905   }
906
907   socket->close();
908
909   // Make sure the read callback received all of the data
910   size_t bytesRead = 0;
911   for (vector<ReadCallback::Buffer>::const_iterator it = rcb.buffers.begin();
912        it != rcb.buffers.end();
913        ++it) {
914     size_t start = bytesRead;
915     bytesRead += it->length;
916     size_t end = bytesRead;
917     if (start < size1) {
918       size_t cmpLen = min(size1, end) - start;
919       ASSERT_EQ(memcmp(it->buffer, buf1.get() + start, cmpLen), 0);
920     }
921     if (end > size1 && end <= size1 + size2) {
922       size_t itOffset;
923       size_t buf2Offset;
924       size_t cmpLen;
925       if (start >= size1) {
926         itOffset = 0;
927         buf2Offset = start - size1;
928         cmpLen = end - start;
929       } else {
930         itOffset = size1 - start;
931         buf2Offset = 0;
932         cmpLen = end - size1;
933       }
934       ASSERT_EQ(memcmp(it->buffer + itOffset, buf2.get() + buf2Offset,
935                                cmpLen),
936                         0);
937     }
938   }
939   ASSERT_EQ(bytesRead, size1 + size2);
940 }
941
942 TEST(AsyncSocketTest, ConnectCallbackWrite) {
943   // Test using small writes that should both succeed immediately
944   testConnectOptWrite(100, 200);
945
946   // Test using a large buffer in the connect callback, that should block
947   const size_t largeSize = 8*1024*1024;
948   testConnectOptWrite(100, largeSize);
949
950   // Test using a large initial write
951   testConnectOptWrite(largeSize, 100);
952
953   // Test using two large buffers
954   testConnectOptWrite(largeSize, largeSize);
955
956   // Test a small write in the connect callback,
957   // but no immediate write before connect completes
958   testConnectOptWrite(0, 64);
959
960   // Test a large write in the connect callback,
961   // but no immediate write before connect completes
962   testConnectOptWrite(0, largeSize);
963
964   // Test connect, a small write, then immediately call close() before connect
965   // completes
966   testConnectOptWrite(211, 0, true);
967
968   // Test connect, a large immediate write (that will block), then immediately
969   // call close() before connect completes
970   testConnectOptWrite(largeSize, 0, true);
971 }
972
973 ///////////////////////////////////////////////////////////////////////////
974 // write() related tests
975 ///////////////////////////////////////////////////////////////////////////
976
977 /**
978  * Test writing using a nullptr callback
979  */
980 TEST(AsyncSocketTest, WriteNullCallback) {
981   TestServer server;
982
983   // connect()
984   EventBase evb;
985   std::shared_ptr<AsyncSocket> socket =
986     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
987   evb.loop(); // loop until the socket is connected
988
989   // write() with a nullptr callback
990   char buf[128];
991   memset(buf, 'a', sizeof(buf));
992   socket->write(nullptr, buf, sizeof(buf));
993
994   evb.loop(); // loop until the data is sent
995
996   // Make sure the server got a connection and received the data
997   socket->close();
998   server.verifyConnection(buf, sizeof(buf));
999
1000   ASSERT_TRUE(socket->isClosedBySelf());
1001   ASSERT_FALSE(socket->isClosedByPeer());
1002 }
1003
1004 /**
1005  * Test writing with a send timeout
1006  */
1007 TEST(AsyncSocketTest, WriteTimeout) {
1008   TestServer server;
1009
1010   // connect()
1011   EventBase evb;
1012   std::shared_ptr<AsyncSocket> socket =
1013     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1014   evb.loop(); // loop until the socket is connected
1015
1016   // write() a large chunk of data, with no-one on the other end reading
1017   size_t writeLength = 8*1024*1024;
1018   uint32_t timeout = 200;
1019   socket->setSendTimeout(timeout);
1020   scoped_array<char> buf(new char[writeLength]);
1021   memset(buf.get(), 'a', writeLength);
1022   WriteCallback wcb;
1023   socket->write(&wcb, buf.get(), writeLength);
1024
1025   TimePoint start;
1026   evb.loop();
1027   TimePoint end;
1028
1029   // Make sure the write attempt timed out as requested
1030   ASSERT_EQ(wcb.state, STATE_FAILED);
1031   ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);
1032
1033   // Check that the write timed out within a reasonable period of time.
1034   // We don't check for exactly the specified timeout, since AsyncSocket only
1035   // times out when it hasn't made progress for that period of time.
1036   //
1037   // On linux, the first write sends a few hundred kb of data, then blocks for
1038   // writability, and then unblocks again after 40ms and is able to write
1039   // another smaller of data before blocking permanently.  Therefore it doesn't
1040   // time out until 40ms + timeout.
1041   //
1042   // I haven't fully verified the cause of this, but I believe it probably
1043   // occurs because the receiving end delays sending an ack for up to 40ms.
1044   // (This is the default value for TCP_DELACK_MIN.)  Once the sender receives
1045   // the ack, it can send some more data.  However, after that point the
1046   // receiver's kernel buffer is full.  This 40ms delay happens even with
1047   // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints.  However, the
1048   // kernel may be automatically disabling TCP_QUICKACK after receiving some
1049   // data.
1050   //
1051   // For now, we simply check that the timeout occurred within 160ms of
1052   // the requested value.
1053   T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
1054 }
1055
1056 /**
1057  * Test writing to a socket that the remote endpoint has closed
1058  */
1059 TEST(AsyncSocketTest, WritePipeError) {
1060   TestServer server;
1061
1062   // connect()
1063   EventBase evb;
1064   std::shared_ptr<AsyncSocket> socket =
1065     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1066   socket->setSendTimeout(1000);
1067   evb.loop(); // loop until the socket is connected
1068
1069   // accept and immediately close the socket
1070   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1071   acceptedSocket->close();
1072
1073   // write() a large chunk of data
1074   size_t writeLength = 8*1024*1024;
1075   scoped_array<char> buf(new char[writeLength]);
1076   memset(buf.get(), 'a', writeLength);
1077   WriteCallback wcb;
1078   socket->write(&wcb, buf.get(), writeLength);
1079
1080   evb.loop();
1081
1082   // Make sure the write failed.
1083   // It would be nice if AsyncSocketException could convey the errno value,
1084   // so that we could check for EPIPE
1085   ASSERT_EQ(wcb.state, STATE_FAILED);
1086   ASSERT_EQ(wcb.exception.getType(),
1087                     AsyncSocketException::INTERNAL_ERROR);
1088
1089   ASSERT_FALSE(socket->isClosedBySelf());
1090   ASSERT_FALSE(socket->isClosedByPeer());
1091 }
1092
1093 /**
1094  * Test writing a mix of simple buffers and IOBufs
1095  */
1096 TEST(AsyncSocketTest, WriteIOBuf) {
1097   TestServer server;
1098
1099   // connect()
1100   EventBase evb;
1101   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1102   ConnCallback ccb;
1103   socket->connect(&ccb, server.getAddress(), 30);
1104
1105   // Accept the connection
1106   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1107   ReadCallback rcb;
1108   acceptedSocket->setReadCB(&rcb);
1109
1110   // Write a simple buffer to the socket
1111   constexpr size_t simpleBufLength = 5;
1112   char simpleBuf[simpleBufLength];
1113   memset(simpleBuf, 'a', simpleBufLength);
1114   WriteCallback wcb;
1115   socket->write(&wcb, simpleBuf, simpleBufLength);
1116
1117   // Write a single-element IOBuf chain
1118   size_t buf1Length = 7;
1119   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1120   memset(buf1->writableData(), 'b', buf1Length);
1121   buf1->append(buf1Length);
1122   unique_ptr<IOBuf> buf1Copy(buf1->clone());
1123   WriteCallback wcb2;
1124   socket->writeChain(&wcb2, std::move(buf1));
1125
1126   // Write a multiple-element IOBuf chain
1127   size_t buf2Length = 11;
1128   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1129   memset(buf2->writableData(), 'c', buf2Length);
1130   buf2->append(buf2Length);
1131   size_t buf3Length = 13;
1132   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1133   memset(buf3->writableData(), 'd', buf3Length);
1134   buf3->append(buf3Length);
1135   buf2->appendChain(std::move(buf3));
1136   unique_ptr<IOBuf> buf2Copy(buf2->clone());
1137   buf2Copy->coalesce();
1138   WriteCallback wcb3;
1139   socket->writeChain(&wcb3, std::move(buf2));
1140   socket->shutdownWrite();
1141
1142   // Let the reads and writes run to completion
1143   evb.loop();
1144
1145   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1146   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1147   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1148
1149   // Make sure the reader got the right data in the right order
1150   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1151   ASSERT_EQ(rcb.buffers.size(), 1);
1152   ASSERT_EQ(rcb.buffers[0].length,
1153       simpleBufLength + buf1Length + buf2Length + buf3Length);
1154   ASSERT_EQ(
1155       memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
1156   ASSERT_EQ(
1157       memcmp(rcb.buffers[0].buffer + simpleBufLength,
1158           buf1Copy->data(), buf1Copy->length()), 0);
1159   ASSERT_EQ(
1160       memcmp(rcb.buffers[0].buffer + simpleBufLength + buf1Length,
1161           buf2Copy->data(), buf2Copy->length()), 0);
1162
1163   acceptedSocket->close();
1164   socket->close();
1165
1166   ASSERT_TRUE(socket->isClosedBySelf());
1167   ASSERT_FALSE(socket->isClosedByPeer());
1168 }
1169
1170 TEST(AsyncSocketTest, WriteIOBufCorked) {
1171   TestServer server;
1172
1173   // connect()
1174   EventBase evb;
1175   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1176   ConnCallback ccb;
1177   socket->connect(&ccb, server.getAddress(), 30);
1178
1179   // Accept the connection
1180   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1181   ReadCallback rcb;
1182   acceptedSocket->setReadCB(&rcb);
1183
1184   // Do three writes, 100ms apart, with the "cork" flag set
1185   // on the second write.  The reader should see the first write
1186   // arrive by itself, followed by the second and third writes
1187   // arriving together.
1188   size_t buf1Length = 5;
1189   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1190   memset(buf1->writableData(), 'a', buf1Length);
1191   buf1->append(buf1Length);
1192   size_t buf2Length = 7;
1193   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1194   memset(buf2->writableData(), 'b', buf2Length);
1195   buf2->append(buf2Length);
1196   size_t buf3Length = 11;
1197   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1198   memset(buf3->writableData(), 'c', buf3Length);
1199   buf3->append(buf3Length);
1200   WriteCallback wcb1;
1201   socket->writeChain(&wcb1, std::move(buf1));
1202   WriteCallback wcb2;
1203   DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
1204   write2.scheduleTimeout(100);
1205   WriteCallback wcb3;
1206   DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
1207   write3.scheduleTimeout(140);
1208
1209   evb.loop();
1210   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1211   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1212   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1213   if (wcb3.state != STATE_SUCCEEDED) {
1214     throw(wcb3.exception);
1215   }
1216   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1217
1218   // Make sure the reader got the data with the right grouping
1219   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1220   ASSERT_EQ(rcb.buffers.size(), 2);
1221   ASSERT_EQ(rcb.buffers[0].length, buf1Length);
1222   ASSERT_EQ(rcb.buffers[1].length, buf2Length + buf3Length);
1223
1224   acceptedSocket->close();
1225   socket->close();
1226
1227   ASSERT_TRUE(socket->isClosedBySelf());
1228   ASSERT_FALSE(socket->isClosedByPeer());
1229 }
1230
1231 /**
1232  * Test performing a zero-length write
1233  */
1234 TEST(AsyncSocketTest, ZeroLengthWrite) {
1235   TestServer server;
1236
1237   // connect()
1238   EventBase evb;
1239   std::shared_ptr<AsyncSocket> socket =
1240     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1241   evb.loop(); // loop until the socket is connected
1242
1243   auto acceptedSocket = server.acceptAsync(&evb);
1244   ReadCallback rcb;
1245   acceptedSocket->setReadCB(&rcb);
1246
1247   size_t len1 = 1024*1024;
1248   size_t len2 = 1024*1024;
1249   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1250   memset(buf.get(), 'a', len1);
1251   memset(buf.get(), 'b', len2);
1252
1253   WriteCallback wcb1;
1254   WriteCallback wcb2;
1255   WriteCallback wcb3;
1256   WriteCallback wcb4;
1257   socket->write(&wcb1, buf.get(), 0);
1258   socket->write(&wcb2, buf.get(), len1);
1259   socket->write(&wcb3, buf.get() + len1, 0);
1260   socket->write(&wcb4, buf.get() + len1, len2);
1261   socket->close();
1262
1263   evb.loop(); // loop until the data is sent
1264
1265   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1266   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1267   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1268   ASSERT_EQ(wcb4.state, STATE_SUCCEEDED);
1269   rcb.verifyData(buf.get(), len1 + len2);
1270
1271   ASSERT_TRUE(socket->isClosedBySelf());
1272   ASSERT_FALSE(socket->isClosedByPeer());
1273 }
1274
1275 TEST(AsyncSocketTest, ZeroLengthWritev) {
1276   TestServer server;
1277
1278   // connect()
1279   EventBase evb;
1280   std::shared_ptr<AsyncSocket> socket =
1281     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1282   evb.loop(); // loop until the socket is connected
1283
1284   auto acceptedSocket = server.acceptAsync(&evb);
1285   ReadCallback rcb;
1286   acceptedSocket->setReadCB(&rcb);
1287
1288   size_t len1 = 1024*1024;
1289   size_t len2 = 1024*1024;
1290   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1291   memset(buf.get(), 'a', len1);
1292   memset(buf.get(), 'b', len2);
1293
1294   WriteCallback wcb;
1295   constexpr size_t iovCount = 4;
1296   struct iovec iov[iovCount];
1297   iov[0].iov_base = buf.get();
1298   iov[0].iov_len = len1;
1299   iov[1].iov_base = buf.get() + len1;
1300   iov[1].iov_len = 0;
1301   iov[2].iov_base = buf.get() + len1;
1302   iov[2].iov_len = len2;
1303   iov[3].iov_base = buf.get() + len1 + len2;
1304   iov[3].iov_len = 0;
1305
1306   socket->writev(&wcb, iov, iovCount);
1307   socket->close();
1308   evb.loop(); // loop until the data is sent
1309
1310   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1311   rcb.verifyData(buf.get(), len1 + len2);
1312
1313   ASSERT_TRUE(socket->isClosedBySelf());
1314   ASSERT_FALSE(socket->isClosedByPeer());
1315 }
1316
1317 ///////////////////////////////////////////////////////////////////////////
1318 // close() related tests
1319 ///////////////////////////////////////////////////////////////////////////
1320
1321 /**
1322  * Test calling close() with pending writes when the socket is already closing.
1323  */
1324 TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
1325   TestServer server;
1326
1327   // connect()
1328   EventBase evb;
1329   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1330   ConnCallback ccb;
1331   socket->connect(&ccb, server.getAddress(), 30);
1332
1333   // accept the socket on the server side
1334   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1335
1336   // Loop to ensure the connect has completed
1337   evb.loop();
1338
1339   // Make sure we are connected
1340   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1341
1342   // Schedule pending writes, until several write attempts have blocked
1343   char buf[128];
1344   memset(buf, 'a', sizeof(buf));
1345   typedef vector< std::shared_ptr<WriteCallback> > WriteCallbackVector;
1346   WriteCallbackVector writeCallbacks;
1347
1348   writeCallbacks.reserve(5);
1349   while (writeCallbacks.size() < 5) {
1350     std::shared_ptr<WriteCallback> wcb(new WriteCallback);
1351
1352     socket->write(wcb.get(), buf, sizeof(buf));
1353     if (wcb->state == STATE_SUCCEEDED) {
1354       // Succeeded immediately.  Keep performing more writes
1355       continue;
1356     }
1357
1358     // This write is blocked.
1359     // Have the write callback call close() when writeError() is invoked
1360     wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
1361     writeCallbacks.push_back(wcb);
1362   }
1363
1364   // Call closeNow() to immediately fail the pending writes
1365   socket->closeNow();
1366
1367   // Make sure writeError() was invoked on all of the pending write callbacks
1368   for (WriteCallbackVector::const_iterator it = writeCallbacks.begin();
1369        it != writeCallbacks.end();
1370        ++it) {
1371     ASSERT_EQ((*it)->state, STATE_FAILED);
1372   }
1373
1374   ASSERT_TRUE(socket->isClosedBySelf());
1375   ASSERT_FALSE(socket->isClosedByPeer());
1376 }
1377
1378 ///////////////////////////////////////////////////////////////////////////
1379 // ImmediateRead related tests
1380 ///////////////////////////////////////////////////////////////////////////
1381
1382 /* AsyncSocket use to verify immediate read works */
1383 class AsyncSocketImmediateRead : public folly::AsyncSocket {
1384  public:
1385   bool immediateReadCalled = false;
1386   explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}
1387  protected:
1388   void checkForImmediateRead() noexcept override {
1389     immediateReadCalled = true;
1390     AsyncSocket::handleRead();
1391   }
1392 };
1393
1394 TEST(AsyncSocket, ConnectReadImmediateRead) {
1395   TestServer server;
1396
1397   const size_t maxBufferSz = 100;
1398   const size_t maxReadsPerEvent = 1;
1399   const size_t expectedDataSz = maxBufferSz * 3;
1400   char expectedData[expectedDataSz];
1401   memset(expectedData, 'j', expectedDataSz);
1402
1403   EventBase evb;
1404   ReadCallback rcb(maxBufferSz);
1405   AsyncSocketImmediateRead socket(&evb);
1406   socket.connect(nullptr, server.getAddress(), 30);
1407
1408   evb.loop(); // loop until the socket is connected
1409
1410   socket.setReadCB(&rcb);
1411   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1412   socket.immediateReadCalled = false;
1413
1414   auto acceptedSocket = server.acceptAsync(&evb);
1415
1416   ReadCallback rcbServer;
1417   WriteCallback wcbServer;
1418   rcbServer.dataAvailableCallback = [&]() {
1419     if (rcbServer.dataRead() == expectedDataSz) {
1420       // write back all data read
1421       rcbServer.verifyData(expectedData, expectedDataSz);
1422       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1423       acceptedSocket->close();
1424     }
1425   };
1426   acceptedSocket->setReadCB(&rcbServer);
1427
1428   // write data
1429   WriteCallback wcb1;
1430   socket.write(&wcb1, expectedData, expectedDataSz);
1431   evb.loop();
1432   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1433   rcb.verifyData(expectedData, expectedDataSz);
1434   ASSERT_EQ(socket.immediateReadCalled, true);
1435
1436   ASSERT_FALSE(socket.isClosedBySelf());
1437   ASSERT_FALSE(socket.isClosedByPeer());
1438 }
1439
1440 TEST(AsyncSocket, ConnectReadUninstallRead) {
1441   TestServer server;
1442
1443   const size_t maxBufferSz = 100;
1444   const size_t maxReadsPerEvent = 1;
1445   const size_t expectedDataSz = maxBufferSz * 3;
1446   char expectedData[expectedDataSz];
1447   memset(expectedData, 'k', expectedDataSz);
1448
1449   EventBase evb;
1450   ReadCallback rcb(maxBufferSz);
1451   AsyncSocketImmediateRead socket(&evb);
1452   socket.connect(nullptr, server.getAddress(), 30);
1453
1454   evb.loop(); // loop until the socket is connected
1455
1456   socket.setReadCB(&rcb);
1457   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1458   socket.immediateReadCalled = false;
1459
1460   auto acceptedSocket = server.acceptAsync(&evb);
1461
1462   ReadCallback rcbServer;
1463   WriteCallback wcbServer;
1464   rcbServer.dataAvailableCallback = [&]() {
1465     if (rcbServer.dataRead() == expectedDataSz) {
1466       // write back all data read
1467       rcbServer.verifyData(expectedData, expectedDataSz);
1468       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1469       acceptedSocket->close();
1470     }
1471   };
1472   acceptedSocket->setReadCB(&rcbServer);
1473
1474   rcb.dataAvailableCallback = [&]() {
1475     // we read data and reset readCB
1476     socket.setReadCB(nullptr);
1477   };
1478
1479   // write data
1480   WriteCallback wcb;
1481   socket.write(&wcb, expectedData, expectedDataSz);
1482   evb.loop();
1483   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1484
1485   /* we shoud've only read maxBufferSz data since readCallback_
1486    * was reset in dataAvailableCallback */
1487   ASSERT_EQ(rcb.dataRead(), maxBufferSz);
1488   ASSERT_EQ(socket.immediateReadCalled, false);
1489
1490   ASSERT_FALSE(socket.isClosedBySelf());
1491   ASSERT_FALSE(socket.isClosedByPeer());
1492 }
1493
1494 // TODO:
1495 // - Test connect() and have the connect callback set the read callback
1496 // - Test connect() and have the connect callback unset the read callback
1497 // - Test reading/writing/closing/destroying the socket in the connect callback
1498 // - Test reading/writing/closing/destroying the socket in the read callback
1499 // - Test reading/writing/closing/destroying the socket in the write callback
1500 // - Test one-way shutdown behavior
1501 // - Test changing the EventBase
1502 //
1503 // - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
1504 //   in connectSuccess(), readDataAvailable(), writeSuccess()
1505
1506
1507 ///////////////////////////////////////////////////////////////////////////
1508 // AsyncServerSocket tests
1509 ///////////////////////////////////////////////////////////////////////////
1510 namespace {
1511 /**
1512  * Helper ConnectionEventCallback class for the test code.
1513  * It maintains counters protected by a spin lock.
1514  */
1515 class TestConnectionEventCallback :
1516   public AsyncServerSocket::ConnectionEventCallback {
1517  public:
1518   virtual void onConnectionAccepted(
1519       const int /* socket */,
1520       const SocketAddress& /* addr */) noexcept override {
1521     folly::RWSpinLock::WriteHolder holder(spinLock_);
1522     connectionAccepted_++;
1523   }
1524
1525   virtual void onConnectionAcceptError(const int /* err */) noexcept override {
1526     folly::RWSpinLock::WriteHolder holder(spinLock_);
1527     connectionAcceptedError_++;
1528   }
1529
1530   virtual void onConnectionDropped(
1531       const int /* socket */,
1532       const SocketAddress& /* addr */) noexcept override {
1533     folly::RWSpinLock::WriteHolder holder(spinLock_);
1534     connectionDropped_++;
1535   }
1536
1537   virtual void onConnectionEnqueuedForAcceptorCallback(
1538       const int /* socket */,
1539       const SocketAddress& /* addr */) noexcept override {
1540     folly::RWSpinLock::WriteHolder holder(spinLock_);
1541     connectionEnqueuedForAcceptCallback_++;
1542   }
1543
1544   virtual void onConnectionDequeuedByAcceptorCallback(
1545       const int /* socket */,
1546       const SocketAddress& /* addr */) noexcept override {
1547     folly::RWSpinLock::WriteHolder holder(spinLock_);
1548     connectionDequeuedByAcceptCallback_++;
1549   }
1550
1551   virtual void onBackoffStarted() noexcept override {
1552     folly::RWSpinLock::WriteHolder holder(spinLock_);
1553     backoffStarted_++;
1554   }
1555
1556   virtual void onBackoffEnded() noexcept override {
1557     folly::RWSpinLock::WriteHolder holder(spinLock_);
1558     backoffEnded_++;
1559   }
1560
1561   virtual void onBackoffError() noexcept override {
1562     folly::RWSpinLock::WriteHolder holder(spinLock_);
1563     backoffError_++;
1564   }
1565
1566   unsigned int getConnectionAccepted() const {
1567     folly::RWSpinLock::ReadHolder holder(spinLock_);
1568     return connectionAccepted_;
1569   }
1570
1571   unsigned int getConnectionAcceptedError() const {
1572     folly::RWSpinLock::ReadHolder holder(spinLock_);
1573     return connectionAcceptedError_;
1574   }
1575
1576   unsigned int getConnectionDropped() const {
1577     folly::RWSpinLock::ReadHolder holder(spinLock_);
1578     return connectionDropped_;
1579   }
1580
1581   unsigned int getConnectionEnqueuedForAcceptCallback() const {
1582     folly::RWSpinLock::ReadHolder holder(spinLock_);
1583     return connectionEnqueuedForAcceptCallback_;
1584   }
1585
1586   unsigned int getConnectionDequeuedByAcceptCallback() const {
1587     folly::RWSpinLock::ReadHolder holder(spinLock_);
1588     return connectionDequeuedByAcceptCallback_;
1589   }
1590
1591   unsigned int getBackoffStarted() const {
1592     folly::RWSpinLock::ReadHolder holder(spinLock_);
1593     return backoffStarted_;
1594   }
1595
1596   unsigned int getBackoffEnded() const {
1597     folly::RWSpinLock::ReadHolder holder(spinLock_);
1598     return backoffEnded_;
1599   }
1600
1601   unsigned int getBackoffError() const {
1602     folly::RWSpinLock::ReadHolder holder(spinLock_);
1603     return backoffError_;
1604   }
1605
1606  private:
1607   mutable folly::RWSpinLock spinLock_;
1608   unsigned int connectionAccepted_{0};
1609   unsigned int connectionAcceptedError_{0};
1610   unsigned int connectionDropped_{0};
1611   unsigned int connectionEnqueuedForAcceptCallback_{0};
1612   unsigned int connectionDequeuedByAcceptCallback_{0};
1613   unsigned int backoffStarted_{0};
1614   unsigned int backoffEnded_{0};
1615   unsigned int backoffError_{0};
1616 };
1617
1618 /**
1619  * Helper AcceptCallback class for the test code
1620  * It records the callbacks that were invoked, and also supports calling
1621  * generic std::function objects in each callback.
1622  */
1623 class TestAcceptCallback : public AsyncServerSocket::AcceptCallback {
1624  public:
1625   enum EventType {
1626     TYPE_START,
1627     TYPE_ACCEPT,
1628     TYPE_ERROR,
1629     TYPE_STOP
1630   };
1631   struct EventInfo {
1632     EventInfo(int fd, const folly::SocketAddress& addr)
1633       : type(TYPE_ACCEPT),
1634         fd(fd),
1635         address(addr),
1636         errorMsg() {}
1637     explicit EventInfo(const std::string& msg)
1638       : type(TYPE_ERROR),
1639         fd(-1),
1640         address(),
1641         errorMsg(msg) {}
1642     explicit EventInfo(EventType et)
1643       : type(et),
1644         fd(-1),
1645         address(),
1646         errorMsg() {}
1647
1648     EventType type;
1649     int fd;  // valid for TYPE_ACCEPT
1650     folly::SocketAddress address;  // valid for TYPE_ACCEPT
1651     string errorMsg;  // valid for TYPE_ERROR
1652   };
1653   typedef std::deque<EventInfo> EventList;
1654
1655   TestAcceptCallback()
1656     : connectionAcceptedFn_(),
1657       acceptErrorFn_(),
1658       acceptStoppedFn_(),
1659       events_() {}
1660
1661   std::deque<EventInfo>* getEvents() {
1662     return &events_;
1663   }
1664
1665   void setConnectionAcceptedFn(
1666       const std::function<void(int, const folly::SocketAddress&)>& fn) {
1667     connectionAcceptedFn_ = fn;
1668   }
1669   void setAcceptErrorFn(const std::function<void(const std::exception&)>& fn) {
1670     acceptErrorFn_ = fn;
1671   }
1672   void setAcceptStartedFn(const std::function<void()>& fn) {
1673     acceptStartedFn_ = fn;
1674   }
1675   void setAcceptStoppedFn(const std::function<void()>& fn) {
1676     acceptStoppedFn_ = fn;
1677   }
1678
1679   void connectionAccepted(
1680       int fd, const folly::SocketAddress& clientAddr) noexcept override {
1681     events_.emplace_back(fd, clientAddr);
1682
1683     if (connectionAcceptedFn_) {
1684       connectionAcceptedFn_(fd, clientAddr);
1685     }
1686   }
1687   void acceptError(const std::exception& ex) noexcept override {
1688     events_.emplace_back(ex.what());
1689
1690     if (acceptErrorFn_) {
1691       acceptErrorFn_(ex);
1692     }
1693   }
1694   void acceptStarted() noexcept override {
1695     events_.emplace_back(TYPE_START);
1696
1697     if (acceptStartedFn_) {
1698       acceptStartedFn_();
1699     }
1700   }
1701   void acceptStopped() noexcept override {
1702     events_.emplace_back(TYPE_STOP);
1703
1704     if (acceptStoppedFn_) {
1705       acceptStoppedFn_();
1706     }
1707   }
1708
1709  private:
1710   std::function<void(int, const folly::SocketAddress&)> connectionAcceptedFn_;
1711   std::function<void(const std::exception&)> acceptErrorFn_;
1712   std::function<void()> acceptStartedFn_;
1713   std::function<void()> acceptStoppedFn_;
1714
1715   std::deque<EventInfo> events_;
1716 };
1717 }
1718
1719 /**
1720  * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
1721  */
1722 TEST(AsyncSocketTest, ServerAcceptOptions) {
1723   EventBase eventBase;
1724
1725   // Create a server socket
1726   std::shared_ptr<AsyncServerSocket> serverSocket(
1727       AsyncServerSocket::newSocket(&eventBase));
1728   serverSocket->bind(0);
1729   serverSocket->listen(16);
1730   folly::SocketAddress serverAddress;
1731   serverSocket->getAddress(&serverAddress);
1732
1733   // Add a callback to accept one connection then stop the loop
1734   TestAcceptCallback acceptCallback;
1735   acceptCallback.setConnectionAcceptedFn(
1736       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1737         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1738       });
1739   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
1740     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1741   });
1742   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
1743   serverSocket->startAccepting();
1744
1745   // Connect to the server socket
1746   std::shared_ptr<AsyncSocket> socket(
1747       AsyncSocket::newSocket(&eventBase, serverAddress));
1748
1749   eventBase.loop();
1750
1751   // Verify that the server accepted a connection
1752   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
1753   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
1754                     TestAcceptCallback::TYPE_START);
1755   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
1756                     TestAcceptCallback::TYPE_ACCEPT);
1757   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
1758                     TestAcceptCallback::TYPE_STOP);
1759   int fd = acceptCallback.getEvents()->at(1).fd;
1760
1761   // The accepted connection should already be in non-blocking mode
1762   int flags = fcntl(fd, F_GETFL, 0);
1763   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
1764
1765 #ifndef TCP_NOPUSH
1766   // The accepted connection should already have TCP_NODELAY set
1767   int value;
1768   socklen_t valueLength = sizeof(value);
1769   int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
1770   ASSERT_EQ(rc, 0);
1771   ASSERT_EQ(value, 1);
1772 #endif
1773 }
1774
1775 /**
1776  * Test AsyncServerSocket::removeAcceptCallback()
1777  */
1778 TEST(AsyncSocketTest, RemoveAcceptCallback) {
1779   // Create a new AsyncServerSocket
1780   EventBase eventBase;
1781   std::shared_ptr<AsyncServerSocket> serverSocket(
1782       AsyncServerSocket::newSocket(&eventBase));
1783   serverSocket->bind(0);
1784   serverSocket->listen(16);
1785   folly::SocketAddress serverAddress;
1786   serverSocket->getAddress(&serverAddress);
1787
1788   // Add several accept callbacks
1789   TestAcceptCallback cb1;
1790   TestAcceptCallback cb2;
1791   TestAcceptCallback cb3;
1792   TestAcceptCallback cb4;
1793   TestAcceptCallback cb5;
1794   TestAcceptCallback cb6;
1795   TestAcceptCallback cb7;
1796
1797   // Test having callbacks remove other callbacks before them on the list,
1798   // after them on the list, or removing themselves.
1799   //
1800   // Have callback 2 remove callback 3 and callback 5 the first time it is
1801   // called.
1802   int cb2Count = 0;
1803   cb1.setConnectionAcceptedFn([&](int /* fd */,
1804                                   const folly::SocketAddress& /* addr */) {
1805     std::shared_ptr<AsyncSocket> sock2(
1806         AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2: -cb3 -cb5
1807   });
1808   cb3.setConnectionAcceptedFn(
1809       [&](int /* fd */, const folly::SocketAddress& /* addr */) {});
1810   cb4.setConnectionAcceptedFn(
1811       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1812         std::shared_ptr<AsyncSocket> sock3(
1813             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
1814       });
1815   cb5.setConnectionAcceptedFn(
1816       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1817         std::shared_ptr<AsyncSocket> sock5(
1818             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
1819
1820       });
1821   cb2.setConnectionAcceptedFn(
1822       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1823         if (cb2Count == 0) {
1824           serverSocket->removeAcceptCallback(&cb3, nullptr);
1825           serverSocket->removeAcceptCallback(&cb5, nullptr);
1826         }
1827         ++cb2Count;
1828       });
1829   // Have callback 6 remove callback 4 the first time it is called,
1830   // and destroy the server socket the second time it is called
1831   int cb6Count = 0;
1832   cb6.setConnectionAcceptedFn(
1833       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1834         if (cb6Count == 0) {
1835           serverSocket->removeAcceptCallback(&cb4, nullptr);
1836           std::shared_ptr<AsyncSocket> sock6(
1837               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1838           std::shared_ptr<AsyncSocket> sock7(
1839               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
1840           std::shared_ptr<AsyncSocket> sock8(
1841               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop
1842
1843         } else {
1844           serverSocket.reset();
1845         }
1846         ++cb6Count;
1847       });
1848   // Have callback 7 remove itself
1849   cb7.setConnectionAcceptedFn(
1850       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1851         serverSocket->removeAcceptCallback(&cb7, nullptr);
1852       });
1853
1854   serverSocket->addAcceptCallback(&cb1, &eventBase);
1855   serverSocket->addAcceptCallback(&cb2, &eventBase);
1856   serverSocket->addAcceptCallback(&cb3, &eventBase);
1857   serverSocket->addAcceptCallback(&cb4, &eventBase);
1858   serverSocket->addAcceptCallback(&cb5, &eventBase);
1859   serverSocket->addAcceptCallback(&cb6, &eventBase);
1860   serverSocket->addAcceptCallback(&cb7, &eventBase);
1861   serverSocket->startAccepting();
1862
1863   // Make several connections to the socket
1864   std::shared_ptr<AsyncSocket> sock1(
1865       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1866   std::shared_ptr<AsyncSocket> sock4(
1867       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4
1868
1869   // Loop until we are stopped
1870   eventBase.loop();
1871
1872   // Check to make sure that the expected callbacks were invoked.
1873   //
1874   // NOTE: This code depends on the AsyncServerSocket operating calling all of
1875   // the AcceptCallbacks in round-robin fashion, in the order that they were
1876   // added.  The code is implemented this way right now, but the API doesn't
1877   // explicitly require it be done this way.  If we change the code not to be
1878   // exactly round robin in the future, we can simplify the test checks here.
1879   // (We'll also need to update the termination code, since we expect cb6 to
1880   // get called twice to terminate the loop.)
1881   ASSERT_EQ(cb1.getEvents()->size(), 4);
1882   ASSERT_EQ(cb1.getEvents()->at(0).type,
1883                     TestAcceptCallback::TYPE_START);
1884   ASSERT_EQ(cb1.getEvents()->at(1).type,
1885                     TestAcceptCallback::TYPE_ACCEPT);
1886   ASSERT_EQ(cb1.getEvents()->at(2).type,
1887                     TestAcceptCallback::TYPE_ACCEPT);
1888   ASSERT_EQ(cb1.getEvents()->at(3).type,
1889                     TestAcceptCallback::TYPE_STOP);
1890
1891   ASSERT_EQ(cb2.getEvents()->size(), 4);
1892   ASSERT_EQ(cb2.getEvents()->at(0).type,
1893                     TestAcceptCallback::TYPE_START);
1894   ASSERT_EQ(cb2.getEvents()->at(1).type,
1895                     TestAcceptCallback::TYPE_ACCEPT);
1896   ASSERT_EQ(cb2.getEvents()->at(2).type,
1897                     TestAcceptCallback::TYPE_ACCEPT);
1898   ASSERT_EQ(cb2.getEvents()->at(3).type,
1899                     TestAcceptCallback::TYPE_STOP);
1900
1901   ASSERT_EQ(cb3.getEvents()->size(), 2);
1902   ASSERT_EQ(cb3.getEvents()->at(0).type,
1903                     TestAcceptCallback::TYPE_START);
1904   ASSERT_EQ(cb3.getEvents()->at(1).type,
1905                     TestAcceptCallback::TYPE_STOP);
1906
1907   ASSERT_EQ(cb4.getEvents()->size(), 3);
1908   ASSERT_EQ(cb4.getEvents()->at(0).type,
1909                     TestAcceptCallback::TYPE_START);
1910   ASSERT_EQ(cb4.getEvents()->at(1).type,
1911                     TestAcceptCallback::TYPE_ACCEPT);
1912   ASSERT_EQ(cb4.getEvents()->at(2).type,
1913                     TestAcceptCallback::TYPE_STOP);
1914
1915   ASSERT_EQ(cb5.getEvents()->size(), 2);
1916   ASSERT_EQ(cb5.getEvents()->at(0).type,
1917                     TestAcceptCallback::TYPE_START);
1918   ASSERT_EQ(cb5.getEvents()->at(1).type,
1919                     TestAcceptCallback::TYPE_STOP);
1920
1921   ASSERT_EQ(cb6.getEvents()->size(), 4);
1922   ASSERT_EQ(cb6.getEvents()->at(0).type,
1923                     TestAcceptCallback::TYPE_START);
1924   ASSERT_EQ(cb6.getEvents()->at(1).type,
1925                     TestAcceptCallback::TYPE_ACCEPT);
1926   ASSERT_EQ(cb6.getEvents()->at(2).type,
1927                     TestAcceptCallback::TYPE_ACCEPT);
1928   ASSERT_EQ(cb6.getEvents()->at(3).type,
1929                     TestAcceptCallback::TYPE_STOP);
1930
1931   ASSERT_EQ(cb7.getEvents()->size(), 3);
1932   ASSERT_EQ(cb7.getEvents()->at(0).type,
1933                     TestAcceptCallback::TYPE_START);
1934   ASSERT_EQ(cb7.getEvents()->at(1).type,
1935                     TestAcceptCallback::TYPE_ACCEPT);
1936   ASSERT_EQ(cb7.getEvents()->at(2).type,
1937                     TestAcceptCallback::TYPE_STOP);
1938 }
1939
1940 /**
1941  * Test AsyncServerSocket::removeAcceptCallback()
1942  */
1943 TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
1944   // Create a new AsyncServerSocket
1945   EventBase eventBase;
1946   std::shared_ptr<AsyncServerSocket> serverSocket(
1947       AsyncServerSocket::newSocket(&eventBase));
1948   serverSocket->bind(0);
1949   serverSocket->listen(16);
1950   folly::SocketAddress serverAddress;
1951   serverSocket->getAddress(&serverAddress);
1952
1953   // Add several accept callbacks
1954   TestAcceptCallback cb1;
1955   auto thread_id = std::this_thread::get_id();
1956   cb1.setAcceptStartedFn([&](){
1957     CHECK_NE(thread_id, std::this_thread::get_id());
1958     thread_id = std::this_thread::get_id();
1959   });
1960   cb1.setConnectionAcceptedFn(
1961       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1962         ASSERT_EQ(thread_id, std::this_thread::get_id());
1963         serverSocket->removeAcceptCallback(&cb1, &eventBase);
1964       });
1965   cb1.setAcceptStoppedFn([&](){
1966     ASSERT_EQ(thread_id, std::this_thread::get_id());
1967   });
1968
1969   // Test having callbacks remove other callbacks before them on the list,
1970   serverSocket->addAcceptCallback(&cb1, &eventBase);
1971   serverSocket->startAccepting();
1972
1973   // Make several connections to the socket
1974   std::shared_ptr<AsyncSocket> sock1(
1975       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1976
1977   // Loop in another thread
1978   auto other = std::thread([&](){
1979     eventBase.loop();
1980   });
1981   other.join();
1982
1983   // Check to make sure that the expected callbacks were invoked.
1984   //
1985   // NOTE: This code depends on the AsyncServerSocket operating calling all of
1986   // the AcceptCallbacks in round-robin fashion, in the order that they were
1987   // added.  The code is implemented this way right now, but the API doesn't
1988   // explicitly require it be done this way.  If we change the code not to be
1989   // exactly round robin in the future, we can simplify the test checks here.
1990   // (We'll also need to update the termination code, since we expect cb6 to
1991   // get called twice to terminate the loop.)
1992   ASSERT_EQ(cb1.getEvents()->size(), 3);
1993   ASSERT_EQ(cb1.getEvents()->at(0).type,
1994                     TestAcceptCallback::TYPE_START);
1995   ASSERT_EQ(cb1.getEvents()->at(1).type,
1996                     TestAcceptCallback::TYPE_ACCEPT);
1997   ASSERT_EQ(cb1.getEvents()->at(2).type,
1998                     TestAcceptCallback::TYPE_STOP);
1999
2000 }
2001
2002 void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
2003   EventBase* eventBase = serverSocket->getEventBase();
2004   CHECK(eventBase);
2005
2006   // Add a callback to accept one connection then stop accepting
2007   TestAcceptCallback acceptCallback;
2008   acceptCallback.setConnectionAcceptedFn(
2009       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2010         serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
2011       });
2012   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2013     serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
2014   });
2015   serverSocket->addAcceptCallback(&acceptCallback, eventBase);
2016   serverSocket->startAccepting();
2017
2018   // Connect to the server socket
2019   folly::SocketAddress serverAddress;
2020   serverSocket->getAddress(&serverAddress);
2021   AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
2022
2023   // Loop to process all events
2024   eventBase->loop();
2025
2026   // Verify that the server accepted a connection
2027   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
2028   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
2029                     TestAcceptCallback::TYPE_START);
2030   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
2031                     TestAcceptCallback::TYPE_ACCEPT);
2032   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
2033                     TestAcceptCallback::TYPE_STOP);
2034 }
2035
2036 /* Verify that we don't leak sockets if we are destroyed()
2037  * and there are still writes pending
2038  *
2039  * If destroy() only calls close() instead of closeNow(),
2040  * it would shutdown(writes) on the socket, but it would
2041  * never be close()'d, and the socket would leak
2042  */
2043 TEST(AsyncSocketTest, DestroyCloseTest) {
2044   TestServer server;
2045
2046   // connect()
2047   EventBase clientEB;
2048   EventBase serverEB;
2049   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
2050   ConnCallback ccb;
2051   socket->connect(&ccb, server.getAddress(), 30);
2052
2053   // Accept the connection
2054   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
2055   ReadCallback rcb;
2056   acceptedSocket->setReadCB(&rcb);
2057
2058   // Write a large buffer to the socket that is larger than kernel buffer
2059   size_t simpleBufLength = 5000000;
2060   char* simpleBuf = new char[simpleBufLength];
2061   memset(simpleBuf, 'a', simpleBufLength);
2062   WriteCallback wcb;
2063
2064   // Let the reads and writes run to completion
2065   int fd = acceptedSocket->getFd();
2066
2067   acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
2068   socket.reset();
2069   acceptedSocket.reset();
2070
2071   // Test that server socket was closed
2072   folly::test::msvcSuppressAbortOnInvalidParams([&] {
2073     ssize_t sz = read(fd, simpleBuf, simpleBufLength);
2074     ASSERT_EQ(sz, -1);
2075     ASSERT_EQ(errno, EBADF);
2076   });
2077   delete[] simpleBuf;
2078 }
2079
2080 /**
2081  * Test AsyncServerSocket::useExistingSocket()
2082  */
2083 TEST(AsyncSocketTest, ServerExistingSocket) {
2084   EventBase eventBase;
2085
2086   // Test creating a socket, and letting AsyncServerSocket bind and listen
2087   {
2088     // Manually create a socket
2089     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2090     ASSERT_GE(fd, 0);
2091
2092     // Create a server socket
2093     AsyncServerSocket::UniquePtr serverSocket(
2094         new AsyncServerSocket(&eventBase));
2095     serverSocket->useExistingSocket(fd);
2096     folly::SocketAddress address;
2097     serverSocket->getAddress(&address);
2098     address.setPort(0);
2099     serverSocket->bind(address);
2100     serverSocket->listen(16);
2101
2102     // Make sure the socket works
2103     serverSocketSanityTest(serverSocket.get());
2104   }
2105
2106   // Test creating a socket and binding manually,
2107   // then letting AsyncServerSocket listen
2108   {
2109     // Manually create a socket
2110     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2111     ASSERT_GE(fd, 0);
2112     // bind
2113     struct sockaddr_in addr;
2114     addr.sin_family = AF_INET;
2115     addr.sin_port = 0;
2116     addr.sin_addr.s_addr = INADDR_ANY;
2117     ASSERT_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
2118                              sizeof(addr)), 0);
2119     // Look up the address that we bound to
2120     folly::SocketAddress boundAddress;
2121     boundAddress.setFromLocalAddress(fd);
2122
2123     // Create a server socket
2124     AsyncServerSocket::UniquePtr serverSocket(
2125         new AsyncServerSocket(&eventBase));
2126     serverSocket->useExistingSocket(fd);
2127     serverSocket->listen(16);
2128
2129     // Make sure AsyncServerSocket reports the same address that we bound to
2130     folly::SocketAddress serverSocketAddress;
2131     serverSocket->getAddress(&serverSocketAddress);
2132     ASSERT_EQ(boundAddress, serverSocketAddress);
2133
2134     // Make sure the socket works
2135     serverSocketSanityTest(serverSocket.get());
2136   }
2137
2138   // Test creating a socket, binding and listening manually,
2139   // then giving it to AsyncServerSocket
2140   {
2141     // Manually create a socket
2142     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2143     ASSERT_GE(fd, 0);
2144     // bind
2145     struct sockaddr_in addr;
2146     addr.sin_family = AF_INET;
2147     addr.sin_port = 0;
2148     addr.sin_addr.s_addr = INADDR_ANY;
2149     ASSERT_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
2150                              sizeof(addr)), 0);
2151     // Look up the address that we bound to
2152     folly::SocketAddress boundAddress;
2153     boundAddress.setFromLocalAddress(fd);
2154     // listen
2155     ASSERT_EQ(listen(fd, 16), 0);
2156
2157     // Create a server socket
2158     AsyncServerSocket::UniquePtr serverSocket(
2159         new AsyncServerSocket(&eventBase));
2160     serverSocket->useExistingSocket(fd);
2161
2162     // Make sure AsyncServerSocket reports the same address that we bound to
2163     folly::SocketAddress serverSocketAddress;
2164     serverSocket->getAddress(&serverSocketAddress);
2165     ASSERT_EQ(boundAddress, serverSocketAddress);
2166
2167     // Make sure the socket works
2168     serverSocketSanityTest(serverSocket.get());
2169   }
2170 }
2171
2172 TEST(AsyncSocketTest, UnixDomainSocketTest) {
2173   EventBase eventBase;
2174
2175   // Create a server socket
2176   std::shared_ptr<AsyncServerSocket> serverSocket(
2177       AsyncServerSocket::newSocket(&eventBase));
2178   string path(1, 0);
2179   path.append(folly::to<string>("/anonymous", folly::Random::rand64()));
2180   folly::SocketAddress serverAddress;
2181   serverAddress.setFromPath(path);
2182   serverSocket->bind(serverAddress);
2183   serverSocket->listen(16);
2184
2185   // Add a callback to accept one connection then stop the loop
2186   TestAcceptCallback acceptCallback;
2187   acceptCallback.setConnectionAcceptedFn(
2188       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2189         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2190       });
2191   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2192     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2193   });
2194   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2195   serverSocket->startAccepting();
2196
2197   // Connect to the server socket
2198   std::shared_ptr<AsyncSocket> socket(
2199       AsyncSocket::newSocket(&eventBase, serverAddress));
2200
2201   eventBase.loop();
2202
2203   // Verify that the server accepted a connection
2204   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
2205   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
2206                     TestAcceptCallback::TYPE_START);
2207   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
2208                     TestAcceptCallback::TYPE_ACCEPT);
2209   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
2210                     TestAcceptCallback::TYPE_STOP);
2211   int fd = acceptCallback.getEvents()->at(1).fd;
2212
2213   // The accepted connection should already be in non-blocking mode
2214   int flags = fcntl(fd, F_GETFL, 0);
2215   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
2216 }
2217
2218 TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
2219   EventBase eventBase;
2220   TestConnectionEventCallback connectionEventCallback;
2221
2222   // Create a server socket
2223   std::shared_ptr<AsyncServerSocket> serverSocket(
2224       AsyncServerSocket::newSocket(&eventBase));
2225   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2226   serverSocket->bind(0);
2227   serverSocket->listen(16);
2228   folly::SocketAddress serverAddress;
2229   serverSocket->getAddress(&serverAddress);
2230
2231   // Add a callback to accept one connection then stop the loop
2232   TestAcceptCallback acceptCallback;
2233   acceptCallback.setConnectionAcceptedFn(
2234       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2235         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2236       });
2237   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2238     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2239   });
2240   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2241   serverSocket->startAccepting();
2242
2243   // Connect to the server socket
2244   std::shared_ptr<AsyncSocket> socket(
2245       AsyncSocket::newSocket(&eventBase, serverAddress));
2246
2247   eventBase.loop();
2248
2249   // Validate the connection event counters
2250   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2251   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2252   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2253   ASSERT_EQ(
2254       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
2255   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
2256   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2257   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2258   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2259 }
2260
2261 TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
2262   EventBase eventBase;
2263   TestConnectionEventCallback connectionEventCallback;
2264
2265   // Create a server socket
2266   std::shared_ptr<AsyncServerSocket> serverSocket(
2267       AsyncServerSocket::newSocket(&eventBase));
2268   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2269   serverSocket->bind(0);
2270   serverSocket->listen(16);
2271   folly::SocketAddress serverAddress;
2272   serverSocket->getAddress(&serverAddress);
2273
2274   // Add a callback to accept one connection then stop the loop
2275   TestAcceptCallback acceptCallback;
2276   acceptCallback.setConnectionAcceptedFn(
2277       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2278         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2279       });
2280   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2281     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2282   });
2283   bool acceptStartedFlag{false};
2284   acceptCallback.setAcceptStartedFn([&acceptStartedFlag](){
2285     acceptStartedFlag = true;
2286   });
2287   bool acceptStoppedFlag{false};
2288   acceptCallback.setAcceptStoppedFn([&acceptStoppedFlag](){
2289     acceptStoppedFlag = true;
2290   });
2291   serverSocket->addAcceptCallback(&acceptCallback, nullptr);
2292   serverSocket->startAccepting();
2293
2294   // Connect to the server socket
2295   std::shared_ptr<AsyncSocket> socket(
2296       AsyncSocket::newSocket(&eventBase, serverAddress));
2297
2298   eventBase.loop();
2299
2300   ASSERT_TRUE(acceptStartedFlag);
2301   ASSERT_TRUE(acceptStoppedFlag);
2302   // Validate the connection event counters
2303   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2304   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2305   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2306   ASSERT_EQ(
2307       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
2308   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
2309   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2310   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2311   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2312 }
2313
2314
2315
2316 /**
2317  * Test AsyncServerSocket::getNumPendingMessagesInQueue()
2318  */
2319 TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
2320   EventBase eventBase;
2321
2322   // Counter of how many connections have been accepted
2323   int count = 0;
2324
2325   // Create a server socket
2326   auto serverSocket(AsyncServerSocket::newSocket(&eventBase));
2327   serverSocket->bind(0);
2328   serverSocket->listen(16);
2329   folly::SocketAddress serverAddress;
2330   serverSocket->getAddress(&serverAddress);
2331
2332   // Add a callback to accept connections
2333   TestAcceptCallback acceptCallback;
2334   acceptCallback.setConnectionAcceptedFn(
2335       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2336         count++;
2337         ASSERT_EQ(4 - count, serverSocket->getNumPendingMessagesInQueue());
2338
2339         if (count == 4) {
2340           // all messages are processed, remove accept callback
2341           serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2342         }
2343       });
2344   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2345     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2346   });
2347   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2348   serverSocket->startAccepting();
2349
2350   // Connect to the server socket, 4 clients, there are 4 connections
2351   auto socket1(AsyncSocket::newSocket(&eventBase, serverAddress));
2352   auto socket2(AsyncSocket::newSocket(&eventBase, serverAddress));
2353   auto socket3(AsyncSocket::newSocket(&eventBase, serverAddress));
2354   auto socket4(AsyncSocket::newSocket(&eventBase, serverAddress));
2355
2356   eventBase.loop();
2357 }
2358
2359 /**
2360  * Test AsyncTransport::BufferCallback
2361  */
2362 TEST(AsyncSocketTest, BufferTest) {
2363   TestServer server;
2364
2365   EventBase evb;
2366   AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2367   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2368   ConnCallback ccb;
2369   socket->connect(&ccb, server.getAddress(), 30, option);
2370
2371   char buf[100 * 1024];
2372   memset(buf, 'c', sizeof(buf));
2373   WriteCallback wcb;
2374   BufferCallback bcb;
2375   socket->setBufferCallback(&bcb);
2376   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2377
2378   evb.loop();
2379   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2380   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2381
2382   ASSERT_TRUE(bcb.hasBuffered());
2383   ASSERT_TRUE(bcb.hasBufferCleared());
2384
2385   socket->close();
2386   server.verifyConnection(buf, sizeof(buf));
2387
2388   ASSERT_TRUE(socket->isClosedBySelf());
2389   ASSERT_FALSE(socket->isClosedByPeer());
2390 }
2391
2392 TEST(AsyncSocketTest, BufferCallbackKill) {
2393   TestServer server;
2394   EventBase evb;
2395   AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2396   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2397   ConnCallback ccb;
2398   socket->connect(&ccb, server.getAddress(), 30, option);
2399   evb.loopOnce();
2400
2401   char buf[100 * 1024];
2402   memset(buf, 'c', sizeof(buf));
2403   BufferCallback bcb;
2404   socket->setBufferCallback(&bcb);
2405   WriteCallback wcb;
2406   wcb.successCallback = [&] {
2407     ASSERT_TRUE(socket.unique());
2408     socket.reset();
2409   };
2410
2411   // This will trigger AsyncSocket::handleWrite,
2412   // which calls WriteCallback::writeSuccess,
2413   // which calls wcb.successCallback above,
2414   // which tries to delete socket
2415   // Then, the socket will also try to use this BufferCallback
2416   // And that should crash us, if there is no DestructorGuard on the stack
2417   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2418
2419   evb.loop();
2420   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2421 }
2422
2423 #if FOLLY_ALLOW_TFO
2424 TEST(AsyncSocketTest, ConnectTFO) {
2425   // Start listening on a local port
2426   TestServer server(true);
2427
2428   // Connect using a AsyncSocket
2429   EventBase evb;
2430   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2431   socket->enableTFO();
2432   ConnCallback cb;
2433   socket->connect(&cb, server.getAddress(), 30);
2434
2435   std::array<uint8_t, 128> buf;
2436   memset(buf.data(), 'a', buf.size());
2437
2438   std::array<uint8_t, 3> readBuf;
2439   auto sendBuf = IOBuf::copyBuffer("hey");
2440
2441   std::thread t([&] {
2442     auto acceptedSocket = server.accept();
2443     acceptedSocket->write(buf.data(), buf.size());
2444     acceptedSocket->flush();
2445     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2446     acceptedSocket->close();
2447   });
2448
2449   evb.loop();
2450
2451   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2452   EXPECT_LE(0, socket->getConnectTime().count());
2453   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2454   EXPECT_TRUE(socket->getTFOAttempted());
2455
2456   // Should trigger the connect
2457   WriteCallback write;
2458   ReadCallback rcb;
2459   socket->writeChain(&write, sendBuf->clone());
2460   socket->setReadCB(&rcb);
2461   evb.loop();
2462
2463   t.join();
2464
2465   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2466   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2467   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2468   ASSERT_EQ(1, rcb.buffers.size());
2469   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2470   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2471   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2472 }
2473
2474 TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
2475   // Start listening on a local port
2476   TestServer server(true);
2477
2478   // Connect using a AsyncSocket
2479   EventBase evb;
2480   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2481   socket->enableTFO();
2482   ConnCallback cb;
2483   socket->connect(&cb, server.getAddress(), 30);
2484   ReadCallback rcb;
2485   socket->setReadCB(&rcb);
2486
2487   std::array<uint8_t, 128> buf;
2488   memset(buf.data(), 'a', buf.size());
2489
2490   std::array<uint8_t, 3> readBuf;
2491   auto sendBuf = IOBuf::copyBuffer("hey");
2492
2493   std::thread t([&] {
2494     auto acceptedSocket = server.accept();
2495     acceptedSocket->write(buf.data(), buf.size());
2496     acceptedSocket->flush();
2497     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2498     acceptedSocket->close();
2499   });
2500
2501   evb.loop();
2502
2503   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2504   EXPECT_LE(0, socket->getConnectTime().count());
2505   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2506   EXPECT_TRUE(socket->getTFOAttempted());
2507
2508   // Should trigger the connect
2509   WriteCallback write;
2510   socket->writeChain(&write, sendBuf->clone());
2511   evb.loop();
2512
2513   t.join();
2514
2515   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2516   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2517   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2518   ASSERT_EQ(1, rcb.buffers.size());
2519   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2520   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2521   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2522 }
2523
2524 /**
2525  * Test connecting to a server that isn't listening
2526  */
2527 TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
2528   EventBase evb;
2529
2530   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2531
2532   socket->enableTFO();
2533
2534   // Hopefully nothing is actually listening on this address
2535   folly::SocketAddress addr("::1", 65535);
2536   ConnCallback cb;
2537   socket->connect(&cb, addr, 30);
2538
2539   evb.loop();
2540
2541   WriteCallback write1;
2542   // Trigger the connect if TFO attempt is supported.
2543   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2544   WriteCallback write2;
2545   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2546   evb.loop();
2547
2548   if (!socket->getTFOFinished()) {
2549     EXPECT_EQ(STATE_FAILED, write1.state);
2550   } else {
2551     EXPECT_EQ(STATE_SUCCEEDED, write1.state);
2552     EXPECT_FALSE(socket->getTFOSucceded());
2553   }
2554
2555   EXPECT_EQ(STATE_FAILED, write2.state);
2556
2557   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2558   EXPECT_LE(0, socket->getConnectTime().count());
2559   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2560   EXPECT_TRUE(socket->getTFOAttempted());
2561 }
2562
2563 /**
2564  * Test calling closeNow() immediately after connecting.
2565  */
2566 TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) {
2567   TestServer server(true);
2568
2569   // connect()
2570   EventBase evb;
2571   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2572   socket->enableTFO();
2573
2574   ConnCallback ccb;
2575   socket->connect(&ccb, server.getAddress(), 30);
2576
2577   // write()
2578   std::array<char, 128> buf;
2579   memset(buf.data(), 'a', buf.size());
2580
2581   // close()
2582   socket->closeNow();
2583
2584   // Loop, although there shouldn't be anything to do.
2585   evb.loop();
2586
2587   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2588
2589   ASSERT_TRUE(socket->isClosedBySelf());
2590   ASSERT_FALSE(socket->isClosedByPeer());
2591 }
2592
2593 /**
2594  * Test calling close() immediately after connect()
2595  */
2596 TEST(AsyncSocketTest, ConnectAndCloseTFO) {
2597   TestServer server(true);
2598
2599   // Connect using a AsyncSocket
2600   EventBase evb;
2601   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2602   socket->enableTFO();
2603
2604   ConnCallback ccb;
2605   socket->connect(&ccb, server.getAddress(), 30);
2606
2607   socket->close();
2608
2609   // Loop, although there shouldn't be anything to do.
2610   evb.loop();
2611
2612   // Make sure the connection was aborted
2613   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2614
2615   ASSERT_TRUE(socket->isClosedBySelf());
2616   ASSERT_FALSE(socket->isClosedByPeer());
2617 }
2618
2619 class MockAsyncTFOSocket : public AsyncSocket {
2620  public:
2621   using UniquePtr = std::unique_ptr<MockAsyncTFOSocket, Destructor>;
2622
2623   explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
2624
2625   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
2626 };
2627
2628 TEST(AsyncSocketTest, TestTFOUnsupported) {
2629   TestServer server(true);
2630
2631   // Connect using a AsyncSocket
2632   EventBase evb;
2633   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2634   socket->enableTFO();
2635
2636   ConnCallback ccb;
2637   socket->connect(&ccb, server.getAddress(), 30);
2638   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2639
2640   ReadCallback rcb;
2641   socket->setReadCB(&rcb);
2642
2643   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2644       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2645   WriteCallback write;
2646   auto sendBuf = IOBuf::copyBuffer("hey");
2647   socket->writeChain(&write, sendBuf->clone());
2648   EXPECT_EQ(STATE_WAITING, write.state);
2649
2650   std::array<uint8_t, 128> buf;
2651   memset(buf.data(), 'a', buf.size());
2652
2653   std::array<uint8_t, 3> readBuf;
2654
2655   std::thread t([&] {
2656     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2657     acceptedSocket->write(buf.data(), buf.size());
2658     acceptedSocket->flush();
2659     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2660     acceptedSocket->close();
2661   });
2662
2663   evb.loop();
2664
2665   t.join();
2666   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2667   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2668
2669   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2670   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2671   ASSERT_EQ(1, rcb.buffers.size());
2672   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2673   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2674   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2675 }
2676
2677 TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
2678   EventBase evb;
2679
2680   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2681   socket->enableTFO();
2682
2683   // Hopefully this fails
2684   folly::SocketAddress fakeAddr("127.0.0.1", 65535);
2685   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2686       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2687         sockaddr_storage addr;
2688         auto len = fakeAddr.getAddress(&addr);
2689         int ret = connect(fd, (const struct sockaddr*)&addr, len);
2690         LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
2691                   << errno;
2692         return ret;
2693       }));
2694
2695   // Hopefully nothing is actually listening on this address
2696   ConnCallback cb;
2697   socket->connect(&cb, fakeAddr, 30);
2698
2699   WriteCallback write1;
2700   // Trigger the connect if TFO attempt is supported.
2701   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2702
2703   if (socket->getTFOFinished()) {
2704     // This test is useless now.
2705     return;
2706   }
2707   WriteCallback write2;
2708   // Trigger the connect if TFO attempt is supported.
2709   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2710   evb.loop();
2711
2712   EXPECT_EQ(STATE_FAILED, write1.state);
2713   EXPECT_EQ(STATE_FAILED, write2.state);
2714   EXPECT_FALSE(socket->getTFOSucceded());
2715
2716   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2717   EXPECT_LE(0, socket->getConnectTime().count());
2718   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2719   EXPECT_TRUE(socket->getTFOAttempted());
2720 }
2721
2722 TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
2723   // Try connecting to server that won't respond.
2724   //
2725   // This depends somewhat on the network where this test is run.
2726   // Hopefully this IP will be routable but unresponsive.
2727   // (Alternatively, we could try listening on a local raw socket, but that
2728   // normally requires root privileges.)
2729   auto host = SocketAddressTestHelper::isIPv6Enabled()
2730       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2731       : SocketAddressTestHelper::isIPv4Enabled()
2732           ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2733           : nullptr;
2734   SocketAddress addr(host, 65535);
2735
2736   // Connect using a AsyncSocket
2737   EventBase evb;
2738   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2739   socket->enableTFO();
2740
2741   ConnCallback ccb;
2742   // Set a very small timeout
2743   socket->connect(&ccb, addr, 1);
2744   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2745
2746   ReadCallback rcb;
2747   socket->setReadCB(&rcb);
2748
2749   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2750       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2751   WriteCallback write;
2752   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2753
2754   evb.loop();
2755
2756   EXPECT_EQ(STATE_FAILED, write.state);
2757 }
2758
2759 TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
2760   TestServer server(true);
2761
2762   // Connect using a AsyncSocket
2763   EventBase evb;
2764   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2765   socket->enableTFO();
2766
2767   ConnCallback ccb;
2768   socket->connect(&ccb, server.getAddress(), 30);
2769   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2770
2771   ReadCallback rcb;
2772   socket->setReadCB(&rcb);
2773
2774   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2775       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2776         sockaddr_storage addr;
2777         auto len = server.getAddress().getAddress(&addr);
2778         return connect(fd, (const struct sockaddr*)&addr, len);
2779       }));
2780   WriteCallback write;
2781   auto sendBuf = IOBuf::copyBuffer("hey");
2782   socket->writeChain(&write, sendBuf->clone());
2783   EXPECT_EQ(STATE_WAITING, write.state);
2784
2785   std::array<uint8_t, 128> buf;
2786   memset(buf.data(), 'a', buf.size());
2787
2788   std::array<uint8_t, 3> readBuf;
2789
2790   std::thread t([&] {
2791     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2792     acceptedSocket->write(buf.data(), buf.size());
2793     acceptedSocket->flush();
2794     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2795     acceptedSocket->close();
2796   });
2797
2798   evb.loop();
2799
2800   t.join();
2801   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2802
2803   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2804   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2805
2806   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2807   ASSERT_EQ(1, rcb.buffers.size());
2808   ASSERT_EQ(buf.size(), rcb.buffers[0].length);
2809   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2810 }
2811
2812 TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
2813   // Try connecting to server that won't respond.
2814   //
2815   // This depends somewhat on the network where this test is run.
2816   // Hopefully this IP will be routable but unresponsive.
2817   // (Alternatively, we could try listening on a local raw socket, but that
2818   // normally requires root privileges.)
2819   auto host = SocketAddressTestHelper::isIPv6Enabled()
2820       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2821       : SocketAddressTestHelper::isIPv4Enabled()
2822           ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2823           : nullptr;
2824   SocketAddress addr(host, 65535);
2825
2826   // Connect using a AsyncSocket
2827   EventBase evb;
2828   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2829   socket->enableTFO();
2830
2831   ConnCallback ccb;
2832   // Set a very small timeout
2833   socket->connect(&ccb, addr, 1);
2834   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2835
2836   ReadCallback rcb;
2837   socket->setReadCB(&rcb);
2838
2839   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2840       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2841         sockaddr_storage addr2;
2842         auto len = addr.getAddress(&addr2);
2843         return connect(fd, (const struct sockaddr*)&addr2, len);
2844       }));
2845   WriteCallback write;
2846   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2847
2848   evb.loop();
2849
2850   EXPECT_EQ(STATE_FAILED, write.state);
2851 }
2852
2853 TEST(AsyncSocketTest, TestTFOEagain) {
2854   TestServer server(true);
2855
2856   // Connect using a AsyncSocket
2857   EventBase evb;
2858   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2859   socket->enableTFO();
2860
2861   ConnCallback ccb;
2862   socket->connect(&ccb, server.getAddress(), 30);
2863
2864   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2865       .WillOnce(SetErrnoAndReturn(EAGAIN, -1));
2866   WriteCallback write;
2867   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2868
2869   evb.loop();
2870
2871   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2872   EXPECT_EQ(STATE_FAILED, write.state);
2873 }
2874
2875 // Sending a large amount of data in the first write which will
2876 // definitely not fit into MSS.
2877 TEST(AsyncSocketTest, ConnectTFOWithBigData) {
2878   // Start listening on a local port
2879   TestServer server(true);
2880
2881   // Connect using a AsyncSocket
2882   EventBase evb;
2883   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2884   socket->enableTFO();
2885   ConnCallback cb;
2886   socket->connect(&cb, server.getAddress(), 30);
2887
2888   std::array<uint8_t, 128> buf;
2889   memset(buf.data(), 'a', buf.size());
2890
2891   constexpr size_t len = 10 * 1024;
2892   auto sendBuf = IOBuf::create(len);
2893   sendBuf->append(len);
2894   std::array<uint8_t, len> readBuf;
2895
2896   std::thread t([&] {
2897     auto acceptedSocket = server.accept();
2898     acceptedSocket->write(buf.data(), buf.size());
2899     acceptedSocket->flush();
2900     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2901     acceptedSocket->close();
2902   });
2903
2904   evb.loop();
2905
2906   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2907   EXPECT_LE(0, socket->getConnectTime().count());
2908   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2909   EXPECT_TRUE(socket->getTFOAttempted());
2910
2911   // Should trigger the connect
2912   WriteCallback write;
2913   ReadCallback rcb;
2914   socket->writeChain(&write, sendBuf->clone());
2915   socket->setReadCB(&rcb);
2916   evb.loop();
2917
2918   t.join();
2919
2920   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2921   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2922   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2923   ASSERT_EQ(1, rcb.buffers.size());
2924   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2925   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2926   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2927 }
2928
2929 #endif