Add TFO support to AsyncSSLSocket
[folly.git] / folly / io / async / AsyncSocket.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSocket.h>
18
19 #include <folly/ExceptionWrapper.h>
20 #include <folly/SocketAddress.h>
21 #include <folly/io/IOBuf.h>
22 #include <folly/portability/Fcntl.h>
23 #include <folly/portability/Sockets.h>
24 #include <folly/portability/SysUio.h>
25 #include <folly/portability/Unistd.h>
26
27 #include <errno.h>
28 #include <limits.h>
29 #include <thread>
30 #include <sys/types.h>
31 #include <boost/preprocessor/control/if.hpp>
32
33 using std::string;
34 using std::unique_ptr;
35
36 namespace folly {
37
38 // static members initializers
39 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
40
41 const AsyncSocketException socketClosedLocallyEx(
42     AsyncSocketException::END_OF_FILE, "socket closed locally");
43 const AsyncSocketException socketShutdownForWritesEx(
44     AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
45
46 // TODO: It might help performance to provide a version of BytesWriteRequest that
47 // users could derive from, so we can avoid the extra allocation for each call
48 // to write()/writev().  We could templatize TFramedAsyncChannel just like the
49 // protocols are currently templatized for transports.
50 //
51 // We would need the version for external users where they provide the iovec
52 // storage space, and only our internal version would allocate it at the end of
53 // the WriteRequest.
54
55 /* The default WriteRequest implementation, used for write(), writev() and
56  * writeChain()
57  *
58  * A new BytesWriteRequest operation is allocated on the heap for all write
59  * operations that cannot be completed immediately.
60  */
61 class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
62  public:
63   static BytesWriteRequest* newRequest(AsyncSocket* socket,
64                                        WriteCallback* callback,
65                                        const iovec* ops,
66                                        uint32_t opCount,
67                                        uint32_t partialWritten,
68                                        uint32_t bytesWritten,
69                                        unique_ptr<IOBuf>&& ioBuf,
70                                        WriteFlags flags) {
71     assert(opCount > 0);
72     // Since we put a variable size iovec array at the end
73     // of each BytesWriteRequest, we have to manually allocate the memory.
74     void* buf = malloc(sizeof(BytesWriteRequest) +
75                        (opCount * sizeof(struct iovec)));
76     if (buf == nullptr) {
77       throw std::bad_alloc();
78     }
79
80     return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
81                                       partialWritten, bytesWritten,
82                                       std::move(ioBuf), flags);
83   }
84
85   void destroy() override {
86     this->~BytesWriteRequest();
87     free(this);
88   }
89
90   WriteResult performWrite() override {
91     WriteFlags writeFlags = flags_;
92     if (getNext() != nullptr) {
93       writeFlags = writeFlags | WriteFlags::CORK;
94     }
95     return socket_->performWrite(
96         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
97   }
98
99   bool isComplete() override {
100     return opsWritten_ == getOpCount();
101   }
102
103   void consume() override {
104     // Advance opIndex_ forward by opsWritten_
105     opIndex_ += opsWritten_;
106     assert(opIndex_ < opCount_);
107
108     // If we've finished writing any IOBufs, release them
109     if (ioBuf_) {
110       for (uint32_t i = opsWritten_; i != 0; --i) {
111         assert(ioBuf_);
112         ioBuf_ = ioBuf_->pop();
113       }
114     }
115
116     // Move partialBytes_ forward into the current iovec buffer
117     struct iovec* currentOp = writeOps_ + opIndex_;
118     assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
119     currentOp->iov_base =
120       reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
121     currentOp->iov_len -= partialBytes_;
122
123     // Increment the totalBytesWritten_ count by bytesWritten_;
124     totalBytesWritten_ += bytesWritten_;
125   }
126
127  private:
128   BytesWriteRequest(AsyncSocket* socket,
129                     WriteCallback* callback,
130                     const struct iovec* ops,
131                     uint32_t opCount,
132                     uint32_t partialBytes,
133                     uint32_t bytesWritten,
134                     unique_ptr<IOBuf>&& ioBuf,
135                     WriteFlags flags)
136     : AsyncSocket::WriteRequest(socket, callback)
137     , opCount_(opCount)
138     , opIndex_(0)
139     , flags_(flags)
140     , ioBuf_(std::move(ioBuf))
141     , opsWritten_(0)
142     , partialBytes_(partialBytes)
143     , bytesWritten_(bytesWritten) {
144     memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
145   }
146
147   // private destructor, to ensure callers use destroy()
148   ~BytesWriteRequest() override = default;
149
150   const struct iovec* getOps() const {
151     assert(opCount_ > opIndex_);
152     return writeOps_ + opIndex_;
153   }
154
155   uint32_t getOpCount() const {
156     assert(opCount_ > opIndex_);
157     return opCount_ - opIndex_;
158   }
159
160   uint32_t opCount_;            ///< number of entries in writeOps_
161   uint32_t opIndex_;            ///< current index into writeOps_
162   WriteFlags flags_;            ///< set for WriteFlags
163   unique_ptr<IOBuf> ioBuf_;     ///< underlying IOBuf, or nullptr if N/A
164
165   // for consume(), how much we wrote on the last write
166   uint32_t opsWritten_;         ///< complete ops written
167   uint32_t partialBytes_;       ///< partial bytes of incomplete op written
168   ssize_t bytesWritten_;        ///< bytes written altogether
169
170   struct iovec writeOps_[];     ///< write operation(s) list
171 };
172
173 AsyncSocket::AsyncSocket()
174   : eventBase_(nullptr)
175   , writeTimeout_(this, nullptr)
176   , ioHandler_(this, nullptr)
177   , immediateReadHandler_(this) {
178   VLOG(5) << "new AsyncSocket()";
179   init();
180 }
181
182 AsyncSocket::AsyncSocket(EventBase* evb)
183   : eventBase_(evb)
184   , writeTimeout_(this, evb)
185   , ioHandler_(this, evb)
186   , immediateReadHandler_(this) {
187   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
188   init();
189 }
190
191 AsyncSocket::AsyncSocket(EventBase* evb,
192                            const folly::SocketAddress& address,
193                            uint32_t connectTimeout)
194   : AsyncSocket(evb) {
195   connect(nullptr, address, connectTimeout);
196 }
197
198 AsyncSocket::AsyncSocket(EventBase* evb,
199                            const std::string& ip,
200                            uint16_t port,
201                            uint32_t connectTimeout)
202   : AsyncSocket(evb) {
203   connect(nullptr, ip, port, connectTimeout);
204 }
205
206 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
207   : eventBase_(evb)
208   , writeTimeout_(this, evb)
209   , ioHandler_(this, evb, fd)
210   , immediateReadHandler_(this) {
211   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
212           << fd << ")";
213   init();
214   fd_ = fd;
215   setCloseOnExec();
216   state_ = StateEnum::ESTABLISHED;
217 }
218
219 // init() method, since constructor forwarding isn't supported in most
220 // compilers yet.
221 void AsyncSocket::init() {
222   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
223   shutdownFlags_ = 0;
224   state_ = StateEnum::UNINIT;
225   eventFlags_ = EventHandler::NONE;
226   fd_ = -1;
227   sendTimeout_ = 0;
228   maxReadsPerEvent_ = 16;
229   connectCallback_ = nullptr;
230   readCallback_ = nullptr;
231   writeReqHead_ = nullptr;
232   writeReqTail_ = nullptr;
233   shutdownSocketSet_ = nullptr;
234   appBytesWritten_ = 0;
235   appBytesReceived_ = 0;
236 }
237
238 AsyncSocket::~AsyncSocket() {
239   VLOG(7) << "actual destruction of AsyncSocket(this=" << this
240           << ", evb=" << eventBase_ << ", fd=" << fd_
241           << ", state=" << state_ << ")";
242 }
243
244 void AsyncSocket::destroy() {
245   VLOG(5) << "AsyncSocket::destroy(this=" << this << ", evb=" << eventBase_
246           << ", fd=" << fd_ << ", state=" << state_;
247   // When destroy is called, close the socket immediately
248   closeNow();
249
250   // Then call DelayedDestruction::destroy() to take care of
251   // whether or not we need immediate or delayed destruction
252   DelayedDestruction::destroy();
253 }
254
255 int AsyncSocket::detachFd() {
256   VLOG(6) << "AsyncSocket::detachFd(this=" << this << ", fd=" << fd_
257           << ", evb=" << eventBase_ << ", state=" << state_
258           << ", events=" << std::hex << eventFlags_ << ")";
259   // Extract the fd, and set fd_ to -1 first, so closeNow() won't
260   // actually close the descriptor.
261   if (shutdownSocketSet_) {
262     shutdownSocketSet_->remove(fd_);
263   }
264   int fd = fd_;
265   fd_ = -1;
266   // Call closeNow() to invoke all pending callbacks with an error.
267   closeNow();
268   // Update the EventHandler to stop using this fd.
269   // This can only be done after closeNow() unregisters the handler.
270   ioHandler_.changeHandlerFD(-1);
271   return fd;
272 }
273
274 const folly::SocketAddress& AsyncSocket::anyAddress() {
275   static const folly::SocketAddress anyAddress =
276     folly::SocketAddress("0.0.0.0", 0);
277   return anyAddress;
278 }
279
280 void AsyncSocket::setShutdownSocketSet(ShutdownSocketSet* newSS) {
281   if (shutdownSocketSet_ == newSS) {
282     return;
283   }
284   if (shutdownSocketSet_ && fd_ != -1) {
285     shutdownSocketSet_->remove(fd_);
286   }
287   shutdownSocketSet_ = newSS;
288   if (shutdownSocketSet_ && fd_ != -1) {
289     shutdownSocketSet_->add(fd_);
290   }
291 }
292
293 void AsyncSocket::setCloseOnExec() {
294   int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC);
295   if (rv != 0) {
296     auto errnoCopy = errno;
297     throw AsyncSocketException(
298         AsyncSocketException::INTERNAL_ERROR,
299         withAddr("failed to set close-on-exec flag"),
300         errnoCopy);
301   }
302 }
303
304 void AsyncSocket::connect(ConnectCallback* callback,
305                            const folly::SocketAddress& address,
306                            int timeout,
307                            const OptionMap &options,
308                            const folly::SocketAddress& bindAddr) noexcept {
309   DestructorGuard dg(this);
310   assert(eventBase_->isInEventBaseThread());
311
312   addr_ = address;
313
314   // Make sure we're in the uninitialized state
315   if (state_ != StateEnum::UNINIT) {
316     return invalidState(callback);
317   }
318
319   connectTimeout_ = std::chrono::milliseconds(timeout);
320   connectStartTime_ = std::chrono::steady_clock::now();
321   // Make connect end time at least >= connectStartTime.
322   connectEndTime_ = connectStartTime_;
323
324   assert(fd_ == -1);
325   state_ = StateEnum::CONNECTING;
326   connectCallback_ = callback;
327
328   sockaddr_storage addrStorage;
329   sockaddr* saddr = reinterpret_cast<sockaddr*>(&addrStorage);
330
331   try {
332     // Create the socket
333     // Technically the first parameter should actually be a protocol family
334     // constant (PF_xxx) rather than an address family (AF_xxx), but the
335     // distinction is mainly just historical.  In pretty much all
336     // implementations the PF_foo and AF_foo constants are identical.
337     fd_ = socket(address.getFamily(), SOCK_STREAM, 0);
338     if (fd_ < 0) {
339       auto errnoCopy = errno;
340       throw AsyncSocketException(
341           AsyncSocketException::INTERNAL_ERROR,
342           withAddr("failed to create socket"),
343           errnoCopy);
344     }
345     if (shutdownSocketSet_) {
346       shutdownSocketSet_->add(fd_);
347     }
348     ioHandler_.changeHandlerFD(fd_);
349
350     setCloseOnExec();
351
352     // Put the socket in non-blocking mode
353     int flags = fcntl(fd_, F_GETFL, 0);
354     if (flags == -1) {
355       auto errnoCopy = errno;
356       throw AsyncSocketException(
357           AsyncSocketException::INTERNAL_ERROR,
358           withAddr("failed to get socket flags"),
359           errnoCopy);
360     }
361     int rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
362     if (rv == -1) {
363       auto errnoCopy = errno;
364       throw AsyncSocketException(
365           AsyncSocketException::INTERNAL_ERROR,
366           withAddr("failed to put socket in non-blocking mode"),
367           errnoCopy);
368     }
369
370 #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
371     // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
372     rv = fcntl(fd_, F_SETNOSIGPIPE, 1);
373     if (rv == -1) {
374       auto errnoCopy = errno;
375       throw AsyncSocketException(
376           AsyncSocketException::INTERNAL_ERROR,
377           "failed to enable F_SETNOSIGPIPE on socket",
378           errnoCopy);
379     }
380 #endif
381
382     // By default, turn on TCP_NODELAY
383     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
384     // setNoDelay() will log an error message if it fails.
385     if (address.getFamily() != AF_UNIX) {
386       (void)setNoDelay(true);
387     }
388
389     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
390             << ", fd=" << fd_ << ", host=" << address.describe().c_str();
391
392     // bind the socket
393     if (bindAddr != anyAddress()) {
394       int one = 1;
395       if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
396         auto errnoCopy = errno;
397         doClose();
398         throw AsyncSocketException(
399             AsyncSocketException::NOT_OPEN,
400             "failed to setsockopt prior to bind on " + bindAddr.describe(),
401             errnoCopy);
402       }
403
404       bindAddr.getAddress(&addrStorage);
405
406       if (::bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
407         auto errnoCopy = errno;
408         doClose();
409         throw AsyncSocketException(
410             AsyncSocketException::NOT_OPEN,
411             "failed to bind to async socket: " + bindAddr.describe(),
412             errnoCopy);
413       }
414     }
415
416     // Apply the additional options if any.
417     for (const auto& opt: options) {
418       int rv = opt.first.apply(fd_, opt.second);
419       if (rv != 0) {
420         auto errnoCopy = errno;
421         throw AsyncSocketException(
422             AsyncSocketException::INTERNAL_ERROR,
423             withAddr("failed to set socket option"),
424             errnoCopy);
425       }
426     }
427
428     // Perform the connect()
429     address.getAddress(&addrStorage);
430
431     if (tfoEnabled_) {
432       state_ = StateEnum::FAST_OPEN;
433       tfoAttempted_ = true;
434     } else {
435       if (socketConnect(saddr, addr_.getActualSize()) < 0) {
436         return;
437       }
438     }
439
440     // If we're still here the connect() succeeded immediately.
441     // Fall through to call the callback outside of this try...catch block
442   } catch (const AsyncSocketException& ex) {
443     return failConnect(__func__, ex);
444   } catch (const std::exception& ex) {
445     // shouldn't happen, but handle it just in case
446     VLOG(4) << "AsyncSocket::connect(this=" << this << ", fd=" << fd_
447                << "): unexpected " << typeid(ex).name() << " exception: "
448                << ex.what();
449     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
450                             withAddr(string("unexpected exception: ") +
451                                      ex.what()));
452     return failConnect(__func__, tex);
453   }
454
455   // The connection succeeded immediately
456   // The read callback may not have been set yet, and no writes may be pending
457   // yet, so we don't have to register for any events at the moment.
458   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
459   assert(readCallback_ == nullptr);
460   assert(writeReqHead_ == nullptr);
461   if (state_ != StateEnum::FAST_OPEN) {
462     state_ = StateEnum::ESTABLISHED;
463   }
464   invokeConnectSuccess();
465 }
466
467 int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
468   int rv = ::connect(fd_, saddr, len);
469   if (rv < 0) {
470     auto errnoCopy = errno;
471     if (errnoCopy == EINPROGRESS) {
472       scheduleConnectTimeoutAndRegisterForEvents();
473     } else {
474       throw AsyncSocketException(
475           AsyncSocketException::NOT_OPEN,
476           "connect failed (immediately)",
477           errnoCopy);
478     }
479   }
480   return rv;
481 }
482
483 void AsyncSocket::scheduleConnectTimeoutAndRegisterForEvents() {
484   // Connection in progress.
485   int timeout = connectTimeout_.count();
486   if (timeout > 0) {
487     // Start a timer in case the connection takes too long.
488     if (!writeTimeout_.scheduleTimeout(timeout)) {
489       throw AsyncSocketException(
490           AsyncSocketException::INTERNAL_ERROR,
491           withAddr("failed to schedule AsyncSocket connect timeout"));
492     }
493   }
494
495   // Register for write events, so we'll
496   // be notified when the connection finishes/fails.
497   // Note that we don't register for a persistent event here.
498   assert(eventFlags_ == EventHandler::NONE);
499   eventFlags_ = EventHandler::WRITE;
500   if (!ioHandler_.registerHandler(eventFlags_)) {
501     throw AsyncSocketException(
502         AsyncSocketException::INTERNAL_ERROR,
503         withAddr("failed to register AsyncSocket connect handler"));
504   }
505 }
506
507 void AsyncSocket::connect(ConnectCallback* callback,
508                            const string& ip, uint16_t port,
509                            int timeout,
510                            const OptionMap &options) noexcept {
511   DestructorGuard dg(this);
512   try {
513     connectCallback_ = callback;
514     connect(callback, folly::SocketAddress(ip, port), timeout, options);
515   } catch (const std::exception& ex) {
516     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
517                             ex.what());
518     return failConnect(__func__, tex);
519   }
520 }
521
522 void AsyncSocket::cancelConnect() {
523   connectCallback_ = nullptr;
524   if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
525     closeNow();
526   }
527 }
528
529 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
530   sendTimeout_ = milliseconds;
531   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
532
533   // If we are currently pending on write requests, immediately update
534   // writeTimeout_ with the new value.
535   if ((eventFlags_ & EventHandler::WRITE) &&
536       (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
537     assert(state_ == StateEnum::ESTABLISHED);
538     assert((shutdownFlags_ & SHUT_WRITE) == 0);
539     if (sendTimeout_ > 0) {
540       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
541         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
542             withAddr("failed to reschedule send timeout in setSendTimeout"));
543         return failWrite(__func__, ex);
544       }
545     } else {
546       writeTimeout_.cancelTimeout();
547     }
548   }
549 }
550
551 void AsyncSocket::setReadCB(ReadCallback *callback) {
552   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
553           << ", callback=" << callback << ", state=" << state_;
554
555   // Short circuit if callback is the same as the existing readCallback_.
556   //
557   // Note that this is needed for proper functioning during some cleanup cases.
558   // During cleanup we allow setReadCallback(nullptr) to be called even if the
559   // read callback is already unset and we have been detached from an event
560   // base.  This check prevents us from asserting
561   // eventBase_->isInEventBaseThread() when eventBase_ is nullptr.
562   if (callback == readCallback_) {
563     return;
564   }
565
566   /* We are removing a read callback */
567   if (callback == nullptr &&
568       immediateReadHandler_.isLoopCallbackScheduled()) {
569     immediateReadHandler_.cancelLoopCallback();
570   }
571
572   if (shutdownFlags_ & SHUT_READ) {
573     // Reads have already been shut down on this socket.
574     //
575     // Allow setReadCallback(nullptr) to be called in this case, but don't
576     // allow a new callback to be set.
577     //
578     // For example, setReadCallback(nullptr) can happen after an error if we
579     // invoke some other error callback before invoking readError().  The other
580     // error callback that is invoked first may go ahead and clear the read
581     // callback before we get a chance to invoke readError().
582     if (callback != nullptr) {
583       return invalidState(callback);
584     }
585     assert((eventFlags_ & EventHandler::READ) == 0);
586     readCallback_ = nullptr;
587     return;
588   }
589
590   DestructorGuard dg(this);
591   assert(eventBase_->isInEventBaseThread());
592
593   switch ((StateEnum)state_) {
594     case StateEnum::CONNECTING:
595     case StateEnum::FAST_OPEN:
596       // For convenience, we allow the read callback to be set while we are
597       // still connecting.  We just store the callback for now.  Once the
598       // connection completes we'll register for read events.
599       readCallback_ = callback;
600       return;
601     case StateEnum::ESTABLISHED:
602     {
603       readCallback_ = callback;
604       uint16_t oldFlags = eventFlags_;
605       if (readCallback_) {
606         eventFlags_ |= EventHandler::READ;
607       } else {
608         eventFlags_ &= ~EventHandler::READ;
609       }
610
611       // Update our registration if our flags have changed
612       if (eventFlags_ != oldFlags) {
613         // We intentionally ignore the return value here.
614         // updateEventRegistration() will move us into the error state if it
615         // fails, and we don't need to do anything else here afterwards.
616         (void)updateEventRegistration();
617       }
618
619       if (readCallback_) {
620         checkForImmediateRead();
621       }
622       return;
623     }
624     case StateEnum::CLOSED:
625     case StateEnum::ERROR:
626       // We should never reach here.  SHUT_READ should always be set
627       // if we are in STATE_CLOSED or STATE_ERROR.
628       assert(false);
629       return invalidState(callback);
630     case StateEnum::UNINIT:
631       // We do not allow setReadCallback() to be called before we start
632       // connecting.
633       return invalidState(callback);
634   }
635
636   // We don't put a default case in the switch statement, so that the compiler
637   // will warn us to update the switch statement if a new state is added.
638   return invalidState(callback);
639 }
640
641 AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
642   return readCallback_;
643 }
644
645 void AsyncSocket::write(WriteCallback* callback,
646                          const void* buf, size_t bytes, WriteFlags flags) {
647   iovec op;
648   op.iov_base = const_cast<void*>(buf);
649   op.iov_len = bytes;
650   writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags);
651 }
652
653 void AsyncSocket::writev(WriteCallback* callback,
654                           const iovec* vec,
655                           size_t count,
656                           WriteFlags flags) {
657   writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags);
658 }
659
660 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
661                               WriteFlags flags) {
662   constexpr size_t kSmallSizeMax = 64;
663   size_t count = buf->countChainElements();
664   if (count <= kSmallSizeMax) {
665     iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
666     writeChainImpl(callback, vec, count, std::move(buf), flags);
667   } else {
668     iovec* vec = new iovec[count];
669     writeChainImpl(callback, vec, count, std::move(buf), flags);
670     delete[] vec;
671   }
672 }
673
674 void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec,
675     size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
676   size_t veclen = buf->fillIov(vec, count);
677   writeImpl(callback, vec, veclen, std::move(buf), flags);
678 }
679
680 void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
681                              size_t count, unique_ptr<IOBuf>&& buf,
682                              WriteFlags flags) {
683   VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
684           << ", callback=" << callback << ", count=" << count
685           << ", state=" << state_;
686   DestructorGuard dg(this);
687   unique_ptr<IOBuf>ioBuf(std::move(buf));
688   assert(eventBase_->isInEventBaseThread());
689
690   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
691     // No new writes may be performed after the write side of the socket has
692     // been shutdown.
693     //
694     // We could just call callback->writeError() here to fail just this write.
695     // However, fail hard and use invalidState() to fail all outstanding
696     // callbacks and move the socket into the error state.  There's most likely
697     // a bug in the caller's code, so we abort everything rather than trying to
698     // proceed as best we can.
699     return invalidState(callback);
700   }
701
702   uint32_t countWritten = 0;
703   uint32_t partialWritten = 0;
704   int bytesWritten = 0;
705   bool mustRegister = false;
706   if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
707       !connecting()) {
708     if (writeReqHead_ == nullptr) {
709       // If we are established and there are no other writes pending,
710       // we can attempt to perform the write immediately.
711       assert(writeReqTail_ == nullptr);
712       assert((eventFlags_ & EventHandler::WRITE) == 0);
713
714       auto writeResult =
715           performWrite(vec, count, flags, &countWritten, &partialWritten);
716       bytesWritten = writeResult.writeReturn;
717       if (bytesWritten < 0) {
718         auto errnoCopy = errno;
719         if (writeResult.exception) {
720           return failWrite(__func__, callback, 0, *writeResult.exception);
721         }
722         AsyncSocketException ex(
723             AsyncSocketException::INTERNAL_ERROR,
724             withAddr("writev failed"),
725             errnoCopy);
726         return failWrite(__func__, callback, 0, ex);
727       } else if (countWritten == count) {
728         // We successfully wrote everything.
729         // Invoke the callback and return.
730         if (callback) {
731           callback->writeSuccess();
732         }
733         return;
734       } else { // continue writing the next writeReq
735         if (bufferCallback_) {
736           bufferCallback_->onEgressBuffered();
737         }
738       }
739       if (!connecting()) {
740         // Writes might put the socket back into connecting state
741         // if TFO is enabled, and using TFO fails.
742         // This means that write timeouts would not be active, however
743         // connect timeouts would affect this stage.
744         mustRegister = true;
745       }
746     }
747   } else if (!connecting()) {
748     // Invalid state for writing
749     return invalidState(callback);
750   }
751
752   // Create a new WriteRequest to add to the queue
753   WriteRequest* req;
754   try {
755     req = BytesWriteRequest::newRequest(this, callback, vec + countWritten,
756                                         count - countWritten, partialWritten,
757                                         bytesWritten, std::move(ioBuf), flags);
758   } catch (const std::exception& ex) {
759     // we mainly expect to catch std::bad_alloc here
760     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
761         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
762     return failWrite(__func__, callback, bytesWritten, tex);
763   }
764   req->consume();
765   if (writeReqTail_ == nullptr) {
766     assert(writeReqHead_ == nullptr);
767     writeReqHead_ = writeReqTail_ = req;
768   } else {
769     writeReqTail_->append(req);
770     writeReqTail_ = req;
771   }
772
773   // Register for write events if are established and not currently
774   // waiting on write events
775   if (mustRegister) {
776     assert(state_ == StateEnum::ESTABLISHED);
777     assert((eventFlags_ & EventHandler::WRITE) == 0);
778     if (!updateEventRegistration(EventHandler::WRITE, 0)) {
779       assert(state_ == StateEnum::ERROR);
780       return;
781     }
782     if (sendTimeout_ > 0) {
783       // Schedule a timeout to fire if the write takes too long.
784       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
785         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
786                                withAddr("failed to schedule send timeout"));
787         return failWrite(__func__, ex);
788       }
789     }
790   }
791 }
792
793 void AsyncSocket::writeRequest(WriteRequest* req) {
794   if (writeReqTail_ == nullptr) {
795     assert(writeReqHead_ == nullptr);
796     writeReqHead_ = writeReqTail_ = req;
797     req->start();
798   } else {
799     writeReqTail_->append(req);
800     writeReqTail_ = req;
801   }
802 }
803
804 void AsyncSocket::close() {
805   VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
806           << ", state=" << state_ << ", shutdownFlags="
807           << std::hex << (int) shutdownFlags_;
808
809   // close() is only different from closeNow() when there are pending writes
810   // that need to drain before we can close.  In all other cases, just call
811   // closeNow().
812   //
813   // Note that writeReqHead_ can be non-nullptr even in STATE_CLOSED or
814   // STATE_ERROR if close() is invoked while a previous closeNow() or failure
815   // is still running.  (e.g., If there are multiple pending writes, and we
816   // call writeError() on the first one, it may call close().  In this case we
817   // will already be in STATE_CLOSED or STATE_ERROR, but the remaining pending
818   // writes will still be in the queue.)
819   //
820   // We only need to drain pending writes if we are still in STATE_CONNECTING
821   // or STATE_ESTABLISHED
822   if ((writeReqHead_ == nullptr) ||
823       !(state_ == StateEnum::CONNECTING ||
824       state_ == StateEnum::ESTABLISHED)) {
825     closeNow();
826     return;
827   }
828
829   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
830   // destroyed until close() returns.
831   DestructorGuard dg(this);
832   assert(eventBase_->isInEventBaseThread());
833
834   // Since there are write requests pending, we have to set the
835   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
836   // connect finishes and we finish writing these requests.
837   //
838   // Set SHUT_READ to indicate that reads are shut down, and set the
839   // SHUT_WRITE_PENDING flag to mark that we want to shutdown once the
840   // pending writes complete.
841   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE_PENDING);
842
843   // If a read callback is set, invoke readEOF() immediately to inform it that
844   // the socket has been closed and no more data can be read.
845   if (readCallback_) {
846     // Disable reads if they are enabled
847     if (!updateEventRegistration(0, EventHandler::READ)) {
848       // We're now in the error state; callbacks have been cleaned up
849       assert(state_ == StateEnum::ERROR);
850       assert(readCallback_ == nullptr);
851     } else {
852       ReadCallback* callback = readCallback_;
853       readCallback_ = nullptr;
854       callback->readEOF();
855     }
856   }
857 }
858
859 void AsyncSocket::closeNow() {
860   VLOG(5) << "AsyncSocket::closeNow(): this=" << this << ", fd_=" << fd_
861           << ", state=" << state_ << ", shutdownFlags="
862           << std::hex << (int) shutdownFlags_;
863   DestructorGuard dg(this);
864   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
865
866   switch (state_) {
867     case StateEnum::ESTABLISHED:
868     case StateEnum::CONNECTING:
869     case StateEnum::FAST_OPEN: {
870       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
871       state_ = StateEnum::CLOSED;
872
873       // If the write timeout was set, cancel it.
874       writeTimeout_.cancelTimeout();
875
876       // If we are registered for I/O events, unregister.
877       if (eventFlags_ != EventHandler::NONE) {
878         eventFlags_ = EventHandler::NONE;
879         if (!updateEventRegistration()) {
880           // We will have been moved into the error state.
881           assert(state_ == StateEnum::ERROR);
882           return;
883         }
884       }
885
886       if (immediateReadHandler_.isLoopCallbackScheduled()) {
887         immediateReadHandler_.cancelLoopCallback();
888       }
889
890       if (fd_ >= 0) {
891         ioHandler_.changeHandlerFD(-1);
892         doClose();
893       }
894
895       invokeConnectErr(socketClosedLocallyEx);
896
897       failAllWrites(socketClosedLocallyEx);
898
899       if (readCallback_) {
900         ReadCallback* callback = readCallback_;
901         readCallback_ = nullptr;
902         callback->readEOF();
903       }
904       return;
905     }
906     case StateEnum::CLOSED:
907       // Do nothing.  It's possible that we are being called recursively
908       // from inside a callback that we invoked inside another call to close()
909       // that is still running.
910       return;
911     case StateEnum::ERROR:
912       // Do nothing.  The error handling code has performed (or is performing)
913       // cleanup.
914       return;
915     case StateEnum::UNINIT:
916       assert(eventFlags_ == EventHandler::NONE);
917       assert(connectCallback_ == nullptr);
918       assert(readCallback_ == nullptr);
919       assert(writeReqHead_ == nullptr);
920       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
921       state_ = StateEnum::CLOSED;
922       return;
923   }
924
925   LOG(DFATAL) << "AsyncSocket::closeNow() (this=" << this << ", fd=" << fd_
926               << ") called in unknown state " << state_;
927 }
928
929 void AsyncSocket::closeWithReset() {
930   // Enable SO_LINGER, with the linger timeout set to 0.
931   // This will trigger a TCP reset when we close the socket.
932   if (fd_ >= 0) {
933     struct linger optLinger = {1, 0};
934     if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
935       VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
936               << "on " << fd_ << ": errno=" << errno;
937     }
938   }
939
940   // Then let closeNow() take care of the rest
941   closeNow();
942 }
943
944 void AsyncSocket::shutdownWrite() {
945   VLOG(5) << "AsyncSocket::shutdownWrite(): this=" << this << ", fd=" << fd_
946           << ", state=" << state_ << ", shutdownFlags="
947           << std::hex << (int) shutdownFlags_;
948
949   // If there are no pending writes, shutdownWrite() is identical to
950   // shutdownWriteNow().
951   if (writeReqHead_ == nullptr) {
952     shutdownWriteNow();
953     return;
954   }
955
956   assert(eventBase_->isInEventBaseThread());
957
958   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
959   // shutdown will be performed once all writes complete.
960   shutdownFlags_ |= SHUT_WRITE_PENDING;
961 }
962
963 void AsyncSocket::shutdownWriteNow() {
964   VLOG(5) << "AsyncSocket::shutdownWriteNow(): this=" << this
965           << ", fd=" << fd_ << ", state=" << state_
966           << ", shutdownFlags=" << std::hex << (int) shutdownFlags_;
967
968   if (shutdownFlags_ & SHUT_WRITE) {
969     // Writes are already shutdown; nothing else to do.
970     return;
971   }
972
973   // If SHUT_READ is already set, just call closeNow() to completely
974   // close the socket.  This can happen if close() was called with writes
975   // pending, and then shutdownWriteNow() is called before all pending writes
976   // complete.
977   if (shutdownFlags_ & SHUT_READ) {
978     closeNow();
979     return;
980   }
981
982   DestructorGuard dg(this);
983   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
984
985   switch (static_cast<StateEnum>(state_)) {
986     case StateEnum::ESTABLISHED:
987     {
988       shutdownFlags_ |= SHUT_WRITE;
989
990       // If the write timeout was set, cancel it.
991       writeTimeout_.cancelTimeout();
992
993       // If we are registered for write events, unregister.
994       if (!updateEventRegistration(0, EventHandler::WRITE)) {
995         // We will have been moved into the error state.
996         assert(state_ == StateEnum::ERROR);
997         return;
998       }
999
1000       // Shutdown writes on the file descriptor
1001       ::shutdown(fd_, SHUT_WR);
1002
1003       // Immediately fail all write requests
1004       failAllWrites(socketShutdownForWritesEx);
1005       return;
1006     }
1007     case StateEnum::CONNECTING:
1008     {
1009       // Set the SHUT_WRITE_PENDING flag.
1010       // When the connection completes, it will check this flag,
1011       // shutdown the write half of the socket, and then set SHUT_WRITE.
1012       shutdownFlags_ |= SHUT_WRITE_PENDING;
1013
1014       // Immediately fail all write requests
1015       failAllWrites(socketShutdownForWritesEx);
1016       return;
1017     }
1018     case StateEnum::UNINIT:
1019       // Callers normally shouldn't call shutdownWriteNow() before the socket
1020       // even starts connecting.  Nonetheless, go ahead and set
1021       // SHUT_WRITE_PENDING.  Once the socket eventually connects it will
1022       // immediately shut down the write side of the socket.
1023       shutdownFlags_ |= SHUT_WRITE_PENDING;
1024       return;
1025     case StateEnum::FAST_OPEN:
1026       // In fast open state we haven't call connected yet, and if we shutdown
1027       // the writes, we will never try to call connect, so shut everything down
1028       shutdownFlags_ |= SHUT_WRITE;
1029       // Immediately fail all write requests
1030       failAllWrites(socketShutdownForWritesEx);
1031       return;
1032     case StateEnum::CLOSED:
1033     case StateEnum::ERROR:
1034       // We should never get here.  SHUT_WRITE should always be set
1035       // in STATE_CLOSED and STATE_ERROR.
1036       VLOG(4) << "AsyncSocket::shutdownWriteNow() (this=" << this
1037                  << ", fd=" << fd_ << ") in unexpected state " << state_
1038                  << " with SHUT_WRITE not set ("
1039                  << std::hex << (int) shutdownFlags_ << ")";
1040       assert(false);
1041       return;
1042   }
1043
1044   LOG(DFATAL) << "AsyncSocket::shutdownWriteNow() (this=" << this << ", fd="
1045               << fd_ << ") called in unknown state " << state_;
1046 }
1047
1048 bool AsyncSocket::readable() const {
1049   if (fd_ == -1) {
1050     return false;
1051   }
1052   struct pollfd fds[1];
1053   fds[0].fd = fd_;
1054   fds[0].events = POLLIN;
1055   fds[0].revents = 0;
1056   int rc = poll(fds, 1, 0);
1057   return rc == 1;
1058 }
1059
1060 bool AsyncSocket::isPending() const {
1061   return ioHandler_.isPending();
1062 }
1063
1064 bool AsyncSocket::hangup() const {
1065   if (fd_ == -1) {
1066     // sanity check, no one should ask for hangup if we are not connected.
1067     assert(false);
1068     return false;
1069   }
1070 #ifdef POLLRDHUP // Linux-only
1071   struct pollfd fds[1];
1072   fds[0].fd = fd_;
1073   fds[0].events = POLLRDHUP|POLLHUP;
1074   fds[0].revents = 0;
1075   poll(fds, 1, 0);
1076   return (fds[0].revents & (POLLRDHUP|POLLHUP)) != 0;
1077 #else
1078   return false;
1079 #endif
1080 }
1081
1082 bool AsyncSocket::good() const {
1083   return (
1084       (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
1085        state_ == StateEnum::ESTABLISHED) &&
1086       (shutdownFlags_ == 0) && (eventBase_ != nullptr));
1087 }
1088
1089 bool AsyncSocket::error() const {
1090   return (state_ == StateEnum::ERROR);
1091 }
1092
1093 void AsyncSocket::attachEventBase(EventBase* eventBase) {
1094   VLOG(5) << "AsyncSocket::attachEventBase(this=" << this << ", fd=" << fd_
1095           << ", old evb=" << eventBase_ << ", new evb=" << eventBase
1096           << ", state=" << state_ << ", events="
1097           << std::hex << eventFlags_ << ")";
1098   assert(eventBase_ == nullptr);
1099   assert(eventBase->isInEventBaseThread());
1100
1101   eventBase_ = eventBase;
1102   ioHandler_.attachEventBase(eventBase);
1103   writeTimeout_.attachEventBase(eventBase);
1104 }
1105
1106 void AsyncSocket::detachEventBase() {
1107   VLOG(5) << "AsyncSocket::detachEventBase(this=" << this << ", fd=" << fd_
1108           << ", old evb=" << eventBase_ << ", state=" << state_
1109           << ", events=" << std::hex << eventFlags_ << ")";
1110   assert(eventBase_ != nullptr);
1111   assert(eventBase_->isInEventBaseThread());
1112
1113   eventBase_ = nullptr;
1114   ioHandler_.detachEventBase();
1115   writeTimeout_.detachEventBase();
1116 }
1117
1118 bool AsyncSocket::isDetachable() const {
1119   DCHECK(eventBase_ != nullptr);
1120   DCHECK(eventBase_->isInEventBaseThread());
1121
1122   return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
1123 }
1124
1125 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
1126   if (!localAddr_.isInitialized()) {
1127     localAddr_.setFromLocalAddress(fd_);
1128   }
1129   *address = localAddr_;
1130 }
1131
1132 void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
1133   if (!addr_.isInitialized()) {
1134     addr_.setFromPeerAddress(fd_);
1135   }
1136   *address = addr_;
1137 }
1138
1139 int AsyncSocket::setNoDelay(bool noDelay) {
1140   if (fd_ < 0) {
1141     VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket "
1142                << this << "(state=" << state_ << ")";
1143     return EINVAL;
1144
1145   }
1146
1147   int value = noDelay ? 1 : 0;
1148   if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
1149     int errnoCopy = errno;
1150     VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket "
1151             << this << " (fd=" << fd_ << ", state=" << state_ << "): "
1152             << strerror(errnoCopy);
1153     return errnoCopy;
1154   }
1155
1156   return 0;
1157 }
1158
1159 int AsyncSocket::setCongestionFlavor(const std::string &cname) {
1160
1161   #ifndef TCP_CONGESTION
1162   #define TCP_CONGESTION  13
1163   #endif
1164
1165   if (fd_ < 0) {
1166     VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
1167                << "socket " << this << "(state=" << state_ << ")";
1168     return EINVAL;
1169
1170   }
1171
1172   if (setsockopt(fd_, IPPROTO_TCP, TCP_CONGESTION, cname.c_str(),
1173         cname.length() + 1) != 0) {
1174     int errnoCopy = errno;
1175     VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket "
1176             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1177             << strerror(errnoCopy);
1178     return errnoCopy;
1179   }
1180
1181   return 0;
1182 }
1183
1184 int AsyncSocket::setQuickAck(bool quickack) {
1185   if (fd_ < 0) {
1186     VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket "
1187                << this << "(state=" << state_ << ")";
1188     return EINVAL;
1189
1190   }
1191
1192 #ifdef TCP_QUICKACK // Linux-only
1193   int value = quickack ? 1 : 0;
1194   if (setsockopt(fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
1195     int errnoCopy = errno;
1196     VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket"
1197             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1198             << strerror(errnoCopy);
1199     return errnoCopy;
1200   }
1201
1202   return 0;
1203 #else
1204   return ENOSYS;
1205 #endif
1206 }
1207
1208 int AsyncSocket::setSendBufSize(size_t bufsize) {
1209   if (fd_ < 0) {
1210     VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
1211                << this << "(state=" << state_ << ")";
1212     return EINVAL;
1213   }
1214
1215   if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) !=0) {
1216     int errnoCopy = errno;
1217     VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket"
1218             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1219             << strerror(errnoCopy);
1220     return errnoCopy;
1221   }
1222
1223   return 0;
1224 }
1225
1226 int AsyncSocket::setRecvBufSize(size_t bufsize) {
1227   if (fd_ < 0) {
1228     VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
1229                << this << "(state=" << state_ << ")";
1230     return EINVAL;
1231   }
1232
1233   if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) !=0) {
1234     int errnoCopy = errno;
1235     VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket"
1236             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1237             << strerror(errnoCopy);
1238     return errnoCopy;
1239   }
1240
1241   return 0;
1242 }
1243
1244 int AsyncSocket::setTCPProfile(int profd) {
1245   if (fd_ < 0) {
1246     VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket "
1247                << this << "(state=" << state_ << ")";
1248     return EINVAL;
1249   }
1250
1251   if (setsockopt(fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) !=0) {
1252     int errnoCopy = errno;
1253     VLOG(2) << "failed to set socket namespace option on AsyncSocket"
1254             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1255             << strerror(errnoCopy);
1256     return errnoCopy;
1257   }
1258
1259   return 0;
1260 }
1261
1262 void AsyncSocket::ioReady(uint16_t events) noexcept {
1263   VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
1264           << ", events=" << std::hex << events << ", state=" << state_;
1265   DestructorGuard dg(this);
1266   assert(events & EventHandler::READ_WRITE);
1267   assert(eventBase_->isInEventBaseThread());
1268
1269   uint16_t relevantEvents = events & EventHandler::READ_WRITE;
1270   if (relevantEvents == EventHandler::READ) {
1271     handleRead();
1272   } else if (relevantEvents == EventHandler::WRITE) {
1273     handleWrite();
1274   } else if (relevantEvents == EventHandler::READ_WRITE) {
1275     EventBase* originalEventBase = eventBase_;
1276     // If both read and write events are ready, process writes first.
1277     handleWrite();
1278
1279     // Return now if handleWrite() detached us from our EventBase
1280     if (eventBase_ != originalEventBase) {
1281       return;
1282     }
1283
1284     // Only call handleRead() if a read callback is still installed.
1285     // (It's possible that the read callback was uninstalled during
1286     // handleWrite().)
1287     if (readCallback_) {
1288       handleRead();
1289     }
1290   } else {
1291     VLOG(4) << "AsyncSocket::ioRead() called with unexpected events "
1292                << std::hex << events << "(this=" << this << ")";
1293     abort();
1294   }
1295 }
1296
1297 AsyncSocket::ReadResult
1298 AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
1299   VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
1300           << ", buflen=" << *buflen;
1301
1302   int recvFlags = 0;
1303   if (peek_) {
1304     recvFlags |= MSG_PEEK;
1305   }
1306
1307   ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT | recvFlags);
1308   if (bytes < 0) {
1309     if (errno == EAGAIN || errno == EWOULDBLOCK) {
1310       // No more data to read right now.
1311       return ReadResult(READ_BLOCKING);
1312     } else {
1313       return ReadResult(READ_ERROR);
1314     }
1315   } else {
1316     appBytesReceived_ += bytes;
1317     return ReadResult(bytes);
1318   }
1319 }
1320
1321 void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) noexcept {
1322   // no matter what, buffer should be preapared for non-ssl socket
1323   CHECK(readCallback_);
1324   readCallback_->getReadBuffer(buf, buflen);
1325 }
1326
1327 void AsyncSocket::handleRead() noexcept {
1328   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
1329           << ", state=" << state_;
1330   assert(state_ == StateEnum::ESTABLISHED);
1331   assert((shutdownFlags_ & SHUT_READ) == 0);
1332   assert(readCallback_ != nullptr);
1333   assert(eventFlags_ & EventHandler::READ);
1334
1335   // Loop until:
1336   // - a read attempt would block
1337   // - readCallback_ is uninstalled
1338   // - the number of loop iterations exceeds the optional maximum
1339   // - this AsyncSocket is moved to another EventBase
1340   //
1341   // When we invoke readDataAvailable() it may uninstall the readCallback_,
1342   // which is why need to check for it here.
1343   //
1344   // The last bullet point is slightly subtle.  readDataAvailable() may also
1345   // detach this socket from this EventBase.  However, before
1346   // readDataAvailable() returns another thread may pick it up, attach it to
1347   // a different EventBase, and install another readCallback_.  We need to
1348   // exit immediately after readDataAvailable() returns if the eventBase_ has
1349   // changed.  (The caller must perform some sort of locking to transfer the
1350   // AsyncSocket between threads properly.  This will be sufficient to ensure
1351   // that this thread sees the updated eventBase_ variable after
1352   // readDataAvailable() returns.)
1353   uint16_t numReads = 0;
1354   EventBase* originalEventBase = eventBase_;
1355   while (readCallback_ && eventBase_ == originalEventBase) {
1356     // Get the buffer to read into.
1357     void* buf = nullptr;
1358     size_t buflen = 0, offset = 0;
1359     try {
1360       prepareReadBuffer(&buf, &buflen);
1361       VLOG(5) << "prepareReadBuffer() buf=" << buf << ", buflen=" << buflen;
1362     } catch (const AsyncSocketException& ex) {
1363       return failRead(__func__, ex);
1364     } catch (const std::exception& ex) {
1365       AsyncSocketException tex(AsyncSocketException::BAD_ARGS,
1366                               string("ReadCallback::getReadBuffer() "
1367                                      "threw exception: ") +
1368                               ex.what());
1369       return failRead(__func__, tex);
1370     } catch (...) {
1371       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1372                              "ReadCallback::getReadBuffer() threw "
1373                              "non-exception type");
1374       return failRead(__func__, ex);
1375     }
1376     if (!isBufferMovable_ && (buf == nullptr || buflen == 0)) {
1377       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1378                              "ReadCallback::getReadBuffer() returned "
1379                              "empty buffer");
1380       return failRead(__func__, ex);
1381     }
1382
1383     // Perform the read
1384     auto readResult = performRead(&buf, &buflen, &offset);
1385     auto bytesRead = readResult.readReturn;
1386     VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
1387             << bytesRead << " bytes";
1388     if (bytesRead > 0) {
1389       if (!isBufferMovable_) {
1390         readCallback_->readDataAvailable(bytesRead);
1391       } else {
1392         CHECK(kOpenSslModeMoveBufferOwnership);
1393         VLOG(5) << "this=" << this << ", AsyncSocket::handleRead() got "
1394                 << "buf=" << buf << ", " << bytesRead << "/" << buflen
1395                 << ", offset=" << offset;
1396         auto readBuf = folly::IOBuf::takeOwnership(buf, buflen);
1397         readBuf->trimStart(offset);
1398         readBuf->trimEnd(buflen - offset - bytesRead);
1399         readCallback_->readBufferAvailable(std::move(readBuf));
1400       }
1401
1402       // Fall through and continue around the loop if the read
1403       // completely filled the available buffer.
1404       // Note that readCallback_ may have been uninstalled or changed inside
1405       // readDataAvailable().
1406       if (size_t(bytesRead) < buflen) {
1407         return;
1408       }
1409     } else if (bytesRead == READ_BLOCKING) {
1410         // No more data to read right now.
1411         return;
1412     } else if (bytesRead == READ_ERROR) {
1413       readErr_ = READ_ERROR;
1414       if (readResult.exception) {
1415         return failRead(__func__, *readResult.exception);
1416       }
1417       auto errnoCopy = errno;
1418       AsyncSocketException ex(
1419           AsyncSocketException::INTERNAL_ERROR,
1420           withAddr("recv() failed"),
1421           errnoCopy);
1422       return failRead(__func__, ex);
1423     } else {
1424       assert(bytesRead == READ_EOF);
1425       readErr_ = READ_EOF;
1426       // EOF
1427       shutdownFlags_ |= SHUT_READ;
1428       if (!updateEventRegistration(0, EventHandler::READ)) {
1429         // we've already been moved into STATE_ERROR
1430         assert(state_ == StateEnum::ERROR);
1431         assert(readCallback_ == nullptr);
1432         return;
1433       }
1434
1435       ReadCallback* callback = readCallback_;
1436       readCallback_ = nullptr;
1437       callback->readEOF();
1438       return;
1439     }
1440     if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
1441       if (readCallback_ != nullptr) {
1442         // We might still have data in the socket.
1443         // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
1444         scheduleImmediateRead();
1445       }
1446       return;
1447     }
1448   }
1449 }
1450
1451 /**
1452  * This function attempts to write as much data as possible, until no more data
1453  * can be written.
1454  *
1455  * - If it sends all available data, it unregisters for write events, and stops
1456  *   the writeTimeout_.
1457  *
1458  * - If not all of the data can be sent immediately, it reschedules
1459  *   writeTimeout_ (if a non-zero timeout is set), and ensures the handler is
1460  *   registered for write events.
1461  */
1462 void AsyncSocket::handleWrite() noexcept {
1463   VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
1464           << ", state=" << state_;
1465   DestructorGuard dg(this);
1466
1467   if (state_ == StateEnum::CONNECTING) {
1468     handleConnect();
1469     return;
1470   }
1471
1472   // Normal write
1473   assert(state_ == StateEnum::ESTABLISHED);
1474   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1475   assert(writeReqHead_ != nullptr);
1476
1477   // Loop until we run out of write requests,
1478   // or until this socket is moved to another EventBase.
1479   // (See the comment in handleRead() explaining how this can happen.)
1480   EventBase* originalEventBase = eventBase_;
1481   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
1482     auto writeResult = writeReqHead_->performWrite();
1483     if (writeResult.writeReturn < 0) {
1484       if (writeResult.exception) {
1485         return failWrite(__func__, *writeResult.exception);
1486       }
1487       auto errnoCopy = errno;
1488       AsyncSocketException ex(
1489           AsyncSocketException::INTERNAL_ERROR,
1490           withAddr("writev() failed"),
1491           errnoCopy);
1492       return failWrite(__func__, ex);
1493     } else if (writeReqHead_->isComplete()) {
1494       // We finished this request
1495       WriteRequest* req = writeReqHead_;
1496       writeReqHead_ = req->getNext();
1497
1498       if (writeReqHead_ == nullptr) {
1499         writeReqTail_ = nullptr;
1500         // This is the last write request.
1501         // Unregister for write events and cancel the send timer
1502         // before we invoke the callback.  We have to update the state properly
1503         // before calling the callback, since it may want to detach us from
1504         // the EventBase.
1505         if (eventFlags_ & EventHandler::WRITE) {
1506           if (!updateEventRegistration(0, EventHandler::WRITE)) {
1507             assert(state_ == StateEnum::ERROR);
1508             return;
1509           }
1510           // Stop the send timeout
1511           writeTimeout_.cancelTimeout();
1512         }
1513         assert(!writeTimeout_.isScheduled());
1514
1515         // If SHUT_WRITE_PENDING is set, we should shutdown the socket after
1516         // we finish sending the last write request.
1517         //
1518         // We have to do this before invoking writeSuccess(), since
1519         // writeSuccess() may detach us from our EventBase.
1520         if (shutdownFlags_ & SHUT_WRITE_PENDING) {
1521           assert(connectCallback_ == nullptr);
1522           shutdownFlags_ |= SHUT_WRITE;
1523
1524           if (shutdownFlags_ & SHUT_READ) {
1525             // Reads have already been shutdown.  Fully close the socket and
1526             // move to STATE_CLOSED.
1527             //
1528             // Note: This code currently moves us to STATE_CLOSED even if
1529             // close() hasn't ever been called.  This can occur if we have
1530             // received EOF from the peer and shutdownWrite() has been called
1531             // locally.  Should we bother staying in STATE_ESTABLISHED in this
1532             // case, until close() is actually called?  I can't think of a
1533             // reason why we would need to do so.  No other operations besides
1534             // calling close() or destroying the socket can be performed at
1535             // this point.
1536             assert(readCallback_ == nullptr);
1537             state_ = StateEnum::CLOSED;
1538             if (fd_ >= 0) {
1539               ioHandler_.changeHandlerFD(-1);
1540               doClose();
1541             }
1542           } else {
1543             // Reads are still enabled, so we are only doing a half-shutdown
1544             ::shutdown(fd_, SHUT_WR);
1545           }
1546         }
1547       }
1548
1549       // Invoke the callback
1550       WriteCallback* callback = req->getCallback();
1551       req->destroy();
1552       if (callback) {
1553         callback->writeSuccess();
1554       }
1555       // We'll continue around the loop, trying to write another request
1556     } else {
1557       // Partial write.
1558       if (bufferCallback_) {
1559         bufferCallback_->onEgressBuffered();
1560       }
1561       writeReqHead_->consume();
1562       // Stop after a partial write; it's highly likely that a subsequent write
1563       // attempt will just return EAGAIN.
1564       //
1565       // Ensure that we are registered for write events.
1566       if ((eventFlags_ & EventHandler::WRITE) == 0) {
1567         if (!updateEventRegistration(EventHandler::WRITE, 0)) {
1568           assert(state_ == StateEnum::ERROR);
1569           return;
1570         }
1571       }
1572
1573       // Reschedule the send timeout, since we have made some write progress.
1574       if (sendTimeout_ > 0) {
1575         if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
1576           AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1577               withAddr("failed to reschedule write timeout"));
1578           return failWrite(__func__, ex);
1579         }
1580       }
1581       return;
1582     }
1583   }
1584   if (!writeReqHead_ && bufferCallback_) {
1585     bufferCallback_->onEgressBufferCleared();
1586   }
1587 }
1588
1589 void AsyncSocket::checkForImmediateRead() noexcept {
1590   // We currently don't attempt to perform optimistic reads in AsyncSocket.
1591   // (However, note that some subclasses do override this method.)
1592   //
1593   // Simply calling handleRead() here would be bad, as this would call
1594   // readCallback_->getReadBuffer(), forcing the callback to allocate a read
1595   // buffer even though no data may be available.  This would waste lots of
1596   // memory, since the buffer will sit around unused until the socket actually
1597   // becomes readable.
1598   //
1599   // Checking if the socket is readable now also seems like it would probably
1600   // be a pessimism.  In most cases it probably wouldn't be readable, and we
1601   // would just waste an extra system call.  Even if it is readable, waiting to
1602   // find out from libevent on the next event loop doesn't seem that bad.
1603 }
1604
1605 void AsyncSocket::handleInitialReadWrite() noexcept {
1606   // Our callers should already be holding a DestructorGuard, but grab
1607   // one here just to make sure, in case one of our calling code paths ever
1608   // changes.
1609   DestructorGuard dg(this);
1610
1611   // If we have a readCallback_, make sure we enable read events.  We
1612   // may already be registered for reads if connectSuccess() set
1613   // the read calback.
1614   if (readCallback_ && !(eventFlags_ & EventHandler::READ)) {
1615     assert(state_ == StateEnum::ESTABLISHED);
1616     assert((shutdownFlags_ & SHUT_READ) == 0);
1617     if (!updateEventRegistration(EventHandler::READ, 0)) {
1618       assert(state_ == StateEnum::ERROR);
1619       return;
1620     }
1621     checkForImmediateRead();
1622   } else if (readCallback_ == nullptr) {
1623     // Unregister for read events.
1624     updateEventRegistration(0, EventHandler::READ);
1625   }
1626
1627   // If we have write requests pending, try to send them immediately.
1628   // Since we just finished accepting, there is a very good chance that we can
1629   // write without blocking.
1630   //
1631   // However, we only process them if EventHandler::WRITE is not already set,
1632   // which means that we're already blocked on a write attempt.  (This can
1633   // happen if connectSuccess() called write() before returning.)
1634   if (writeReqHead_ && !(eventFlags_ & EventHandler::WRITE)) {
1635     // Call handleWrite() to perform write processing.
1636     handleWrite();
1637   } else if (writeReqHead_ == nullptr) {
1638     // Unregister for write event.
1639     updateEventRegistration(0, EventHandler::WRITE);
1640   }
1641 }
1642
1643 void AsyncSocket::handleConnect() noexcept {
1644   VLOG(5) << "AsyncSocket::handleConnect() this=" << this << ", fd=" << fd_
1645           << ", state=" << state_;
1646   assert(state_ == StateEnum::CONNECTING);
1647   // SHUT_WRITE can never be set while we are still connecting;
1648   // SHUT_WRITE_PENDING may be set, be we only set SHUT_WRITE once the connect
1649   // finishes
1650   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1651
1652   // In case we had a connect timeout, cancel the timeout
1653   writeTimeout_.cancelTimeout();
1654   // We don't use a persistent registration when waiting on a connect event,
1655   // so we have been automatically unregistered now.  Update eventFlags_ to
1656   // reflect reality.
1657   assert(eventFlags_ == EventHandler::WRITE);
1658   eventFlags_ = EventHandler::NONE;
1659
1660   // Call getsockopt() to check if the connect succeeded
1661   int error;
1662   socklen_t len = sizeof(error);
1663   int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
1664   if (rv != 0) {
1665     auto errnoCopy = errno;
1666     AsyncSocketException ex(
1667         AsyncSocketException::INTERNAL_ERROR,
1668         withAddr("error calling getsockopt() after connect"),
1669         errnoCopy);
1670     VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1671                << fd_ << " host=" << addr_.describe()
1672                << ") exception:" << ex.what();
1673     return failConnect(__func__, ex);
1674   }
1675
1676   if (error != 0) {
1677     AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1678                            "connect failed", error);
1679     VLOG(1) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1680             << fd_ << " host=" << addr_.describe()
1681             << ") exception: " << ex.what();
1682     return failConnect(__func__, ex);
1683   }
1684
1685   // Move into STATE_ESTABLISHED
1686   state_ = StateEnum::ESTABLISHED;
1687
1688   // If SHUT_WRITE_PENDING is set and we don't have any write requests to
1689   // perform, immediately shutdown the write half of the socket.
1690   if ((shutdownFlags_ & SHUT_WRITE_PENDING) && writeReqHead_ == nullptr) {
1691     // SHUT_READ shouldn't be set.  If close() is called on the socket while we
1692     // are still connecting we just abort the connect rather than waiting for
1693     // it to complete.
1694     assert((shutdownFlags_ & SHUT_READ) == 0);
1695     ::shutdown(fd_, SHUT_WR);
1696     shutdownFlags_ |= SHUT_WRITE;
1697   }
1698
1699   VLOG(7) << "AsyncSocket " << this << ": fd " << fd_
1700           << "successfully connected; state=" << state_;
1701
1702   // Remember the EventBase we are attached to, before we start invoking any
1703   // callbacks (since the callbacks may call detachEventBase()).
1704   EventBase* originalEventBase = eventBase_;
1705
1706   invokeConnectSuccess();
1707   // Note that the connect callback may have changed our state.
1708   // (set or unset the read callback, called write(), closed the socket, etc.)
1709   // The following code needs to handle these situations correctly.
1710   //
1711   // If the socket has been closed, readCallback_ and writeReqHead_ will
1712   // always be nullptr, so that will prevent us from trying to read or write.
1713   //
1714   // The main thing to check for is if eventBase_ is still originalEventBase.
1715   // If not, we have been detached from this event base, so we shouldn't
1716   // perform any more operations.
1717   if (eventBase_ != originalEventBase) {
1718     return;
1719   }
1720
1721   handleInitialReadWrite();
1722 }
1723
1724 void AsyncSocket::timeoutExpired() noexcept {
1725   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
1726           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
1727   DestructorGuard dg(this);
1728   assert(eventBase_->isInEventBaseThread());
1729
1730   if (state_ == StateEnum::CONNECTING) {
1731     // connect() timed out
1732     // Unregister for I/O events.
1733     if (connectCallback_) {
1734       AsyncSocketException ex(
1735           AsyncSocketException::TIMED_OUT, "connect timed out");
1736       failConnect(__func__, ex);
1737     } else {
1738       // we faced a connect error without a connect callback, which could
1739       // happen due to TFO.
1740       AsyncSocketException ex(
1741           AsyncSocketException::TIMED_OUT, "write timed out during connection");
1742       failWrite(__func__, ex);
1743     }
1744   } else {
1745     // a normal write operation timed out
1746     AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out");
1747     failWrite(__func__, ex);
1748   }
1749 }
1750
1751 ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
1752   return detail::tfo_sendmsg(fd, msg, msg_flags);
1753 }
1754
1755 AsyncSocket::WriteResult
1756 AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
1757   ssize_t totalWritten = 0;
1758   if (state_ == StateEnum::FAST_OPEN) {
1759     sockaddr_storage addr;
1760     auto len = addr_.getAddress(&addr);
1761     msg->msg_name = &addr;
1762     msg->msg_namelen = len;
1763     totalWritten = tfoSendMsg(fd_, msg, msg_flags);
1764     if (totalWritten >= 0) {
1765       tfoFinished_ = true;
1766       state_ = StateEnum::ESTABLISHED;
1767       handleInitialReadWrite();
1768     } else if (errno == EINPROGRESS) {
1769       VLOG(4) << "TFO falling back to connecting";
1770       // A normal sendmsg doesn't return EINPROGRESS, however
1771       // TFO might fallback to connecting if there is no
1772       // cookie.
1773       state_ = StateEnum::CONNECTING;
1774       try {
1775         scheduleConnectTimeoutAndRegisterForEvents();
1776       } catch (const AsyncSocketException& ex) {
1777         return WriteResult(
1778             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
1779       }
1780       // Let's fake it that no bytes were written and return an errno.
1781       errno = EAGAIN;
1782       totalWritten = -1;
1783     } else if (errno == EOPNOTSUPP) {
1784       VLOG(4) << "TFO not supported";
1785       // Try falling back to connecting.
1786       state_ = StateEnum::CONNECTING;
1787       try {
1788         int ret = socketConnect((const sockaddr*)&addr, len);
1789         if (ret == 0) {
1790           // connect succeeded immediately
1791           // Treat this like no data was written.
1792           state_ = StateEnum::ESTABLISHED;
1793           handleInitialReadWrite();
1794         }
1795         // If there was no exception during connections,
1796         // we would return that no bytes were written.
1797         errno = EAGAIN;
1798         totalWritten = -1;
1799       } catch (const AsyncSocketException& ex) {
1800         return WriteResult(
1801             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
1802       }
1803     } else if (errno == EAGAIN) {
1804       // Normally sendmsg would indicate that the write would block.
1805       // However in the fast open case, it would indicate that sendmsg
1806       // fell back to a connect. This is a return code from connect()
1807       // instead, and is an error condition indicating no fds available.
1808       return WriteResult(
1809           WRITE_ERROR,
1810           folly::make_unique<AsyncSocketException>(
1811               AsyncSocketException::UNKNOWN, "No more free local ports"));
1812     }
1813   } else {
1814     totalWritten = ::sendmsg(fd, msg, msg_flags);
1815   }
1816   return WriteResult(totalWritten);
1817 }
1818
1819 AsyncSocket::WriteResult AsyncSocket::performWrite(
1820     const iovec* vec,
1821     uint32_t count,
1822     WriteFlags flags,
1823     uint32_t* countWritten,
1824     uint32_t* partialWritten) {
1825   // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
1826   // We correctly handle EPIPE errors, so we never want to receive SIGPIPE
1827   // (since it may terminate the program if the main program doesn't explicitly
1828   // ignore it).
1829   struct msghdr msg;
1830   msg.msg_name = nullptr;
1831   msg.msg_namelen = 0;
1832   msg.msg_iov = const_cast<iovec *>(vec);
1833   msg.msg_iovlen = std::min<size_t>(count, kIovMax);
1834   msg.msg_control = nullptr;
1835   msg.msg_controllen = 0;
1836   msg.msg_flags = 0;
1837
1838   int msg_flags = MSG_DONTWAIT;
1839
1840 #ifdef MSG_NOSIGNAL // Linux-only
1841   msg_flags |= MSG_NOSIGNAL;
1842   if (isSet(flags, WriteFlags::CORK)) {
1843     // MSG_MORE tells the kernel we have more data to send, so wait for us to
1844     // give it the rest of the data rather than immediately sending a partial
1845     // frame, even when TCP_NODELAY is enabled.
1846     msg_flags |= MSG_MORE;
1847   }
1848 #endif
1849   if (isSet(flags, WriteFlags::EOR)) {
1850     // marks that this is the last byte of a record (response)
1851     msg_flags |= MSG_EOR;
1852   }
1853   auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
1854   auto totalWritten = writeResult.writeReturn;
1855   if (totalWritten < 0) {
1856     if (!writeResult.exception && errno == EAGAIN) {
1857       // TCP buffer is full; we can't write any more data right now.
1858       *countWritten = 0;
1859       *partialWritten = 0;
1860       return WriteResult(0);
1861     }
1862     // error
1863     *countWritten = 0;
1864     *partialWritten = 0;
1865     return writeResult;
1866   }
1867
1868   appBytesWritten_ += totalWritten;
1869
1870   uint32_t bytesWritten;
1871   uint32_t n;
1872   for (bytesWritten = totalWritten, n = 0; n < count; ++n) {
1873     const iovec* v = vec + n;
1874     if (v->iov_len > bytesWritten) {
1875       // Partial write finished in the middle of this iovec
1876       *countWritten = n;
1877       *partialWritten = bytesWritten;
1878       return WriteResult(totalWritten);
1879     }
1880
1881     bytesWritten -= v->iov_len;
1882   }
1883
1884   assert(bytesWritten == 0);
1885   *countWritten = n;
1886   *partialWritten = 0;
1887   return WriteResult(totalWritten);
1888 }
1889
1890 /**
1891  * Re-register the EventHandler after eventFlags_ has changed.
1892  *
1893  * If an error occurs, fail() is called to move the socket into the error state
1894  * and call all currently installed callbacks.  After an error, the
1895  * AsyncSocket is completely unregistered.
1896  *
1897  * @return Returns true on succcess, or false on error.
1898  */
1899 bool AsyncSocket::updateEventRegistration() {
1900   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
1901           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
1902           << ", events=" << std::hex << eventFlags_;
1903   assert(eventBase_->isInEventBaseThread());
1904   if (eventFlags_ == EventHandler::NONE) {
1905     ioHandler_.unregisterHandler();
1906     return true;
1907   }
1908
1909   // Always register for persistent events, so we don't have to re-register
1910   // after being called back.
1911   if (!ioHandler_.registerHandler(eventFlags_ | EventHandler::PERSIST)) {
1912     eventFlags_ = EventHandler::NONE; // we're not registered after error
1913     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1914         withAddr("failed to update AsyncSocket event registration"));
1915     fail("updateEventRegistration", ex);
1916     return false;
1917   }
1918
1919   return true;
1920 }
1921
1922 bool AsyncSocket::updateEventRegistration(uint16_t enable,
1923                                            uint16_t disable) {
1924   uint16_t oldFlags = eventFlags_;
1925   eventFlags_ |= enable;
1926   eventFlags_ &= ~disable;
1927   if (eventFlags_ == oldFlags) {
1928     return true;
1929   } else {
1930     return updateEventRegistration();
1931   }
1932 }
1933
1934 void AsyncSocket::startFail() {
1935   // startFail() should only be called once
1936   assert(state_ != StateEnum::ERROR);
1937   assert(getDestructorGuardCount() > 0);
1938   state_ = StateEnum::ERROR;
1939   // Ensure that SHUT_READ and SHUT_WRITE are set,
1940   // so all future attempts to read or write will be rejected
1941   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1942
1943   if (eventFlags_ != EventHandler::NONE) {
1944     eventFlags_ = EventHandler::NONE;
1945     ioHandler_.unregisterHandler();
1946   }
1947   writeTimeout_.cancelTimeout();
1948
1949   if (fd_ >= 0) {
1950     ioHandler_.changeHandlerFD(-1);
1951     doClose();
1952   }
1953 }
1954
1955 void AsyncSocket::finishFail() {
1956   assert(state_ == StateEnum::ERROR);
1957   assert(getDestructorGuardCount() > 0);
1958
1959   AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1960                          withAddr("socket closing after error"));
1961   invokeConnectErr(ex);
1962   failAllWrites(ex);
1963
1964   if (readCallback_) {
1965     ReadCallback* callback = readCallback_;
1966     readCallback_ = nullptr;
1967     callback->readErr(ex);
1968   }
1969 }
1970
1971 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
1972   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1973              << state_ << " host=" << addr_.describe()
1974              << "): failed in " << fn << "(): "
1975              << ex.what();
1976   startFail();
1977   finishFail();
1978 }
1979
1980 void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
1981   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1982                << state_ << " host=" << addr_.describe()
1983                << "): failed while connecting in " << fn << "(): "
1984                << ex.what();
1985   startFail();
1986
1987   invokeConnectErr(ex);
1988   finishFail();
1989 }
1990
1991 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
1992   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1993                << state_ << " host=" << addr_.describe()
1994                << "): failed while reading in " << fn << "(): "
1995                << ex.what();
1996   startFail();
1997
1998   if (readCallback_ != nullptr) {
1999     ReadCallback* callback = readCallback_;
2000     readCallback_ = nullptr;
2001     callback->readErr(ex);
2002   }
2003
2004   finishFail();
2005 }
2006
2007 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
2008   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2009                << state_ << " host=" << addr_.describe()
2010                << "): failed while writing in " << fn << "(): "
2011                << ex.what();
2012   startFail();
2013
2014   // Only invoke the first write callback, since the error occurred while
2015   // writing this request.  Let any other pending write callbacks be invoked in
2016   // finishFail().
2017   if (writeReqHead_ != nullptr) {
2018     WriteRequest* req = writeReqHead_;
2019     writeReqHead_ = req->getNext();
2020     WriteCallback* callback = req->getCallback();
2021     uint32_t bytesWritten = req->getTotalBytesWritten();
2022     req->destroy();
2023     if (callback) {
2024       callback->writeErr(bytesWritten, ex);
2025     }
2026   }
2027
2028   finishFail();
2029 }
2030
2031 void AsyncSocket::failWrite(const char* fn, WriteCallback* callback,
2032                              size_t bytesWritten,
2033                              const AsyncSocketException& ex) {
2034   // This version of failWrite() is used when the failure occurs before
2035   // we've added the callback to writeReqHead_.
2036   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2037              << state_ << " host=" << addr_.describe()
2038              <<"): failed while writing in " << fn << "(): "
2039              << ex.what();
2040   startFail();
2041
2042   if (callback != nullptr) {
2043     callback->writeErr(bytesWritten, ex);
2044   }
2045
2046   finishFail();
2047 }
2048
2049 void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
2050   // Invoke writeError() on all write callbacks.
2051   // This is used when writes are forcibly shutdown with write requests
2052   // pending, or when an error occurs with writes pending.
2053   while (writeReqHead_ != nullptr) {
2054     WriteRequest* req = writeReqHead_;
2055     writeReqHead_ = req->getNext();
2056     WriteCallback* callback = req->getCallback();
2057     if (callback) {
2058       callback->writeErr(req->getTotalBytesWritten(), ex);
2059     }
2060     req->destroy();
2061   }
2062 }
2063
2064 void AsyncSocket::invalidState(ConnectCallback* callback) {
2065   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
2066              << "): connect() called in invalid state " << state_;
2067
2068   /*
2069    * The invalidState() methods don't use the normal failure mechanisms,
2070    * since we don't know what state we are in.  We don't want to call
2071    * startFail()/finishFail() recursively if we are already in the middle of
2072    * cleaning up.
2073    */
2074
2075   AsyncSocketException ex(AsyncSocketException::ALREADY_OPEN,
2076                          "connect() called with socket in invalid state");
2077   connectEndTime_ = std::chrono::steady_clock::now();
2078   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2079     if (callback) {
2080       callback->connectErr(ex);
2081     }
2082   } else {
2083     // We can't use failConnect() here since connectCallback_
2084     // may already be set to another callback.  Invoke this ConnectCallback
2085     // here; any other connectCallback_ will be invoked in finishFail()
2086     startFail();
2087     if (callback) {
2088       callback->connectErr(ex);
2089     }
2090     finishFail();
2091   }
2092 }
2093
2094 void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
2095   connectEndTime_ = std::chrono::steady_clock::now();
2096   if (connectCallback_) {
2097     ConnectCallback* callback = connectCallback_;
2098     connectCallback_ = nullptr;
2099     callback->connectErr(ex);
2100   }
2101 }
2102
2103 void AsyncSocket::invokeConnectSuccess() {
2104   connectEndTime_ = std::chrono::steady_clock::now();
2105   if (connectCallback_) {
2106     ConnectCallback* callback = connectCallback_;
2107     connectCallback_ = nullptr;
2108     callback->connectSuccess();
2109   }
2110 }
2111
2112 void AsyncSocket::invalidState(ReadCallback* callback) {
2113   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2114              << "): setReadCallback(" << callback
2115              << ") called in invalid state " << state_;
2116
2117   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2118                          "setReadCallback() called with socket in "
2119                          "invalid state");
2120   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2121     if (callback) {
2122       callback->readErr(ex);
2123     }
2124   } else {
2125     startFail();
2126     if (callback) {
2127       callback->readErr(ex);
2128     }
2129     finishFail();
2130   }
2131 }
2132
2133 void AsyncSocket::invalidState(WriteCallback* callback) {
2134   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2135              << "): write() called in invalid state " << state_;
2136
2137   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2138                          withAddr("write() called with socket in invalid state"));
2139   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2140     if (callback) {
2141       callback->writeErr(0, ex);
2142     }
2143   } else {
2144     startFail();
2145     if (callback) {
2146       callback->writeErr(0, ex);
2147     }
2148     finishFail();
2149   }
2150 }
2151
2152 void AsyncSocket::doClose() {
2153   if (fd_ == -1) return;
2154   if (shutdownSocketSet_) {
2155     shutdownSocketSet_->close(fd_);
2156   } else {
2157     ::close(fd_);
2158   }
2159   fd_ = -1;
2160 }
2161
2162 std::ostream& operator << (std::ostream& os,
2163                            const AsyncSocket::StateEnum& state) {
2164   os << static_cast<int>(state);
2165   return os;
2166 }
2167
2168 std::string AsyncSocket::withAddr(const std::string& s) {
2169   // Don't use addr_ directly because it may not be initialized
2170   // e.g. if constructed from fd
2171   folly::SocketAddress peer, local;
2172   try {
2173     getPeerAddress(&peer);
2174     getLocalAddress(&local);
2175   } catch (const std::exception&) {
2176     // ignore
2177   } catch (...) {
2178     // ignore
2179   }
2180   return s + " (peer=" + peer.describe() + ", local=" + local.describe() + ")";
2181 }
2182
2183 void AsyncSocket::setBufferCallback(BufferCallback* cb) {
2184   bufferCallback_ = cb;
2185 }
2186
2187 } // folly