X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2FCompression.cpp;h=98f332b85eb55d34f047d47ef51e7cd368369e16;hb=94f59e2636b7a3b910cbe3a8ec8f541c1e9b1719;hp=b523e697d5dcc02303cde019cdc343c46253924c;hpb=e1d2ddd5dca9f7a4a3c83ae7bc8c0c59c58333c1;p=folly.git diff --git a/folly/io/Compression.cpp b/folly/io/Compression.cpp index b523e697..98f332b8 100644 --- a/folly/io/Compression.cpp +++ b/folly/io/Compression.cpp @@ -58,12 +58,16 @@ #include #include -namespace folly { namespace io { +namespace folly { +namespace io { Codec::Codec(CodecType type) : type_(type) { } // Ensure consistent behavior in the nullptr case std::unique_ptr Codec::compress(const IOBuf* data) { + if (data == nullptr) { + throw std::invalid_argument("Codec: data must not be nullptr"); + } uint64_t len = data->computeChainDataLength(); if (len == 0) { return IOBuf::create(0); @@ -90,6 +94,9 @@ std::string Codec::compress(const StringPiece data) { std::unique_ptr Codec::uncompress( const IOBuf* data, Optional uncompressedLength) { + if (data == nullptr) { + throw std::invalid_argument("Codec: data must not be nullptr"); + } if (!uncompressedLength) { if (needsUncompressedLength()) { throw std::invalid_argument("Codec: uncompressed length required"); @@ -341,6 +348,8 @@ std::unique_ptr StreamCodec::doCompress(IOBuf const* data) { if (output.empty()) { buffer->prependChain(addOutputBuffer(output, kDefaultBufferLength)); } + size_t const inputSize = input.size(); + size_t const outputSize = output.size(); bool const done = compressStream(input, output, flushOp); if (done) { DCHECK(input.empty()); @@ -348,6 +357,9 @@ std::unique_ptr StreamCodec::doCompress(IOBuf const* data) { DCHECK_EQ(current->next(), data); break; } + if (inputSize == input.size() && outputSize == output.size()) { + throw std::runtime_error("Codec: No forward progress made"); + } } buffer->prev()->trimEnd(output.size()); return buffer; @@ -395,10 +407,15 @@ std::unique_ptr StreamCodec::doUncompress( if (output.empty()) { buffer->prependChain(addOutputBuffer(output, defaultBufferLength)); } + size_t const inputSize = input.size(); + size_t const outputSize = output.size(); bool const done = uncompressStream(input, output, flushOp); if (done) { break; } + if (inputSize == input.size() && outputSize == output.size()) { + throw std::runtime_error("Codec: Truncated data"); + } } if (!input.empty()) { throw std::runtime_error("Codec: Junk after end of data"); @@ -496,7 +513,7 @@ inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) { return val; } -} // namespace +} // namespace #endif // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA @@ -992,10 +1009,12 @@ std::unique_ptr SnappyCodec::doUncompress( /** * Zlib codec */ -class ZlibCodec final : public Codec { +class ZlibStreamCodec final : public StreamCodec { public: - static std::unique_ptr create(int level, CodecType type); - explicit ZlibCodec(int level, CodecType type); + static std::unique_ptr createCodec(int level, CodecType type); + static std::unique_ptr createStream(int level, CodecType type); + explicit ZlibStreamCodec(int level, CodecType type); + ~ZlibStreamCodec() override; std::vector validPrefixes() const override; bool canUncompress(const IOBuf* data, Optional uncompressedLength) @@ -1003,20 +1022,29 @@ class ZlibCodec final : public Codec { private: uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; - std::unique_ptr doCompress(const IOBuf* data) override; - std::unique_ptr doUncompress( - const IOBuf* data, - Optional uncompressedLength) override; - std::unique_ptr addOutputBuffer(z_stream* stream, uint32_t length); - bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength); + void doResetStream() override; + bool doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flush) override; + bool doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flush) override; + + void resetDeflateStream(); + void resetInflateStream(); + Optional deflateStream_{}; + Optional inflateStream_{}; int level_; + bool needReset_{true}; }; static constexpr uint16_t kGZIPMagicLE = 0x8B1F; -std::vector ZlibCodec::validPrefixes() const { +std::vector ZlibStreamCodec::validPrefixes() const { if (type() == CodecType::ZLIB) { // Zlib streams start with a 2 byte header. // @@ -1060,7 +1088,8 @@ std::vector ZlibCodec::validPrefixes() const { } } -bool ZlibCodec::canUncompress(const IOBuf* data, Optional) const { +bool ZlibStreamCodec::canUncompress(const IOBuf* data, Optional) + const { if (type() == CodecType::ZLIB) { uint16_t value; Cursor cursor{data}; @@ -1074,80 +1103,68 @@ bool ZlibCodec::canUncompress(const IOBuf* data, Optional) const { } } -uint64_t ZlibCodec::doMaxCompressedLength(uint64_t uncompressedLength) const { +uint64_t ZlibStreamCodec::doMaxCompressedLength( + uint64_t uncompressedLength) const { return deflateBound(nullptr, uncompressedLength); } -std::unique_ptr ZlibCodec::create(int level, CodecType type) { - return std::make_unique(level, type); +std::unique_ptr ZlibStreamCodec::createCodec(int level, CodecType type) { + return std::make_unique(level, type); +} + +std::unique_ptr ZlibStreamCodec::createStream( + int level, + CodecType type) { + return std::make_unique(level, type); } -ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) { +ZlibStreamCodec::ZlibStreamCodec(int level, CodecType type) + : StreamCodec(type) { DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP); switch (level) { - case COMPRESSION_LEVEL_FASTEST: - level = 1; - break; - case COMPRESSION_LEVEL_DEFAULT: - level = Z_DEFAULT_COMPRESSION; - break; - case COMPRESSION_LEVEL_BEST: - level = 9; - break; + case COMPRESSION_LEVEL_FASTEST: + level = 1; + break; + case COMPRESSION_LEVEL_DEFAULT: + level = Z_DEFAULT_COMPRESSION; + break; + case COMPRESSION_LEVEL_BEST: + level = 9; + break; } if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) { - throw std::invalid_argument(to( - "ZlibCodec: invalid level: ", level)); + throw std::invalid_argument( + to("ZlibStreamCodec: invalid level: ", level)); } level_ = level; } -std::unique_ptr ZlibCodec::addOutputBuffer(z_stream* stream, - uint32_t length) { - CHECK_EQ(stream->avail_out, 0); - - auto buf = IOBuf::create(length); - buf->append(buf->capacity()); - - stream->next_out = buf->writableData(); - stream->avail_out = buf->length(); - - return buf; -} - -bool ZlibCodec::doInflate(z_stream* stream, - IOBuf* head, - uint32_t bufferLength) { - if (stream->avail_out == 0) { - head->prependChain(addOutputBuffer(stream, bufferLength)); +ZlibStreamCodec::~ZlibStreamCodec() { + if (deflateStream_) { + deflateEnd(deflateStream_.get_pointer()); + deflateStream_.clear(); } - - int rc = inflate(stream, Z_NO_FLUSH); - - switch (rc) { - case Z_OK: - break; - case Z_STREAM_END: - return true; - case Z_BUF_ERROR: - case Z_NEED_DICT: - case Z_DATA_ERROR: - case Z_MEM_ERROR: - throw std::runtime_error(to( - "ZlibCodec: inflate error: ", rc, ": ", stream->msg)); - default: - CHECK(false) << rc << ": " << stream->msg; + if (inflateStream_) { + inflateEnd(inflateStream_.get_pointer()); + inflateStream_.clear(); } - - return false; } -std::unique_ptr ZlibCodec::doCompress(const IOBuf* data) { - z_stream stream; - stream.zalloc = nullptr; - stream.zfree = nullptr; - stream.opaque = nullptr; +void ZlibStreamCodec::doResetStream() { + needReset_ = true; +} +void ZlibStreamCodec::resetDeflateStream() { + if (deflateStream_) { + int const rc = deflateReset(deflateStream_.get_pointer()); + if (rc != Z_OK) { + deflateStream_.clear(); + throw std::runtime_error( + to("ZlibStreamCodec: deflateReset error: ", rc)); + } + return; + } + deflateStream_ = z_stream{}; // Using deflateInit2() to support gzip. "The windowBits parameter is the // base two logarithm of the maximum window size (...) The default value is // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer @@ -1155,167 +1172,132 @@ std::unique_ptr ZlibCodec::doCompress(const IOBuf* data) { // will have no file name, no extra data, no comment, no modification time // (set to zero), no header crc, and the operating system will be set to 255 // (unknown)." - int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0); + int const windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0); // All other parameters (method, memLevel, strategy) get default values from // the zlib manual. - int rc = deflateInit2(&stream, - level_, - Z_DEFLATED, - windowBits, - /* memLevel */ 8, - Z_DEFAULT_STRATEGY); + int const rc = deflateInit2( + deflateStream_.get_pointer(), + level_, + Z_DEFLATED, + windowBits, + /* memLevel */ 8, + Z_DEFAULT_STRATEGY); if (rc != Z_OK) { - throw std::runtime_error(to( - "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg)); + deflateStream_.clear(); + throw std::runtime_error( + to("ZlibStreamCodec: deflateInit error: ", rc)); } +} - stream.next_in = stream.next_out = nullptr; - stream.avail_in = stream.avail_out = 0; - stream.total_in = stream.total_out = 0; - - bool success = false; - - SCOPE_EXIT { - rc = deflateEnd(&stream); - // If we're here because of an exception, it's okay if some data - // got dropped. - CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR)) - << rc << ": " << stream.msg; - }; - - uint64_t uncompressedLength = data->computeChainDataLength(); - uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength); - - // Max 64MiB in one go - constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB - constexpr uint32_t defaultBufferLength = uint32_t(4) << 20; // 4MiB - - auto out = addOutputBuffer( - &stream, - (maxCompressedLength <= maxSingleStepLength ? - maxCompressedLength : - defaultBufferLength)); - - for (auto& range : *data) { - uint64_t remaining = range.size(); - uint64_t written = 0; - while (remaining) { - uint32_t step = (remaining > maxSingleStepLength ? - maxSingleStepLength : remaining); - stream.next_in = const_cast(range.data() + written); - stream.avail_in = step; - remaining -= step; - written += step; - - while (stream.avail_in != 0) { - if (stream.avail_out == 0) { - out->prependChain(addOutputBuffer(&stream, defaultBufferLength)); - } - - rc = deflate(&stream, Z_NO_FLUSH); - - CHECK_EQ(rc, Z_OK) << stream.msg; - } +void ZlibStreamCodec::resetInflateStream() { + if (inflateStream_) { + int const rc = inflateReset(inflateStream_.get_pointer()); + if (rc != Z_OK) { + inflateStream_.clear(); + throw std::runtime_error( + to("ZlibStreamCodec: inflateReset error: ", rc)); } + return; } - - do { - if (stream.avail_out == 0) { - out->prependChain(addOutputBuffer(&stream, defaultBufferLength)); - } - - rc = deflate(&stream, Z_FINISH); - } while (rc == Z_OK); - - CHECK_EQ(rc, Z_STREAM_END) << stream.msg; - - out->prev()->trimEnd(stream.avail_out); - - success = true; // we survived - - return out; -} - -std::unique_ptr ZlibCodec::doUncompress( - const IOBuf* data, - Optional uncompressedLength) { - z_stream stream; - stream.zalloc = nullptr; - stream.zfree = nullptr; - stream.opaque = nullptr; - + inflateStream_ = z_stream{}; // "The windowBits parameter is the base two logarithm of the maximum window // size (...) The default value is 15 (...) add 16 to decode only the gzip // format (the zlib format will return a Z_DATA_ERROR)." - int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0); - int rc = inflateInit2(&stream, windowBits); + int const windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0); + int const rc = inflateInit2(inflateStream_.get_pointer(), windowBits); if (rc != Z_OK) { - throw std::runtime_error(to( - "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg)); + inflateStream_.clear(); + throw std::runtime_error( + to("ZlibStreamCodec: inflateInit error: ", rc)); } +} - stream.next_in = stream.next_out = nullptr; - stream.avail_in = stream.avail_out = 0; - stream.total_in = stream.total_out = 0; +static int zlibTranslateFlush(StreamCodec::FlushOp flush) { + switch (flush) { + case StreamCodec::FlushOp::NONE: + return Z_NO_FLUSH; + case StreamCodec::FlushOp::FLUSH: + return Z_SYNC_FLUSH; + case StreamCodec::FlushOp::END: + return Z_FINISH; + default: + throw std::invalid_argument("ZlibStreamCodec: Invalid flush"); + } +} - bool success = false; +static int zlibThrowOnError(int rc) { + switch (rc) { + case Z_OK: + case Z_BUF_ERROR: + case Z_STREAM_END: + return rc; + default: + throw std::runtime_error(to("ZlibStreamCodec: error: ", rc)); + } +} +bool ZlibStreamCodec::doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flush) { + if (needReset_) { + resetDeflateStream(); + needReset_ = false; + } + DCHECK(deflateStream_.hasValue()); + // zlib will return Z_STREAM_ERROR if output.data() is null. + if (output.data() == nullptr) { + return false; + } + deflateStream_->next_in = const_cast(input.data()); + deflateStream_->avail_in = input.size(); + deflateStream_->next_out = output.data(); + deflateStream_->avail_out = output.size(); SCOPE_EXIT { - rc = inflateEnd(&stream); - // If we're here because of an exception, it's okay if some data - // got dropped. - CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR)) - << rc << ": " << stream.msg; + input.uncheckedAdvance(input.size() - deflateStream_->avail_in); + output.uncheckedAdvance(output.size() - deflateStream_->avail_out); }; - - // Max 64MiB in one go - constexpr uint64_t maxSingleStepLength = uint64_t(64) << 20; // 64MiB - constexpr uint64_t kBlockSize = uint64_t(32) << 10; // 32 KiB - const uint64_t defaultBufferLength = - computeBufferLength(data->computeChainDataLength(), kBlockSize); - - auto out = addOutputBuffer( - &stream, - ((uncompressedLength && *uncompressedLength <= maxSingleStepLength) - ? *uncompressedLength - : defaultBufferLength)); - - bool streamEnd = false; - for (auto& range : *data) { - if (range.empty()) { - continue; - } - - stream.next_in = const_cast(range.data()); - stream.avail_in = range.size(); - - while (stream.avail_in != 0) { - if (streamEnd) { - throw std::runtime_error(to( - "ZlibCodec: junk after end of data")); - } - - streamEnd = doInflate(&stream, out.get(), defaultBufferLength); - } + int const rc = zlibThrowOnError( + deflate(deflateStream_.get_pointer(), zlibTranslateFlush(flush))); + switch (flush) { + case StreamCodec::FlushOp::NONE: + return false; + case StreamCodec::FlushOp::FLUSH: + return deflateStream_->avail_in == 0 && deflateStream_->avail_out != 0; + case StreamCodec::FlushOp::END: + return rc == Z_STREAM_END; + default: + throw std::invalid_argument("ZlibStreamCodec: Invalid flush"); } +} - while (!streamEnd) { - streamEnd = doInflate(&stream, out.get(), defaultBufferLength); +bool ZlibStreamCodec::doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flush) { + if (needReset_) { + resetInflateStream(); + needReset_ = false; } - - out->prev()->trimEnd(stream.avail_out); - - if (uncompressedLength && *uncompressedLength != stream.total_out) { - throw std::runtime_error( - to("ZlibCodec: invalid uncompressed length")); + DCHECK(inflateStream_.hasValue()); + // zlib will return Z_STREAM_ERROR if output.data() is null. + if (output.data() == nullptr) { + return false; } - - success = true; // we survived - - return out; + inflateStream_->next_in = const_cast(input.data()); + inflateStream_->avail_in = input.size(); + inflateStream_->next_out = output.data(); + inflateStream_->avail_out = output.size(); + SCOPE_EXIT { + input.advance(input.size() - inflateStream_->avail_in); + output.advance(output.size() - inflateStream_->avail_out); + }; + int const rc = zlibThrowOnError( + inflate(inflateStream_.get_pointer(), zlibTranslateFlush(flush))); + return rc == Z_STREAM_END; } -#endif // FOLLY_HAVE_LIBZ +#endif // FOLLY_HAVE_LIBZ #if FOLLY_HAVE_LIBLZMA @@ -1812,7 +1794,9 @@ bool ZSTDStreamCodec::tryBlockUncompress( size_t const length = ZSTD_decompress( output.data(), *uncompressedLength(), input.data(), compressedLength); zstdThrowIfError(length); - DCHECK_EQ(length, *uncompressedLength()); + if (length != *uncompressedLength()) { + throw std::runtime_error("ZSTDStreamCodec: Incorrect uncompressed length"); + } input.uncheckedAdvance(compressedLength); output.uncheckedAdvance(length); return true; @@ -2043,8 +2027,11 @@ std::unique_ptr Bzip2Codec::doUncompress( if (stream.avail_out == 0) { out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength)); } - + size_t const outputSize = stream.avail_out; rc = bzCheck(BZ2_bzDecompress(&stream)); + if (outputSize == stream.avail_out) { + throw std::runtime_error("Bzip2Codec: Truncated input"); + } } out->prev()->trimEnd(stream.avail_out); @@ -2249,7 +2236,7 @@ constexpr Factory #endif #if FOLLY_HAVE_LIBZ - {ZlibCodec::create, nullptr}, + {ZlibStreamCodec::createCodec, ZlibStreamCodec::createStream}, #else {}, #endif @@ -2275,7 +2262,7 @@ constexpr Factory #endif #if FOLLY_HAVE_LIBZ - {ZlibCodec::create, nullptr}, + {ZlibStreamCodec::createCodec, ZlibStreamCodec::createStream}, #else {}, #endif @@ -2337,4 +2324,5 @@ std::unique_ptr getAutoUncompressionCodec( std::vector> customCodecs) { return AutomaticCodec::create(std::move(customCodecs)); } -}} // namespaces +} // namespace io +} // namespace folly