Move SSL socket to folly
[folly.git] / folly / io / async / AsyncSocket.cpp
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/async/AsyncSocket.h>
18
19 #include <folly/io/async/EventBase.h>
20 #include <folly/SocketAddress.h>
21 #include <folly/io/IOBuf.h>
22
23 #include <poll.h>
24 #include <errno.h>
25 #include <limits.h>
26 #include <unistd.h>
27 #include <fcntl.h>
28 #include <sys/types.h>
29 #include <sys/socket.h>
30 #include <netinet/in.h>
31 #include <netinet/tcp.h>
32
33 using std::string;
34 using std::unique_ptr;
35
36 namespace folly {
37
38 // static members initializers
39 const AsyncSocket::OptionMap AsyncSocket::emptyOptionMap;
40 const folly::SocketAddress AsyncSocket::anyAddress =
41   folly::SocketAddress("0.0.0.0", 0);
42
43 const AsyncSocketException socketClosedLocallyEx(
44     AsyncSocketException::END_OF_FILE, "socket closed locally");
45 const AsyncSocketException socketShutdownForWritesEx(
46     AsyncSocketException::END_OF_FILE, "socket shutdown for writes");
47
48 // TODO: It might help performance to provide a version of WriteRequest that
49 // users could derive from, so we can avoid the extra allocation for each call
50 // to write()/writev().  We could templatize TFramedAsyncChannel just like the
51 // protocols are currently templatized for transports.
52 //
53 // We would need the version for external users where they provide the iovec
54 // storage space, and only our internal version would allocate it at the end of
55 // the WriteRequest.
56
57 /**
58  * A WriteRequest object tracks information about a pending write() or writev()
59  * operation.
60  *
61  * A new WriteRequest operation is allocated on the heap for all write
62  * operations that cannot be completed immediately.
63  */
64 class AsyncSocket::WriteRequest {
65  public:
66   static WriteRequest* newRequest(WriteCallback* callback,
67                                   const iovec* ops,
68                                   uint32_t opCount,
69                                   unique_ptr<IOBuf>&& ioBuf,
70                                   WriteFlags flags) {
71     assert(opCount > 0);
72     // Since we put a variable size iovec array at the end
73     // of each WriteRequest, we have to manually allocate the memory.
74     void* buf = malloc(sizeof(WriteRequest) +
75                        (opCount * sizeof(struct iovec)));
76     if (buf == nullptr) {
77       throw std::bad_alloc();
78     }
79
80     return new(buf) WriteRequest(callback, ops, opCount, std::move(ioBuf),
81                                  flags);
82   }
83
84   void destroy() {
85     this->~WriteRequest();
86     free(this);
87   }
88
89   bool cork() const {
90     return isSet(flags_, WriteFlags::CORK);
91   }
92
93   WriteFlags flags() const {
94     return flags_;
95   }
96
97   WriteRequest* getNext() const {
98     return next_;
99   }
100
101   WriteCallback* getCallback() const {
102     return callback_;
103   }
104
105   uint32_t getBytesWritten() const {
106     return bytesWritten_;
107   }
108
109   const struct iovec* getOps() const {
110     assert(opCount_ > opIndex_);
111     return writeOps_ + opIndex_;
112   }
113
114   uint32_t getOpCount() const {
115     assert(opCount_ > opIndex_);
116     return opCount_ - opIndex_;
117   }
118
119   void consume(uint32_t wholeOps, uint32_t partialBytes,
120                uint32_t totalBytesWritten) {
121     // Advance opIndex_ forward by wholeOps
122     opIndex_ += wholeOps;
123     assert(opIndex_ < opCount_);
124
125     // If we've finished writing any IOBufs, release them
126     if (ioBuf_) {
127       for (uint32_t i = wholeOps; i != 0; --i) {
128         assert(ioBuf_);
129         ioBuf_ = ioBuf_->pop();
130       }
131     }
132
133     // Move partialBytes forward into the current iovec buffer
134     struct iovec* currentOp = writeOps_ + opIndex_;
135     assert((partialBytes < currentOp->iov_len) || (currentOp->iov_len == 0));
136     currentOp->iov_base =
137       reinterpret_cast<uint8_t*>(currentOp->iov_base) + partialBytes;
138     currentOp->iov_len -= partialBytes;
139
140     // Increment the bytesWritten_ count by totalBytesWritten
141     bytesWritten_ += totalBytesWritten;
142   }
143
144   void append(WriteRequest* next) {
145     assert(next_ == nullptr);
146     next_ = next;
147   }
148
149  private:
150   WriteRequest(WriteCallback* callback,
151                const struct iovec* ops,
152                uint32_t opCount,
153                unique_ptr<IOBuf>&& ioBuf,
154                WriteFlags flags)
155     : next_(nullptr)
156     , callback_(callback)
157     , bytesWritten_(0)
158     , opCount_(opCount)
159     , opIndex_(0)
160     , flags_(flags)
161     , ioBuf_(std::move(ioBuf)) {
162     memcpy(writeOps_, ops, sizeof(*ops) * opCount_);
163   }
164
165   // Private destructor, to ensure callers use destroy()
166   ~WriteRequest() {}
167
168   WriteRequest* next_;          ///< pointer to next WriteRequest
169   WriteCallback* callback_;     ///< completion callback
170   uint32_t bytesWritten_;       ///< bytes written
171   uint32_t opCount_;            ///< number of entries in writeOps_
172   uint32_t opIndex_;            ///< current index into writeOps_
173   WriteFlags flags_;            ///< set for WriteFlags
174   unique_ptr<IOBuf> ioBuf_;     ///< underlying IOBuf, or nullptr if N/A
175   struct iovec writeOps_[];     ///< write operation(s) list
176 };
177
178 AsyncSocket::AsyncSocket()
179   : eventBase_(nullptr)
180   , writeTimeout_(this, nullptr)
181   , ioHandler_(this, nullptr) {
182   VLOG(5) << "new AsyncSocket(" << ")";
183   init();
184 }
185
186 AsyncSocket::AsyncSocket(EventBase* evb)
187   : eventBase_(evb)
188   , writeTimeout_(this, evb)
189   , ioHandler_(this, evb) {
190   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
191   init();
192 }
193
194 AsyncSocket::AsyncSocket(EventBase* evb,
195                            const folly::SocketAddress& address,
196                            uint32_t connectTimeout)
197   : eventBase_(evb)
198   , writeTimeout_(this, evb)
199   , ioHandler_(this, evb) {
200   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
201   init();
202   connect(nullptr, address, connectTimeout);
203 }
204
205 AsyncSocket::AsyncSocket(EventBase* evb,
206                            const std::string& ip,
207                            uint16_t port,
208                            uint32_t connectTimeout)
209   : eventBase_(evb)
210   , writeTimeout_(this, evb)
211   , ioHandler_(this, evb) {
212   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ")";
213   init();
214   connect(nullptr, ip, port, connectTimeout);
215 }
216
217 AsyncSocket::AsyncSocket(EventBase* evb, int fd)
218   : eventBase_(evb)
219   , writeTimeout_(this, evb)
220   , ioHandler_(this, evb, fd) {
221   VLOG(5) << "new AsyncSocket(" << this << ", evb=" << evb << ", fd="
222           << fd << ")";
223   init();
224   fd_ = fd;
225   state_ = StateEnum::ESTABLISHED;
226 }
227
228 // init() method, since constructor forwarding isn't supported in most
229 // compilers yet.
230 void AsyncSocket::init() {
231   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
232   shutdownFlags_ = 0;
233   state_ = StateEnum::UNINIT;
234   eventFlags_ = EventHandler::NONE;
235   fd_ = -1;
236   sendTimeout_ = 0;
237   maxReadsPerEvent_ = 0;
238   connectCallback_ = nullptr;
239   readCallback_ = nullptr;
240   writeReqHead_ = nullptr;
241   writeReqTail_ = nullptr;
242   shutdownSocketSet_ = nullptr;
243   appBytesWritten_ = 0;
244   appBytesReceived_ = 0;
245 }
246
247 AsyncSocket::~AsyncSocket() {
248   VLOG(7) << "actual destruction of AsyncSocket(this=" << this
249           << ", evb=" << eventBase_ << ", fd=" << fd_
250           << ", state=" << state_ << ")";
251 }
252
253 void AsyncSocket::destroy() {
254   VLOG(5) << "AsyncSocket::destroy(this=" << this << ", evb=" << eventBase_
255           << ", fd=" << fd_ << ", state=" << state_;
256   // When destroy is called, close the socket immediately
257   closeNow();
258
259   // Then call DelayedDestruction::destroy() to take care of
260   // whether or not we need immediate or delayed destruction
261   DelayedDestruction::destroy();
262 }
263
264 int AsyncSocket::detachFd() {
265   VLOG(6) << "AsyncSocket::detachFd(this=" << this << ", fd=" << fd_
266           << ", evb=" << eventBase_ << ", state=" << state_
267           << ", events=" << std::hex << eventFlags_ << ")";
268   // Extract the fd, and set fd_ to -1 first, so closeNow() won't
269   // actually close the descriptor.
270   if (shutdownSocketSet_) {
271     shutdownSocketSet_->remove(fd_);
272   }
273   int fd = fd_;
274   fd_ = -1;
275   // Call closeNow() to invoke all pending callbacks with an error.
276   closeNow();
277   // Update the EventHandler to stop using this fd.
278   // This can only be done after closeNow() unregisters the handler.
279   ioHandler_.changeHandlerFD(-1);
280   return fd;
281 }
282
283 void AsyncSocket::setShutdownSocketSet(ShutdownSocketSet* newSS) {
284   if (shutdownSocketSet_ == newSS) {
285     return;
286   }
287   if (shutdownSocketSet_ && fd_ != -1) {
288     shutdownSocketSet_->remove(fd_);
289   }
290   shutdownSocketSet_ = newSS;
291   if (shutdownSocketSet_ && fd_ != -1) {
292     shutdownSocketSet_->add(fd_);
293   }
294 }
295
296 void AsyncSocket::connect(ConnectCallback* callback,
297                            const folly::SocketAddress& address,
298                            int timeout,
299                            const OptionMap &options,
300                            const folly::SocketAddress& bindAddr) noexcept {
301   DestructorGuard dg(this);
302   assert(eventBase_->isInEventBaseThread());
303
304   addr_ = address;
305
306   // Make sure we're in the uninitialized state
307   if (state_ != StateEnum::UNINIT) {
308     return invalidState(callback);
309   }
310
311   assert(fd_ == -1);
312   state_ = StateEnum::CONNECTING;
313   connectCallback_ = callback;
314
315   sockaddr_storage addrStorage;
316   sockaddr* saddr = reinterpret_cast<sockaddr*>(&addrStorage);
317
318   try {
319     // Create the socket
320     // Technically the first parameter should actually be a protocol family
321     // constant (PF_xxx) rather than an address family (AF_xxx), but the
322     // distinction is mainly just historical.  In pretty much all
323     // implementations the PF_foo and AF_foo constants are identical.
324     fd_ = socket(address.getFamily(), SOCK_STREAM, 0);
325     if (fd_ < 0) {
326       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
327                                 withAddr("failed to create socket"), errno);
328     }
329     if (shutdownSocketSet_) {
330       shutdownSocketSet_->add(fd_);
331     }
332     ioHandler_.changeHandlerFD(fd_);
333
334     // Set the FD_CLOEXEC flag so that the socket will be closed if the program
335     // later forks and execs.
336     int rv = fcntl(fd_, F_SETFD, FD_CLOEXEC);
337     if (rv != 0) {
338       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
339                                 withAddr("failed to set close-on-exec flag"),
340                                 errno);
341     }
342
343     // Put the socket in non-blocking mode
344     int flags = fcntl(fd_, F_GETFL, 0);
345     if (flags == -1) {
346       throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
347                                 withAddr("failed to get socket flags"), errno);
348     }
349     rv = fcntl(fd_, F_SETFL, flags | O_NONBLOCK);
350     if (rv == -1) {
351       throw AsyncSocketException(
352           AsyncSocketException::INTERNAL_ERROR,
353           withAddr("failed to put socket in non-blocking mode"),
354           errno);
355     }
356
357 #if !defined(MSG_NOSIGNAL) && defined(F_SETNOSIGPIPE)
358     // iOS and OS X don't support MSG_NOSIGNAL; set F_SETNOSIGPIPE instead
359     rv = fcntl(fd_, F_SETNOSIGPIPE, 1);
360     if (rv == -1) {
361       throw AsyncSocketException(
362           AsyncSocketException::INTERNAL_ERROR,
363           "failed to enable F_SETNOSIGPIPE on socket",
364           errno);
365     }
366 #endif
367
368     // By default, turn on TCP_NODELAY
369     // If setNoDelay() fails, we continue anyway; this isn't a fatal error.
370     // setNoDelay() will log an error message if it fails.
371     if (address.getFamily() != AF_UNIX) {
372       (void)setNoDelay(true);
373     }
374
375     VLOG(5) << "AsyncSocket::connect(this=" << this << ", evb=" << eventBase_
376             << ", fd=" << fd_ << ", host=" << address.describe().c_str();
377
378     // bind the socket
379     if (bindAddr != anyAddress) {
380       int one = 1;
381       if (::setsockopt(fd_, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one))) {
382         doClose();
383         throw AsyncSocketException(
384           AsyncSocketException::NOT_OPEN,
385           "failed to setsockopt prior to bind on " + bindAddr.describe(),
386           errno);
387       }
388
389       bindAddr.getAddress(&addrStorage);
390
391       if (::bind(fd_, saddr, bindAddr.getActualSize()) != 0) {
392         doClose();
393         throw AsyncSocketException(AsyncSocketException::NOT_OPEN,
394                                   "failed to bind to async socket: " +
395                                   bindAddr.describe(),
396                                   errno);
397       }
398     }
399
400     // Apply the additional options if any.
401     for (const auto& opt: options) {
402       int rv = opt.first.apply(fd_, opt.second);
403       if (rv != 0) {
404         throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
405                                   withAddr("failed to set socket option"),
406                                   errno);
407       }
408     }
409
410     // Perform the connect()
411     address.getAddress(&addrStorage);
412
413     rv = ::connect(fd_, saddr, address.getActualSize());
414     if (rv < 0) {
415       if (errno == EINPROGRESS) {
416         // Connection in progress.
417         if (timeout > 0) {
418           // Start a timer in case the connection takes too long.
419           if (!writeTimeout_.scheduleTimeout(timeout)) {
420             throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
421                 withAddr("failed to schedule AsyncSocket connect timeout"));
422           }
423         }
424
425         // Register for write events, so we'll
426         // be notified when the connection finishes/fails.
427         // Note that we don't register for a persistent event here.
428         assert(eventFlags_ == EventHandler::NONE);
429         eventFlags_ = EventHandler::WRITE;
430         if (!ioHandler_.registerHandler(eventFlags_)) {
431           throw AsyncSocketException(AsyncSocketException::INTERNAL_ERROR,
432               withAddr("failed to register AsyncSocket connect handler"));
433         }
434         return;
435       } else {
436         throw AsyncSocketException(AsyncSocketException::NOT_OPEN,
437                                   "connect failed (immediately)", errno);
438       }
439     }
440
441     // If we're still here the connect() succeeded immediately.
442     // Fall through to call the callback outside of this try...catch block
443   } catch (const AsyncSocketException& ex) {
444     return failConnect(__func__, ex);
445   } catch (const std::exception& ex) {
446     // shouldn't happen, but handle it just in case
447     VLOG(4) << "AsyncSocket::connect(this=" << this << ", fd=" << fd_
448                << "): unexpected " << typeid(ex).name() << " exception: "
449                << ex.what();
450     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
451                             withAddr(string("unexpected exception: ") +
452                                      ex.what()));
453     return failConnect(__func__, tex);
454   }
455
456   // The connection succeeded immediately
457   // The read callback may not have been set yet, and no writes may be pending
458   // yet, so we don't have to register for any events at the moment.
459   VLOG(8) << "AsyncSocket::connect succeeded immediately; this=" << this;
460   assert(readCallback_ == nullptr);
461   assert(writeReqHead_ == nullptr);
462   state_ = StateEnum::ESTABLISHED;
463   if (callback) {
464     connectCallback_ = nullptr;
465     callback->connectSuccess();
466   }
467 }
468
469 void AsyncSocket::connect(ConnectCallback* callback,
470                            const string& ip, uint16_t port,
471                            int timeout,
472                            const OptionMap &options) noexcept {
473   DestructorGuard dg(this);
474   try {
475     connectCallback_ = callback;
476     connect(callback, folly::SocketAddress(ip, port), timeout, options);
477   } catch (const std::exception& ex) {
478     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
479                             ex.what());
480     return failConnect(__func__, tex);
481   }
482 }
483
484 void AsyncSocket::setSendTimeout(uint32_t milliseconds) {
485   sendTimeout_ = milliseconds;
486   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
487
488   // If we are currently pending on write requests, immediately update
489   // writeTimeout_ with the new value.
490   if ((eventFlags_ & EventHandler::WRITE) &&
491       (state_ != StateEnum::CONNECTING)) {
492     assert(state_ == StateEnum::ESTABLISHED);
493     assert((shutdownFlags_ & SHUT_WRITE) == 0);
494     if (sendTimeout_ > 0) {
495       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
496         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
497             withAddr("failed to reschedule send timeout in setSendTimeout"));
498         return failWrite(__func__, ex);
499       }
500     } else {
501       writeTimeout_.cancelTimeout();
502     }
503   }
504 }
505
506 void AsyncSocket::setReadCB(ReadCallback *callback) {
507   VLOG(6) << "AsyncSocket::setReadCallback() this=" << this << ", fd=" << fd_
508           << ", callback=" << callback << ", state=" << state_;
509
510   // Short circuit if callback is the same as the existing readCallback_.
511   //
512   // Note that this is needed for proper functioning during some cleanup cases.
513   // During cleanup we allow setReadCallback(nullptr) to be called even if the
514   // read callback is already unset and we have been detached from an event
515   // base.  This check prevents us from asserting
516   // eventBase_->isInEventBaseThread() when eventBase_ is nullptr.
517   if (callback == readCallback_) {
518     return;
519   }
520
521   if (shutdownFlags_ & SHUT_READ) {
522     // Reads have already been shut down on this socket.
523     //
524     // Allow setReadCallback(nullptr) to be called in this case, but don't
525     // allow a new callback to be set.
526     //
527     // For example, setReadCallback(nullptr) can happen after an error if we
528     // invoke some other error callback before invoking readError().  The other
529     // error callback that is invoked first may go ahead and clear the read
530     // callback before we get a chance to invoke readError().
531     if (callback != nullptr) {
532       return invalidState(callback);
533     }
534     assert((eventFlags_ & EventHandler::READ) == 0);
535     readCallback_ = nullptr;
536     return;
537   }
538
539   DestructorGuard dg(this);
540   assert(eventBase_->isInEventBaseThread());
541
542   switch ((StateEnum)state_) {
543     case StateEnum::CONNECTING:
544       // For convenience, we allow the read callback to be set while we are
545       // still connecting.  We just store the callback for now.  Once the
546       // connection completes we'll register for read events.
547       readCallback_ = callback;
548       return;
549     case StateEnum::ESTABLISHED:
550     {
551       readCallback_ = callback;
552       uint16_t oldFlags = eventFlags_;
553       if (readCallback_) {
554         eventFlags_ |= EventHandler::READ;
555       } else {
556         eventFlags_ &= ~EventHandler::READ;
557       }
558
559       // Update our registration if our flags have changed
560       if (eventFlags_ != oldFlags) {
561         // We intentionally ignore the return value here.
562         // updateEventRegistration() will move us into the error state if it
563         // fails, and we don't need to do anything else here afterwards.
564         (void)updateEventRegistration();
565       }
566
567       if (readCallback_) {
568         checkForImmediateRead();
569       }
570       return;
571     }
572     case StateEnum::CLOSED:
573     case StateEnum::ERROR:
574       // We should never reach here.  SHUT_READ should always be set
575       // if we are in STATE_CLOSED or STATE_ERROR.
576       assert(false);
577       return invalidState(callback);
578     case StateEnum::UNINIT:
579       // We do not allow setReadCallback() to be called before we start
580       // connecting.
581       return invalidState(callback);
582   }
583
584   // We don't put a default case in the switch statement, so that the compiler
585   // will warn us to update the switch statement if a new state is added.
586   return invalidState(callback);
587 }
588
589 AsyncSocket::ReadCallback* AsyncSocket::getReadCallback() const {
590   return readCallback_;
591 }
592
593 void AsyncSocket::write(WriteCallback* callback,
594                          const void* buf, size_t bytes, WriteFlags flags) {
595   iovec op;
596   op.iov_base = const_cast<void*>(buf);
597   op.iov_len = bytes;
598   writeImpl(callback, &op, 1, std::move(unique_ptr<IOBuf>()), flags);
599 }
600
601 void AsyncSocket::writev(WriteCallback* callback,
602                           const iovec* vec,
603                           size_t count,
604                           WriteFlags flags) {
605   writeImpl(callback, vec, count, std::move(unique_ptr<IOBuf>()), flags);
606 }
607
608 void AsyncSocket::writeChain(WriteCallback* callback, unique_ptr<IOBuf>&& buf,
609                               WriteFlags flags) {
610   size_t count = buf->countChainElements();
611   if (count <= 64) {
612     iovec vec[count];
613     writeChainImpl(callback, vec, count, std::move(buf), flags);
614   } else {
615     iovec* vec = new iovec[count];
616     writeChainImpl(callback, vec, count, std::move(buf), flags);
617     delete[] vec;
618   }
619 }
620
621 void AsyncSocket::writeChainImpl(WriteCallback* callback, iovec* vec,
622     size_t count, unique_ptr<IOBuf>&& buf, WriteFlags flags) {
623   const IOBuf* head = buf.get();
624   const IOBuf* next = head;
625   unsigned i = 0;
626   do {
627     vec[i].iov_base = const_cast<uint8_t *>(next->data());
628     vec[i].iov_len = next->length();
629     // IOBuf can get confused by empty iovec buffers, so increment the
630     // output pointer only if the iovec buffer is non-empty.  We could
631     // end the loop with i < count, but that's ok.
632     if (vec[i].iov_len != 0) {
633       i++;
634     }
635     next = next->next();
636   } while (next != head);
637   writeImpl(callback, vec, i, std::move(buf), flags);
638 }
639
640 void AsyncSocket::writeImpl(WriteCallback* callback, const iovec* vec,
641                              size_t count, unique_ptr<IOBuf>&& buf,
642                              WriteFlags flags) {
643   VLOG(6) << "AsyncSocket::writev() this=" << this << ", fd=" << fd_
644           << ", callback=" << callback << ", count=" << count
645           << ", state=" << state_;
646   DestructorGuard dg(this);
647   unique_ptr<IOBuf>ioBuf(std::move(buf));
648   assert(eventBase_->isInEventBaseThread());
649
650   if (shutdownFlags_ & (SHUT_WRITE | SHUT_WRITE_PENDING)) {
651     // No new writes may be performed after the write side of the socket has
652     // been shutdown.
653     //
654     // We could just call callback->writeError() here to fail just this write.
655     // However, fail hard and use invalidState() to fail all outstanding
656     // callbacks and move the socket into the error state.  There's most likely
657     // a bug in the caller's code, so we abort everything rather than trying to
658     // proceed as best we can.
659     return invalidState(callback);
660   }
661
662   uint32_t countWritten = 0;
663   uint32_t partialWritten = 0;
664   int bytesWritten = 0;
665   bool mustRegister = false;
666   if (state_ == StateEnum::ESTABLISHED && !connecting()) {
667     if (writeReqHead_ == nullptr) {
668       // If we are established and there are no other writes pending,
669       // we can attempt to perform the write immediately.
670       assert(writeReqTail_ == nullptr);
671       assert((eventFlags_ & EventHandler::WRITE) == 0);
672
673       bytesWritten = performWrite(vec, count, flags,
674                                   &countWritten, &partialWritten);
675       if (bytesWritten < 0) {
676         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
677                                withAddr("writev failed"), errno);
678         return failWrite(__func__, callback, 0, ex);
679       } else if (countWritten == count) {
680         // We successfully wrote everything.
681         // Invoke the callback and return.
682         if (callback) {
683           callback->writeSuccess();
684         }
685         return;
686       } // else { continue writing the next writeReq }
687       mustRegister = true;
688     }
689   } else if (!connecting()) {
690     // Invalid state for writing
691     return invalidState(callback);
692   }
693
694   // Create a new WriteRequest to add to the queue
695   WriteRequest* req;
696   try {
697     req = WriteRequest::newRequest(callback, vec + countWritten,
698                                    count - countWritten, std::move(ioBuf),
699                                    flags);
700   } catch (const std::exception& ex) {
701     // we mainly expect to catch std::bad_alloc here
702     AsyncSocketException tex(AsyncSocketException::INTERNAL_ERROR,
703         withAddr(string("failed to append new WriteRequest: ") + ex.what()));
704     return failWrite(__func__, callback, bytesWritten, tex);
705   }
706   req->consume(0, partialWritten, bytesWritten);
707   if (writeReqTail_ == nullptr) {
708     assert(writeReqHead_ == nullptr);
709     writeReqHead_ = writeReqTail_ = req;
710   } else {
711     writeReqTail_->append(req);
712     writeReqTail_ = req;
713   }
714
715   // Register for write events if are established and not currently
716   // waiting on write events
717   if (mustRegister) {
718     assert(state_ == StateEnum::ESTABLISHED);
719     assert((eventFlags_ & EventHandler::WRITE) == 0);
720     if (!updateEventRegistration(EventHandler::WRITE, 0)) {
721       assert(state_ == StateEnum::ERROR);
722       return;
723     }
724     if (sendTimeout_ > 0) {
725       // Schedule a timeout to fire if the write takes too long.
726       if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
727         AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
728                                withAddr("failed to schedule send timeout"));
729         return failWrite(__func__, ex);
730       }
731     }
732   }
733 }
734
735 void AsyncSocket::close() {
736   VLOG(5) << "AsyncSocket::close(): this=" << this << ", fd_=" << fd_
737           << ", state=" << state_ << ", shutdownFlags="
738           << std::hex << (int) shutdownFlags_;
739
740   // close() is only different from closeNow() when there are pending writes
741   // that need to drain before we can close.  In all other cases, just call
742   // closeNow().
743   //
744   // Note that writeReqHead_ can be non-nullptr even in STATE_CLOSED or
745   // STATE_ERROR if close() is invoked while a previous closeNow() or failure
746   // is still running.  (e.g., If there are multiple pending writes, and we
747   // call writeError() on the first one, it may call close().  In this case we
748   // will already be in STATE_CLOSED or STATE_ERROR, but the remaining pending
749   // writes will still be in the queue.)
750   //
751   // We only need to drain pending writes if we are still in STATE_CONNECTING
752   // or STATE_ESTABLISHED
753   if ((writeReqHead_ == nullptr) ||
754       !(state_ == StateEnum::CONNECTING ||
755       state_ == StateEnum::ESTABLISHED)) {
756     closeNow();
757     return;
758   }
759
760   // Declare a DestructorGuard to ensure that the AsyncSocket cannot be
761   // destroyed until close() returns.
762   DestructorGuard dg(this);
763   assert(eventBase_->isInEventBaseThread());
764
765   // Since there are write requests pending, we have to set the
766   // SHUT_WRITE_PENDING flag, and wait to perform the real close until the
767   // connect finishes and we finish writing these requests.
768   //
769   // Set SHUT_READ to indicate that reads are shut down, and set the
770   // SHUT_WRITE_PENDING flag to mark that we want to shutdown once the
771   // pending writes complete.
772   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE_PENDING);
773
774   // If a read callback is set, invoke readEOF() immediately to inform it that
775   // the socket has been closed and no more data can be read.
776   if (readCallback_) {
777     // Disable reads if they are enabled
778     if (!updateEventRegistration(0, EventHandler::READ)) {
779       // We're now in the error state; callbacks have been cleaned up
780       assert(state_ == StateEnum::ERROR);
781       assert(readCallback_ == nullptr);
782     } else {
783       ReadCallback* callback = readCallback_;
784       readCallback_ = nullptr;
785       callback->readEOF();
786     }
787   }
788 }
789
790 void AsyncSocket::closeNow() {
791   VLOG(5) << "AsyncSocket::closeNow(): this=" << this << ", fd_=" << fd_
792           << ", state=" << state_ << ", shutdownFlags="
793           << std::hex << (int) shutdownFlags_;
794   DestructorGuard dg(this);
795   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
796
797   switch (state_) {
798     case StateEnum::ESTABLISHED:
799     case StateEnum::CONNECTING:
800     {
801       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
802       state_ = StateEnum::CLOSED;
803
804       // If the write timeout was set, cancel it.
805       writeTimeout_.cancelTimeout();
806
807       // If we are registered for I/O events, unregister.
808       if (eventFlags_ != EventHandler::NONE) {
809         eventFlags_ = EventHandler::NONE;
810         if (!updateEventRegistration()) {
811           // We will have been moved into the error state.
812           assert(state_ == StateEnum::ERROR);
813           return;
814         }
815       }
816
817       if (fd_ >= 0) {
818         ioHandler_.changeHandlerFD(-1);
819         doClose();
820       }
821
822       if (connectCallback_) {
823         ConnectCallback* callback = connectCallback_;
824         connectCallback_ = nullptr;
825         callback->connectErr(socketClosedLocallyEx);
826       }
827
828       failAllWrites(socketClosedLocallyEx);
829
830       if (readCallback_) {
831         ReadCallback* callback = readCallback_;
832         readCallback_ = nullptr;
833         callback->readEOF();
834       }
835       return;
836     }
837     case StateEnum::CLOSED:
838       // Do nothing.  It's possible that we are being called recursively
839       // from inside a callback that we invoked inside another call to close()
840       // that is still running.
841       return;
842     case StateEnum::ERROR:
843       // Do nothing.  The error handling code has performed (or is performing)
844       // cleanup.
845       return;
846     case StateEnum::UNINIT:
847       assert(eventFlags_ == EventHandler::NONE);
848       assert(connectCallback_ == nullptr);
849       assert(readCallback_ == nullptr);
850       assert(writeReqHead_ == nullptr);
851       shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
852       state_ = StateEnum::CLOSED;
853       return;
854   }
855
856   LOG(DFATAL) << "AsyncSocket::closeNow() (this=" << this << ", fd=" << fd_
857               << ") called in unknown state " << state_;
858 }
859
860 void AsyncSocket::closeWithReset() {
861   // Enable SO_LINGER, with the linger timeout set to 0.
862   // This will trigger a TCP reset when we close the socket.
863   if (fd_ >= 0) {
864     struct linger optLinger = {1, 0};
865     if (setSockOpt(SOL_SOCKET, SO_LINGER, &optLinger) != 0) {
866       VLOG(2) << "AsyncSocket::closeWithReset(): error setting SO_LINGER "
867               << "on " << fd_ << ": errno=" << errno;
868     }
869   }
870
871   // Then let closeNow() take care of the rest
872   closeNow();
873 }
874
875 void AsyncSocket::shutdownWrite() {
876   VLOG(5) << "AsyncSocket::shutdownWrite(): this=" << this << ", fd=" << fd_
877           << ", state=" << state_ << ", shutdownFlags="
878           << std::hex << (int) shutdownFlags_;
879
880   // If there are no pending writes, shutdownWrite() is identical to
881   // shutdownWriteNow().
882   if (writeReqHead_ == nullptr) {
883     shutdownWriteNow();
884     return;
885   }
886
887   assert(eventBase_->isInEventBaseThread());
888
889   // There are pending writes.  Set SHUT_WRITE_PENDING so that the actual
890   // shutdown will be performed once all writes complete.
891   shutdownFlags_ |= SHUT_WRITE_PENDING;
892 }
893
894 void AsyncSocket::shutdownWriteNow() {
895   VLOG(5) << "AsyncSocket::shutdownWriteNow(): this=" << this
896           << ", fd=" << fd_ << ", state=" << state_
897           << ", shutdownFlags=" << std::hex << (int) shutdownFlags_;
898
899   if (shutdownFlags_ & SHUT_WRITE) {
900     // Writes are already shutdown; nothing else to do.
901     return;
902   }
903
904   // If SHUT_READ is already set, just call closeNow() to completely
905   // close the socket.  This can happen if close() was called with writes
906   // pending, and then shutdownWriteNow() is called before all pending writes
907   // complete.
908   if (shutdownFlags_ & SHUT_READ) {
909     closeNow();
910     return;
911   }
912
913   DestructorGuard dg(this);
914   assert(eventBase_ == nullptr || eventBase_->isInEventBaseThread());
915
916   switch (static_cast<StateEnum>(state_)) {
917     case StateEnum::ESTABLISHED:
918     {
919       shutdownFlags_ |= SHUT_WRITE;
920
921       // If the write timeout was set, cancel it.
922       writeTimeout_.cancelTimeout();
923
924       // If we are registered for write events, unregister.
925       if (!updateEventRegistration(0, EventHandler::WRITE)) {
926         // We will have been moved into the error state.
927         assert(state_ == StateEnum::ERROR);
928         return;
929       }
930
931       // Shutdown writes on the file descriptor
932       ::shutdown(fd_, SHUT_WR);
933
934       // Immediately fail all write requests
935       failAllWrites(socketShutdownForWritesEx);
936       return;
937     }
938     case StateEnum::CONNECTING:
939     {
940       // Set the SHUT_WRITE_PENDING flag.
941       // When the connection completes, it will check this flag,
942       // shutdown the write half of the socket, and then set SHUT_WRITE.
943       shutdownFlags_ |= SHUT_WRITE_PENDING;
944
945       // Immediately fail all write requests
946       failAllWrites(socketShutdownForWritesEx);
947       return;
948     }
949     case StateEnum::UNINIT:
950       // Callers normally shouldn't call shutdownWriteNow() before the socket
951       // even starts connecting.  Nonetheless, go ahead and set
952       // SHUT_WRITE_PENDING.  Once the socket eventually connects it will
953       // immediately shut down the write side of the socket.
954       shutdownFlags_ |= SHUT_WRITE_PENDING;
955       return;
956     case StateEnum::CLOSED:
957     case StateEnum::ERROR:
958       // We should never get here.  SHUT_WRITE should always be set
959       // in STATE_CLOSED and STATE_ERROR.
960       VLOG(4) << "AsyncSocket::shutdownWriteNow() (this=" << this
961                  << ", fd=" << fd_ << ") in unexpected state " << state_
962                  << " with SHUT_WRITE not set ("
963                  << std::hex << (int) shutdownFlags_ << ")";
964       assert(false);
965       return;
966   }
967
968   LOG(DFATAL) << "AsyncSocket::shutdownWriteNow() (this=" << this << ", fd="
969               << fd_ << ") called in unknown state " << state_;
970 }
971
972 bool AsyncSocket::readable() const {
973   if (fd_ == -1) {
974     return false;
975   }
976   struct pollfd fds[1];
977   fds[0].fd = fd_;
978   fds[0].events = POLLIN;
979   fds[0].revents = 0;
980   int rc = poll(fds, 1, 0);
981   return rc == 1;
982 }
983
984 bool AsyncSocket::isPending() const {
985   return ioHandler_.isPending();
986 }
987
988 bool AsyncSocket::hangup() const {
989   if (fd_ == -1) {
990     // sanity check, no one should ask for hangup if we are not connected.
991     assert(false);
992     return false;
993   }
994 #ifdef POLLRDHUP // Linux-only
995   struct pollfd fds[1];
996   fds[0].fd = fd_;
997   fds[0].events = POLLRDHUP|POLLHUP;
998   fds[0].revents = 0;
999   poll(fds, 1, 0);
1000   return (fds[0].revents & (POLLRDHUP|POLLHUP)) != 0;
1001 #else
1002   return false;
1003 #endif
1004 }
1005
1006 bool AsyncSocket::good() const {
1007   return ((state_ == StateEnum::CONNECTING ||
1008           state_ == StateEnum::ESTABLISHED) &&
1009           (shutdownFlags_ == 0) && (eventBase_ != nullptr));
1010 }
1011
1012 bool AsyncSocket::error() const {
1013   return (state_ == StateEnum::ERROR);
1014 }
1015
1016 void AsyncSocket::attachEventBase(EventBase* eventBase) {
1017   VLOG(5) << "AsyncSocket::attachEventBase(this=" << this << ", fd=" << fd_
1018           << ", old evb=" << eventBase_ << ", new evb=" << eventBase
1019           << ", state=" << state_ << ", events="
1020           << std::hex << eventFlags_ << ")";
1021   assert(eventBase_ == nullptr);
1022   assert(eventBase->isInEventBaseThread());
1023
1024   eventBase_ = eventBase;
1025   ioHandler_.attachEventBase(eventBase);
1026   writeTimeout_.attachEventBase(eventBase);
1027 }
1028
1029 void AsyncSocket::detachEventBase() {
1030   VLOG(5) << "AsyncSocket::detachEventBase(this=" << this << ", fd=" << fd_
1031           << ", old evb=" << eventBase_ << ", state=" << state_
1032           << ", events=" << std::hex << eventFlags_ << ")";
1033   assert(eventBase_ != nullptr);
1034   assert(eventBase_->isInEventBaseThread());
1035
1036   eventBase_ = nullptr;
1037   ioHandler_.detachEventBase();
1038   writeTimeout_.detachEventBase();
1039 }
1040
1041 bool AsyncSocket::isDetachable() const {
1042   DCHECK(eventBase_ != nullptr);
1043   DCHECK(eventBase_->isInEventBaseThread());
1044
1045   return !ioHandler_.isHandlerRegistered() && !writeTimeout_.isScheduled();
1046 }
1047
1048 void AsyncSocket::getLocalAddress(folly::SocketAddress* address) const {
1049   address->setFromLocalAddress(fd_);
1050 }
1051
1052 void AsyncSocket::getPeerAddress(folly::SocketAddress* address) const {
1053   if (!addr_.isInitialized()) {
1054     addr_.setFromPeerAddress(fd_);
1055   }
1056   *address = addr_;
1057 }
1058
1059 int AsyncSocket::setNoDelay(bool noDelay) {
1060   if (fd_ < 0) {
1061     VLOG(4) << "AsyncSocket::setNoDelay() called on non-open socket "
1062                << this << "(state=" << state_ << ")";
1063     return EINVAL;
1064
1065   }
1066
1067   int value = noDelay ? 1 : 0;
1068   if (setsockopt(fd_, IPPROTO_TCP, TCP_NODELAY, &value, sizeof(value)) != 0) {
1069     int errnoCopy = errno;
1070     VLOG(2) << "failed to update TCP_NODELAY option on AsyncSocket "
1071             << this << " (fd=" << fd_ << ", state=" << state_ << "): "
1072             << strerror(errnoCopy);
1073     return errnoCopy;
1074   }
1075
1076   return 0;
1077 }
1078
1079 int AsyncSocket::setCongestionFlavor(const std::string &cname) {
1080
1081   #ifndef TCP_CONGESTION
1082   #define TCP_CONGESTION  13
1083   #endif
1084
1085   if (fd_ < 0) {
1086     VLOG(4) << "AsyncSocket::setCongestionFlavor() called on non-open "
1087                << "socket " << this << "(state=" << state_ << ")";
1088     return EINVAL;
1089
1090   }
1091
1092   if (setsockopt(fd_, IPPROTO_TCP, TCP_CONGESTION, cname.c_str(),
1093         cname.length() + 1) != 0) {
1094     int errnoCopy = errno;
1095     VLOG(2) << "failed to update TCP_CONGESTION option on AsyncSocket "
1096             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1097             << strerror(errnoCopy);
1098     return errnoCopy;
1099   }
1100
1101   return 0;
1102 }
1103
1104 int AsyncSocket::setQuickAck(bool quickack) {
1105   if (fd_ < 0) {
1106     VLOG(4) << "AsyncSocket::setQuickAck() called on non-open socket "
1107                << this << "(state=" << state_ << ")";
1108     return EINVAL;
1109
1110   }
1111
1112 #ifdef TCP_QUICKACK // Linux-only
1113   int value = quickack ? 1 : 0;
1114   if (setsockopt(fd_, IPPROTO_TCP, TCP_QUICKACK, &value, sizeof(value)) != 0) {
1115     int errnoCopy = errno;
1116     VLOG(2) << "failed to update TCP_QUICKACK option on AsyncSocket"
1117             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1118             << strerror(errnoCopy);
1119     return errnoCopy;
1120   }
1121
1122   return 0;
1123 #else
1124   return ENOSYS;
1125 #endif
1126 }
1127
1128 int AsyncSocket::setSendBufSize(size_t bufsize) {
1129   if (fd_ < 0) {
1130     VLOG(4) << "AsyncSocket::setSendBufSize() called on non-open socket "
1131                << this << "(state=" << state_ << ")";
1132     return EINVAL;
1133   }
1134
1135   if (setsockopt(fd_, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)) !=0) {
1136     int errnoCopy = errno;
1137     VLOG(2) << "failed to update SO_SNDBUF option on AsyncSocket"
1138             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1139             << strerror(errnoCopy);
1140     return errnoCopy;
1141   }
1142
1143   return 0;
1144 }
1145
1146 int AsyncSocket::setRecvBufSize(size_t bufsize) {
1147   if (fd_ < 0) {
1148     VLOG(4) << "AsyncSocket::setRecvBufSize() called on non-open socket "
1149                << this << "(state=" << state_ << ")";
1150     return EINVAL;
1151   }
1152
1153   if (setsockopt(fd_, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)) !=0) {
1154     int errnoCopy = errno;
1155     VLOG(2) << "failed to update SO_RCVBUF option on AsyncSocket"
1156             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1157             << strerror(errnoCopy);
1158     return errnoCopy;
1159   }
1160
1161   return 0;
1162 }
1163
1164 int AsyncSocket::setTCPProfile(int profd) {
1165   if (fd_ < 0) {
1166     VLOG(4) << "AsyncSocket::setTCPProfile() called on non-open socket "
1167                << this << "(state=" << state_ << ")";
1168     return EINVAL;
1169   }
1170
1171   if (setsockopt(fd_, SOL_SOCKET, SO_SET_NAMESPACE, &profd, sizeof(int)) !=0) {
1172     int errnoCopy = errno;
1173     VLOG(2) << "failed to set socket namespace option on AsyncSocket"
1174             << this << "(fd=" << fd_ << ", state=" << state_ << "): "
1175             << strerror(errnoCopy);
1176     return errnoCopy;
1177   }
1178
1179   return 0;
1180 }
1181
1182 void AsyncSocket::ioReady(uint16_t events) noexcept {
1183   VLOG(7) << "AsyncSocket::ioRead() this=" << this << ", fd" << fd_
1184           << ", events=" << std::hex << events << ", state=" << state_;
1185   DestructorGuard dg(this);
1186   assert(events & EventHandler::READ_WRITE);
1187   assert(eventBase_->isInEventBaseThread());
1188
1189   uint16_t relevantEvents = events & EventHandler::READ_WRITE;
1190   if (relevantEvents == EventHandler::READ) {
1191     handleRead();
1192   } else if (relevantEvents == EventHandler::WRITE) {
1193     handleWrite();
1194   } else if (relevantEvents == EventHandler::READ_WRITE) {
1195     EventBase* originalEventBase = eventBase_;
1196     // If both read and write events are ready, process writes first.
1197     handleWrite();
1198
1199     // Return now if handleWrite() detached us from our EventBase
1200     if (eventBase_ != originalEventBase) {
1201       return;
1202     }
1203
1204     // Only call handleRead() if a read callback is still installed.
1205     // (It's possible that the read callback was uninstalled during
1206     // handleWrite().)
1207     if (readCallback_) {
1208       handleRead();
1209     }
1210   } else {
1211     VLOG(4) << "AsyncSocket::ioRead() called with unexpected events "
1212                << std::hex << events << "(this=" << this << ")";
1213     abort();
1214   }
1215 }
1216
1217 ssize_t AsyncSocket::performRead(void* buf, size_t buflen) {
1218   ssize_t bytes = recv(fd_, buf, buflen, MSG_DONTWAIT);
1219   if (bytes < 0) {
1220     if (errno == EAGAIN || errno == EWOULDBLOCK) {
1221       // No more data to read right now.
1222       return READ_BLOCKING;
1223     } else {
1224       return READ_ERROR;
1225     }
1226   } else {
1227     appBytesReceived_ += bytes;
1228     return bytes;
1229   }
1230 }
1231
1232 void AsyncSocket::handleRead() noexcept {
1233   VLOG(5) << "AsyncSocket::handleRead() this=" << this << ", fd=" << fd_
1234           << ", state=" << state_;
1235   assert(state_ == StateEnum::ESTABLISHED);
1236   assert((shutdownFlags_ & SHUT_READ) == 0);
1237   assert(readCallback_ != nullptr);
1238   assert(eventFlags_ & EventHandler::READ);
1239
1240   // Loop until:
1241   // - a read attempt would block
1242   // - readCallback_ is uninstalled
1243   // - the number of loop iterations exceeds the optional maximum
1244   // - this AsyncSocket is moved to another EventBase
1245   //
1246   // When we invoke readDataAvailable() it may uninstall the readCallback_,
1247   // which is why need to check for it here.
1248   //
1249   // The last bullet point is slightly subtle.  readDataAvailable() may also
1250   // detach this socket from this EventBase.  However, before
1251   // readDataAvailable() returns another thread may pick it up, attach it to
1252   // a different EventBase, and install another readCallback_.  We need to
1253   // exit immediately after readDataAvailable() returns if the eventBase_ has
1254   // changed.  (The caller must perform some sort of locking to transfer the
1255   // AsyncSocket between threads properly.  This will be sufficient to ensure
1256   // that this thread sees the updated eventBase_ variable after
1257   // readDataAvailable() returns.)
1258   uint16_t numReads = 0;
1259   EventBase* originalEventBase = eventBase_;
1260   while (readCallback_ && eventBase_ == originalEventBase) {
1261     // Get the buffer to read into.
1262     void* buf = nullptr;
1263     size_t buflen = 0;
1264     try {
1265       readCallback_->getReadBuffer(&buf, &buflen);
1266     } catch (const AsyncSocketException& ex) {
1267       return failRead(__func__, ex);
1268     } catch (const std::exception& ex) {
1269       AsyncSocketException tex(AsyncSocketException::BAD_ARGS,
1270                               string("ReadCallback::getReadBuffer() "
1271                                      "threw exception: ") +
1272                               ex.what());
1273       return failRead(__func__, tex);
1274     } catch (...) {
1275       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1276                              "ReadCallback::getReadBuffer() threw "
1277                              "non-exception type");
1278       return failRead(__func__, ex);
1279     }
1280     if (buf == nullptr || buflen == 0) {
1281       AsyncSocketException ex(AsyncSocketException::BAD_ARGS,
1282                              "ReadCallback::getReadBuffer() returned "
1283                              "empty buffer");
1284       return failRead(__func__, ex);
1285     }
1286
1287     // Perform the read
1288     ssize_t bytesRead = performRead(buf, buflen);
1289     if (bytesRead > 0) {
1290       readCallback_->readDataAvailable(bytesRead);
1291       // Fall through and continue around the loop if the read
1292       // completely filled the available buffer.
1293       // Note that readCallback_ may have been uninstalled or changed inside
1294       // readDataAvailable().
1295       if (size_t(bytesRead) < buflen) {
1296         return;
1297       }
1298     } else if (bytesRead == READ_BLOCKING) {
1299         // No more data to read right now.
1300         return;
1301     } else if (bytesRead == READ_ERROR) {
1302       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1303                              withAddr("recv() failed"), errno);
1304       return failRead(__func__, ex);
1305     } else {
1306       assert(bytesRead == READ_EOF);
1307       // EOF
1308       shutdownFlags_ |= SHUT_READ;
1309       if (!updateEventRegistration(0, EventHandler::READ)) {
1310         // we've already been moved into STATE_ERROR
1311         assert(state_ == StateEnum::ERROR);
1312         assert(readCallback_ == nullptr);
1313         return;
1314       }
1315
1316       ReadCallback* callback = readCallback_;
1317       readCallback_ = nullptr;
1318       callback->readEOF();
1319       return;
1320     }
1321     if (maxReadsPerEvent_ && (++numReads >= maxReadsPerEvent_)) {
1322       return;
1323     }
1324   }
1325 }
1326
1327 /**
1328  * This function attempts to write as much data as possible, until no more data
1329  * can be written.
1330  *
1331  * - If it sends all available data, it unregisters for write events, and stops
1332  *   the writeTimeout_.
1333  *
1334  * - If not all of the data can be sent immediately, it reschedules
1335  *   writeTimeout_ (if a non-zero timeout is set), and ensures the handler is
1336  *   registered for write events.
1337  */
1338 void AsyncSocket::handleWrite() noexcept {
1339   VLOG(5) << "AsyncSocket::handleWrite() this=" << this << ", fd=" << fd_
1340           << ", state=" << state_;
1341   if (state_ == StateEnum::CONNECTING) {
1342     handleConnect();
1343     return;
1344   }
1345
1346   // Normal write
1347   assert(state_ == StateEnum::ESTABLISHED);
1348   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1349   assert(writeReqHead_ != nullptr);
1350
1351   // Loop until we run out of write requests,
1352   // or until this socket is moved to another EventBase.
1353   // (See the comment in handleRead() explaining how this can happen.)
1354   EventBase* originalEventBase = eventBase_;
1355   while (writeReqHead_ != nullptr && eventBase_ == originalEventBase) {
1356     uint32_t countWritten;
1357     uint32_t partialWritten;
1358     WriteFlags writeFlags = writeReqHead_->flags();
1359     if (writeReqHead_->getNext() != nullptr) {
1360       writeFlags = writeFlags | WriteFlags::CORK;
1361     }
1362     int bytesWritten = performWrite(writeReqHead_->getOps(),
1363                                     writeReqHead_->getOpCount(),
1364                                     writeFlags, &countWritten, &partialWritten);
1365     if (bytesWritten < 0) {
1366       AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1367                              withAddr("writev() failed"), errno);
1368       return failWrite(__func__, ex);
1369     } else if (countWritten == writeReqHead_->getOpCount()) {
1370       // We finished this request
1371       WriteRequest* req = writeReqHead_;
1372       writeReqHead_ = req->getNext();
1373
1374       if (writeReqHead_ == nullptr) {
1375         writeReqTail_ = nullptr;
1376         // This is the last write request.
1377         // Unregister for write events and cancel the send timer
1378         // before we invoke the callback.  We have to update the state properly
1379         // before calling the callback, since it may want to detach us from
1380         // the EventBase.
1381         if (eventFlags_ & EventHandler::WRITE) {
1382           if (!updateEventRegistration(0, EventHandler::WRITE)) {
1383             assert(state_ == StateEnum::ERROR);
1384             return;
1385           }
1386           // Stop the send timeout
1387           writeTimeout_.cancelTimeout();
1388         }
1389         assert(!writeTimeout_.isScheduled());
1390
1391         // If SHUT_WRITE_PENDING is set, we should shutdown the socket after
1392         // we finish sending the last write request.
1393         //
1394         // We have to do this before invoking writeSuccess(), since
1395         // writeSuccess() may detach us from our EventBase.
1396         if (shutdownFlags_ & SHUT_WRITE_PENDING) {
1397           assert(connectCallback_ == nullptr);
1398           shutdownFlags_ |= SHUT_WRITE;
1399
1400           if (shutdownFlags_ & SHUT_READ) {
1401             // Reads have already been shutdown.  Fully close the socket and
1402             // move to STATE_CLOSED.
1403             //
1404             // Note: This code currently moves us to STATE_CLOSED even if
1405             // close() hasn't ever been called.  This can occur if we have
1406             // received EOF from the peer and shutdownWrite() has been called
1407             // locally.  Should we bother staying in STATE_ESTABLISHED in this
1408             // case, until close() is actually called?  I can't think of a
1409             // reason why we would need to do so.  No other operations besides
1410             // calling close() or destroying the socket can be performed at
1411             // this point.
1412             assert(readCallback_ == nullptr);
1413             state_ = StateEnum::CLOSED;
1414             if (fd_ >= 0) {
1415               ioHandler_.changeHandlerFD(-1);
1416               doClose();
1417             }
1418           } else {
1419             // Reads are still enabled, so we are only doing a half-shutdown
1420             ::shutdown(fd_, SHUT_WR);
1421           }
1422         }
1423       }
1424
1425       // Invoke the callback
1426       WriteCallback* callback = req->getCallback();
1427       req->destroy();
1428       if (callback) {
1429         callback->writeSuccess();
1430       }
1431       // We'll continue around the loop, trying to write another request
1432     } else {
1433       // Partial write.
1434       writeReqHead_->consume(countWritten, partialWritten, bytesWritten);
1435       // Stop after a partial write; it's highly likely that a subsequent write
1436       // attempt will just return EAGAIN.
1437       //
1438       // Ensure that we are registered for write events.
1439       if ((eventFlags_ & EventHandler::WRITE) == 0) {
1440         if (!updateEventRegistration(EventHandler::WRITE, 0)) {
1441           assert(state_ == StateEnum::ERROR);
1442           return;
1443         }
1444       }
1445
1446       // Reschedule the send timeout, since we have made some write progress.
1447       if (sendTimeout_ > 0) {
1448         if (!writeTimeout_.scheduleTimeout(sendTimeout_)) {
1449           AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1450               withAddr("failed to reschedule write timeout"));
1451           return failWrite(__func__, ex);
1452         }
1453       }
1454       return;
1455     }
1456   }
1457 }
1458
1459 void AsyncSocket::checkForImmediateRead() noexcept {
1460   // We currently don't attempt to perform optimistic reads in AsyncSocket.
1461   // (However, note that some subclasses do override this method.)
1462   //
1463   // Simply calling handleRead() here would be bad, as this would call
1464   // readCallback_->getReadBuffer(), forcing the callback to allocate a read
1465   // buffer even though no data may be available.  This would waste lots of
1466   // memory, since the buffer will sit around unused until the socket actually
1467   // becomes readable.
1468   //
1469   // Checking if the socket is readable now also seems like it would probably
1470   // be a pessimism.  In most cases it probably wouldn't be readable, and we
1471   // would just waste an extra system call.  Even if it is readable, waiting to
1472   // find out from libevent on the next event loop doesn't seem that bad.
1473 }
1474
1475 void AsyncSocket::handleInitialReadWrite() noexcept {
1476   // Our callers should already be holding a DestructorGuard, but grab
1477   // one here just to make sure, in case one of our calling code paths ever
1478   // changes.
1479   DestructorGuard dg(this);
1480
1481   // If we have a readCallback_, make sure we enable read events.  We
1482   // may already be registered for reads if connectSuccess() set
1483   // the read calback.
1484   if (readCallback_ && !(eventFlags_ & EventHandler::READ)) {
1485     assert(state_ == StateEnum::ESTABLISHED);
1486     assert((shutdownFlags_ & SHUT_READ) == 0);
1487     if (!updateEventRegistration(EventHandler::READ, 0)) {
1488       assert(state_ == StateEnum::ERROR);
1489       return;
1490     }
1491     checkForImmediateRead();
1492   } else if (readCallback_ == nullptr) {
1493     // Unregister for read events.
1494     updateEventRegistration(0, EventHandler::READ);
1495   }
1496
1497   // If we have write requests pending, try to send them immediately.
1498   // Since we just finished accepting, there is a very good chance that we can
1499   // write without blocking.
1500   //
1501   // However, we only process them if EventHandler::WRITE is not already set,
1502   // which means that we're already blocked on a write attempt.  (This can
1503   // happen if connectSuccess() called write() before returning.)
1504   if (writeReqHead_ && !(eventFlags_ & EventHandler::WRITE)) {
1505     // Call handleWrite() to perform write processing.
1506     handleWrite();
1507   } else if (writeReqHead_ == nullptr) {
1508     // Unregister for write event.
1509     updateEventRegistration(0, EventHandler::WRITE);
1510   }
1511 }
1512
1513 void AsyncSocket::handleConnect() noexcept {
1514   VLOG(5) << "AsyncSocket::handleConnect() this=" << this << ", fd=" << fd_
1515           << ", state=" << state_;
1516   assert(state_ == StateEnum::CONNECTING);
1517   // SHUT_WRITE can never be set while we are still connecting;
1518   // SHUT_WRITE_PENDING may be set, be we only set SHUT_WRITE once the connect
1519   // finishes
1520   assert((shutdownFlags_ & SHUT_WRITE) == 0);
1521
1522   // In case we had a connect timeout, cancel the timeout
1523   writeTimeout_.cancelTimeout();
1524   // We don't use a persistent registration when waiting on a connect event,
1525   // so we have been automatically unregistered now.  Update eventFlags_ to
1526   // reflect reality.
1527   assert(eventFlags_ == EventHandler::WRITE);
1528   eventFlags_ = EventHandler::NONE;
1529
1530   // Call getsockopt() to check if the connect succeeded
1531   int error;
1532   socklen_t len = sizeof(error);
1533   int rv = getsockopt(fd_, SOL_SOCKET, SO_ERROR, &error, &len);
1534   if (rv != 0) {
1535     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1536                            withAddr("error calling getsockopt() after connect"),
1537                            errno);
1538     VLOG(4) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1539                << fd_ << " host=" << addr_.describe()
1540                << ") exception:" << ex.what();
1541     return failConnect(__func__, ex);
1542   }
1543
1544   if (error != 0) {
1545     AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1546                            "connect failed", error);
1547     VLOG(1) << "AsyncSocket::handleConnect(this=" << this << ", fd="
1548             << fd_ << " host=" << addr_.describe()
1549             << ") exception: " << ex.what();
1550     return failConnect(__func__, ex);
1551   }
1552
1553   // Move into STATE_ESTABLISHED
1554   state_ = StateEnum::ESTABLISHED;
1555
1556   // If SHUT_WRITE_PENDING is set and we don't have any write requests to
1557   // perform, immediately shutdown the write half of the socket.
1558   if ((shutdownFlags_ & SHUT_WRITE_PENDING) && writeReqHead_ == nullptr) {
1559     // SHUT_READ shouldn't be set.  If close() is called on the socket while we
1560     // are still connecting we just abort the connect rather than waiting for
1561     // it to complete.
1562     assert((shutdownFlags_ & SHUT_READ) == 0);
1563     ::shutdown(fd_, SHUT_WR);
1564     shutdownFlags_ |= SHUT_WRITE;
1565   }
1566
1567   VLOG(7) << "AsyncSocket " << this << ": fd " << fd_
1568           << "successfully connected; state=" << state_;
1569
1570   // Remember the EventBase we are attached to, before we start invoking any
1571   // callbacks (since the callbacks may call detachEventBase()).
1572   EventBase* originalEventBase = eventBase_;
1573
1574   // Call the connect callback.
1575   if (connectCallback_) {
1576     ConnectCallback* callback = connectCallback_;
1577     connectCallback_ = nullptr;
1578     callback->connectSuccess();
1579   }
1580
1581   // Note that the connect callback may have changed our state.
1582   // (set or unset the read callback, called write(), closed the socket, etc.)
1583   // The following code needs to handle these situations correctly.
1584   //
1585   // If the socket has been closed, readCallback_ and writeReqHead_ will
1586   // always be nullptr, so that will prevent us from trying to read or write.
1587   //
1588   // The main thing to check for is if eventBase_ is still originalEventBase.
1589   // If not, we have been detached from this event base, so we shouldn't
1590   // perform any more operations.
1591   if (eventBase_ != originalEventBase) {
1592     return;
1593   }
1594
1595   handleInitialReadWrite();
1596 }
1597
1598 void AsyncSocket::timeoutExpired() noexcept {
1599   VLOG(7) << "AsyncSocket " << this << ", fd " << fd_ << ": timeout expired: "
1600           << "state=" << state_ << ", events=" << std::hex << eventFlags_;
1601   DestructorGuard dg(this);
1602   assert(eventBase_->isInEventBaseThread());
1603
1604   if (state_ == StateEnum::CONNECTING) {
1605     // connect() timed out
1606     // Unregister for I/O events.
1607     AsyncSocketException ex(AsyncSocketException::TIMED_OUT,
1608                            "connect timed out");
1609     failConnect(__func__, ex);
1610   } else {
1611     // a normal write operation timed out
1612     assert(state_ == StateEnum::ESTABLISHED);
1613     AsyncSocketException ex(AsyncSocketException::TIMED_OUT, "write timed out");
1614     failWrite(__func__, ex);
1615   }
1616 }
1617
1618 ssize_t AsyncSocket::performWrite(const iovec* vec,
1619                                    uint32_t count,
1620                                    WriteFlags flags,
1621                                    uint32_t* countWritten,
1622                                    uint32_t* partialWritten) {
1623   // We use sendmsg() instead of writev() so that we can pass in MSG_NOSIGNAL
1624   // We correctly handle EPIPE errors, so we never want to receive SIGPIPE
1625   // (since it may terminate the program if the main program doesn't explicitly
1626   // ignore it).
1627   struct msghdr msg;
1628   msg.msg_name = nullptr;
1629   msg.msg_namelen = 0;
1630   msg.msg_iov = const_cast<iovec *>(vec);
1631 #ifdef IOV_MAX // not defined on Android
1632   msg.msg_iovlen = std::min(count, (uint32_t)IOV_MAX);
1633 #else
1634   msg.msg_iovlen = std::min(count, (uint32_t)UIO_MAXIOV);
1635 #endif
1636   msg.msg_control = nullptr;
1637   msg.msg_controllen = 0;
1638   msg.msg_flags = 0;
1639
1640   int msg_flags = MSG_DONTWAIT;
1641
1642 #ifdef MSG_NOSIGNAL // Linux-only
1643   msg_flags |= MSG_NOSIGNAL;
1644   if (isSet(flags, WriteFlags::CORK)) {
1645     // MSG_MORE tells the kernel we have more data to send, so wait for us to
1646     // give it the rest of the data rather than immediately sending a partial
1647     // frame, even when TCP_NODELAY is enabled.
1648     msg_flags |= MSG_MORE;
1649   }
1650 #endif
1651   if (isSet(flags, WriteFlags::EOR)) {
1652     // marks that this is the last byte of a record (response)
1653     msg_flags |= MSG_EOR;
1654   }
1655   ssize_t totalWritten = ::sendmsg(fd_, &msg, msg_flags);
1656   if (totalWritten < 0) {
1657     if (errno == EAGAIN) {
1658       // TCP buffer is full; we can't write any more data right now.
1659       *countWritten = 0;
1660       *partialWritten = 0;
1661       return 0;
1662     }
1663     // error
1664     *countWritten = 0;
1665     *partialWritten = 0;
1666     return -1;
1667   }
1668
1669   appBytesWritten_ += totalWritten;
1670
1671   uint32_t bytesWritten;
1672   uint32_t n;
1673   for (bytesWritten = totalWritten, n = 0; n < count; ++n) {
1674     const iovec* v = vec + n;
1675     if (v->iov_len > bytesWritten) {
1676       // Partial write finished in the middle of this iovec
1677       *countWritten = n;
1678       *partialWritten = bytesWritten;
1679       return totalWritten;
1680     }
1681
1682     bytesWritten -= v->iov_len;
1683   }
1684
1685   assert(bytesWritten == 0);
1686   *countWritten = n;
1687   *partialWritten = 0;
1688   return totalWritten;
1689 }
1690
1691 /**
1692  * Re-register the EventHandler after eventFlags_ has changed.
1693  *
1694  * If an error occurs, fail() is called to move the socket into the error state
1695  * and call all currently installed callbacks.  After an error, the
1696  * AsyncSocket is completely unregistered.
1697  *
1698  * @return Returns true on succcess, or false on error.
1699  */
1700 bool AsyncSocket::updateEventRegistration() {
1701   VLOG(5) << "AsyncSocket::updateEventRegistration(this=" << this
1702           << ", fd=" << fd_ << ", evb=" << eventBase_ << ", state=" << state_
1703           << ", events=" << std::hex << eventFlags_;
1704   assert(eventBase_->isInEventBaseThread());
1705   if (eventFlags_ == EventHandler::NONE) {
1706     ioHandler_.unregisterHandler();
1707     return true;
1708   }
1709
1710   // Always register for persistent events, so we don't have to re-register
1711   // after being called back.
1712   if (!ioHandler_.registerHandler(eventFlags_ | EventHandler::PERSIST)) {
1713     eventFlags_ = EventHandler::NONE; // we're not registered after error
1714     AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1715         withAddr("failed to update AsyncSocket event registration"));
1716     fail("updateEventRegistration", ex);
1717     return false;
1718   }
1719
1720   return true;
1721 }
1722
1723 bool AsyncSocket::updateEventRegistration(uint16_t enable,
1724                                            uint16_t disable) {
1725   uint16_t oldFlags = eventFlags_;
1726   eventFlags_ |= enable;
1727   eventFlags_ &= ~disable;
1728   if (eventFlags_ == oldFlags) {
1729     return true;
1730   } else {
1731     return updateEventRegistration();
1732   }
1733 }
1734
1735 void AsyncSocket::startFail() {
1736   // startFail() should only be called once
1737   assert(state_ != StateEnum::ERROR);
1738   assert(getDestructorGuardCount() > 0);
1739   state_ = StateEnum::ERROR;
1740   // Ensure that SHUT_READ and SHUT_WRITE are set,
1741   // so all future attempts to read or write will be rejected
1742   shutdownFlags_ |= (SHUT_READ | SHUT_WRITE);
1743
1744   if (eventFlags_ != EventHandler::NONE) {
1745     eventFlags_ = EventHandler::NONE;
1746     ioHandler_.unregisterHandler();
1747   }
1748   writeTimeout_.cancelTimeout();
1749
1750   if (fd_ >= 0) {
1751     ioHandler_.changeHandlerFD(-1);
1752     doClose();
1753   }
1754 }
1755
1756 void AsyncSocket::finishFail() {
1757   assert(state_ == StateEnum::ERROR);
1758   assert(getDestructorGuardCount() > 0);
1759
1760   AsyncSocketException ex(AsyncSocketException::INTERNAL_ERROR,
1761                          withAddr("socket closing after error"));
1762   if (connectCallback_) {
1763     ConnectCallback* callback = connectCallback_;
1764     connectCallback_ = nullptr;
1765     callback->connectErr(ex);
1766   }
1767
1768   failAllWrites(ex);
1769
1770   if (readCallback_) {
1771     ReadCallback* callback = readCallback_;
1772     readCallback_ = nullptr;
1773     callback->readErr(ex);
1774   }
1775 }
1776
1777 void AsyncSocket::fail(const char* fn, const AsyncSocketException& ex) {
1778   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1779              << state_ << " host=" << addr_.describe()
1780              << "): failed in " << fn << "(): "
1781              << ex.what();
1782   startFail();
1783   finishFail();
1784 }
1785
1786 void AsyncSocket::failConnect(const char* fn, const AsyncSocketException& ex) {
1787   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1788                << state_ << " host=" << addr_.describe()
1789                << "): failed while connecting in " << fn << "(): "
1790                << ex.what();
1791   startFail();
1792
1793   if (connectCallback_ != nullptr) {
1794     ConnectCallback* callback = connectCallback_;
1795     connectCallback_ = nullptr;
1796     callback->connectErr(ex);
1797   }
1798
1799   finishFail();
1800 }
1801
1802 void AsyncSocket::failRead(const char* fn, const AsyncSocketException& ex) {
1803   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1804                << state_ << " host=" << addr_.describe()
1805                << "): failed while reading in " << fn << "(): "
1806                << ex.what();
1807   startFail();
1808
1809   if (readCallback_ != nullptr) {
1810     ReadCallback* callback = readCallback_;
1811     readCallback_ = nullptr;
1812     callback->readErr(ex);
1813   }
1814
1815   finishFail();
1816 }
1817
1818 void AsyncSocket::failWrite(const char* fn, const AsyncSocketException& ex) {
1819   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1820                << state_ << " host=" << addr_.describe()
1821                << "): failed while writing in " << fn << "(): "
1822                << ex.what();
1823   startFail();
1824
1825   // Only invoke the first write callback, since the error occurred while
1826   // writing this request.  Let any other pending write callbacks be invoked in
1827   // finishFail().
1828   if (writeReqHead_ != nullptr) {
1829     WriteRequest* req = writeReqHead_;
1830     writeReqHead_ = req->getNext();
1831     WriteCallback* callback = req->getCallback();
1832     uint32_t bytesWritten = req->getBytesWritten();
1833     req->destroy();
1834     if (callback) {
1835       callback->writeErr(bytesWritten, ex);
1836     }
1837   }
1838
1839   finishFail();
1840 }
1841
1842 void AsyncSocket::failWrite(const char* fn, WriteCallback* callback,
1843                              size_t bytesWritten,
1844                              const AsyncSocketException& ex) {
1845   // This version of failWrite() is used when the failure occurs before
1846   // we've added the callback to writeReqHead_.
1847   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_ << ", state="
1848              << state_ << " host=" << addr_.describe()
1849              <<"): failed while writing in " << fn << "(): "
1850              << ex.what();
1851   startFail();
1852
1853   if (callback != nullptr) {
1854     callback->writeErr(bytesWritten, ex);
1855   }
1856
1857   finishFail();
1858 }
1859
1860 void AsyncSocket::failAllWrites(const AsyncSocketException& ex) {
1861   // Invoke writeError() on all write callbacks.
1862   // This is used when writes are forcibly shutdown with write requests
1863   // pending, or when an error occurs with writes pending.
1864   while (writeReqHead_ != nullptr) {
1865     WriteRequest* req = writeReqHead_;
1866     writeReqHead_ = req->getNext();
1867     WriteCallback* callback = req->getCallback();
1868     if (callback) {
1869       callback->writeErr(req->getBytesWritten(), ex);
1870     }
1871     req->destroy();
1872   }
1873 }
1874
1875 void AsyncSocket::invalidState(ConnectCallback* callback) {
1876   VLOG(5) << "AsyncSocket(this=" << this << ", fd=" << fd_
1877              << "): connect() called in invalid state " << state_;
1878
1879   /*
1880    * The invalidState() methods don't use the normal failure mechanisms,
1881    * since we don't know what state we are in.  We don't want to call
1882    * startFail()/finishFail() recursively if we are already in the middle of
1883    * cleaning up.
1884    */
1885
1886   AsyncSocketException ex(AsyncSocketException::ALREADY_OPEN,
1887                          "connect() called with socket in invalid state");
1888   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
1889     if (callback) {
1890       callback->connectErr(ex);
1891     }
1892   } else {
1893     // We can't use failConnect() here since connectCallback_
1894     // may already be set to another callback.  Invoke this ConnectCallback
1895     // here; any other connectCallback_ will be invoked in finishFail()
1896     startFail();
1897     if (callback) {
1898       callback->connectErr(ex);
1899     }
1900     finishFail();
1901   }
1902 }
1903
1904 void AsyncSocket::invalidState(ReadCallback* callback) {
1905   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
1906              << "): setReadCallback(" << callback
1907              << ") called in invalid state " << state_;
1908
1909   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1910                          "setReadCallback() called with socket in "
1911                          "invalid state");
1912   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
1913     if (callback) {
1914       callback->readErr(ex);
1915     }
1916   } else {
1917     startFail();
1918     if (callback) {
1919       callback->readErr(ex);
1920     }
1921     finishFail();
1922   }
1923 }
1924
1925 void AsyncSocket::invalidState(WriteCallback* callback) {
1926   VLOG(4) << "AsyncSocket(this=" << this << ", fd=" << fd_
1927              << "): write() called in invalid state " << state_;
1928
1929   AsyncSocketException ex(AsyncSocketException::NOT_OPEN,
1930                          withAddr("write() called with socket in invalid state"));
1931   if (state_ == StateEnum::CLOSED || state_ == StateEnum::ERROR) {
1932     if (callback) {
1933       callback->writeErr(0, ex);
1934     }
1935   } else {
1936     startFail();
1937     if (callback) {
1938       callback->writeErr(0, ex);
1939     }
1940     finishFail();
1941   }
1942 }
1943
1944 void AsyncSocket::doClose() {
1945   if (fd_ == -1) return;
1946   if (shutdownSocketSet_) {
1947     shutdownSocketSet_->close(fd_);
1948   } else {
1949     ::close(fd_);
1950   }
1951   fd_ = -1;
1952 }
1953
1954 std::ostream& operator << (std::ostream& os,
1955                            const AsyncSocket::StateEnum& state) {
1956   os << static_cast<int>(state);
1957   return os;
1958 }
1959
1960 std::string AsyncSocket::withAddr(const std::string& s) {
1961   // Don't use addr_ directly because it may not be initialized
1962   // e.g. if constructed from fd
1963   folly::SocketAddress peer, local;
1964   try {
1965     getPeerAddress(&peer);
1966     getLocalAddress(&local);
1967   } catch (const std::exception&) {
1968     // ignore
1969   } catch (...) {
1970     // ignore
1971   }
1972   return s + " (peer=" + peer.describe() + ", local=" + local.describe() + ")";
1973 }
1974
1975 } // folly