Add DCHECKs for checking that underlying IOBuf wasn't modified
[folly.git] / folly / io / Cursor.h
index ba0dd36f64e14a192ed2405953a542f474aa683e..be231d9b6246e31974274fff3cad1b2ae1c709f2 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 Facebook, Inc.
+ * Copyright 2013-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 #pragma once
 
 #include <cassert>
 #include <stdexcept>
 #include <type_traits>
 
-#include <folly/Bits.h>
 #include <folly/Likely.h>
 #include <folly/Memory.h>
 #include <folly/Portability.h>
 #include <folly/Range.h>
 #include <folly/io/IOBuf.h>
 #include <folly/io/IOBufQueue.h>
+#include <folly/lang/Bits.h>
 #include <folly/portability/BitsFunctexcept.h>
 
 /**
@@ -39,7 +38,7 @@
  *
  * RWPrivateCursor - Read-write access, assumes private access to IOBuf chain
  * RWUnshareCursor - Read-write access, calls unshare on write (COW)
- * Appender        - Write access, assumes private access to IOBuf chian
+ * Appender        - Write access, assumes private access to IOBuf chain
  *
  * Note that RW cursors write in the preallocated part of buffers (that is,
  * between the buffer's data() and tail()), while Appenders append to the end
@@ -58,7 +57,12 @@ class CursorBase {
   // Make all the templated classes friends for copy constructor.
   template <class D, typename B> friend class CursorBase;
  public:
-  explicit CursorBase(BufType* buf) : crtBuf_(buf), buffer_(buf) { }
+  explicit CursorBase(BufType* buf) : crtBuf_(buf), buffer_(buf) {
+    if (crtBuf_) {
+      crtPos_ = crtBegin_ = crtBuf_->data();
+      crtEnd_ = crtBuf_->tail();
+    }
+  }
 
   /**
    * Copy constructor.
@@ -68,9 +72,12 @@ class CursorBase {
    */
   template <class OtherDerived, class OtherBuf>
   explicit CursorBase(const CursorBase<OtherDerived, OtherBuf>& cursor)
-    : crtBuf_(cursor.crtBuf_),
-      offset_(cursor.offset_),
-      buffer_(cursor.buffer_) { }
+      : crtBuf_(cursor.crtBuf_),
+        buffer_(cursor.buffer_),
+        crtBegin_(cursor.crtBegin_),
+        crtEnd_(cursor.crtEnd_),
+        crtPos_(cursor.crtPos_),
+        absolutePos_(cursor.absolutePos_) {}
 
   /**
    * Reset cursor to point to a new buffer.
@@ -78,11 +85,24 @@ class CursorBase {
   void reset(BufType* buf) {
     crtBuf_ = buf;
     buffer_ = buf;
-    offset_ = 0;
+    absolutePos_ = 0;
+    if (crtBuf_) {
+      crtPos_ = crtBegin_ = crtBuf_->data();
+      crtEnd_ = crtBuf_->tail();
+    }
+  }
+
+  /**
+   * Get the current Cursor position relative to the head of IOBuf chain.
+   */
+  size_t getCurrentPosition() const {
+    dcheckIntegrity();
+    return (crtPos_ - crtBegin_) + absolutePos_;
   }
 
   const uint8_t* data() const {
-    return crtBuf_->data() + offset_;
+    dcheckIntegrity();
+    return crtPos_;
   }
 
   /**
@@ -94,7 +114,8 @@ class CursorBase {
    * pointing at the end of a buffer.
    */
   size_t length() const {
-    return crtBuf_->length() - offset_;
+    dcheckIntegrity();
+    return crtEnd_ - crtPos_;
   }
 
   /**
@@ -102,10 +123,10 @@ class CursorBase {
    */
   size_t totalLength() const {
     if (crtBuf_ == buffer_) {
-      return crtBuf_->computeChainDataLength() - offset_;
+      return crtBuf_->computeChainDataLength() - (crtPos_ - crtBegin_);
     }
     CursorBase end(buffer_->prev());
-    end.offset_ = end.buffer_->length();
+    end.crtPos_ = end.crtEnd_;
     return end - *this;
   }
 
