From e1d2ddd5dca9f7a4a3c83ae7bc8c0c59c58333c1 Mon Sep 17 00:00:00 2001 From: Nick Terrell Date: Wed, 24 May 2017 14:59:42 -0700 Subject: [PATCH] Add zstd streaming interface Summary: * Add streaming interface to the `ZstdCodec` * Implement `ZstdCodec::doCompress()` and `ZstdCodec::doUncompress()` using the streaming interface. [fbgs CodecType::ZSTD](https://fburl.com/pr8chg64) and check that no caller requires thread-safety. Reviewed By: yfeldblum Differential Revision: D5026558 fbshipit-source-id: 61faa25c71f5aef06ca2d7e0700f43214353c650 --- folly/io/Compression.cpp | 336 +++++++++++++++++------------- folly/io/test/CompressionTest.cpp | 29 +++ 2 files changed, 217 insertions(+), 148 deletions(-) diff --git a/folly/io/Compression.cpp b/folly/io/Compression.cpp index 2c47093b..b523e697 100644 --- a/folly/io/Compression.cpp +++ b/folly/io/Compression.cpp @@ -40,6 +40,7 @@ #endif #if FOLLY_HAVE_LIBZSTD +#define ZSTD_STATIC_LINKING_ONLY #include #endif @@ -1584,13 +1585,24 @@ std::unique_ptr LZMA2Codec::doUncompress( #ifdef FOLLY_HAVE_LIBZSTD +namespace { +void zstdFreeCStream(ZSTD_CStream* zcs) { + ZSTD_freeCStream(zcs); +} + +void zstdFreeDStream(ZSTD_DStream* zds) { + ZSTD_freeDStream(zds); +} +} + /** * ZSTD compression */ -class ZSTDCodec final : public Codec { +class ZSTDStreamCodec final : public StreamCodec { public: - static std::unique_ptr create(int level, CodecType); - explicit ZSTDCodec(int level, CodecType type); + static std::unique_ptr createCodec(int level, CodecType); + static std::unique_ptr createStream(int level, CodecType); + explicit ZSTDStreamCodec(int level, CodecType type); std::vector validPrefixes() const override; bool canUncompress(const IOBuf* data, Optional uncompressedLength) @@ -1599,29 +1611,61 @@ class ZSTDCodec final : public Codec { private: bool doNeedsUncompressedLength() const override; uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override; - std::unique_ptr doCompress(const IOBuf* data) override; - std::unique_ptr doUncompress( - const IOBuf* data, - Optional uncompressedLength) override; + Optional doGetUncompressedLength( + IOBuf const* data, + Optional uncompressedLength) const override; + + void doResetStream() override; + bool doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) override; + bool doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) override; + + void resetCStream(); + void resetDStream(); + + bool tryBlockCompress(ByteRange& input, MutableByteRange& output) const; + bool tryBlockUncompress(ByteRange& input, MutableByteRange& output) const; int level_; + bool needReset_{true}; + std::unique_ptr< + ZSTD_CStream, + folly::static_function_deleter> + cstream_{nullptr}; + std::unique_ptr< + ZSTD_DStream, + folly::static_function_deleter> + dstream_{nullptr}; }; static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528; -std::vector ZSTDCodec::validPrefixes() const { +std::vector ZSTDStreamCodec::validPrefixes() const { return {prefixToStringLE(kZSTDMagicLE)}; } -bool ZSTDCodec::canUncompress(const IOBuf* data, Optional) const { +bool ZSTDStreamCodec::canUncompress(const IOBuf* data, Optional) + const { return dataStartsWithLE(data, kZSTDMagicLE); } -std::unique_ptr ZSTDCodec::create(int level, CodecType type) { - return std::make_unique(level, type); +std::unique_ptr ZSTDStreamCodec::createCodec(int level, CodecType type) { + return make_unique(level, type); +} + +std::unique_ptr ZSTDStreamCodec::createStream( + int level, + CodecType type) { + return make_unique(level, type); } -ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) { +ZSTDStreamCodec::ZSTDStreamCodec(int level, CodecType type) + : StreamCodec(type) { DCHECK(type == CodecType::ZSTD); switch (level) { case COMPRESSION_LEVEL_FASTEST: @@ -1641,11 +1685,12 @@ ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) { level_ = level; } -bool ZSTDCodec::doNeedsUncompressedLength() const { +bool ZSTDStreamCodec::doNeedsUncompressedLength() const { return false; } -uint64_t ZSTDCodec::doMaxCompressedLength(uint64_t uncompressedLength) const { +uint64_t ZSTDStreamCodec::doMaxCompressedLength( + uint64_t uncompressedLength) const { return ZSTD_compressBound(uncompressedLength); } @@ -1657,163 +1702,158 @@ void zstdThrowIfError(size_t rc) { to("ZSTD returned an error: ", ZSTD_getErrorName(rc))); } -std::unique_ptr ZSTDCodec::doCompress(const IOBuf* data) { - // Support earlier versions of the codec (working with a single IOBuf, - // and using ZSTD_decompress which requires ZSTD frame to contain size, - // which isn't populated by streaming API). - if (!data->isChained()) { - auto out = IOBuf::createCombined(ZSTD_compressBound(data->length())); - const auto rc = ZSTD_compress( - out->writableData(), - out->capacity(), - data->data(), - data->length(), - level_); - zstdThrowIfError(rc); - out->append(rc); - return out; - } - - auto zcs = ZSTD_createCStream(); - SCOPE_EXIT { - ZSTD_freeCStream(zcs); - }; - - auto rc = ZSTD_initCStream(zcs, level_); - zstdThrowIfError(rc); - - Cursor cursor(data); - auto result = - IOBuf::createCombined(maxCompressedLength(cursor.totalLength())); - - ZSTD_outBuffer out; - out.dst = result->writableTail(); - out.size = result->capacity(); - out.pos = 0; - - for (auto buffer = cursor.peekBytes(); !buffer.empty();) { - ZSTD_inBuffer in; - in.src = buffer.data(); - in.size = buffer.size(); - for (in.pos = 0; in.pos != in.size;) { - rc = ZSTD_compressStream(zcs, &out, &in); - zstdThrowIfError(rc); +Optional ZSTDStreamCodec::doGetUncompressedLength( + IOBuf const* data, + Optional uncompressedLength) const { + // Read decompressed size from frame if available in first IOBuf. + auto const decompressedSize = + ZSTD_getDecompressedSize(data->data(), data->length()); + if (decompressedSize != 0) { + if (uncompressedLength && *uncompressedLength != decompressedSize) { + throw std::runtime_error("ZSTD: invalid uncompressed length"); } - cursor.skip(in.size); - buffer = cursor.peekBytes(); + uncompressedLength = decompressedSize; } + return uncompressedLength; +} - rc = ZSTD_endStream(zcs, &out); - zstdThrowIfError(rc); - CHECK_EQ(rc, 0); +void ZSTDStreamCodec::doResetStream() { + needReset_ = true; +} - result->append(out.pos); - return result; +bool ZSTDStreamCodec::tryBlockCompress( + ByteRange& input, + MutableByteRange& output) const { + DCHECK(needReset_); + // We need to know that we have enough output space to use block compression + if (output.size() < ZSTD_compressBound(input.size())) { + return false; + } + size_t const length = ZSTD_compress( + output.data(), output.size(), input.data(), input.size(), level_); + zstdThrowIfError(length); + input.uncheckedAdvance(input.size()); + output.uncheckedAdvance(length); + return true; } -static std::unique_ptr zstdUncompressBuffer( - const IOBuf* data, - Optional uncompressedLength) { - // Check preconditions - DCHECK(!data->isChained()); - DCHECK(uncompressedLength.hasValue()); - - auto uncompressed = IOBuf::create(*uncompressedLength); - const auto decompressedSize = ZSTD_decompress( - uncompressed->writableTail(), - uncompressed->tailroom(), - data->data(), - data->length()); - zstdThrowIfError(decompressedSize); - if (decompressedSize != uncompressedLength) { - throw std::runtime_error("ZSTD: invalid uncompressed length"); +void ZSTDStreamCodec::resetCStream() { + if (!cstream_) { + cstream_.reset(ZSTD_createCStream()); + if (!cstream_) { + throw std::bad_alloc{}; + } } - uncompressed->append(decompressedSize); - return uncompressed; + // Advanced API usage works for all supported versions of zstd. + // Required to set contentSizeFlag. + auto params = ZSTD_getParams(level_, uncompressedLength().value_or(0), 0); + params.fParams.contentSizeFlag = uncompressedLength().hasValue(); + zstdThrowIfError(ZSTD_initCStream_advanced( + cstream_.get(), nullptr, 0, params, uncompressedLength().value_or(0))); } -static std::unique_ptr zstdUncompressStream( - const IOBuf* data, - Optional uncompressedLength) { - auto zds = ZSTD_createDStream(); +bool ZSTDStreamCodec::doCompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (needReset_) { + // If we are given all the input in one chunk try to use block compression + if (flushOp == StreamCodec::FlushOp::END && + tryBlockCompress(input, output)) { + return true; + } + resetCStream(); + needReset_ = false; + } + ZSTD_inBuffer in = {input.data(), input.size(), 0}; + ZSTD_outBuffer out = {output.data(), output.size(), 0}; SCOPE_EXIT { - ZSTD_freeDStream(zds); + input.uncheckedAdvance(in.pos); + output.uncheckedAdvance(out.pos); }; - - auto rc = ZSTD_initDStream(zds); - zstdThrowIfError(rc); - - ZSTD_outBuffer out{}; - ZSTD_inBuffer in{}; - - auto outputSize = uncompressedLength.value_or(ZSTD_DStreamOutSize()); - - IOBufQueue queue(IOBufQueue::cacheChainLength()); - - Cursor cursor(data); - for (rc = 0;;) { - if (in.pos == in.size) { - auto buffer = cursor.peekBytes(); - in.src = buffer.data(); - in.size = buffer.size(); - in.pos = 0; - cursor.skip(in.size); - if (rc > 1 && in.size == 0) { - throw std::runtime_error(to("ZSTD: incomplete input")); - } - } - if (out.pos == out.size) { - if (out.pos != 0) { - queue.postallocate(out.pos); - } - auto buffer = queue.preallocate(outputSize, outputSize); - out.dst = buffer.first; - out.size = buffer.second; - out.pos = 0; - outputSize = ZSTD_DStreamOutSize(); + if (flushOp == StreamCodec::FlushOp::NONE || !input.empty()) { + zstdThrowIfError(ZSTD_compressStream(cstream_.get(), &out, &in)); + } + if (in.pos == in.size && flushOp != StreamCodec::FlushOp::NONE) { + size_t rc; + switch (flushOp) { + case StreamCodec::FlushOp::FLUSH: + rc = ZSTD_flushStream(cstream_.get(), &out); + break; + case StreamCodec::FlushOp::END: + rc = ZSTD_endStream(cstream_.get(), &out); + break; + default: + throw std::invalid_argument("ZSTD: invalid FlushOp"); } - rc = ZSTD_decompressStream(zds, &out, &in); zstdThrowIfError(rc); if (rc == 0) { - break; + return true; } } - if (out.pos != 0) { - queue.postallocate(out.pos); - } - if (in.pos != in.size || !cursor.isAtEnd()) { - throw std::runtime_error("ZSTD: junk after end of data"); - } - if (uncompressedLength && queue.chainLength() != *uncompressedLength) { - throw std::runtime_error("ZSTD: invalid uncompressed length"); - } + return false; +} - return queue.move(); +bool ZSTDStreamCodec::tryBlockUncompress( + ByteRange& input, + MutableByteRange& output) const { + DCHECK(needReset_); +#if ZSTD_VERSION_NUMBER < 10104 + // We require ZSTD_findFrameCompressedSize() to perform this optimization. + return false; +#else + // We need to know the uncompressed length and have enough output space. + if (!uncompressedLength() || output.size() < *uncompressedLength()) { + return false; + } + size_t const compressedLength = + ZSTD_findFrameCompressedSize(input.data(), input.size()); + zstdThrowIfError(compressedLength); + size_t const length = ZSTD_decompress( + output.data(), *uncompressedLength(), input.data(), compressedLength); + zstdThrowIfError(length); + DCHECK_EQ(length, *uncompressedLength()); + input.uncheckedAdvance(compressedLength); + output.uncheckedAdvance(length); + return true; +#endif } -std::unique_ptr ZSTDCodec::doUncompress( - const IOBuf* data, - Optional uncompressedLength) { - { - // Read decompressed size from frame if available in first IOBuf. - const auto decompressedSize = - ZSTD_getDecompressedSize(data->data(), data->length()); - if (decompressedSize != 0) { - if (uncompressedLength && *uncompressedLength != decompressedSize) { - throw std::runtime_error("ZSTD: invalid uncompressed length"); - } - uncompressedLength = decompressedSize; +void ZSTDStreamCodec::resetDStream() { + if (!dstream_) { + dstream_.reset(ZSTD_createDStream()); + if (!dstream_) { + throw std::bad_alloc{}; } } - // Faster to decompress using ZSTD_decompress() if we can. - if (uncompressedLength && !data->isChained()) { - return zstdUncompressBuffer(data, uncompressedLength); + zstdThrowIfError(ZSTD_initDStream(dstream_.get())); +} + +bool ZSTDStreamCodec::doUncompressStream( + ByteRange& input, + MutableByteRange& output, + StreamCodec::FlushOp flushOp) { + if (needReset_) { + // If we are given all the input in one chunk try to use block uncompression + if (flushOp == StreamCodec::FlushOp::END && + tryBlockUncompress(input, output)) { + return true; + } + resetDStream(); + needReset_ = false; } - // Fall back to slower streaming decompression. - return zstdUncompressStream(data, uncompressedLength); + ZSTD_inBuffer in = {input.data(), input.size(), 0}; + ZSTD_outBuffer out = {output.data(), output.size(), 0}; + SCOPE_EXIT { + input.uncheckedAdvance(in.pos); + output.uncheckedAdvance(out.pos); + }; + size_t const rc = ZSTD_decompressStream(dstream_.get(), &out, &in); + zstdThrowIfError(rc); + return rc == 0; } -#endif // FOLLY_HAVE_LIBZSTD +#endif // FOLLY_HAVE_LIBZSTD #if FOLLY_HAVE_LIBBZ2 @@ -2229,7 +2269,7 @@ constexpr Factory #endif #if FOLLY_HAVE_LIBZSTD - {ZSTDCodec::create, nullptr}, + {ZSTDStreamCodec::createCodec, ZSTDStreamCodec::createStream}, #else {}, #endif diff --git a/folly/io/test/CompressionTest.cpp b/folly/io/test/CompressionTest.cpp index d1d3a177..4c45648a 100644 --- a/folly/io/test/CompressionTest.cpp +++ b/folly/io/test/CompressionTest.cpp @@ -34,6 +34,10 @@ #include #include +#if FOLLY_HAVE_LIBZSTD +#include +#endif + namespace folly { namespace io { namespace test { class DataHolder : private boost::noncopyable { @@ -1084,6 +1088,31 @@ TEST(CheckCompatibleTest, ZlibIsPrefix) { EXPECT_THROW_IF_DEBUG( getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument); } + +#if FOLLY_HAVE_LIBZSTD + +TEST(ZstdTest, BackwardCompatible) { + auto codec = getCodec(CodecType::ZSTD); + { + auto const data = IOBuf::wrapBuffer(randomDataHolder.data(size_t(1) << 20)); + auto compressed = codec->compress(data.get()); + compressed->coalesce(); + EXPECT_EQ( + data->length(), + ZSTD_getDecompressedSize(compressed->data(), compressed->length())); + } + { + auto const data = + IOBuf::wrapBuffer(randomDataHolder.data(size_t(100) << 20)); + auto compressed = codec->compress(data.get()); + compressed->coalesce(); + EXPECT_EQ( + data->length(), + ZSTD_getDecompressedSize(compressed->data(), compressed->length())); + } +} + +#endif }}} // namespaces int main(int argc, char *argv[]) { -- 2.34.1