d88f9d851deff798f25930e09ff2d7ecd6a44650
[folly.git] / folly / io / async / AsyncSocket.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSocket.h>
18
19 #include <folly/ExceptionWrapper.h>
20 #include <folly/SocketAddress.h>
21 #include <folly/io/IOBuf.h>
22 #include <folly/Portability.h>
23 #include <folly/portability/Fcntl.h>
24 #include <folly/portability/Sockets.h>
25 #include <folly/portability/SysUio.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <boost/preprocessor/control/if.hpp>
29 #include <errno.h>
30 #include <limits.h>
31 #include <sys/types.h>
32 #include <thread>
33
34 using std::string;
35 using std::unique_ptr;
36
37 namespace fsp = folly::portability::sockets;
38
39 namespace folly {
40
41 static constexpr bool msgErrQueueSupported =
42 #ifdef MSG_ERRQUEUE
43     true;
44 #else
45     false;
46 #endif // MSG_ERRQUEUE
47
48 // static members initializers
49 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
50
51 const AsyncSocketException socketClosedLocallyEx(
52     AsyncSocketException::END_OF_FILE, "socket closed locally");
53 const AsyncSocketException socketShutdownForWritesEx(
54     AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
55
56 // TODO: It might help performance to provide a version of BytesWriteRequest that
57 // users could derive from, so we can avoid the extra allocation for each call
58 // to write()/writev().  We could templatize TFramedAsyncChannel just like the
59 // protocols are currently templatized for transports.
60 //
61 // We would need the version for external users where they provide the iovec
62 // storage space, and only our internal version would allocate it at the end of
63 // the WriteRequest.
64
65 /* The default WriteRequest implementation, used for write(), writev() and
66  * writeChain()
67  *
68  * A new BytesWriteRequest operation is allocated on the heap for all write
69  * operations that cannot be completed immediately.
70  */
71 class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
72  public:
73   static BytesWriteRequest* newRequest(AsyncSocket* socket,
74                                        WriteCallback* callback,
75                                        const iovec* ops,
76                                        uint32_t opCount,
77                                        uint32_t partialWritten,
78                                        uint32_t bytesWritten,
79                                        unique_ptr<IOBuf>&& ioBuf,
80                                        WriteFlags flags) {
81     assert(opCount > 0);
82     // Since we put a variable size iovec array at the end
83     // of each BytesWriteRequest, we have to manually allocate the memory.
84     void* buf = malloc(sizeof(BytesWriteRequest) +
85                        (opCount * sizeof(struct iovec)));
86     if (buf == nullptr) {
87       throw std::bad_alloc();
88     }
89
90     return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
91                                       partialWritten, bytesWritten,
92                                       std::move(ioBuf), flags);
93   }
94
95   void destroy() override {
96     this->~BytesWriteRequest();
97     free(this);
98   }
99
100   WriteResult performWrite() override {
101     WriteFlags writeFlags = flags_;
102     if (getNext() != nullptr) {
103       writeFlags |= WriteFlags::CORK;
104     }
105     auto writeResult = socket_->performWrite(
106         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
107     bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
108     return writeResult;
109   }
110
111   bool isComplete() override {
112     return opsWritten_ == getOpCount();
113   }
114
115   void consume() override {
116     // Advance opIndex_ forward by opsWritten_
117     opIndex_ += opsWritten_;
118     assert(opIndex_ < opCount_);
119
120     // If we've finished writing any IOBufs, release them
121     if (ioBuf_) {
122       for (uint32_t i = opsWritten_; i != 0; --i) {
123         assert(ioBuf_);
124         ioBuf_ = ioBuf_->pop();
125       }
126     }
127
128     // Move partialBytes_ forward into the current iovec buffer
129     struct iovec* currentOp = writeOps_ + opIndex_;
130     assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
131     currentOp->iov_base =
132       reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
133     currentOp->iov_len -= partialBytes_;
134
135     // Increment the totalBytesWritten_ count by bytesWritten_;
136     assert(bytesWritten_ >= 0);
137     totalBytesWritten_ += uint32_t(bytesWritten_);
138   }
139
140  private:
141   BytesWriteRequest(AsyncSocket* socket,
142                     WriteCallback* callback,
143                     const struct iovec* ops,
144                     uint32_t opCount,
145                     uint32_t partialBytes,
146                     uint32_t bytesWritten,
147                     unique_ptr<IOBuf>&& ioBuf,
148                     WriteFlags flags)
149     : AsyncSocket::WriteRequest(socket, callback)
150     , opCount_(opCount)
151     , opIndex_(0)
152     , flags_(flags)
153     , ioBuf_(std::move(ioBuf))
154     , opsWritten_(0)
155     , partialBytes_(partialBytes)
156     , bytesWritten_(bytesWritten) {
157     memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
158   }
159
160   // private destructor, to ensure callers use destroy()
161   ~BytesWriteRequest() override = default;
162
163   const struct iovec* getOps() const {
164     assert(opCount_ > opIndex_);
165     return writeOps_ + opIndex_;
166   }
167
168   uint32_t getOpCount() const {
169     assert(opCount_ > opIndex_);
170     return opCount_ - opIndex_;
171   }
172
173   uint32_t opCount_;            ///< number of entries in writeOps_
174   uint32_t opIndex_;            ///< current index into writeOps_
175   WriteFlags flags_;            ///< set for WriteFlags
176   unique_ptr<IOBuf> ioBuf_;     ///< underlying IOBuf, or nullptr if N/A
177
178   // for consume(), how much we wrote on the last write
179   uint32_t opsWritten_;         ///< complete ops written
180   uint32_t partialBytes_;       ///< partial bytes of incomplete op written
181   ssize_t bytesWritten_;        ///< bytes written altogether
182
183   struct iovec writeOps_[];     ///< write operation(s) list
184 };
185
186 AsyncSocket::AsyncSocket()
187     : eventBase_(nullptr),
188       writeTimeout_(this, nullptr),
189       ioHandler_(this, nullptr),
190       immediateReadHandler_(this) {
191   VLOG(5) << "new AsyncSocket()";
192   init();
193 }
194
195 AsyncSocket::AsyncSocket(EventBase* evb)
196     : eventBase_(evb),
197       writeTimeout_(this, evb),
198       ioHandler_(this, evb),
199       immediateReadHandler_(this) {
200   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
201   init();
202 }
203
204 AsyncSocket::AsyncSocket(EventBase* evb,
205                            const folly::SocketAddress& address,
206                            uint32_t connectTimeout)
207   : AsyncSocket(evb) {
208   connect(nullptr, address, connectTimeout);
209 }
210
211 AsyncSocket::AsyncSocket(EventBase* evb,
212                            const std::string& ip,
213                            uint16_t port,
214                            uint32_t connectTimeout)
215   : AsyncSocket(evb) {
216   connect(nullptr, ip, port, connectTimeout);
217 }
218
219 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
220     : eventBase_(evb),
221       writeTimeout_(this, evb),
222       ioHandler_(this, evb, fd),
223       immediateReadHandler_(this) {
224   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
225           << fd << ")";
226   init();
227   fd_ = fd;
228   setCloseOnExec();
229   state_ = StateEnum::ESTABLISHED;
230 }
231
232 // init() method, since constructor forwarding isn't supported in most
233 // compilers yet.
234 void AsyncSocket::init() {
235   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
236   shutdownFlags_ = 0;
237   state_ = StateEnum::UNINIT;
238   eventFlags_ = EventHandler::NONE;
239   fd_ = -1;
240   sendTimeout_ = 0;
241   maxReadsPerEvent_ = 16;
242   connectCallback_ = nullptr;
243   errMessageCallback_ = nullptr;
244   readCallback_ = nullptr;
245   writeReqHead_ = nullptr;
246   writeReqTail_ = nullptr;
247   shutdownSocketSet_ = nullptr;
248   appBytesWritten_ = 0;
249   appBytesReceived_ = 0;
250 }
251
252 AsyncSocket::~AsyncSocket() {
253   VLOG(7) << "actual destruction of AsyncSocket(this=" << this
254           << ", evb=" << eventBase_ << ", fd=" << fd_
255           << ", state=" << state_ << ")";
256 }
257
258 void AsyncSocket::destroy() {
259   VLOG(5) << "AsyncSocket::destroy(this=" << this << ", evb=" << eventBase_
260           << ", fd=" << fd_ << ", state=" << state_;
261   // When destroy is called, close the socket immediately
262   closeNow();
263
264   // Then call DelayedDestruction::destroy() to take care of
265   // whether or not we need immediate or delayed destruction
266   DelayedDestruction::destroy();
267 }
268
269 int AsyncSocket::detachFd() {
270   VLOG(6) << "AsyncSocket::detachFd(this=" << this << ", fd=" << fd_
271           << ", evb=" << eventBase_ << ", state=" << state_
272           << ", events=" << std::hex << eventFlags_ << ")";
273   // Extract the fd, and set fd_ to -1 first, so closeNow() won't
274   // actually close the descriptor.
275   if (shutdownSocketSet_) {
276     shutdownSocketSet_->remove(fd_);
277   }
278   int fd = fd_;
279   fd_ = -1;
280   // Call closeNow() to invoke all pending callbacks with an error.
281   closeNow();
282   // Update the EventHandler to stop using this fd.
283   // This can only be done after closeNow() unregisters the handler.
284   ioHandler_.changeHandlerFD(-1);
285   return fd;
286 }
287
288 const folly::SocketAddress& AsyncSocket::anyAddress() {
289   static const folly::SocketAddress anyAddress =
290     folly::SocketAddress("0.0.0.0", 0);
291   return anyAddress;
292 }
293
294 void AsyncSocket::setShutdownSocketSet(ShutdownSocketSet* newSS) {
295   if (shutdownSocketSet_ == newSS) {
296     return;
297   }
298   if (shutdownSocketSet_ && fd_ != -1) {
299     shutdownSocketSet_->remove(fd_);
300   }
301   shutdownSocketSet_ = newSS;
302   if (shutdownSocketSet_ && fd_ != -1) {
303     shutdownSocketSet_->add(fd_);
304   }
305 }
306
307 void AsyncSocket::setCloseOnExec() {
308   int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC);
309   if (rv != 0) {
310     auto errnoCopy = errno;
311     throw AsyncSocketException(
312         AsyncSocketException::INTERNAL_ERROR,
313         withAddr("failed to set close-on-exec flag"),
314         errnoCopy);
315   }
316 }
317
318 void AsyncSocket::connect(ConnectCallback* callback,
319                            const folly::SocketAddress& address,
320                            int timeout,
321                            const OptionMap &options,
322                            const folly::SocketAddress& bindAddr) noexcept {
323   DestructorGuard dg(this);
324   assert(eventBase_->isInEventBaseThread());
325
326   addr_ = address;
327
328   // Make sure we're in the uninitialized state
329   if (state_ != StateEnum::UNINIT) {
330     return invalidState(callback);
331   }
332
333   connectTimeout_ = std::chrono::milliseconds(timeout);
334   connectStartTime_ = std::chrono::steady_clock::now();
335   // Make connect end time at least >= connectStartTime.
336   connectEndTime_ = connectStartTime_;
337
338   assert(fd_ == -1);
339   state_ = StateEnum::CONNECTING;
340   connectCallback_ = callback;
341
342   sockaddr_storage addrStorage;
343   sockaddr* saddr = reinterpret_cast<sockaddr*>(&addrStorage);
344
345   try {
346     // Create the socket
347     // Technically the first parameter should actually be a protocol family
348     // constant (PF_xxx) rather than an address family (AF_xxx), but the
349     // distinction is mainly just historical.  In pretty much all
350     // implementations the PF_foo and AF_foo constants are identical.
351     fd_ = fsp::socket(address.getFamily(), SOCK_STREAM, 0);
352     if (fd_ < 0) {
353       auto errnoCopy = errno;
354       throw AsyncSocketException(
355           AsyncSocketException::INTERNAL_ERROR,
356           withAddr("failed to create socket"),
357           errnoCopy);
358     }
359     if (shutdownSocketSet_) {
360       shutdownSocketSet_->add(fd_);
361     }
362     ioHandler_.changeHandlerFD(fd_);
363
364     setCloseOnExec();
365
366     // Put the socket in non-blocking mode
367     int flags = fcntl(fd_, F_GETFL, 0);
368     if (flags == -1) {
369       auto errnoCopy = errno;
370       throw AsyncSocketException(
371           AsyncSocketException::INTERNAL_ERROR,
372           withAddr("failed to get socket flags"),
373           errnoCopy);
374     }
375     int rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
376     if (rv == -1) {
377       auto errnoCopy = errno;
378       throw AsyncSocketException(
379           AsyncSocketException::INTERNAL_ERROR,
380           withAddr("failed to put socket in non-blocking mode"),
381           errnoCopy);
382     }
383
384 #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
385     // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
386     rv = fcntl(fd_, F_SETNOSIGPIPE, 1);
387     if (rv == -1) {
388       auto errnoCopy = errno;
389       throw AsyncSocketException(
390           AsyncSocketException::INTERNAL_ERROR,
391           "failed to enable F_SETNOSIGPIPE on socket",
392           errnoCopy);
393     }
394 #endif
395
396     // By default, turn on TCP_NODELAY
397     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
398     // setNoDelay() will log an error message if it fails.
399     if (address.getFamily() != AF_UNIX) {
400       (void)setNoDelay(true);
401     }
402
403     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
404             << ", fd=" << fd_ << ", host=" << address.describe().c_str();
405
406     // bind the socket
407     if (bindAddr != anyAddress()) {
408       int one = 1;
409       if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
410         auto errnoCopy = errno;
411         doClose();
412         throw AsyncSocketException(
413             AsyncSocketException::NOT_OPEN,
414             "failed to setsockopt prior to bind on " + bindAddr.describe(),
415             errnoCopy);
416       }
417
418       bindAddr.getAddress(&addrStorage);
419
420       if (bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
421         auto errnoCopy = errno;
422         doClose();
423         throw AsyncSocketException(
424             AsyncSocketException::NOT_OPEN,
425             "failed to bind to async socket: " + bindAddr.describe(),
426             errnoCopy);
427       }
428     }
429
430     // Apply the additional options if any.
431     for (const auto& opt: options) {
432       rv = opt.first.apply(fd_, opt.second);
433       if (rv != 0) {
434         auto errnoCopy = errno;
435         throw AsyncSocketException(
436             AsyncSocketException::INTERNAL_ERROR,
437             withAddr("failed to set socket option"),
438             errnoCopy);
439       }
440     }
441
442     // Perform the connect()
443     address.getAddress(&addrStorage);
444
445     if (tfoEnabled_) {
446       state_ = StateEnum::FAST_OPEN;
447       tfoAttempted_ = true;
448     } else {
449       if (socketConnect(saddr, addr_.getActualSize()) < 0) {
450         return;
451       }
452     }
453
454     // If we're still here the connect() succeeded immediately.
455     // Fall through to call the callback outside of this try...catch block
456   } catch (const AsyncSocketException& ex) {
457     return failConnect(__func__, ex);
458   } catch (const std::exception& ex) {
459     // shouldn't happen, but handle it just in case
460     VLOG(4) << "AsyncSocket::connect(this=" << this << ", fd=" << fd_
461                << "): unexpected " << typeid(ex).name() << " exception: "
462                << ex.what();
463     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
464                             withAddr(string("unexpected exception: ") +
465                                      ex.what()));
466     return failConnect(__func__, tex);
467   }
468
469   // The connection succeeded immediately
470   // The read callback may not have been set yet, and no writes may be pending
471   // yet, so we don't have to register for any events at the moment.
472   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
473   assert(errMessageCallback_ == nullptr);
474   assert(readCallback_ == nullptr);
475   assert(writeReqHead_ == nullptr);
476   if (state_ != StateEnum::FAST_OPEN) {
477     state_ = StateEnum::ESTABLISHED;
478   }
479   invokeConnectSuccess();
480 }
481
482 int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
483 #if __linux__
484   if (noTransparentTls_) {
485     // Ignore return value, errors are ok
486     setsockopt(fd_, SOL_SOCKET, SO_NO_TRANSPARENT_TLS, nullptr, 0);
487   }
488 #endif
489   int rv = fsp::connect(fd_, saddr, len);
490   if (rv < 0) {
491     auto errnoCopy = errno;
492     if (errnoCopy == EINPROGRESS) {
493       scheduleConnectTimeout();
494       registerForConnectEvents();
495     } else {
496       throw AsyncSocketException(
497           AsyncSocketException::NOT_OPEN,
498           "connect failed (immediately)",
499           errnoCopy);
500     }
501   }
502   return rv;
503 }
504
505 void AsyncSocket::scheduleConnectTimeout() {
506   // Connection in progress.
507   auto timeout = connectTimeout_.count();
508   if (timeout > 0) {
509     // Start a timer in case the connection takes too long.
510     if (!writeTimeout_.scheduleTimeout(uint32_t(timeout))) {
511       throw AsyncSocketException(
512           AsyncSocketException::INTERNAL_ERROR,
513           withAddr("failed to schedule AsyncSocket connect timeout"));
514     }
515   }
516 }
517
518 void AsyncSocket::registerForConnectEvents() {
519   // Register for write events, so we'll
520   // be notified when the connection finishes/fails.
521   // Note that we don't register for a persistent event here.
522   assert(eventFlags_ == EventHandler::NONE);
523   eventFlags_ = EventHandler::WRITE;
524   if (!ioHandler_.registerHandler(eventFlags_)) {
525     throw AsyncSocketException(
526         AsyncSocketException::INTERNAL_ERROR,
527         withAddr("failed to register AsyncSocket connect handler"));
528   }
529 }
530
531 void AsyncSocket::connect(ConnectCallback* callback,
532                            const string& ip, uint16_t port,
533                            int timeout,
534                            const OptionMap &options) noexcept {
535   DestructorGuard dg(this);
536   try {
537     connectCallback_ = callback;
538     connect(callback, folly::SocketAddress(ip, port), timeout, options);
539   } catch (const std::exception& ex) {
540     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
541                             ex.what());
542     return failConnect(__func__, tex);
543   }
544 }
545
546 void AsyncSocket::cancelConnect() {
547   connectCallback_ = nullptr;
548   if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
549     closeNow();
550   }
551 }
552
553 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
554   sendTimeout_ = milliseconds;
555   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
556
557   // If we are currently pending on write requests, immediately update
558   // writeTimeout_ with the new value.
559   if ((eventFlags_ & EventHandler::WRITE) &&
560       (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
561     assert(state_ == StateEnum::ESTABLISHED);
562     assert((shutdownFlags_ & SHUT_WRITE) == 0);
563     if (sendTimeout_ > 0) {
564       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
565         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
566             withAddr("failed to reschedule send timeout in setSendTimeout"));
567         return failWrite(__func__, ex);
568       }
569     } else {
570       writeTimeout_.cancelTimeout();
571     }
572   }
573 }
574
575 void AsyncSocket::setErrMessageCB(ErrMessageCallback* callback) {
576   VLOG(6) << "AsyncSocket::setErrMessageCB() this=" << this
577           << ", fd=" << fd_ << ", callback=" << callback
578           << ", state=" << state_;
579
580   // Short circuit if callback is the same as the existing timestampCallback_.
581   if (callback == errMessageCallback_) {
582     return;
583   }
584
585   if (!msgErrQueueSupported) {
586       // Per-socket error message queue is not supported on this platform.
587       return invalidState(callback);
588   }
589
590   DestructorGuard dg(this);
591   assert(eventBase_->isInEventBaseThread());
592
593   switch ((StateEnum)state_) {
594     case StateEnum::CONNECTING:
595     case StateEnum::FAST_OPEN:
596     case StateEnum::ESTABLISHED: {
597       errMessageCallback_ = callback;
598       return;
599     }
600     case StateEnum::CLOSED:
601     case StateEnum::ERROR:
602       // We should never reach here.  SHUT_READ should always be set
603       // if we are in STATE_CLOSED or STATE_ERROR.
604       assert(false);
605       return invalidState(callback);
606     case StateEnum::UNINIT:
607       // We do not allow setReadCallback() to be called before we start
608       // connecting.
609       return invalidState(callback);
610   }
611
612   // We don't put a default case in the switch statement, so that the compiler
613   // will warn us to update the switch statement if a new state is added.
614   return invalidState(callback);
615 }
616
617 AsyncSocket::ErrMessageCallback* AsyncSocket::getErrMessageCallback() const {
618   return errMessageCallback_;
619 }
620
621 void AsyncSocket::setReadCB(ReadCallback *callback) {
622   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
623           << ", callback=" << callback << ", state=" << state_;
624
625   // Short circuit if callback is the same as the existing readCallback_.
626   //
627   // Note that this is needed for proper functioning during some cleanup cases.
628   // During cleanup we allow setReadCallback(nullptr) to be called even if the
629   // read callback is already unset and we have been detached from an event
630   // base.  This check prevents us from asserting
631   // eventBase_->isInEventBaseThread() when eventBase_ is nullptr.
632   if (callback == readCallback_) {
633     return;
634   }
635
636   /* We are removing a read callback */
637   if (callback == nullptr &&
638       immediateReadHandler_.isLoopCallbackScheduled()) {
639     immediateReadHandler_.cancelLoopCallback();
640   }
641
642   if (shutdownFlags_ & SHUT_READ) {
643     // Reads have already been shut down on this socket.
644     //
645     // Allow setReadCallback(nullptr) to be called in this case, but don't
646     // allow a new callback to be set.
647     //
648     // For example, setReadCallback(nullptr) can happen after an error if we
649     // invoke some other error callback before invoking readError().  The other
650     // error callback that is invoked first may go ahead and clear the read
651     // callback before we get a chance to invoke readError().
652     if (callback != nullptr) {
653       return invalidState(callback);
654     }
655     assert((eventFlags_ & EventHandler::READ) == 0);
656     readCallback_ = nullptr;
657     return;
658   }
659
660   DestructorGuard dg(this);
661   assert(eventBase_->isInEventBaseThread());
662
663   switch ((StateEnum)state_) {
664     case StateEnum::CONNECTING:
665     case StateEnum::FAST_OPEN:
666       // For convenience, we allow the read callback to be set while we are
667       // still connecting.  We just store the callback for now.  Once the
668       // connection completes we'll register for read events.
669       readCallback_ = callback;
670       return;
671     case StateEnum::ESTABLISHED:
672     {
673       readCallback_ = callback;
674       uint16_t oldFlags = eventFlags_;
675       if (readCallback_) {
676         eventFlags_ |= EventHandler::READ;
677       } else {
678         eventFlags_ &= ~EventHandler::READ;
679       }
680
681       // Update our registration if our flags have changed
682       if (eventFlags_ != oldFlags) {
683         // We intentionally ignore the return value here.
684         // updateEventRegistration() will move us into the error state if it
685         // fails, and we don't need to do anything else here afterwards.
686         (void)updateEventRegistration();
687       }
688
689       if (readCallback_) {
690         checkForImmediateRead();
691       }
692       return;
693     }
694     case StateEnum::CLOSED:
695     case StateEnum::ERROR:
696       // We should never reach here.  SHUT_READ should always be set
697       // if we are in STATE_CLOSED or STATE_ERROR.
698       assert(false);
699       return invalidState(callback);
700     case StateEnum::UNINIT:
701       // We do not allow setReadCallback() to be called before we start
702       // connecting.
703       return invalidState(callback);
704   }
705
706   // We don't put a default case in the switch statement, so that the compiler
707   // will warn us to update the switch statement if a new state is added.
708   return invalidState(callback);
709 }
710
711 AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
712   return readCallback_;
713 }
714
715 void AsyncSocket::write(WriteCallback* callback,
716                          const void* buf, size_t bytes, WriteFlags flags) {
717   iovec op;
718   op.iov_base = const_cast<void*>(buf);
719   op.iov_len = bytes;
720   writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags);
721 }
722
723 void AsyncSocket::writev(WriteCallback* callback,
724                           const iovec* vec,
725                           size_t count,
726                           WriteFlags flags) {
727   writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags);
728 }
729
730 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
731                               WriteFlags flags) {
732   constexpr size_t kSmallSizeMax = 64;
733   size_t count = buf->countChainElements();
734   if (count <= kSmallSizeMax) {
735     // suppress "warning: variable length array 'vec' is used [-Wvla]"
736     FOLLY_PUSH_WARNING;
737     FOLLY_GCC_DISABLE_WARNING(vla);
738     iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
739     FOLLY_POP_WARNING;
740
741     writeChainImpl(callback, vec, count, std::move(buf), flags);
742   } else {
743     iovec* vec = new iovec[count];
744     writeChainImpl(callback, vec, count, std::move(buf), flags);
745     delete[] vec;
746   }
747 }
748
749 void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec,
750     size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
751   size_t veclen = buf->fillIov(vec, count);
752   writeImpl(callback, vec, veclen, std::move(buf), flags);
753 }
754
755 void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
756                              size_t count, unique_ptr<IOBuf>&& buf,
757                              WriteFlags flags) {
758   VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
759           << ", callback=" << callback << ", count=" << count
760           << ", state=" << state_;
761   DestructorGuard dg(this);
762   unique_ptr<IOBuf>ioBuf(std::move(buf));
763   assert(eventBase_->isInEventBaseThread());
764
765   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
766     // No new writes may be performed after the write side of the socket has
767     // been shutdown.
768     //
769     // We could just call callback->writeError() here to fail just this write.
770     // However, fail hard and use invalidState() to fail all outstanding
771     // callbacks and move the socket into the error state.  There's most likely
772     // a bug in the caller's code, so we abort everything rather than trying to
773     // proceed as best we can.
774     return invalidState(callback);
775   }
776
777   uint32_t countWritten = 0;
778   uint32_t partialWritten = 0;
779   ssize_t bytesWritten = 0;
780   bool mustRegister = false;
781   if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
782       !connecting()) {
783     if (writeReqHead_ == nullptr) {
784       // If we are established and there are no other writes pending,
785       // we can attempt to perform the write immediately.
786       assert(writeReqTail_ == nullptr);
787       assert((eventFlags_ & EventHandler::WRITE) == 0);
788
789       auto writeResult = performWrite(
790           vec, uint32_t(count), flags, &countWritten, &partialWritten);
791       bytesWritten = writeResult.writeReturn;
792       if (bytesWritten < 0) {
793         auto errnoCopy = errno;
794         if (writeResult.exception) {
795           return failWrite(__func__, callback, 0, *writeResult.exception);
796         }
797         AsyncSocketException ex(
798             AsyncSocketException::INTERNAL_ERROR,
799             withAddr("writev failed"),
800             errnoCopy);
801         return failWrite(__func__, callback, 0, ex);
802       } else if (countWritten == count) {
803         // We successfully wrote everything.
804         // Invoke the callback and return.
805         if (callback) {
806           callback->writeSuccess();
807         }
808         return;
809       } else { // continue writing the next writeReq
810         if (bufferCallback_) {
811           bufferCallback_->onEgressBuffered();
812         }
813       }
814       if (!connecting()) {
815         // Writes might put the socket back into connecting state
816         // if TFO is enabled, and using TFO fails.
817         // This means that write timeouts would not be active, however
818         // connect timeouts would affect this stage.
819         mustRegister = true;
820       }
821     }
822   } else if (!connecting()) {
823     // Invalid state for writing
824     return invalidState(callback);
825   }
826
827   // Create a new WriteRequest to add to the queue
828   WriteRequest* req;
829   try {
830     req = BytesWriteRequest::newRequest(
831         this,
832         callback,
833         vec + countWritten,
834         uint32_t(count - countWritten),
835         partialWritten,
836         uint32_t(bytesWritten),
837         std::move(ioBuf),
838         flags);
839   } catch (const std::exception& ex) {
840     // we mainly expect to catch std::bad_alloc here
841     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
842         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
843     return failWrite(__func__, callback, size_t(bytesWritten), tex);
844   }
845   req->consume();
846   if (writeReqTail_ == nullptr) {
847     assert(writeReqHead_ == nullptr);
848     writeReqHead_ = writeReqTail_ = req;
849   } else {
850     writeReqTail_->append(req);
851     writeReqTail_ = req;
852   }
853
854   // Register for write events if are established and not currently
855   // waiting on write events
856   if (mustRegister) {
857     assert(state_ == StateEnum::ESTABLISHED);
858     assert((eventFlags_ & EventHandler::WRITE) == 0);
859     if (!updateEventRegistration(EventHandler::WRITE, 0)) {
860       assert(state_ == StateEnum::ERROR);
861       return;
862     }
863     if (sendTimeout_ > 0) {
864       // Schedule a timeout to fire if the write takes too long.
865       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
866         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
867                                withAddr("failed to schedule send timeout"));
868         return failWrite(__func__, ex);
869       }
870     }
871   }
872 }
873
874 void AsyncSocket::writeRequest(WriteRequest* req) {
875   if (writeReqTail_ == nullptr) {
876     assert(writeReqHead_ == nullptr);
877     writeReqHead_ = writeReqTail_ = req;
878     req->start();
879   } else {
880     writeReqTail_->append(req);
881     writeReqTail_ = req;
882   }
883 }
884
885 void AsyncSocket::close() {
886   VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
887           << ", state=" << state_ << ", shutdownFlags="
888           << std::hex << (int) shutdownFlags_;
889
890   // close() is only different from closeNow() when there are pending writes
891   // that need to drain before we can close.  In all other cases, just call
892   // closeNow().
893   //
894   // Note that writeReqHead_ can be non-nullptr even in STATE_CLOSED or
895   // STATE_ERROR if close() is invoked while a previous closeNow() or failure
896   // is still running.  (e.g., If there are multiple pending writes, and we
897   // call writeError() on the first one, it may call close().  In this case we
898   // will already be in STATE_CLOSED or STATE_ERROR, but the remaining pending
899   // writes will still be in the queue.)
900   //
901   // We only need to drain pending writes if we are still in STATE_CONNECTING
902   // or STATE_ESTABLISHED
903   if ((writeReqHead_ == nullptr) ||
904       !(state_ == StateEnum::CONNECTING ||
905       state_ == StateEnum::ESTABLISHED)) {
906     closeNow();
907     return;
908   }
909
910   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
911   // destroyed until close() returns.
912   DestructorGuard dg(this);
913   assert(eventBase_->isInEventBaseThread());
914
915   // Since there are write requests pending, we have to set the
916   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
917   // connect finishes and we finish writing these requests.
918   //
919   // Set SHUT_READ to indicate that reads are shut down, and set the
920   // SHUT_WRITE_PENDING flag to mark that we want to shutdown once the
921   // pending writes complete.
922   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE_PENDING);
923
924   // If a read callback is set, invoke readEOF() immediately to inform it that
925   // the socket has been closed and no more data can be read.
926   if (readCallback_) {
927     // Disable reads if they are enabled
928     if (!updateEventRegistration(0, EventHandler::READ)) {
929       // We're now in the error state; callbacks have been cleaned up
930       assert(state_ == StateEnum::ERROR);
931       assert(readCallback_ == nullptr);
932     } else {
933       ReadCallback* callback = readCallback_;
934       readCallback_ = nullptr;
935       callback->readEOF();
936     }
937   }
938 }
939
940 void AsyncSocket::closeNow() {
941   VLOG(5) << "AsyncSocket::closeNow(): this=" << this << ", fd_=" << fd_
942           << ", state=" << state_ << ", shutdownFlags="
943           << std::hex << (int) shutdownFlags_;
944   DestructorGuard dg(this);
945   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
946
947   switch (state_) {
948     case StateEnum::ESTABLISHED:
949     case StateEnum::CONNECTING:
950     case StateEnum::FAST_OPEN: {
951       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
952       state_ = StateEnum::CLOSED;
953
954       // If the write timeout was set, cancel it.
955       writeTimeout_.cancelTimeout();
956
957       // If we are registered for I/O events, unregister.
958       if (eventFlags_ != EventHandler::NONE) {
959         eventFlags_ = EventHandler::NONE;
960         if (!updateEventRegistration()) {
961           // We will have been moved into the error state.
962           assert(state_ == StateEnum::ERROR);
963           return;
964         }
965       }
966
967       if (immediateReadHandler_.isLoopCallbackScheduled()) {
968         immediateReadHandler_.cancelLoopCallback();
969       }
970
971       if (fd_ >= 0) {
972         ioHandler_.changeHandlerFD(-1);
973         doClose();
974       }
975
976       invokeConnectErr(socketClosedLocallyEx);
977
978       failAllWrites(socketClosedLocallyEx);
979
980       if (readCallback_) {
981         ReadCallback* callback = readCallback_;
982         readCallback_ = nullptr;
983         callback->readEOF();
984       }
985       return;
986     }
987     case StateEnum::CLOSED:
988       // Do nothing.  It's possible that we are being called recursively
989       // from inside a callback that we invoked inside another call to close()
990       // that is still running.
991       return;
992     case StateEnum::ERROR:
993       // Do nothing.  The error handling code has performed (or is performing)
994       // cleanup.
995       return;
996     case StateEnum::UNINIT:
997       assert(eventFlags_ == EventHandler::NONE);
998       assert(connectCallback_ == nullptr);
999       assert(readCallback_ == nullptr);
1000       assert(writeReqHead_ == nullptr);
1001       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1002       state_ = StateEnum::CLOSED;
1003       return;
1004   }
1005
1006   LOG(DFATAL) << "AsyncSocket::closeNow() (this=" << this << ", fd=" << fd_
1007               << ") called in unknown state " << state_;
1008 }
1009
1010 void AsyncSocket::closeWithReset() {
1011   // Enable SO_LINGER, with the linger timeout set to 0.
1012   // This will trigger a TCP reset when we close the socket.
1013   if (fd_ >= 0) {
1014     struct linger optLinger = {1, 0};
1015     if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
1016       VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
1017               << "on " << fd_ << ": errno=" << errno;
1018     }
1019   }
1020
1021   // Then let closeNow() take care of the rest
1022   closeNow();
1023 }
1024
1025 void AsyncSocket::shutdownWrite() {
1026   VLOG(5) << "AsyncSocket::shutdownWrite(): this=" << this << ", fd=" << fd_
1027           << ", state=" << state_ << ", shutdownFlags="
1028           << std::hex << (int) shutdownFlags_;
1029
1030   // If there are no pending writes, shutdownWrite() is identical to
1031   // shutdownWriteNow().
1032   if (writeReqHead_ == nullptr) {
1033     shutdownWriteNow();
1034     return;
1035   }
1036
1037   assert(eventBase_->isInEventBaseThread());
1038
1039   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
1040   // shutdown will be performed once all writes complete.
1041   shutdownFlags_ |= SHUT_WRITE_PENDING;
1042 }
1043
1044 void AsyncSocket::shutdownWriteNow() {
1045   VLOG(5) << "AsyncSocket::shutdownWriteNow(): this=" << this
1046           << ", fd=" << fd_ << ", state=" << state_
1047           << ", shutdownFlags=" << std::hex << (int) shutdownFlags_;
1048
1049   if (shutdownFlags_ & SHUT_WRITE) {
1050     // Writes are already shutdown; nothing else to do.
1051     return;
1052   }
1053
1054   // If SHUT_READ is already set, just call closeNow() to completely
1055   // close the socket.  This can happen if close() was called with writes
1056   // pending, and then shutdownWriteNow() is called before all pending writes
1057   // complete.
1058   if (shutdownFlags_ & SHUT_READ) {
1059     closeNow();
1060     return;
1061   }
1062
1063   DestructorGuard dg(this);
1064   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
1065
1066   switch (static_cast<StateEnum>(state_)) {
1067     case StateEnum::ESTABLISHED:
1068     {
1069       shutdownFlags_ |= SHUT_WRITE;
1070
1071       // If the write timeout was set, cancel it.
1072       writeTimeout_.cancelTimeout();
1073
1074       // If we are registered for write events, unregister.
1075       if (!updateEventRegistration(0, EventHandler::WRITE)) {
1076         // We will have been moved into the error state.
1077         assert(state_ == StateEnum::ERROR);
1078         return;
1079       }
1080
1081       // Shutdown writes on the file descriptor
1082       shutdown(fd_, SHUT_WR);
1083
1084       // Immediately fail all write requests
1085       failAllWrites(socketShutdownForWritesEx);
1086       return;
1087     }
1088     case StateEnum::CONNECTING:
1089     {
1090       // Set the SHUT_WRITE_PENDING flag.
1091       // When the connection completes, it will check this flag,
1092       // shutdown the write half of the socket, and then set SHUT_WRITE.
1093       shutdownFlags_ |= SHUT_WRITE_PENDING;
1094
1095       // Immediately fail all write requests
1096       failAllWrites(socketShutdownForWritesEx);
1097       return;
1098     }
1099     case StateEnum::UNINIT:
1100       // Callers normally shouldn't call shutdownWriteNow() before the socket
1101       // even starts connecting.  Nonetheless, go ahead and set
1102       // SHUT_WRITE_PENDING.  Once the socket eventually connects it will
1103       // immediately shut down the write side of the socket.
1104       shutdownFlags_ |= SHUT_WRITE_PENDING;
1105       return;
1106     case StateEnum::FAST_OPEN:
1107       // In fast open state we haven't call connected yet, and if we shutdown
1108       // the writes, we will never try to call connect, so shut everything down
1109       shutdownFlags_ |= SHUT_WRITE;
1110       // Immediately fail all write requests
1111       failAllWrites(socketShutdownForWritesEx);
1112       return;
1113     case StateEnum::CLOSED:
1114     case StateEnum::ERROR:
1115       // We should never get here.  SHUT_WRITE should always be set
1116       // in STATE_CLOSED and STATE_ERROR.
1117       VLOG(4) << "AsyncSocket::shutdownWriteNow() (this=" << this
1118                  << ", fd=" << fd_ << ") in unexpected state " << state_
1119                  << " with SHUT_WRITE not set ("
1120                  << std::hex << (int) shutdownFlags_ << ")";
1121       assert(false);
1122       return;
1123   }
1124
1125   LOG(DFATAL) << "AsyncSocket::shutdownWriteNow() (this=" << this << ", fd="
1126               << fd_ << ") called in unknown state " << state_;
1127 }
1128
1129 bool AsyncSocket::readable() const {
1130   if (fd_ == -1) {
1131     return false;
1132   }
1133   struct pollfd fds[1];
1134   fds[0].fd = fd_;
1135   fds[0].events = POLLIN;
1136   fds[0].revents = 0;
1137   int rc = poll(fds, 1, 0);
1138   return rc == 1;
1139 }
1140
1141 bool AsyncSocket::isPending() const {
1142   return ioHandler_.isPending();
1143 }
1144
1145 bool AsyncSocket::hangup() const {
1146   if (fd_ == -1) {
1147     // sanity check, no one should ask for hangup if we are not connected.
1148     assert(false);
1149     return false;
1150   }
1151 #ifdef POLLRDHUP // Linux-only
1152   struct pollfd fds[1];
1153   fds[0].fd = fd_;
1154   fds[0].events = POLLRDHUP|POLLHUP;
1155   fds[0].revents = 0;
1156   poll(fds, 1, 0);
1157   return (fds[0].revents & (POLLRDHUP|POLLHUP)) != 0;
1158 #else
1159   return false;
1160 #endif
1161 }
1162
1163 bool AsyncSocket::good() const {
1164   return (
1165       (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
1166        state_ == StateEnum::ESTABLISHED) &&
1167       (shutdownFlags_ == 0) && (eventBase_ != nullptr));
1168 }
1169
1170 bool AsyncSocket::error() const {
1171   return (state_ == StateEnum::ERROR);
1172 }
1173
1174 void AsyncSocket::attachEventBase(EventBase* eventBase) {
1175   VLOG(5) << "AsyncSocket::attachEventBase(this=" << this << ", fd=" << fd_
1176           << ", old evb=" << eventBase_ << ", new evb=" << eventBase
1177           << ", state=" << state_ << ", events="
1178           << std::hex << eventFlags_ << ")";
1179   assert(eventBase_ == nullptr);
1180   assert(eventBase->isInEventBaseThread());
1181
1182   eventBase_ = eventBase;
1183   ioHandler_.attachEventBase(eventBase);
1184   writeTimeout_.attachEventBase(eventBase);
1185   if (evbChangeCb_) {
1186     evbChangeCb_->evbAttached(this);
1187   }
1188 }
1189
1190 void AsyncSocket::detachEventBase() {
1191   VLOG(5) << "AsyncSocket::detachEventBase(this=" << this << ", fd=" << fd_
1192           << ", old evb=" << eventBase_ << ", state=" << state_
1193           << ", events=" << std::hex << eventFlags_ << ")";
1194   assert(eventBase_ != nullptr);
1195   assert(eventBase_->isInEventBaseThread());
1196
1197   eventBase_ = nullptr;
1198   ioHandler_.detachEventBase();
1199   writeTimeout_.detachEventBase();
1200   if (evbChangeCb_) {
1201     evbChangeCb_->evbDetached(this);
1202   }
1203 }
1204
1205 bool AsyncSocket::isDetachable() const {
1206   DCHECK(eventBase_ != nullptr);
1207   DCHECK(eventBase_->isInEventBaseThread());
1208
1209   return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
1210 }
1211
1212 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
1213   if (!localAddr_.isInitialized()) {
1214     localAddr_.setFromLocalAddress(fd_);
1215   }
1216   *address = localAddr_;
1217 }
1218
1219 void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
1220   if (!addr_.isInitialized()) {
1221     addr_.setFromPeerAddress(fd_);
1222   }
1223   *address = addr_;
1224 }
1225
1226 bool AsyncSocket::getTFOSucceded() const {
1227   return detail::tfo_succeeded(fd_);
1228 }
1229
1230 int AsyncSocket::setNoDelay(bool noDelay) {
1231   if (fd_ < 0) {
1232     VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket "
1233                << this << "(state=" << state_ << ")";
1234     return EINVAL;
1235
1236   }
1237
1238   int value = noDelay ? 1 : 0;
1239   if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
1240     int errnoCopy = errno;
1241     VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket "
1242             << this << " (fd=" << fd_ << ", state=" << state_ << "): "
1243             << strerror(errnoCopy);
1244     return errnoCopy;
1245   }
1246
1247   return 0;
1248 }
1249
1250 int AsyncSocket::setCongestionFlavor(const std::string &cname) {
1251
1252   #ifndef TCP_CONGESTION
1253   #define TCP_CONGESTION  13
1254   #endif
1255
1256   if (fd_ < 0) {
1257     VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
1258                << "socket " << this << "(state=" << state_ << ")";
1259     return EINVAL;
1260
1261   }
1262
1263   if (setsockopt(
1264           fd_,
1265           IPPROTO_TCP,
1266           TCP_CONGESTION,
1267           cname.c_str(),
1268           socklen_t(cname.length() + 1)) != 0) {
1269     int errnoCopy = errno;
1270     VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket "
1271             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1272             << strerror(errnoCopy);
1273     return errnoCopy;
1274   }
1275
1276   return 0;
1277 }
1278
1279 int AsyncSocket::setQuickAck(bool quickack) {
1280   if (fd_ < 0) {
1281     VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket "
1282                << this << "(state=" << state_ << ")";
1283     return EINVAL;
1284
1285   }
1286
1287 #ifdef TCP_QUICKACK // Linux-only
1288   int value = quickack ? 1 : 0;
1289   if (setsockopt(fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
1290     int errnoCopy = errno;
1291     VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket"
1292             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1293             << strerror(errnoCopy);
1294     return errnoCopy;
1295   }
1296
1297   return 0;
1298 #else
1299   return ENOSYS;
1300 #endif
1301 }
1302
1303 int AsyncSocket::setSendBufSize(size_t bufsize) {
1304   if (fd_ < 0) {
1305     VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
1306                << this << "(state=" << state_ << ")";
1307     return EINVAL;
1308   }
1309
1310   if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) !=0) {
1311     int errnoCopy = errno;
1312     VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket"
1313             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1314             << strerror(errnoCopy);
1315     return errnoCopy;
1316   }
1317
1318   return 0;
1319 }
1320
1321 int AsyncSocket::setRecvBufSize(size_t bufsize) {
1322   if (fd_ < 0) {
1323     VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
1324                << this << "(state=" << state_ << ")";
1325     return EINVAL;
1326   }
1327
1328   if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) !=0) {
1329     int errnoCopy = errno;
1330     VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket"
1331             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1332             << strerror(errnoCopy);
1333     return errnoCopy;
1334   }
1335
1336   return 0;
1337 }
1338
1339 int AsyncSocket::setTCPProfile(int profd) {
1340   if (fd_ < 0) {
1341     VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket "
1342                << this << "(state=" << state_ << ")";
1343     return EINVAL;
1344   }
1345
1346   if (setsockopt(fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) !=0) {
1347     int errnoCopy = errno;
1348     VLOG(2) << "failed to set socket namespace option on AsyncSocket"
1349             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1350             << strerror(errnoCopy);
1351     return errnoCopy;
1352   }
1353
1354   return 0;
1355 }
1356
1357 void AsyncSocket::ioReady(uint16_t events) noexcept {
1358   VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
1359           << ", events=" << std::hex << events << ", state=" << state_;
1360   DestructorGuard dg(this);
1361   assert(events & EventHandler::READ_WRITE);
1362   assert(eventBase_->isInEventBaseThread());
1363
1364   uint16_t relevantEvents = uint16_t(events & EventHandler::READ_WRITE);
1365   EventBase* originalEventBase = eventBase_;
1366   // If we got there it means that either EventHandler::READ or
1367   // EventHandler::WRITE is set. Any of these flags can
1368   // indicate that there are messages available in the socket
1369   // error message queue.
1370   handleErrMessages();
1371
1372   // Return now if handleErrMessages() detached us from our EventBase
1373   if (eventBase_ != originalEventBase) {
1374     return;
1375   }
1376
1377   if (relevantEvents == EventHandler::READ) {
1378     handleRead();
1379   } else if (relevantEvents == EventHandler::WRITE) {
1380     handleWrite();
1381   } else if (relevantEvents == EventHandler::READ_WRITE) {
1382     // If both read and write events are ready, process writes first.
1383     handleWrite();
1384
1385     // Return now if handleWrite() detached us from our EventBase
1386     if (eventBase_ != originalEventBase) {
1387       return;
1388     }
1389
1390     // Only call handleRead() if a read callback is still installed.
1391     // (It's possible that the read callback was uninstalled during
1392     // handleWrite().)
1393     if (readCallback_) {
1394       handleRead();
1395     }
1396   } else {
1397     VLOG(4) << "AsyncSocket::ioRead() called with unexpected events "
1398                << std::hex << events << "(this=" << this << ")";
1399     abort();
1400   }
1401 }
1402
1403 AsyncSocket::ReadResult
1404 AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
1405   VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
1406           << ", buflen=" << *buflen;
1407
1408   int recvFlags = 0;
1409   if (peek_) {
1410     recvFlags |= MSG_PEEK;
1411   }
1412
1413   ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT | recvFlags);
1414   if (bytes < 0) {
1415     if (errno == EAGAIN || errno == EWOULDBLOCK) {
1416       // No more data to read right now.
1417       return ReadResult(READ_BLOCKING);
1418     } else {
1419       return ReadResult(READ_ERROR);
1420     }
1421   } else {
1422     appBytesReceived_ += bytes;
1423     return ReadResult(bytes);
1424   }
1425 }
1426
1427 void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
1428   // no matter what, buffer should be preapared for non-ssl socket
1429   CHECK(readCallback_);
1430   readCallback_->getReadBuffer(buf, buflen);
1431 }
1432
1433 void AsyncSocket::handleErrMessages() noexcept {
1434   // This method has non-empty implementation only for platforms
1435   // supporting per-socket error queues.
1436   VLOG(5) << "AsyncSocket::handleErrMessages() this=" << this << ", fd=" << fd_
1437           << ", state=" << state_;
1438   if (errMessageCallback_ == nullptr) {
1439     VLOG(7) << "AsyncSocket::handleErrMessages(): "
1440             << "no callback installed - exiting.";
1441     return;
1442   }
1443
1444 #ifdef MSG_ERRQUEUE
1445   uint8_t ctrl[1024];
1446   unsigned char data;
1447   struct msghdr msg;
1448   iovec entry;
1449
1450   entry.iov_base = &data;
1451   entry.iov_len = sizeof(data);
1452   msg.msg_iov = &entry;
1453   msg.msg_iovlen = 1;
1454   msg.msg_name = nullptr;
1455   msg.msg_namelen = 0;
1456   msg.msg_control = ctrl;
1457   msg.msg_controllen = sizeof(ctrl);
1458   msg.msg_flags = 0;
1459
1460   int ret;
1461   while (true) {
1462     ret = recvmsg(fd_, &msg, MSG_ERRQUEUE);
1463     VLOG(5) << "AsyncSocket::handleErrMessages(): recvmsg returned " << ret;
1464
1465     if (ret < 0) {
1466       if (errno != EAGAIN) {
1467         auto errnoCopy = errno;
1468         LOG(ERROR) << "::recvmsg exited with code " << ret
1469                    << ", errno: " << errnoCopy;
1470         AsyncSocketException ex(
1471           AsyncSocketException::INTERNAL_ERROR,
1472           withAddr("recvmsg() failed"),
1473           errnoCopy);
1474         failErrMessageRead(__func__, ex);
1475       }
1476       return;
1477     }
1478
1479     for (struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
1480          cmsg != nullptr && cmsg->cmsg_len != 0;
1481          cmsg = CMSG_NXTHDR(&msg, cmsg)) {
1482       errMessageCallback_->errMessage(*cmsg);
1483     }
1484   }
1485 #endif //MSG_ERRQUEUE
1486 }
1487
1488 void AsyncSocket::handleRead() noexcept {
1489   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
1490           << ", state=" << state_;
1491   assert(state_ == StateEnum::ESTABLISHED);
1492   assert((shutdownFlags_ & SHUT_READ) == 0);
1493   assert(readCallback_ != nullptr);
1494   assert(eventFlags_ & EventHandler::READ);
1495
1496   // Loop until:
1497   // - a read attempt would block
1498   // - readCallback_ is uninstalled
1499   // - the number of loop iterations exceeds the optional maximum
1500   // - this AsyncSocket is moved to another EventBase
1501   //
1502   // When we invoke readDataAvailable() it may uninstall the readCallback_,
1503   // which is why need to check for it here.
1504   //
1505   // The last bullet point is slightly subtle.  readDataAvailable() may also
1506   // detach this socket from this EventBase.  However, before
1507   // readDataAvailable() returns another thread may pick it up, attach it to
1508   // a different EventBase, and install another readCallback_.  We need to
1509   // exit immediately after readDataAvailable() returns if the eventBase_ has
1510   // changed.  (The caller must perform some sort of locking to transfer the
1511   // AsyncSocket between threads properly.  This will be sufficient to ensure
1512   // that this thread sees the updated eventBase_ variable after
1513   // readDataAvailable() returns.)
1514   uint16_t numReads = 0;
1515   EventBase* originalEventBase = eventBase_;
1516   while (readCallback_ && eventBase_ == originalEventBase) {
1517     // Get the buffer to read into.
1518     void* buf = nullptr;
1519     size_t buflen = 0, offset = 0;
1520     try {
1521       prepareReadBuffer(&buf, &buflen);
1522       VLOG(5) << "prepareReadBuffer() buf=" << buf << ", buflen=" << buflen;
1523     } catch (const AsyncSocketException& ex) {
1524       return failRead(__func__, ex);
1525     } catch (const std::exception& ex) {
1526       AsyncSocketException tex(AsyncSocketException::BAD_ARGS,
1527                               string("ReadCallback::getReadBuffer() "
1528                                      "threw exception: ") +
1529                               ex.what());
1530       return failRead(__func__, tex);
1531     } catch (...) {
1532       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1533                              "ReadCallback::getReadBuffer() threw "
1534                              "non-exception type");
1535       return failRead(__func__, ex);
1536     }
1537     if (!isBufferMovable_ && (buf == nullptr || buflen == 0)) {
1538       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1539                              "ReadCallback::getReadBuffer() returned "
1540                              "empty buffer");
1541       return failRead(__func__, ex);
1542     }
1543
1544     // Perform the read
1545     auto readResult = performRead(&buf, &buflen, &offset);
1546     auto bytesRead = readResult.readReturn;
1547     VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
1548             << bytesRead << " bytes";
1549     if (bytesRead > 0) {
1550       if (!isBufferMovable_) {
1551         readCallback_->readDataAvailable(size_t(bytesRead));
1552       } else {
1553         CHECK(kOpenSslModeMoveBufferOwnership);
1554         VLOG(5) << "this=" << this << ", AsyncSocket::handleRead() got "
1555                 << "buf=" << buf << ", " << bytesRead << "/" << buflen
1556                 << ", offset=" << offset;
1557         auto readBuf = folly::IOBuf::takeOwnership(buf, buflen);
1558         readBuf->trimStart(offset);
1559         readBuf->trimEnd(buflen - offset - bytesRead);
1560         readCallback_->readBufferAvailable(std::move(readBuf));
1561       }
1562
1563       // Fall through and continue around the loop if the read
1564       // completely filled the available buffer.
1565       // Note that readCallback_ may have been uninstalled or changed inside
1566       // readDataAvailable().
1567       if (size_t(bytesRead) < buflen) {
1568         return;
1569       }
1570     } else if (bytesRead == READ_BLOCKING) {
1571         // No more data to read right now.
1572         return;
1573     } else if (bytesRead == READ_ERROR) {
1574       readErr_ = READ_ERROR;
1575       if (readResult.exception) {
1576         return failRead(__func__, *readResult.exception);
1577       }
1578       auto errnoCopy = errno;
1579       AsyncSocketException ex(
1580           AsyncSocketException::INTERNAL_ERROR,
1581           withAddr("recv() failed"),
1582           errnoCopy);
1583       return failRead(__func__, ex);
1584     } else {
1585       assert(bytesRead == READ_EOF);
1586       readErr_ = READ_EOF;
1587       // EOF
1588       shutdownFlags_ |= SHUT_READ;
1589       if (!updateEventRegistration(0, EventHandler::READ)) {
1590         // we've already been moved into STATE_ERROR
1591         assert(state_ == StateEnum::ERROR);
1592         assert(readCallback_ == nullptr);
1593         return;
1594       }
1595
1596       ReadCallback* callback = readCallback_;
1597       readCallback_ = nullptr;
1598       callback->readEOF();
1599       return;
1600     }
1601     if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
1602       if (readCallback_ != nullptr) {
1603         // We might still have data in the socket.
1604         // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
1605         scheduleImmediateRead();
1606       }
1607       return;
1608     }
1609   }
1610 }
1611
1612 /**
1613  * This function attempts to write as much data as possible, until no more data
1614  * can be written.
1615  *
1616  * - If it sends all available data, it unregisters for write events, and stops
1617  *   the writeTimeout_.
1618  *
1619  * - If not all of the data can be sent immediately, it reschedules
1620  *   writeTimeout_ (if a non-zero timeout is set), and ensures the handler is
1621  *   registered for write events.
1622  */
1623 void AsyncSocket::handleWrite() noexcept {
1624   VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
1625           << ", state=" << state_;
1626   DestructorGuard dg(this);
1627
1628   if (state_ == StateEnum::CONNECTING) {
1629     handleConnect();
1630     return;
1631   }
1632
1633   // Normal write
1634   assert(state_ == StateEnum::ESTABLISHED);
1635   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1636   assert(writeReqHead_ != nullptr);
1637
1638   // Loop until we run out of write requests,
1639   // or until this socket is moved to another EventBase.
1640   // (See the comment in handleRead() explaining how this can happen.)
1641   EventBase* originalEventBase = eventBase_;
1642   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
1643     auto writeResult = writeReqHead_->performWrite();
1644     if (writeResult.writeReturn < 0) {
1645       if (writeResult.exception) {
1646         return failWrite(__func__, *writeResult.exception);
1647       }
1648       auto errnoCopy = errno;
1649       AsyncSocketException ex(
1650           AsyncSocketException::INTERNAL_ERROR,
1651           withAddr("writev() failed"),
1652           errnoCopy);
1653       return failWrite(__func__, ex);
1654     } else if (writeReqHead_->isComplete()) {
1655       // We finished this request
1656       WriteRequest* req = writeReqHead_;
1657       writeReqHead_ = req->getNext();
1658
1659       if (writeReqHead_ == nullptr) {
1660         writeReqTail_ = nullptr;
1661         // This is the last write request.
1662         // Unregister for write events and cancel the send timer
1663         // before we invoke the callback.  We have to update the state properly
1664         // before calling the callback, since it may want to detach us from
1665         // the EventBase.
1666         if (eventFlags_ & EventHandler::WRITE) {
1667           if (!updateEventRegistration(0, EventHandler::WRITE)) {
1668             assert(state_ == StateEnum::ERROR);
1669             return;
1670           }
1671           // Stop the send timeout
1672           writeTimeout_.cancelTimeout();
1673         }
1674         assert(!writeTimeout_.isScheduled());
1675
1676         // If SHUT_WRITE_PENDING is set, we should shutdown the socket after
1677         // we finish sending the last write request.
1678         //
1679         // We have to do this before invoking writeSuccess(), since
1680         // writeSuccess() may detach us from our EventBase.
1681         if (shutdownFlags_ & SHUT_WRITE_PENDING) {
1682           assert(connectCallback_ == nullptr);
1683           shutdownFlags_ |= SHUT_WRITE;
1684
1685           if (shutdownFlags_ & SHUT_READ) {
1686             // Reads have already been shutdown.  Fully close the socket and
1687             // move to STATE_CLOSED.
1688             //
1689             // Note: This code currently moves us to STATE_CLOSED even if
1690             // close() hasn't ever been called.  This can occur if we have
1691             // received EOF from the peer and shutdownWrite() has been called
1692             // locally.  Should we bother staying in STATE_ESTABLISHED in this
1693             // case, until close() is actually called?  I can't think of a
1694             // reason why we would need to do so.  No other operations besides
1695             // calling close() or destroying the socket can be performed at
1696             // this point.
1697             assert(readCallback_ == nullptr);
1698             state_ = StateEnum::CLOSED;
1699             if (fd_ >= 0) {
1700               ioHandler_.changeHandlerFD(-1);
1701               doClose();
1702             }
1703           } else {
1704             // Reads are still enabled, so we are only doing a half-shutdown
1705             shutdown(fd_, SHUT_WR);
1706           }
1707         }
1708       }
1709
1710       // Invoke the callback
1711       WriteCallback* callback = req->getCallback();
1712       req->destroy();
1713       if (callback) {
1714         callback->writeSuccess();
1715       }
1716       // We'll continue around the loop, trying to write another request
1717     } else {
1718       // Partial write.
1719       if (bufferCallback_) {
1720         bufferCallback_->onEgressBuffered();
1721       }
1722       writeReqHead_->consume();
1723       // Stop after a partial write; it's highly likely that a subsequent write
1724       // attempt will just return EAGAIN.
1725       //
1726       // Ensure that we are registered for write events.
1727       if ((eventFlags_ & EventHandler::WRITE) == 0) {
1728         if (!updateEventRegistration(EventHandler::WRITE, 0)) {
1729           assert(state_ == StateEnum::ERROR);
1730           return;
1731         }
1732       }
1733
1734       // Reschedule the send timeout, since we have made some write progress.
1735       if (sendTimeout_ > 0) {
1736         if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
1737           AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1738               withAddr("failed to reschedule write timeout"));
1739           return failWrite(__func__, ex);
1740         }
1741       }
1742       return;
1743     }
1744   }
1745   if (!writeReqHead_ && bufferCallback_) {
1746     bufferCallback_->onEgressBufferCleared();
1747   }
1748 }
1749
1750 void AsyncSocket::checkForImmediateRead() noexcept {
1751   // We currently don't attempt to perform optimistic reads in AsyncSocket.
1752   // (However, note that some subclasses do override this method.)
1753   //
1754   // Simply calling handleRead() here would be bad, as this would call
1755   // readCallback_->getReadBuffer(), forcing the callback to allocate a read
1756   // buffer even though no data may be available.  This would waste lots of
1757   // memory, since the buffer will sit around unused until the socket actually
1758   // becomes readable.
1759   //
1760   // Checking if the socket is readable now also seems like it would probably
1761   // be a pessimism.  In most cases it probably wouldn't be readable, and we
1762   // would just waste an extra system call.  Even if it is readable, waiting to
1763   // find out from libevent on the next event loop doesn't seem that bad.
1764 }
1765
1766 void AsyncSocket::handleInitialReadWrite() noexcept {
1767   // Our callers should already be holding a DestructorGuard, but grab
1768   // one here just to make sure, in case one of our calling code paths ever
1769   // changes.
1770   DestructorGuard dg(this);
1771   // If we have a readCallback_, make sure we enable read events.  We
1772   // may already be registered for reads if connectSuccess() set
1773   // the read calback.
1774   if (readCallback_ && !(eventFlags_ & EventHandler::READ)) {
1775     assert(state_ == StateEnum::ESTABLISHED);
1776     assert((shutdownFlags_ & SHUT_READ) == 0);
1777     if (!updateEventRegistration(EventHandler::READ, 0)) {
1778       assert(state_ == StateEnum::ERROR);
1779       return;
1780     }
1781     checkForImmediateRead();
1782   } else if (readCallback_ == nullptr) {
1783     // Unregister for read events.
1784     updateEventRegistration(0, EventHandler::READ);
1785   }
1786
1787   // If we have write requests pending, try to send them immediately.
1788   // Since we just finished accepting, there is a very good chance that we can
1789   // write without blocking.
1790   //
1791   // However, we only process them if EventHandler::WRITE is not already set,
1792   // which means that we're already blocked on a write attempt.  (This can
1793   // happen if connectSuccess() called write() before returning.)
1794   if (writeReqHead_ && !(eventFlags_ & EventHandler::WRITE)) {
1795     // Call handleWrite() to perform write processing.
1796     handleWrite();
1797   } else if (writeReqHead_ == nullptr) {
1798     // Unregister for write event.
1799     updateEventRegistration(0, EventHandler::WRITE);
1800   }
1801 }
1802
1803 void AsyncSocket::handleConnect() noexcept {
1804   VLOG(5) << "AsyncSocket::handleConnect() this=" << this << ", fd=" << fd_
1805           << ", state=" << state_;
1806   assert(state_ == StateEnum::CONNECTING);
1807   // SHUT_WRITE can never be set while we are still connecting;
1808   // SHUT_WRITE_PENDING may be set, be we only set SHUT_WRITE once the connect
1809   // finishes
1810   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1811
1812   // In case we had a connect timeout, cancel the timeout
1813   writeTimeout_.cancelTimeout();
1814   // We don't use a persistent registration when waiting on a connect event,
1815   // so we have been automatically unregistered now.  Update eventFlags_ to
1816   // reflect reality.
1817   assert(eventFlags_ == EventHandler::WRITE);
1818   eventFlags_ = EventHandler::NONE;
1819
1820   // Call getsockopt() to check if the connect succeeded
1821   int error;
1822   socklen_t len = sizeof(error);
1823   int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
1824   if (rv != 0) {
1825     auto errnoCopy = errno;
1826     AsyncSocketException ex(
1827         AsyncSocketException::INTERNAL_ERROR,
1828         withAddr("error calling getsockopt() after connect"),
1829         errnoCopy);
1830     VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1831                << fd_ << " host=" << addr_.describe()
1832                << ") exception:" << ex.what();
1833     return failConnect(__func__, ex);
1834   }
1835
1836   if (error != 0) {
1837     AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1838                            "connect failed", error);
1839     VLOG(1) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1840             << fd_ << " host=" << addr_.describe()
1841             << ") exception: " << ex.what();
1842     return failConnect(__func__, ex);
1843   }
1844
1845   // Move into STATE_ESTABLISHED
1846   state_ = StateEnum::ESTABLISHED;
1847
1848   // If SHUT_WRITE_PENDING is set and we don't have any write requests to
1849   // perform, immediately shutdown the write half of the socket.
1850   if ((shutdownFlags_ & SHUT_WRITE_PENDING) && writeReqHead_ == nullptr) {
1851     // SHUT_READ shouldn't be set.  If close() is called on the socket while we
1852     // are still connecting we just abort the connect rather than waiting for
1853     // it to complete.
1854     assert((shutdownFlags_ & SHUT_READ) == 0);
1855     shutdown(fd_, SHUT_WR);
1856     shutdownFlags_ |= SHUT_WRITE;
1857   }
1858
1859   VLOG(7) << "AsyncSocket " << this << ": fd " << fd_
1860           << "successfully connected; state=" << state_;
1861
1862   // Remember the EventBase we are attached to, before we start invoking any
1863   // callbacks (since the callbacks may call detachEventBase()).
1864   EventBase* originalEventBase = eventBase_;
1865
1866   invokeConnectSuccess();
1867   // Note that the connect callback may have changed our state.
1868   // (set or unset the read callback, called write(), closed the socket, etc.)
1869   // The following code needs to handle these situations correctly.
1870   //
1871   // If the socket has been closed, readCallback_ and writeReqHead_ will
1872   // always be nullptr, so that will prevent us from trying to read or write.
1873   //
1874   // The main thing to check for is if eventBase_ is still originalEventBase.
1875   // If not, we have been detached from this event base, so we shouldn't
1876   // perform any more operations.
1877   if (eventBase_ != originalEventBase) {
1878     return;
1879   }
1880
1881   handleInitialReadWrite();
1882 }
1883
1884 void AsyncSocket::timeoutExpired() noexcept {
1885   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
1886           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
1887   DestructorGuard dg(this);
1888   assert(eventBase_->isInEventBaseThread());
1889
1890   if (state_ == StateEnum::CONNECTING) {
1891     // connect() timed out
1892     // Unregister for I/O events.
1893     if (connectCallback_) {
1894       AsyncSocketException ex(
1895           AsyncSocketException::TIMED_OUT,
1896           folly::sformat(
1897               "connect timed out after {}ms", connectTimeout_.count()));
1898       failConnect(__func__, ex);
1899     } else {
1900       // we faced a connect error without a connect callback, which could
1901       // happen due to TFO.
1902       AsyncSocketException ex(
1903           AsyncSocketException::TIMED_OUT, "write timed out during connection");
1904       failWrite(__func__, ex);
1905     }
1906   } else {
1907     // a normal write operation timed out
1908     AsyncSocketException ex(
1909         AsyncSocketException::TIMED_OUT,
1910         folly::sformat("write timed out after {}ms", sendTimeout_));
1911     failWrite(__func__, ex);
1912   }
1913 }
1914
1915 ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
1916   return detail::tfo_sendmsg(fd, msg, msg_flags);
1917 }
1918
1919 AsyncSocket::WriteResult
1920 AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
1921   ssize_t totalWritten = 0;
1922   if (state_ == StateEnum::FAST_OPEN) {
1923     sockaddr_storage addr;
1924     auto len = addr_.getAddress(&addr);
1925     msg->msg_name = &addr;
1926     msg->msg_namelen = len;
1927     totalWritten = tfoSendMsg(fd_, msg, msg_flags);
1928     if (totalWritten >= 0) {
1929       tfoFinished_ = true;
1930       state_ = StateEnum::ESTABLISHED;
1931       // We schedule this asynchrously so that we don't end up
1932       // invoking initial read or write while a write is in progress.
1933       scheduleInitialReadWrite();
1934     } else if (errno == EINPROGRESS) {
1935       VLOG(4) << "TFO falling back to connecting";
1936       // A normal sendmsg doesn't return EINPROGRESS, however
1937       // TFO might fallback to connecting if there is no
1938       // cookie.
1939       state_ = StateEnum::CONNECTING;
1940       try {
1941         scheduleConnectTimeout();
1942         registerForConnectEvents();
1943       } catch (const AsyncSocketException& ex) {
1944         return WriteResult(
1945             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
1946       }
1947       // Let's fake it that no bytes were written and return an errno.
1948       errno = EAGAIN;
1949       totalWritten = -1;
1950     } else if (errno == EOPNOTSUPP) {
1951       // Try falling back to connecting.
1952       VLOG(4) << "TFO not supported";
1953       state_ = StateEnum::CONNECTING;
1954       try {
1955         int ret = socketConnect((const sockaddr*)&addr, len);
1956         if (ret == 0) {
1957           // connect succeeded immediately
1958           // Treat this like no data was written.
1959           state_ = StateEnum::ESTABLISHED;
1960           scheduleInitialReadWrite();
1961         }
1962         // If there was no exception during connections,
1963         // we would return that no bytes were written.
1964         errno = EAGAIN;
1965         totalWritten = -1;
1966       } catch (const AsyncSocketException& ex) {
1967         return WriteResult(
1968             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
1969       }
1970     } else if (errno == EAGAIN) {
1971       // Normally sendmsg would indicate that the write would block.
1972       // However in the fast open case, it would indicate that sendmsg
1973       // fell back to a connect. This is a return code from connect()
1974       // instead, and is an error condition indicating no fds available.
1975       return WriteResult(
1976           WRITE_ERROR,
1977           folly::make_unique<AsyncSocketException>(
1978               AsyncSocketException::UNKNOWN, "No more free local ports"));
1979     }
1980   } else {
1981     totalWritten = ::sendmsg(fd, msg, msg_flags);
1982   }
1983   return WriteResult(totalWritten);
1984 }
1985
1986 AsyncSocket::WriteResult AsyncSocket::performWrite(
1987     const iovec* vec,
1988     uint32_t count,
1989     WriteFlags flags,
1990     uint32_t* countWritten,
1991     uint32_t* partialWritten) {
1992   // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
1993   // We correctly handle EPIPE errors, so we never want to receive SIGPIPE
1994   // (since it may terminate the program if the main program doesn't explicitly
1995   // ignore it).
1996   struct msghdr msg;
1997   msg.msg_name = nullptr;
1998   msg.msg_namelen = 0;
1999   msg.msg_iov = const_cast<iovec *>(vec);
2000   msg.msg_iovlen = std::min<size_t>(count, kIovMax);
2001   msg.msg_control = nullptr;
2002   msg.msg_controllen = 0;
2003   msg.msg_flags = 0;
2004
2005   int msg_flags = MSG_DONTWAIT;
2006
2007 #ifdef MSG_NOSIGNAL // Linux-only
2008   msg_flags |= MSG_NOSIGNAL;
2009   if (isSet(flags, WriteFlags::CORK)) {
2010     // MSG_MORE tells the kernel we have more data to send, so wait for us to
2011     // give it the rest of the data rather than immediately sending a partial
2012     // frame, even when TCP_NODELAY is enabled.
2013     msg_flags |= MSG_MORE;
2014   }
2015 #endif
2016   if (isSet(flags, WriteFlags::EOR)) {
2017     // marks that this is the last byte of a record (response)
2018     msg_flags |= MSG_EOR;
2019   }
2020   auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
2021   auto totalWritten = writeResult.writeReturn;
2022   if (totalWritten < 0) {
2023     bool tryAgain = (errno == EAGAIN);
2024 #ifdef __APPLE__
2025     // Apple has a bug where doing a second write on a socket which we
2026     // have opened with TFO causes an ENOTCONN to be thrown. However the
2027     // socket is really connected, so treat ENOTCONN as a EAGAIN until
2028     // this bug is fixed.
2029     tryAgain |= (errno == ENOTCONN);
2030 #endif
2031     if (!writeResult.exception && tryAgain) {
2032       // TCP buffer is full; we can't write any more data right now.
2033       *countWritten = 0;
2034       *partialWritten = 0;
2035       return WriteResult(0);
2036     }
2037     // error
2038     *countWritten = 0;
2039     *partialWritten = 0;
2040     return writeResult;
2041   }
2042
2043   appBytesWritten_ += totalWritten;
2044
2045   uint32_t bytesWritten;
2046   uint32_t n;
2047   for (bytesWritten = uint32_t(totalWritten), n = 0; n < count; ++n) {
2048     const iovec* v = vec + n;
2049     if (v->iov_len > bytesWritten) {
2050       // Partial write finished in the middle of this iovec
2051       *countWritten = n;
2052       *partialWritten = bytesWritten;
2053       return WriteResult(totalWritten);
2054     }
2055
2056     bytesWritten -= uint32_t(v->iov_len);
2057   }
2058
2059   assert(bytesWritten == 0);
2060   *countWritten = n;
2061   *partialWritten = 0;
2062   return WriteResult(totalWritten);
2063 }
2064
2065 /**
2066  * Re-register the EventHandler after eventFlags_ has changed.
2067  *
2068  * If an error occurs, fail() is called to move the socket into the error state
2069  * and call all currently installed callbacks.  After an error, the
2070  * AsyncSocket is completely unregistered.
2071  *
2072  * @return Returns true on succcess, or false on error.
2073  */
2074 bool AsyncSocket::updateEventRegistration() {
2075   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
2076           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
2077           << ", events=" << std::hex << eventFlags_;
2078   assert(eventBase_->isInEventBaseThread());
2079   if (eventFlags_ == EventHandler::NONE) {
2080     ioHandler_.unregisterHandler();
2081     return true;
2082   }
2083
2084   // Always register for persistent events, so we don't have to re-register
2085   // after being called back.
2086   if (!ioHandler_.registerHandler(
2087           uint16_t(eventFlags_ | EventHandler::PERSIST))) {
2088     eventFlags_ = EventHandler::NONE; // we're not registered after error
2089     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
2090         withAddr("failed to update AsyncSocket event registration"));
2091     fail("updateEventRegistration", ex);
2092     return false;
2093   }
2094
2095   return true;
2096 }
2097
2098 bool AsyncSocket::updateEventRegistration(uint16_t enable,
2099                                            uint16_t disable) {
2100   uint16_t oldFlags = eventFlags_;
2101   eventFlags_ |= enable;
2102   eventFlags_ &= ~disable;
2103   if (eventFlags_ == oldFlags) {
2104     return true;
2105   } else {
2106     return updateEventRegistration();
2107   }
2108 }
2109
2110 void AsyncSocket::startFail() {
2111   // startFail() should only be called once
2112   assert(state_ != StateEnum::ERROR);
2113   assert(getDestructorGuardCount() > 0);
2114   state_ = StateEnum::ERROR;
2115   // Ensure that SHUT_READ and SHUT_WRITE are set,
2116   // so all future attempts to read or write will be rejected
2117   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
2118
2119   if (eventFlags_ != EventHandler::NONE) {
2120     eventFlags_ = EventHandler::NONE;
2121     ioHandler_.unregisterHandler();
2122   }
2123   writeTimeout_.cancelTimeout();
2124
2125   if (fd_ >= 0) {
2126     ioHandler_.changeHandlerFD(-1);
2127     doClose();
2128   }
2129 }
2130
2131 void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
2132   invokeConnectErr(ex);
2133   failAllWrites(ex);
2134
2135   if (readCallback_) {
2136     ReadCallback* callback = readCallback_;
2137     readCallback_ = nullptr;
2138     callback->readErr(ex);
2139   }
2140 }
2141
2142 void AsyncSocket::finishFail() {
2143   assert(state_ == StateEnum::ERROR);
2144   assert(getDestructorGuardCount() > 0);
2145
2146   AsyncSocketException ex(
2147       AsyncSocketException::INTERNAL_ERROR,
2148       withAddr("socket closing after error"));
2149   invokeAllErrors(ex);
2150 }
2151
2152 void AsyncSocket::finishFail(const AsyncSocketException& ex) {
2153   assert(state_ == StateEnum::ERROR);
2154   assert(getDestructorGuardCount() > 0);
2155   invokeAllErrors(ex);
2156 }
2157
2158 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
2159   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2160              << state_ << " host=" << addr_.describe()
2161              << "): failed in " << fn << "(): "
2162              << ex.what();
2163   startFail();
2164   finishFail();
2165 }
2166
2167 void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
2168   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2169                << state_ << " host=" << addr_.describe()
2170                << "): failed while connecting in " << fn << "(): "
2171                << ex.what();
2172   startFail();
2173
2174   invokeConnectErr(ex);
2175   finishFail(ex);
2176 }
2177
2178 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
2179   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2180                << state_ << " host=" << addr_.describe()
2181                << "): failed while reading in " << fn << "(): "
2182                << ex.what();
2183   startFail();
2184
2185   if (readCallback_ != nullptr) {
2186     ReadCallback* callback = readCallback_;
2187     readCallback_ = nullptr;
2188     callback->readErr(ex);
2189   }
2190
2191   finishFail();
2192 }
2193
2194 void AsyncSocket::failErrMessageRead(const char* fn,
2195                                      const AsyncSocketException& ex) {
2196   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2197                << state_ << " host=" << addr_.describe()
2198                << "): failed while reading message in " << fn << "(): "
2199                << ex.what();
2200   startFail();
2201
2202   if (errMessageCallback_ != nullptr) {
2203     ErrMessageCallback* callback = errMessageCallback_;
2204     errMessageCallback_ = nullptr;
2205     callback->errMessageError(ex);
2206   }
2207
2208   finishFail();
2209 }
2210
2211 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
2212   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2213                << state_ << " host=" << addr_.describe()
2214                << "): failed while writing in " << fn << "(): "
2215                << ex.what();
2216   startFail();
2217
2218   // Only invoke the first write callback, since the error occurred while
2219   // writing this request.  Let any other pending write callbacks be invoked in
2220   // finishFail().
2221   if (writeReqHead_ != nullptr) {
2222     WriteRequest* req = writeReqHead_;
2223     writeReqHead_ = req->getNext();
2224     WriteCallback* callback = req->getCallback();
2225     uint32_t bytesWritten = req->getTotalBytesWritten();
2226     req->destroy();
2227     if (callback) {
2228       callback->writeErr(bytesWritten, ex);
2229     }
2230   }
2231
2232   finishFail();
2233 }
2234
2235 void AsyncSocket::failWrite(const char* fn, WriteCallback* callback,
2236                              size_t bytesWritten,
2237                              const AsyncSocketException& ex) {
2238   // This version of failWrite() is used when the failure occurs before
2239   // we've added the callback to writeReqHead_.
2240   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2241              << state_ << " host=" << addr_.describe()
2242              <<"): failed while writing in " << fn << "(): "
2243              << ex.what();
2244   startFail();
2245
2246   if (callback != nullptr) {
2247     callback->writeErr(bytesWritten, ex);
2248   }
2249
2250   finishFail();
2251 }
2252
2253 void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
2254   // Invoke writeError() on all write callbacks.
2255   // This is used when writes are forcibly shutdown with write requests
2256   // pending, or when an error occurs with writes pending.
2257   while (writeReqHead_ != nullptr) {
2258     WriteRequest* req = writeReqHead_;
2259     writeReqHead_ = req->getNext();
2260     WriteCallback* callback = req->getCallback();
2261     if (callback) {
2262       callback->writeErr(req->getTotalBytesWritten(), ex);
2263     }
2264     req->destroy();
2265   }
2266 }
2267
2268 void AsyncSocket::invalidState(ConnectCallback* callback) {
2269   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
2270           << "): connect() called in invalid state " << state_;
2271
2272   /*
2273    * The invalidState() methods don't use the normal failure mechanisms,
2274    * since we don't know what state we are in.  We don't want to call
2275    * startFail()/finishFail() recursively if we are already in the middle of
2276    * cleaning up.
2277    */
2278
2279   AsyncSocketException ex(AsyncSocketException::ALREADY_OPEN,
2280                          "connect() called with socket in invalid state");
2281   connectEndTime_ = std::chrono::steady_clock::now();
2282   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2283     if (callback) {
2284       callback->connectErr(ex);
2285     }
2286   } else {
2287     // We can't use failConnect() here since connectCallback_
2288     // may already be set to another callback.  Invoke this ConnectCallback
2289     // here; any other connectCallback_ will be invoked in finishFail()
2290     startFail();
2291     if (callback) {
2292       callback->connectErr(ex);
2293     }
2294     finishFail();
2295   }
2296 }
2297
2298 void AsyncSocket::invalidState(ErrMessageCallback* callback) {
2299   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2300           << "): setErrMessageCB(" << callback
2301           << ") called in invalid state " << state_;
2302
2303   AsyncSocketException ex(
2304       AsyncSocketException::NOT_OPEN,
2305       msgErrQueueSupported
2306       ? "setErrMessageCB() called with socket in invalid state"
2307       : "This platform does not support socket error message notifications");
2308   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2309     if (callback) {
2310       callback->errMessageError(ex);
2311     }
2312   } else {
2313     startFail();
2314     if (callback) {
2315       callback->errMessageError(ex);
2316     }
2317     finishFail();
2318   }
2319 }
2320
2321 void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
2322   connectEndTime_ = std::chrono::steady_clock::now();
2323   if (connectCallback_) {
2324     ConnectCallback* callback = connectCallback_;
2325     connectCallback_ = nullptr;
2326     callback->connectErr(ex);
2327   }
2328 }
2329
2330 void AsyncSocket::invokeConnectSuccess() {
2331   connectEndTime_ = std::chrono::steady_clock::now();
2332   if (connectCallback_) {
2333     ConnectCallback* callback = connectCallback_;
2334     connectCallback_ = nullptr;
2335     callback->connectSuccess();
2336   }
2337 }
2338
2339 void AsyncSocket::invalidState(ReadCallback* callback) {
2340   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2341              << "): setReadCallback(" << callback
2342              << ") called in invalid state " << state_;
2343
2344   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2345                          "setReadCallback() called with socket in "
2346                          "invalid state");
2347   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2348     if (callback) {
2349       callback->readErr(ex);
2350     }
2351   } else {
2352     startFail();
2353     if (callback) {
2354       callback->readErr(ex);
2355     }
2356     finishFail();
2357   }
2358 }
2359
2360 void AsyncSocket::invalidState(WriteCallback* callback) {
2361   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2362              << "): write() called in invalid state " << state_;
2363
2364   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2365                          withAddr("write() called with socket in invalid state"));
2366   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2367     if (callback) {
2368       callback->writeErr(0, ex);
2369     }
2370   } else {
2371     startFail();
2372     if (callback) {
2373       callback->writeErr(0, ex);
2374     }
2375     finishFail();
2376   }
2377 }
2378
2379 void AsyncSocket::doClose() {
2380   if (fd_ == -1) return;
2381   if (shutdownSocketSet_) {
2382     shutdownSocketSet_->close(fd_);
2383   } else {
2384     ::close(fd_);
2385   }
2386   fd_ = -1;
2387 }
2388
2389 std::ostream& operator << (std::ostream& os,
2390                            const AsyncSocket::StateEnum& state) {
2391   os << static_cast<int>(state);
2392   return os;
2393 }
2394
2395 std::string AsyncSocket::withAddr(const std::string& s) {
2396   // Don't use addr_ directly because it may not be initialized
2397   // e.g. if constructed from fd
2398   folly::SocketAddress peer, local;
2399   try {
2400     getPeerAddress(&peer);
2401     getLocalAddress(&local);
2402   } catch (const std::exception&) {
2403     // ignore
2404   } catch (...) {
2405     // ignore
2406   }
2407   return s + " (peer=" + peer.describe() + ", local=" + local.describe() + ")";
2408 }
2409
2410 void AsyncSocket::setBufferCallback(BufferCallback* cb) {
2411   bufferCallback_ = cb;
2412 }
2413
2414 } // folly