Adds writer test case for RCU
[folly.git] / folly / io / async / test / AsyncSocketTest2.cpp
1 /*
2  * Copyright 2010-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/test/AsyncSocketTest2.h>
18
19 #include <folly/ConstexprMath.h>
20 #include <folly/ExceptionWrapper.h>
21 #include <folly/Random.h>
22 #include <folly/SocketAddress.h>
23 #include <folly/io/async/AsyncSocket.h>
24 #include <folly/io/async/AsyncTimeout.h>
25 #include <folly/io/async/EventBase.h>
26
27 #include <folly/experimental/TestUtil.h>
28 #include <folly/io/IOBuf.h>
29 #include <folly/io/async/test/AsyncSocketTest.h>
30 #include <folly/io/async/test/Util.h>
31 #include <folly/portability/GMock.h>
32 #include <folly/portability/GTest.h>
33 #include <folly/portability/Sockets.h>
34 #include <folly/portability/Unistd.h>
35 #include <folly/test/SocketAddressTestHelper.h>
36
37 #include <boost/scoped_array.hpp>
38 #include <fcntl.h>
39 #include <sys/types.h>
40 #include <iostream>
41 #include <thread>
42
43 using namespace boost;
44
45 using std::string;
46 using std::vector;
47 using std::min;
48 using std::cerr;
49 using std::endl;
50 using std::unique_ptr;
51 using std::chrono::milliseconds;
52 using boost::scoped_array;
53
54 using namespace folly;
55 using namespace folly::test;
56 using namespace testing;
57
58 namespace fsp = folly::portability::sockets;
59
60 class DelayedWrite: public AsyncTimeout {
61  public:
62   DelayedWrite(const std::shared_ptr<AsyncSocket>& socket,
63       unique_ptr<IOBuf>&& bufs, AsyncTransportWrapper::WriteCallback* wcb,
64       bool cork, bool lastWrite = false):
65     AsyncTimeout(socket->getEventBase()),
66     socket_(socket),
67     bufs_(std::move(bufs)),
68     wcb_(wcb),
69     cork_(cork),
70     lastWrite_(lastWrite) {}
71
72  private:
73   void timeoutExpired() noexcept override {
74     WriteFlags flags = cork_ ? WriteFlags::CORK : WriteFlags::NONE;
75     socket_->writeChain(wcb_, std::move(bufs_), flags);
76     if (lastWrite_) {
77       socket_->shutdownWrite();
78     }
79   }
80
81   std::shared_ptr<AsyncSocket> socket_;
82   unique_ptr<IOBuf> bufs_;
83   AsyncTransportWrapper::WriteCallback* wcb_;
84   bool cork_;
85   bool lastWrite_;
86 };
87
88 ///////////////////////////////////////////////////////////////////////////
89 // connect() tests
90 ///////////////////////////////////////////////////////////////////////////
91
92 /**
93  * Test connecting to a server
94  */
95 TEST(AsyncSocketTest, Connect) {
96   // Start listening on a local port
97   TestServer server;
98
99   // Connect using a AsyncSocket
100   EventBase evb;
101   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
102   ConnCallback cb;
103   socket->connect(&cb, server.getAddress(), 30);
104
105   evb.loop();
106
107   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
108   EXPECT_LE(0, socket->getConnectTime().count());
109   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
110 }
111
112 enum class TFOState {
113   DISABLED,
114   ENABLED,
115 };
116
117 class AsyncSocketConnectTest : public ::testing::TestWithParam<TFOState> {};
118
119 std::vector<TFOState> getTestingValues() {
120   std::vector<TFOState> vals;
121   vals.emplace_back(TFOState::DISABLED);
122
123 #if FOLLY_ALLOW_TFO
124   vals.emplace_back(TFOState::ENABLED);
125 #endif
126   return vals;
127 }
128
129 INSTANTIATE_TEST_CASE_P(
130     ConnectTests,
131     AsyncSocketConnectTest,
132     ::testing::ValuesIn(getTestingValues()));
133
134 /**
135  * Test connecting to a server that isn't listening
136  */
137 TEST(AsyncSocketTest, ConnectRefused) {
138   EventBase evb;
139
140   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
141
142   // Hopefully nothing is actually listening on this address
143   folly::SocketAddress addr("127.0.0.1", 65535);
144   ConnCallback cb;
145   socket->connect(&cb, addr, 30);
146
147   evb.loop();
148
149   EXPECT_EQ(STATE_FAILED, cb.state);
150   EXPECT_EQ(AsyncSocketException::NOT_OPEN, cb.exception.getType());
151   EXPECT_LE(0, socket->getConnectTime().count());
152   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
153 }
154
155 /**
156  * Test connection timeout
157  */
158 TEST(AsyncSocketTest, ConnectTimeout) {
159   EventBase evb;
160
161   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
162
163   // Try connecting to server that won't respond.
164   //
165   // This depends somewhat on the network where this test is run.
166   // Hopefully this IP will be routable but unresponsive.
167   // (Alternatively, we could try listening on a local raw socket, but that
168   // normally requires root privileges.)
169   auto host =
170       SocketAddressTestHelper::isIPv6Enabled() ?
171       SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6 :
172       SocketAddressTestHelper::isIPv4Enabled() ?
173       SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4 :
174       nullptr;
175   SocketAddress addr(host, 65535);
176   ConnCallback cb;
177   socket->connect(&cb, addr, 1); // also set a ridiculously small timeout
178
179   evb.loop();
180
181   ASSERT_EQ(cb.state, STATE_FAILED);
182   ASSERT_EQ(cb.exception.getType(), AsyncSocketException::TIMED_OUT);
183
184   // Verify that we can still get the peer address after a timeout.
185   // Use case is if the client was created from a client pool, and we want
186   // to log which peer failed.
187   folly::SocketAddress peer;
188   socket->getPeerAddress(&peer);
189   ASSERT_EQ(peer, addr);
190   EXPECT_LE(0, socket->getConnectTime().count());
191   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(1));
192 }
193
194 /**
195  * Test writing immediately after connecting, without waiting for connect
196  * to finish.
197  */
198 TEST_P(AsyncSocketConnectTest, ConnectAndWrite) {
199   TestServer server;
200
201   // connect()
202   EventBase evb;
203   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
204
205   if (GetParam() == TFOState::ENABLED) {
206     socket->enableTFO();
207   }
208
209   ConnCallback ccb;
210   socket->connect(&ccb, server.getAddress(), 30);
211
212   // write()
213   char buf[128];
214   memset(buf, 'a', sizeof(buf));
215   WriteCallback wcb;
216   socket->write(&wcb, buf, sizeof(buf));
217
218   // Loop.  We don't bother accepting on the server socket yet.
219   // The kernel should be able to buffer the write request so it can succeed.
220   evb.loop();
221
222   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
223   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
224
225   // Make sure the server got a connection and received the data
226   socket->close();
227   server.verifyConnection(buf, sizeof(buf));
228
229   ASSERT_TRUE(socket->isClosedBySelf());
230   ASSERT_FALSE(socket->isClosedByPeer());
231   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
232 }
233
234 /**
235  * Test connecting using a nullptr connect callback.
236  */
237 TEST_P(AsyncSocketConnectTest, ConnectNullCallback) {
238   TestServer server;
239
240   // connect()
241   EventBase evb;
242   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
243   if (GetParam() == TFOState::ENABLED) {
244     socket->enableTFO();
245   }
246
247   socket->connect(nullptr, server.getAddress(), 30);
248
249   // write some data, just so we have some way of verifing
250   // that the socket works correctly after connecting
251   char buf[128];
252   memset(buf, 'a', sizeof(buf));
253   WriteCallback wcb;
254   socket->write(&wcb, buf, sizeof(buf));
255
256   evb.loop();
257
258   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
259
260   // Make sure the server got a connection and received the data
261   socket->close();
262   server.verifyConnection(buf, sizeof(buf));
263
264   ASSERT_TRUE(socket->isClosedBySelf());
265   ASSERT_FALSE(socket->isClosedByPeer());
266 }
267
268 /**
269  * Test calling both write() and close() immediately after connecting, without
270  * waiting for connect to finish.
271  *
272  * This exercises the STATE_CONNECTING_CLOSING code.
273  */
274 TEST_P(AsyncSocketConnectTest, ConnectWriteAndClose) {
275   TestServer server;
276
277   // connect()
278   EventBase evb;
279   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
280   if (GetParam() == TFOState::ENABLED) {
281     socket->enableTFO();
282   }
283   ConnCallback ccb;
284   socket->connect(&ccb, server.getAddress(), 30);
285
286   // write()
287   char buf[128];
288   memset(buf, 'a', sizeof(buf));
289   WriteCallback wcb;
290   socket->write(&wcb, buf, sizeof(buf));
291
292   // close()
293   socket->close();
294
295   // Loop.  We don't bother accepting on the server socket yet.
296   // The kernel should be able to buffer the write request so it can succeed.
297   evb.loop();
298
299   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
300   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
301
302   // Make sure the server got a connection and received the data
303   server.verifyConnection(buf, sizeof(buf));
304
305   ASSERT_TRUE(socket->isClosedBySelf());
306   ASSERT_FALSE(socket->isClosedByPeer());
307 }
308
309 /**
310  * Test calling close() immediately after connect()
311  */
312 TEST(AsyncSocketTest, ConnectAndClose) {
313   TestServer server;
314
315   // Connect using a AsyncSocket
316   EventBase evb;
317   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
318   ConnCallback ccb;
319   socket->connect(&ccb, server.getAddress(), 30);
320
321   // Hopefully the connect didn't succeed immediately.
322   // If it did, we can't exercise the close-while-connecting code path.
323   if (ccb.state == STATE_SUCCEEDED) {
324     LOG(INFO) << "connect() succeeded immediately; aborting test "
325                        "of close-during-connect behavior";
326     return;
327   }
328
329   socket->close();
330
331   // Loop, although there shouldn't be anything to do.
332   evb.loop();
333
334   // Make sure the connection was aborted
335   ASSERT_EQ(ccb.state, STATE_FAILED);
336
337   ASSERT_TRUE(socket->isClosedBySelf());
338   ASSERT_FALSE(socket->isClosedByPeer());
339 }
340
341 /**
342  * Test calling closeNow() immediately after connect()
343  *
344  * This should be identical to the normal close behavior.
345  */
346 TEST(AsyncSocketTest, ConnectAndCloseNow) {
347   TestServer server;
348
349   // Connect using a AsyncSocket
350   EventBase evb;
351   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
352   ConnCallback ccb;
353   socket->connect(&ccb, server.getAddress(), 30);
354
355   // Hopefully the connect didn't succeed immediately.
356   // If it did, we can't exercise the close-while-connecting code path.
357   if (ccb.state == STATE_SUCCEEDED) {
358     LOG(INFO) << "connect() succeeded immediately; aborting test "
359                        "of closeNow()-during-connect behavior";
360     return;
361   }
362
363   socket->closeNow();
364
365   // Loop, although there shouldn't be anything to do.
366   evb.loop();
367
368   // Make sure the connection was aborted
369   ASSERT_EQ(ccb.state, STATE_FAILED);
370
371   ASSERT_TRUE(socket->isClosedBySelf());
372   ASSERT_FALSE(socket->isClosedByPeer());
373 }
374
375 /**
376  * Test calling both write() and closeNow() immediately after connecting,
377  * without waiting for connect to finish.
378  *
379  * This should abort the pending write.
380  */
381 TEST(AsyncSocketTest, ConnectWriteAndCloseNow) {
382   TestServer server;
383
384   // connect()
385   EventBase evb;
386   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
387   ConnCallback ccb;
388   socket->connect(&ccb, server.getAddress(), 30);
389
390   // Hopefully the connect didn't succeed immediately.
391   // If it did, we can't exercise the close-while-connecting code path.
392   if (ccb.state == STATE_SUCCEEDED) {
393     LOG(INFO) << "connect() succeeded immediately; aborting test "
394                        "of write-during-connect behavior";
395     return;
396   }
397
398   // write()
399   char buf[128];
400   memset(buf, 'a', sizeof(buf));
401   WriteCallback wcb;
402   socket->write(&wcb, buf, sizeof(buf));
403
404   // close()
405   socket->closeNow();
406
407   // Loop, although there shouldn't be anything to do.
408   evb.loop();
409
410   ASSERT_EQ(ccb.state, STATE_FAILED);
411   ASSERT_EQ(wcb.state, STATE_FAILED);
412
413   ASSERT_TRUE(socket->isClosedBySelf());
414   ASSERT_FALSE(socket->isClosedByPeer());
415 }
416
417 /**
418  * Test installing a read callback immediately, before connect() finishes.
419  */
420 TEST_P(AsyncSocketConnectTest, ConnectAndRead) {
421   TestServer server;
422
423   // connect()
424   EventBase evb;
425   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
426   if (GetParam() == TFOState::ENABLED) {
427     socket->enableTFO();
428   }
429
430   ConnCallback ccb;
431   socket->connect(&ccb, server.getAddress(), 30);
432
433   ReadCallback rcb;
434   socket->setReadCB(&rcb);
435
436   if (GetParam() == TFOState::ENABLED) {
437     // Trigger a connection
438     socket->writeChain(nullptr, IOBuf::copyBuffer("hey"));
439   }
440
441   // Even though we haven't looped yet, we should be able to accept
442   // the connection and send data to it.
443   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
444   uint8_t buf[128];
445   memset(buf, 'a', sizeof(buf));
446   acceptedSocket->write(buf, sizeof(buf));
447   acceptedSocket->flush();
448   acceptedSocket->close();
449
450   // Loop, although there shouldn't be anything to do.
451   evb.loop();
452
453   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
454   ASSERT_EQ(rcb.buffers.size(), 1);
455   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf));
456   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf, sizeof(buf)), 0);
457
458   ASSERT_FALSE(socket->isClosedBySelf());
459   ASSERT_FALSE(socket->isClosedByPeer());
460 }
461
462 /**
463  * Test installing a read callback and then closing immediately before the
464  * connect attempt finishes.
465  */
466 TEST(AsyncSocketTest, ConnectReadAndClose) {
467   TestServer server;
468
469   // connect()
470   EventBase evb;
471   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
472   ConnCallback ccb;
473   socket->connect(&ccb, server.getAddress(), 30);
474
475   // Hopefully the connect didn't succeed immediately.
476   // If it did, we can't exercise the close-while-connecting code path.
477   if (ccb.state == STATE_SUCCEEDED) {
478     LOG(INFO) << "connect() succeeded immediately; aborting test "
479                        "of read-during-connect behavior";
480     return;
481   }
482
483   ReadCallback rcb;
484   socket->setReadCB(&rcb);
485
486   // close()
487   socket->close();
488
489   // Loop, although there shouldn't be anything to do.
490   evb.loop();
491
492   ASSERT_EQ(ccb.state, STATE_FAILED); // we aborted the close attempt
493   ASSERT_EQ(rcb.buffers.size(), 0);
494   ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
495
496   ASSERT_TRUE(socket->isClosedBySelf());
497   ASSERT_FALSE(socket->isClosedByPeer());
498 }
499
500 /**
501  * Test both writing and installing a read callback immediately,
502  * before connect() finishes.
503  */
504 TEST_P(AsyncSocketConnectTest, ConnectWriteAndRead) {
505   TestServer server;
506
507   // connect()
508   EventBase evb;
509   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
510   if (GetParam() == TFOState::ENABLED) {
511     socket->enableTFO();
512   }
513   ConnCallback ccb;
514   socket->connect(&ccb, server.getAddress(), 30);
515
516   // write()
517   char buf1[128];
518   memset(buf1, 'a', sizeof(buf1));
519   WriteCallback wcb;
520   socket->write(&wcb, buf1, sizeof(buf1));
521
522   // set a read callback
523   ReadCallback rcb;
524   socket->setReadCB(&rcb);
525
526   // Even though we haven't looped yet, we should be able to accept
527   // the connection and send data to it.
528   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
529   uint8_t buf2[128];
530   memset(buf2, 'b', sizeof(buf2));
531   acceptedSocket->write(buf2, sizeof(buf2));
532   acceptedSocket->flush();
533
534   // shut down the write half of acceptedSocket, so that the AsyncSocket
535   // will stop reading and we can break out of the event loop.
536   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
537
538   // Loop
539   evb.loop();
540
541   // Make sure the connect succeeded
542   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
543
544   // Make sure the AsyncSocket read the data written by the accepted socket
545   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
546   ASSERT_EQ(rcb.buffers.size(), 1);
547   ASSERT_EQ(rcb.buffers[0].length, sizeof(buf2));
548   ASSERT_EQ(memcmp(rcb.buffers[0].buffer, buf2, sizeof(buf2)), 0);
549
550   // Close the AsyncSocket so we'll see EOF on acceptedSocket
551   socket->close();
552
553   // Make sure the accepted socket saw the data written by the AsyncSocket
554   uint8_t readbuf[sizeof(buf1)];
555   acceptedSocket->readAll(readbuf, sizeof(readbuf));
556   ASSERT_EQ(memcmp(buf1, readbuf, sizeof(buf1)), 0);
557   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
558   ASSERT_EQ(bytesRead, 0);
559
560   ASSERT_FALSE(socket->isClosedBySelf());
561   ASSERT_TRUE(socket->isClosedByPeer());
562 }
563
564 /**
565  * Test writing to the socket then shutting down writes before the connect
566  * attempt finishes.
567  */
568 TEST(AsyncSocketTest, ConnectWriteAndShutdownWrite) {
569   TestServer server;
570
571   // connect()
572   EventBase evb;
573   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
574   ConnCallback ccb;
575   socket->connect(&ccb, server.getAddress(), 30);
576
577   // Hopefully the connect didn't succeed immediately.
578   // If it did, we can't exercise the write-while-connecting code path.
579   if (ccb.state == STATE_SUCCEEDED) {
580     LOG(INFO) << "connect() succeeded immediately; skipping test";
581     return;
582   }
583
584   // Ask to write some data
585   char wbuf[128];
586   memset(wbuf, 'a', sizeof(wbuf));
587   WriteCallback wcb;
588   socket->write(&wcb, wbuf, sizeof(wbuf));
589   socket->shutdownWrite();
590
591   // Shutdown writes
592   socket->shutdownWrite();
593
594   // Even though we haven't looped yet, we should be able to accept
595   // the connection.
596   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
597
598   // Since the connection is still in progress, there should be no data to
599   // read yet.  Verify that the accepted socket is not readable.
600   struct pollfd fds[1];
601   fds[0].fd = acceptedSocket->getSocketFD();
602   fds[0].events = POLLIN;
603   fds[0].revents = 0;
604   int rc = poll(fds, 1, 0);
605   ASSERT_EQ(rc, 0);
606
607   // Write data to the accepted socket
608   uint8_t acceptedWbuf[192];
609   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
610   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
611   acceptedSocket->flush();
612
613   // Loop
614   evb.loop();
615
616   // The loop should have completed the connection, written the queued data,
617   // and shutdown writes on the socket.
618   //
619   // Check that the connection was completed successfully and that the write
620   // callback succeeded.
621   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
622   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
623
624   // Check that we can read the data that was written to the socket, and that
625   // we see an EOF, since its socket was half-shutdown.
626   uint8_t readbuf[sizeof(wbuf)];
627   acceptedSocket->readAll(readbuf, sizeof(readbuf));
628   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
629   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
630   ASSERT_EQ(bytesRead, 0);
631
632   // Close the accepted socket.  This will cause it to see EOF
633   // and uninstall the read callback when we loop next.
634   acceptedSocket->close();
635
636   // Install a read callback, then loop again.
637   ReadCallback rcb;
638   socket->setReadCB(&rcb);
639   evb.loop();
640
641   // This loop should have read the data and seen the EOF
642   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
643   ASSERT_EQ(rcb.buffers.size(), 1);
644   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
645   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
646                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
647
648   ASSERT_FALSE(socket->isClosedBySelf());
649   ASSERT_FALSE(socket->isClosedByPeer());
650 }
651
652 /**
653  * Test reading, writing, and shutting down writes before the connect attempt
654  * finishes.
655  */
656 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWrite) {
657   TestServer server;
658
659   // connect()
660   EventBase evb;
661   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
662   ConnCallback ccb;
663   socket->connect(&ccb, server.getAddress(), 30);
664
665   // Hopefully the connect didn't succeed immediately.
666   // If it did, we can't exercise the write-while-connecting code path.
667   if (ccb.state == STATE_SUCCEEDED) {
668     LOG(INFO) << "connect() succeeded immediately; skipping test";
669     return;
670   }
671
672   // Install a read callback
673   ReadCallback rcb;
674   socket->setReadCB(&rcb);
675
676   // Ask to write some data
677   char wbuf[128];
678   memset(wbuf, 'a', sizeof(wbuf));
679   WriteCallback wcb;
680   socket->write(&wcb, wbuf, sizeof(wbuf));
681
682   // Shutdown writes
683   socket->shutdownWrite();
684
685   // Even though we haven't looped yet, we should be able to accept
686   // the connection.
687   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
688
689   // Since the connection is still in progress, there should be no data to
690   // read yet.  Verify that the accepted socket is not readable.
691   struct pollfd fds[1];
692   fds[0].fd = acceptedSocket->getSocketFD();
693   fds[0].events = POLLIN;
694   fds[0].revents = 0;
695   int rc = poll(fds, 1, 0);
696   ASSERT_EQ(rc, 0);
697
698   // Write data to the accepted socket
699   uint8_t acceptedWbuf[192];
700   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
701   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
702   acceptedSocket->flush();
703   // Shutdown writes to the accepted socket.  This will cause it to see EOF
704   // and uninstall the read callback.
705   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
706
707   // Loop
708   evb.loop();
709
710   // The loop should have completed the connection, written the queued data,
711   // shutdown writes on the socket, read the data we wrote to it, and see the
712   // EOF.
713   //
714   // Check that the connection was completed successfully and that the read
715   // and write callbacks were invoked as expected.
716   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
717   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
718   ASSERT_EQ(rcb.buffers.size(), 1);
719   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
720   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
721                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
722   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
723
724   // Check that we can read the data that was written to the socket, and that
725   // we see an EOF, since its socket was half-shutdown.
726   uint8_t readbuf[sizeof(wbuf)];
727   acceptedSocket->readAll(readbuf, sizeof(readbuf));
728   ASSERT_EQ(memcmp(wbuf, readbuf, sizeof(wbuf)), 0);
729   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
730   ASSERT_EQ(bytesRead, 0);
731
732   // Fully close both sockets
733   acceptedSocket->close();
734   socket->close();
735
736   ASSERT_FALSE(socket->isClosedBySelf());
737   ASSERT_TRUE(socket->isClosedByPeer());
738 }
739
740 /**
741  * Test reading, writing, and calling shutdownWriteNow() before the
742  * connect attempt finishes.
743  */
744 TEST(AsyncSocketTest, ConnectReadWriteAndShutdownWriteNow) {
745   TestServer server;
746
747   // connect()
748   EventBase evb;
749   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
750   ConnCallback ccb;
751   socket->connect(&ccb, server.getAddress(), 30);
752
753   // Hopefully the connect didn't succeed immediately.
754   // If it did, we can't exercise the write-while-connecting code path.
755   if (ccb.state == STATE_SUCCEEDED) {
756     LOG(INFO) << "connect() succeeded immediately; skipping test";
757     return;
758   }
759
760   // Install a read callback
761   ReadCallback rcb;
762   socket->setReadCB(&rcb);
763
764   // Ask to write some data
765   char wbuf[128];
766   memset(wbuf, 'a', sizeof(wbuf));
767   WriteCallback wcb;
768   socket->write(&wcb, wbuf, sizeof(wbuf));
769
770   // Shutdown writes immediately.
771   // This should immediately discard the data that we just tried to write.
772   socket->shutdownWriteNow();
773
774   // Verify that writeError() was invoked on the write callback.
775   ASSERT_EQ(wcb.state, STATE_FAILED);
776   ASSERT_EQ(wcb.bytesWritten, 0);
777
778   // Even though we haven't looped yet, we should be able to accept
779   // the connection.
780   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
781
782   // Since the connection is still in progress, there should be no data to
783   // read yet.  Verify that the accepted socket is not readable.
784   struct pollfd fds[1];
785   fds[0].fd = acceptedSocket->getSocketFD();
786   fds[0].events = POLLIN;
787   fds[0].revents = 0;
788   int rc = poll(fds, 1, 0);
789   ASSERT_EQ(rc, 0);
790
791   // Write data to the accepted socket
792   uint8_t acceptedWbuf[192];
793   memset(acceptedWbuf, 'b', sizeof(acceptedWbuf));
794   acceptedSocket->write(acceptedWbuf, sizeof(acceptedWbuf));
795   acceptedSocket->flush();
796   // Shutdown writes to the accepted socket.  This will cause it to see EOF
797   // and uninstall the read callback.
798   shutdown(acceptedSocket->getSocketFD(), SHUT_WR);
799
800   // Loop
801   evb.loop();
802
803   // The loop should have completed the connection, written the queued data,
804   // shutdown writes on the socket, read the data we wrote to it, and see the
805   // EOF.
806   //
807   // Check that the connection was completed successfully and that the read
808   // callback was invoked as expected.
809   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
810   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
811   ASSERT_EQ(rcb.buffers.size(), 1);
812   ASSERT_EQ(rcb.buffers[0].length, sizeof(acceptedWbuf));
813   ASSERT_EQ(memcmp(rcb.buffers[0].buffer,
814                            acceptedWbuf, sizeof(acceptedWbuf)), 0);
815
816   // Since we used shutdownWriteNow(), it should have discarded all pending
817   // write data.  Verify we see an immediate EOF when reading from the accepted
818   // socket.
819   uint8_t readbuf[sizeof(wbuf)];
820   uint32_t bytesRead = acceptedSocket->read(readbuf, sizeof(readbuf));
821   ASSERT_EQ(bytesRead, 0);
822
823   // Fully close both sockets
824   acceptedSocket->close();
825   socket->close();
826
827   ASSERT_FALSE(socket->isClosedBySelf());
828   ASSERT_TRUE(socket->isClosedByPeer());
829 }
830
831 // Helper function for use in testConnectOptWrite()
832 // Temporarily disable the read callback
833 void tmpDisableReads(AsyncSocket* socket, ReadCallback* rcb) {
834   // Uninstall the read callback
835   socket->setReadCB(nullptr);
836   // Schedule the read callback to be reinstalled after 1ms
837   socket->getEventBase()->runInLoop(
838       std::bind(&AsyncSocket::setReadCB, socket, rcb));
839 }
840
841 /**
842  * Test connect+write, then have the connect callback perform another write.
843  *
844  * This tests interaction of the optimistic writing after connect with
845  * additional write attempts that occur in the connect callback.
846  */
847 void testConnectOptWrite(size_t size1, size_t size2, bool close = false) {
848   TestServer server;
849   EventBase evb;
850   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
851
852   // connect()
853   ConnCallback ccb;
854   socket->connect(&ccb, server.getAddress(), 30);
855
856   // Hopefully the connect didn't succeed immediately.
857   // If it did, we can't exercise the optimistic write code path.
858   if (ccb.state == STATE_SUCCEEDED) {
859     LOG(INFO) << "connect() succeeded immediately; aborting test "
860                        "of optimistic write behavior";
861     return;
862   }
863
864   // Tell the connect callback to perform a write when the connect succeeds
865   WriteCallback wcb2;
866   scoped_array<char> buf2(new char[size2]);
867   memset(buf2.get(), 'b', size2);
868   if (size2 > 0) {
869     ccb.successCallback = [&] { socket->write(&wcb2, buf2.get(), size2); };
870     // Tell the second write callback to close the connection when it is done
871     wcb2.successCallback = [&] { socket->closeNow(); };
872   }
873
874   // Schedule one write() immediately, before the connect finishes
875   scoped_array<char> buf1(new char[size1]);
876   memset(buf1.get(), 'a', size1);
877   WriteCallback wcb1;
878   if (size1 > 0) {
879     socket->write(&wcb1, buf1.get(), size1);
880   }
881
882   if (close) {
883     // immediately perform a close, before connect() completes
884     socket->close();
885   }
886
887   // Start reading from the other endpoint after 10ms.
888   // If we're using large buffers, we have to read so that the writes don't
889   // block forever.
890   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
891   ReadCallback rcb;
892   rcb.dataAvailableCallback = std::bind(tmpDisableReads,
893                                         acceptedSocket.get(), &rcb);
894   socket->getEventBase()->tryRunAfterDelay(
895       std::bind(&AsyncSocket::setReadCB, acceptedSocket.get(), &rcb),
896       10);
897
898   // Loop.  We don't bother accepting on the server socket yet.
899   // The kernel should be able to buffer the write request so it can succeed.
900   evb.loop();
901
902   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
903   if (size1 > 0) {
904     ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
905   }
906   if (size2 > 0) {
907     ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
908   }
909
910   socket->close();
911
912   // Make sure the read callback received all of the data
913   size_t bytesRead = 0;
914   for (vector<ReadCallback::Buffer>::const_iterator it = rcb.buffers.begin();
915        it != rcb.buffers.end();
916        ++it) {
917     size_t start = bytesRead;
918     bytesRead += it->length;
919     size_t end = bytesRead;
920     if (start < size1) {
921       size_t cmpLen = min(size1, end) - start;
922       ASSERT_EQ(memcmp(it->buffer, buf1.get() + start, cmpLen), 0);
923     }
924     if (end > size1 && end <= size1 + size2) {
925       size_t itOffset;
926       size_t buf2Offset;
927       size_t cmpLen;
928       if (start >= size1) {
929         itOffset = 0;
930         buf2Offset = start - size1;
931         cmpLen = end - start;
932       } else {
933         itOffset = size1 - start;
934         buf2Offset = 0;
935         cmpLen = end - size1;
936       }
937       ASSERT_EQ(memcmp(it->buffer + itOffset, buf2.get() + buf2Offset,
938                                cmpLen),
939                         0);
940     }
941   }
942   ASSERT_EQ(bytesRead, size1 + size2);
943 }
944
945 TEST(AsyncSocketTest, ConnectCallbackWrite) {
946   // Test using small writes that should both succeed immediately
947   testConnectOptWrite(100, 200);
948
949   // Test using a large buffer in the connect callback, that should block
950   const size_t largeSize = 32 * 1024 * 1024;
951   testConnectOptWrite(100, largeSize);
952
953   // Test using a large initial write
954   testConnectOptWrite(largeSize, 100);
955
956   // Test using two large buffers
957   testConnectOptWrite(largeSize, largeSize);
958
959   // Test a small write in the connect callback,
960   // but no immediate write before connect completes
961   testConnectOptWrite(0, 64);
962
963   // Test a large write in the connect callback,
964   // but no immediate write before connect completes
965   testConnectOptWrite(0, largeSize);
966
967   // Test connect, a small write, then immediately call close() before connect
968   // completes
969   testConnectOptWrite(211, 0, true);
970
971   // Test connect, a large immediate write (that will block), then immediately
972   // call close() before connect completes
973   testConnectOptWrite(largeSize, 0, true);
974 }
975
976 ///////////////////////////////////////////////////////////////////////////
977 // write() related tests
978 ///////////////////////////////////////////////////////////////////////////
979
980 /**
981  * Test writing using a nullptr callback
982  */
983 TEST(AsyncSocketTest, WriteNullCallback) {
984   TestServer server;
985
986   // connect()
987   EventBase evb;
988   std::shared_ptr<AsyncSocket> socket =
989     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
990   evb.loop(); // loop until the socket is connected
991
992   // write() with a nullptr callback
993   char buf[128];
994   memset(buf, 'a', sizeof(buf));
995   socket->write(nullptr, buf, sizeof(buf));
996
997   evb.loop(); // loop until the data is sent
998
999   // Make sure the server got a connection and received the data
1000   socket->close();
1001   server.verifyConnection(buf, sizeof(buf));
1002
1003   ASSERT_TRUE(socket->isClosedBySelf());
1004   ASSERT_FALSE(socket->isClosedByPeer());
1005 }
1006
1007 /**
1008  * Test writing with a send timeout
1009  */
1010 TEST(AsyncSocketTest, WriteTimeout) {
1011   TestServer server;
1012
1013   // connect()
1014   EventBase evb;
1015   std::shared_ptr<AsyncSocket> socket =
1016     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1017   evb.loop(); // loop until the socket is connected
1018
1019   // write() a large chunk of data, with no-one on the other end reading.
1020   // Tricky: the kernel caches the connection metrics for recently-used
1021   // routes (see tcp_no_metrics_save) so a freshly opened connection can
1022   // have a send buffer size bigger than wmem_default.  This makes the test
1023   // flaky on contbuild if writeLength is < wmem_max (20M on our systems).
1024   size_t writeLength = 32 * 1024 * 1024;
1025   uint32_t timeout = 200;
1026   socket->setSendTimeout(timeout);
1027   scoped_array<char> buf(new char[writeLength]);
1028   memset(buf.get(), 'a', writeLength);
1029   WriteCallback wcb;
1030   socket->write(&wcb, buf.get(), writeLength);
1031
1032   TimePoint start;
1033   evb.loop();
1034   TimePoint end;
1035
1036   // Make sure the write attempt timed out as requested
1037   ASSERT_EQ(wcb.state, STATE_FAILED);
1038   ASSERT_EQ(wcb.exception.getType(), AsyncSocketException::TIMED_OUT);
1039
1040   // Check that the write timed out within a reasonable period of time.
1041   // We don't check for exactly the specified timeout, since AsyncSocket only
1042   // times out when it hasn't made progress for that period of time.
1043   //
1044   // On linux, the first write sends a few hundred kb of data, then blocks for
1045   // writability, and then unblocks again after 40ms and is able to write
1046   // another smaller of data before blocking permanently.  Therefore it doesn't
1047   // time out until 40ms + timeout.
1048   //
1049   // I haven't fully verified the cause of this, but I believe it probably
1050   // occurs because the receiving end delays sending an ack for up to 40ms.
1051   // (This is the default value for TCP_DELACK_MIN.)  Once the sender receives
1052   // the ack, it can send some more data.  However, after that point the
1053   // receiver's kernel buffer is full.  This 40ms delay happens even with
1054   // TCP_NODELAY and TCP_QUICKACK enabled on both endpoints.  However, the
1055   // kernel may be automatically disabling TCP_QUICKACK after receiving some
1056   // data.
1057   //
1058   // For now, we simply check that the timeout occurred within 160ms of
1059   // the requested value.
1060   T_CHECK_TIMEOUT(start, end, milliseconds(timeout), milliseconds(160));
1061 }
1062
1063 /**
1064  * Test writing to a socket that the remote endpoint has closed
1065  */
1066 TEST(AsyncSocketTest, WritePipeError) {
1067   TestServer server;
1068
1069   // connect()
1070   EventBase evb;
1071   std::shared_ptr<AsyncSocket> socket =
1072     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1073   socket->setSendTimeout(1000);
1074   evb.loop(); // loop until the socket is connected
1075
1076   // accept and immediately close the socket
1077   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1078   acceptedSocket->close();
1079
1080   // write() a large chunk of data
1081   size_t writeLength = 32 * 1024 * 1024;
1082   scoped_array<char> buf(new char[writeLength]);
1083   memset(buf.get(), 'a', writeLength);
1084   WriteCallback wcb;
1085   socket->write(&wcb, buf.get(), writeLength);
1086
1087   evb.loop();
1088
1089   // Make sure the write failed.
1090   // It would be nice if AsyncSocketException could convey the errno value,
1091   // so that we could check for EPIPE
1092   ASSERT_EQ(wcb.state, STATE_FAILED);
1093   ASSERT_EQ(wcb.exception.getType(),
1094                     AsyncSocketException::INTERNAL_ERROR);
1095
1096   ASSERT_FALSE(socket->isClosedBySelf());
1097   ASSERT_FALSE(socket->isClosedByPeer());
1098 }
1099
1100 /**
1101  * Test writing to a socket that has its read side closed
1102  */
1103 TEST(AsyncSocketTest, WriteAfterReadEOF) {
1104   TestServer server;
1105
1106   // connect()
1107   EventBase evb;
1108   std::shared_ptr<AsyncSocket> socket =
1109       AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1110   evb.loop(); // loop until the socket is connected
1111
1112   // Accept the connection
1113   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1114   ReadCallback rcb;
1115   acceptedSocket->setReadCB(&rcb);
1116
1117   // Shutdown the write side of client socket (read side of server socket)
1118   socket->shutdownWrite();
1119   evb.loop();
1120
1121   // Check that accepted socket is still writable
1122   ASSERT_FALSE(acceptedSocket->good());
1123   ASSERT_TRUE(acceptedSocket->writable());
1124
1125   // Write data to accepted socket
1126   constexpr size_t simpleBufLength = 5;
1127   char simpleBuf[simpleBufLength];
1128   memset(simpleBuf, 'a', simpleBufLength);
1129   WriteCallback wcb;
1130   acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
1131   evb.loop();
1132
1133   // Make sure we were able to write even after getting a read EOF
1134   ASSERT_EQ(rcb.state, STATE_SUCCEEDED); // this indicates EOF
1135   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1136 }
1137
1138 /**
1139  * Test that bytes written is correctly computed in case of write failure
1140  */
1141 TEST(AsyncSocketTest, WriteErrorCallbackBytesWritten) {
1142   // Send and receive buffer sizes for the sockets.
1143   constexpr size_t kSockBufSize = 8 * 1024;
1144
1145   TestServer server(false, kSockBufSize);
1146
1147   AsyncSocket::OptionMap options{
1148       {{SOL_SOCKET, SO_SNDBUF}, kSockBufSize},
1149       {{SOL_SOCKET, SO_RCVBUF}, kSockBufSize},
1150       {{IPPROTO_TCP, TCP_NODELAY}, 1},
1151   };
1152
1153   // The current thread will be used by the receiver - use a separate thread
1154   // for the sender.
1155   EventBase senderEvb;
1156   std::thread senderThread([&]() { senderEvb.loopForever(); });
1157
1158   ConnCallback ccb;
1159   std::shared_ptr<AsyncSocket> socket;
1160
1161   senderEvb.runInEventBaseThreadAndWait([&]() {
1162     socket = AsyncSocket::newSocket(&senderEvb);
1163     socket->connect(&ccb, server.getAddress(), 30, options);
1164   });
1165
1166   // accept the socket on the server side
1167   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1168
1169   // Send a big (45KB) write so that it is partially written. The first write
1170   // is 16KB (8KB on both sides) and subsequent writes are 8KB each. Reading
1171   // just under 24KB would cause 3-4 writes for the total of 32-40KB in the
1172   // following sequence: 16KB + 8KB + 8KB (+ 8KB). This ensures that not all
1173   // bytes are written when the socket is reset. Having at least 3 writes
1174   // ensures that the total size (45KB) would be exceeed in case of overcounting
1175   // based on the initial write size of 16KB.
1176   constexpr size_t kSendSize = 45 * 1024;
1177   auto const sendBuf = std::vector<char>(kSendSize, 'a');
1178
1179   WriteCallback wcb;
1180
1181   senderEvb.runInEventBaseThreadAndWait(
1182       [&]() { socket->write(&wcb, sendBuf.data(), kSendSize); });
1183
1184   // Reading 20KB would cause three additional writes of 8KB, but less
1185   // than 45KB total, so the socket is reset before all bytes are written.
1186   constexpr size_t kRecvSize = 20 * 1024;
1187   uint8_t recvBuf[kRecvSize];
1188   int bytesRead = acceptedSocket->readAll(recvBuf, sizeof(recvBuf));
1189   ASSERT_EQ(kRecvSize, bytesRead);
1190
1191   constexpr size_t kMinExpectedBytesWritten = // 20 ACK + 8 send buf
1192       kRecvSize + kSockBufSize;
1193   static_assert(kMinExpectedBytesWritten == 28 * 1024, "bad math");
1194   static_assert(kMinExpectedBytesWritten > kRecvSize, "bad math");
1195
1196   constexpr size_t kMaxExpectedBytesWritten = // 24 ACK + 8 sent + 8 send buf
1197       constexpr_ceil(kRecvSize, kSockBufSize) + 2 * kSockBufSize;
1198   static_assert(kMaxExpectedBytesWritten == 40 * 1024, "bad math");
1199   static_assert(kMaxExpectedBytesWritten < kSendSize, "bad math");
1200
1201   // Need to delay after receiving 20KB and before closing the receive side so
1202   // that the send side has a chance to fill the send buffer past.
1203   using clock = std::chrono::steady_clock;
1204   auto const deadline = clock::now() + std::chrono::seconds(2);
1205   while (wcb.bytesWritten < kMinExpectedBytesWritten &&
1206          clock::now() < deadline) {
1207     std::this_thread::yield();
1208   }
1209   acceptedSocket->closeWithReset();
1210
1211   senderEvb.terminateLoopSoon();
1212   senderThread.join();
1213
1214   ASSERT_EQ(STATE_FAILED, wcb.state);
1215   ASSERT_LE(kMinExpectedBytesWritten, wcb.bytesWritten);
1216   ASSERT_GE(kMaxExpectedBytesWritten, wcb.bytesWritten);
1217 }
1218
1219 /**
1220  * Test writing a mix of simple buffers and IOBufs
1221  */
1222 TEST(AsyncSocketTest, WriteIOBuf) {
1223   TestServer server;
1224
1225   // connect()
1226   EventBase evb;
1227   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1228   ConnCallback ccb;
1229   socket->connect(&ccb, server.getAddress(), 30);
1230
1231   // Accept the connection
1232   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1233   ReadCallback rcb;
1234   acceptedSocket->setReadCB(&rcb);
1235
1236   // Check if EOR tracking flag can be set and reset.
1237   EXPECT_FALSE(socket->isEorTrackingEnabled());
1238   socket->setEorTracking(true);
1239   EXPECT_TRUE(socket->isEorTrackingEnabled());
1240   socket->setEorTracking(false);
1241   EXPECT_FALSE(socket->isEorTrackingEnabled());
1242
1243   // Write a simple buffer to the socket
1244   constexpr size_t simpleBufLength = 5;
1245   char simpleBuf[simpleBufLength];
1246   memset(simpleBuf, 'a', simpleBufLength);
1247   WriteCallback wcb;
1248   socket->write(&wcb, simpleBuf, simpleBufLength);
1249
1250   // Write a single-element IOBuf chain
1251   size_t buf1Length = 7;
1252   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1253   memset(buf1->writableData(), 'b', buf1Length);
1254   buf1->append(buf1Length);
1255   unique_ptr<IOBuf> buf1Copy(buf1->clone());
1256   WriteCallback wcb2;
1257   socket->writeChain(&wcb2, std::move(buf1));
1258
1259   // Write a multiple-element IOBuf chain
1260   size_t buf2Length = 11;
1261   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1262   memset(buf2->writableData(), 'c', buf2Length);
1263   buf2->append(buf2Length);
1264   size_t buf3Length = 13;
1265   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1266   memset(buf3->writableData(), 'd', buf3Length);
1267   buf3->append(buf3Length);
1268   buf2->appendChain(std::move(buf3));
1269   unique_ptr<IOBuf> buf2Copy(buf2->clone());
1270   buf2Copy->coalesce();
1271   WriteCallback wcb3;
1272   socket->writeChain(&wcb3, std::move(buf2));
1273   socket->shutdownWrite();
1274
1275   // Let the reads and writes run to completion
1276   evb.loop();
1277
1278   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1279   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1280   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1281
1282   // Make sure the reader got the right data in the right order
1283   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1284   ASSERT_EQ(rcb.buffers.size(), 1);
1285   ASSERT_EQ(rcb.buffers[0].length,
1286       simpleBufLength + buf1Length + buf2Length + buf3Length);
1287   ASSERT_EQ(
1288       memcmp(rcb.buffers[0].buffer, simpleBuf, simpleBufLength), 0);
1289   ASSERT_EQ(
1290       memcmp(rcb.buffers[0].buffer + simpleBufLength,
1291           buf1Copy->data(), buf1Copy->length()), 0);
1292   ASSERT_EQ(
1293       memcmp(rcb.buffers[0].buffer + simpleBufLength + buf1Length,
1294           buf2Copy->data(), buf2Copy->length()), 0);
1295
1296   acceptedSocket->close();
1297   socket->close();
1298
1299   ASSERT_TRUE(socket->isClosedBySelf());
1300   ASSERT_FALSE(socket->isClosedByPeer());
1301 }
1302
1303 TEST(AsyncSocketTest, WriteIOBufCorked) {
1304   TestServer server;
1305
1306   // connect()
1307   EventBase evb;
1308   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1309   ConnCallback ccb;
1310   socket->connect(&ccb, server.getAddress(), 30);
1311
1312   // Accept the connection
1313   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&evb);
1314   ReadCallback rcb;
1315   acceptedSocket->setReadCB(&rcb);
1316
1317   // Do three writes, 100ms apart, with the "cork" flag set
1318   // on the second write.  The reader should see the first write
1319   // arrive by itself, followed by the second and third writes
1320   // arriving together.
1321   size_t buf1Length = 5;
1322   unique_ptr<IOBuf> buf1(IOBuf::create(buf1Length));
1323   memset(buf1->writableData(), 'a', buf1Length);
1324   buf1->append(buf1Length);
1325   size_t buf2Length = 7;
1326   unique_ptr<IOBuf> buf2(IOBuf::create(buf2Length));
1327   memset(buf2->writableData(), 'b', buf2Length);
1328   buf2->append(buf2Length);
1329   size_t buf3Length = 11;
1330   unique_ptr<IOBuf> buf3(IOBuf::create(buf3Length));
1331   memset(buf3->writableData(), 'c', buf3Length);
1332   buf3->append(buf3Length);
1333   WriteCallback wcb1;
1334   socket->writeChain(&wcb1, std::move(buf1));
1335   WriteCallback wcb2;
1336   DelayedWrite write2(socket, std::move(buf2), &wcb2, true);
1337   write2.scheduleTimeout(100);
1338   WriteCallback wcb3;
1339   DelayedWrite write3(socket, std::move(buf3), &wcb3, false, true);
1340   write3.scheduleTimeout(140);
1341
1342   evb.loop();
1343   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1344   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1345   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1346   if (wcb3.state != STATE_SUCCEEDED) {
1347     throw(wcb3.exception);
1348   }
1349   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1350
1351   // Make sure the reader got the data with the right grouping
1352   ASSERT_EQ(rcb.state, STATE_SUCCEEDED);
1353   ASSERT_EQ(rcb.buffers.size(), 2);
1354   ASSERT_EQ(rcb.buffers[0].length, buf1Length);
1355   ASSERT_EQ(rcb.buffers[1].length, buf2Length + buf3Length);
1356
1357   acceptedSocket->close();
1358   socket->close();
1359
1360   ASSERT_TRUE(socket->isClosedBySelf());
1361   ASSERT_FALSE(socket->isClosedByPeer());
1362 }
1363
1364 /**
1365  * Test performing a zero-length write
1366  */
1367 TEST(AsyncSocketTest, ZeroLengthWrite) {
1368   TestServer server;
1369
1370   // connect()
1371   EventBase evb;
1372   std::shared_ptr<AsyncSocket> socket =
1373     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1374   evb.loop(); // loop until the socket is connected
1375
1376   auto acceptedSocket = server.acceptAsync(&evb);
1377   ReadCallback rcb;
1378   acceptedSocket->setReadCB(&rcb);
1379
1380   size_t len1 = 1024*1024;
1381   size_t len2 = 1024*1024;
1382   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1383   memset(buf.get(), 'a', len1);
1384   memset(buf.get(), 'b', len2);
1385
1386   WriteCallback wcb1;
1387   WriteCallback wcb2;
1388   WriteCallback wcb3;
1389   WriteCallback wcb4;
1390   socket->write(&wcb1, buf.get(), 0);
1391   socket->write(&wcb2, buf.get(), len1);
1392   socket->write(&wcb3, buf.get() + len1, 0);
1393   socket->write(&wcb4, buf.get() + len1, len2);
1394   socket->close();
1395
1396   evb.loop(); // loop until the data is sent
1397
1398   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1399   ASSERT_EQ(wcb2.state, STATE_SUCCEEDED);
1400   ASSERT_EQ(wcb3.state, STATE_SUCCEEDED);
1401   ASSERT_EQ(wcb4.state, STATE_SUCCEEDED);
1402   rcb.verifyData(buf.get(), len1 + len2);
1403
1404   ASSERT_TRUE(socket->isClosedBySelf());
1405   ASSERT_FALSE(socket->isClosedByPeer());
1406 }
1407
1408 TEST(AsyncSocketTest, ZeroLengthWritev) {
1409   TestServer server;
1410
1411   // connect()
1412   EventBase evb;
1413   std::shared_ptr<AsyncSocket> socket =
1414     AsyncSocket::newSocket(&evb, server.getAddress(), 30);
1415   evb.loop(); // loop until the socket is connected
1416
1417   auto acceptedSocket = server.acceptAsync(&evb);
1418   ReadCallback rcb;
1419   acceptedSocket->setReadCB(&rcb);
1420
1421   size_t len1 = 1024*1024;
1422   size_t len2 = 1024*1024;
1423   std::unique_ptr<char[]> buf(new char[len1 + len2]);
1424   memset(buf.get(), 'a', len1);
1425   memset(buf.get(), 'b', len2);
1426
1427   WriteCallback wcb;
1428   constexpr size_t iovCount = 4;
1429   struct iovec iov[iovCount];
1430   iov[0].iov_base = buf.get();
1431   iov[0].iov_len = len1;
1432   iov[1].iov_base = buf.get() + len1;
1433   iov[1].iov_len = 0;
1434   iov[2].iov_base = buf.get() + len1;
1435   iov[2].iov_len = len2;
1436   iov[3].iov_base = buf.get() + len1 + len2;
1437   iov[3].iov_len = 0;
1438
1439   socket->writev(&wcb, iov, iovCount);
1440   socket->close();
1441   evb.loop(); // loop until the data is sent
1442
1443   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1444   rcb.verifyData(buf.get(), len1 + len2);
1445
1446   ASSERT_TRUE(socket->isClosedBySelf());
1447   ASSERT_FALSE(socket->isClosedByPeer());
1448 }
1449
1450 ///////////////////////////////////////////////////////////////////////////
1451 // close() related tests
1452 ///////////////////////////////////////////////////////////////////////////
1453
1454 /**
1455  * Test calling close() with pending writes when the socket is already closing.
1456  */
1457 TEST(AsyncSocketTest, ClosePendingWritesWhileClosing) {
1458   TestServer server;
1459
1460   // connect()
1461   EventBase evb;
1462   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
1463   ConnCallback ccb;
1464   socket->connect(&ccb, server.getAddress(), 30);
1465
1466   // accept the socket on the server side
1467   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
1468
1469   // Loop to ensure the connect has completed
1470   evb.loop();
1471
1472   // Make sure we are connected
1473   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
1474
1475   // Schedule pending writes, until several write attempts have blocked
1476   char buf[128];
1477   memset(buf, 'a', sizeof(buf));
1478   typedef vector< std::shared_ptr<WriteCallback> > WriteCallbackVector;
1479   WriteCallbackVector writeCallbacks;
1480
1481   writeCallbacks.reserve(5);
1482   while (writeCallbacks.size() < 5) {
1483     std::shared_ptr<WriteCallback> wcb(new WriteCallback);
1484
1485     socket->write(wcb.get(), buf, sizeof(buf));
1486     if (wcb->state == STATE_SUCCEEDED) {
1487       // Succeeded immediately.  Keep performing more writes
1488       continue;
1489     }
1490
1491     // This write is blocked.
1492     // Have the write callback call close() when writeError() is invoked
1493     wcb->errorCallback = std::bind(&AsyncSocket::close, socket.get());
1494     writeCallbacks.push_back(wcb);
1495   }
1496
1497   // Call closeNow() to immediately fail the pending writes
1498   socket->closeNow();
1499
1500   // Make sure writeError() was invoked on all of the pending write callbacks
1501   for (WriteCallbackVector::const_iterator it = writeCallbacks.begin();
1502        it != writeCallbacks.end();
1503        ++it) {
1504     ASSERT_EQ((*it)->state, STATE_FAILED);
1505   }
1506
1507   ASSERT_TRUE(socket->isClosedBySelf());
1508   ASSERT_FALSE(socket->isClosedByPeer());
1509 }
1510
1511 ///////////////////////////////////////////////////////////////////////////
1512 // ImmediateRead related tests
1513 ///////////////////////////////////////////////////////////////////////////
1514
1515 /* AsyncSocket use to verify immediate read works */
1516 class AsyncSocketImmediateRead : public folly::AsyncSocket {
1517  public:
1518   bool immediateReadCalled = false;
1519   explicit AsyncSocketImmediateRead(folly::EventBase* evb) : AsyncSocket(evb) {}
1520  protected:
1521   void checkForImmediateRead() noexcept override {
1522     immediateReadCalled = true;
1523     AsyncSocket::handleRead();
1524   }
1525 };
1526
1527 TEST(AsyncSocket, ConnectReadImmediateRead) {
1528   TestServer server;
1529
1530   const size_t maxBufferSz = 100;
1531   const size_t maxReadsPerEvent = 1;
1532   const size_t expectedDataSz = maxBufferSz * 3;
1533   char expectedData[expectedDataSz];
1534   memset(expectedData, 'j', expectedDataSz);
1535
1536   EventBase evb;
1537   ReadCallback rcb(maxBufferSz);
1538   AsyncSocketImmediateRead socket(&evb);
1539   socket.connect(nullptr, server.getAddress(), 30);
1540
1541   evb.loop(); // loop until the socket is connected
1542
1543   socket.setReadCB(&rcb);
1544   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1545   socket.immediateReadCalled = false;
1546
1547   auto acceptedSocket = server.acceptAsync(&evb);
1548
1549   ReadCallback rcbServer;
1550   WriteCallback wcbServer;
1551   rcbServer.dataAvailableCallback = [&]() {
1552     if (rcbServer.dataRead() == expectedDataSz) {
1553       // write back all data read
1554       rcbServer.verifyData(expectedData, expectedDataSz);
1555       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1556       acceptedSocket->close();
1557     }
1558   };
1559   acceptedSocket->setReadCB(&rcbServer);
1560
1561   // write data
1562   WriteCallback wcb1;
1563   socket.write(&wcb1, expectedData, expectedDataSz);
1564   evb.loop();
1565   ASSERT_EQ(wcb1.state, STATE_SUCCEEDED);
1566   rcb.verifyData(expectedData, expectedDataSz);
1567   ASSERT_EQ(socket.immediateReadCalled, true);
1568
1569   ASSERT_FALSE(socket.isClosedBySelf());
1570   ASSERT_FALSE(socket.isClosedByPeer());
1571 }
1572
1573 TEST(AsyncSocket, ConnectReadUninstallRead) {
1574   TestServer server;
1575
1576   const size_t maxBufferSz = 100;
1577   const size_t maxReadsPerEvent = 1;
1578   const size_t expectedDataSz = maxBufferSz * 3;
1579   char expectedData[expectedDataSz];
1580   memset(expectedData, 'k', expectedDataSz);
1581
1582   EventBase evb;
1583   ReadCallback rcb(maxBufferSz);
1584   AsyncSocketImmediateRead socket(&evb);
1585   socket.connect(nullptr, server.getAddress(), 30);
1586
1587   evb.loop(); // loop until the socket is connected
1588
1589   socket.setReadCB(&rcb);
1590   socket.setMaxReadsPerEvent(maxReadsPerEvent);
1591   socket.immediateReadCalled = false;
1592
1593   auto acceptedSocket = server.acceptAsync(&evb);
1594
1595   ReadCallback rcbServer;
1596   WriteCallback wcbServer;
1597   rcbServer.dataAvailableCallback = [&]() {
1598     if (rcbServer.dataRead() == expectedDataSz) {
1599       // write back all data read
1600       rcbServer.verifyData(expectedData, expectedDataSz);
1601       acceptedSocket->write(&wcbServer, expectedData, expectedDataSz);
1602       acceptedSocket->close();
1603     }
1604   };
1605   acceptedSocket->setReadCB(&rcbServer);
1606
1607   rcb.dataAvailableCallback = [&]() {
1608     // we read data and reset readCB
1609     socket.setReadCB(nullptr);
1610   };
1611
1612   // write data
1613   WriteCallback wcb;
1614   socket.write(&wcb, expectedData, expectedDataSz);
1615   evb.loop();
1616   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
1617
1618   /* we shoud've only read maxBufferSz data since readCallback_
1619    * was reset in dataAvailableCallback */
1620   ASSERT_EQ(rcb.dataRead(), maxBufferSz);
1621   ASSERT_EQ(socket.immediateReadCalled, false);
1622
1623   ASSERT_FALSE(socket.isClosedBySelf());
1624   ASSERT_FALSE(socket.isClosedByPeer());
1625 }
1626
1627 // TODO:
1628 // - Test connect() and have the connect callback set the read callback
1629 // - Test connect() and have the connect callback unset the read callback
1630 // - Test reading/writing/closing/destroying the socket in the connect callback
1631 // - Test reading/writing/closing/destroying the socket in the read callback
1632 // - Test reading/writing/closing/destroying the socket in the write callback
1633 // - Test one-way shutdown behavior
1634 // - Test changing the EventBase
1635 //
1636 // - TODO: test multiple threads sharing a AsyncSocket, and detaching from it
1637 //   in connectSuccess(), readDataAvailable(), writeSuccess()
1638
1639
1640 ///////////////////////////////////////////////////////////////////////////
1641 // AsyncServerSocket tests
1642 ///////////////////////////////////////////////////////////////////////////
1643
1644 /**
1645  * Make sure accepted sockets have O_NONBLOCK and TCP_NODELAY set
1646  */
1647 TEST(AsyncSocketTest, ServerAcceptOptions) {
1648   EventBase eventBase;
1649
1650   // Create a server socket
1651   std::shared_ptr<AsyncServerSocket> serverSocket(
1652       AsyncServerSocket::newSocket(&eventBase));
1653   serverSocket->bind(0);
1654   serverSocket->listen(16);
1655   folly::SocketAddress serverAddress;
1656   serverSocket->getAddress(&serverAddress);
1657
1658   // Add a callback to accept one connection then stop the loop
1659   TestAcceptCallback acceptCallback;
1660   acceptCallback.setConnectionAcceptedFn(
1661       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1662         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1663       });
1664   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
1665     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
1666   });
1667   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
1668   serverSocket->startAccepting();
1669
1670   // Connect to the server socket
1671   std::shared_ptr<AsyncSocket> socket(
1672       AsyncSocket::newSocket(&eventBase, serverAddress));
1673
1674   eventBase.loop();
1675
1676   // Verify that the server accepted a connection
1677   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
1678   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
1679                     TestAcceptCallback::TYPE_START);
1680   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
1681                     TestAcceptCallback::TYPE_ACCEPT);
1682   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
1683                     TestAcceptCallback::TYPE_STOP);
1684   int fd = acceptCallback.getEvents()->at(1).fd;
1685
1686   // The accepted connection should already be in non-blocking mode
1687   int flags = fcntl(fd, F_GETFL, 0);
1688   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
1689
1690 #ifndef TCP_NOPUSH
1691   // The accepted connection should already have TCP_NODELAY set
1692   int value;
1693   socklen_t valueLength = sizeof(value);
1694   int rc = getsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &value, &valueLength);
1695   ASSERT_EQ(rc, 0);
1696   ASSERT_EQ(value, 1);
1697 #endif
1698 }
1699
1700 /**
1701  * Test AsyncServerSocket::removeAcceptCallback()
1702  */
1703 TEST(AsyncSocketTest, RemoveAcceptCallback) {
1704   // Create a new AsyncServerSocket
1705   EventBase eventBase;
1706   std::shared_ptr<AsyncServerSocket> serverSocket(
1707       AsyncServerSocket::newSocket(&eventBase));
1708   serverSocket->bind(0);
1709   serverSocket->listen(16);
1710   folly::SocketAddress serverAddress;
1711   serverSocket->getAddress(&serverAddress);
1712
1713   // Add several accept callbacks
1714   TestAcceptCallback cb1;
1715   TestAcceptCallback cb2;
1716   TestAcceptCallback cb3;
1717   TestAcceptCallback cb4;
1718   TestAcceptCallback cb5;
1719   TestAcceptCallback cb6;
1720   TestAcceptCallback cb7;
1721
1722   // Test having callbacks remove other callbacks before them on the list,
1723   // after them on the list, or removing themselves.
1724   //
1725   // Have callback 2 remove callback 3 and callback 5 the first time it is
1726   // called.
1727   int cb2Count = 0;
1728   cb1.setConnectionAcceptedFn([&](int /* fd */,
1729                                   const folly::SocketAddress& /* addr */) {
1730     std::shared_ptr<AsyncSocket> sock2(
1731         AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2: -cb3 -cb5
1732   });
1733   cb3.setConnectionAcceptedFn(
1734       [&](int /* fd */, const folly::SocketAddress& /* addr */) {});
1735   cb4.setConnectionAcceptedFn(
1736       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1737         std::shared_ptr<AsyncSocket> sock3(
1738             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb4
1739       });
1740   cb5.setConnectionAcceptedFn(
1741       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1742         std::shared_ptr<AsyncSocket> sock5(
1743             AsyncSocket::newSocket(&eventBase, serverAddress)); // cb7: -cb7
1744
1745       });
1746   cb2.setConnectionAcceptedFn(
1747       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1748         if (cb2Count == 0) {
1749           serverSocket->removeAcceptCallback(&cb3, nullptr);
1750           serverSocket->removeAcceptCallback(&cb5, nullptr);
1751         }
1752         ++cb2Count;
1753       });
1754   // Have callback 6 remove callback 4 the first time it is called,
1755   // and destroy the server socket the second time it is called
1756   int cb6Count = 0;
1757   cb6.setConnectionAcceptedFn(
1758       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1759         if (cb6Count == 0) {
1760           serverSocket->removeAcceptCallback(&cb4, nullptr);
1761           std::shared_ptr<AsyncSocket> sock6(
1762               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1763           std::shared_ptr<AsyncSocket> sock7(
1764               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb2
1765           std::shared_ptr<AsyncSocket> sock8(
1766               AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: stop
1767
1768         } else {
1769           serverSocket.reset();
1770         }
1771         ++cb6Count;
1772       });
1773   // Have callback 7 remove itself
1774   cb7.setConnectionAcceptedFn(
1775       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1776         serverSocket->removeAcceptCallback(&cb7, nullptr);
1777       });
1778
1779   serverSocket->addAcceptCallback(&cb1, &eventBase);
1780   serverSocket->addAcceptCallback(&cb2, &eventBase);
1781   serverSocket->addAcceptCallback(&cb3, &eventBase);
1782   serverSocket->addAcceptCallback(&cb4, &eventBase);
1783   serverSocket->addAcceptCallback(&cb5, &eventBase);
1784   serverSocket->addAcceptCallback(&cb6, &eventBase);
1785   serverSocket->addAcceptCallback(&cb7, &eventBase);
1786   serverSocket->startAccepting();
1787
1788   // Make several connections to the socket
1789   std::shared_ptr<AsyncSocket> sock1(
1790       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1791   std::shared_ptr<AsyncSocket> sock4(
1792       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb6: -cb4
1793
1794   // Loop until we are stopped
1795   eventBase.loop();
1796
1797   // Check to make sure that the expected callbacks were invoked.
1798   //
1799   // NOTE: This code depends on the AsyncServerSocket operating calling all of
1800   // the AcceptCallbacks in round-robin fashion, in the order that they were
1801   // added.  The code is implemented this way right now, but the API doesn't
1802   // explicitly require it be done this way.  If we change the code not to be
1803   // exactly round robin in the future, we can simplify the test checks here.
1804   // (We'll also need to update the termination code, since we expect cb6 to
1805   // get called twice to terminate the loop.)
1806   ASSERT_EQ(cb1.getEvents()->size(), 4);
1807   ASSERT_EQ(cb1.getEvents()->at(0).type,
1808                     TestAcceptCallback::TYPE_START);
1809   ASSERT_EQ(cb1.getEvents()->at(1).type,
1810                     TestAcceptCallback::TYPE_ACCEPT);
1811   ASSERT_EQ(cb1.getEvents()->at(2).type,
1812                     TestAcceptCallback::TYPE_ACCEPT);
1813   ASSERT_EQ(cb1.getEvents()->at(3).type,
1814                     TestAcceptCallback::TYPE_STOP);
1815
1816   ASSERT_EQ(cb2.getEvents()->size(), 4);
1817   ASSERT_EQ(cb2.getEvents()->at(0).type,
1818                     TestAcceptCallback::TYPE_START);
1819   ASSERT_EQ(cb2.getEvents()->at(1).type,
1820                     TestAcceptCallback::TYPE_ACCEPT);
1821   ASSERT_EQ(cb2.getEvents()->at(2).type,
1822                     TestAcceptCallback::TYPE_ACCEPT);
1823   ASSERT_EQ(cb2.getEvents()->at(3).type,
1824                     TestAcceptCallback::TYPE_STOP);
1825
1826   ASSERT_EQ(cb3.getEvents()->size(), 2);
1827   ASSERT_EQ(cb3.getEvents()->at(0).type,
1828                     TestAcceptCallback::TYPE_START);
1829   ASSERT_EQ(cb3.getEvents()->at(1).type,
1830                     TestAcceptCallback::TYPE_STOP);
1831
1832   ASSERT_EQ(cb4.getEvents()->size(), 3);
1833   ASSERT_EQ(cb4.getEvents()->at(0).type,
1834                     TestAcceptCallback::TYPE_START);
1835   ASSERT_EQ(cb4.getEvents()->at(1).type,
1836                     TestAcceptCallback::TYPE_ACCEPT);
1837   ASSERT_EQ(cb4.getEvents()->at(2).type,
1838                     TestAcceptCallback::TYPE_STOP);
1839
1840   ASSERT_EQ(cb5.getEvents()->size(), 2);
1841   ASSERT_EQ(cb5.getEvents()->at(0).type,
1842                     TestAcceptCallback::TYPE_START);
1843   ASSERT_EQ(cb5.getEvents()->at(1).type,
1844                     TestAcceptCallback::TYPE_STOP);
1845
1846   ASSERT_EQ(cb6.getEvents()->size(), 4);
1847   ASSERT_EQ(cb6.getEvents()->at(0).type,
1848                     TestAcceptCallback::TYPE_START);
1849   ASSERT_EQ(cb6.getEvents()->at(1).type,
1850                     TestAcceptCallback::TYPE_ACCEPT);
1851   ASSERT_EQ(cb6.getEvents()->at(2).type,
1852                     TestAcceptCallback::TYPE_ACCEPT);
1853   ASSERT_EQ(cb6.getEvents()->at(3).type,
1854                     TestAcceptCallback::TYPE_STOP);
1855
1856   ASSERT_EQ(cb7.getEvents()->size(), 3);
1857   ASSERT_EQ(cb7.getEvents()->at(0).type,
1858                     TestAcceptCallback::TYPE_START);
1859   ASSERT_EQ(cb7.getEvents()->at(1).type,
1860                     TestAcceptCallback::TYPE_ACCEPT);
1861   ASSERT_EQ(cb7.getEvents()->at(2).type,
1862                     TestAcceptCallback::TYPE_STOP);
1863 }
1864
1865 /**
1866  * Test AsyncServerSocket::removeAcceptCallback()
1867  */
1868 TEST(AsyncSocketTest, OtherThreadAcceptCallback) {
1869   // Create a new AsyncServerSocket
1870   EventBase eventBase;
1871   std::shared_ptr<AsyncServerSocket> serverSocket(
1872       AsyncServerSocket::newSocket(&eventBase));
1873   serverSocket->bind(0);
1874   serverSocket->listen(16);
1875   folly::SocketAddress serverAddress;
1876   serverSocket->getAddress(&serverAddress);
1877
1878   // Add several accept callbacks
1879   TestAcceptCallback cb1;
1880   auto thread_id = std::this_thread::get_id();
1881   cb1.setAcceptStartedFn([&](){
1882     CHECK_NE(thread_id, std::this_thread::get_id());
1883     thread_id = std::this_thread::get_id();
1884   });
1885   cb1.setConnectionAcceptedFn(
1886       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1887         ASSERT_EQ(thread_id, std::this_thread::get_id());
1888         serverSocket->removeAcceptCallback(&cb1, &eventBase);
1889       });
1890   cb1.setAcceptStoppedFn([&](){
1891     ASSERT_EQ(thread_id, std::this_thread::get_id());
1892   });
1893
1894   // Test having callbacks remove other callbacks before them on the list,
1895   serverSocket->addAcceptCallback(&cb1, &eventBase);
1896   serverSocket->startAccepting();
1897
1898   // Make several connections to the socket
1899   std::shared_ptr<AsyncSocket> sock1(
1900       AsyncSocket::newSocket(&eventBase, serverAddress)); // cb1
1901
1902   // Loop in another thread
1903   auto other = std::thread([&](){
1904     eventBase.loop();
1905   });
1906   other.join();
1907
1908   // Check to make sure that the expected callbacks were invoked.
1909   //
1910   // NOTE: This code depends on the AsyncServerSocket operating calling all of
1911   // the AcceptCallbacks in round-robin fashion, in the order that they were
1912   // added.  The code is implemented this way right now, but the API doesn't
1913   // explicitly require it be done this way.  If we change the code not to be
1914   // exactly round robin in the future, we can simplify the test checks here.
1915   // (We'll also need to update the termination code, since we expect cb6 to
1916   // get called twice to terminate the loop.)
1917   ASSERT_EQ(cb1.getEvents()->size(), 3);
1918   ASSERT_EQ(cb1.getEvents()->at(0).type,
1919                     TestAcceptCallback::TYPE_START);
1920   ASSERT_EQ(cb1.getEvents()->at(1).type,
1921                     TestAcceptCallback::TYPE_ACCEPT);
1922   ASSERT_EQ(cb1.getEvents()->at(2).type,
1923                     TestAcceptCallback::TYPE_STOP);
1924
1925 }
1926
1927 void serverSocketSanityTest(AsyncServerSocket* serverSocket) {
1928   EventBase* eventBase = serverSocket->getEventBase();
1929   CHECK(eventBase);
1930
1931   // Add a callback to accept one connection then stop accepting
1932   TestAcceptCallback acceptCallback;
1933   acceptCallback.setConnectionAcceptedFn(
1934       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
1935         serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
1936       });
1937   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
1938     serverSocket->removeAcceptCallback(&acceptCallback, eventBase);
1939   });
1940   serverSocket->addAcceptCallback(&acceptCallback, eventBase);
1941   serverSocket->startAccepting();
1942
1943   // Connect to the server socket
1944   folly::SocketAddress serverAddress;
1945   serverSocket->getAddress(&serverAddress);
1946   AsyncSocket::UniquePtr socket(new AsyncSocket(eventBase, serverAddress));
1947
1948   // Loop to process all events
1949   eventBase->loop();
1950
1951   // Verify that the server accepted a connection
1952   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
1953   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
1954                     TestAcceptCallback::TYPE_START);
1955   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
1956                     TestAcceptCallback::TYPE_ACCEPT);
1957   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
1958                     TestAcceptCallback::TYPE_STOP);
1959 }
1960
1961 /* Verify that we don't leak sockets if we are destroyed()
1962  * and there are still writes pending
1963  *
1964  * If destroy() only calls close() instead of closeNow(),
1965  * it would shutdown(writes) on the socket, but it would
1966  * never be close()'d, and the socket would leak
1967  */
1968 TEST(AsyncSocketTest, DestroyCloseTest) {
1969   TestServer server;
1970
1971   // connect()
1972   EventBase clientEB;
1973   EventBase serverEB;
1974   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&clientEB);
1975   ConnCallback ccb;
1976   socket->connect(&ccb, server.getAddress(), 30);
1977
1978   // Accept the connection
1979   std::shared_ptr<AsyncSocket> acceptedSocket = server.acceptAsync(&serverEB);
1980   ReadCallback rcb;
1981   acceptedSocket->setReadCB(&rcb);
1982
1983   // Write a large buffer to the socket that is larger than kernel buffer
1984   size_t simpleBufLength = 5000000;
1985   char* simpleBuf = new char[simpleBufLength];
1986   memset(simpleBuf, 'a', simpleBufLength);
1987   WriteCallback wcb;
1988
1989   // Let the reads and writes run to completion
1990   int fd = acceptedSocket->getFd();
1991
1992   acceptedSocket->write(&wcb, simpleBuf, simpleBufLength);
1993   socket.reset();
1994   acceptedSocket.reset();
1995
1996   // Test that server socket was closed
1997   folly::test::msvcSuppressAbortOnInvalidParams([&] {
1998     ssize_t sz = read(fd, simpleBuf, simpleBufLength);
1999     ASSERT_EQ(sz, -1);
2000     ASSERT_EQ(errno, EBADF);
2001   });
2002   delete[] simpleBuf;
2003 }
2004
2005 /**
2006  * Test AsyncServerSocket::useExistingSocket()
2007  */
2008 TEST(AsyncSocketTest, ServerExistingSocket) {
2009   EventBase eventBase;
2010
2011   // Test creating a socket, and letting AsyncServerSocket bind and listen
2012   {
2013     // Manually create a socket
2014     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2015     ASSERT_GE(fd, 0);
2016
2017     // Create a server socket
2018     AsyncServerSocket::UniquePtr serverSocket(
2019         new AsyncServerSocket(&eventBase));
2020     serverSocket->useExistingSocket(fd);
2021     folly::SocketAddress address;
2022     serverSocket->getAddress(&address);
2023     address.setPort(0);
2024     serverSocket->bind(address);
2025     serverSocket->listen(16);
2026
2027     // Make sure the socket works
2028     serverSocketSanityTest(serverSocket.get());
2029   }
2030
2031   // Test creating a socket and binding manually,
2032   // then letting AsyncServerSocket listen
2033   {
2034     // Manually create a socket
2035     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2036     ASSERT_GE(fd, 0);
2037     // bind
2038     struct sockaddr_in addr;
2039     addr.sin_family = AF_INET;
2040     addr.sin_port = 0;
2041     addr.sin_addr.s_addr = INADDR_ANY;
2042     ASSERT_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
2043                              sizeof(addr)), 0);
2044     // Look up the address that we bound to
2045     folly::SocketAddress boundAddress;
2046     boundAddress.setFromLocalAddress(fd);
2047
2048     // Create a server socket
2049     AsyncServerSocket::UniquePtr serverSocket(
2050         new AsyncServerSocket(&eventBase));
2051     serverSocket->useExistingSocket(fd);
2052     serverSocket->listen(16);
2053
2054     // Make sure AsyncServerSocket reports the same address that we bound to
2055     folly::SocketAddress serverSocketAddress;
2056     serverSocket->getAddress(&serverSocketAddress);
2057     ASSERT_EQ(boundAddress, serverSocketAddress);
2058
2059     // Make sure the socket works
2060     serverSocketSanityTest(serverSocket.get());
2061   }
2062
2063   // Test creating a socket, binding and listening manually,
2064   // then giving it to AsyncServerSocket
2065   {
2066     // Manually create a socket
2067     int fd = fsp::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
2068     ASSERT_GE(fd, 0);
2069     // bind
2070     struct sockaddr_in addr;
2071     addr.sin_family = AF_INET;
2072     addr.sin_port = 0;
2073     addr.sin_addr.s_addr = INADDR_ANY;
2074     ASSERT_EQ(bind(fd, reinterpret_cast<struct sockaddr*>(&addr),
2075                              sizeof(addr)), 0);
2076     // Look up the address that we bound to
2077     folly::SocketAddress boundAddress;
2078     boundAddress.setFromLocalAddress(fd);
2079     // listen
2080     ASSERT_EQ(listen(fd, 16), 0);
2081
2082     // Create a server socket
2083     AsyncServerSocket::UniquePtr serverSocket(
2084         new AsyncServerSocket(&eventBase));
2085     serverSocket->useExistingSocket(fd);
2086
2087     // Make sure AsyncServerSocket reports the same address that we bound to
2088     folly::SocketAddress serverSocketAddress;
2089     serverSocket->getAddress(&serverSocketAddress);
2090     ASSERT_EQ(boundAddress, serverSocketAddress);
2091
2092     // Make sure the socket works
2093     serverSocketSanityTest(serverSocket.get());
2094   }
2095 }
2096
2097 TEST(AsyncSocketTest, UnixDomainSocketTest) {
2098   EventBase eventBase;
2099
2100   // Create a server socket
2101   std::shared_ptr<AsyncServerSocket> serverSocket(
2102       AsyncServerSocket::newSocket(&eventBase));
2103   string path(1, 0);
2104   path.append(folly::to<string>("/anonymous", folly::Random::rand64()));
2105   folly::SocketAddress serverAddress;
2106   serverAddress.setFromPath(path);
2107   serverSocket->bind(serverAddress);
2108   serverSocket->listen(16);
2109
2110   // Add a callback to accept one connection then stop the loop
2111   TestAcceptCallback acceptCallback;
2112   acceptCallback.setConnectionAcceptedFn(
2113       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2114         serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2115       });
2116   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2117     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2118   });
2119   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2120   serverSocket->startAccepting();
2121
2122   // Connect to the server socket
2123   std::shared_ptr<AsyncSocket> socket(
2124       AsyncSocket::newSocket(&eventBase, serverAddress));
2125
2126   eventBase.loop();
2127
2128   // Verify that the server accepted a connection
2129   ASSERT_EQ(acceptCallback.getEvents()->size(), 3);
2130   ASSERT_EQ(acceptCallback.getEvents()->at(0).type,
2131                     TestAcceptCallback::TYPE_START);
2132   ASSERT_EQ(acceptCallback.getEvents()->at(1).type,
2133                     TestAcceptCallback::TYPE_ACCEPT);
2134   ASSERT_EQ(acceptCallback.getEvents()->at(2).type,
2135                     TestAcceptCallback::TYPE_STOP);
2136   int fd = acceptCallback.getEvents()->at(1).fd;
2137
2138   // The accepted connection should already be in non-blocking mode
2139   int flags = fcntl(fd, F_GETFL, 0);
2140   ASSERT_EQ(flags & O_NONBLOCK, O_NONBLOCK);
2141 }
2142
2143 TEST(AsyncSocketTest, ConnectionEventCallbackDefault) {
2144   EventBase eventBase;
2145   TestConnectionEventCallback connectionEventCallback;
2146
2147   // Create a server socket
2148   std::shared_ptr<AsyncServerSocket> serverSocket(
2149       AsyncServerSocket::newSocket(&eventBase));
2150   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2151   serverSocket->bind(0);
2152   serverSocket->listen(16);
2153   folly::SocketAddress serverAddress;
2154   serverSocket->getAddress(&serverAddress);
2155
2156   // Add a callback to accept one connection then stop the loop
2157   TestAcceptCallback acceptCallback;
2158   acceptCallback.setConnectionAcceptedFn(
2159       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2160         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2161       });
2162   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2163     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2164   });
2165   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2166   serverSocket->startAccepting();
2167
2168   // Connect to the server socket
2169   std::shared_ptr<AsyncSocket> socket(
2170       AsyncSocket::newSocket(&eventBase, serverAddress));
2171
2172   eventBase.loop();
2173
2174   // Validate the connection event counters
2175   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2176   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2177   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2178   ASSERT_EQ(
2179       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 1);
2180   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 1);
2181   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2182   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2183   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2184 }
2185
2186 TEST(AsyncSocketTest, CallbackInPrimaryEventBase) {
2187   EventBase eventBase;
2188   TestConnectionEventCallback connectionEventCallback;
2189
2190   // Create a server socket
2191   std::shared_ptr<AsyncServerSocket> serverSocket(
2192       AsyncServerSocket::newSocket(&eventBase));
2193   serverSocket->setConnectionEventCallback(&connectionEventCallback);
2194   serverSocket->bind(0);
2195   serverSocket->listen(16);
2196   folly::SocketAddress serverAddress;
2197   serverSocket->getAddress(&serverAddress);
2198
2199   // Add a callback to accept one connection then stop the loop
2200   TestAcceptCallback acceptCallback;
2201   acceptCallback.setConnectionAcceptedFn(
2202       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2203         serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2204       });
2205   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2206     serverSocket->removeAcceptCallback(&acceptCallback, nullptr);
2207   });
2208   bool acceptStartedFlag{false};
2209   acceptCallback.setAcceptStartedFn([&acceptStartedFlag](){
2210     acceptStartedFlag = true;
2211   });
2212   bool acceptStoppedFlag{false};
2213   acceptCallback.setAcceptStoppedFn([&acceptStoppedFlag](){
2214     acceptStoppedFlag = true;
2215   });
2216   serverSocket->addAcceptCallback(&acceptCallback, nullptr);
2217   serverSocket->startAccepting();
2218
2219   // Connect to the server socket
2220   std::shared_ptr<AsyncSocket> socket(
2221       AsyncSocket::newSocket(&eventBase, serverAddress));
2222
2223   eventBase.loop();
2224
2225   ASSERT_TRUE(acceptStartedFlag);
2226   ASSERT_TRUE(acceptStoppedFlag);
2227   // Validate the connection event counters
2228   ASSERT_EQ(connectionEventCallback.getConnectionAccepted(), 1);
2229   ASSERT_EQ(connectionEventCallback.getConnectionAcceptedError(), 0);
2230   ASSERT_EQ(connectionEventCallback.getConnectionDropped(), 0);
2231   ASSERT_EQ(
2232       connectionEventCallback.getConnectionEnqueuedForAcceptCallback(), 0);
2233   ASSERT_EQ(connectionEventCallback.getConnectionDequeuedByAcceptCallback(), 0);
2234   ASSERT_EQ(connectionEventCallback.getBackoffStarted(), 0);
2235   ASSERT_EQ(connectionEventCallback.getBackoffEnded(), 0);
2236   ASSERT_EQ(connectionEventCallback.getBackoffError(), 0);
2237 }
2238
2239
2240
2241 /**
2242  * Test AsyncServerSocket::getNumPendingMessagesInQueue()
2243  */
2244 TEST(AsyncSocketTest, NumPendingMessagesInQueue) {
2245   EventBase eventBase;
2246
2247   // Counter of how many connections have been accepted
2248   int count = 0;
2249
2250   // Create a server socket
2251   auto serverSocket(AsyncServerSocket::newSocket(&eventBase));
2252   serverSocket->bind(0);
2253   serverSocket->listen(16);
2254   folly::SocketAddress serverAddress;
2255   serverSocket->getAddress(&serverAddress);
2256
2257   // Add a callback to accept connections
2258   TestAcceptCallback acceptCallback;
2259   acceptCallback.setConnectionAcceptedFn(
2260       [&](int /* fd */, const folly::SocketAddress& /* addr */) {
2261         count++;
2262         ASSERT_EQ(4 - count, serverSocket->getNumPendingMessagesInQueue());
2263
2264         if (count == 4) {
2265           // all messages are processed, remove accept callback
2266           serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2267         }
2268       });
2269   acceptCallback.setAcceptErrorFn([&](const std::exception& /* ex */) {
2270     serverSocket->removeAcceptCallback(&acceptCallback, &eventBase);
2271   });
2272   serverSocket->addAcceptCallback(&acceptCallback, &eventBase);
2273   serverSocket->startAccepting();
2274
2275   // Connect to the server socket, 4 clients, there are 4 connections
2276   auto socket1(AsyncSocket::newSocket(&eventBase, serverAddress));
2277   auto socket2(AsyncSocket::newSocket(&eventBase, serverAddress));
2278   auto socket3(AsyncSocket::newSocket(&eventBase, serverAddress));
2279   auto socket4(AsyncSocket::newSocket(&eventBase, serverAddress));
2280
2281   eventBase.loop();
2282 }
2283
2284 /**
2285  * Test AsyncTransport::BufferCallback
2286  */
2287 TEST(AsyncSocketTest, BufferTest) {
2288   TestServer server;
2289
2290   EventBase evb;
2291   AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2292   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2293   ConnCallback ccb;
2294   socket->connect(&ccb, server.getAddress(), 30, option);
2295
2296   char buf[100 * 1024];
2297   memset(buf, 'c', sizeof(buf));
2298   WriteCallback wcb;
2299   BufferCallback bcb;
2300   socket->setBufferCallback(&bcb);
2301   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2302
2303   evb.loop();
2304   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2305   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
2306
2307   ASSERT_TRUE(bcb.hasBuffered());
2308   ASSERT_TRUE(bcb.hasBufferCleared());
2309
2310   socket->close();
2311   server.verifyConnection(buf, sizeof(buf));
2312
2313   ASSERT_TRUE(socket->isClosedBySelf());
2314   ASSERT_FALSE(socket->isClosedByPeer());
2315 }
2316
2317 TEST(AsyncSocketTest, BufferCallbackKill) {
2318   TestServer server;
2319   EventBase evb;
2320   AsyncSocket::OptionMap option{{{SOL_SOCKET, SO_SNDBUF}, 128}};
2321   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2322   ConnCallback ccb;
2323   socket->connect(&ccb, server.getAddress(), 30, option);
2324   evb.loopOnce();
2325
2326   char buf[100 * 1024];
2327   memset(buf, 'c', sizeof(buf));
2328   BufferCallback bcb;
2329   socket->setBufferCallback(&bcb);
2330   WriteCallback wcb;
2331   wcb.successCallback = [&] {
2332     ASSERT_TRUE(socket.unique());
2333     socket.reset();
2334   };
2335
2336   // This will trigger AsyncSocket::handleWrite,
2337   // which calls WriteCallback::writeSuccess,
2338   // which calls wcb.successCallback above,
2339   // which tries to delete socket
2340   // Then, the socket will also try to use this BufferCallback
2341   // And that should crash us, if there is no DestructorGuard on the stack
2342   socket->write(&wcb, buf, sizeof(buf), WriteFlags::NONE);
2343
2344   evb.loop();
2345   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2346 }
2347
2348 #if FOLLY_ALLOW_TFO
2349 TEST(AsyncSocketTest, ConnectTFO) {
2350   // Start listening on a local port
2351   TestServer server(true);
2352
2353   // Connect using a AsyncSocket
2354   EventBase evb;
2355   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2356   socket->enableTFO();
2357   ConnCallback cb;
2358   socket->connect(&cb, server.getAddress(), 30);
2359
2360   std::array<uint8_t, 128> buf;
2361   memset(buf.data(), 'a', buf.size());
2362
2363   std::array<uint8_t, 3> readBuf;
2364   auto sendBuf = IOBuf::copyBuffer("hey");
2365
2366   std::thread t([&] {
2367     auto acceptedSocket = server.accept();
2368     acceptedSocket->write(buf.data(), buf.size());
2369     acceptedSocket->flush();
2370     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2371     acceptedSocket->close();
2372   });
2373
2374   evb.loop();
2375
2376   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2377   EXPECT_LE(0, socket->getConnectTime().count());
2378   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2379   EXPECT_TRUE(socket->getTFOAttempted());
2380
2381   // Should trigger the connect
2382   WriteCallback write;
2383   ReadCallback rcb;
2384   socket->writeChain(&write, sendBuf->clone());
2385   socket->setReadCB(&rcb);
2386   evb.loop();
2387
2388   t.join();
2389
2390   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2391   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2392   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2393   ASSERT_EQ(1, rcb.buffers.size());
2394   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2395   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2396   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2397 }
2398
2399 TEST(AsyncSocketTest, ConnectTFOSupplyEarlyReadCB) {
2400   // Start listening on a local port
2401   TestServer server(true);
2402
2403   // Connect using a AsyncSocket
2404   EventBase evb;
2405   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2406   socket->enableTFO();
2407   ConnCallback cb;
2408   socket->connect(&cb, server.getAddress(), 30);
2409   ReadCallback rcb;
2410   socket->setReadCB(&rcb);
2411
2412   std::array<uint8_t, 128> buf;
2413   memset(buf.data(), 'a', buf.size());
2414
2415   std::array<uint8_t, 3> readBuf;
2416   auto sendBuf = IOBuf::copyBuffer("hey");
2417
2418   std::thread t([&] {
2419     auto acceptedSocket = server.accept();
2420     acceptedSocket->write(buf.data(), buf.size());
2421     acceptedSocket->flush();
2422     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2423     acceptedSocket->close();
2424   });
2425
2426   evb.loop();
2427
2428   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2429   EXPECT_LE(0, socket->getConnectTime().count());
2430   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2431   EXPECT_TRUE(socket->getTFOAttempted());
2432
2433   // Should trigger the connect
2434   WriteCallback write;
2435   socket->writeChain(&write, sendBuf->clone());
2436   evb.loop();
2437
2438   t.join();
2439
2440   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2441   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2442   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2443   ASSERT_EQ(1, rcb.buffers.size());
2444   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2445   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2446   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2447 }
2448
2449 /**
2450  * Test connecting to a server that isn't listening
2451  */
2452 TEST(AsyncSocketTest, ConnectRefusedImmediatelyTFO) {
2453   EventBase evb;
2454
2455   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2456
2457   socket->enableTFO();
2458
2459   // Hopefully nothing is actually listening on this address
2460   folly::SocketAddress addr("::1", 65535);
2461   ConnCallback cb;
2462   socket->connect(&cb, addr, 30);
2463
2464   evb.loop();
2465
2466   WriteCallback write1;
2467   // Trigger the connect if TFO attempt is supported.
2468   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2469   WriteCallback write2;
2470   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2471   evb.loop();
2472
2473   if (!socket->getTFOFinished()) {
2474     EXPECT_EQ(STATE_FAILED, write1.state);
2475   } else {
2476     EXPECT_EQ(STATE_SUCCEEDED, write1.state);
2477     EXPECT_FALSE(socket->getTFOSucceded());
2478   }
2479
2480   EXPECT_EQ(STATE_FAILED, write2.state);
2481
2482   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2483   EXPECT_LE(0, socket->getConnectTime().count());
2484   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2485   EXPECT_TRUE(socket->getTFOAttempted());
2486 }
2487
2488 /**
2489  * Test calling closeNow() immediately after connecting.
2490  */
2491 TEST(AsyncSocketTest, ConnectWriteAndCloseNowTFO) {
2492   TestServer server(true);
2493
2494   // connect()
2495   EventBase evb;
2496   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2497   socket->enableTFO();
2498
2499   ConnCallback ccb;
2500   socket->connect(&ccb, server.getAddress(), 30);
2501
2502   // write()
2503   std::array<char, 128> buf;
2504   memset(buf.data(), 'a', buf.size());
2505
2506   // close()
2507   socket->closeNow();
2508
2509   // Loop, although there shouldn't be anything to do.
2510   evb.loop();
2511
2512   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2513
2514   ASSERT_TRUE(socket->isClosedBySelf());
2515   ASSERT_FALSE(socket->isClosedByPeer());
2516 }
2517
2518 /**
2519  * Test calling close() immediately after connect()
2520  */
2521 TEST(AsyncSocketTest, ConnectAndCloseTFO) {
2522   TestServer server(true);
2523
2524   // Connect using a AsyncSocket
2525   EventBase evb;
2526   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2527   socket->enableTFO();
2528
2529   ConnCallback ccb;
2530   socket->connect(&ccb, server.getAddress(), 30);
2531
2532   socket->close();
2533
2534   // Loop, although there shouldn't be anything to do.
2535   evb.loop();
2536
2537   // Make sure the connection was aborted
2538   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2539
2540   ASSERT_TRUE(socket->isClosedBySelf());
2541   ASSERT_FALSE(socket->isClosedByPeer());
2542 }
2543
2544 class MockAsyncTFOSocket : public AsyncSocket {
2545  public:
2546   using UniquePtr = std::unique_ptr<MockAsyncTFOSocket, Destructor>;
2547
2548   explicit MockAsyncTFOSocket(EventBase* evb) : AsyncSocket(evb) {}
2549
2550   MOCK_METHOD3(tfoSendMsg, ssize_t(int fd, struct msghdr* msg, int msg_flags));
2551 };
2552
2553 TEST(AsyncSocketTest, TestTFOUnsupported) {
2554   TestServer server(true);
2555
2556   // Connect using a AsyncSocket
2557   EventBase evb;
2558   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2559   socket->enableTFO();
2560
2561   ConnCallback ccb;
2562   socket->connect(&ccb, server.getAddress(), 30);
2563   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2564
2565   ReadCallback rcb;
2566   socket->setReadCB(&rcb);
2567
2568   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2569       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2570   WriteCallback write;
2571   auto sendBuf = IOBuf::copyBuffer("hey");
2572   socket->writeChain(&write, sendBuf->clone());
2573   EXPECT_EQ(STATE_WAITING, write.state);
2574
2575   std::array<uint8_t, 128> buf;
2576   memset(buf.data(), 'a', buf.size());
2577
2578   std::array<uint8_t, 3> readBuf;
2579
2580   std::thread t([&] {
2581     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2582     acceptedSocket->write(buf.data(), buf.size());
2583     acceptedSocket->flush();
2584     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2585     acceptedSocket->close();
2586   });
2587
2588   evb.loop();
2589
2590   t.join();
2591   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2592   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2593
2594   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2595   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2596   ASSERT_EQ(1, rcb.buffers.size());
2597   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2598   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2599   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2600 }
2601
2602 TEST(AsyncSocketTest, ConnectRefusedDelayedTFO) {
2603   EventBase evb;
2604
2605   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2606   socket->enableTFO();
2607
2608   // Hopefully this fails
2609   folly::SocketAddress fakeAddr("127.0.0.1", 65535);
2610   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2611       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2612         sockaddr_storage addr;
2613         auto len = fakeAddr.getAddress(&addr);
2614         int ret = connect(fd, (const struct sockaddr*)&addr, len);
2615         LOG(INFO) << "connecting the socket " << fd << " : " << ret << " : "
2616                   << errno;
2617         return ret;
2618       }));
2619
2620   // Hopefully nothing is actually listening on this address
2621   ConnCallback cb;
2622   socket->connect(&cb, fakeAddr, 30);
2623
2624   WriteCallback write1;
2625   // Trigger the connect if TFO attempt is supported.
2626   socket->writeChain(&write1, IOBuf::copyBuffer("hey"));
2627
2628   if (socket->getTFOFinished()) {
2629     // This test is useless now.
2630     return;
2631   }
2632   WriteCallback write2;
2633   // Trigger the connect if TFO attempt is supported.
2634   socket->writeChain(&write2, IOBuf::copyBuffer("hey"));
2635   evb.loop();
2636
2637   EXPECT_EQ(STATE_FAILED, write1.state);
2638   EXPECT_EQ(STATE_FAILED, write2.state);
2639   EXPECT_FALSE(socket->getTFOSucceded());
2640
2641   EXPECT_EQ(STATE_SUCCEEDED, cb.state);
2642   EXPECT_LE(0, socket->getConnectTime().count());
2643   EXPECT_EQ(std::chrono::milliseconds(30), socket->getConnectTimeout());
2644   EXPECT_TRUE(socket->getTFOAttempted());
2645 }
2646
2647 TEST(AsyncSocketTest, TestTFOUnsupportedTimeout) {
2648   // Try connecting to server that won't respond.
2649   //
2650   // This depends somewhat on the network where this test is run.
2651   // Hopefully this IP will be routable but unresponsive.
2652   // (Alternatively, we could try listening on a local raw socket, but that
2653   // normally requires root privileges.)
2654   auto host = SocketAddressTestHelper::isIPv6Enabled()
2655       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2656       : SocketAddressTestHelper::isIPv4Enabled()
2657           ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2658           : nullptr;
2659   SocketAddress addr(host, 65535);
2660
2661   // Connect using a AsyncSocket
2662   EventBase evb;
2663   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2664   socket->enableTFO();
2665
2666   ConnCallback ccb;
2667   // Set a very small timeout
2668   socket->connect(&ccb, addr, 1);
2669   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2670
2671   ReadCallback rcb;
2672   socket->setReadCB(&rcb);
2673
2674   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2675       .WillOnce(SetErrnoAndReturn(EOPNOTSUPP, -1));
2676   WriteCallback write;
2677   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2678
2679   evb.loop();
2680
2681   EXPECT_EQ(STATE_FAILED, write.state);
2682 }
2683
2684 TEST(AsyncSocketTest, TestTFOFallbackToConnect) {
2685   TestServer server(true);
2686
2687   // Connect using a AsyncSocket
2688   EventBase evb;
2689   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2690   socket->enableTFO();
2691
2692   ConnCallback ccb;
2693   socket->connect(&ccb, server.getAddress(), 30);
2694   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2695
2696   ReadCallback rcb;
2697   socket->setReadCB(&rcb);
2698
2699   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2700       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2701         sockaddr_storage addr;
2702         auto len = server.getAddress().getAddress(&addr);
2703         return connect(fd, (const struct sockaddr*)&addr, len);
2704       }));
2705   WriteCallback write;
2706   auto sendBuf = IOBuf::copyBuffer("hey");
2707   socket->writeChain(&write, sendBuf->clone());
2708   EXPECT_EQ(STATE_WAITING, write.state);
2709
2710   std::array<uint8_t, 128> buf;
2711   memset(buf.data(), 'a', buf.size());
2712
2713   std::array<uint8_t, 3> readBuf;
2714
2715   std::thread t([&] {
2716     std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
2717     acceptedSocket->write(buf.data(), buf.size());
2718     acceptedSocket->flush();
2719     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2720     acceptedSocket->close();
2721   });
2722
2723   evb.loop();
2724
2725   t.join();
2726   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2727
2728   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2729   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2730
2731   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2732   ASSERT_EQ(1, rcb.buffers.size());
2733   ASSERT_EQ(buf.size(), rcb.buffers[0].length);
2734   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2735 }
2736
2737 TEST(AsyncSocketTest, TestTFOFallbackTimeout) {
2738   // Try connecting to server that won't respond.
2739   //
2740   // This depends somewhat on the network where this test is run.
2741   // Hopefully this IP will be routable but unresponsive.
2742   // (Alternatively, we could try listening on a local raw socket, but that
2743   // normally requires root privileges.)
2744   auto host = SocketAddressTestHelper::isIPv6Enabled()
2745       ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv6
2746       : SocketAddressTestHelper::isIPv4Enabled()
2747           ? SocketAddressTestHelper::kGooglePublicDnsAAddrIPv4
2748           : nullptr;
2749   SocketAddress addr(host, 65535);
2750
2751   // Connect using a AsyncSocket
2752   EventBase evb;
2753   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2754   socket->enableTFO();
2755
2756   ConnCallback ccb;
2757   // Set a very small timeout
2758   socket->connect(&ccb, addr, 1);
2759   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2760
2761   ReadCallback rcb;
2762   socket->setReadCB(&rcb);
2763
2764   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2765       .WillOnce(Invoke([&](int fd, struct msghdr*, int) {
2766         sockaddr_storage addr2;
2767         auto len = addr.getAddress(&addr2);
2768         return connect(fd, (const struct sockaddr*)&addr2, len);
2769       }));
2770   WriteCallback write;
2771   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2772
2773   evb.loop();
2774
2775   EXPECT_EQ(STATE_FAILED, write.state);
2776 }
2777
2778 TEST(AsyncSocketTest, TestTFOEagain) {
2779   TestServer server(true);
2780
2781   // Connect using a AsyncSocket
2782   EventBase evb;
2783   auto socket = MockAsyncTFOSocket::UniquePtr(new MockAsyncTFOSocket(&evb));
2784   socket->enableTFO();
2785
2786   ConnCallback ccb;
2787   socket->connect(&ccb, server.getAddress(), 30);
2788
2789   EXPECT_CALL(*socket, tfoSendMsg(_, _, _))
2790       .WillOnce(SetErrnoAndReturn(EAGAIN, -1));
2791   WriteCallback write;
2792   socket->writeChain(&write, IOBuf::copyBuffer("hey"));
2793
2794   evb.loop();
2795
2796   EXPECT_EQ(STATE_SUCCEEDED, ccb.state);
2797   EXPECT_EQ(STATE_FAILED, write.state);
2798 }
2799
2800 // Sending a large amount of data in the first write which will
2801 // definitely not fit into MSS.
2802 TEST(AsyncSocketTest, ConnectTFOWithBigData) {
2803   // Start listening on a local port
2804   TestServer server(true);
2805
2806   // Connect using a AsyncSocket
2807   EventBase evb;
2808   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2809   socket->enableTFO();
2810   ConnCallback cb;
2811   socket->connect(&cb, server.getAddress(), 30);
2812
2813   std::array<uint8_t, 128> buf;
2814   memset(buf.data(), 'a', buf.size());
2815
2816   constexpr size_t len = 10 * 1024;
2817   auto sendBuf = IOBuf::create(len);
2818   sendBuf->append(len);
2819   std::array<uint8_t, len> readBuf;
2820
2821   std::thread t([&] {
2822     auto acceptedSocket = server.accept();
2823     acceptedSocket->write(buf.data(), buf.size());
2824     acceptedSocket->flush();
2825     acceptedSocket->readAll(readBuf.data(), readBuf.size());
2826     acceptedSocket->close();
2827   });
2828
2829   evb.loop();
2830
2831   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2832   EXPECT_LE(0, socket->getConnectTime().count());
2833   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2834   EXPECT_TRUE(socket->getTFOAttempted());
2835
2836   // Should trigger the connect
2837   WriteCallback write;
2838   ReadCallback rcb;
2839   socket->writeChain(&write, sendBuf->clone());
2840   socket->setReadCB(&rcb);
2841   evb.loop();
2842
2843   t.join();
2844
2845   EXPECT_EQ(STATE_SUCCEEDED, write.state);
2846   EXPECT_EQ(0, memcmp(readBuf.data(), sendBuf->data(), readBuf.size()));
2847   EXPECT_EQ(STATE_SUCCEEDED, rcb.state);
2848   ASSERT_EQ(1, rcb.buffers.size());
2849   ASSERT_EQ(sizeof(buf), rcb.buffers[0].length);
2850   EXPECT_EQ(0, memcmp(rcb.buffers[0].buffer, buf.data(), buf.size()));
2851   EXPECT_EQ(socket->getTFOFinished(), socket->getTFOSucceded());
2852 }
2853
2854 #endif // FOLLY_ALLOW_TFO
2855
2856 class MockEvbChangeCallback : public AsyncSocket::EvbChangeCallback {
2857  public:
2858   MOCK_METHOD1(evbAttached, void(AsyncSocket*));
2859   MOCK_METHOD1(evbDetached, void(AsyncSocket*));
2860 };
2861
2862 TEST(AsyncSocketTest, EvbCallbacks) {
2863   auto cb = std::make_unique<MockEvbChangeCallback>();
2864   EventBase evb;
2865   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2866
2867   InSequence seq;
2868   EXPECT_CALL(*cb, evbDetached(socket.get())).Times(1);
2869   EXPECT_CALL(*cb, evbAttached(socket.get())).Times(1);
2870
2871   socket->setEvbChangedCallback(std::move(cb));
2872   socket->detachEventBase();
2873   socket->attachEventBase(&evb);
2874 }
2875
2876 TEST(AsyncSocketTest, TestEvbDetachWtRegisteredIOHandlers) {
2877   // Start listening on a local port
2878   TestServer server;
2879
2880   // Connect using a AsyncSocket
2881   EventBase evb;
2882   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2883   ConnCallback cb;
2884   socket->connect(&cb, server.getAddress(), 30);
2885
2886   evb.loop();
2887
2888   ASSERT_EQ(cb.state, STATE_SUCCEEDED);
2889   EXPECT_LE(0, socket->getConnectTime().count());
2890   EXPECT_EQ(socket->getConnectTimeout(), std::chrono::milliseconds(30));
2891
2892   // After the ioHandlers are registered, still should be able to detach/attach
2893   ReadCallback rcb;
2894   socket->setReadCB(&rcb);
2895
2896   auto cbEvbChg = std::make_unique<MockEvbChangeCallback>();
2897   InSequence seq;
2898   EXPECT_CALL(*cbEvbChg, evbDetached(socket.get())).Times(1);
2899   EXPECT_CALL(*cbEvbChg, evbAttached(socket.get())).Times(1);
2900
2901   socket->setEvbChangedCallback(std::move(cbEvbChg));
2902   EXPECT_TRUE(socket->isDetachable());
2903   socket->detachEventBase();
2904   socket->attachEventBase(&evb);
2905
2906   socket->close();
2907 }
2908
2909 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
2910 /* copied from include/uapi/linux/net_tstamp.h */
2911 /* SO_TIMESTAMPING gets an integer bit field comprised of these values */
2912 enum SOF_TIMESTAMPING {
2913   SOF_TIMESTAMPING_SOFTWARE = (1 << 4),
2914   SOF_TIMESTAMPING_OPT_ID = (1 << 7),
2915   SOF_TIMESTAMPING_TX_SCHED = (1 << 8),
2916   SOF_TIMESTAMPING_OPT_CMSG = (1 << 10),
2917   SOF_TIMESTAMPING_OPT_TSONLY = (1 << 11),
2918 };
2919
2920 class TestErrMessageCallback : public folly::AsyncSocket::ErrMessageCallback {
2921  public:
2922   TestErrMessageCallback()
2923       : exception_(folly::AsyncSocketException::UNKNOWN, "none") {}
2924
2925   void errMessage(const cmsghdr& cmsg) noexcept override {
2926     if (cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_TIMESTAMPING) {
2927       gotTimestamp_++;
2928       checkResetCallback();
2929     } else if (
2930         (cmsg.cmsg_level == SOL_IP && cmsg.cmsg_type == IP_RECVERR) ||
2931         (cmsg.cmsg_level == SOL_IPV6 && cmsg.cmsg_type == IPV6_RECVERR)) {
2932       gotByteSeq_++;
2933       checkResetCallback();
2934     }
2935   }
2936
2937   void errMessageError(
2938       const folly::AsyncSocketException& ex) noexcept override {
2939     exception_ = ex;
2940   }
2941
2942   void checkResetCallback() noexcept {
2943     if (socket_ != nullptr && resetAfter_ != -1 &&
2944         gotTimestamp_ + gotByteSeq_ == resetAfter_) {
2945       socket_->setErrMessageCB(nullptr);
2946     }
2947   }
2948
2949   folly::AsyncSocket* socket_{nullptr};
2950   folly::AsyncSocketException exception_;
2951   int gotTimestamp_{0};
2952   int gotByteSeq_{0};
2953   int resetAfter_{-1};
2954 };
2955
2956 TEST(AsyncSocketTest, ErrMessageCallback) {
2957   TestServer server;
2958
2959   // connect()
2960   EventBase evb;
2961   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
2962
2963   ConnCallback ccb;
2964   socket->connect(&ccb, server.getAddress(), 30);
2965   LOG(INFO) << "Client socket fd=" << socket->getFd();
2966
2967   // Let the socket
2968   evb.loop();
2969
2970   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
2971
2972   // Set read callback to keep the socket subscribed for event
2973   // notifications. Though we're no planning to read anything from
2974   // this side of the connection.
2975   ReadCallback rcb(1);
2976   socket->setReadCB(&rcb);
2977
2978   // Set up timestamp callbacks
2979   TestErrMessageCallback errMsgCB;
2980   socket->setErrMessageCB(&errMsgCB);
2981   ASSERT_EQ(socket->getErrMessageCallback(),
2982             static_cast<folly::AsyncSocket::ErrMessageCallback*>(&errMsgCB));
2983
2984   errMsgCB.socket_ = socket.get();
2985   errMsgCB.resetAfter_ = 3;
2986
2987   // Enable timestamp notifications
2988   ASSERT_GT(socket->getFd(), 0);
2989   int flags = SOF_TIMESTAMPING_OPT_ID
2990               | SOF_TIMESTAMPING_OPT_TSONLY
2991               | SOF_TIMESTAMPING_SOFTWARE
2992               | SOF_TIMESTAMPING_OPT_CMSG
2993               | SOF_TIMESTAMPING_TX_SCHED;
2994   AsyncSocket::OptionKey tstampingOpt = {SOL_SOCKET, SO_TIMESTAMPING};
2995   EXPECT_EQ(tstampingOpt.apply(socket->getFd(), flags), 0);
2996
2997   // write()
2998   std::vector<uint8_t> wbuf(128, 'a');
2999   WriteCallback wcb;
3000   // Send two packets to get two EOM notifications
3001   socket->write(&wcb, wbuf.data(), wbuf.size() / 2);
3002   socket->write(&wcb, wbuf.data() + wbuf.size() / 2, wbuf.size() / 2);
3003
3004   // Accept the connection.
3005   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
3006   LOG(INFO) << "Server socket fd=" << acceptedSocket->getSocketFD();
3007
3008   // Loop
3009   evb.loopOnce();
3010   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3011
3012   // Check that we can read the data that was written to the socket
3013   std::vector<uint8_t> rbuf(1 + wbuf.size(), 0);
3014   uint32_t bytesRead = acceptedSocket->read(rbuf.data(), rbuf.size());
3015   ASSERT_TRUE(std::equal(wbuf.begin(), wbuf.end(), rbuf.begin()));
3016   ASSERT_EQ(bytesRead, wbuf.size());
3017
3018   // Close both sockets
3019   acceptedSocket->close();
3020   socket->close();
3021
3022   ASSERT_TRUE(socket->isClosedBySelf());
3023   ASSERT_FALSE(socket->isClosedByPeer());
3024
3025   // Check for the timestamp notifications.
3026   ASSERT_EQ(errMsgCB.exception_.type_, folly::AsyncSocketException::UNKNOWN);
3027   ASSERT_GT(errMsgCB.gotByteSeq_, 0);
3028   ASSERT_GT(errMsgCB.gotTimestamp_, 0);
3029   ASSERT_EQ(
3030       errMsgCB.gotByteSeq_ + errMsgCB.gotTimestamp_, errMsgCB.resetAfter_);
3031 }
3032 #endif // FOLLY_HAVE_MSG_ERRQUEUE
3033
3034 TEST(AsyncSocket, PreReceivedData) {
3035   TestServer server;
3036
3037   EventBase evb;
3038   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3039   socket->connect(nullptr, server.getAddress(), 30);
3040   evb.loop();
3041
3042   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
3043
3044   auto acceptedSocket = server.acceptAsync(&evb);
3045
3046   ReadCallback peekCallback(2);
3047   ReadCallback readCallback;
3048   peekCallback.dataAvailableCallback = [&]() {
3049     peekCallback.verifyData("he", 2);
3050     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("h"));
3051     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("e"));
3052     acceptedSocket->setReadCB(nullptr);
3053     acceptedSocket->setReadCB(&readCallback);
3054   };
3055   readCallback.dataAvailableCallback = [&]() {
3056     if (readCallback.dataRead() == 5) {
3057       readCallback.verifyData("hello", 5);
3058       acceptedSocket->setReadCB(nullptr);
3059     }
3060   };
3061
3062   acceptedSocket->setReadCB(&peekCallback);
3063
3064   evb.loop();
3065 }
3066
3067 TEST(AsyncSocket, PreReceivedDataOnly) {
3068   TestServer server;
3069
3070   EventBase evb;
3071   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3072   socket->connect(nullptr, server.getAddress(), 30);
3073   evb.loop();
3074
3075   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
3076
3077   auto acceptedSocket = server.acceptAsync(&evb);
3078
3079   ReadCallback peekCallback;
3080   ReadCallback readCallback;
3081   peekCallback.dataAvailableCallback = [&]() {
3082     peekCallback.verifyData("hello", 5);
3083     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
3084     acceptedSocket->setReadCB(&readCallback);
3085   };
3086   readCallback.dataAvailableCallback = [&]() {
3087     readCallback.verifyData("hello", 5);
3088     acceptedSocket->setReadCB(nullptr);
3089   };
3090
3091   acceptedSocket->setReadCB(&peekCallback);
3092
3093   evb.loop();
3094 }
3095
3096 TEST(AsyncSocket, PreReceivedDataPartial) {
3097   TestServer server;
3098
3099   EventBase evb;
3100   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3101   socket->connect(nullptr, server.getAddress(), 30);
3102   evb.loop();
3103
3104   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
3105
3106   auto acceptedSocket = server.acceptAsync(&evb);
3107
3108   ReadCallback peekCallback;
3109   ReadCallback smallReadCallback(3);
3110   ReadCallback normalReadCallback;
3111   peekCallback.dataAvailableCallback = [&]() {
3112     peekCallback.verifyData("hello", 5);
3113     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
3114     acceptedSocket->setReadCB(&smallReadCallback);
3115   };
3116   smallReadCallback.dataAvailableCallback = [&]() {
3117     smallReadCallback.verifyData("hel", 3);
3118     acceptedSocket->setReadCB(&normalReadCallback);
3119   };
3120   normalReadCallback.dataAvailableCallback = [&]() {
3121     normalReadCallback.verifyData("lo", 2);
3122     acceptedSocket->setReadCB(nullptr);
3123   };
3124
3125   acceptedSocket->setReadCB(&peekCallback);
3126
3127   evb.loop();
3128 }
3129
3130 TEST(AsyncSocket, PreReceivedDataTakeover) {
3131   TestServer server;
3132
3133   EventBase evb;
3134   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3135   socket->connect(nullptr, server.getAddress(), 30);
3136   evb.loop();
3137
3138   socket->writeChain(nullptr, IOBuf::copyBuffer("hello"));
3139
3140   auto acceptedSocket =
3141       AsyncSocket::UniquePtr(new AsyncSocket(&evb, server.acceptFD()));
3142   AsyncSocket::UniquePtr takeoverSocket;
3143
3144   ReadCallback peekCallback(3);
3145   ReadCallback readCallback;
3146   peekCallback.dataAvailableCallback = [&]() {
3147     peekCallback.verifyData("hel", 3);
3148     acceptedSocket->setPreReceivedData(IOBuf::copyBuffer("hello"));
3149     acceptedSocket->setReadCB(nullptr);
3150     takeoverSocket =
3151         AsyncSocket::UniquePtr(new AsyncSocket(std::move(acceptedSocket)));
3152     takeoverSocket->setReadCB(&readCallback);
3153   };
3154   readCallback.dataAvailableCallback = [&]() {
3155     readCallback.verifyData("hello", 5);
3156     takeoverSocket->setReadCB(nullptr);
3157   };
3158
3159   acceptedSocket->setReadCB(&peekCallback);
3160
3161   evb.loop();
3162 }
3163
3164 #ifdef MSG_NOSIGNAL
3165 TEST(AsyncSocketTest, SendMessageFlags) {
3166   TestServer server;
3167   TestSendMsgParamsCallback sendMsgCB(
3168       MSG_DONTWAIT|MSG_NOSIGNAL|MSG_MORE, 0, nullptr);
3169
3170   // connect()
3171   EventBase evb;
3172   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb);
3173
3174   ConnCallback ccb;
3175   socket->connect(&ccb, server.getAddress(), 30);
3176   std::shared_ptr<BlockingSocket> acceptedSocket = server.accept();
3177
3178   evb.loop();
3179   ASSERT_EQ(ccb.state, STATE_SUCCEEDED);
3180
3181   // Set SendMsgParamsCallback
3182   socket->setSendMsgParamCB(&sendMsgCB);
3183   ASSERT_EQ(socket->getSendMsgParamsCB(), &sendMsgCB);
3184
3185   // Write the first portion of data. This data is expected to be
3186   // sent out immediately.
3187   std::vector<uint8_t> buf(128, 'a');
3188   WriteCallback wcb;
3189   sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL);
3190   socket->write(&wcb, buf.data(), buf.size());
3191   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3192   ASSERT_TRUE(sendMsgCB.queriedFlags_);
3193   ASSERT_FALSE(sendMsgCB.queriedData_);
3194
3195   // Using different flags for the second write operation.
3196   // MSG_MORE flag is expected to delay sending this
3197   // data to the wire.
3198   sendMsgCB.reset(MSG_DONTWAIT | MSG_NOSIGNAL | MSG_MORE);
3199   socket->write(&wcb, buf.data(), buf.size());
3200   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3201   ASSERT_TRUE(sendMsgCB.queriedFlags_);
3202   ASSERT_FALSE(sendMsgCB.queriedData_);
3203
3204   // Make sure the accepted socket saw only the data from
3205   // the first write request.
3206   std::vector<uint8_t> readbuf(2 * buf.size());
3207   uint32_t bytesRead = acceptedSocket->read(readbuf.data(), readbuf.size());
3208   ASSERT_TRUE(std::equal(buf.begin(), buf.end(), readbuf.begin()));
3209   ASSERT_EQ(bytesRead, buf.size());
3210
3211   // Make sure the server got a connection and received the data
3212   acceptedSocket->close();
3213   socket->close();
3214
3215   ASSERT_TRUE(socket->isClosedBySelf());
3216   ASSERT_FALSE(socket->isClosedByPeer());
3217 }
3218
3219 TEST(AsyncSocketTest, SendMessageAncillaryData) {
3220   int fds[2];
3221   EXPECT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fds), 0);
3222
3223   // "Client" socket
3224   int cfd = fds[0];
3225   ASSERT_NE(cfd, -1);
3226
3227   // "Server" socket
3228   int sfd = fds[1];
3229   ASSERT_NE(sfd, -1);
3230   SCOPE_EXIT { close(sfd); };
3231
3232   // Instantiate AsyncSocket object for the connected socket
3233   EventBase evb;
3234   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, cfd);
3235
3236   // Open a temporary file and write a magic string to it
3237   // We'll transfer the file handle to test the message parameters
3238   // callback logic.
3239   TemporaryFile file(StringPiece(),
3240                      fs::path(),
3241                      TemporaryFile::Scope::UNLINK_IMMEDIATELY);
3242   int tmpfd = file.fd();
3243   ASSERT_NE(tmpfd, -1) << "Failed to open a temporary file";
3244   std::string magicString("Magic string");
3245   ASSERT_EQ(write(tmpfd, magicString.c_str(), magicString.length()),
3246             magicString.length());
3247
3248   // Send message
3249   union {
3250     // Space large enough to hold an 'int'
3251     char control[CMSG_SPACE(sizeof(int))];
3252     struct cmsghdr cmh;
3253   } s_u;
3254   s_u.cmh.cmsg_len = CMSG_LEN(sizeof(int));
3255   s_u.cmh.cmsg_level = SOL_SOCKET;
3256   s_u.cmh.cmsg_type = SCM_RIGHTS;
3257   memcpy(CMSG_DATA(&s_u.cmh), &tmpfd, sizeof(int));
3258
3259   // Set up the callback providing message parameters
3260   TestSendMsgParamsCallback sendMsgCB(
3261       MSG_DONTWAIT | MSG_NOSIGNAL, sizeof(s_u.control), s_u.control);
3262   socket->setSendMsgParamCB(&sendMsgCB);
3263
3264   // We must transmit at least 1 byte of real data in order
3265   // to send ancillary data
3266   int s_data = 12345;
3267   WriteCallback wcb;
3268   socket->write(&wcb, &s_data, sizeof(s_data));
3269   ASSERT_EQ(wcb.state, STATE_SUCCEEDED);
3270
3271   // Receive the message
3272   union {
3273     // Space large enough to hold an 'int'
3274     char control[CMSG_SPACE(sizeof(int))];
3275     struct cmsghdr cmh;
3276   } r_u;
3277   struct msghdr msgh;
3278   struct iovec iov;
3279   int r_data = 0;
3280
3281   msgh.msg_control = r_u.control;
3282   msgh.msg_controllen = sizeof(r_u.control);
3283   msgh.msg_name = nullptr;
3284   msgh.msg_namelen = 0;
3285   msgh.msg_iov = &iov;
3286   msgh.msg_iovlen = 1;
3287   iov.iov_base = &r_data;
3288   iov.iov_len = sizeof(r_data);
3289
3290   // Receive data
3291   ASSERT_NE(recvmsg(sfd, &msgh, 0), -1) << "recvmsg failed: " << errno;
3292
3293   // Validate the received message
3294   ASSERT_EQ(r_u.cmh.cmsg_len, CMSG_LEN(sizeof(int)));
3295   ASSERT_EQ(r_u.cmh.cmsg_level, SOL_SOCKET);
3296   ASSERT_EQ(r_u.cmh.cmsg_type, SCM_RIGHTS);
3297   ASSERT_EQ(r_data, s_data);
3298   int fd = 0;
3299   memcpy(&fd, CMSG_DATA(&r_u.cmh), sizeof(int));
3300   ASSERT_NE(fd, 0);
3301   SCOPE_EXIT { close(fd); };
3302
3303   std::vector<uint8_t> transferredMagicString(magicString.length() + 1, 0);
3304
3305   // Reposition to the beginning of the file
3306   ASSERT_EQ(0, lseek(fd, 0, SEEK_SET));
3307
3308   // Read the magic string back, and compare it with the original
3309   ASSERT_EQ(
3310       magicString.length(),
3311       read(fd, transferredMagicString.data(), transferredMagicString.size()));
3312   ASSERT_TRUE(std::equal(
3313       magicString.begin(),
3314       magicString.end(),
3315       transferredMagicString.begin()));
3316 }
3317
3318 TEST(AsyncSocketTest, UnixDomainSocketErrMessageCB) {
3319   // In the latest stable kernel 4.14.3 as of 2017-12-04, Unix Domain
3320   // Socket (UDS) does not support MSG_ERRQUEUE. So
3321   // recvmsg(MSG_ERRQUEUE) will read application data from UDS which
3322   // breaks application message flow.  To avoid this problem,
3323   // AsyncSocket currently disables setErrMessageCB for UDS.
3324   //
3325   // This tests two things for UDS
3326   // 1. setErrMessageCB fails
3327   // 2. recvmsg(MSG_ERRQUEUE) reads application data
3328   //
3329   // Feel free to remove this test if UDS supports MSG_ERRQUEUE in the future.
3330
3331   int fd[2];
3332   EXPECT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, fd), 0);
3333   ASSERT_NE(fd[0], -1);
3334   ASSERT_NE(fd[1], -1);
3335   SCOPE_EXIT {
3336     close(fd[1]);
3337   };
3338
3339   EXPECT_EQ(fcntl(fd[0], F_SETFL, O_NONBLOCK), 0);
3340   EXPECT_EQ(fcntl(fd[1], F_SETFL, O_NONBLOCK), 0);
3341
3342   EventBase evb;
3343   std::shared_ptr<AsyncSocket> socket = AsyncSocket::newSocket(&evb, fd[0]);
3344
3345   // setErrMessageCB should fail for unix domain socket
3346   TestErrMessageCallback errMsgCB;
3347   ASSERT_NE(&errMsgCB, nullptr);
3348   socket->setErrMessageCB(&errMsgCB);
3349   ASSERT_EQ(socket->getErrMessageCallback(), nullptr);
3350
3351 #ifdef FOLLY_HAVE_MSG_ERRQUEUE
3352   // The following verifies that MSG_ERRQUEUE does not work for UDS,
3353   // and recvmsg reads application data
3354   union {
3355     // Space large enough to hold an 'int'
3356     char control[CMSG_SPACE(sizeof(int))];
3357     struct cmsghdr cmh;
3358   } r_u;
3359   struct msghdr msgh;
3360   struct iovec iov;
3361   int recv_data = 0;
3362
3363   msgh.msg_control = r_u.control;
3364   msgh.msg_controllen = sizeof(r_u.control);
3365   msgh.msg_name = nullptr;
3366   msgh.msg_namelen = 0;
3367   msgh.msg_iov = &iov;
3368   msgh.msg_iovlen = 1;
3369   iov.iov_base = &recv_data;
3370   iov.iov_len = sizeof(recv_data);
3371
3372   // there is no data, recvmsg should fail
3373   EXPECT_EQ(recvmsg(fd[1], &msgh, MSG_ERRQUEUE), -1);
3374   EXPECT_TRUE(errno == EAGAIN || errno == EWOULDBLOCK);
3375
3376   // provide some application data, error queue should be empty if it exists
3377   // However, UDS reads application data as error message
3378   int test_data = 123456;
3379   WriteCallback wcb;
3380   socket->write(&wcb, &test_data, sizeof(test_data));
3381   recv_data = 0;
3382   ASSERT_NE(recvmsg(fd[1], &msgh, MSG_ERRQUEUE), -1);
3383   ASSERT_EQ(recv_data, test_data);
3384 #endif // FOLLY_HAVE_MSG_ERRQUEUE
3385 }
3386 #endif