Add evb change callback to SSL Socket
[folly.git] / folly / io / async / AsyncSocket.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSocket.h>
18
19 #include <folly/ExceptionWrapper.h>
20 #include <folly/SocketAddress.h>
21 #include <folly/io/IOBuf.h>
22 #include <folly/Portability.h>
23 #include <folly/portability/Fcntl.h>
24 #include <folly/portability/Sockets.h>
25 #include <folly/portability/SysUio.h>
26 #include <folly/portability/Unistd.h>
27
28 #include <errno.h>
29 #include <limits.h>
30 #include <thread>
31 #include <sys/types.h>
32 #include <boost/preprocessor/control/if.hpp>
33
34 using std::string;
35 using std::unique_ptr;
36
37 namespace fsp = folly::portability::sockets;
38
39 namespace folly {
40
41 // static members initializers
42 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
43
44 const AsyncSocketException socketClosedLocallyEx(
45     AsyncSocketException::END_OF_FILE, "socket closed locally");
46 const AsyncSocketException socketShutdownForWritesEx(
47     AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
48
49 // TODO: It might help performance to provide a version of BytesWriteRequest that
50 // users could derive from, so we can avoid the extra allocation for each call
51 // to write()/writev().  We could templatize TFramedAsyncChannel just like the
52 // protocols are currently templatized for transports.
53 //
54 // We would need the version for external users where they provide the iovec
55 // storage space, and only our internal version would allocate it at the end of
56 // the WriteRequest.
57
58 /* The default WriteRequest implementation, used for write(), writev() and
59  * writeChain()
60  *
61  * A new BytesWriteRequest operation is allocated on the heap for all write
62  * operations that cannot be completed immediately.
63  */
64 class AsyncSocket::BytesWriteRequest : public AsyncSocket::WriteRequest {
65  public:
66   static BytesWriteRequest* newRequest(AsyncSocket* socket,
67                                        WriteCallback* callback,
68                                        const iovec* ops,
69                                        uint32_t opCount,
70                                        uint32_t partialWritten,
71                                        uint32_t bytesWritten,
72                                        unique_ptr<IOBuf>&& ioBuf,
73                                        WriteFlags flags) {
74     assert(opCount > 0);
75     // Since we put a variable size iovec array at the end
76     // of each BytesWriteRequest, we have to manually allocate the memory.
77     void* buf = malloc(sizeof(BytesWriteRequest) +
78                        (opCount * sizeof(struct iovec)));
79     if (buf == nullptr) {
80       throw std::bad_alloc();
81     }
82
83     return new(buf) BytesWriteRequest(socket, callback, ops, opCount,
84                                       partialWritten, bytesWritten,
85                                       std::move(ioBuf), flags);
86   }
87
88   void destroy() override {
89     this->~BytesWriteRequest();
90     free(this);
91   }
92
93   WriteResult performWrite() override {
94     WriteFlags writeFlags = flags_;
95     if (getNext() != nullptr) {
96       writeFlags = writeFlags | WriteFlags::CORK;
97     }
98     auto writeResult = socket_->performWrite(
99         getOps(), getOpCount(), writeFlags, &opsWritten_, &partialBytes_);
100     bytesWritten_ = writeResult.writeReturn > 0 ? writeResult.writeReturn : 0;
101     return writeResult;
102   }
103
104   bool isComplete() override {
105     return opsWritten_ == getOpCount();
106   }
107
108   void consume() override {
109     // Advance opIndex_ forward by opsWritten_
110     opIndex_ += opsWritten_;
111     assert(opIndex_ < opCount_);
112
113     // If we've finished writing any IOBufs, release them
114     if (ioBuf_) {
115       for (uint32_t i = opsWritten_; i != 0; --i) {
116         assert(ioBuf_);
117         ioBuf_ = ioBuf_->pop();
118       }
119     }
120
121     // Move partialBytes_ forward into the current iovec buffer
122     struct iovec* currentOp = writeOps_ + opIndex_;
123     assert((partialBytes_ < currentOp->iov_len) || (currentOp->iov_len == 0));
124     currentOp->iov_base =
125       reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes_;
126     currentOp->iov_len -= partialBytes_;
127
128     // Increment the totalBytesWritten_ count by bytesWritten_;
129     assert(bytesWritten_ >= 0);
130     totalBytesWritten_ += uint32_t(bytesWritten_);
131   }
132
133  private:
134   BytesWriteRequest(AsyncSocket* socket,
135                     WriteCallback* callback,
136                     const struct iovec* ops,
137                     uint32_t opCount,
138                     uint32_t partialBytes,
139                     uint32_t bytesWritten,
140                     unique_ptr<IOBuf>&& ioBuf,
141                     WriteFlags flags)
142     : AsyncSocket::WriteRequest(socket, callback)
143     , opCount_(opCount)
144     , opIndex_(0)
145     , flags_(flags)
146     , ioBuf_(std::move(ioBuf))
147     , opsWritten_(0)
148     , partialBytes_(partialBytes)
149     , bytesWritten_(bytesWritten) {
150     memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
151   }
152
153   // private destructor, to ensure callers use destroy()
154   ~BytesWriteRequest() override = default;
155
156   const struct iovec* getOps() const {
157     assert(opCount_ > opIndex_);
158     return writeOps_ + opIndex_;
159   }
160
161   uint32_t getOpCount() const {
162     assert(opCount_ > opIndex_);
163     return opCount_ - opIndex_;
164   }
165
166   uint32_t opCount_;            ///< number of entries in writeOps_
167   uint32_t opIndex_;            ///< current index into writeOps_
168   WriteFlags flags_;            ///< set for WriteFlags
169   unique_ptr<IOBuf> ioBuf_;     ///< underlying IOBuf, or nullptr if N/A
170
171   // for consume(), how much we wrote on the last write
172   uint32_t opsWritten_;         ///< complete ops written
173   uint32_t partialBytes_;       ///< partial bytes of incomplete op written
174   ssize_t bytesWritten_;        ///< bytes written altogether
175
176   struct iovec writeOps_[];     ///< write operation(s) list
177 };
178
179 AsyncSocket::AsyncSocket()
180     : eventBase_(nullptr),
181       writeTimeout_(this, nullptr),
182       ioHandler_(this, nullptr),
183       immediateReadHandler_(this) {
184   VLOG(5) << "new AsyncSocket()";
185   init();
186 }
187
188 AsyncSocket::AsyncSocket(EventBase* evb)
189     : eventBase_(evb),
190       writeTimeout_(this, evb),
191       ioHandler_(this, evb),
192       immediateReadHandler_(this) {
193   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
194   init();
195 }
196
197 AsyncSocket::AsyncSocket(EventBase* evb,
198                            const folly::SocketAddress& address,
199                            uint32_t connectTimeout)
200   : AsyncSocket(evb) {
201   connect(nullptr, address, connectTimeout);
202 }
203
204 AsyncSocket::AsyncSocket(EventBase* evb,
205                            const std::string& ip,
206                            uint16_t port,
207                            uint32_t connectTimeout)
208   : AsyncSocket(evb) {
209   connect(nullptr, ip, port, connectTimeout);
210 }
211
212 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
213     : eventBase_(evb),
214       writeTimeout_(this, evb),
215       ioHandler_(this, evb, fd),
216       immediateReadHandler_(this) {
217   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
218           << fd << ")";
219   init();
220   fd_ = fd;
221   setCloseOnExec();
222   state_ = StateEnum::ESTABLISHED;
223 }
224
225 // init() method, since constructor forwarding isn't supported in most
226 // compilers yet.
227 void AsyncSocket::init() {
228   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
229   shutdownFlags_ = 0;
230   state_ = StateEnum::UNINIT;
231   eventFlags_ = EventHandler::NONE;
232   fd_ = -1;
233   sendTimeout_ = 0;
234   maxReadsPerEvent_ = 16;
235   connectCallback_ = nullptr;
236   readCallback_ = nullptr;
237   writeReqHead_ = nullptr;
238   writeReqTail_ = nullptr;
239   shutdownSocketSet_ = nullptr;
240   appBytesWritten_ = 0;
241   appBytesReceived_ = 0;
242 }
243
244 AsyncSocket::~AsyncSocket() {
245   VLOG(7) << "actual destruction of AsyncSocket(this=" << this
246           << ", evb=" << eventBase_ << ", fd=" << fd_
247           << ", state=" << state_ << ")";
248 }
249
250 void AsyncSocket::destroy() {
251   VLOG(5) << "AsyncSocket::destroy(this=" << this << ", evb=" << eventBase_
252           << ", fd=" << fd_ << ", state=" << state_;
253   // When destroy is called, close the socket immediately
254   closeNow();
255
256   // Then call DelayedDestruction::destroy() to take care of
257   // whether or not we need immediate or delayed destruction
258   DelayedDestruction::destroy();
259 }
260
261 int AsyncSocket::detachFd() {
262   VLOG(6) << "AsyncSocket::detachFd(this=" << this << ", fd=" << fd_
263           << ", evb=" << eventBase_ << ", state=" << state_
264           << ", events=" << std::hex << eventFlags_ << ")";
265   // Extract the fd, and set fd_ to -1 first, so closeNow() won't
266   // actually close the descriptor.
267   if (shutdownSocketSet_) {
268     shutdownSocketSet_->remove(fd_);
269   }
270   int fd = fd_;
271   fd_ = -1;
272   // Call closeNow() to invoke all pending callbacks with an error.
273   closeNow();
274   // Update the EventHandler to stop using this fd.
275   // This can only be done after closeNow() unregisters the handler.
276   ioHandler_.changeHandlerFD(-1);
277   return fd;
278 }
279
280 const folly::SocketAddress& AsyncSocket::anyAddress() {
281   static const folly::SocketAddress anyAddress =
282     folly::SocketAddress("0.0.0.0", 0);
283   return anyAddress;
284 }
285
286 void AsyncSocket::setShutdownSocketSet(ShutdownSocketSet* newSS) {
287   if (shutdownSocketSet_ == newSS) {
288     return;
289   }
290   if (shutdownSocketSet_ && fd_ != -1) {
291     shutdownSocketSet_->remove(fd_);
292   }
293   shutdownSocketSet_ = newSS;
294   if (shutdownSocketSet_ && fd_ != -1) {
295     shutdownSocketSet_->add(fd_);
296   }
297 }
298
299 void AsyncSocket::setCloseOnExec() {
300   int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC);
301   if (rv != 0) {
302     auto errnoCopy = errno;
303     throw AsyncSocketException(
304         AsyncSocketException::INTERNAL_ERROR,
305         withAddr("failed to set close-on-exec flag"),
306         errnoCopy);
307   }
308 }
309
310 void AsyncSocket::connect(ConnectCallback* callback,
311                            const folly::SocketAddress& address,
312                            int timeout,
313                            const OptionMap &options,
314                            const folly::SocketAddress& bindAddr) noexcept {
315   DestructorGuard dg(this);
316   assert(eventBase_->isInEventBaseThread());
317
318   addr_ = address;
319
320   // Make sure we're in the uninitialized state
321   if (state_ != StateEnum::UNINIT) {
322     return invalidState(callback);
323   }
324
325   connectTimeout_ = std::chrono::milliseconds(timeout);
326   connectStartTime_ = std::chrono::steady_clock::now();
327   // Make connect end time at least >= connectStartTime.
328   connectEndTime_ = connectStartTime_;
329
330   assert(fd_ == -1);
331   state_ = StateEnum::CONNECTING;
332   connectCallback_ = callback;
333
334   sockaddr_storage addrStorage;
335   sockaddr* saddr = reinterpret_cast<sockaddr*>(&addrStorage);
336
337   try {
338     // Create the socket
339     // Technically the first parameter should actually be a protocol family
340     // constant (PF_xxx) rather than an address family (AF_xxx), but the
341     // distinction is mainly just historical.  In pretty much all
342     // implementations the PF_foo and AF_foo constants are identical.
343     fd_ = fsp::socket(address.getFamily(), SOCK_STREAM, 0);
344     if (fd_ < 0) {
345       auto errnoCopy = errno;
346       throw AsyncSocketException(
347           AsyncSocketException::INTERNAL_ERROR,
348           withAddr("failed to create socket"),
349           errnoCopy);
350     }
351     if (shutdownSocketSet_) {
352       shutdownSocketSet_->add(fd_);
353     }
354     ioHandler_.changeHandlerFD(fd_);
355
356     setCloseOnExec();
357
358     // Put the socket in non-blocking mode
359     int flags = fcntl(fd_, F_GETFL, 0);
360     if (flags == -1) {
361       auto errnoCopy = errno;
362       throw AsyncSocketException(
363           AsyncSocketException::INTERNAL_ERROR,
364           withAddr("failed to get socket flags"),
365           errnoCopy);
366     }
367     int rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
368     if (rv == -1) {
369       auto errnoCopy = errno;
370       throw AsyncSocketException(
371           AsyncSocketException::INTERNAL_ERROR,
372           withAddr("failed to put socket in non-blocking mode"),
373           errnoCopy);
374     }
375
376 #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
377     // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
378     rv = fcntl(fd_, F_SETNOSIGPIPE, 1);
379     if (rv == -1) {
380       auto errnoCopy = errno;
381       throw AsyncSocketException(
382           AsyncSocketException::INTERNAL_ERROR,
383           "failed to enable F_SETNOSIGPIPE on socket",
384           errnoCopy);
385     }
386 #endif
387
388     // By default, turn on TCP_NODELAY
389     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
390     // setNoDelay() will log an error message if it fails.
391     if (address.getFamily() != AF_UNIX) {
392       (void)setNoDelay(true);
393     }
394
395     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
396             << ", fd=" << fd_ << ", host=" << address.describe().c_str();
397
398     // bind the socket
399     if (bindAddr != anyAddress()) {
400       int one = 1;
401       if (setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
402         auto errnoCopy = errno;
403         doClose();
404         throw AsyncSocketException(
405             AsyncSocketException::NOT_OPEN,
406             "failed to setsockopt prior to bind on " + bindAddr.describe(),
407             errnoCopy);
408       }
409
410       bindAddr.getAddress(&addrStorage);
411
412       if (bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
413         auto errnoCopy = errno;
414         doClose();
415         throw AsyncSocketException(
416             AsyncSocketException::NOT_OPEN,
417             "failed to bind to async socket: " + bindAddr.describe(),
418             errnoCopy);
419       }
420     }
421
422     // Apply the additional options if any.
423     for (const auto& opt: options) {
424       rv = opt.first.apply(fd_, opt.second);
425       if (rv != 0) {
426         auto errnoCopy = errno;
427         throw AsyncSocketException(
428             AsyncSocketException::INTERNAL_ERROR,
429             withAddr("failed to set socket option"),
430             errnoCopy);
431       }
432     }
433
434     // Perform the connect()
435     address.getAddress(&addrStorage);
436
437     if (tfoEnabled_) {
438       state_ = StateEnum::FAST_OPEN;
439       tfoAttempted_ = true;
440     } else {
441       if (socketConnect(saddr, addr_.getActualSize()) < 0) {
442         return;
443       }
444     }
445
446     // If we're still here the connect() succeeded immediately.
447     // Fall through to call the callback outside of this try...catch block
448   } catch (const AsyncSocketException& ex) {
449     return failConnect(__func__, ex);
450   } catch (const std::exception& ex) {
451     // shouldn't happen, but handle it just in case
452     VLOG(4) << "AsyncSocket::connect(this=" << this << ", fd=" << fd_
453                << "): unexpected " << typeid(ex).name() << " exception: "
454                << ex.what();
455     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
456                             withAddr(string("unexpected exception: ") +
457                                      ex.what()));
458     return failConnect(__func__, tex);
459   }
460
461   // The connection succeeded immediately
462   // The read callback may not have been set yet, and no writes may be pending
463   // yet, so we don't have to register for any events at the moment.
464   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
465   assert(readCallback_ == nullptr);
466   assert(writeReqHead_ == nullptr);
467   if (state_ != StateEnum::FAST_OPEN) {
468     state_ = StateEnum::ESTABLISHED;
469   }
470   invokeConnectSuccess();
471 }
472
473 int AsyncSocket::socketConnect(const struct sockaddr* saddr, socklen_t len) {
474   int rv = fsp::connect(fd_, saddr, len);
475   if (rv < 0) {
476     auto errnoCopy = errno;
477     if (errnoCopy == EINPROGRESS) {
478       scheduleConnectTimeout();
479       registerForConnectEvents();
480     } else {
481       throw AsyncSocketException(
482           AsyncSocketException::NOT_OPEN,
483           "connect failed (immediately)",
484           errnoCopy);
485     }
486   }
487   return rv;
488 }
489
490 void AsyncSocket::scheduleConnectTimeout() {
491   // Connection in progress.
492   auto timeout = connectTimeout_.count();
493   if (timeout > 0) {
494     // Start a timer in case the connection takes too long.
495     if (!writeTimeout_.scheduleTimeout(uint32_t(timeout))) {
496       throw AsyncSocketException(
497           AsyncSocketException::INTERNAL_ERROR,
498           withAddr("failed to schedule AsyncSocket connect timeout"));
499     }
500   }
501 }
502
503 void AsyncSocket::registerForConnectEvents() {
504   // Register for write events, so we'll
505   // be notified when the connection finishes/fails.
506   // Note that we don't register for a persistent event here.
507   assert(eventFlags_ == EventHandler::NONE);
508   eventFlags_ = EventHandler::WRITE;
509   if (!ioHandler_.registerHandler(eventFlags_)) {
510     throw AsyncSocketException(
511         AsyncSocketException::INTERNAL_ERROR,
512         withAddr("failed to register AsyncSocket connect handler"));
513   }
514 }
515
516 void AsyncSocket::connect(ConnectCallback* callback,
517                            const string& ip, uint16_t port,
518                            int timeout,
519                            const OptionMap &options) noexcept {
520   DestructorGuard dg(this);
521   try {
522     connectCallback_ = callback;
523     connect(callback, folly::SocketAddress(ip, port), timeout, options);
524   } catch (const std::exception& ex) {
525     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
526                             ex.what());
527     return failConnect(__func__, tex);
528   }
529 }
530
531 void AsyncSocket::cancelConnect() {
532   connectCallback_ = nullptr;
533   if (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN) {
534     closeNow();
535   }
536 }
537
538 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
539   sendTimeout_ = milliseconds;
540   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
541
542   // If we are currently pending on write requests, immediately update
543   // writeTimeout_ with the new value.
544   if ((eventFlags_ & EventHandler::WRITE) &&
545       (state_ != StateEnum::CONNECTING && state_ != StateEnum::FAST_OPEN)) {
546     assert(state_ == StateEnum::ESTABLISHED);
547     assert((shutdownFlags_ & SHUT_WRITE) == 0);
548     if (sendTimeout_ > 0) {
549       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
550         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
551             withAddr("failed to reschedule send timeout in setSendTimeout"));
552         return failWrite(__func__, ex);
553       }
554     } else {
555       writeTimeout_.cancelTimeout();
556     }
557   }
558 }
559
560 void AsyncSocket::setReadCB(ReadCallback *callback) {
561   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
562           << ", callback=" << callback << ", state=" << state_;
563
564   // Short circuit if callback is the same as the existing readCallback_.
565   //
566   // Note that this is needed for proper functioning during some cleanup cases.
567   // During cleanup we allow setReadCallback(nullptr) to be called even if the
568   // read callback is already unset and we have been detached from an event
569   // base.  This check prevents us from asserting
570   // eventBase_->isInEventBaseThread() when eventBase_ is nullptr.
571   if (callback == readCallback_) {
572     return;
573   }
574
575   /* We are removing a read callback */
576   if (callback == nullptr &&
577       immediateReadHandler_.isLoopCallbackScheduled()) {
578     immediateReadHandler_.cancelLoopCallback();
579   }
580
581   if (shutdownFlags_ & SHUT_READ) {
582     // Reads have already been shut down on this socket.
583     //
584     // Allow setReadCallback(nullptr) to be called in this case, but don't
585     // allow a new callback to be set.
586     //
587     // For example, setReadCallback(nullptr) can happen after an error if we
588     // invoke some other error callback before invoking readError().  The other
589     // error callback that is invoked first may go ahead and clear the read
590     // callback before we get a chance to invoke readError().
591     if (callback != nullptr) {
592       return invalidState(callback);
593     }
594     assert((eventFlags_ & EventHandler::READ) == 0);
595     readCallback_ = nullptr;
596     return;
597   }
598
599   DestructorGuard dg(this);
600   assert(eventBase_->isInEventBaseThread());
601
602   switch ((StateEnum)state_) {
603     case StateEnum::CONNECTING:
604     case StateEnum::FAST_OPEN:
605       // For convenience, we allow the read callback to be set while we are
606       // still connecting.  We just store the callback for now.  Once the
607       // connection completes we'll register for read events.
608       readCallback_ = callback;
609       return;
610     case StateEnum::ESTABLISHED:
611     {
612       readCallback_ = callback;
613       uint16_t oldFlags = eventFlags_;
614       if (readCallback_) {
615         eventFlags_ |= EventHandler::READ;
616       } else {
617         eventFlags_ &= ~EventHandler::READ;
618       }
619
620       // Update our registration if our flags have changed
621       if (eventFlags_ != oldFlags) {
622         // We intentionally ignore the return value here.
623         // updateEventRegistration() will move us into the error state if it
624         // fails, and we don't need to do anything else here afterwards.
625         (void)updateEventRegistration();
626       }
627
628       if (readCallback_) {
629         checkForImmediateRead();
630       }
631       return;
632     }
633     case StateEnum::CLOSED:
634     case StateEnum::ERROR:
635       // We should never reach here.  SHUT_READ should always be set
636       // if we are in STATE_CLOSED or STATE_ERROR.
637       assert(false);
638       return invalidState(callback);
639     case StateEnum::UNINIT:
640       // We do not allow setReadCallback() to be called before we start
641       // connecting.
642       return invalidState(callback);
643   }
644
645   // We don't put a default case in the switch statement, so that the compiler
646   // will warn us to update the switch statement if a new state is added.
647   return invalidState(callback);
648 }
649
650 AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
651   return readCallback_;
652 }
653
654 void AsyncSocket::write(WriteCallback* callback,
655                          const void* buf, size_t bytes, WriteFlags flags) {
656   iovec op;
657   op.iov_base = const_cast<void*>(buf);
658   op.iov_len = bytes;
659   writeImpl(callback, &op, 1, unique_ptr<IOBuf>(), flags);
660 }
661
662 void AsyncSocket::writev(WriteCallback* callback,
663                           const iovec* vec,
664                           size_t count,
665                           WriteFlags flags) {
666   writeImpl(callback, vec, count, unique_ptr<IOBuf>(), flags);
667 }
668
669 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
670                               WriteFlags flags) {
671   constexpr size_t kSmallSizeMax = 64;
672   size_t count = buf->countChainElements();
673   if (count <= kSmallSizeMax) {
674     // suppress "warning: variable length array 'vec' is used [-Wvla]"
675     FOLLY_PUSH_WARNING;
676     FOLLY_GCC_DISABLE_WARNING(vla);
677     iovec vec[BOOST_PP_IF(FOLLY_HAVE_VLA, count, kSmallSizeMax)];
678     FOLLY_POP_WARNING;
679
680     writeChainImpl(callback, vec, count, std::move(buf), flags);
681   } else {
682     iovec* vec = new iovec[count];
683     writeChainImpl(callback, vec, count, std::move(buf), flags);
684     delete[] vec;
685   }
686 }
687
688 void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec,
689     size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
690   size_t veclen = buf->fillIov(vec, count);
691   writeImpl(callback, vec, veclen, std::move(buf), flags);
692 }
693
694 void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
695                              size_t count, unique_ptr<IOBuf>&& buf,
696                              WriteFlags flags) {
697   VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
698           << ", callback=" << callback << ", count=" << count
699           << ", state=" << state_;
700   DestructorGuard dg(this);
701   unique_ptr<IOBuf>ioBuf(std::move(buf));
702   assert(eventBase_->isInEventBaseThread());
703
704   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
705     // No new writes may be performed after the write side of the socket has
706     // been shutdown.
707     //
708     // We could just call callback->writeError() here to fail just this write.
709     // However, fail hard and use invalidState() to fail all outstanding
710     // callbacks and move the socket into the error state.  There's most likely
711     // a bug in the caller's code, so we abort everything rather than trying to
712     // proceed as best we can.
713     return invalidState(callback);
714   }
715
716   uint32_t countWritten = 0;
717   uint32_t partialWritten = 0;
718   ssize_t bytesWritten = 0;
719   bool mustRegister = false;
720   if ((state_ == StateEnum::ESTABLISHED || state_ == StateEnum::FAST_OPEN) &&
721       !connecting()) {
722     if (writeReqHead_ == nullptr) {
723       // If we are established and there are no other writes pending,
724       // we can attempt to perform the write immediately.
725       assert(writeReqTail_ == nullptr);
726       assert((eventFlags_ & EventHandler::WRITE) == 0);
727
728       auto writeResult = performWrite(
729           vec, uint32_t(count), flags, &countWritten, &partialWritten);
730       bytesWritten = writeResult.writeReturn;
731       if (bytesWritten < 0) {
732         auto errnoCopy = errno;
733         if (writeResult.exception) {
734           return failWrite(__func__, callback, 0, *writeResult.exception);
735         }
736         AsyncSocketException ex(
737             AsyncSocketException::INTERNAL_ERROR,
738             withAddr("writev failed"),
739             errnoCopy);
740         return failWrite(__func__, callback, 0, ex);
741       } else if (countWritten == count) {
742         // We successfully wrote everything.
743         // Invoke the callback and return.
744         if (callback) {
745           callback->writeSuccess();
746         }
747         return;
748       } else { // continue writing the next writeReq
749         if (bufferCallback_) {
750           bufferCallback_->onEgressBuffered();
751         }
752       }
753       if (!connecting()) {
754         // Writes might put the socket back into connecting state
755         // if TFO is enabled, and using TFO fails.
756         // This means that write timeouts would not be active, however
757         // connect timeouts would affect this stage.
758         mustRegister = true;
759       }
760     }
761   } else if (!connecting()) {
762     // Invalid state for writing
763     return invalidState(callback);
764   }
765
766   // Create a new WriteRequest to add to the queue
767   WriteRequest* req;
768   try {
769     req = BytesWriteRequest::newRequest(
770         this,
771         callback,
772         vec + countWritten,
773         uint32_t(count - countWritten),
774         partialWritten,
775         uint32_t(bytesWritten),
776         std::move(ioBuf),
777         flags);
778   } catch (const std::exception& ex) {
779     // we mainly expect to catch std::bad_alloc here
780     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
781         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
782     return failWrite(__func__, callback, size_t(bytesWritten), tex);
783   }
784   req->consume();
785   if (writeReqTail_ == nullptr) {
786     assert(writeReqHead_ == nullptr);
787     writeReqHead_ = writeReqTail_ = req;
788   } else {
789     writeReqTail_->append(req);
790     writeReqTail_ = req;
791   }
792
793   // Register for write events if are established and not currently
794   // waiting on write events
795   if (mustRegister) {
796     assert(state_ == StateEnum::ESTABLISHED);
797     assert((eventFlags_ & EventHandler::WRITE) == 0);
798     if (!updateEventRegistration(EventHandler::WRITE, 0)) {
799       assert(state_ == StateEnum::ERROR);
800       return;
801     }
802     if (sendTimeout_ > 0) {
803       // Schedule a timeout to fire if the write takes too long.
804       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
805         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
806                                withAddr("failed to schedule send timeout"));
807         return failWrite(__func__, ex);
808       }
809     }
810   }
811 }
812
813 void AsyncSocket::writeRequest(WriteRequest* req) {
814   if (writeReqTail_ == nullptr) {
815     assert(writeReqHead_ == nullptr);
816     writeReqHead_ = writeReqTail_ = req;
817     req->start();
818   } else {
819     writeReqTail_->append(req);
820     writeReqTail_ = req;
821   }
822 }
823
824 void AsyncSocket::close() {
825   VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
826           << ", state=" << state_ << ", shutdownFlags="
827           << std::hex << (int) shutdownFlags_;
828
829   // close() is only different from closeNow() when there are pending writes
830   // that need to drain before we can close.  In all other cases, just call
831   // closeNow().
832   //
833   // Note that writeReqHead_ can be non-nullptr even in STATE_CLOSED or
834   // STATE_ERROR if close() is invoked while a previous closeNow() or failure
835   // is still running.  (e.g., If there are multiple pending writes, and we
836   // call writeError() on the first one, it may call close().  In this case we
837   // will already be in STATE_CLOSED or STATE_ERROR, but the remaining pending
838   // writes will still be in the queue.)
839   //
840   // We only need to drain pending writes if we are still in STATE_CONNECTING
841   // or STATE_ESTABLISHED
842   if ((writeReqHead_ == nullptr) ||
843       !(state_ == StateEnum::CONNECTING ||
844       state_ == StateEnum::ESTABLISHED)) {
845     closeNow();
846     return;
847   }
848
849   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
850   // destroyed until close() returns.
851   DestructorGuard dg(this);
852   assert(eventBase_->isInEventBaseThread());
853
854   // Since there are write requests pending, we have to set the
855   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
856   // connect finishes and we finish writing these requests.
857   //
858   // Set SHUT_READ to indicate that reads are shut down, and set the
859   // SHUT_WRITE_PENDING flag to mark that we want to shutdown once the
860   // pending writes complete.
861   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE_PENDING);
862
863   // If a read callback is set, invoke readEOF() immediately to inform it that
864   // the socket has been closed and no more data can be read.
865   if (readCallback_) {
866     // Disable reads if they are enabled
867     if (!updateEventRegistration(0, EventHandler::READ)) {
868       // We're now in the error state; callbacks have been cleaned up
869       assert(state_ == StateEnum::ERROR);
870       assert(readCallback_ == nullptr);
871     } else {
872       ReadCallback* callback = readCallback_;
873       readCallback_ = nullptr;
874       callback->readEOF();
875     }
876   }
877 }
878
879 void AsyncSocket::closeNow() {
880   VLOG(5) << "AsyncSocket::closeNow(): this=" << this << ", fd_=" << fd_
881           << ", state=" << state_ << ", shutdownFlags="
882           << std::hex << (int) shutdownFlags_;
883   DestructorGuard dg(this);
884   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
885
886   switch (state_) {
887     case StateEnum::ESTABLISHED:
888     case StateEnum::CONNECTING:
889     case StateEnum::FAST_OPEN: {
890       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
891       state_ = StateEnum::CLOSED;
892
893       // If the write timeout was set, cancel it.
894       writeTimeout_.cancelTimeout();
895
896       // If we are registered for I/O events, unregister.
897       if (eventFlags_ != EventHandler::NONE) {
898         eventFlags_ = EventHandler::NONE;
899         if (!updateEventRegistration()) {
900           // We will have been moved into the error state.
901           assert(state_ == StateEnum::ERROR);
902           return;
903         }
904       }
905
906       if (immediateReadHandler_.isLoopCallbackScheduled()) {
907         immediateReadHandler_.cancelLoopCallback();
908       }
909
910       if (fd_ >= 0) {
911         ioHandler_.changeHandlerFD(-1);
912         doClose();
913       }
914
915       invokeConnectErr(socketClosedLocallyEx);
916
917       failAllWrites(socketClosedLocallyEx);
918
919       if (readCallback_) {
920         ReadCallback* callback = readCallback_;
921         readCallback_ = nullptr;
922         callback->readEOF();
923       }
924       return;
925     }
926     case StateEnum::CLOSED:
927       // Do nothing.  It's possible that we are being called recursively
928       // from inside a callback that we invoked inside another call to close()
929       // that is still running.
930       return;
931     case StateEnum::ERROR:
932       // Do nothing.  The error handling code has performed (or is performing)
933       // cleanup.
934       return;
935     case StateEnum::UNINIT:
936       assert(eventFlags_ == EventHandler::NONE);
937       assert(connectCallback_ == nullptr);
938       assert(readCallback_ == nullptr);
939       assert(writeReqHead_ == nullptr);
940       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
941       state_ = StateEnum::CLOSED;
942       return;
943   }
944
945   LOG(DFATAL) << "AsyncSocket::closeNow() (this=" << this << ", fd=" << fd_
946               << ") called in unknown state " << state_;
947 }
948
949 void AsyncSocket::closeWithReset() {
950   // Enable SO_LINGER, with the linger timeout set to 0.
951   // This will trigger a TCP reset when we close the socket.
952   if (fd_ >= 0) {
953     struct linger optLinger = {1, 0};
954     if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
955       VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
956               << "on " << fd_ << ": errno=" << errno;
957     }
958   }
959
960   // Then let closeNow() take care of the rest
961   closeNow();
962 }
963
964 void AsyncSocket::shutdownWrite() {
965   VLOG(5) << "AsyncSocket::shutdownWrite(): this=" << this << ", fd=" << fd_
966           << ", state=" << state_ << ", shutdownFlags="
967           << std::hex << (int) shutdownFlags_;
968
969   // If there are no pending writes, shutdownWrite() is identical to
970   // shutdownWriteNow().
971   if (writeReqHead_ == nullptr) {
972     shutdownWriteNow();
973     return;
974   }
975
976   assert(eventBase_->isInEventBaseThread());
977
978   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
979   // shutdown will be performed once all writes complete.
980   shutdownFlags_ |= SHUT_WRITE_PENDING;
981 }
982
983 void AsyncSocket::shutdownWriteNow() {
984   VLOG(5) << "AsyncSocket::shutdownWriteNow(): this=" << this
985           << ", fd=" << fd_ << ", state=" << state_
986           << ", shutdownFlags=" << std::hex << (int) shutdownFlags_;
987
988   if (shutdownFlags_ & SHUT_WRITE) {
989     // Writes are already shutdown; nothing else to do.
990     return;
991   }
992
993   // If SHUT_READ is already set, just call closeNow() to completely
994   // close the socket.  This can happen if close() was called with writes
995   // pending, and then shutdownWriteNow() is called before all pending writes
996   // complete.
997   if (shutdownFlags_ & SHUT_READ) {
998     closeNow();
999     return;
1000   }
1001
1002   DestructorGuard dg(this);
1003   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
1004
1005   switch (static_cast<StateEnum>(state_)) {
1006     case StateEnum::ESTABLISHED:
1007     {
1008       shutdownFlags_ |= SHUT_WRITE;
1009
1010       // If the write timeout was set, cancel it.
1011       writeTimeout_.cancelTimeout();
1012
1013       // If we are registered for write events, unregister.
1014       if (!updateEventRegistration(0, EventHandler::WRITE)) {
1015         // We will have been moved into the error state.
1016         assert(state_ == StateEnum::ERROR);
1017         return;
1018       }
1019
1020       // Shutdown writes on the file descriptor
1021       shutdown(fd_, SHUT_WR);
1022
1023       // Immediately fail all write requests
1024       failAllWrites(socketShutdownForWritesEx);
1025       return;
1026     }
1027     case StateEnum::CONNECTING:
1028     {
1029       // Set the SHUT_WRITE_PENDING flag.
1030       // When the connection completes, it will check this flag,
1031       // shutdown the write half of the socket, and then set SHUT_WRITE.
1032       shutdownFlags_ |= SHUT_WRITE_PENDING;
1033
1034       // Immediately fail all write requests
1035       failAllWrites(socketShutdownForWritesEx);
1036       return;
1037     }
1038     case StateEnum::UNINIT:
1039       // Callers normally shouldn't call shutdownWriteNow() before the socket
1040       // even starts connecting.  Nonetheless, go ahead and set
1041       // SHUT_WRITE_PENDING.  Once the socket eventually connects it will
1042       // immediately shut down the write side of the socket.
1043       shutdownFlags_ |= SHUT_WRITE_PENDING;
1044       return;
1045     case StateEnum::FAST_OPEN:
1046       // In fast open state we haven't call connected yet, and if we shutdown
1047       // the writes, we will never try to call connect, so shut everything down
1048       shutdownFlags_ |= SHUT_WRITE;
1049       // Immediately fail all write requests
1050       failAllWrites(socketShutdownForWritesEx);
1051       return;
1052     case StateEnum::CLOSED:
1053     case StateEnum::ERROR:
1054       // We should never get here.  SHUT_WRITE should always be set
1055       // in STATE_CLOSED and STATE_ERROR.
1056       VLOG(4) << "AsyncSocket::shutdownWriteNow() (this=" << this
1057                  << ", fd=" << fd_ << ") in unexpected state " << state_
1058                  << " with SHUT_WRITE not set ("
1059                  << std::hex << (int) shutdownFlags_ << ")";
1060       assert(false);
1061       return;
1062   }
1063
1064   LOG(DFATAL) << "AsyncSocket::shutdownWriteNow() (this=" << this << ", fd="
1065               << fd_ << ") called in unknown state " << state_;
1066 }
1067
1068 bool AsyncSocket::readable() const {
1069   if (fd_ == -1) {
1070     return false;
1071   }
1072   struct pollfd fds[1];
1073   fds[0].fd = fd_;
1074   fds[0].events = POLLIN;
1075   fds[0].revents = 0;
1076   int rc = poll(fds, 1, 0);
1077   return rc == 1;
1078 }
1079
1080 bool AsyncSocket::isPending() const {
1081   return ioHandler_.isPending();
1082 }
1083
1084 bool AsyncSocket::hangup() const {
1085   if (fd_ == -1) {
1086     // sanity check, no one should ask for hangup if we are not connected.
1087     assert(false);
1088     return false;
1089   }
1090 #ifdef POLLRDHUP // Linux-only
1091   struct pollfd fds[1];
1092   fds[0].fd = fd_;
1093   fds[0].events = POLLRDHUP|POLLHUP;
1094   fds[0].revents = 0;
1095   poll(fds, 1, 0);
1096   return (fds[0].revents & (POLLRDHUP|POLLHUP)) != 0;
1097 #else
1098   return false;
1099 #endif
1100 }
1101
1102 bool AsyncSocket::good() const {
1103   return (
1104       (state_ == StateEnum::CONNECTING || state_ == StateEnum::FAST_OPEN ||
1105        state_ == StateEnum::ESTABLISHED) &&
1106       (shutdownFlags_ == 0) && (eventBase_ != nullptr));
1107 }
1108
1109 bool AsyncSocket::error() const {
1110   return (state_ == StateEnum::ERROR);
1111 }
1112
1113 void AsyncSocket::attachEventBase(EventBase* eventBase) {
1114   VLOG(5) << "AsyncSocket::attachEventBase(this=" << this << ", fd=" << fd_
1115           << ", old evb=" << eventBase_ << ", new evb=" << eventBase
1116           << ", state=" << state_ << ", events="
1117           << std::hex << eventFlags_ << ")";
1118   assert(eventBase_ == nullptr);
1119   assert(eventBase->isInEventBaseThread());
1120
1121   eventBase_ = eventBase;
1122   ioHandler_.attachEventBase(eventBase);
1123   writeTimeout_.attachEventBase(eventBase);
1124   if (evbChangeCb_) {
1125     evbChangeCb_->evbAttached(this);
1126   }
1127 }
1128
1129 void AsyncSocket::detachEventBase() {
1130   VLOG(5) << "AsyncSocket::detachEventBase(this=" << this << ", fd=" << fd_
1131           << ", old evb=" << eventBase_ << ", state=" << state_
1132           << ", events=" << std::hex << eventFlags_ << ")";
1133   assert(eventBase_ != nullptr);
1134   assert(eventBase_->isInEventBaseThread());
1135
1136   eventBase_ = nullptr;
1137   ioHandler_.detachEventBase();
1138   writeTimeout_.detachEventBase();
1139   if (evbChangeCb_) {
1140     evbChangeCb_->evbDetached(this);
1141   }
1142 }
1143
1144 bool AsyncSocket::isDetachable() const {
1145   DCHECK(eventBase_ != nullptr);
1146   DCHECK(eventBase_->isInEventBaseThread());
1147
1148   return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
1149 }
1150
1151 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
1152   if (!localAddr_.isInitialized()) {
1153     localAddr_.setFromLocalAddress(fd_);
1154   }
1155   *address = localAddr_;
1156 }
1157
1158 void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
1159   if (!addr_.isInitialized()) {
1160     addr_.setFromPeerAddress(fd_);
1161   }
1162   *address = addr_;
1163 }
1164
1165 bool AsyncSocket::getTFOSucceded() const {
1166   return detail::tfo_succeeded(fd_);
1167 }
1168
1169 int AsyncSocket::setNoDelay(bool noDelay) {
1170   if (fd_ < 0) {
1171     VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket "
1172                << this << "(state=" << state_ << ")";
1173     return EINVAL;
1174
1175   }
1176
1177   int value = noDelay ? 1 : 0;
1178   if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
1179     int errnoCopy = errno;
1180     VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket "
1181             << this << " (fd=" << fd_ << ", state=" << state_ << "): "
1182             << strerror(errnoCopy);
1183     return errnoCopy;
1184   }
1185
1186   return 0;
1187 }
1188
1189 int AsyncSocket::setCongestionFlavor(const std::string &cname) {
1190
1191   #ifndef TCP_CONGESTION
1192   #define TCP_CONGESTION  13
1193   #endif
1194
1195   if (fd_ < 0) {
1196     VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
1197                << "socket " << this << "(state=" << state_ << ")";
1198     return EINVAL;
1199
1200   }
1201
1202   if (setsockopt(
1203           fd_,
1204           IPPROTO_TCP,
1205           TCP_CONGESTION,
1206           cname.c_str(),
1207           socklen_t(cname.length() + 1)) != 0) {
1208     int errnoCopy = errno;
1209     VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket "
1210             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1211             << strerror(errnoCopy);
1212     return errnoCopy;
1213   }
1214
1215   return 0;
1216 }
1217
1218 int AsyncSocket::setQuickAck(bool quickack) {
1219   if (fd_ < 0) {
1220     VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket "
1221                << this << "(state=" << state_ << ")";
1222     return EINVAL;
1223
1224   }
1225
1226 #ifdef TCP_QUICKACK // Linux-only
1227   int value = quickack ? 1 : 0;
1228   if (setsockopt(fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
1229     int errnoCopy = errno;
1230     VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket"
1231             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1232             << strerror(errnoCopy);
1233     return errnoCopy;
1234   }
1235
1236   return 0;
1237 #else
1238   return ENOSYS;
1239 #endif
1240 }
1241
1242 int AsyncSocket::setSendBufSize(size_t bufsize) {
1243   if (fd_ < 0) {
1244     VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
1245                << this << "(state=" << state_ << ")";
1246     return EINVAL;
1247   }
1248
1249   if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) !=0) {
1250     int errnoCopy = errno;
1251     VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket"
1252             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1253             << strerror(errnoCopy);
1254     return errnoCopy;
1255   }
1256
1257   return 0;
1258 }
1259
1260 int AsyncSocket::setRecvBufSize(size_t bufsize) {
1261   if (fd_ < 0) {
1262     VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
1263                << this << "(state=" << state_ << ")";
1264     return EINVAL;
1265   }
1266
1267   if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) !=0) {
1268     int errnoCopy = errno;
1269     VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket"
1270             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1271             << strerror(errnoCopy);
1272     return errnoCopy;
1273   }
1274
1275   return 0;
1276 }
1277
1278 int AsyncSocket::setTCPProfile(int profd) {
1279   if (fd_ < 0) {
1280     VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket "
1281                << this << "(state=" << state_ << ")";
1282     return EINVAL;
1283   }
1284
1285   if (setsockopt(fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) !=0) {
1286     int errnoCopy = errno;
1287     VLOG(2) << "failed to set socket namespace option on AsyncSocket"
1288             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1289             << strerror(errnoCopy);
1290     return errnoCopy;
1291   }
1292
1293   return 0;
1294 }
1295
1296 void AsyncSocket::ioReady(uint16_t events) noexcept {
1297   VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
1298           << ", events=" << std::hex << events << ", state=" << state_;
1299   DestructorGuard dg(this);
1300   assert(events & EventHandler::READ_WRITE);
1301   assert(eventBase_->isInEventBaseThread());
1302
1303   uint16_t relevantEvents = events & EventHandler::READ_WRITE;
1304   if (relevantEvents == EventHandler::READ) {
1305     handleRead();
1306   } else if (relevantEvents == EventHandler::WRITE) {
1307     handleWrite();
1308   } else if (relevantEvents == EventHandler::READ_WRITE) {
1309     EventBase* originalEventBase = eventBase_;
1310     // If both read and write events are ready, process writes first.
1311     handleWrite();
1312
1313     // Return now if handleWrite() detached us from our EventBase
1314     if (eventBase_ != originalEventBase) {
1315       return;
1316     }
1317
1318     // Only call handleRead() if a read callback is still installed.
1319     // (It's possible that the read callback was uninstalled during
1320     // handleWrite().)
1321     if (readCallback_) {
1322       handleRead();
1323     }
1324   } else {
1325     VLOG(4) << "AsyncSocket::ioRead() called with unexpected events "
1326                << std::hex << events << "(this=" << this << ")";
1327     abort();
1328   }
1329 }
1330
1331 AsyncSocket::ReadResult
1332 AsyncSocket::performRead(void** buf, size_t* buflen, size_t* /* offset */) {
1333   VLOG(5) << "AsyncSocket::performRead() this=" << this << ", buf=" << *buf
1334           << ", buflen=" << *buflen;
1335
1336   int recvFlags = 0;
1337   if (peek_) {
1338     recvFlags |= MSG_PEEK;
1339   }
1340
1341   ssize_t bytes = recv(fd_, *buf, *buflen, MSG_DONTWAIT | recvFlags);
1342   if (bytes < 0) {
1343     if (errno == EAGAIN || errno == EWOULDBLOCK) {
1344       // No more data to read right now.
1345       return ReadResult(READ_BLOCKING);
1346     } else {
1347       return ReadResult(READ_ERROR);
1348     }
1349   } else {
1350     appBytesReceived_ += bytes;
1351     return ReadResult(bytes);
1352   }
1353 }
1354
1355 void AsyncSocket::prepareReadBuffer(void** buf, size_t* buflen) {
1356   // no matter what, buffer should be preapared for non-ssl socket
1357   CHECK(readCallback_);
1358   readCallback_->getReadBuffer(buf, buflen);
1359 }
1360
1361 void AsyncSocket::handleRead() noexcept {
1362   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
1363           << ", state=" << state_;
1364   assert(state_ == StateEnum::ESTABLISHED);
1365   assert((shutdownFlags_ & SHUT_READ) == 0);
1366   assert(readCallback_ != nullptr);
1367   assert(eventFlags_ & EventHandler::READ);
1368
1369   // Loop until:
1370   // - a read attempt would block
1371   // - readCallback_ is uninstalled
1372   // - the number of loop iterations exceeds the optional maximum
1373   // - this AsyncSocket is moved to another EventBase
1374   //
1375   // When we invoke readDataAvailable() it may uninstall the readCallback_,
1376   // which is why need to check for it here.
1377   //
1378   // The last bullet point is slightly subtle.  readDataAvailable() may also
1379   // detach this socket from this EventBase.  However, before
1380   // readDataAvailable() returns another thread may pick it up, attach it to
1381   // a different EventBase, and install another readCallback_.  We need to
1382   // exit immediately after readDataAvailable() returns if the eventBase_ has
1383   // changed.  (The caller must perform some sort of locking to transfer the
1384   // AsyncSocket between threads properly.  This will be sufficient to ensure
1385   // that this thread sees the updated eventBase_ variable after
1386   // readDataAvailable() returns.)
1387   uint16_t numReads = 0;
1388   EventBase* originalEventBase = eventBase_;
1389   while (readCallback_ && eventBase_ == originalEventBase) {
1390     // Get the buffer to read into.
1391     void* buf = nullptr;
1392     size_t buflen = 0, offset = 0;
1393     try {
1394       prepareReadBuffer(&buf, &buflen);
1395       VLOG(5) << "prepareReadBuffer() buf=" << buf << ", buflen=" << buflen;
1396     } catch (const AsyncSocketException& ex) {
1397       return failRead(__func__, ex);
1398     } catch (const std::exception& ex) {
1399       AsyncSocketException tex(AsyncSocketException::BAD_ARGS,
1400                               string("ReadCallback::getReadBuffer() "
1401                                      "threw exception: ") +
1402                               ex.what());
1403       return failRead(__func__, tex);
1404     } catch (...) {
1405       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1406                              "ReadCallback::getReadBuffer() threw "
1407                              "non-exception type");
1408       return failRead(__func__, ex);
1409     }
1410     if (!isBufferMovable_ && (buf == nullptr || buflen == 0)) {
1411       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1412                              "ReadCallback::getReadBuffer() returned "
1413                              "empty buffer");
1414       return failRead(__func__, ex);
1415     }
1416
1417     // Perform the read
1418     auto readResult = performRead(&buf, &buflen, &offset);
1419     auto bytesRead = readResult.readReturn;
1420     VLOG(4) << "this=" << this << ", AsyncSocket::handleRead() got "
1421             << bytesRead << " bytes";
1422     if (bytesRead > 0) {
1423       if (!isBufferMovable_) {
1424         readCallback_->readDataAvailable(bytesRead);
1425       } else {
1426         CHECK(kOpenSslModeMoveBufferOwnership);
1427         VLOG(5) << "this=" << this << ", AsyncSocket::handleRead() got "
1428                 << "buf=" << buf << ", " << bytesRead << "/" << buflen
1429                 << ", offset=" << offset;
1430         auto readBuf = folly::IOBuf::takeOwnership(buf, buflen);
1431         readBuf->trimStart(offset);
1432         readBuf->trimEnd(buflen - offset - bytesRead);
1433         readCallback_->readBufferAvailable(std::move(readBuf));
1434       }
1435
1436       // Fall through and continue around the loop if the read
1437       // completely filled the available buffer.
1438       // Note that readCallback_ may have been uninstalled or changed inside
1439       // readDataAvailable().
1440       if (size_t(bytesRead) < buflen) {
1441         return;
1442       }
1443     } else if (bytesRead == READ_BLOCKING) {
1444         // No more data to read right now.
1445         return;
1446     } else if (bytesRead == READ_ERROR) {
1447       readErr_ = READ_ERROR;
1448       if (readResult.exception) {
1449         return failRead(__func__, *readResult.exception);
1450       }
1451       auto errnoCopy = errno;
1452       AsyncSocketException ex(
1453           AsyncSocketException::INTERNAL_ERROR,
1454           withAddr("recv() failed"),
1455           errnoCopy);
1456       return failRead(__func__, ex);
1457     } else {
1458       assert(bytesRead == READ_EOF);
1459       readErr_ = READ_EOF;
1460       // EOF
1461       shutdownFlags_ |= SHUT_READ;
1462       if (!updateEventRegistration(0, EventHandler::READ)) {
1463         // we've already been moved into STATE_ERROR
1464         assert(state_ == StateEnum::ERROR);
1465         assert(readCallback_ == nullptr);
1466         return;
1467       }
1468
1469       ReadCallback* callback = readCallback_;
1470       readCallback_ = nullptr;
1471       callback->readEOF();
1472       return;
1473     }
1474     if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
1475       if (readCallback_ != nullptr) {
1476         // We might still have data in the socket.
1477         // (e.g. see comment in AsyncSSLSocket::checkForImmediateRead)
1478         scheduleImmediateRead();
1479       }
1480       return;
1481     }
1482   }
1483 }
1484
1485 /**
1486  * This function attempts to write as much data as possible, until no more data
1487  * can be written.
1488  *
1489  * - If it sends all available data, it unregisters for write events, and stops
1490  *   the writeTimeout_.
1491  *
1492  * - If not all of the data can be sent immediately, it reschedules
1493  *   writeTimeout_ (if a non-zero timeout is set), and ensures the handler is
1494  *   registered for write events.
1495  */
1496 void AsyncSocket::handleWrite() noexcept {
1497   VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
1498           << ", state=" << state_;
1499   DestructorGuard dg(this);
1500
1501   if (state_ == StateEnum::CONNECTING) {
1502     handleConnect();
1503     return;
1504   }
1505
1506   // Normal write
1507   assert(state_ == StateEnum::ESTABLISHED);
1508   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1509   assert(writeReqHead_ != nullptr);
1510
1511   // Loop until we run out of write requests,
1512   // or until this socket is moved to another EventBase.
1513   // (See the comment in handleRead() explaining how this can happen.)
1514   EventBase* originalEventBase = eventBase_;
1515   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
1516     auto writeResult = writeReqHead_->performWrite();
1517     if (writeResult.writeReturn < 0) {
1518       if (writeResult.exception) {
1519         return failWrite(__func__, *writeResult.exception);
1520       }
1521       auto errnoCopy = errno;
1522       AsyncSocketException ex(
1523           AsyncSocketException::INTERNAL_ERROR,
1524           withAddr("writev() failed"),
1525           errnoCopy);
1526       return failWrite(__func__, ex);
1527     } else if (writeReqHead_->isComplete()) {
1528       // We finished this request
1529       WriteRequest* req = writeReqHead_;
1530       writeReqHead_ = req->getNext();
1531
1532       if (writeReqHead_ == nullptr) {
1533         writeReqTail_ = nullptr;
1534         // This is the last write request.
1535         // Unregister for write events and cancel the send timer
1536         // before we invoke the callback.  We have to update the state properly
1537         // before calling the callback, since it may want to detach us from
1538         // the EventBase.
1539         if (eventFlags_ & EventHandler::WRITE) {
1540           if (!updateEventRegistration(0, EventHandler::WRITE)) {
1541             assert(state_ == StateEnum::ERROR);
1542             return;
1543           }
1544           // Stop the send timeout
1545           writeTimeout_.cancelTimeout();
1546         }
1547         assert(!writeTimeout_.isScheduled());
1548
1549         // If SHUT_WRITE_PENDING is set, we should shutdown the socket after
1550         // we finish sending the last write request.
1551         //
1552         // We have to do this before invoking writeSuccess(), since
1553         // writeSuccess() may detach us from our EventBase.
1554         if (shutdownFlags_ & SHUT_WRITE_PENDING) {
1555           assert(connectCallback_ == nullptr);
1556           shutdownFlags_ |= SHUT_WRITE;
1557
1558           if (shutdownFlags_ & SHUT_READ) {
1559             // Reads have already been shutdown.  Fully close the socket and
1560             // move to STATE_CLOSED.
1561             //
1562             // Note: This code currently moves us to STATE_CLOSED even if
1563             // close() hasn't ever been called.  This can occur if we have
1564             // received EOF from the peer and shutdownWrite() has been called
1565             // locally.  Should we bother staying in STATE_ESTABLISHED in this
1566             // case, until close() is actually called?  I can't think of a
1567             // reason why we would need to do so.  No other operations besides
1568             // calling close() or destroying the socket can be performed at
1569             // this point.
1570             assert(readCallback_ == nullptr);
1571             state_ = StateEnum::CLOSED;
1572             if (fd_ >= 0) {
1573               ioHandler_.changeHandlerFD(-1);
1574               doClose();
1575             }
1576           } else {
1577             // Reads are still enabled, so we are only doing a half-shutdown
1578             shutdown(fd_, SHUT_WR);
1579           }
1580         }
1581       }
1582
1583       // Invoke the callback
1584       WriteCallback* callback = req->getCallback();
1585       req->destroy();
1586       if (callback) {
1587         callback->writeSuccess();
1588       }
1589       // We'll continue around the loop, trying to write another request
1590     } else {
1591       // Partial write.
1592       if (bufferCallback_) {
1593         bufferCallback_->onEgressBuffered();
1594       }
1595       writeReqHead_->consume();
1596       // Stop after a partial write; it's highly likely that a subsequent write
1597       // attempt will just return EAGAIN.
1598       //
1599       // Ensure that we are registered for write events.
1600       if ((eventFlags_ & EventHandler::WRITE) == 0) {
1601         if (!updateEventRegistration(EventHandler::WRITE, 0)) {
1602           assert(state_ == StateEnum::ERROR);
1603           return;
1604         }
1605       }
1606
1607       // Reschedule the send timeout, since we have made some write progress.
1608       if (sendTimeout_ > 0) {
1609         if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
1610           AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1611               withAddr("failed to reschedule write timeout"));
1612           return failWrite(__func__, ex);
1613         }
1614       }
1615       return;
1616     }
1617   }
1618   if (!writeReqHead_ && bufferCallback_) {
1619     bufferCallback_->onEgressBufferCleared();
1620   }
1621 }
1622
1623 void AsyncSocket::checkForImmediateRead() noexcept {
1624   // We currently don't attempt to perform optimistic reads in AsyncSocket.
1625   // (However, note that some subclasses do override this method.)
1626   //
1627   // Simply calling handleRead() here would be bad, as this would call
1628   // readCallback_->getReadBuffer(), forcing the callback to allocate a read
1629   // buffer even though no data may be available.  This would waste lots of
1630   // memory, since the buffer will sit around unused until the socket actually
1631   // becomes readable.
1632   //
1633   // Checking if the socket is readable now also seems like it would probably
1634   // be a pessimism.  In most cases it probably wouldn't be readable, and we
1635   // would just waste an extra system call.  Even if it is readable, waiting to
1636   // find out from libevent on the next event loop doesn't seem that bad.
1637 }
1638
1639 void AsyncSocket::handleInitialReadWrite() noexcept {
1640   // Our callers should already be holding a DestructorGuard, but grab
1641   // one here just to make sure, in case one of our calling code paths ever
1642   // changes.
1643   DestructorGuard dg(this);
1644   // If we have a readCallback_, make sure we enable read events.  We
1645   // may already be registered for reads if connectSuccess() set
1646   // the read calback.
1647   if (readCallback_ && !(eventFlags_ & EventHandler::READ)) {
1648     assert(state_ == StateEnum::ESTABLISHED);
1649     assert((shutdownFlags_ & SHUT_READ) == 0);
1650     if (!updateEventRegistration(EventHandler::READ, 0)) {
1651       assert(state_ == StateEnum::ERROR);
1652       return;
1653     }
1654     checkForImmediateRead();
1655   } else if (readCallback_ == nullptr) {
1656     // Unregister for read events.
1657     updateEventRegistration(0, EventHandler::READ);
1658   }
1659
1660   // If we have write requests pending, try to send them immediately.
1661   // Since we just finished accepting, there is a very good chance that we can
1662   // write without blocking.
1663   //
1664   // However, we only process them if EventHandler::WRITE is not already set,
1665   // which means that we're already blocked on a write attempt.  (This can
1666   // happen if connectSuccess() called write() before returning.)
1667   if (writeReqHead_ && !(eventFlags_ & EventHandler::WRITE)) {
1668     // Call handleWrite() to perform write processing.
1669     handleWrite();
1670   } else if (writeReqHead_ == nullptr) {
1671     // Unregister for write event.
1672     updateEventRegistration(0, EventHandler::WRITE);
1673   }
1674 }
1675
1676 void AsyncSocket::handleConnect() noexcept {
1677   VLOG(5) << "AsyncSocket::handleConnect() this=" << this << ", fd=" << fd_
1678           << ", state=" << state_;
1679   assert(state_ == StateEnum::CONNECTING);
1680   // SHUT_WRITE can never be set while we are still connecting;
1681   // SHUT_WRITE_PENDING may be set, be we only set SHUT_WRITE once the connect
1682   // finishes
1683   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1684
1685   // In case we had a connect timeout, cancel the timeout
1686   writeTimeout_.cancelTimeout();
1687   // We don't use a persistent registration when waiting on a connect event,
1688   // so we have been automatically unregistered now.  Update eventFlags_ to
1689   // reflect reality.
1690   assert(eventFlags_ == EventHandler::WRITE);
1691   eventFlags_ = EventHandler::NONE;
1692
1693   // Call getsockopt() to check if the connect succeeded
1694   int error;
1695   socklen_t len = sizeof(error);
1696   int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
1697   if (rv != 0) {
1698     auto errnoCopy = errno;
1699     AsyncSocketException ex(
1700         AsyncSocketException::INTERNAL_ERROR,
1701         withAddr("error calling getsockopt() after connect"),
1702         errnoCopy);
1703     VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1704                << fd_ << " host=" << addr_.describe()
1705                << ") exception:" << ex.what();
1706     return failConnect(__func__, ex);
1707   }
1708
1709   if (error != 0) {
1710     AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1711                            "connect failed", error);
1712     VLOG(1) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1713             << fd_ << " host=" << addr_.describe()
1714             << ") exception: " << ex.what();
1715     return failConnect(__func__, ex);
1716   }
1717
1718   // Move into STATE_ESTABLISHED
1719   state_ = StateEnum::ESTABLISHED;
1720
1721   // If SHUT_WRITE_PENDING is set and we don't have any write requests to
1722   // perform, immediately shutdown the write half of the socket.
1723   if ((shutdownFlags_ & SHUT_WRITE_PENDING) && writeReqHead_ == nullptr) {
1724     // SHUT_READ shouldn't be set.  If close() is called on the socket while we
1725     // are still connecting we just abort the connect rather than waiting for
1726     // it to complete.
1727     assert((shutdownFlags_ & SHUT_READ) == 0);
1728     shutdown(fd_, SHUT_WR);
1729     shutdownFlags_ |= SHUT_WRITE;
1730   }
1731
1732   VLOG(7) << "AsyncSocket " << this << ": fd " << fd_
1733           << "successfully connected; state=" << state_;
1734
1735   // Remember the EventBase we are attached to, before we start invoking any
1736   // callbacks (since the callbacks may call detachEventBase()).
1737   EventBase* originalEventBase = eventBase_;
1738
1739   invokeConnectSuccess();
1740   // Note that the connect callback may have changed our state.
1741   // (set or unset the read callback, called write(), closed the socket, etc.)
1742   // The following code needs to handle these situations correctly.
1743   //
1744   // If the socket has been closed, readCallback_ and writeReqHead_ will
1745   // always be nullptr, so that will prevent us from trying to read or write.
1746   //
1747   // The main thing to check for is if eventBase_ is still originalEventBase.
1748   // If not, we have been detached from this event base, so we shouldn't
1749   // perform any more operations.
1750   if (eventBase_ != originalEventBase) {
1751     return;
1752   }
1753
1754   handleInitialReadWrite();
1755 }
1756
1757 void AsyncSocket::timeoutExpired() noexcept {
1758   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
1759           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
1760   DestructorGuard dg(this);
1761   assert(eventBase_->isInEventBaseThread());
1762
1763   if (state_ == StateEnum::CONNECTING) {
1764     // connect() timed out
1765     // Unregister for I/O events.
1766     if (connectCallback_) {
1767       AsyncSocketException ex(
1768           AsyncSocketException::TIMED_OUT, "connect timed out");
1769       failConnect(__func__, ex);
1770     } else {
1771       // we faced a connect error without a connect callback, which could
1772       // happen due to TFO.
1773       AsyncSocketException ex(
1774           AsyncSocketException::TIMED_OUT, "write timed out during connection");
1775       failWrite(__func__, ex);
1776     }
1777   } else {
1778     // a normal write operation timed out
1779     AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out");
1780     failWrite(__func__, ex);
1781   }
1782 }
1783
1784 ssize_t AsyncSocket::tfoSendMsg(int fd, struct msghdr* msg, int msg_flags) {
1785   return detail::tfo_sendmsg(fd, msg, msg_flags);
1786 }
1787
1788 AsyncSocket::WriteResult
1789 AsyncSocket::sendSocketMessage(int fd, struct msghdr* msg, int msg_flags) {
1790   ssize_t totalWritten = 0;
1791   if (state_ == StateEnum::FAST_OPEN) {
1792     sockaddr_storage addr;
1793     auto len = addr_.getAddress(&addr);
1794     msg->msg_name = &addr;
1795     msg->msg_namelen = len;
1796     totalWritten = tfoSendMsg(fd_, msg, msg_flags);
1797     if (totalWritten >= 0) {
1798       tfoFinished_ = true;
1799       state_ = StateEnum::ESTABLISHED;
1800       // We schedule this asynchrously so that we don't end up
1801       // invoking initial read or write while a write is in progress.
1802       scheduleInitialReadWrite();
1803     } else if (errno == EINPROGRESS) {
1804       VLOG(4) << "TFO falling back to connecting";
1805       // A normal sendmsg doesn't return EINPROGRESS, however
1806       // TFO might fallback to connecting if there is no
1807       // cookie.
1808       state_ = StateEnum::CONNECTING;
1809       try {
1810         scheduleConnectTimeout();
1811         registerForConnectEvents();
1812       } catch (const AsyncSocketException& ex) {
1813         return WriteResult(
1814             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
1815       }
1816       // Let's fake it that no bytes were written and return an errno.
1817       errno = EAGAIN;
1818       totalWritten = -1;
1819     } else if (errno == EOPNOTSUPP) {
1820       // Try falling back to connecting.
1821       VLOG(4) << "TFO not supported";
1822       state_ = StateEnum::CONNECTING;
1823       try {
1824         int ret = socketConnect((const sockaddr*)&addr, len);
1825         if (ret == 0) {
1826           // connect succeeded immediately
1827           // Treat this like no data was written.
1828           state_ = StateEnum::ESTABLISHED;
1829           scheduleInitialReadWrite();
1830         }
1831         // If there was no exception during connections,
1832         // we would return that no bytes were written.
1833         errno = EAGAIN;
1834         totalWritten = -1;
1835       } catch (const AsyncSocketException& ex) {
1836         return WriteResult(
1837             WRITE_ERROR, folly::make_unique<AsyncSocketException>(ex));
1838       }
1839     } else if (errno == EAGAIN) {
1840       // Normally sendmsg would indicate that the write would block.
1841       // However in the fast open case, it would indicate that sendmsg
1842       // fell back to a connect. This is a return code from connect()
1843       // instead, and is an error condition indicating no fds available.
1844       return WriteResult(
1845           WRITE_ERROR,
1846           folly::make_unique<AsyncSocketException>(
1847               AsyncSocketException::UNKNOWN, "No more free local ports"));
1848     }
1849   } else {
1850     totalWritten = ::sendmsg(fd, msg, msg_flags);
1851   }
1852   return WriteResult(totalWritten);
1853 }
1854
1855 AsyncSocket::WriteResult AsyncSocket::performWrite(
1856     const iovec* vec,
1857     uint32_t count,
1858     WriteFlags flags,
1859     uint32_t* countWritten,
1860     uint32_t* partialWritten) {
1861   // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
1862   // We correctly handle EPIPE errors, so we never want to receive SIGPIPE
1863   // (since it may terminate the program if the main program doesn't explicitly
1864   // ignore it).
1865   struct msghdr msg;
1866   msg.msg_name = nullptr;
1867   msg.msg_namelen = 0;
1868   msg.msg_iov = const_cast<iovec *>(vec);
1869   msg.msg_iovlen = std::min<size_t>(count, kIovMax);
1870   msg.msg_control = nullptr;
1871   msg.msg_controllen = 0;
1872   msg.msg_flags = 0;
1873
1874   int msg_flags = MSG_DONTWAIT;
1875
1876 #ifdef MSG_NOSIGNAL // Linux-only
1877   msg_flags |= MSG_NOSIGNAL;
1878   if (isSet(flags, WriteFlags::CORK)) {
1879     // MSG_MORE tells the kernel we have more data to send, so wait for us to
1880     // give it the rest of the data rather than immediately sending a partial
1881     // frame, even when TCP_NODELAY is enabled.
1882     msg_flags |= MSG_MORE;
1883   }
1884 #endif
1885   if (isSet(flags, WriteFlags::EOR)) {
1886     // marks that this is the last byte of a record (response)
1887     msg_flags |= MSG_EOR;
1888   }
1889   auto writeResult = sendSocketMessage(fd_, &msg, msg_flags);
1890   auto totalWritten = writeResult.writeReturn;
1891   if (totalWritten < 0) {
1892     bool tryAgain = (errno == EAGAIN);
1893 #ifdef __APPLE__
1894     // Apple has a bug where doing a second write on a socket which we
1895     // have opened with TFO causes an ENOTCONN to be thrown. However the
1896     // socket is really connected, so treat ENOTCONN as a EAGAIN until
1897     // this bug is fixed.
1898     tryAgain |= (errno == ENOTCONN);
1899 #endif
1900     if (!writeResult.exception && tryAgain) {
1901       // TCP buffer is full; we can't write any more data right now.
1902       *countWritten = 0;
1903       *partialWritten = 0;
1904       return WriteResult(0);
1905     }
1906     // error
1907     *countWritten = 0;
1908     *partialWritten = 0;
1909     return writeResult;
1910   }
1911
1912   appBytesWritten_ += totalWritten;
1913
1914   uint32_t bytesWritten;
1915   uint32_t n;
1916   for (bytesWritten = uint32_t(totalWritten), n = 0; n < count; ++n) {
1917     const iovec* v = vec + n;
1918     if (v->iov_len > bytesWritten) {
1919       // Partial write finished in the middle of this iovec
1920       *countWritten = n;
1921       *partialWritten = bytesWritten;
1922       return WriteResult(totalWritten);
1923     }
1924
1925     bytesWritten -= uint32_t(v->iov_len);
1926   }
1927
1928   assert(bytesWritten == 0);
1929   *countWritten = n;
1930   *partialWritten = 0;
1931   return WriteResult(totalWritten);
1932 }
1933
1934 /**
1935  * Re-register the EventHandler after eventFlags_ has changed.
1936  *
1937  * If an error occurs, fail() is called to move the socket into the error state
1938  * and call all currently installed callbacks.  After an error, the
1939  * AsyncSocket is completely unregistered.
1940  *
1941  * @return Returns true on succcess, or false on error.
1942  */
1943 bool AsyncSocket::updateEventRegistration() {
1944   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
1945           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
1946           << ", events=" << std::hex << eventFlags_;
1947   assert(eventBase_->isInEventBaseThread());
1948   if (eventFlags_ == EventHandler::NONE) {
1949     ioHandler_.unregisterHandler();
1950     return true;
1951   }
1952
1953   // Always register for persistent events, so we don't have to re-register
1954   // after being called back.
1955   if (!ioHandler_.registerHandler(eventFlags_ | EventHandler::PERSIST)) {
1956     eventFlags_ = EventHandler::NONE; // we're not registered after error
1957     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1958         withAddr("failed to update AsyncSocket event registration"));
1959     fail("updateEventRegistration", ex);
1960     return false;
1961   }
1962
1963   return true;
1964 }
1965
1966 bool AsyncSocket::updateEventRegistration(uint16_t enable,
1967                                            uint16_t disable) {
1968   uint16_t oldFlags = eventFlags_;
1969   eventFlags_ |= enable;
1970   eventFlags_ &= ~disable;
1971   if (eventFlags_ == oldFlags) {
1972     return true;
1973   } else {
1974     return updateEventRegistration();
1975   }
1976 }
1977
1978 void AsyncSocket::startFail() {
1979   // startFail() should only be called once
1980   assert(state_ != StateEnum::ERROR);
1981   assert(getDestructorGuardCount() > 0);
1982   state_ = StateEnum::ERROR;
1983   // Ensure that SHUT_READ and SHUT_WRITE are set,
1984   // so all future attempts to read or write will be rejected
1985   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1986
1987   if (eventFlags_ != EventHandler::NONE) {
1988     eventFlags_ = EventHandler::NONE;
1989     ioHandler_.unregisterHandler();
1990   }
1991   writeTimeout_.cancelTimeout();
1992
1993   if (fd_ >= 0) {
1994     ioHandler_.changeHandlerFD(-1);
1995     doClose();
1996   }
1997 }
1998
1999 void AsyncSocket::invokeAllErrors(const AsyncSocketException& ex) {
2000   invokeConnectErr(ex);
2001   failAllWrites(ex);
2002
2003   if (readCallback_) {
2004     ReadCallback* callback = readCallback_;
2005     readCallback_ = nullptr;
2006     callback->readErr(ex);
2007   }
2008 }
2009
2010 void AsyncSocket::finishFail() {
2011   assert(state_ == StateEnum::ERROR);
2012   assert(getDestructorGuardCount() > 0);
2013
2014   AsyncSocketException ex(
2015       AsyncSocketException::INTERNAL_ERROR,
2016       withAddr("socket closing after error"));
2017   invokeAllErrors(ex);
2018 }
2019
2020 void AsyncSocket::finishFail(const AsyncSocketException& ex) {
2021   assert(state_ == StateEnum::ERROR);
2022   assert(getDestructorGuardCount() > 0);
2023   invokeAllErrors(ex);
2024 }
2025
2026 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
2027   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2028              << state_ << " host=" << addr_.describe()
2029              << "): failed in " << fn << "(): "
2030              << ex.what();
2031   startFail();
2032   finishFail();
2033 }
2034
2035 void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
2036   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2037                << state_ << " host=" << addr_.describe()
2038                << "): failed while connecting in " << fn << "(): "
2039                << ex.what();
2040   startFail();
2041
2042   invokeConnectErr(ex);
2043   finishFail(ex);
2044 }
2045
2046 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
2047   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2048                << state_ << " host=" << addr_.describe()
2049                << "): failed while reading in " << fn << "(): "
2050                << ex.what();
2051   startFail();
2052
2053   if (readCallback_ != nullptr) {
2054     ReadCallback* callback = readCallback_;
2055     readCallback_ = nullptr;
2056     callback->readErr(ex);
2057   }
2058
2059   finishFail();
2060 }
2061
2062 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
2063   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2064                << state_ << " host=" << addr_.describe()
2065                << "): failed while writing in " << fn << "(): "
2066                << ex.what();
2067   startFail();
2068
2069   // Only invoke the first write callback, since the error occurred while
2070   // writing this request.  Let any other pending write callbacks be invoked in
2071   // finishFail().
2072   if (writeReqHead_ != nullptr) {
2073     WriteRequest* req = writeReqHead_;
2074     writeReqHead_ = req->getNext();
2075     WriteCallback* callback = req->getCallback();
2076     uint32_t bytesWritten = req->getTotalBytesWritten();
2077     req->destroy();
2078     if (callback) {
2079       callback->writeErr(bytesWritten, ex);
2080     }
2081   }
2082
2083   finishFail();
2084 }
2085
2086 void AsyncSocket::failWrite(const char* fn, WriteCallback* callback,
2087                              size_t bytesWritten,
2088                              const AsyncSocketException& ex) {
2089   // This version of failWrite() is used when the failure occurs before
2090   // we've added the callback to writeReqHead_.
2091   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
2092              << state_ << " host=" << addr_.describe()
2093              <<"): failed while writing in " << fn << "(): "
2094              << ex.what();
2095   startFail();
2096
2097   if (callback != nullptr) {
2098     callback->writeErr(bytesWritten, ex);
2099   }
2100
2101   finishFail();
2102 }
2103
2104 void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
2105   // Invoke writeError() on all write callbacks.
2106   // This is used when writes are forcibly shutdown with write requests
2107   // pending, or when an error occurs with writes pending.
2108   while (writeReqHead_ != nullptr) {
2109     WriteRequest* req = writeReqHead_;
2110     writeReqHead_ = req->getNext();
2111     WriteCallback* callback = req->getCallback();
2112     if (callback) {
2113       callback->writeErr(req->getTotalBytesWritten(), ex);
2114     }
2115     req->destroy();
2116   }
2117 }
2118
2119 void AsyncSocket::invalidState(ConnectCallback* callback) {
2120   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
2121              << "): connect() called in invalid state " << state_;
2122
2123   /*
2124    * The invalidState() methods don't use the normal failure mechanisms,
2125    * since we don't know what state we are in.  We don't want to call
2126    * startFail()/finishFail() recursively if we are already in the middle of
2127    * cleaning up.
2128    */
2129
2130   AsyncSocketException ex(AsyncSocketException::ALREADY_OPEN,
2131                          "connect() called with socket in invalid state");
2132   connectEndTime_ = std::chrono::steady_clock::now();
2133   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2134     if (callback) {
2135       callback->connectErr(ex);
2136     }
2137   } else {
2138     // We can't use failConnect() here since connectCallback_
2139     // may already be set to another callback.  Invoke this ConnectCallback
2140     // here; any other connectCallback_ will be invoked in finishFail()
2141     startFail();
2142     if (callback) {
2143       callback->connectErr(ex);
2144     }
2145     finishFail();
2146   }
2147 }
2148
2149 void AsyncSocket::invokeConnectErr(const AsyncSocketException& ex) {
2150   connectEndTime_ = std::chrono::steady_clock::now();
2151   if (connectCallback_) {
2152     ConnectCallback* callback = connectCallback_;
2153     connectCallback_ = nullptr;
2154     callback->connectErr(ex);
2155   }
2156 }
2157
2158 void AsyncSocket::invokeConnectSuccess() {
2159   connectEndTime_ = std::chrono::steady_clock::now();
2160   if (connectCallback_) {
2161     ConnectCallback* callback = connectCallback_;
2162     connectCallback_ = nullptr;
2163     callback->connectSuccess();
2164   }
2165 }
2166
2167 void AsyncSocket::invalidState(ReadCallback* callback) {
2168   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2169              << "): setReadCallback(" << callback
2170              << ") called in invalid state " << state_;
2171
2172   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2173                          "setReadCallback() called with socket in "
2174                          "invalid state");
2175   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2176     if (callback) {
2177       callback->readErr(ex);
2178     }
2179   } else {
2180     startFail();
2181     if (callback) {
2182       callback->readErr(ex);
2183     }
2184     finishFail();
2185   }
2186 }
2187
2188 void AsyncSocket::invalidState(WriteCallback* callback) {
2189   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
2190              << "): write() called in invalid state " << state_;
2191
2192   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
2193                          withAddr("write() called with socket in invalid state"));
2194   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
2195     if (callback) {
2196       callback->writeErr(0, ex);
2197     }
2198   } else {
2199     startFail();
2200     if (callback) {
2201       callback->writeErr(0, ex);
2202     }
2203     finishFail();
2204   }
2205 }
2206
2207 void AsyncSocket::doClose() {
2208   if (fd_ == -1) return;
2209   if (shutdownSocketSet_) {
2210     shutdownSocketSet_->close(fd_);
2211   } else {
2212     ::close(fd_);
2213   }
2214   fd_ = -1;
2215 }
2216
2217 std::ostream& operator << (std::ostream& os,
2218                            const AsyncSocket::StateEnum& state) {
2219   os << static_cast<int>(state);
2220   return os;
2221 }
2222
2223 std::string AsyncSocket::withAddr(const std::string& s) {
2224   // Don't use addr_ directly because it may not be initialized
2225   // e.g. if constructed from fd
2226   folly::SocketAddress peer, local;
2227   try {
2228     getPeerAddress(&peer);
2229     getLocalAddress(&local);
2230   } catch (const std::exception&) {
2231     // ignore
2232   } catch (...) {
2233     // ignore
2234   }
2235   return s + " (peer=" + peer.describe() + ", local=" + local.describe() + ")";
2236 }
2237
2238 void AsyncSocket::setBufferCallback(BufferCallback* cb) {
2239   bufferCallback_ = cb;
2240 }
2241
2242 } // folly