make IOBuf::gather() safe
[folly.git] / folly / io / Cursor.h
index 73e37965713924dc905f23c67ef8bd0c8c1c8fa6..0b79c23f26b501f9b35583163a52ac32239e7bab 100644 (file)
@@ -54,13 +54,30 @@ class CursorBase {
     return crtBuf_->data() + offset_;
   }
 
-  // Space available in the current IOBuf.  May be 0; use peek() instead which
-  // will always point to a non-empty chunk of data or at the end of the
-  // chain.
+  /*
+   * Return the remaining space available in the current IOBuf.
+   *
+   * May return 0 if the cursor is at the end of an IOBuf.  Use peek() instead
+   * if you want to avoid this.  peek() will advance to the next non-empty
+   * IOBuf (up to the end of the chain) if the cursor is currently pointing at
+   * the end of a buffer.
+   */
   size_t length() const {
     return crtBuf_->length() - offset_;
   }
 
+  /*
+   * Return the space available until the end of the entire IOBuf chain.
+   */
+  size_t totalLength() const {
+    if (crtBuf_ == buffer_) {
+      return crtBuf_->computeChainDataLength() - offset_;
+    }
+    CursorBase end(buffer_->prev());
+    end.offset_ = end.buffer_->length();
+    return end - *this;
+  }
+
   Derived& operator+=(size_t offset) {
     Derived* p = static_cast<Derived*>(this);
     p->skip(offset);
@@ -344,6 +361,10 @@ class CursorBase {
 
   ~CursorBase(){}
 
+  BufType* head() {
+    return buffer_;
+  }
+
   bool tryAdvanceBuffer() {
     BufType* nextBuf = crtBuf_->next();
     if (UNLIKELY(nextBuf == buffer_)) {
@@ -431,8 +452,23 @@ class RWCursor
    * by coalescing subsequent buffers from the chain as necessary.
    */
   void gather(size_t n) {
+    // Forbid attempts to gather beyond the end of this IOBuf chain.
+    // Otherwise we could try to coalesce the head of the chain and end up
+    // accidentally freeing it, invalidating the pointer owned by external
+    // code.
+    //
+    // If crtBuf_ == head() then IOBuf::gather() will perform all necessary
+    // checking.  We only have to perform an explicit check here when calling
+    // gather() on a non-head element.
+    if (this->crtBuf_ != this->head() && this->totalLength() < n) {
+      throw std::overflow_error("cannot gather() past the end of the chain");
+    }
     this->crtBuf_->gather(this->offset_ + n);
   }
+  void gatherAtMost(size_t n) {
+    size_t size = std::min(n, this->totalLength());
+    return this->crtBuf_->gather(this->offset_ + size);
+  }
 
   size_t pushAtMost(const uint8_t* buf, size_t len) {
     size_t copied = 0;