Print expected/actual thread names when running EventBase logic in unexpected thread
[folly.git] / folly / io / async / AsyncSocket.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSocket.h>
18
19 #include <folly/ExceptionWrapper.h>
20 #include <folly/Portability.h>
21 #include <folly/SocketAddress.h>
22 #include <folly/io/Cursor.h>
23 #include <folly/io/IOBuf.h>
24 #include <folly/io/IOBufQueue.h>
25 #include <folly/portability/Fcntl.h>
26 #include <folly/portability/Sockets.h>
27 #include <folly/portability/SysUio.h>
28 #include <folly/portability/Unistd.h>
29
30 #include <boost/preprocessor/control/if.hpp>
31 #include <errno.h>
32 #include <limits.h>
33 #include <sys/types.h>
34 #include <thread>
35
36 using std::string;
37 using std::unique_ptr;
38
39 namespace fsp = folly::portability::sockets;
40
41 namespace folly {
42
43 static constexpr bool msgErrQueueSupported =
44 #ifdef MSG_ERRQUEUE
45     true;
46 #else
47     false;
48 #endif // MSG_ERRQUEUE
49
50 // static members initializers
51 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
52
53 const AsyncSocketException socketClosedLocallyEx(
54     AsyncSocketException::END_OF_FILE, "socket closed locally");
55 const AsyncSocketException socketShutdownForWritesEx(
56     AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
57
58 // TODO: It might help performance to provide a version of BytesWriteRequest that
59 // users could derive from, so we can avoid the extra allocation for each call
60 // to write()/writev().  We could templatize TFramedAsyncChannel just like the
61 // protocols are currently templatized for transports.
62 //
63 // We would need the version for external users where they provide the iovec
64 // storage space, and only our internal version would allocate it at the end of
65 // the WriteRequest.
66
67 /* The default WriteRequest implementation, used for write(), writev() and
68  * writeChain()
69  *
70  * A new BytesWriteRequest operation is allocated on the heap for all write
71  * operations that cannot be completed immediately.
72  */
73 class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
74  public:
75   static BytesWriteRequest* newRequest(AsyncSocket* socket,
76                                        WriteCallback* callback,
77                                        const iovec* ops,
78                                        uint32_t opCount,
79                                        uint32_t partialWritten,
80                                        uint32_t bytesWritten,
81                                        unique_ptr<IOBuf>&& ioBuf,
82                                        WriteFlags flags) {
83     assert(opCount > 0);
84     // Since we put a variable size iovec array at the end
85     // of each BytesWriteRequest, we have to manually allocate the memory.
86     void* buf = malloc(sizeof(BytesWriteRequest) +
87                        (opCount * sizeof(struct iovec)));
88     if (buf == nullptr) {
89       throw std::bad_alloc();
90     }
91
92     return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
93                                       partialWritten, bytesWritten,
94                                       std::move(ioBuf), flags);
95   }
96
97   void destroy() override {
98     this->~BytesWriteRequest();
99     free(this);
100   }
101
102   WriteResult performWrite() override {
103     WriteFlags writeFlags = flags_;
104     if (getNext() != nullptr) {
105       writeFlags |= WriteFlags::CORK;
106     }
107     auto writeResult = socket_->performWrite(
108         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
109     bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
110     return writeResult;
111   }
112
113   bool isComplete() override {
114     return opsWritten_ == getOpCount();
115   }
116
117   void consume() override {
118     // Advance opIndex_ forward by opsWritten_
119     opIndex_ += opsWritten_;
120     assert(opIndex_ < opCount_);
121
122     // If we've finished writing any IOBufs, release them
123     if (ioBuf_) {
124       for (uint32_t i = opsWritten_; i != 0; --i) {
125         assert(ioBuf_);
126         ioBuf_ = ioBuf_->pop();
127       }
128     }
129
130     // Move partialBytes_ forward into the current iovec buffer
131     struct iovec* currentOp = writeOps_ + opIndex_;
132     assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
133     currentOp->iov_base =
134       reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
135     currentOp->iov_len -= partialBytes_;
136
137     // Increment the totalBytesWritten_ count by bytesWritten_;
138     assert(bytesWritten_ >= 0);
139     totalBytesWritten_ += uint32_t(bytesWritten_);
140   }
141
142  private:
143   BytesWriteRequest(AsyncSocket* socket,
144                     WriteCallback* callback,
145                     const struct iovec* ops,
146                     uint32_t opCount,
147                     uint32_t partialBytes,
148                     uint32_t bytesWritten,
149                     unique_ptr<IOBuf>&& ioBuf,
150                     WriteFlags flags)
151     : AsyncSocket::WriteRequest(socket, callback)
152     , opCount_(opCount)
153     , opIndex_(0)
154     , flags_(flags)
155     , ioBuf_(std::move(ioBuf))
156     , opsWritten_(0)
157     , partialBytes_(partialBytes)
158     , bytesWritten_(bytesWritten) {
159     memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
160   }
161
162   // private destructor, to ensure callers use destroy()
163   ~BytesWriteRequest() override = default;
164
165   const struct iovec* getOps() const {
166     assert(opCount_ > opIndex_);
167     return writeOps_ + opIndex_;
168   }
169
170   uint32_t getOpCount() const {
171     assert(opCount_ > opIndex_);
172     return opCount_ - opIndex_;
173   }
174
175   uint32_t opCount_;            ///< number of entries in writeOps_
176   uint32_t opIndex_;            ///< current index into writeOps_
177   WriteFlags flags_;            ///< set for WriteFlags
178   unique_ptr<IOBuf> ioBuf_;     ///< underlying IOBuf, or nullptr if N/A
179
180   // for consume(), how much we wrote on the last write
181   uint32_t opsWritten_;         ///< complete ops written
182   uint32_t partialBytes_;       ///< partial bytes of incomplete op written
183   ssize_t bytesWritten_;        ///< bytes written altogether
184
185   struct iovec writeOps_[];     ///< write operation(s) list
186 };
187
188 int AsyncSocket::SendMsgParamsCallback::getDefaultFlags(folly::WriteFlags flags)
189                                                                       noexcept {
190   int msg_flags = MSG_DONTWAIT;
191
192 #ifdef MSG_NOSIGNAL // Linux-only
193   msg_flags |= MSG_NOSIGNAL;
194 #ifdef MSG_MORE
195   if (isSet(flags, WriteFlags::CORK)) {
196     // MSG_MORE tells the kernel we have more data to send, so wait for us to
197     // give it the rest of the data rather than immediately sending a partial
198     // frame, even when TCP_NODELAY is enabled.
199     msg_flags |= MSG_MORE;
200   }
201 #endif // MSG_MORE
202 #endif // MSG_NOSIGNAL
203   if (isSet(flags, WriteFlags::EOR)) {
204     // marks that this is the last byte of a record (response)
205     msg_flags |= MSG_EOR;
206   }
207
208   return msg_flags;
209 }
210
211 namespace {
212 static AsyncSocket::SendMsgParamsCallback defaultSendMsgParamsCallback;
213 }
214
215 AsyncSocket::AsyncSocket()
216     : eventBase_(nullptr),
217       writeTimeout_(this, nullptr),
218       ioHandler_(this, nullptr),
219       immediateReadHandler_(this) {
220   VLOG(5) << "new AsyncSocket()";
221   init();
222 }
223
224 AsyncSocket::AsyncSocket(EventBase* evb)
225     : eventBase_(evb),
226       writeTimeout_(this, evb),
227       ioHandler_(this, evb),
228       immediateReadHandler_(this) {
229   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
230   init();
231 }
232
233 AsyncSocket::AsyncSocket(EventBase* evb,
234                            const folly::SocketAddress& address,
235                            uint32_t connectTimeout)
236   : AsyncSocket(evb) {
237   connect(nullptr, address, connectTimeout);
238 }
239
240 AsyncSocket::AsyncSocket(EventBase* evb,
241                            const std::string& ip,
242                            uint16_t port,
243                            uint32_t connectTimeout)
244   : AsyncSocket(evb) {
245   connect(nullptr, ip, port, connectTimeout);
246 }
247
248 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
249     : eventBase_(evb),
250       writeTimeout_(this, evb),
251       ioHandler_(this, evb, fd),
252       immediateReadHandler_(this) {
253   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
254           << fd << ")";
255   init();
256   fd_ = fd;
257   setCloseOnExec();
258   state_ = StateEnum::ESTABLISHED;
259 }
260
261 AsyncSocket::AsyncSocket(AsyncSocket::UniquePtr oldAsyncSocket)
262     : AsyncSocket(oldAsyncSocket->getEventBase(), oldAsyncSocket->detachFd()) {
263   preReceivedData_ = std::move(oldAsyncSocket->preReceivedData_);
264 }
265
266 // init() method, since constructor forwarding isn't supported in most
267 // compilers yet.
268 void AsyncSocket::init() {
269   if (eventBase_) {
270     eventBase_->dcheckIsInEventBaseThread();
271   }
272   shutdownFlags_ = 0;
273   state_ = StateEnum::UNINIT;
274   eventFlags_ = EventHandler::NONE;
275   fd_ = -1;
276   sendTimeout_ = 0;
277   maxReadsPerEvent_ = 16;
278   connectCallback_ = nullptr;
279   errMessageCallback_ = nullptr;
280   readCallback_ = nullptr;
281   writeReqHead_ = nullptr;
282   writeReqTail_ = nullptr;
283   shutdownSocketSet_ = nullptr;
284   appBytesWritten_ = 0;
285   appBytesReceived_ = 0;
286   sendMsgParamCallback_ = &defaultSendMsgParamsCallback;
287 }
288
289 AsyncSocket::~AsyncSocket() {
290   VLOG(7) << "actual destruction of AsyncSocket(this=" << this
291           << ", evb=" << eventBase_ << ", fd=" << fd_
292           << ", state=" << state_ << ")";
293 }
294
295 void AsyncSocket::destroy() {
296   VLOG(5) << "AsyncSocket::destroy(this=" << this << ", evb=" << eventBase_
297           << ", fd=" << fd_ << ", state=" << state_;
298   // When destroy is called, close the socket immediately
299   closeNow();
300
301   // Then call DelayedDestruction::destroy() to take care of
302   // whether or not we need immediate or delayed destruction
303   DelayedDestruction::destroy();
304 }
305
306 int AsyncSocket::detachFd() {
307   VLOG(6) << "AsyncSocket::detachFd(this=" << this << ", fd=" << fd_
308           << ", evb=" << eventBase_ << ", state=" << state_
309           << ", events=" << std::hex << eventFlags_ << ")";
310   // Extract the fd, and set fd_ to -1 first, so closeNow() won't
311   // actually close the descriptor.
312   if (shutdownSocketSet_) {
313     shutdownSocketSet_->remove(fd_);
314   }
315   int fd = fd_;
316   fd_ = -1;
317   // Call closeNow() to invoke all pending callbacks with an error.
318   closeNow();
319   // Update the EventHandler to stop using this fd.
320   // This can only be done after closeNow() unregisters the handler.
321   ioHandler_.changeHandlerFD(-1);
322   return fd;
323 }
324
325 const folly::SocketAddress& AsyncSocket::anyAddress() {
326   static const folly::SocketAddress anyAddress =
327     folly::SocketAddress("0.0.0.0", 0);
328   return anyAddress;
329 }
330
331 void AsyncSocket::setShutdownSocketSet(ShutdownSocketSet* newSS) {
332   if (shutdownSocketSet_ == newSS) {
333     return;
334   }
335   if (shutdownSocketSet_ && fd_ != -1) {
336     shutdownSocketSet_->remove(fd_);
337   }
338   shutdownSocketSet_ = newSS;
339   if (shutdownSocketSet_ && fd_ != -1) {
340     shutdownSocketSet_->add(fd_);
341   }
342 }
343
344 void AsyncSocket::setCloseOnExec() {
345   int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC);
346   if (rv != 0) {
347     auto errnoCopy = errno;
348     throw AsyncSocketException(
349         AsyncSocketException::INTERNAL_ERROR,
350         withAddr("failed to set close-on-exec flag"),
351         errnoCopy);
352   }
353 }
354
355 void AsyncSocket::connect(ConnectCallback* callback,
356                            const folly::SocketAddress& address,
357                            int timeout,
358                            const OptionMap &options,
359                            const folly::SocketAddress& bindAddr) noexcept {
360   DestructorGuard dg(this);
361   eventBase_->dcheckIsInEventBaseThread();
362
363   addr_ = address;
364
365   // Make sure we're in the uninitialized state
366   if (state_ != StateEnum::UNINIT) {
367     return invalidState(callback);
368   }
369
370   connectTimeout_ = std::chrono::milliseconds(timeout);
371   connectStartTime_ = std::chrono::steady_clock::now();
372   // Make connect end time at least >= connectStartTime.
373   connectEndTime_ = connectStartTime_;
374
375   assert(fd_ == -1);
376   state_ = StateEnum::CONNECTING;
377   connectCallback_ = callback;
378
379   sockaddr_storage addrStorage;
380   sockaddr* saddr = reinterpret_cast<sockaddr*>(&addrStorage);
381
382   try {
383     // Create the socket
384     // Technically the first parameter should actually be a protocol family
385     // constant (PF_xxx) rather than an address family (AF_xxx), but the
386     // distinction is mainly just historical.  In pretty much all
387     // implementations the PF_foo and AF_foo constants are identical.
388     fd_ = fsp::socket(address.getFamily(), SOCK_STREAM, 0);
389     if (fd_ < 0) {
390       auto errnoCopy = errno;
391       throw AsyncSocketException(
392           AsyncSocketException::INTERNAL_ERROR,
393           withAddr("failed to create socket"),
394           errnoCopy);
395     }
396     if (shutdownSocketSet_) {
397       shutdownSocketSet_->add(fd_);
398     }
399     ioHandler_.changeHandlerFD(fd_);
400
401     setCloseOnExec();
402
403     // Put the socket in non-blocking mode
404     int flags = fcntl(fd_, F_GETFL, 0);
405     if (flags == -1) {
406       auto errnoCopy = errno;
407       throw AsyncSocketException(
408           AsyncSocketException::INTERNAL_ERROR,
409           withAddr("failed to get socket flags"),
410           errnoCopy);
411     }
412     int rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
413     if (rv == -1) {
414       auto errnoCopy = errno;
415       throw AsyncSocketException(
416           AsyncSocketException::INTERNAL_ERROR,
417           withAddr("failed to put socket in non-blocking mode"),
418           errnoCopy);
419     }
420
421 #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
422     // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
423     rv = fcntl(fd_, F_SETNOSIGPIPE, 1);
424     if (rv == -1) {
425       auto errnoCopy = errno;
426       throw AsyncSocketException(
427           AsyncSocketException::INTERNAL_ERROR,
428           "failed to enable F_SETNOSIGPIPE on socket",
429           errnoCopy);
430     }
431 #endif
432
433     // By default, turn on TCP_NODELAY
434     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
435     // setNoDelay() will log an error message if it fails.
436     if (address.getFamily() != AF_UNIX) {
437       (void)setNoDelay(true);
438     }
439
440     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
441             << ", fd=" << fd_ << ", host=" << address.describe().c_str();
442
443     // bind the socket
444     if (bindAddr != anyAddress()) {
445       int one = 1;
446       if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
447         auto errnoCopy = errno;
448         doClose();
449         throw AsyncSocketException(
450             AsyncSocketException::NOT_OPEN,
451             "failed to setsockopt prior to bind on " + bindAddr.describe(),
452             errnoCopy);
453       }
454
455       bindAddr.getAddress(&addrStorage);
456
457       if (bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
458         auto errnoCopy = errno;
459         doClose();
460         throw AsyncSocketException(
461             AsyncSocketException::NOT_OPEN,
462             "failed to bind to async socket: " + bindAddr.describe(),
463             errnoCopy);
464       }
465     }
466
467     // Apply the additional options if any.
468     for (const auto& opt: options) {
469       rv = opt.first.apply(fd_, opt.second);
470       if (rv != 0) {
471         auto errnoCopy = errno;
472         throw AsyncSocketException(
473             AsyncSocketException::INTERNAL_ERROR,
474             withAddr("failed to set socket option"),
475             errnoCopy);
476       }
477     }
478
479     // Perform the connect()
480     address.getAddress(&addrStorage);
481
482     if (tfoEnabled_) {
483       state_ = StateEnum::FAST_OPEN;
484       tfoAttempted_ = true;
485     } else {
486       if (socketConnect(saddr, addr_.getActualSize()) < 0) {
487         return;
488       }
489     }
490
491     // If we're still here the connect() succeeded immediately.
492     // Fall through to call the callback outside of this try...catch block
493   } catch (const AsyncSocketException& ex) {
494     return failConnect(__func__, ex);
495   } catch (const std::exception& ex) {
496     // shouldn't happen, but handle it just in case
497     VLOG(4) << "AsyncSocket::connect(this=" << this << ", fd=" << fd_
498                << "): unexpected " << typeid(ex).name() << " exception: "
499                << ex.what();
500     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
501                             withAddr(string("unexpected exception: ") +
502                                      ex.what()));
503     return failConnect(__func__, tex);
504   }
505
506   // The connection succeeded immediately
507   // The read callback may not have been set yet, and no writes may be pending
508   // yet, so we don't have to register for any events at the moment.
509   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
510   assert(errMessageCallback_ == nullptr);
511   assert(readCallback_ == nullptr);
512   assert(writeReqHead_ == nullptr);
513   if (state_ != StateEnum::FAST_OPEN) {
514     state_ = StateEnum::ESTABLISHED;
515   }
516   invokeConnectSuccess();
517 }
518
519 int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
520 #if __linux__
521   if (noTransparentTls_) {
522     // Ignore return value, errors are ok
523     setsockopt(fd_, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
524   }
525   if (noTSocks_) {
526     VLOG(4) << "Disabling TSOCKS for fd " << fd_;
527     // Ignore return value, errors are ok
528     setsockopt(fd_, SOL_SOCKET, SO_NO_TSOCKS, nullptr, 0);
529   }
530 #endif
531   int rv = fsp::connect(fd_, saddr, len);
532   if (rv < 0) {
533     auto errnoCopy = errno;
534     if (errnoCopy == EINPROGRESS) {
535       scheduleConnectTimeout();
536       registerForConnectEvents();
537     } else {
538       throw AsyncSocketException(
539           AsyncSocketException::NOT_OPEN,
540           "connect failed (immediately)",
541           errnoCopy);
542     }
543   }
544   return rv;
545 }
546
547 void AsyncSocket::scheduleConnectTimeout() {
548   // Connection in progress.
549   auto timeout = connectTimeout_.count();
550   if (timeout > 0) {
551     // Start a timer in case the connection takes too long.
552     if (!writeTimeout_.scheduleTimeout(uint32_t(timeout))) {
553       throw AsyncSocketException(
554           AsyncSocketException::INTERNAL_ERROR,
555           withAddr("failed to schedule AsyncSocket connect timeout"));
556     }
557   }
558 }
559
560 void AsyncSocket::registerForConnectEvents() {
561   // Register for write events, so we'll
562   // be notified when the connection finishes/fails.
563   // Note that we don't register for a persistent event here.
564   assert(eventFlags_ == EventHandler::NONE);
565   eventFlags_ = EventHandler::WRITE;
566   if (!ioHandler_.registerHandler(eventFlags_)) {
567     throw AsyncSocketException(
568         AsyncSocketException::INTERNAL_ERROR,
569         withAddr("failed to register AsyncSocket connect handler"));
570   }
571 }
572
573 void AsyncSocket::connect(ConnectCallback* callback,
574                            const string& ip, uint16_t port,
575                            int timeout,
576                            const OptionMap &options) noexcept {
577   DestructorGuard dg(this);
578   try {
579     connectCallback_ = callback;
580     connect(callback, folly::SocketAddress(ip, port), timeout, options);
581   } catch (const std::exception& ex) {
582     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
583                             ex.what());
584     return failConnect(__func__, tex);
585   }
586 }
587
588 void AsyncSocket::cancelConnect() {
589   connectCallback_ = nullptr;
590   if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
591     closeNow();
592   }
593 }
594
595 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
596   sendTimeout_ = milliseconds;
597   if (eventBase_) {
598     eventBase_->dcheckIsInEventBaseThread();
599   }
600
601   // If we are currently pending on write requests, immediately update
602   // writeTimeout_ with the new value.
603   if ((eventFlags_ & EventHandler::WRITE) &&
604       (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
605     assert(state_ == StateEnum::ESTABLISHED);
606     assert((shutdownFlags_ & SHUT_WRITE) == 0);
607     if (sendTimeout_ > 0) {
608       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
609         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
610             withAddr("failed to reschedule send timeout in setSendTimeout"));
611         return failWrite(__func__, ex);
612       }
613     } else {
614       writeTimeout_.cancelTimeout();
615     }
616   }
617 }
618
619 void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
620   VLOG(6) << "AsyncSocket::setErrMessageCB() this=" << this
621           << ", fd=" << fd_ << ", callback=" << callback
622           << ", state=" << state_;
623
624   // Short circuit if callback is the same as the existing errMessageCallback_.
625   if (callback == errMessageCallback_) {
626     return;
627   }
628
629   if (!msgErrQueueSupported) {
630       // Per-socket error message queue is not supported on this platform.
631       return invalidState(callback);
632   }
633
634   DestructorGuard dg(this);
635   eventBase_->dcheckIsInEventBaseThread();
636
637   if (callback == nullptr) {
638     // We should be able to reset the callback regardless of the
639     // socket state. It's important to have a reliable callback
640     // cancellation mechanism.
641     errMessageCallback_ = callback;
642     return;
643   }
644
645   switch ((StateEnum)state_) {
646     case StateEnum::CONNECTING:
647     case StateEnum::FAST_OPEN:
648     case StateEnum::ESTABLISHED: {
649       errMessageCallback_ = callback;
650       return;
651     }
652     case StateEnum::CLOSED:
653     case StateEnum::ERROR:
654       // We should never reach here.  SHUT_READ should always be set
655       // if we are in STATE_CLOSED or STATE_ERROR.
656       assert(false);
657       return invalidState(callback);
658     case StateEnum::UNINIT:
659       // We do not allow setReadCallback() to be called before we start
660       // connecting.
661       return invalidState(callback);
662   }
663
664   // We don't put a default case in the switch statement, so that the compiler
665   // will warn us to update the switch statement if a new state is added.
666   return invalidState(callback);
667 }
668
669 AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const {
670   return errMessageCallback_;
671 }
672
673 void AsyncSocket::setSendMsgParamCB(SendMsgParamsCallback* callback) {
674   sendMsgParamCallback_ = callback;
675 }
676
677 AsyncSocket::SendMsgParamsCallback* AsyncSocket::getSendMsgParamsCB() const {
678   return sendMsgParamCallback_;
679 }
680
681 void AsyncSocket::setReadCB(ReadCallback *callback) {
682   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
683           << ", callback=" << callback << ", state=" << state_;
684
685   // Short circuit if callback is the same as the existing readCallback_.
686   //
687   // Note that this is needed for proper functioning during some cleanup cases.
688   // During cleanup we allow setReadCallback(nullptr) to be called even if the
689   // read callback is already unset and we have been detached from an event
690   // base.  This check prevents us from asserting
691   // eventBase_->isInEventBaseThread() when eventBase_ is nullptr.
692   if (callback == readCallback_) {
693     return;
694   }
695
696   /* We are removing a read callback */
697   if (callback == nullptr &&
698       immediateReadHandler_.isLoopCallbackScheduled()) {
699     immediateReadHandler_.cancelLoopCallback();
700   }
701
702   if (shutdownFlags_ & SHUT_READ) {
703     // Reads have already been shut down on this socket.
704     //
705     // Allow setReadCallback(nullptr) to be called in this case, but don't
706     // allow a new callback to be set.
707     //
708     // For example, setReadCallback(nullptr) can happen after an error if we
709     // invoke some other error callback before invoking readError().  The other
710     // error callback that is invoked first may go ahead and clear the read
711     // callback before we get a chance to invoke readError().
712     if (callback != nullptr) {
713       return invalidState(callback);
714     }
715     assert((eventFlags_ & EventHandler::READ) == 0);
716     readCallback_ = nullptr;
717     return;
718   }
719
720   DestructorGuard dg(this);
721   eventBase_->dcheckIsInEventBaseThread();
722
723   switch ((StateEnum)state_) {
724     case StateEnum::CONNECTING:
725     case StateEnum::FAST_OPEN:
726       // For convenience, we allow the read callback to be set while we are
727       // still connecting.  We just store the callback for now.  Once the
728       // connection completes we'll register for read events.
729       readCallback_ = callback;
730       return;
731     case StateEnum::ESTABLISHED:
732     {
733       readCallback_ = callback;
734       uint16_t oldFlags = eventFlags_;
735       if (readCallback_) {
736         eventFlags_ |= EventHandler::READ;
737       } else {
738         eventFlags_ &= ~EventHandler::READ;
739       }
740
741       // Update our registration if our flags have changed
742       if (eventFlags_ != oldFlags) {
743         // We intentionally ignore the return value here.
744         // updateEventRegistration() will move us into the error state if it
745         // fails, and we don't need to do anything else here afterwards.
746         (void)updateEventRegistration();
747       }
748
749       if (readCallback_) {
750         checkForImmediateRead();
751       }
752       return;
753     }
754     case StateEnum::CLOSED:
755     case StateEnum::ERROR:
756       // We should never reach here.  SHUT_READ should always be set
757       // if we are in STATE_CLOSED or STATE_ERROR.
758       assert(false);
759       return invalidState(callback);
760     case StateEnum::UNINIT:
761       // We do not allow setReadCallback() to be called before we start
762       // connecting.
763       return invalidState(callback);
764   }
765
766   // We don't put a default case in the switch statement, so that the compiler
767   // will warn us to update the switch statement if a new state is added.
768   return invalidState(callback);
769 }
770
771 AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
772   return readCallback_;
773 }
774
775 void AsyncSocket::write(WriteCallback* callback,
776                          const void* buf, size_t bytes, WriteFlags flags) {
777   iovec op;
778   op.iov_base = const_cast<void*>(buf);
779   op.iov_len = bytes;
780   writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags);
781 }
782
783 void AsyncSocket::writev(WriteCallback* callback,
784                           const iovec* vec,
785                           size_t count,
786                           WriteFlags flags) {
787   writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags);
788 }
789
790 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
791                               WriteFlags flags) {
792   constexpr size_t kSmallSizeMax = 64;
793   size_t count = buf->countChainElements();
794   if (count <= kSmallSizeMax) {
795     // suppress "warning: variable length array 'vec' is used [-Wvla]"
796     FOLLY_PUSH_WARNING
797     FOLLY_GCC_DISABLE_WARNING("-Wvla")
798     iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
799     FOLLY_POP_WARNING
800
801     writeChainImpl(callback, vec, count, std::move(buf), flags);
802   } else {
803     iovec* vec = new iovec[count];
804     writeChainImpl(callback, vec, count, std::move(buf), flags);
805     delete[] vec;
806   }
807 }
808
809 void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec,
810     size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
811   size_t veclen = buf->fillIov(vec, count);
812   writeImpl(callback, vec, veclen, std::move(buf), flags);
813 }
814
815 void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
816                              size_t count, unique_ptr<IOBuf>&& buf,
817                              WriteFlags flags) {
818   VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
819           << ", callback=" << callback << ", count=" << count
820           << ", state=" << state_;
821   DestructorGuard dg(this);
822   unique_ptr<IOBuf>ioBuf(std::move(buf));
823   eventBase_->dcheckIsInEventBaseThread();
824
825   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
826     // No new writes may be performed after the write side of the socket has
827     // been shutdown.
828     //
829     // We could just call callback->writeError() here to fail just this write.
830     // However, fail hard and use invalidState() to fail all outstanding
831     // callbacks and move the socket into the error state.  There's most likely
832     // a bug in the caller's code, so we abort everything rather than trying to
833     // proceed as best we can.
834     return invalidState(callback);
835   }
836
837   uint32_t countWritten = 0;
838   uint32_t partialWritten = 0;
839   ssize_t bytesWritten = 0;
840   bool mustRegister = false;
841   if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
842       !connecting()) {
843     if (writeReqHead_ == nullptr) {
844       // If we are established and there are no other writes pending,
845       // we can attempt to perform the write immediately.
846       assert(writeReqTail_ == nullptr);
847       assert((eventFlags_ & EventHandler::WRITE) == 0);
848
849       auto writeResult = performWrite(
850           vec, uint32_t(count), flags, &countWritten, &partialWritten);
851       bytesWritten = writeResult.writeReturn;
852       if (bytesWritten < 0) {
853         auto errnoCopy = errno;
854         if (writeResult.exception) {
855           return failWrite(__func__, callback, 0, *writeResult.exception);
856         }
857         AsyncSocketException ex(
858             AsyncSocketException::INTERNAL_ERROR,
859             withAddr("writev failed"),
860             errnoCopy);
861         return failWrite(__func__, callback, 0, ex);
862       } else if (countWritten == count) {
863         // We successfully wrote everything.
864         // Invoke the callback and return.
865         if (callback) {
866           callback->writeSuccess();
867         }
868         return;
869       } else { // continue writing the next writeReq
870         if (bufferCallback_) {
871           bufferCallback_->onEgressBuffered();
872         }
873       }
874       if (!connecting()) {
875         // Writes might put the socket back into connecting state
876         // if TFO is enabled, and using TFO fails.
877         // This means that write timeouts would not be active, however
878         // connect timeouts would affect this stage.
879         mustRegister = true;
880       }
881     }
882   } else if (!connecting()) {
883     // Invalid state for writing
884     return invalidState(callback);
885   }
886
887   // Create a new WriteRequest to add to the queue
888   WriteRequest* req;
889   try {
890     req = BytesWriteRequest::newRequest(
891         this,
892         callback,
893         vec + countWritten,
894         uint32_t(count - countWritten),
895         partialWritten,
896         uint32_t(bytesWritten),
897         std::move(ioBuf),
898         flags);
899   } catch (const std::exception& ex) {
900     // we mainly expect to catch std::bad_alloc here
901     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
902         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
903     return failWrite(__func__, callback, size_t(bytesWritten), tex);
904   }
905   req->consume();
906   if (writeReqTail_ == nullptr) {
907     assert(writeReqHead_ == nullptr);
908     writeReqHead_ = writeReqTail_ = req;
909   } else {
910     writeReqTail_->append(req);
911     writeReqTail_ = req;
912   }
913
914   // Register for write events if are established and not currently
915   // waiting on write events
916   if (mustRegister) {
917     assert(state_ == StateEnum::ESTABLISHED);
918     assert((eventFlags_ & EventHandler::WRITE) == 0);
919     if (!updateEventRegistration(EventHandler::WRITE, 0)) {
920       assert(state_ == StateEnum::ERROR);
921       return;
922     }
923     if (sendTimeout_ > 0) {
924       // Schedule a timeout to fire if the write takes too long.
925       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
926         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
927                                withAddr("failed to schedule send timeout"));
928         return failWrite(__func__, ex);
929       }
930     }
931   }
932 }
933
934 void AsyncSocket::writeRequest(WriteRequest* req) {
935   if (writeReqTail_ == nullptr) {
936     assert(writeReqHead_ == nullptr);
937     writeReqHead_ = writeReqTail_ = req;
938     req->start();
939   } else {
940     writeReqTail_->append(req);
941     writeReqTail_ = req;
942   }
943 }
944
945 void AsyncSocket::close() {
946   VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
947           << ", state=" << state_ << ", shutdownFlags="
948           << std::hex << (int) shutdownFlags_;
949
950   // close() is only different from closeNow() when there are pending writes
951   // that need to drain before we can close.  In all other cases, just call
952   // closeNow().
953   //
954   // Note that writeReqHead_ can be non-nullptr even in STATE_CLOSED or
955   // STATE_ERROR if close() is invoked while a previous closeNow() or failure
956   // is still running.  (e.g., If there are multiple pending writes, and we
957   // call writeError() on the first one, it may call close().  In this case we
958   // will already be in STATE_CLOSED or STATE_ERROR, but the remaining pending
959   // writes will still be in the queue.)
960   //
961   // We only need to drain pending writes if we are still in STATE_CONNECTING
962   // or STATE_ESTABLISHED
963   if ((writeReqHead_ == nullptr) ||
964       !(state_ == StateEnum::CONNECTING ||
965       state_ == StateEnum::ESTABLISHED)) {
966     closeNow();
967     return;
968   }
969
970   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
971   // destroyed until close() returns.
972   DestructorGuard dg(this);
973   eventBase_->dcheckIsInEventBaseThread();
974
975   // Since there are write requests pending, we have to set the
976   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
977   // connect finishes and we finish writing these requests.
978   //
979   // Set SHUT_READ to indicate that reads are shut down, and set the
980   // SHUT_WRITE_PENDING flag to mark that we want to shutdown once the
981   // pending writes complete.
982   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE_PENDING);
983
984   // If a read callback is set, invoke readEOF() immediately to inform it that
985   // the socket has been closed and no more data can be read.
986   if (readCallback_) {
987     // Disable reads if they are enabled
988     if (!updateEventRegistration(0, EventHandler::READ)) {
989       // We're now in the error state; callbacks have been cleaned up
990       assert(state_ == StateEnum::ERROR);
991       assert(readCallback_ == nullptr);
992     } else {
993       ReadCallback* callback = readCallback_;
994       readCallback_ = nullptr;
995       callback->readEOF();
996     }
997   }
998 }
999
1000 void AsyncSocket::closeNow() {
1001   VLOG(5) << "AsyncSocket::closeNow(): this=" << this << ", fd_=" << fd_
1002           << ", state=" << state_ << ", shutdownFlags="
1003           << std::hex << (int) shutdownFlags_;
1004   DestructorGuard dg(this);
1005   if (eventBase_) {
1006     eventBase_->dcheckIsInEventBaseThread();
1007   }
1008
1009   switch (state_) {
1010     case StateEnum::ESTABLISHED:
1011     case StateEnum::CONNECTING:
1012     case StateEnum::FAST_OPEN: {
1013       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1014       state_ = StateEnum::CLOSED;
1015
1016       // If the write timeout was set, cancel it.
1017       writeTimeout_.cancelTimeout();
1018
1019       // If we are registered for I/O events, unregister.
1020       if (eventFlags_ != EventHandler::NONE) {
1021         eventFlags_ = EventHandler::NONE;
1022         if (!updateEventRegistration()) {
1023           // We will have been moved into the error state.
1024           assert(state_ == StateEnum::ERROR);
1025           return;
1026         }
1027       }
1028
1029       if (immediateReadHandler_.isLoopCallbackScheduled()) {
1030         immediateReadHandler_.cancelLoopCallback();
1031       }
1032
1033       if (fd_ >= 0) {
1034         ioHandler_.changeHandlerFD(-1);
1035         doClose();
1036       }
1037
1038       invokeConnectErr(socketClosedLocallyEx);
1039
1040       failAllWrites(socketClosedLocallyEx);
1041
1042       if (readCallback_) {
1043         ReadCallback* callback = readCallback_;
1044         readCallback_ = nullptr;
1045         callback->readEOF();
1046       }
1047       return;
1048     }
1049     case StateEnum::CLOSED:
1050       // Do nothing.  It's possible that we are being called recursively
1051       // from inside a callback that we invoked inside another call to close()
1052       // that is still running.
1053       return;
1054     case StateEnum::ERROR:
1055       // Do nothing.  The error handling code has performed (or is performing)
1056       // cleanup.
1057       return;
1058     case StateEnum::UNINIT:
1059       assert(eventFlags_ == EventHandler::NONE);
1060       assert(connectCallback_ == nullptr);
1061       assert(readCallback_ == nullptr);
1062       assert(writeReqHead_ == nullptr);
1063       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1064       state_ = StateEnum::CLOSED;
1065       return;
1066   }
1067
1068   LOG(DFATAL) << "AsyncSocket::closeNow() (this=" << this << ", fd=" << fd_
1069               << ") called in unknown state " << state_;
1070 }
1071
1072 void AsyncSocket::closeWithReset() {
1073   // Enable SO_LINGER, with the linger timeout set to 0.
1074   // This will trigger a TCP reset when we close the socket.
1075   if (fd_ >= 0) {
1076     struct linger optLinger = {1, 0};
1077     if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
1078       VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
1079               << "on " << fd_ << ": errno=" << errno;
1080     }
1081   }
1082
1083   // Then let closeNow() take care of the rest
1084   closeNow();
1085 }
1086
1087 void AsyncSocket::shutdownWrite() {
1088   VLOG(5) << "AsyncSocket::shutdownWrite(): this=" << this << ", fd=" << fd_
1089           << ", state=" << state_ << ", shutdownFlags="
1090           << std::hex << (int) shutdownFlags_;
1091
1092   // If there are no pending writes, shutdownWrite() is identical to
1093   // shutdownWriteNow().
1094   if (writeReqHead_ == nullptr) {
1095     shutdownWriteNow();
1096     return;
1097   }
1098
1099   eventBase_->dcheckIsInEventBaseThread();
1100
1101   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
1102   // shutdown will be performed once all writes complete.
1103   shutdownFlags_ |= SHUT_WRITE_PENDING;
1104 }
1105
1106 void AsyncSocket::shutdownWriteNow() {
1107   VLOG(5) << "AsyncSocket::shutdownWriteNow(): this=" << this
1108           << ", fd=" << fd_ << ", state=" << state_
1109           << ", shutdownFlags=" << std::hex << (int) shutdownFlags_;
1110
1111   if (shutdownFlags_ & SHUT_WRITE) {
1112     // Writes are already shutdown; nothing else to do.
1113     return;
1114   }
1115
1116   // If SHUT_READ is already set, just call closeNow() to completely
1117   // close the socket.  This can happen if close() was called with writes
1118   // pending, and then shutdownWriteNow() is called before all pending writes
1119   // complete.
1120   if (shutdownFlags_ & SHUT_READ) {
1121     closeNow();
1122     return;
1123   }
1124
1125   DestructorGuard dg(this);
1126   if (eventBase_) {
1127     eventBase_->dcheckIsInEventBaseThread();
1128   }
1129
1130   switch (static_cast<StateEnum>(state_)) {
1131     case StateEnum::ESTABLISHED:
1132     {
1133       shutdownFlags_ |= SHUT_WRITE;
1134
1135       // If the write timeout was set, cancel it.
1136       writeTimeout_.cancelTimeout();
1137
1138       // If we are registered for write events, unregister.
1139       if (!updateEventRegistration(0, EventHandler::WRITE)) {
1140         // We will have been moved into the error state.
1141         assert(state_ == StateEnum::ERROR);
1142         return;
1143       }
1144
1145       // Shutdown writes on the file descriptor
1146       shutdown(fd_, SHUT_WR);
1147
1148       // Immediately fail all write requests
1149       failAllWrites(socketShutdownForWritesEx);
1150       return;
1151     }
1152     case StateEnum::CONNECTING:
1153     {
1154       // Set the SHUT_WRITE_PENDING flag.
1155       // When the connection completes, it will check this flag,
1156       // shutdown the write half of the socket, and then set SHUT_WRITE.
1157       shutdownFlags_ |= SHUT_WRITE_PENDING;
1158
1159       // Immediately fail all write requests
1160       failAllWrites(socketShutdownForWritesEx);
1161       return;
1162     }
1163     case StateEnum::UNINIT:
1164       // Callers normally shouldn't call shutdownWriteNow() before the socket
1165       // even starts connecting.  Nonetheless, go ahead and set
1166       // SHUT_WRITE_PENDING.  Once the socket eventually connects it will
1167       // immediately shut down the write side of the socket.
1168       shutdownFlags_ |= SHUT_WRITE_PENDING;
1169       return;
1170     case StateEnum::FAST_OPEN:
1171       // In fast open state we haven't call connected yet, and if we shutdown
1172       // the writes, we will never try to call connect, so shut everything down
1173       shutdownFlags_ |= SHUT_WRITE;
1174       // Immediately fail all write requests
1175       failAllWrites(socketShutdownForWritesEx);
1176       return;
1177     case StateEnum::CLOSED:
1178     case StateEnum::ERROR:
1179       // We should never get here.  SHUT_WRITE should always be set
1180       // in STATE_CLOSED and STATE_ERROR.
1181       VLOG(4) << "AsyncSocket::shutdownWriteNow() (this=" << this
1182                  << ", fd=" << fd_ << ") in unexpected state " << state_
1183                  << " with SHUT_WRITE not set ("
1184                  << std::hex << (int) shutdownFlags_ << ")";
1185       assert(false);
1186       return;
1187   }
1188
1189   LOG(DFATAL) << "AsyncSocket::shutdownWriteNow() (this=" << this << ", fd="
1190               << fd_ << ") called in unknown state " << state_;
1191 }
1192
1193 bool AsyncSocket::readable() const {
1194   if (fd_ == -1) {
1195     return false;
1196   }
1197   struct pollfd fds[1];
1198   fds[0].fd = fd_;
1199   fds[0].events = POLLIN;
1200   fds[0].revents = 0;
1201   int rc = poll(fds, 1, 0);
1202   return rc == 1;
1203 }
1204
1205 bool AsyncSocket::writable() const {
1206   if (fd_ == -1) {
1207     return false;
1208   }
1209   struct pollfd fds[1];
1210   fds[0].fd = fd_;
1211   fds[0].events = POLLOUT;
1212   fds[0].revents = 0;
1213   int rc = poll(fds, 1, 0);
1214   return rc == 1;
1215 }
1216
1217 bool AsyncSocket::isPending() const {
1218   return ioHandler_.isPending();
1219 }
1220
1221 bool AsyncSocket::hangup() const {
1222   if (fd_ == -1) {
1223     // sanity check, no one should ask for hangup if we are not connected.
1224     assert(false);
1225     return false;
1226   }
1227 #ifdef POLLRDHUP // Linux-only
1228   struct pollfd fds[1];
1229   fds[0].fd = fd_;
1230   fds[0].events = POLLRDHUP|POLLHUP;
1231   fds[0].revents = 0;
1232   poll(fds, 1, 0);
1233   return (fds[0].revents & (POLLRDHUP|POLLHUP)) != 0;
1234 #else
1235   return false;
1236 #endif
1237 }
1238
1239 bool AsyncSocket::good() const {
1240   return (
1241       (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
1242        state_ == StateEnum::ESTABLISHED) &&
1243       (shutdownFlags_ == 0) && (eventBase_ != nullptr));
1244 }
1245
1246 bool AsyncSocket::error() const {
1247   return (state_ == StateEnum::ERROR);
1248 }
1249
1250 void AsyncSocket::attachEventBase(EventBase* eventBase) {
1251   VLOG(5) << "AsyncSocket::attachEventBase(this=" << this << ", fd=" << fd_
1252           << ", old evb=" << eventBase_ << ", new evb=" << eventBase
1253           << ", state=" << state_ << ", events="
1254           << std::hex << eventFlags_ << ")";
1255   assert(eventBase_ == nullptr);
1256   eventBase->dcheckIsInEventBaseThread();
1257
1258   eventBase_ = eventBase;
1259   ioHandler_.attachEventBase(eventBase);
1260   writeTimeout_.attachEventBase(eventBase);
1261   if (evbChangeCb_) {
1262     evbChangeCb_->evbAttached(this);
1263   }
1264 }
1265
1266 void AsyncSocket::detachEventBase() {
1267   VLOG(5) << "AsyncSocket::detachEventBase(this=" << this << ", fd=" << fd_
1268           << ", old evb=" << eventBase_ << ", state=" << state_
1269           << ", events=" << std::hex << eventFlags_ << ")";
1270   assert(eventBase_ != nullptr);
1271   eventBase_->dcheckIsInEventBaseThread();
1272
1273   eventBase_ = nullptr;
1274   ioHandler_.detachEventBase();
1275   writeTimeout_.detachEventBase();
1276   if (evbChangeCb_) {
1277     evbChangeCb_->evbDetached(this);
1278   }
1279 }
1280
1281 bool AsyncSocket::isDetachable() const {
1282   DCHECK(eventBase_ != nullptr);
1283   eventBase_->dcheckIsInEventBaseThread();
1284
1285   return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
1286 }
1287
1288 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
1289   if (!localAddr_.isInitialized()) {
1290     localAddr_.setFromLocalAddress(fd_);
1291   }
1292   *address = localAddr_;
1293 }
1294
1295 void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
1296   if (!addr_.isInitialized()) {
1297     addr_.setFromPeerAddress(fd_);
1298   }
1299   *address = addr_;
1300 }
1301
1302 bool AsyncSocket::getTFOSucceded() const {
1303   return detail::tfo_succeeded(fd_);
1304 }
1305
1306 int AsyncSocket::setNoDelay(bool noDelay) {
1307   if (fd_ < 0) {
1308     VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket "
1309                << this << "(state=" << state_ << ")";
1310     return EINVAL;
1311
1312   }
1313
1314   int value = noDelay ? 1 : 0;
1315   if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
1316     int errnoCopy = errno;
1317     VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket "
1318             << this << " (fd=" << fd_ << ", state=" << state_ << "): "
1319             << strerror(errnoCopy);
1320     return errnoCopy;
1321   }
1322
1323   return 0;
1324 }
1325
1326 int AsyncSocket::setCongestionFlavor(const std::string &cname) {
1327
1328   #ifndef TCP_CONGESTION
1329   #define TCP_CONGESTION  13
1330   #endif
1331
1332   if (fd_ < 0) {
1333     VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
1334                << "socket " << this << "(state=" << state_ << ")";
1335     return EINVAL;
1336
1337   }
1338
1339   if (setsockopt(
1340           fd_,
1341           IPPROTO_TCP,
1342           TCP_CONGESTION,
1343           cname.c_str(),
1344           socklen_t(cname.length() + 1)) != 0) {
1345     int errnoCopy = errno;
1346     VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket "
1347             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1348             << strerror(errnoCopy);
1349     return errnoCopy;
1350   }
1351
1352   return 0;
1353 }
1354
1355 int AsyncSocket::setQuickAck(bool quickack) {
1356   (void)quickack;
1357   if (fd_ < 0) {
1358     VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket "
1359                << this << "(state=" << state_ << ")";
1360     return EINVAL;
1361
1362   }
1363
1364 #ifdef TCP_QUICKACK // Linux-only
1365   int value = quickack ? 1 : 0;
1366   if (setsockopt(fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
1367     int errnoCopy = errno;
1368     VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket"
1369             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1370             << strerror(errnoCopy);
1371     return errnoCopy;
1372   }
1373
1374   return 0;
1375 #else
1376   return ENOSYS;
1377 #endif
1378 }
1379
1380 int AsyncSocket::setSendBufSize(size_t bufsize) {
1381   if (fd_ < 0) {
1382     VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
1383                << this << "(state=" << state_ << ")";
1384     return EINVAL;
1385   }
1386
1387   if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) !=0) {
1388     int errnoCopy = errno;
1389     VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket"
1390             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1391             << strerror(errnoCopy);
1392     return errnoCopy;
1393   }
1394
1395   return 0;
1396 }
1397
1398 int AsyncSocket::setRecvBufSize(size_t bufsize) {
1399   if (fd_ < 0) {
1400     VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
1401                << this << "(state=" << state_ << ")";
1402     return EINVAL;
1403   }
1404
1405   if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) !=0) {
1406     int errnoCopy = errno;
1407     VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket"
1408             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1409             << strerror(errnoCopy);
1410     return errnoCopy;
1411   }
1412
1413   return 0;
1414 }
1415
1416 int AsyncSocket::setTCPProfile(int profd) {
1417   if (fd_ < 0) {
1418     VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket "
1419                << this << "(state=" << state_ << ")";
1420     return EINVAL;
1421   }
1422
1423   if (setsockopt(fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) !=0) {
1424     int errnoCopy = errno;
1425     VLOG(2) << "failed to set socket namespace option on AsyncSocket"
1426             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1427             << strerror(errnoCopy);
1428     return errnoCopy;
1429   }
1430
1431   return 0;
1432 }
1433
1434 void AsyncSocket::ioReady(uint16_t events) noexcept {
1435   VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd=" << fd_
1436           << ", events=" << std::hex << events << ", state=" << state_;
1437   DestructorGuard dg(this);
1438   assert(events & EventHandler::READ_WRITE);
1439   eventBase_->dcheckIsInEventBaseThread();
1440
1441   uint16_t relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
1442   EventBase* originalEventBase = eventBase_;
1443   // If we got there it means that either EventHandler::READ or
1444   // EventHandler::WRITE is set. Any of these flags can
1445   // indicate that there are messages available in the socket
1446   // error message queue.
1447   handleErrMessages();
1448
1449   // Return now if handleErrMessages() detached us from our EventBase
1450   if (eventBase_ != originalEventBase) {
1451     return;
1452   }
1453
1454   if (relevantEvents == EventHandler::READ) {
1455     handleRead();
1456   } else if (relevantEvents == EventHandler::WRITE) {
1457     handleWrite();
1458   } else if (relevantEvents == EventHandler::READ_WRITE) {
1459     // If both read and write events are ready, process writes first.
1460     handleWrite();
1461
1462     // Return now if handleWrite() detached us from our EventBase
1463     if (eventBase_ != originalEventBase) {
1464       return;
1465     }
1466
1467     // Only call handleRead() if a read callback is still installed.
1468     // (It's possible that the read callback was uninstalled during
1469     // handleWrite().)
1470     if (readCallback_) {
1471       handleRead();
1472     }
1473   } else {
1474     VLOG(4) << "AsyncSocket::ioRead() called with unexpected events "
1475                << std::hex << events << "(this=" << this << ")";
1476     abort();
1477   }
1478 }
1479
1480 AsyncSocket::ReadResult
1481 AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
1482   VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
1483           << ", buflen=" << *buflen;
1484
1485   if (preReceivedData_ && !preReceivedData_->empty()) {
1486     VLOG(5) << "AsyncSocket::performRead() this=" << this
1487             << ", reading pre-received data";
1488
1489     io::Cursor cursor(preReceivedData_.get());
1490     auto len = cursor.pullAtMost(*buf, *buflen);
1491
1492     IOBufQueue queue;
1493     queue.append(std::move(preReceivedData_));
1494     queue.trimStart(len);
1495     preReceivedData_ = queue.move();
1496
1497     appBytesReceived_ += len;
1498     return ReadResult(len);
1499   }
1500
1501   ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT);
1502   if (bytes < 0) {
1503     if (errno == EAGAIN || errno == EWOULDBLOCK) {
1504       // No more data to read right now.
1505       return ReadResult(READ_BLOCKING);
1506     } else {
1507       return ReadResult(READ_ERROR);
1508     }
1509   } else {
1510     appBytesReceived_ += bytes;
1511     return ReadResult(bytes);
1512   }
1513 }
1514
1515 void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
1516   // no matter what, buffer should be preapared for non-ssl socket
1517   CHECK(readCallback_);
1518   readCallback_->getReadBuffer(buf, buflen);
1519 }
1520
1521 void AsyncSocket::handleErrMessages() noexcept {
1522   // This method has non-empty implementation only for platforms
1523   // supporting per-socket error queues.
1524   VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
1525           << ", state=" << state_;
1526   if (errMessageCallback_ == nullptr) {
1527     VLOG(7) << "AsyncSocket::handleErrMessages(): "
1528             << "no callback installed - exiting.";
1529     return;
1530   }
1531
1532 #ifdef MSG_ERRQUEUE
1533   uint8_t ctrl[1024];
1534   unsigned char data;
1535   struct msghdr msg;
1536   iovec entry;
1537
1538   entry.iov_base = &data;
1539   entry.iov_len = sizeof(data);
1540   msg.msg_iov = &entry;
1541   msg.msg_iovlen = 1;
1542   msg.msg_name = nullptr;
1543   msg.msg_namelen = 0;
1544   msg.msg_control = ctrl;
1545   msg.msg_controllen = sizeof(ctrl);
1546   msg.msg_flags = 0;
1547
1548   int ret;
1549   while (true) {
1550     ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
1551     VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
1552
1553     if (ret < 0) {
1554       if (errno != EAGAIN) {
1555         auto errnoCopy = errno;
1556         LOG(ERROR) << "::recvmsg exited with code " << ret
1557                    << ", errno: " << errnoCopy;
1558         AsyncSocketException ex(
1559           AsyncSocketException::INTERNAL_ERROR,
1560           withAddr("recvmsg() failed"),
1561           errnoCopy);
1562         failErrMessageRead(__func__, ex);
1563       }
1564       return;
1565     }
1566
1567     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
1568          cmsg != nullptr &&
1569            cmsg->cmsg_len != 0 &&
1570            errMessageCallback_ != nullptr;
1571          cmsg = CMSG_NXTHDR(&msg, cmsg)) {
1572       errMessageCallback_->errMessage(*cmsg);
1573     }
1574   }
1575 #endif //MSG_ERRQUEUE
1576 }
1577
1578 void AsyncSocket::handleRead() noexcept {
1579   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
1580           << ", state=" << state_;
1581   assert(state_ == StateEnum::ESTABLISHED);
1582   assert((shutdownFlags_ & SHUT_READ) == 0);
1583   assert(readCallback_ != nullptr);
1584   assert(eventFlags_ & EventHandler::READ);
1585
1586   // Loop until:
1587   // - a read attempt would block
1588   // - readCallback_ is uninstalled
1589   // - the number of loop iterations exceeds the optional maximum
1590   // - this AsyncSocket is moved to another EventBase
1591   //
1592   // When we invoke readDataAvailable() it may uninstall the readCallback_,
1593   // which is why need to check for it here.
1594   //
1595   // The last bullet point is slightly subtle.  readDataAvailable() may also
1596   // detach this socket from this EventBase.  However, before
1597   // readDataAvailable() returns another thread may pick it up, attach it to
1598   // a different EventBase, and install another readCallback_.  We need to
1599   // exit immediately after readDataAvailable() returns if the eventBase_ has
1600   // changed.  (The caller must perform some sort of locking to transfer the
1601   // AsyncSocket between threads properly.  This will be sufficient to ensure
1602   // that this thread sees the updated eventBase_ variable after
1603   // readDataAvailable() returns.)
1604   uint16_t numReads = 0;
1605   EventBase* originalEventBase = eventBase_;
1606   while (readCallback_ && eventBase_ == originalEventBase) {
1607     // Get the buffer to read into.
1608     void* buf = nullptr;
1609     size_t buflen = 0, offset = 0;
1610     try {
1611       prepareReadBuffer(&buf, &buflen);
1612       VLOG(5) << "prepareReadBuffer() buf=" << buf << ", buflen=" << buflen;
1613     } catch (const AsyncSocketException& ex) {
1614       return failRead(__func__, ex);
1615     } catch (const std::exception& ex) {
1616       AsyncSocketException tex(AsyncSocketException::BAD_ARGS,
1617                               string("ReadCallback::getReadBuffer() "
1618                                      "threw exception: ") +
1619                               ex.what());
1620       return failRead(__func__, tex);
1621     } catch (...) {
1622       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1623                              "ReadCallback::getReadBuffer() threw "
1624                              "non-exception type");
1625       return failRead(__func__, ex);
1626     }
1627     if (!isBufferMovable_ && (buf == nullptr || buflen == 0)) {
1628       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1629                              "ReadCallback::getReadBuffer() returned "
1630                              "empty buffer");
1631       return failRead(__func__, ex);
1632     }
1633
1634     // Perform the read
1635     auto readResult = performRead(&buf, &buflen, &offset);
1636     auto bytesRead = readResult.readReturn;
1637     VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
1638             << bytesRead << " bytes";
1639     if (bytesRead > 0) {
1640       if (!isBufferMovable_) {
1641         readCallback_->readDataAvailable(size_t(bytesRead));
1642       } else {
1643         CHECK(kOpenSslModeMoveBufferOwnership);
1644         VLOG(5) << "this=" << this << ", AsyncSocket::handleRead() got "
1645                 << "buf=" << buf << ", " << bytesRead << "/" << buflen
1646                 << ", offset=" << offset;
1647         auto readBuf = folly::IOBuf::takeOwnership(buf, buflen);
1648         readBuf->trimStart(offset);
1649         readBuf->trimEnd(buflen - offset - bytesRead);
1650         readCallback_->readBufferAvailable(std::move(readBuf));
1651       }
1652
1653       // Fall through and continue around the loop if the read
1654       // completely filled the available buffer.
1655       // Note that readCallback_ may have been uninstalled or changed inside
1656       // readDataAvailable().
1657       if (size_t(bytesRead) < buflen) {
1658         return;
1659       }
1660     } else if (bytesRead == READ_BLOCKING) {
1661         // No more data to read right now.
1662         return;
1663     } else if (bytesRead == READ_ERROR) {
1664       readErr_ = READ_ERROR;
1665       if (readResult.exception) {
1666         return failRead(__func__, *readResult.exception);
1667       }
1668       auto errnoCopy = errno;
1669       AsyncSocketException ex(
1670           AsyncSocketException::INTERNAL_ERROR,
1671           withAddr("recv() failed"),
1672           errnoCopy);
1673       return failRead(__func__, ex);
1674     } else {
1675       assert(bytesRead == READ_EOF);
1676       readErr_ = READ_EOF;
1677       // EOF
1678       shutdownFlags_ |= SHUT_READ;
1679       if (!updateEventRegistration(0, EventHandler::READ)) {
1680         // we've already been moved into STATE_ERROR
1681         assert(state_ == StateEnum::ERROR);
1682         assert(readCallback_ == nullptr);
1683         return;
1684       }
1685
1686       ReadCallback* callback = readCallback_;
1687       readCallback_ = nullptr;
1688       callback->readEOF();
1689       return;
1690     }
1691     if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
1692       if (readCallback_ != nullptr) {
1693         // We might still have data in the socket.
1694         // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
1695         scheduleImmediateRead();
1696       }
1697       return;
1698     }
1699   }
1700 }
1701
1702 /**
1703  * This function attempts to write as much data as possible, until no more data
1704  * can be written.
1705  *
1706  * - If it sends all available data, it unregisters for write events, and stops
1707  *   the writeTimeout_.
1708  *
1709  * - If not all of the data can be sent immediately, it reschedules
1710  *   writeTimeout_ (if a non-zero timeout is set), and ensures the handler is
1711  *   registered for write events.
1712  */
1713 void AsyncSocket::handleWrite() noexcept {
1714   VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
1715           << ", state=" << state_;
1716   DestructorGuard dg(this);
1717
1718   if (state_ == StateEnum::CONNECTING) {
1719     handleConnect();
1720     return;
1721   }
1722
1723   // Normal write
1724   assert(state_ == StateEnum::ESTABLISHED);
1725   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1726   assert(writeReqHead_ != nullptr);
1727
1728   // Loop until we run out of write requests,
1729   // or until this socket is moved to another EventBase.
1730   // (See the comment in handleRead() explaining how this can happen.)
1731   EventBase* originalEventBase = eventBase_;
1732   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
1733     auto writeResult = writeReqHead_->performWrite();
1734     if (writeResult.writeReturn < 0) {
1735       if (writeResult.exception) {
1736         return failWrite(__func__, *writeResult.exception);
1737       }
1738       auto errnoCopy = errno;
1739       AsyncSocketException ex(
1740           AsyncSocketException::INTERNAL_ERROR,
1741           withAddr("writev() failed"),
1742           errnoCopy);
1743       return failWrite(__func__, ex);
1744     } else if (writeReqHead_->isComplete()) {
1745       // We finished this request
1746       WriteRequest* req = writeReqHead_;
1747       writeReqHead_ = req->getNext();
1748
1749       if (writeReqHead_ == nullptr) {
1750         writeReqTail_ = nullptr;
1751         // This is the last write request.
1752         // Unregister for write events and cancel the send timer
1753         // before we invoke the callback.  We have to update the state properly
1754         // before calling the callback, since it may want to detach us from
1755         // the EventBase.
1756         if (eventFlags_ & EventHandler::WRITE) {
1757           if (!updateEventRegistration(0, EventHandler::WRITE)) {
1758             assert(state_ == StateEnum::ERROR);
1759             return;
1760           }
1761           // Stop the send timeout
1762           writeTimeout_.cancelTimeout();
1763         }
1764         assert(!writeTimeout_.isScheduled());
1765
1766         // If SHUT_WRITE_PENDING is set, we should shutdown the socket after
1767         // we finish sending the last write request.
1768         //
1769         // We have to do this before invoking writeSuccess(), since
1770         // writeSuccess() may detach us from our EventBase.
1771         if (shutdownFlags_ & SHUT_WRITE_PENDING) {
1772           assert(connectCallback_ == nullptr);
1773           shutdownFlags_ |= SHUT_WRITE;
1774
1775           if (shutdownFlags_ & SHUT_READ) {
1776             // Reads have already been shutdown.  Fully close the socket and
1777             // move to STATE_CLOSED.
1778             //
1779             // Note: This code currently moves us to STATE_CLOSED even if
1780             // close() hasn't ever been called.  This can occur if we have
1781             // received EOF from the peer and shutdownWrite() has been called
1782             // locally.  Should we bother staying in STATE_ESTABLISHED in this
1783             // case, until close() is actually called?  I can't think of a
1784             // reason why we would need to do so.  No other operations besides
1785             // calling close() or destroying the socket can be performed at
1786             // this point.
1787             assert(readCallback_ == nullptr);
1788             state_ = StateEnum::CLOSED;
1789             if (fd_ >= 0) {
1790               ioHandler_.changeHandlerFD(-1);
1791               doClose();
1792             }
1793           } else {
1794             // Reads are still enabled, so we are only doing a half-shutdown
1795             shutdown(fd_, SHUT_WR);
1796           }
1797         }
1798       }
1799
1800       // Invoke the callback
1801       WriteCallback* callback = req->getCallback();
1802       req->destroy();
1803       if (callback) {
1804         callback->writeSuccess();
1805       }
1806       // We'll continue around the loop, trying to write another request
1807     } else {
1808       // Partial write.
1809       if (bufferCallback_) {
1810         bufferCallback_->onEgressBuffered();
1811       }
1812       writeReqHead_->consume();
1813       // Stop after a partial write; it's highly likely that a subsequent write
1814       // attempt will just return EAGAIN.
1815       //
1816       // Ensure that we are registered for write events.
1817       if ((eventFlags_ & EventHandler::WRITE) == 0) {
1818         if (!updateEventRegistration(EventHandler::WRITE, 0)) {
1819           assert(state_ == StateEnum::ERROR);
1820           return;
1821         }
1822       }
1823
1824       // Reschedule the send timeout, since we have made some write progress.
1825       if (sendTimeout_ > 0) {
1826         if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
1827           AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1828               withAddr("failed to reschedule write timeout"));
1829           return failWrite(__func__, ex);
1830         }
1831       }
1832       return;
1833     }
1834   }
1835   if (!writeReqHead_ && bufferCallback_) {
1836     bufferCallback_->onEgressBufferCleared();
1837   }
1838 }
1839
1840 void AsyncSocket::checkForImmediateRead() noexcept {
1841   // We currently don't attempt to perform optimistic reads in AsyncSocket.
1842   // (However, note that some subclasses do override this method.)
1843   //
1844   // Simply calling handleRead() here would be bad, as this would call
1845   // readCallback_->getReadBuffer(), forcing the callback to allocate a read
1846   // buffer even though no data may be available.  This would waste lots of
1847   // memory, since the buffer will sit around unused until the socket actually
1848   // becomes readable.
1849   //
1850   // Checking if the socket is readable now also seems like it would probably
1851   // be a pessimism.  In most cases it probably wouldn't be readable, and we
1852   // would just waste an extra system call.  Even if it is readable, waiting to
1853   // find out from libevent on the next event loop doesn't seem that bad.
1854   //
1855   // The exception to this is if we have pre-received data. In that case there
1856   // is definitely data available immediately.
1857   if (preReceivedData_ && !preReceivedData_->empty()) {
1858     handleRead();
1859   }
1860 }
1861
1862 void AsyncSocket::handleInitialReadWrite() noexcept {
1863   // Our callers should already be holding a DestructorGuard, but grab
1864   // one here just to make sure, in case one of our calling code paths ever
1865   // changes.
1866   DestructorGuard dg(this);
1867   // If we have a readCallback_, make sure we enable read events.  We
1868   // may already be registered for reads if connectSuccess() set
1869   // the read calback.
1870   if (readCallback_ && !(eventFlags_ & EventHandler::READ)) {
1871     assert(state_ == StateEnum::ESTABLISHED);
1872     assert((shutdownFlags_ & SHUT_READ) == 0);
1873     if (!updateEventRegistration(EventHandler::READ, 0)) {
1874       assert(state_ == StateEnum::ERROR);
1875       return;
1876     }
1877     checkForImmediateRead();
1878   } else if (readCallback_ == nullptr) {
1879     // Unregister for read events.
1880     updateEventRegistration(0, EventHandler::READ);
1881   }
1882
1883   // If we have write requests pending, try to send them immediately.
1884   // Since we just finished accepting, there is a very good chance that we can
1885   // write without blocking.
1886   //
1887   // However, we only process them if EventHandler::WRITE is not already set,
1888   // which means that we're already blocked on a write attempt.  (This can
1889   // happen if connectSuccess() called write() before returning.)
1890   if (writeReqHead_ && !(eventFlags_ & EventHandler::WRITE)) {
1891     // Call handleWrite() to perform write processing.
1892     handleWrite();
1893   } else if (writeReqHead_ == nullptr) {
1894     // Unregister for write event.
1895     updateEventRegistration(0, EventHandler::WRITE);
1896   }
1897 }
1898
1899 void AsyncSocket::handleConnect() noexcept {
1900   VLOG(5) << "AsyncSocket::handleConnect() this=" << this << ", fd=" << fd_
1901           << ", state=" << state_;
1902   assert(state_ == StateEnum::CONNECTING);
1903   // SHUT_WRITE can never be set while we are still connecting;
1904   // SHUT_WRITE_PENDING may be set, be we only set SHUT_WRITE once the connect
1905   // finishes
1906   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1907
1908   // In case we had a connect timeout, cancel the timeout
1909   writeTimeout_.cancelTimeout();
1910   // We don't use a persistent registration when waiting on a connect event,
1911   // so we have been automatically unregistered now.  Update eventFlags_ to
1912   // reflect reality.
1913   assert(eventFlags_ == EventHandler::WRITE);
1914   eventFlags_ = EventHandler::NONE;
1915
1916   // Call getsockopt() to check if the connect succeeded
1917   int error;
1918   socklen_t len = sizeof(error);
1919   int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
1920   if (rv != 0) {
1921     auto errnoCopy = errno;
1922     AsyncSocketException ex(
1923         AsyncSocketException::INTERNAL_ERROR,
1924         withAddr("error calling getsockopt() after connect"),
1925         errnoCopy);
1926     VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1927                << fd_ << " host=" << addr_.describe()
1928                << ") exception:" << ex.what();
1929     return failConnect(__func__, ex);
1930   }
1931
1932   if (error != 0) {
1933     AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1934                            "connect failed", error);
1935     VLOG(1) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1936             << fd_ << " host=" << addr_.describe()
1937             << ") exception: " << ex.what();
1938     return failConnect(__func__, ex);
1939   }
1940
1941   // Move into STATE_ESTABLISHED
1942   state_ = StateEnum::ESTABLISHED;
1943
1944   // If SHUT_WRITE_PENDING is set and we don't have any write requests to
1945   // perform, immediately shutdown the write half of the socket.
1946   if ((shutdownFlags_ & SHUT_WRITE_PENDING) && writeReqHead_ == nullptr) {
1947     // SHUT_READ shouldn't be set.  If close() is called on the socket while we
1948     // are still connecting we just abort the connect rather than waiting for
1949     // it to complete.
1950     assert((shutdownFlags_ & SHUT_READ) == 0);
1951     shutdown(fd_, SHUT_WR);
1952     shutdownFlags_ |= SHUT_WRITE;
1953   }
1954
1955   VLOG(7) << "AsyncSocket " << this << ": fd " << fd_
1956           << "successfully connected; state=" << state_;
1957
1958   // Remember the EventBase we are attached to, before we start invoking any
1959   // callbacks (since the callbacks may call detachEventBase()).
1960   EventBase* originalEventBase = eventBase_;
1961
1962   invokeConnectSuccess();
1963   // Note that the connect callback may have changed our state.
1964   // (set or unset the read callback, called write(), closed the socket, etc.)
1965   // The following code needs to handle these situations correctly.
1966   //
1967   // If the socket has been closed, readCallback_ and writeReqHead_ will
1968   // always be nullptr, so that will prevent us from trying to read or write.
1969   //
1970   // The main thing to check for is if eventBase_ is still originalEventBase.
1971   // If not, we have been detached from this event base, so we shouldn't
1972   // perform any more operations.
1973   if (eventBase_ != originalEventBase) {
1974     return;
1975   }
1976
1977   handleInitialReadWrite();
1978 }
1979
1980 void AsyncSocket::timeoutExpired() noexcept {
1981   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
1982           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
1983   DestructorGuard dg(this);
1984   eventBase_->dcheckIsInEventBaseThread();
1985
1986   if (state_ == StateEnum::CONNECTING) {
1987     // connect() timed out
1988     // Unregister for I/O events.
1989     if (connectCallback_) {
1990       AsyncSocketException ex(
1991           AsyncSocketException::TIMED_OUT,
1992           folly::sformat(
1993               "connect timed out after {}ms", connectTimeout_.count()));
1994       failConnect(__func__, ex);
1995     } else {
1996       // we faced a connect error without a connect callback, which could
1997       // happen due to TFO.
1998       AsyncSocketException ex(
1999           AsyncSocketException::TIMED_OUT, "write timed out during connection");
2000       failWrite(__func__, ex);
2001     }
2002   } else {
2003     // a normal write operation timed out
2004     AsyncSocketException ex(
2005         AsyncSocketException::TIMED_OUT,
2006         folly::sformat("write timed out after {}ms", sendTimeout_));
2007     failWrite(__func__, ex);
2008   }
2009 }
2010
2011 ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
2012   return detail::tfo_sendmsg(fd, msg, msg_flags);
2013 }
2014
2015 AsyncSocket::WriteResult
2016 AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
2017   ssize_t totalWritten = 0;
2018   if (state_ == StateEnum::FAST_OPEN) {
2019     sockaddr_storage addr;
2020     auto len = addr_.getAddress(&addr);
2021     msg->msg_name = &addr;
2022     msg->msg_namelen = len;
2023     totalWritten = tfoSendMsg(fd_, msg, msg_flags);
2024     if (totalWritten >= 0) {
2025       tfoFinished_ = true;
2026       state_ = StateEnum::ESTABLISHED;
2027       // We schedule this asynchrously so that we don't end up
2028       // invoking initial read or write while a write is in progress.
2029       scheduleInitialReadWrite();
2030     } else if (errno == EINPROGRESS) {
2031       VLOG(4) << "TFO falling back to connecting";
2032       // A normal sendmsg doesn't return EINPROGRESS, however
2033       // TFO might fallback to connecting if there is no
2034       // cookie.
2035       state_ = StateEnum::CONNECTING;
2036       try {
2037         scheduleConnectTimeout();
2038         registerForConnectEvents();
2039       } catch (const AsyncSocketException& ex) {
2040         return WriteResult(
2041             WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
2042       }
2043       // Let's fake it that no bytes were written and return an errno.
2044       errno = EAGAIN;
2045       totalWritten = -1;
2046     } else if (errno == EOPNOTSUPP) {
2047       // Try falling back to connecting.
2048       VLOG(4) << "TFO not supported";
2049       state_ = StateEnum::CONNECTING;
2050       try {
2051         int ret = socketConnect((const sockaddr*)&addr, len);
2052         if (ret == 0) {
2053           // connect succeeded immediately
2054           // Treat this like no data was written.
2055           state_ = StateEnum::ESTABLISHED;
2056           scheduleInitialReadWrite();
2057         }
2058         // If there was no exception during connections,
2059         // we would return that no bytes were written.
2060         errno = EAGAIN;
2061         totalWritten = -1;
2062       } catch (const AsyncSocketException& ex) {
2063         return WriteResult(
2064             WRITE_ERROR, std::make_unique<AsyncSocketException>(ex));
2065       }
2066     } else if (errno == EAGAIN) {
2067       // Normally sendmsg would indicate that the write would block.
2068       // However in the fast open case, it would indicate that sendmsg
2069       // fell back to a connect. This is a return code from connect()
2070       // instead, and is an error condition indicating no fds available.
2071       return WriteResult(
2072           WRITE_ERROR,
2073           std::make_unique<AsyncSocketException>(
2074               AsyncSocketException::UNKNOWN, "No more free local ports"));
2075     }
2076   } else {
2077     totalWritten = ::sendmsg(fd, msg, msg_flags);
2078   }
2079   return WriteResult(totalWritten);
2080 }
2081
2082 AsyncSocket::WriteResult AsyncSocket::performWrite(
2083     const iovec* vec,
2084     uint32_t count,
2085     WriteFlags flags,
2086     uint32_t* countWritten,
2087     uint32_t* partialWritten) {
2088   // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
2089   // We correctly handle EPIPE errors, so we never want to receive SIGPIPE
2090   // (since it may terminate the program if the main program doesn't explicitly
2091   // ignore it).
2092   struct msghdr msg;
2093   msg.msg_name = nullptr;
2094   msg.msg_namelen = 0;
2095   msg.msg_iov = const_cast<iovec *>(vec);
2096   msg.msg_iovlen = std::min<size_t>(count, kIovMax);
2097   msg.msg_flags = 0;
2098   msg.msg_controllen = sendMsgParamCallback_->getAncillaryDataSize(flags);
2099   CHECK_GE(AsyncSocket::SendMsgParamsCallback::maxAncillaryDataSize,
2100            msg.msg_controllen);
2101
2102   if (msg.msg_controllen != 0) {
2103     msg.msg_control = reinterpret_cast<char*>(alloca(msg.msg_controllen));
2104     sendMsgParamCallback_->getAncillaryData(flags, msg.msg_control);
2105   } else {
2106     msg.msg_control = nullptr;
2107   }
2108   int msg_flags = sendMsgParamCallback_->getFlags(flags);
2109
2110   auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
2111   auto totalWritten = writeResult.writeReturn;
2112   if (totalWritten < 0) {
2113     bool tryAgain = (errno == EAGAIN);
2114 #ifdef __APPLE__
2115     // Apple has a bug where doing a second write on a socket which we
2116     // have opened with TFO causes an ENOTCONN to be thrown. However the
2117     // socket is really connected, so treat ENOTCONN as a EAGAIN until
2118     // this bug is fixed.
2119     tryAgain |= (errno == ENOTCONN);
2120 #endif
2121     if (!writeResult.exception && tryAgain) {
2122       // TCP buffer is full; we can't write any more data right now.
2123       *countWritten = 0;
2124       *partialWritten = 0;
2125       return WriteResult(0);
2126     }
2127     // error
2128     *countWritten = 0;
2129     *partialWritten = 0;
2130     return writeResult;
2131   }
2132
2133   appBytesWritten_ += totalWritten;
2134
2135   uint32_t bytesWritten;
2136   uint32_t n;
2137   for (bytesWritten = uint32_t(totalWritten), n = 0; n < count; ++n) {
2138     const iovec* v = vec + n;
2139     if (v->iov_len > bytesWritten) {
2140       // Partial write finished in the middle of this iovec
2141       *countWritten = n;
2142       *partialWritten = bytesWritten;
2143       return WriteResult(totalWritten);
2144     }
2145
2146     bytesWritten -= uint32_t(v->iov_len);
2147   }
2148
2149   assert(bytesWritten == 0);
2150   *countWritten = n;
2151   *partialWritten = 0;
2152   return WriteResult(totalWritten);
2153 }
2154
2155 /**
2156  * Re-register the EventHandler after eventFlags_ has changed.
2157  *
2158  * If an error occurs, fail() is called to move the socket into the error state
2159  * and call all currently installed callbacks.  After an error, the
2160  * AsyncSocket is completely unregistered.
2161  *
2162  * @return Returns true on success, or false on error.
2163  */
2164 bool AsyncSocket::updateEventRegistration() {
2165   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
2166           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
2167           << ", events=" << std::hex << eventFlags_;
2168   eventBase_->dcheckIsInEventBaseThread();
2169   if (eventFlags_ == EventHandler::NONE) {
2170     ioHandler_.unregisterHandler();
2171     return true;
2172   }
2173
2174   // Always register for persistent events, so we don't have to re-register
2175   // after being called back.
2176   if (!ioHandler_.registerHandler(
2177           uint16_t(eventFlags_ | EventHandler::PERSIST))) {
2178     eventFlags_ = EventHandler::NONE; // we're not registered after error
2179     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
2180         withAddr("failed to update AsyncSocket event registration"));
2181     fail("updateEventRegistration", ex);
2182     return false;
2183   }
2184
2185   return true;
2186 }
2187
2188 bool AsyncSocket::updateEventRegistration(uint16_t enable,
2189                                            uint16_t disable) {
2190   uint16_t oldFlags = eventFlags_;
2191   eventFlags_ |= enable;
2192   eventFlags_ &= ~disable;
2193   if (eventFlags_ == oldFlags) {
2194     return true;
2195   } else {
2196     return updateEventRegistration();
2197   }
2198 }
2199
2200 void AsyncSocket::startFail() {
2201   // startFail() should only be called once
2202   assert(state_ != StateEnum::ERROR);
2203   assert(getDestructorGuardCount() > 0);
2204   state_ = StateEnum::ERROR;
2205   // Ensure that SHUT_READ and SHUT_WRITE are set,
2206   // so all future attempts to read or write will be rejected
2207   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
2208
2209   if (eventFlags_ != EventHandler::NONE) {
2210     eventFlags_ = EventHandler::NONE;
2211     ioHandler_.unregisterHandler();
2212   }
2213   writeTimeout_.cancelTimeout();
2214
2215   if (fd_ >= 0) {
2216     ioHandler_.changeHandlerFD(-1);
2217     doClose();
2218   }
2219 }
2220
2221 void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
2222   invokeConnectErr(ex);
2223   failAllWrites(ex);
2224
2225   if (readCallback_) {
2226     ReadCallback* callback = readCallback_;
2227     readCallback_ = nullptr;
2228     callback->readErr(ex);
2229   }
2230 }
2231
2232 void AsyncSocket::finishFail() {
2233   assert(state_ == StateEnum::ERROR);
2234   assert(getDestructorGuardCount() > 0);
2235
2236   AsyncSocketException ex(
2237       AsyncSocketException::INTERNAL_ERROR,
2238       withAddr("socket closing after error"));
2239   invokeAllErrors(ex);
2240 }
2241
2242 void AsyncSocket::finishFail(const AsyncSocketException& ex) {
2243   assert(state_ == StateEnum::ERROR);
2244   assert(getDestructorGuardCount() > 0);
2245   invokeAllErrors(ex);
2246 }
2247
2248 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
2249   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2250              << state_ << " host=" << addr_.describe()
2251              << "): failed in " << fn << "(): "
2252              << ex.what();
2253   startFail();
2254   finishFail();
2255 }
2256
2257 void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
2258   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2259                << state_ << " host=" << addr_.describe()
2260                << "): failed while connecting in " << fn << "(): "
2261                << ex.what();
2262   startFail();
2263
2264   invokeConnectErr(ex);
2265   finishFail(ex);
2266 }
2267
2268 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
2269   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2270                << state_ << " host=" << addr_.describe()
2271                << "): failed while reading in " << fn << "(): "
2272                << ex.what();
2273   startFail();
2274
2275   if (readCallback_ != nullptr) {
2276     ReadCallback* callback = readCallback_;
2277     readCallback_ = nullptr;
2278     callback->readErr(ex);
2279   }
2280
2281   finishFail();
2282 }
2283
2284 void AsyncSocket::failErrMessageRead(const char* fn,
2285                                      const AsyncSocketException& ex) {
2286   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2287                << state_ << " host=" << addr_.describe()
2288                << "): failed while reading message in " << fn << "(): "
2289                << ex.what();
2290   startFail();
2291
2292   if (errMessageCallback_ != nullptr) {
2293     ErrMessageCallback* callback = errMessageCallback_;
2294     errMessageCallback_ = nullptr;
2295     callback->errMessageError(ex);
2296   }
2297
2298   finishFail();
2299 }
2300
2301 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
2302   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2303                << state_ << " host=" << addr_.describe()
2304                << "): failed while writing in " << fn << "(): "
2305                << ex.what();
2306   startFail();
2307
2308   // Only invoke the first write callback, since the error occurred while
2309   // writing this request.  Let any other pending write callbacks be invoked in
2310   // finishFail().
2311   if (writeReqHead_ != nullptr) {
2312     WriteRequest* req = writeReqHead_;
2313     writeReqHead_ = req->getNext();
2314     WriteCallback* callback = req->getCallback();
2315     uint32_t bytesWritten = req->getTotalBytesWritten();
2316     req->destroy();
2317     if (callback) {
2318       callback->writeErr(bytesWritten, ex);
2319     }
2320   }
2321
2322   finishFail();
2323 }
2324
2325 void AsyncSocket::failWrite(const char* fn, WriteCallback* callback,
2326                              size_t bytesWritten,
2327                              const AsyncSocketException& ex) {
2328   // This version of failWrite() is used when the failure occurs before
2329   // we've added the callback to writeReqHead_.
2330   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2331              << state_ << " host=" << addr_.describe()
2332              <<"): failed while writing in " << fn << "(): "
2333              << ex.what();
2334   startFail();
2335
2336   if (callback != nullptr) {
2337     callback->writeErr(bytesWritten, ex);
2338   }
2339
2340   finishFail();
2341 }
2342
2343 void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
2344   // Invoke writeError() on all write callbacks.
2345   // This is used when writes are forcibly shutdown with write requests
2346   // pending, or when an error occurs with writes pending.
2347   while (writeReqHead_ != nullptr) {
2348     WriteRequest* req = writeReqHead_;
2349     writeReqHead_ = req->getNext();
2350     WriteCallback* callback = req->getCallback();
2351     if (callback) {
2352       callback->writeErr(req->getTotalBytesWritten(), ex);
2353     }
2354     req->destroy();
2355   }
2356 }
2357
2358 void AsyncSocket::invalidState(ConnectCallback* callback) {
2359   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
2360           << "): connect() called in invalid state " << state_;
2361
2362   /*
2363    * The invalidState() methods don't use the normal failure mechanisms,
2364    * since we don't know what state we are in.  We don't want to call
2365    * startFail()/finishFail() recursively if we are already in the middle of
2366    * cleaning up.
2367    */
2368
2369   AsyncSocketException ex(AsyncSocketException::ALREADY_OPEN,
2370                          "connect() called with socket in invalid state");
2371   connectEndTime_ = std::chrono::steady_clock::now();
2372   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2373     if (callback) {
2374       callback->connectErr(ex);
2375     }
2376   } else {
2377     // We can't use failConnect() here since connectCallback_
2378     // may already be set to another callback.  Invoke this ConnectCallback
2379     // here; any other connectCallback_ will be invoked in finishFail()
2380     startFail();
2381     if (callback) {
2382       callback->connectErr(ex);
2383     }
2384     finishFail();
2385   }
2386 }
2387
2388 void AsyncSocket::invalidState(ErrMessageCallback* callback) {
2389   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2390           << "): setErrMessageCB(" << callback
2391           << ") called in invalid state " << state_;
2392
2393   AsyncSocketException ex(
2394       AsyncSocketException::NOT_OPEN,
2395       msgErrQueueSupported
2396       ? "setErrMessageCB() called with socket in invalid state"
2397       : "This platform does not support socket error message notifications");
2398   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2399     if (callback) {
2400       callback->errMessageError(ex);
2401     }
2402   } else {
2403     startFail();
2404     if (callback) {
2405       callback->errMessageError(ex);
2406     }
2407     finishFail();
2408   }
2409 }
2410
2411 void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
2412   connectEndTime_ = std::chrono::steady_clock::now();
2413   if (connectCallback_) {
2414     ConnectCallback* callback = connectCallback_;
2415     connectCallback_ = nullptr;
2416     callback->connectErr(ex);
2417   }
2418 }
2419
2420 void AsyncSocket::invokeConnectSuccess() {
2421   connectEndTime_ = std::chrono::steady_clock::now();
2422   if (connectCallback_) {
2423     ConnectCallback* callback = connectCallback_;
2424     connectCallback_ = nullptr;
2425     callback->connectSuccess();
2426   }
2427 }
2428
2429 void AsyncSocket::invalidState(ReadCallback* callback) {
2430   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2431              << "): setReadCallback(" << callback
2432              << ") called in invalid state " << state_;
2433
2434   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2435                          "setReadCallback() called with socket in "
2436                          "invalid state");
2437   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2438     if (callback) {
2439       callback->readErr(ex);
2440     }
2441   } else {
2442     startFail();
2443     if (callback) {
2444       callback->readErr(ex);
2445     }
2446     finishFail();
2447   }
2448 }
2449
2450 void AsyncSocket::invalidState(WriteCallback* callback) {
2451   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2452              << "): write() called in invalid state " << state_;
2453
2454   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2455                          withAddr("write() called with socket in invalid state"));
2456   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2457     if (callback) {
2458       callback->writeErr(0, ex);
2459     }
2460   } else {
2461     startFail();
2462     if (callback) {
2463       callback->writeErr(0, ex);
2464     }
2465     finishFail();
2466   }
2467 }
2468
2469 void AsyncSocket::doClose() {
2470   if (fd_ == -1) return;
2471   if (shutdownSocketSet_) {
2472     shutdownSocketSet_->close(fd_);
2473   } else {
2474     ::close(fd_);
2475   }
2476   fd_ = -1;
2477 }
2478
2479 std::ostream& operator << (std::ostream& os,
2480                            const AsyncSocket::StateEnum& state) {
2481   os << static_cast<int>(state);
2482   return os;
2483 }
2484
2485 std::string AsyncSocket::withAddr(const std::string& s) {
2486   // Don't use addr_ directly because it may not be initialized
2487   // e.g. if constructed from fd
2488   folly::SocketAddress peer, local;
2489   try {
2490     getPeerAddress(&peer);
2491     getLocalAddress(&local);
2492   } catch (const std::exception&) {
2493     // ignore
2494   } catch (...) {
2495     // ignore
2496   }
2497   return s + " (peer=" + peer.describe() + ", local=" + local.describe() + ")";
2498 }
2499
2500 void AsyncSocket::setBufferCallback(BufferCallback* cb) {
2501   bufferCallback_ = cb;
2502 }
2503
2504 } // folly