Reduce AsyncSocket code size
[folly.git] / folly / io / async / AsyncSocket.cpp
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20 #include <folly/SocketAddress.h>
21 #include <folly/io/IOBuf.h>
22
23 #include <poll.h>
24 #include <errno.h>
25 #include <limits.h>
26 #include <unistd.h>
27 #include <fcntl.h>
28 #include <sys/types.h>
29 #include <sys/socket.h>
30 #include <netinet/in.h>
31 #include <netinet/tcp.h>
32
33 using std::string;
34 using std::unique_ptr;
35
36 namespace folly {
37
38 // static members initializers
39 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
40 const folly::SocketAddress AsyncSocket::anyAddress =
41   folly::SocketAddress("0.0.0.0", 0);
42
43 const AsyncSocketException socketClosedLocallyEx(
44     AsyncSocketException::END_OF_FILE, "socket closed locally");
45 const AsyncSocketException socketShutdownForWritesEx(
46     AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
47
48 // TODO: It might help performance to provide a version of WriteRequest that
49 // users could derive from, so we can avoid the extra allocation for each call
50 // to write()/writev().  We could templatize TFramedAsyncChannel just like the
51 // protocols are currently templatized for transports.
52 //
53 // We would need the version for external users where they provide the iovec
54 // storage space, and only our internal version would allocate it at the end of
55 // the WriteRequest.
56
57 /**
58  * A WriteRequest object tracks information about a pending write() or writev()
59  * operation.
60  *
61  * A new WriteRequest operation is allocated on the heap for all write
62  * operations that cannot be completed immediately.
63  */
64 class AsyncSocket::WriteRequest {
65  public:
66   static WriteRequest* newRequest(WriteCallback* callback,
67                                   const iovec* ops,
68                                   uint32_t opCount,
69                                   unique_ptr<IOBuf>&& ioBuf,
70                                   WriteFlags flags) {
71     assert(opCount > 0);
72     // Since we put a variable size iovec array at the end
73     // of each WriteRequest, we have to manually allocate the memory.
74     void* buf = malloc(sizeof(WriteRequest) +
75                        (opCount * sizeof(struct iovec)));
76     if (buf == nullptr) {
77       throw std::bad_alloc();
78     }
79
80     return new(buf) WriteRequest(callback, ops, opCount, std::move(ioBuf),
81                                  flags);
82   }
83
84   void destroy() {
85     this->~WriteRequest();
86     free(this);
87   }
88
89   bool cork() const {
90     return isSet(flags_, WriteFlags::CORK);
91   }
92
93   WriteFlags flags() const {
94     return flags_;
95   }
96
97   WriteRequest* getNext() const {
98     return next_;
99   }
100
101   WriteCallback* getCallback() const {
102     return callback_;
103   }
104
105   uint32_t getBytesWritten() const {
106     return bytesWritten_;
107   }
108
109   const struct iovec* getOps() const {
110     assert(opCount_ > opIndex_);
111     return writeOps_ + opIndex_;
112   }
113
114   uint32_t getOpCount() const {
115     assert(opCount_ > opIndex_);
116     return opCount_ - opIndex_;
117   }
118
119   void consume(uint32_t wholeOps, uint32_t partialBytes,
120                uint32_t totalBytesWritten) {
121     // Advance opIndex_ forward by wholeOps
122     opIndex_ += wholeOps;
123     assert(opIndex_ < opCount_);
124
125     // If we've finished writing any IOBufs, release them
126     if (ioBuf_) {
127       for (uint32_t i = wholeOps; i != 0; --i) {
128         assert(ioBuf_);
129         ioBuf_ = ioBuf_->pop();
130       }
131     }
132
133     // Move partialBytes forward into the current iovec buffer
134     struct iovec* currentOp = writeOps_ + opIndex_;
135     assert((partialBytes < currentOp->iov_len) || (currentOp->iov_len == 0));
136     currentOp->iov_base =
137       reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes;
138     currentOp->iov_len -= partialBytes;
139
140     // Increment the bytesWritten_ count by totalBytesWritten
141     bytesWritten_ += totalBytesWritten;
142   }
143
144   void append(WriteRequest* next) {
145     assert(next_ == nullptr);
146     next_ = next;
147   }
148
149  private:
150   WriteRequest(WriteCallback* callback,
151                const struct iovec* ops,
152                uint32_t opCount,
153                unique_ptr<IOBuf>&& ioBuf,
154                WriteFlags flags)
155     : next_(nullptr)
156     , callback_(callback)
157     , bytesWritten_(0)
158     , opCount_(opCount)
159     , opIndex_(0)
160     , flags_(flags)
161     , ioBuf_(std::move(ioBuf)) {
162     memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
163   }
164
165   // Private destructor, to ensure callers use destroy()
166   ~WriteRequest() {}
167
168   WriteRequest* next_;          ///< pointer to next WriteRequest
169   WriteCallback* callback_;     ///< completion callback
170   uint32_t bytesWritten_;       ///< bytes written
171   uint32_t opCount_;            ///< number of entries in writeOps_
172   uint32_t opIndex_;            ///< current index into writeOps_
173   WriteFlags flags_;            ///< set for WriteFlags
174   unique_ptr<IOBuf> ioBuf_;     ///< underlying IOBuf, or nullptr if N/A
175   struct iovec writeOps_[];     ///< write operation(s) list
176 };
177
178 AsyncSocket::AsyncSocket()
179   : eventBase_(nullptr)
180   , writeTimeout_(this, nullptr)
181   , ioHandler_(this, nullptr) {
182   VLOG(5) << "new AsyncSocket()";
183   init();
184 }
185
186 AsyncSocket::AsyncSocket(EventBase* evb)
187   : eventBase_(evb)
188   , writeTimeout_(this, evb)
189   , ioHandler_(this, evb) {
190   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
191   init();
192 }
193
194 AsyncSocket::AsyncSocket(EventBase* evb,
195                            const folly::SocketAddress& address,
196                            uint32_t connectTimeout)
197   : AsyncSocket(evb) {
198   connect(nullptr, address, connectTimeout);
199 }
200
201 AsyncSocket::AsyncSocket(EventBase* evb,
202                            const std::string& ip,
203                            uint16_t port,
204                            uint32_t connectTimeout)
205   : AsyncSocket(evb) {
206   connect(nullptr, ip, port, connectTimeout);
207 }
208
209 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
210   : eventBase_(evb)
211   , writeTimeout_(this, evb)
212   , ioHandler_(this, evb, fd) {
213   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
214           << fd << ")";
215   init();
216   fd_ = fd;
217   setCloseOnExec();
218   state_ = StateEnum::ESTABLISHED;
219 }
220
221 // init() method, since constructor forwarding isn't supported in most
222 // compilers yet.
223 void AsyncSocket::init() {
224   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
225   shutdownFlags_ = 0;
226   state_ = StateEnum::UNINIT;
227   eventFlags_ = EventHandler::NONE;
228   fd_ = -1;
229   sendTimeout_ = 0;
230   maxReadsPerEvent_ = 16;
231   connectCallback_ = nullptr;
232   readCallback_ = nullptr;
233   writeReqHead_ = nullptr;
234   writeReqTail_ = nullptr;
235   shutdownSocketSet_ = nullptr;
236   appBytesWritten_ = 0;
237   appBytesReceived_ = 0;
238 }
239
240 AsyncSocket::~AsyncSocket() {
241   VLOG(7) << "actual destruction of AsyncSocket(this=" << this
242           << ", evb=" << eventBase_ << ", fd=" << fd_
243           << ", state=" << state_ << ")";
244 }
245
246 void AsyncSocket::destroy() {
247   VLOG(5) << "AsyncSocket::destroy(this=" << this << ", evb=" << eventBase_
248           << ", fd=" << fd_ << ", state=" << state_;
249   // When destroy is called, close the socket immediately
250   closeNow();
251
252   // Then call DelayedDestruction::destroy() to take care of
253   // whether or not we need immediate or delayed destruction
254   DelayedDestruction::destroy();
255 }
256
257 int AsyncSocket::detachFd() {
258   VLOG(6) << "AsyncSocket::detachFd(this=" << this << ", fd=" << fd_
259           << ", evb=" << eventBase_ << ", state=" << state_
260           << ", events=" << std::hex << eventFlags_ << ")";
261   // Extract the fd, and set fd_ to -1 first, so closeNow() won't
262   // actually close the descriptor.
263   if (shutdownSocketSet_) {
264     shutdownSocketSet_->remove(fd_);
265   }
266   int fd = fd_;
267   fd_ = -1;
268   // Call closeNow() to invoke all pending callbacks with an error.
269   closeNow();
270   // Update the EventHandler to stop using this fd.
271   // This can only be done after closeNow() unregisters the handler.
272   ioHandler_.changeHandlerFD(-1);
273   return fd;
274 }
275
276 void AsyncSocket::setShutdownSocketSet(ShutdownSocketSet* newSS) {
277   if (shutdownSocketSet_ == newSS) {
278     return;
279   }
280   if (shutdownSocketSet_ && fd_ != -1) {
281     shutdownSocketSet_->remove(fd_);
282   }
283   shutdownSocketSet_ = newSS;
284   if (shutdownSocketSet_ && fd_ != -1) {
285     shutdownSocketSet_->add(fd_);
286   }
287 }
288
289 void AsyncSocket::setCloseOnExec() {
290   int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC);
291   if (rv != 0) {
292     throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
293                                withAddr("failed to set close-on-exec flag"),
294                                errno);
295   }
296 }
297
298 void AsyncSocket::connect(ConnectCallback* callback,
299                            const folly::SocketAddress& address,
300                            int timeout,
301                            const OptionMap &options,
302                            const folly::SocketAddress& bindAddr) noexcept {
303   DestructorGuard dg(this);
304   assert(eventBase_->isInEventBaseThread());
305
306   addr_ = address;
307
308   // Make sure we're in the uninitialized state
309   if (state_ != StateEnum::UNINIT) {
310     return invalidState(callback);
311   }
312
313   assert(fd_ == -1);
314   state_ = StateEnum::CONNECTING;
315   connectCallback_ = callback;
316
317   sockaddr_storage addrStorage;
318   sockaddr* saddr = reinterpret_cast<sockaddr*>(&addrStorage);
319
320   try {
321     // Create the socket
322     // Technically the first parameter should actually be a protocol family
323     // constant (PF_xxx) rather than an address family (AF_xxx), but the
324     // distinction is mainly just historical.  In pretty much all
325     // implementations the PF_foo and AF_foo constants are identical.
326     fd_ = socket(address.getFamily(), SOCK_STREAM, 0);
327     if (fd_ < 0) {
328       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
329                                 withAddr("failed to create socket"), errno);
330     }
331     if (shutdownSocketSet_) {
332       shutdownSocketSet_->add(fd_);
333     }
334     ioHandler_.changeHandlerFD(fd_);
335
336     setCloseOnExec();
337
338     // Put the socket in non-blocking mode
339     int flags = fcntl(fd_, F_GETFL, 0);
340     if (flags == -1) {
341       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
342                                 withAddr("failed to get socket flags"), errno);
343     }
344     int rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
345     if (rv == -1) {
346       throw AsyncSocketException(
347           AsyncSocketException::INTERNAL_ERROR,
348           withAddr("failed to put socket in non-blocking mode"),
349           errno);
350     }
351
352 #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
353     // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
354     rv = fcntl(fd_, F_SETNOSIGPIPE, 1);
355     if (rv == -1) {
356       throw AsyncSocketException(
357           AsyncSocketException::INTERNAL_ERROR,
358           "failed to enable F_SETNOSIGPIPE on socket",
359           errno);
360     }
361 #endif
362
363     // By default, turn on TCP_NODELAY
364     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
365     // setNoDelay() will log an error message if it fails.
366     if (address.getFamily() != AF_UNIX) {
367       (void)setNoDelay(true);
368     }
369
370     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
371             << ", fd=" << fd_ << ", host=" << address.describe().c_str();
372
373     // bind the socket
374     if (bindAddr != anyAddress) {
375       int one = 1;
376       if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
377         doClose();
378         throw AsyncSocketException(
379           AsyncSocketException::NOT_OPEN,
380           "failed to setsockopt prior to bind on " + bindAddr.describe(),
381           errno);
382       }
383
384       bindAddr.getAddress(&addrStorage);
385
386       if (::bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
387         doClose();
388         throw AsyncSocketException(AsyncSocketException::NOT_OPEN,
389                                   "failed to bind to async socket: " +
390                                   bindAddr.describe(),
391                                   errno);
392       }
393     }
394
395     // Apply the additional options if any.
396     for (const auto& opt: options) {
397       int rv = opt.first.apply(fd_, opt.second);
398       if (rv != 0) {
399         throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
400                                   withAddr("failed to set socket option"),
401                                   errno);
402       }
403     }
404
405     // Perform the connect()
406     address.getAddress(&addrStorage);
407
408     rv = ::connect(fd_, saddr, address.getActualSize());
409     if (rv < 0) {
410       if (errno == EINPROGRESS) {
411         // Connection in progress.
412         if (timeout > 0) {
413           // Start a timer in case the connection takes too long.
414           if (!writeTimeout_.scheduleTimeout(timeout)) {
415             throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
416                 withAddr("failed to schedule AsyncSocket connect timeout"));
417           }
418         }
419
420         // Register for write events, so we'll
421         // be notified when the connection finishes/fails.
422         // Note that we don't register for a persistent event here.
423         assert(eventFlags_ == EventHandler::NONE);
424         eventFlags_ = EventHandler::WRITE;
425         if (!ioHandler_.registerHandler(eventFlags_)) {
426           throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
427               withAddr("failed to register AsyncSocket connect handler"));
428         }
429         return;
430       } else {
431         throw AsyncSocketException(AsyncSocketException::NOT_OPEN,
432                                   "connect failed (immediately)", errno);
433       }
434     }
435
436     // If we're still here the connect() succeeded immediately.
437     // Fall through to call the callback outside of this try...catch block
438   } catch (const AsyncSocketException& ex) {
439     return failConnect(__func__, ex);
440   } catch (const std::exception& ex) {
441     // shouldn't happen, but handle it just in case
442     VLOG(4) << "AsyncSocket::connect(this=" << this << ", fd=" << fd_
443                << "): unexpected " << typeid(ex).name() << " exception: "
444                << ex.what();
445     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
446                             withAddr(string("unexpected exception: ") +
447                                      ex.what()));
448     return failConnect(__func__, tex);
449   }
450
451   // The connection succeeded immediately
452   // The read callback may not have been set yet, and no writes may be pending
453   // yet, so we don't have to register for any events at the moment.
454   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
455   assert(readCallback_ == nullptr);
456   assert(writeReqHead_ == nullptr);
457   state_ = StateEnum::ESTABLISHED;
458   if (callback) {
459     connectCallback_ = nullptr;
460     callback->connectSuccess();
461   }
462 }
463
464 void AsyncSocket::connect(ConnectCallback* callback,
465                            const string& ip, uint16_t port,
466                            int timeout,
467                            const OptionMap &options) noexcept {
468   DestructorGuard dg(this);
469   try {
470     connectCallback_ = callback;
471     connect(callback, folly::SocketAddress(ip, port), timeout, options);
472   } catch (const std::exception& ex) {
473     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
474                             ex.what());
475     return failConnect(__func__, tex);
476   }
477 }
478
479 void AsyncSocket::cancelConnect() {
480   connectCallback_ = nullptr;
481   if (state_ == StateEnum::CONNECTING) {
482     closeNow();
483   }
484 }
485
486 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
487   sendTimeout_ = milliseconds;
488   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
489
490   // If we are currently pending on write requests, immediately update
491   // writeTimeout_ with the new value.
492   if ((eventFlags_ & EventHandler::WRITE) &&
493       (state_ != StateEnum::CONNECTING)) {
494     assert(state_ == StateEnum::ESTABLISHED);
495     assert((shutdownFlags_ & SHUT_WRITE) == 0);
496     if (sendTimeout_ > 0) {
497       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
498         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
499             withAddr("failed to reschedule send timeout in setSendTimeout"));
500         return failWrite(__func__, ex);
501       }
502     } else {
503       writeTimeout_.cancelTimeout();
504     }
505   }
506 }
507
508 void AsyncSocket::setReadCB(ReadCallback *callback) {
509   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
510           << ", callback=" << callback << ", state=" << state_;
511
512   // Short circuit if callback is the same as the existing readCallback_.
513   //
514   // Note that this is needed for proper functioning during some cleanup cases.
515   // During cleanup we allow setReadCallback(nullptr) to be called even if the
516   // read callback is already unset and we have been detached from an event
517   // base.  This check prevents us from asserting
518   // eventBase_->isInEventBaseThread() when eventBase_ is nullptr.
519   if (callback == readCallback_) {
520     return;
521   }
522
523   if (shutdownFlags_ & SHUT_READ) {
524     // Reads have already been shut down on this socket.
525     //
526     // Allow setReadCallback(nullptr) to be called in this case, but don't
527     // allow a new callback to be set.
528     //
529     // For example, setReadCallback(nullptr) can happen after an error if we
530     // invoke some other error callback before invoking readError().  The other
531     // error callback that is invoked first may go ahead and clear the read
532     // callback before we get a chance to invoke readError().
533     if (callback != nullptr) {
534       return invalidState(callback);
535     }
536     assert((eventFlags_ & EventHandler::READ) == 0);
537     readCallback_ = nullptr;
538     return;
539   }
540
541   DestructorGuard dg(this);
542   assert(eventBase_->isInEventBaseThread());
543
544   switch ((StateEnum)state_) {
545     case StateEnum::CONNECTING:
546       // For convenience, we allow the read callback to be set while we are
547       // still connecting.  We just store the callback for now.  Once the
548       // connection completes we'll register for read events.
549       readCallback_ = callback;
550       return;
551     case StateEnum::ESTABLISHED:
552     {
553       readCallback_ = callback;
554       uint16_t oldFlags = eventFlags_;
555       if (readCallback_) {
556         eventFlags_ |= EventHandler::READ;
557       } else {
558         eventFlags_ &= ~EventHandler::READ;
559       }
560
561       // Update our registration if our flags have changed
562       if (eventFlags_ != oldFlags) {
563         // We intentionally ignore the return value here.
564         // updateEventRegistration() will move us into the error state if it
565         // fails, and we don't need to do anything else here afterwards.
566         (void)updateEventRegistration();
567       }
568
569       if (readCallback_) {
570         checkForImmediateRead();
571       }
572       return;
573     }
574     case StateEnum::CLOSED:
575     case StateEnum::ERROR:
576       // We should never reach here.  SHUT_READ should always be set
577       // if we are in STATE_CLOSED or STATE_ERROR.
578       assert(false);
579       return invalidState(callback);
580     case StateEnum::UNINIT:
581       // We do not allow setReadCallback() to be called before we start
582       // connecting.
583       return invalidState(callback);
584   }
585
586   // We don't put a default case in the switch statement, so that the compiler
587   // will warn us to update the switch statement if a new state is added.
588   return invalidState(callback);
589 }
590
591 AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
592   return readCallback_;
593 }
594
595 void AsyncSocket::write(WriteCallback* callback,
596                          const void* buf, size_t bytes, WriteFlags flags) {
597   iovec op;
598   op.iov_base = const_cast<void*>(buf);
599   op.iov_len = bytes;
600   writeImpl(callback, &op, 1, std::move(unique_ptr<IOBuf>()), flags);
601 }
602
603 void AsyncSocket::writev(WriteCallback* callback,
604                           const iovec* vec,
605                           size_t count,
606                           WriteFlags flags) {
607   writeImpl(callback, vec, count, std::move(unique_ptr<IOBuf>()), flags);
608 }
609
610 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
611                               WriteFlags flags) {
612   size_t count = buf->countChainElements();
613   if (count <= 64) {
614     iovec vec[count];
615     writeChainImpl(callback, vec, count, std::move(buf), flags);
616   } else {
617     iovec* vec = new iovec[count];
618     writeChainImpl(callback, vec, count, std::move(buf), flags);
619     delete[] vec;
620   }
621 }
622
623 void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec,
624     size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
625   const IOBuf* head = buf.get();
626   const IOBuf* next = head;
627   unsigned i = 0;
628   do {
629     vec[i].iov_base = const_cast<uint8_t *>(next->data());
630     vec[i].iov_len = next->length();
631     // IOBuf can get confused by empty iovec buffers, so increment the
632     // output pointer only if the iovec buffer is non-empty.  We could
633     // end the loop with i < count, but that's ok.
634     if (vec[i].iov_len != 0) {
635       i++;
636     }
637     next = next->next();
638   } while (next != head);
639   writeImpl(callback, vec, i, std::move(buf), flags);
640 }
641
642 void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
643                              size_t count, unique_ptr<IOBuf>&& buf,
644                              WriteFlags flags) {
645   VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
646           << ", callback=" << callback << ", count=" << count
647           << ", state=" << state_;
648   DestructorGuard dg(this);
649   unique_ptr<IOBuf>ioBuf(std::move(buf));
650   assert(eventBase_->isInEventBaseThread());
651
652   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
653     // No new writes may be performed after the write side of the socket has
654     // been shutdown.
655     //
656     // We could just call callback->writeError() here to fail just this write.
657     // However, fail hard and use invalidState() to fail all outstanding
658     // callbacks and move the socket into the error state.  There's most likely
659     // a bug in the caller's code, so we abort everything rather than trying to
660     // proceed as best we can.
661     return invalidState(callback);
662   }
663
664   uint32_t countWritten = 0;
665   uint32_t partialWritten = 0;
666   int bytesWritten = 0;
667   bool mustRegister = false;
668   if (state_ == StateEnum::ESTABLISHED && !connecting()) {
669     if (writeReqHead_ == nullptr) {
670       // If we are established and there are no other writes pending,
671       // we can attempt to perform the write immediately.
672       assert(writeReqTail_ == nullptr);
673       assert((eventFlags_ & EventHandler::WRITE) == 0);
674
675       bytesWritten = performWrite(vec, count, flags,
676                                   &countWritten, &partialWritten);
677       if (bytesWritten < 0) {
678         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
679                                withAddr("writev failed"), errno);
680         return failWrite(__func__, callback, 0, ex);
681       } else if (countWritten == count) {
682         // We successfully wrote everything.
683         // Invoke the callback and return.
684         if (callback) {
685           callback->writeSuccess();
686         }
687         return;
688       } // else { continue writing the next writeReq }
689       mustRegister = true;
690     }
691   } else if (!connecting()) {
692     // Invalid state for writing
693     return invalidState(callback);
694   }
695
696   // Create a new WriteRequest to add to the queue
697   WriteRequest* req;
698   try {
699     req = WriteRequest::newRequest(callback, vec + countWritten,
700                                    count - countWritten, std::move(ioBuf),
701                                    flags);
702   } catch (const std::exception& ex) {
703     // we mainly expect to catch std::bad_alloc here
704     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
705         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
706     return failWrite(__func__, callback, bytesWritten, tex);
707   }
708   req->consume(0, partialWritten, bytesWritten);
709   if (writeReqTail_ == nullptr) {
710     assert(writeReqHead_ == nullptr);
711     writeReqHead_ = writeReqTail_ = req;
712   } else {
713     writeReqTail_->append(req);
714     writeReqTail_ = req;
715   }
716
717   // Register for write events if are established and not currently
718   // waiting on write events
719   if (mustRegister) {
720     assert(state_ == StateEnum::ESTABLISHED);
721     assert((eventFlags_ & EventHandler::WRITE) == 0);
722     if (!updateEventRegistration(EventHandler::WRITE, 0)) {
723       assert(state_ == StateEnum::ERROR);
724       return;
725     }
726     if (sendTimeout_ > 0) {
727       // Schedule a timeout to fire if the write takes too long.
728       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
729         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
730                                withAddr("failed to schedule send timeout"));
731         return failWrite(__func__, ex);
732       }
733     }
734   }
735 }
736
737 void AsyncSocket::close() {
738   VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
739           << ", state=" << state_ << ", shutdownFlags="
740           << std::hex << (int) shutdownFlags_;
741
742   // close() is only different from closeNow() when there are pending writes
743   // that need to drain before we can close.  In all other cases, just call
744   // closeNow().
745   //
746   // Note that writeReqHead_ can be non-nullptr even in STATE_CLOSED or
747   // STATE_ERROR if close() is invoked while a previous closeNow() or failure
748   // is still running.  (e.g., If there are multiple pending writes, and we
749   // call writeError() on the first one, it may call close().  In this case we
750   // will already be in STATE_CLOSED or STATE_ERROR, but the remaining pending
751   // writes will still be in the queue.)
752   //
753   // We only need to drain pending writes if we are still in STATE_CONNECTING
754   // or STATE_ESTABLISHED
755   if ((writeReqHead_ == nullptr) ||
756       !(state_ == StateEnum::CONNECTING ||
757       state_ == StateEnum::ESTABLISHED)) {
758     closeNow();
759     return;
760   }
761
762   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
763   // destroyed until close() returns.
764   DestructorGuard dg(this);
765   assert(eventBase_->isInEventBaseThread());
766
767   // Since there are write requests pending, we have to set the
768   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
769   // connect finishes and we finish writing these requests.
770   //
771   // Set SHUT_READ to indicate that reads are shut down, and set the
772   // SHUT_WRITE_PENDING flag to mark that we want to shutdown once the
773   // pending writes complete.
774   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE_PENDING);
775
776   // If a read callback is set, invoke readEOF() immediately to inform it that
777   // the socket has been closed and no more data can be read.
778   if (readCallback_) {
779     // Disable reads if they are enabled
780     if (!updateEventRegistration(0, EventHandler::READ)) {
781       // We're now in the error state; callbacks have been cleaned up
782       assert(state_ == StateEnum::ERROR);
783       assert(readCallback_ == nullptr);
784     } else {
785       ReadCallback* callback = readCallback_;
786       readCallback_ = nullptr;
787       callback->readEOF();
788     }
789   }
790 }
791
792 void AsyncSocket::closeNow() {
793   VLOG(5) << "AsyncSocket::closeNow(): this=" << this << ", fd_=" << fd_
794           << ", state=" << state_ << ", shutdownFlags="
795           << std::hex << (int) shutdownFlags_;
796   DestructorGuard dg(this);
797   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
798
799   switch (state_) {
800     case StateEnum::ESTABLISHED:
801     case StateEnum::CONNECTING:
802     {
803       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
804       state_ = StateEnum::CLOSED;
805
806       // If the write timeout was set, cancel it.
807       writeTimeout_.cancelTimeout();
808
809       // If we are registered for I/O events, unregister.
810       if (eventFlags_ != EventHandler::NONE) {
811         eventFlags_ = EventHandler::NONE;
812         if (!updateEventRegistration()) {
813           // We will have been moved into the error state.
814           assert(state_ == StateEnum::ERROR);
815           return;
816         }
817       }
818
819       if (fd_ >= 0) {
820         ioHandler_.changeHandlerFD(-1);
821         doClose();
822       }
823
824       if (connectCallback_) {
825         ConnectCallback* callback = connectCallback_;
826         connectCallback_ = nullptr;
827         callback->connectErr(socketClosedLocallyEx);
828       }
829
830       failAllWrites(socketClosedLocallyEx);
831
832       if (readCallback_) {
833         ReadCallback* callback = readCallback_;
834         readCallback_ = nullptr;
835         callback->readEOF();
836       }
837       return;
838     }
839     case StateEnum::CLOSED:
840       // Do nothing.  It's possible that we are being called recursively
841       // from inside a callback that we invoked inside another call to close()
842       // that is still running.
843       return;
844     case StateEnum::ERROR:
845       // Do nothing.  The error handling code has performed (or is performing)
846       // cleanup.
847       return;
848     case StateEnum::UNINIT:
849       assert(eventFlags_ == EventHandler::NONE);
850       assert(connectCallback_ == nullptr);
851       assert(readCallback_ == nullptr);
852       assert(writeReqHead_ == nullptr);
853       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
854       state_ = StateEnum::CLOSED;
855       return;
856   }
857
858   LOG(DFATAL) << "AsyncSocket::closeNow() (this=" << this << ", fd=" << fd_
859               << ") called in unknown state " << state_;
860 }
861
862 void AsyncSocket::closeWithReset() {
863   // Enable SO_LINGER, with the linger timeout set to 0.
864   // This will trigger a TCP reset when we close the socket.
865   if (fd_ >= 0) {
866     struct linger optLinger = {1, 0};
867     if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
868       VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
869               << "on " << fd_ << ": errno=" << errno;
870     }
871   }
872
873   // Then let closeNow() take care of the rest
874   closeNow();
875 }
876
877 void AsyncSocket::shutdownWrite() {
878   VLOG(5) << "AsyncSocket::shutdownWrite(): this=" << this << ", fd=" << fd_
879           << ", state=" << state_ << ", shutdownFlags="
880           << std::hex << (int) shutdownFlags_;
881
882   // If there are no pending writes, shutdownWrite() is identical to
883   // shutdownWriteNow().
884   if (writeReqHead_ == nullptr) {
885     shutdownWriteNow();
886     return;
887   }
888
889   assert(eventBase_->isInEventBaseThread());
890
891   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
892   // shutdown will be performed once all writes complete.
893   shutdownFlags_ |= SHUT_WRITE_PENDING;
894 }
895
896 void AsyncSocket::shutdownWriteNow() {
897   VLOG(5) << "AsyncSocket::shutdownWriteNow(): this=" << this
898           << ", fd=" << fd_ << ", state=" << state_
899           << ", shutdownFlags=" << std::hex << (int) shutdownFlags_;
900
901   if (shutdownFlags_ & SHUT_WRITE) {
902     // Writes are already shutdown; nothing else to do.
903     return;
904   }
905
906   // If SHUT_READ is already set, just call closeNow() to completely
907   // close the socket.  This can happen if close() was called with writes
908   // pending, and then shutdownWriteNow() is called before all pending writes
909   // complete.
910   if (shutdownFlags_ & SHUT_READ) {
911     closeNow();
912     return;
913   }
914
915   DestructorGuard dg(this);
916   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
917
918   switch (static_cast<StateEnum>(state_)) {
919     case StateEnum::ESTABLISHED:
920     {
921       shutdownFlags_ |= SHUT_WRITE;
922
923       // If the write timeout was set, cancel it.
924       writeTimeout_.cancelTimeout();
925
926       // If we are registered for write events, unregister.
927       if (!updateEventRegistration(0, EventHandler::WRITE)) {
928         // We will have been moved into the error state.
929         assert(state_ == StateEnum::ERROR);
930         return;
931       }
932
933       // Shutdown writes on the file descriptor
934       ::shutdown(fd_, SHUT_WR);
935
936       // Immediately fail all write requests
937       failAllWrites(socketShutdownForWritesEx);
938       return;
939     }
940     case StateEnum::CONNECTING:
941     {
942       // Set the SHUT_WRITE_PENDING flag.
943       // When the connection completes, it will check this flag,
944       // shutdown the write half of the socket, and then set SHUT_WRITE.
945       shutdownFlags_ |= SHUT_WRITE_PENDING;
946
947       // Immediately fail all write requests
948       failAllWrites(socketShutdownForWritesEx);
949       return;
950     }
951     case StateEnum::UNINIT:
952       // Callers normally shouldn't call shutdownWriteNow() before the socket
953       // even starts connecting.  Nonetheless, go ahead and set
954       // SHUT_WRITE_PENDING.  Once the socket eventually connects it will
955       // immediately shut down the write side of the socket.
956       shutdownFlags_ |= SHUT_WRITE_PENDING;
957       return;
958     case StateEnum::CLOSED:
959     case StateEnum::ERROR:
960       // We should never get here.  SHUT_WRITE should always be set
961       // in STATE_CLOSED and STATE_ERROR.
962       VLOG(4) << "AsyncSocket::shutdownWriteNow() (this=" << this
963                  << ", fd=" << fd_ << ") in unexpected state " << state_
964                  << " with SHUT_WRITE not set ("
965                  << std::hex << (int) shutdownFlags_ << ")";
966       assert(false);
967       return;
968   }
969
970   LOG(DFATAL) << "AsyncSocket::shutdownWriteNow() (this=" << this << ", fd="
971               << fd_ << ") called in unknown state " << state_;
972 }
973
974 bool AsyncSocket::readable() const {
975   if (fd_ == -1) {
976     return false;
977   }
978   struct pollfd fds[1];
979   fds[0].fd = fd_;
980   fds[0].events = POLLIN;
981   fds[0].revents = 0;
982   int rc = poll(fds, 1, 0);
983   return rc == 1;
984 }
985
986 bool AsyncSocket::isPending() const {
987   return ioHandler_.isPending();
988 }
989
990 bool AsyncSocket::hangup() const {
991   if (fd_ == -1) {
992     // sanity check, no one should ask for hangup if we are not connected.
993     assert(false);
994     return false;
995   }
996 #ifdef POLLRDHUP // Linux-only
997   struct pollfd fds[1];
998   fds[0].fd = fd_;
999   fds[0].events = POLLRDHUP|POLLHUP;
1000   fds[0].revents = 0;
1001   poll(fds, 1, 0);
1002   return (fds[0].revents & (POLLRDHUP|POLLHUP)) != 0;
1003 #else
1004   return false;
1005 #endif
1006 }
1007
1008 bool AsyncSocket::good() const {
1009   return ((state_ == StateEnum::CONNECTING ||
1010           state_ == StateEnum::ESTABLISHED) &&
1011           (shutdownFlags_ == 0) && (eventBase_ != nullptr));
1012 }
1013
1014 bool AsyncSocket::error() const {
1015   return (state_ == StateEnum::ERROR);
1016 }
1017
1018 void AsyncSocket::attachEventBase(EventBase* eventBase) {
1019   VLOG(5) << "AsyncSocket::attachEventBase(this=" << this << ", fd=" << fd_
1020           << ", old evb=" << eventBase_ << ", new evb=" << eventBase
1021           << ", state=" << state_ << ", events="
1022           << std::hex << eventFlags_ << ")";
1023   assert(eventBase_ == nullptr);
1024   assert(eventBase->isInEventBaseThread());
1025
1026   eventBase_ = eventBase;
1027   ioHandler_.attachEventBase(eventBase);
1028   writeTimeout_.attachEventBase(eventBase);
1029 }
1030
1031 void AsyncSocket::detachEventBase() {
1032   VLOG(5) << "AsyncSocket::detachEventBase(this=" << this << ", fd=" << fd_
1033           << ", old evb=" << eventBase_ << ", state=" << state_
1034           << ", events=" << std::hex << eventFlags_ << ")";
1035   assert(eventBase_ != nullptr);
1036   assert(eventBase_->isInEventBaseThread());
1037
1038   eventBase_ = nullptr;
1039   ioHandler_.detachEventBase();
1040   writeTimeout_.detachEventBase();
1041 }
1042
1043 bool AsyncSocket::isDetachable() const {
1044   DCHECK(eventBase_ != nullptr);
1045   DCHECK(eventBase_->isInEventBaseThread());
1046
1047   return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
1048 }
1049
1050 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
1051   address->setFromLocalAddress(fd_);
1052 }
1053
1054 void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
1055   if (!addr_.isInitialized()) {
1056     addr_.setFromPeerAddress(fd_);
1057   }
1058   *address = addr_;
1059 }
1060
1061 int AsyncSocket::setNoDelay(bool noDelay) {
1062   if (fd_ < 0) {
1063     VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket "
1064                << this << "(state=" << state_ << ")";
1065     return EINVAL;
1066
1067   }
1068
1069   int value = noDelay ? 1 : 0;
1070   if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
1071     int errnoCopy = errno;
1072     VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket "
1073             << this << " (fd=" << fd_ << ", state=" << state_ << "): "
1074             << strerror(errnoCopy);
1075     return errnoCopy;
1076   }
1077
1078   return 0;
1079 }
1080
1081 int AsyncSocket::setCongestionFlavor(const std::string &cname) {
1082
1083   #ifndef TCP_CONGESTION
1084   #define TCP_CONGESTION  13
1085   #endif
1086
1087   if (fd_ < 0) {
1088     VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
1089                << "socket " << this << "(state=" << state_ << ")";
1090     return EINVAL;
1091
1092   }
1093
1094   if (setsockopt(fd_, IPPROTO_TCP, TCP_CONGESTION, cname.c_str(),
1095         cname.length() + 1) != 0) {
1096     int errnoCopy = errno;
1097     VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket "
1098             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1099             << strerror(errnoCopy);
1100     return errnoCopy;
1101   }
1102
1103   return 0;
1104 }
1105
1106 int AsyncSocket::setQuickAck(bool quickack) {
1107   if (fd_ < 0) {
1108     VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket "
1109                << this << "(state=" << state_ << ")";
1110     return EINVAL;
1111
1112   }
1113
1114 #ifdef TCP_QUICKACK // Linux-only
1115   int value = quickack ? 1 : 0;
1116   if (setsockopt(fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
1117     int errnoCopy = errno;
1118     VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket"
1119             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1120             << strerror(errnoCopy);
1121     return errnoCopy;
1122   }
1123
1124   return 0;
1125 #else
1126   return ENOSYS;
1127 #endif
1128 }
1129
1130 int AsyncSocket::setSendBufSize(size_t bufsize) {
1131   if (fd_ < 0) {
1132     VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
1133                << this << "(state=" << state_ << ")";
1134     return EINVAL;
1135   }
1136
1137   if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) !=0) {
1138     int errnoCopy = errno;
1139     VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket"
1140             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1141             << strerror(errnoCopy);
1142     return errnoCopy;
1143   }
1144
1145   return 0;
1146 }
1147
1148 int AsyncSocket::setRecvBufSize(size_t bufsize) {
1149   if (fd_ < 0) {
1150     VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
1151                << this << "(state=" << state_ << ")";
1152     return EINVAL;
1153   }
1154
1155   if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) !=0) {
1156     int errnoCopy = errno;
1157     VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket"
1158             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1159             << strerror(errnoCopy);
1160     return errnoCopy;
1161   }
1162
1163   return 0;
1164 }
1165
1166 int AsyncSocket::setTCPProfile(int profd) {
1167   if (fd_ < 0) {
1168     VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket "
1169                << this << "(state=" << state_ << ")";
1170     return EINVAL;
1171   }
1172
1173   if (setsockopt(fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) !=0) {
1174     int errnoCopy = errno;
1175     VLOG(2) << "failed to set socket namespace option on AsyncSocket"
1176             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1177             << strerror(errnoCopy);
1178     return errnoCopy;
1179   }
1180
1181   return 0;
1182 }
1183
1184 void AsyncSocket::ioReady(uint16_t events) noexcept {
1185   VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
1186           << ", events=" << std::hex << events << ", state=" << state_;
1187   DestructorGuard dg(this);
1188   assert(events & EventHandler::READ_WRITE);
1189   assert(eventBase_->isInEventBaseThread());
1190
1191   uint16_t relevantEvents = events & EventHandler::READ_WRITE;
1192   if (relevantEvents == EventHandler::READ) {
1193     handleRead();
1194   } else if (relevantEvents == EventHandler::WRITE) {
1195     handleWrite();
1196   } else if (relevantEvents == EventHandler::READ_WRITE) {
1197     EventBase* originalEventBase = eventBase_;
1198     // If both read and write events are ready, process writes first.
1199     handleWrite();
1200
1201     // Return now if handleWrite() detached us from our EventBase
1202     if (eventBase_ != originalEventBase) {
1203       return;
1204     }
1205
1206     // Only call handleRead() if a read callback is still installed.
1207     // (It's possible that the read callback was uninstalled during
1208     // handleWrite().)
1209     if (readCallback_) {
1210       handleRead();
1211     }
1212   } else {
1213     VLOG(4) << "AsyncSocket::ioRead() called with unexpected events "
1214                << std::hex << events << "(this=" << this << ")";
1215     abort();
1216   }
1217 }
1218
1219 ssize_t AsyncSocket::performRead(void* buf, size_t buflen) {
1220   ssize_t bytes = recv(fd_, buf, buflen, MSG_DONTWAIT);
1221   if (bytes < 0) {
1222     if (errno == EAGAIN || errno == EWOULDBLOCK) {
1223       // No more data to read right now.
1224       return READ_BLOCKING;
1225     } else {
1226       return READ_ERROR;
1227     }
1228   } else {
1229     appBytesReceived_ += bytes;
1230     return bytes;
1231   }
1232 }
1233
1234 void AsyncSocket::handleRead() noexcept {
1235   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
1236           << ", state=" << state_;
1237   assert(state_ == StateEnum::ESTABLISHED);
1238   assert((shutdownFlags_ & SHUT_READ) == 0);
1239   assert(readCallback_ != nullptr);
1240   assert(eventFlags_ & EventHandler::READ);
1241
1242   // Loop until:
1243   // - a read attempt would block
1244   // - readCallback_ is uninstalled
1245   // - the number of loop iterations exceeds the optional maximum
1246   // - this AsyncSocket is moved to another EventBase
1247   //
1248   // When we invoke readDataAvailable() it may uninstall the readCallback_,
1249   // which is why need to check for it here.
1250   //
1251   // The last bullet point is slightly subtle.  readDataAvailable() may also
1252   // detach this socket from this EventBase.  However, before
1253   // readDataAvailable() returns another thread may pick it up, attach it to
1254   // a different EventBase, and install another readCallback_.  We need to
1255   // exit immediately after readDataAvailable() returns if the eventBase_ has
1256   // changed.  (The caller must perform some sort of locking to transfer the
1257   // AsyncSocket between threads properly.  This will be sufficient to ensure
1258   // that this thread sees the updated eventBase_ variable after
1259   // readDataAvailable() returns.)
1260   uint16_t numReads = 0;
1261   EventBase* originalEventBase = eventBase_;
1262   while (readCallback_ && eventBase_ == originalEventBase) {
1263     // Get the buffer to read into.
1264     void* buf = nullptr;
1265     size_t buflen = 0;
1266     try {
1267       readCallback_->getReadBuffer(&buf, &buflen);
1268     } catch (const AsyncSocketException& ex) {
1269       return failRead(__func__, ex);
1270     } catch (const std::exception& ex) {
1271       AsyncSocketException tex(AsyncSocketException::BAD_ARGS,
1272                               string("ReadCallback::getReadBuffer() "
1273                                      "threw exception: ") +
1274                               ex.what());
1275       return failRead(__func__, tex);
1276     } catch (...) {
1277       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1278                              "ReadCallback::getReadBuffer() threw "
1279                              "non-exception type");
1280       return failRead(__func__, ex);
1281     }
1282     if (buf == nullptr || buflen == 0) {
1283       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1284                              "ReadCallback::getReadBuffer() returned "
1285                              "empty buffer");
1286       return failRead(__func__, ex);
1287     }
1288
1289     // Perform the read
1290     ssize_t bytesRead = performRead(buf, buflen);
1291     if (bytesRead > 0) {
1292       readCallback_->readDataAvailable(bytesRead);
1293       // Fall through and continue around the loop if the read
1294       // completely filled the available buffer.
1295       // Note that readCallback_ may have been uninstalled or changed inside
1296       // readDataAvailable().
1297       if (size_t(bytesRead) < buflen) {
1298         return;
1299       }
1300     } else if (bytesRead == READ_BLOCKING) {
1301         // No more data to read right now.
1302         return;
1303     } else if (bytesRead == READ_ERROR) {
1304       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1305                              withAddr("recv() failed"), errno);
1306       return failRead(__func__, ex);
1307     } else {
1308       assert(bytesRead == READ_EOF);
1309       // EOF
1310       shutdownFlags_ |= SHUT_READ;
1311       if (!updateEventRegistration(0, EventHandler::READ)) {
1312         // we've already been moved into STATE_ERROR
1313         assert(state_ == StateEnum::ERROR);
1314         assert(readCallback_ == nullptr);
1315         return;
1316       }
1317
1318       ReadCallback* callback = readCallback_;
1319       readCallback_ = nullptr;
1320       callback->readEOF();
1321       return;
1322     }
1323     if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
1324       return;
1325     }
1326   }
1327 }
1328
1329 /**
1330  * This function attempts to write as much data as possible, until no more data
1331  * can be written.
1332  *
1333  * - If it sends all available data, it unregisters for write events, and stops
1334  *   the writeTimeout_.
1335  *
1336  * - If not all of the data can be sent immediately, it reschedules
1337  *   writeTimeout_ (if a non-zero timeout is set), and ensures the handler is
1338  *   registered for write events.
1339  */
1340 void AsyncSocket::handleWrite() noexcept {
1341   VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
1342           << ", state=" << state_;
1343   if (state_ == StateEnum::CONNECTING) {
1344     handleConnect();
1345     return;
1346   }
1347
1348   // Normal write
1349   assert(state_ == StateEnum::ESTABLISHED);
1350   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1351   assert(writeReqHead_ != nullptr);
1352
1353   // Loop until we run out of write requests,
1354   // or until this socket is moved to another EventBase.
1355   // (See the comment in handleRead() explaining how this can happen.)
1356   EventBase* originalEventBase = eventBase_;
1357   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
1358     uint32_t countWritten;
1359     uint32_t partialWritten;
1360     WriteFlags writeFlags = writeReqHead_->flags();
1361     if (writeReqHead_->getNext() != nullptr) {
1362       writeFlags = writeFlags | WriteFlags::CORK;
1363     }
1364     int bytesWritten = performWrite(writeReqHead_->getOps(),
1365                                     writeReqHead_->getOpCount(),
1366                                     writeFlags, &countWritten, &partialWritten);
1367     if (bytesWritten < 0) {
1368       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1369                              withAddr("writev() failed"), errno);
1370       return failWrite(__func__, ex);
1371     } else if (countWritten == writeReqHead_->getOpCount()) {
1372       // We finished this request
1373       WriteRequest* req = writeReqHead_;
1374       writeReqHead_ = req->getNext();
1375
1376       if (writeReqHead_ == nullptr) {
1377         writeReqTail_ = nullptr;
1378         // This is the last write request.
1379         // Unregister for write events and cancel the send timer
1380         // before we invoke the callback.  We have to update the state properly
1381         // before calling the callback, since it may want to detach us from
1382         // the EventBase.
1383         if (eventFlags_ & EventHandler::WRITE) {
1384           if (!updateEventRegistration(0, EventHandler::WRITE)) {
1385             assert(state_ == StateEnum::ERROR);
1386             return;
1387           }
1388           // Stop the send timeout
1389           writeTimeout_.cancelTimeout();
1390         }
1391         assert(!writeTimeout_.isScheduled());
1392
1393         // If SHUT_WRITE_PENDING is set, we should shutdown the socket after
1394         // we finish sending the last write request.
1395         //
1396         // We have to do this before invoking writeSuccess(), since
1397         // writeSuccess() may detach us from our EventBase.
1398         if (shutdownFlags_ & SHUT_WRITE_PENDING) {
1399           assert(connectCallback_ == nullptr);
1400           shutdownFlags_ |= SHUT_WRITE;
1401
1402           if (shutdownFlags_ & SHUT_READ) {
1403             // Reads have already been shutdown.  Fully close the socket and
1404             // move to STATE_CLOSED.
1405             //
1406             // Note: This code currently moves us to STATE_CLOSED even if
1407             // close() hasn't ever been called.  This can occur if we have
1408             // received EOF from the peer and shutdownWrite() has been called
1409             // locally.  Should we bother staying in STATE_ESTABLISHED in this
1410             // case, until close() is actually called?  I can't think of a
1411             // reason why we would need to do so.  No other operations besides
1412             // calling close() or destroying the socket can be performed at
1413             // this point.
1414             assert(readCallback_ == nullptr);
1415             state_ = StateEnum::CLOSED;
1416             if (fd_ >= 0) {
1417               ioHandler_.changeHandlerFD(-1);
1418               doClose();
1419             }
1420           } else {
1421             // Reads are still enabled, so we are only doing a half-shutdown
1422             ::shutdown(fd_, SHUT_WR);
1423           }
1424         }
1425       }
1426
1427       // Invoke the callback
1428       WriteCallback* callback = req->getCallback();
1429       req->destroy();
1430       if (callback) {
1431         callback->writeSuccess();
1432       }
1433       // We'll continue around the loop, trying to write another request
1434     } else {
1435       // Partial write.
1436       writeReqHead_->consume(countWritten, partialWritten, bytesWritten);
1437       // Stop after a partial write; it's highly likely that a subsequent write
1438       // attempt will just return EAGAIN.
1439       //
1440       // Ensure that we are registered for write events.
1441       if ((eventFlags_ & EventHandler::WRITE) == 0) {
1442         if (!updateEventRegistration(EventHandler::WRITE, 0)) {
1443           assert(state_ == StateEnum::ERROR);
1444           return;
1445         }
1446       }
1447
1448       // Reschedule the send timeout, since we have made some write progress.
1449       if (sendTimeout_ > 0) {
1450         if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
1451           AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1452               withAddr("failed to reschedule write timeout"));
1453           return failWrite(__func__, ex);
1454         }
1455       }
1456       return;
1457     }
1458   }
1459 }
1460
1461 void AsyncSocket::checkForImmediateRead() noexcept {
1462   // We currently don't attempt to perform optimistic reads in AsyncSocket.
1463   // (However, note that some subclasses do override this method.)
1464   //
1465   // Simply calling handleRead() here would be bad, as this would call
1466   // readCallback_->getReadBuffer(), forcing the callback to allocate a read
1467   // buffer even though no data may be available.  This would waste lots of
1468   // memory, since the buffer will sit around unused until the socket actually
1469   // becomes readable.
1470   //
1471   // Checking if the socket is readable now also seems like it would probably
1472   // be a pessimism.  In most cases it probably wouldn't be readable, and we
1473   // would just waste an extra system call.  Even if it is readable, waiting to
1474   // find out from libevent on the next event loop doesn't seem that bad.
1475 }
1476
1477 void AsyncSocket::handleInitialReadWrite() noexcept {
1478   // Our callers should already be holding a DestructorGuard, but grab
1479   // one here just to make sure, in case one of our calling code paths ever
1480   // changes.
1481   DestructorGuard dg(this);
1482
1483   // If we have a readCallback_, make sure we enable read events.  We
1484   // may already be registered for reads if connectSuccess() set
1485   // the read calback.
1486   if (readCallback_ && !(eventFlags_ & EventHandler::READ)) {
1487     assert(state_ == StateEnum::ESTABLISHED);
1488     assert((shutdownFlags_ & SHUT_READ) == 0);
1489     if (!updateEventRegistration(EventHandler::READ, 0)) {
1490       assert(state_ == StateEnum::ERROR);
1491       return;
1492     }
1493     checkForImmediateRead();
1494   } else if (readCallback_ == nullptr) {
1495     // Unregister for read events.
1496     updateEventRegistration(0, EventHandler::READ);
1497   }
1498
1499   // If we have write requests pending, try to send them immediately.
1500   // Since we just finished accepting, there is a very good chance that we can
1501   // write without blocking.
1502   //
1503   // However, we only process them if EventHandler::WRITE is not already set,
1504   // which means that we're already blocked on a write attempt.  (This can
1505   // happen if connectSuccess() called write() before returning.)
1506   if (writeReqHead_ && !(eventFlags_ & EventHandler::WRITE)) {
1507     // Call handleWrite() to perform write processing.
1508     handleWrite();
1509   } else if (writeReqHead_ == nullptr) {
1510     // Unregister for write event.
1511     updateEventRegistration(0, EventHandler::WRITE);
1512   }
1513 }
1514
1515 void AsyncSocket::handleConnect() noexcept {
1516   VLOG(5) << "AsyncSocket::handleConnect() this=" << this << ", fd=" << fd_
1517           << ", state=" << state_;
1518   assert(state_ == StateEnum::CONNECTING);
1519   // SHUT_WRITE can never be set while we are still connecting;
1520   // SHUT_WRITE_PENDING may be set, be we only set SHUT_WRITE once the connect
1521   // finishes
1522   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1523
1524   // In case we had a connect timeout, cancel the timeout
1525   writeTimeout_.cancelTimeout();
1526   // We don't use a persistent registration when waiting on a connect event,
1527   // so we have been automatically unregistered now.  Update eventFlags_ to
1528   // reflect reality.
1529   assert(eventFlags_ == EventHandler::WRITE);
1530   eventFlags_ = EventHandler::NONE;
1531
1532   // Call getsockopt() to check if the connect succeeded
1533   int error;
1534   socklen_t len = sizeof(error);
1535   int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
1536   if (rv != 0) {
1537     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1538                            withAddr("error calling getsockopt() after connect"),
1539                            errno);
1540     VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1541                << fd_ << " host=" << addr_.describe()
1542                << ") exception:" << ex.what();
1543     return failConnect(__func__, ex);
1544   }
1545
1546   if (error != 0) {
1547     AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1548                            "connect failed", error);
1549     VLOG(1) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1550             << fd_ << " host=" << addr_.describe()
1551             << ") exception: " << ex.what();
1552     return failConnect(__func__, ex);
1553   }
1554
1555   // Move into STATE_ESTABLISHED
1556   state_ = StateEnum::ESTABLISHED;
1557
1558   // If SHUT_WRITE_PENDING is set and we don't have any write requests to
1559   // perform, immediately shutdown the write half of the socket.
1560   if ((shutdownFlags_ & SHUT_WRITE_PENDING) && writeReqHead_ == nullptr) {
1561     // SHUT_READ shouldn't be set.  If close() is called on the socket while we
1562     // are still connecting we just abort the connect rather than waiting for
1563     // it to complete.
1564     assert((shutdownFlags_ & SHUT_READ) == 0);
1565     ::shutdown(fd_, SHUT_WR);
1566     shutdownFlags_ |= SHUT_WRITE;
1567   }
1568
1569   VLOG(7) << "AsyncSocket " << this << ": fd " << fd_
1570           << "successfully connected; state=" << state_;
1571
1572   // Remember the EventBase we are attached to, before we start invoking any
1573   // callbacks (since the callbacks may call detachEventBase()).
1574   EventBase* originalEventBase = eventBase_;
1575
1576   // Call the connect callback.
1577   if (connectCallback_) {
1578     ConnectCallback* callback = connectCallback_;
1579     connectCallback_ = nullptr;
1580     callback->connectSuccess();
1581   }
1582
1583   // Note that the connect callback may have changed our state.
1584   // (set or unset the read callback, called write(), closed the socket, etc.)
1585   // The following code needs to handle these situations correctly.
1586   //
1587   // If the socket has been closed, readCallback_ and writeReqHead_ will
1588   // always be nullptr, so that will prevent us from trying to read or write.
1589   //
1590   // The main thing to check for is if eventBase_ is still originalEventBase.
1591   // If not, we have been detached from this event base, so we shouldn't
1592   // perform any more operations.
1593   if (eventBase_ != originalEventBase) {
1594     return;
1595   }
1596
1597   handleInitialReadWrite();
1598 }
1599
1600 void AsyncSocket::timeoutExpired() noexcept {
1601   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
1602           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
1603   DestructorGuard dg(this);
1604   assert(eventBase_->isInEventBaseThread());
1605
1606   if (state_ == StateEnum::CONNECTING) {
1607     // connect() timed out
1608     // Unregister for I/O events.
1609     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
1610                            "connect timed out");
1611     failConnect(__func__, ex);
1612   } else {
1613     // a normal write operation timed out
1614     assert(state_ == StateEnum::ESTABLISHED);
1615     AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out");
1616     failWrite(__func__, ex);
1617   }
1618 }
1619
1620 ssize_t AsyncSocket::performWrite(const iovec* vec,
1621                                    uint32_t count,
1622                                    WriteFlags flags,
1623                                    uint32_t* countWritten,
1624                                    uint32_t* partialWritten) {
1625   // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
1626   // We correctly handle EPIPE errors, so we never want to receive SIGPIPE
1627   // (since it may terminate the program if the main program doesn't explicitly
1628   // ignore it).
1629   struct msghdr msg;
1630   msg.msg_name = nullptr;
1631   msg.msg_namelen = 0;
1632   msg.msg_iov = const_cast<iovec *>(vec);
1633 #ifdef IOV_MAX // not defined on Android
1634   msg.msg_iovlen = std::min(count, (uint32_t)IOV_MAX);
1635 #else
1636   msg.msg_iovlen = std::min(count, (uint32_t)UIO_MAXIOV);
1637 #endif
1638   msg.msg_control = nullptr;
1639   msg.msg_controllen = 0;
1640   msg.msg_flags = 0;
1641
1642   int msg_flags = MSG_DONTWAIT;
1643
1644 #ifdef MSG_NOSIGNAL // Linux-only
1645   msg_flags |= MSG_NOSIGNAL;
1646   if (isSet(flags, WriteFlags::CORK)) {
1647     // MSG_MORE tells the kernel we have more data to send, so wait for us to
1648     // give it the rest of the data rather than immediately sending a partial
1649     // frame, even when TCP_NODELAY is enabled.
1650     msg_flags |= MSG_MORE;
1651   }
1652 #endif
1653   if (isSet(flags, WriteFlags::EOR)) {
1654     // marks that this is the last byte of a record (response)
1655     msg_flags |= MSG_EOR;
1656   }
1657   ssize_t totalWritten = ::sendmsg(fd_, &msg, msg_flags);
1658   if (totalWritten < 0) {
1659     if (errno == EAGAIN) {
1660       // TCP buffer is full; we can't write any more data right now.
1661       *countWritten = 0;
1662       *partialWritten = 0;
1663       return 0;
1664     }
1665     // error
1666     *countWritten = 0;
1667     *partialWritten = 0;
1668     return -1;
1669   }
1670
1671   appBytesWritten_ += totalWritten;
1672
1673   uint32_t bytesWritten;
1674   uint32_t n;
1675   for (bytesWritten = totalWritten, n = 0; n < count; ++n) {
1676     const iovec* v = vec + n;
1677     if (v->iov_len > bytesWritten) {
1678       // Partial write finished in the middle of this iovec
1679       *countWritten = n;
1680       *partialWritten = bytesWritten;
1681       return totalWritten;
1682     }
1683
1684     bytesWritten -= v->iov_len;
1685   }
1686
1687   assert(bytesWritten == 0);
1688   *countWritten = n;
1689   *partialWritten = 0;
1690   return totalWritten;
1691 }
1692
1693 /**
1694  * Re-register the EventHandler after eventFlags_ has changed.
1695  *
1696  * If an error occurs, fail() is called to move the socket into the error state
1697  * and call all currently installed callbacks.  After an error, the
1698  * AsyncSocket is completely unregistered.
1699  *
1700  * @return Returns true on succcess, or false on error.
1701  */
1702 bool AsyncSocket::updateEventRegistration() {
1703   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
1704           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
1705           << ", events=" << std::hex << eventFlags_;
1706   assert(eventBase_->isInEventBaseThread());
1707   if (eventFlags_ == EventHandler::NONE) {
1708     ioHandler_.unregisterHandler();
1709     return true;
1710   }
1711
1712   // Always register for persistent events, so we don't have to re-register
1713   // after being called back.
1714   if (!ioHandler_.registerHandler(eventFlags_ | EventHandler::PERSIST)) {
1715     eventFlags_ = EventHandler::NONE; // we're not registered after error
1716     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1717         withAddr("failed to update AsyncSocket event registration"));
1718     fail("updateEventRegistration", ex);
1719     return false;
1720   }
1721
1722   return true;
1723 }
1724
1725 bool AsyncSocket::updateEventRegistration(uint16_t enable,
1726                                            uint16_t disable) {
1727   uint16_t oldFlags = eventFlags_;
1728   eventFlags_ |= enable;
1729   eventFlags_ &= ~disable;
1730   if (eventFlags_ == oldFlags) {
1731     return true;
1732   } else {
1733     return updateEventRegistration();
1734   }
1735 }
1736
1737 void AsyncSocket::startFail() {
1738   // startFail() should only be called once
1739   assert(state_ != StateEnum::ERROR);
1740   assert(getDestructorGuardCount() > 0);
1741   state_ = StateEnum::ERROR;
1742   // Ensure that SHUT_READ and SHUT_WRITE are set,
1743   // so all future attempts to read or write will be rejected
1744   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1745
1746   if (eventFlags_ != EventHandler::NONE) {
1747     eventFlags_ = EventHandler::NONE;
1748     ioHandler_.unregisterHandler();
1749   }
1750   writeTimeout_.cancelTimeout();
1751
1752   if (fd_ >= 0) {
1753     ioHandler_.changeHandlerFD(-1);
1754     doClose();
1755   }
1756 }
1757
1758 void AsyncSocket::finishFail() {
1759   assert(state_ == StateEnum::ERROR);
1760   assert(getDestructorGuardCount() > 0);
1761
1762   AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1763                          withAddr("socket closing after error"));
1764   if (connectCallback_) {
1765     ConnectCallback* callback = connectCallback_;
1766     connectCallback_ = nullptr;
1767     callback->connectErr(ex);
1768   }
1769
1770   failAllWrites(ex);
1771
1772   if (readCallback_) {
1773     ReadCallback* callback = readCallback_;
1774     readCallback_ = nullptr;
1775     callback->readErr(ex);
1776   }
1777 }
1778
1779 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
1780   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1781              << state_ << " host=" << addr_.describe()
1782              << "): failed in " << fn << "(): "
1783              << ex.what();
1784   startFail();
1785   finishFail();
1786 }
1787
1788 void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
1789   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1790                << state_ << " host=" << addr_.describe()
1791                << "): failed while connecting in " << fn << "(): "
1792                << ex.what();
1793   startFail();
1794
1795   if (connectCallback_ != nullptr) {
1796     ConnectCallback* callback = connectCallback_;
1797     connectCallback_ = nullptr;
1798     callback->connectErr(ex);
1799   }
1800
1801   finishFail();
1802 }
1803
1804 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
1805   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1806                << state_ << " host=" << addr_.describe()
1807                << "): failed while reading in " << fn << "(): "
1808                << ex.what();
1809   startFail();
1810
1811   if (readCallback_ != nullptr) {
1812     ReadCallback* callback = readCallback_;
1813     readCallback_ = nullptr;
1814     callback->readErr(ex);
1815   }
1816
1817   finishFail();
1818 }
1819
1820 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
1821   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1822                << state_ << " host=" << addr_.describe()
1823                << "): failed while writing in " << fn << "(): "
1824                << ex.what();
1825   startFail();
1826
1827   // Only invoke the first write callback, since the error occurred while
1828   // writing this request.  Let any other pending write callbacks be invoked in
1829   // finishFail().
1830   if (writeReqHead_ != nullptr) {
1831     WriteRequest* req = writeReqHead_;
1832     writeReqHead_ = req->getNext();
1833     WriteCallback* callback = req->getCallback();
1834     uint32_t bytesWritten = req->getBytesWritten();
1835     req->destroy();
1836     if (callback) {
1837       callback->writeErr(bytesWritten, ex);
1838     }
1839   }
1840
1841   finishFail();
1842 }
1843
1844 void AsyncSocket::failWrite(const char* fn, WriteCallback* callback,
1845                              size_t bytesWritten,
1846                              const AsyncSocketException& ex) {
1847   // This version of failWrite() is used when the failure occurs before
1848   // we've added the callback to writeReqHead_.
1849   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1850              << state_ << " host=" << addr_.describe()
1851              <<"): failed while writing in " << fn << "(): "
1852              << ex.what();
1853   startFail();
1854
1855   if (callback != nullptr) {
1856     callback->writeErr(bytesWritten, ex);
1857   }
1858
1859   finishFail();
1860 }
1861
1862 void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
1863   // Invoke writeError() on all write callbacks.
1864   // This is used when writes are forcibly shutdown with write requests
1865   // pending, or when an error occurs with writes pending.
1866   while (writeReqHead_ != nullptr) {
1867     WriteRequest* req = writeReqHead_;
1868     writeReqHead_ = req->getNext();
1869     WriteCallback* callback = req->getCallback();
1870     if (callback) {
1871       callback->writeErr(req->getBytesWritten(), ex);
1872     }
1873     req->destroy();
1874   }
1875 }
1876
1877 void AsyncSocket::invalidState(ConnectCallback* callback) {
1878   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
1879              << "): connect() called in invalid state " << state_;
1880
1881   /*
1882    * The invalidState() methods don't use the normal failure mechanisms,
1883    * since we don't know what state we are in.  We don't want to call
1884    * startFail()/finishFail() recursively if we are already in the middle of
1885    * cleaning up.
1886    */
1887
1888   AsyncSocketException ex(AsyncSocketException::ALREADY_OPEN,
1889                          "connect() called with socket in invalid state");
1890   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
1891     if (callback) {
1892       callback->connectErr(ex);
1893     }
1894   } else {
1895     // We can't use failConnect() here since connectCallback_
1896     // may already be set to another callback.  Invoke this ConnectCallback
1897     // here; any other connectCallback_ will be invoked in finishFail()
1898     startFail();
1899     if (callback) {
1900       callback->connectErr(ex);
1901     }
1902     finishFail();
1903   }
1904 }
1905
1906 void AsyncSocket::invalidState(ReadCallback* callback) {
1907   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
1908              << "): setReadCallback(" << callback
1909              << ") called in invalid state " << state_;
1910
1911   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1912                          "setReadCallback() called with socket in "
1913                          "invalid state");
1914   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
1915     if (callback) {
1916       callback->readErr(ex);
1917     }
1918   } else {
1919     startFail();
1920     if (callback) {
1921       callback->readErr(ex);
1922     }
1923     finishFail();
1924   }
1925 }
1926
1927 void AsyncSocket::invalidState(WriteCallback* callback) {
1928   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
1929              << "): write() called in invalid state " << state_;
1930
1931   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1932                          withAddr("write() called with socket in invalid state"));
1933   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
1934     if (callback) {
1935       callback->writeErr(0, ex);
1936     }
1937   } else {
1938     startFail();
1939     if (callback) {
1940       callback->writeErr(0, ex);
1941     }
1942     finishFail();
1943   }
1944 }
1945
1946 void AsyncSocket::doClose() {
1947   if (fd_ == -1) return;
1948   if (shutdownSocketSet_) {
1949     shutdownSocketSet_->close(fd_);
1950   } else {
1951     ::close(fd_);
1952   }
1953   fd_ = -1;
1954 }
1955
1956 std::ostream& operator << (std::ostream& os,
1957                            const AsyncSocket::StateEnum& state) {
1958   os << static_cast<int>(state);
1959   return os;
1960 }
1961
1962 std::string AsyncSocket::withAddr(const std::string& s) {
1963   // Don't use addr_ directly because it may not be initialized
1964   // e.g. if constructed from fd
1965   folly::SocketAddress peer, local;
1966   try {
1967     getPeerAddress(&peer);
1968     getLocalAddress(&local);
1969   } catch (const std::exception&) {
1970     // ignore
1971   } catch (...) {
1972     // ignore
1973   }
1974   return s + " (peer=" + peer.describe() + ", local=" + local.describe() + ")";
1975 }
1976
1977 } // folly