@@ -134,8 +155,9 @@ class CursorBase {
    * Return true if the cursor is at the end of the entire IOBuf chain.
    */
   bool isAtEnd() const {
+    dcheckIntegrity();
     // Check for the simple cases first.
-    if (offset_ != crtBuf_->length()) {
+    if (crtPos_ != crtEnd_) {
       return false;
     }
     if (crtBuf_ == buffer_->prev()) {
@@ -158,9 +180,21 @@ class CursorBase {
    * Advances the cursor to the end of the entire IOBuf chain.
    */
   void advanceToEnd() {
-    offset_ = buffer_->prev()->length();
-    if (crtBuf_ != buffer_->prev()) {
-      crtBuf_ = buffer_->prev();
+    // Simple case, we're already in the last IOBuf.
+    if (crtBuf_ == buffer_->prev()) {
+      crtPos_ = crtEnd_;
+      return;
+    }
+
+    auto* nextBuf = crtBuf_->next();
+    while (nextBuf != buffer_) {
+      absolutePos_ += crtEnd_ - crtBegin_;
+
+      crtBuf_ = nextBuf;
+      nextBuf = crtBuf_->next();
+      crtBegin_ = crtBuf_->data();
+      crtPos_ = crtEnd_ = crtBuf_->tail();
+
       static_cast<Derived*>(this)->advanceDone();
     }
   }
@@ -194,7 +228,23 @@ class CursorBase {
    * same IOBuf chain.
    */
   bool operator==(const Derived& other) const {
-    return (offset_ == other.offset_) && (crtBuf_ == other.crtBuf_);
+    const IOBuf* crtBuf = crtBuf_;
+    auto crtPos = crtPos_;
+    // We can be pointing to the end of a buffer chunk, find first non-empty.
+    while (crtPos == crtBuf->tail() && crtBuf != buffer_->prev()) {
+      crtBuf = crtBuf->next();
+      crtPos = crtBuf->data();
+    }
+
+    const IOBuf* crtBufOther = other.crtBuf_;
+    auto crtPosOther = other.crtPos_;
+    // We can be pointing to the end of a buffer chunk, find first non-empty.
+    while (crtPosOther == crtBufOther->tail() &&
+           crtBufOther != other.buffer_->prev()) {
+      crtBufOther = crtBufOther->next();
+      crtPosOther = crtBufOther->data();
+    }
+    return (crtPos == crtPosOther) && (crtBuf == crtBufOther);
   }
   bool operator!=(const Derived& other) const {
     return !operator==(other);
@@ -203,10 +253,9 @@ class CursorBase {
   template <class T>
   typename std::enable_if<std::is_arithmetic<T>::value, bool>::type tryRead(
       T& val) {
-    if (LIKELY(length() >= sizeof(T))) {
+    if (LIKELY(crtPos_ + sizeof(T) <= crtEnd_)) {
       val = loadUnaligned<T>(data());
-      offset_ += sizeof(T);
-      advanceBufferIfEmpty();
+      crtPos_ += sizeof(T);
       return true;
     }
     return pullAtMostSlow(&val, sizeof(T)) == sizeof(T);
@@ -228,11 +277,13 @@ class CursorBase {
 
   template <class T>
   T read() {
-    T val{};
-    if (!tryRead(val)) {
-      std::__throw_out_of_range("underflow");
+    if (LIKELY(crtPos_ + sizeof(T) <= crtEnd_)) {
+      T val = loadUnaligned<T>(data());
+      crtPos_ += sizeof(T);
+      return val;
+    } else {
+      return readSlow<T>();
     }
-    return val;
   }
 
   template <class T>
@@ -257,8 +308,7 @@ class CursorBase {
     str.reserve(len);
     if (LIKELY(length() >= len)) {
       str.append(reinterpret_cast<const char*>(data()), len);
-      offset_ += len;
-      advanceBufferIfEmpty();
+      crtPos_ += len;
     } else {
       readFixedStringSlow(&str, len);
     }
@@ -307,55 +357,66 @@ class CursorBase {
   void skipWhile(const Predicate& predicate);
 
   size_t skipAtMost(size_t len) {
-    if (LIKELY(length() >= len)) {
-      offset_ += len;
-      advanceBufferIfEmpty();
+    dcheckIntegrity();
+    if (LIKELY(crtPos_ + len < crtEnd_)) {
+      crtPos_ += len;
       return len;
     }
     return skipAtMostSlow(len);
   }
 
   void skip(size_t len) {
-    if (LIKELY(length() >= len)) {
-      offset_ += len;
-      advanceBufferIfEmpty();
+    dcheckIntegrity();
+    if (LIKELY(crtPos_ + len < crtEnd_)) {
+      crtPos_ += len;
     } else {
       skipSlow(len);
     }
   }
 
+  /**
+   * Skip bytes in the current IOBuf without advancing to the next one.
+   * Precondition: length() >= len
+   */
+  void skipNoAdvance(size_t len) {
+    DCHECK_LE(len, length());
+    crtPos_ += len;
+  }
+
   size_t retreatAtMost(size_t len) {
-    if (len <= offset_) {
-      offset_ -= len;
+    dcheckIntegrity();
+    if (len <= static_cast<size_t>(crtPos_ - crtBegin_)) {
+      crtPos_ -= len;
       return len;
     }
     return retreatAtMostSlow(len);
   }
 
   void retreat(size_t len) {
-    if (len <= offset_) {
-      offset_ -= len;
+    dcheckIntegrity();
+    if (len <= static_cast<size_t>(crtPos_ - crtBegin_)) {
+      crtPos_ -= len;
     } else {
       retreatSlow(len);
     }
   }
 
   size_t pullAtMost(void* buf, size_t len) {
+    dcheckIntegrity();
     // Fast path: it all fits in one buffer.
-    if (LIKELY(length() >= len)) {
+    if (LIKELY(crtPos_ + len <= crtEnd_)) {
       memcpy(buf, data(), len);
-      offset_ += len;
-      advanceBufferIfEmpty();
+      crtPos_ += len;
       return len;
     }
     return pullAtMostSlow(buf, len);
   }
 
   void pull(void* buf, size_t len) {
-    if (LIKELY(length() >= len)) {
+    dcheckIntegrity();
+    if (LIKELY(crtPos_ + len <= crtEnd_)) {
       memcpy(buf, data(), len);
-      offset_ += len;
-      advanceBufferIfEmpty();
+      crtPos_ += len;
     } else {
       pullSlow(buf, len);
     }
@@ -399,6 +460,9 @@ class CursorBase {
   }
 
   size_t cloneAtMost(folly::IOBuf& buf, size_t len) {
+    // We might be at the end of buffer.
+    advanceBufferIfEmpty();
+
     std::unique_ptr<folly::IOBuf> tmp;
     size_t copied = 0;
     for (int loopCount = 0; true; ++loopCount) {
@@ -407,26 +471,26 @@ class CursorBase {
       if (LIKELY(available >= len)) {
         if (loopCount == 0) {
           crtBuf_->cloneOneInto(buf);
-          buf.trimStart(offset_);
+          buf.trimStart(crtPos_ - crtBegin_);
           buf.trimEnd(buf.length() - len);
         } else {
           tmp = crtBuf_->cloneOne();
-          tmp->trimStart(offset_);
+          tmp->trimStart(crtPos_ - crtBegin_);
           tmp->trimEnd(tmp->length() - len);
           buf.prependChain(std::move(tmp));
         }
 
-        offset_ += len;
+        crtPos_ += len;
         advanceBufferIfEmpty();
         return copied + len;
       }
 
       if (loopCount == 0) {
         crtBuf_->cloneOneInto(buf);
-        buf.trimStart(offset_);
+        buf.trimStart(crtPos_ - crtBegin_);
       } else {
         tmp = crtBuf_->cloneOne();
-        tmp->trimStart(offset_);
+        tmp->trimStart(crtPos_ - crtBegin_);
         buf.prependChain(std::move(tmp));
       }
 
@@ -453,7 +517,7 @@ class CursorBase {
     size_t len = 0;
 
     if (otherBuf != crtBuf_) {
-      len += otherBuf->length() - other.offset_;
+      len += other.crtEnd_ - other.crtPos_;
 
       for (otherBuf = otherBuf->next();
            otherBuf != crtBuf_ && otherBuf != other.buffer_;
@@ -465,13 +529,13 @@ class CursorBase {
         std::__throw_out_of_range("wrap-around");
       }
 
-      len += offset_;
+      len += crtPos_ - crtBegin_;
     } else {
-      if (offset_ < other.offset_) {
+      if (crtPos_ < other.crtPos_) {
         std::__throw_out_of_range("underflow");
       }
 
-      len += offset_ - other.offset_;
+      len += crtPos_ - other.crtPos_;
     }
 
     return len;
@@ -492,11 +556,19 @@ class CursorBase {
       }
     }
 
-    len += offset_;
+    len += crtPos_ - crtBegin_;
     return len;
   }
 
  protected:
+  void dcheckIntegrity() const {
+    DCHECK(crtBegin_ <= crtPos_ && crtPos_ <= crtEnd_);
+    DCHECK(crtBuf_ == nullptr || crtBegin_ == crtBuf_->data());
+    DCHECK(
+        crtBuf_ == nullptr ||
+        (uint64_t)(crtEnd_ - crtBegin_) == crtBuf_->length());
+  }
+
   ~CursorBase() { }
 
   BufType* head() {
@@ -506,37 +578,53 @@ class CursorBase {
   bool tryAdvanceBuffer() {
     BufType* nextBuf = crtBuf_->next();
     if (UNLIKELY(nextBuf == buffer_)) {
-      offset_ = crtBuf_->length();
+      crtPos_ = crtEnd_;
       return false;
     }
 
-    offset_ = 0;
+    absolutePos_ += crtEnd_ - crtBegin_;
     crtBuf_ = nextBuf;
+    crtPos_ = crtBegin_ = crtBuf_->data();
+    crtEnd_ = crtBuf_->tail();
     static_cast<Derived*>(this)->advanceDone();
     return true;
   }
 
   bool tryRetreatBuffer() {
     if (UNLIKELY(crtBuf_ == buffer_)) {
-      offset_ = 0;
+      crtPos_ = crtBegin_;
       return false;
     }
     crtBuf_ = crtBuf_->prev();
-    offset_ = crtBuf_->length();
+    crtBegin_ = crtBuf_->data();
+    crtPos_ = crtEnd_ = crtBuf_->tail();
+    absolutePos_ -= crtEnd_ - crtBegin_;
     static_cast<Derived*>(this)->advanceDone();
     return true;
   }
 
   void advanceBufferIfEmpty() {
-    if (length() == 0) {
+    dcheckIntegrity();
+    if (crtPos_ == crtEnd_) {
       tryAdvanceBuffer();
     }
   }
 
   BufType* crtBuf_;
-  size_t offset_ = 0;
+  BufType* buffer_;
+  const uint8_t* crtBegin_{nullptr};
+  const uint8_t* crtEnd_{nullptr};
+  const uint8_t* crtPos_{nullptr};
+  size_t absolutePos_{0};
 
  private:
+  template <class T>
+  FOLLY_NOINLINE T readSlow() {
+    T val;
+    pullSlow(&val, sizeof(T));
+    return val;
+  }
+
   void readFixedStringSlow(std::string* str, size_t len) {
     for (size_t available; (available = length()) < len; ) {
       str->append(reinterpret_cast<const char*>(data()), available);
@@ -546,7 +634,7 @@ class CursorBase {
       len -= available;
     }
     str->append(reinterpret_cast<const char*>(data()), len);
-    offset_ += len;
+    crtPos_ += len;
     advanceBufferIfEmpty();
   }
 
@@ -563,7 +651,7 @@ class CursorBase {
       len -= available;
     }
     memcpy(p, data(), len);
-    offset_ += len;
+    crtPos_ += len;
     advanceBufferIfEmpty();
     return copied + len;
   }
@@ -583,7 +671,7 @@ class CursorBase {
       }
       len -= available;
     }
-    offset_ += len;
+    crtPos_ += len;
     advanceBufferIfEmpty();
     return skipped + len;
   }
@@ -596,14 +684,14 @@ class CursorBase {
 
   size_t retreatAtMostSlow(size_t len) {
     size_t retreated = 0;
-    for (size_t available; (available = offset_) < len;) {
+    for (size_t available; (available = crtPos_ - crtBegin_) < len;) {
       retreated += available;
       if (UNLIKELY(!tryRetreatBuffer())) {
         return retreated;
       }
       len -= available;
     }
-    offset_ -= len;
+    crtPos_ -= len;
     return retreated + len;
   }
 
@@ -615,8 +703,6 @@ class CursorBase {
 
   void advanceDone() {
   }
-
-  BufType* buffer_;
 };
 
 } // namespace detail
@@ -748,11 +834,20 @@ class RWCursor
     if (this->crtBuf_ != this->head() && this->totalLength() < n) {
       throw std::overflow_error("cannot gather() past the end of the chain");
     }
-    this->crtBuf_->gather(this->offset_ + n);
+    size_t offset = this->crtPos_ - this->crtBegin_;
+    this->crtBuf_->gather(offset + n);
+    this->crtBegin_ = this->crtBuf_->data();
+    this->crtEnd_ = this->crtBuf_->tail();
+    this->crtPos_ = this->crtBegin_ + offset;
   }
   void gatherAtMost(size_t n) {
+    this->dcheckIntegrity();
     size_t size = std::min(n, this->totalLength());
-    return this->crtBuf_->gather(this->offset_ + size);
+    size_t offset = this->crtPos_ - this->crtBegin_;
+    this->crtBuf_->gather(offset + size);
+    this->crtBegin_ = this->crtBuf_->data();
+    this->crtEnd_ = this->crtBuf_->tail();
+    this->crtPos_ = this->crtBegin_ + offset;
   }
 
   using detail::Writable<RWCursor<access>>::pushAtMost;
@@ -774,7 +869,7 @@ class RWCursor
           maybeUnshare();
         }
         memcpy(writableData(), buf, len);
-        this->offset_ += len;
+        this->crtPos_ += len;
         return copied + len;
       }
 
@@ -792,17 +887,18 @@ class RWCursor
   }
 
   void insert(std::unique_ptr<folly::IOBuf> buf) {
-    folly::IOBuf* nextBuf;
-    if (this->offset_ == 0) {
+    this->dcheckIntegrity();
+    this->absolutePos_ += buf->computeChainDataLength();
+    if (this->crtPos_ == this->crtBegin_ && this->crtBuf_ != this->buffer_) {
       // Can just prepend
-      nextBuf = this->crtBuf_;
       this->crtBuf_->prependChain(std::move(buf));
     } else {
+      IOBuf* nextBuf;
       std::unique_ptr<folly::IOBuf> remaining;
-      if (this->crtBuf_->length() - this->offset_ > 0) {
+      if (this->crtPos_ != this->crtEnd_) {
         // Need to split current IOBuf in two.
         remaining = this->crtBuf_->cloneOne();
-        remaining->trimStart(this->offset_);
+        remaining->trimStart(this->crtPos_ - this->crtBegin_);
         nextBuf = remaining.get();
         buf->prependChain(std::move(remaining));
       } else {
@@ -810,21 +906,38 @@ class RWCursor
         nextBuf = this->crtBuf_->next();
       }
       this->crtBuf_->trimEnd(this->length());
+      this->absolutePos_ += this->crtPos_ - this->crtBegin_;
       this->crtBuf_->appendChain(std::move(buf));
+
+      if (nextBuf == this->buffer_) {
+        // We've just appended to the end of the buffer, so advance to the end.
+        this->crtBuf_ = this->buffer_->prev();
+        this->crtBegin_ = this->crtBuf_->data();
+        this->crtPos_ = this->crtEnd_ = this->crtBuf_->tail();
+        // This has already been accounted for, so remove it.
+        this->absolutePos_ -= this->crtEnd_ - this->crtBegin_;
+      } else {
+        // Jump past the new links
+        this->crtBuf_ = nextBuf;
+        this->crtPos_ = this->crtBegin_ = this->crtBuf_->data();
+        this->crtEnd_ = this->crtBuf_->tail();
+      }
     }
-    // Jump past the new links
-    this->offset_ = 0;
-    this->crtBuf_ = nextBuf;
   }
 
   uint8_t* writableData() {
-    return this->crtBuf_->writableData() + this->offset_;
+    this->dcheckIntegrity();
+    return this->crtBuf_->writableData() + (this->crtPos_ - this->crtBegin_);
   }
 
  private:
   void maybeUnshare() {
     if (UNLIKELY(maybeShared_)) {
+      size_t offset = this->crtPos_ - this->crtBegin_;
       this->crtBuf_->unshareOne();
+      this->crtBegin_ = this->crtBuf_->data();
+      this->crtEnd_ = this->crtBuf_->tail();
+      this->crtPos_ = this->crtBegin_ + offset;
       maybeShared_ = false;
     }
   }