From: Philip Pronin Date: Tue, 20 Sep 2016 03:52:08 +0000 (-0700) Subject: fix ZSTD support X-Git-Tag: v2016.09.26.00~13 X-Git-Url: http://plrg.eecs.uci.edu/git/?a=commitdiff_plain;h=151e22b290a371ee1dfcf605b5bfa0e58fe2d9ae;p=folly.git fix ZSTD support Summary: Existing logic is broken (unable to correctly handle chained `IOBuf` in case of both `compress` and `uncompress`) and has unnecessarly strict `needsUncompressedLength() == true` requirement. This diff switches `ZSTDCodec` to use streaming to handle chained `IOBuf`, drops `needsUncompressedLength() == true`. Reviewed By: luciang Differential Revision: D3827579 fbshipit-source-id: 0ef6a9ea664ef585d0e181bff6ca17166b28efc2 --- diff --git a/folly/io/Compression.cpp b/folly/io/Compression.cpp index d83a7b5c..0b725703 100644 --- a/folly/io/Compression.cpp +++ b/folly/io/Compression.cpp @@ -983,55 +983,137 @@ ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) { } bool ZSTDCodec::doNeedsUncompressedLength() const { - return true; + return false; +} + +void zstdThrowIfError(size_t rc) { + if (!ZSTD_isError(rc)) { + return; + } + throw std::runtime_error( + to("ZSTD returned an error: ", ZSTD_getErrorName(rc))); } std::unique_ptr ZSTDCodec::doCompress(const IOBuf* data) { - size_t rc; - size_t maxCompressedLength = ZSTD_compressBound(data->length()); - auto out = IOBuf::createCombined(maxCompressedLength); + // Support earlier versions of the codec (working with a single IOBuf, + // and using ZSTD_decompress which requires ZSTD frame to contain size, + // which isn't populated by streaming API). + if (!data->isChained()) { + auto out = IOBuf::createCombined(ZSTD_compressBound(data->length())); + const auto rc = ZSTD_compress( + out->writableData(), + out->capacity(), + data->data(), + data->length(), + level_); + zstdThrowIfError(rc); + out->append(rc); + return out; + } - CHECK_EQ(out->length(), 0); + auto zcs = ZSTD_createCStream(); + SCOPE_EXIT { + ZSTD_freeCStream(zcs); + }; - rc = ZSTD_compress(out->writableTail(), - out->capacity(), - data->data(), - data->length(), - level_); + auto rc = ZSTD_initCStream(zcs, level_); + zstdThrowIfError(rc); - if (ZSTD_isError(rc)) { - throw std::runtime_error(to( - "ZSTD compression returned an error: ", - ZSTD_getErrorName(rc))); + Cursor cursor(data); + auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength())); + + ZSTD_outBuffer out; + out.dst = result->writableTail(); + out.size = result->capacity(); + out.pos = 0; + + for (auto buffer = cursor.peekBytes(); !buffer.empty();) { + ZSTD_inBuffer in; + in.src = buffer.data(); + in.size = buffer.size(); + for (in.pos = 0; in.pos != in.size;) { + rc = ZSTD_compressStream(zcs, &out, &in); + zstdThrowIfError(rc); + } + cursor.skip(in.size); + buffer = cursor.peekBytes(); } - out->append(rc); - CHECK_EQ(out->length(), rc); + rc = ZSTD_endStream(zcs, &out); + zstdThrowIfError(rc); + CHECK_EQ(rc, 0); - return out; + result->append(out.pos); + return result; } -std::unique_ptr ZSTDCodec::doUncompress(const IOBuf* data, - uint64_t uncompressedLength) { - size_t rc; - auto out = IOBuf::createCombined(uncompressedLength); +std::unique_ptr ZSTDCodec::doUncompress( + const IOBuf* data, + uint64_t uncompressedLength) { + auto zds = ZSTD_createDStream(); + SCOPE_EXIT { + ZSTD_freeDStream(zds); + }; - CHECK_GE(out->capacity(), uncompressedLength); - CHECK_EQ(out->length(), 0); + auto rc = ZSTD_initDStream(zds); + zstdThrowIfError(rc); - rc = ZSTD_decompress( - out->writableTail(), out->capacity(), data->data(), data->length()); + ZSTD_outBuffer out{}; + ZSTD_inBuffer in{}; - if (ZSTD_isError(rc)) { - throw std::runtime_error(to( - "ZSTD decompression returned an error: ", - ZSTD_getErrorName(rc))); + auto outputSize = ZSTD_DStreamOutSize(); + if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) { + outputSize = uncompressedLength; + } else { + auto decompressedSize = + ZSTD_getDecompressedSize(data->data(), data->length()); + if (decompressedSize != 0 && decompressedSize < outputSize) { + outputSize = decompressedSize; + } } - out->append(rc); - CHECK_EQ(out->length(), rc); + IOBufQueue queue(IOBufQueue::cacheChainLength()); + + Cursor cursor(data); + for (rc = 0;;) { + if (in.pos == in.size) { + auto buffer = cursor.peekBytes(); + in.src = buffer.data(); + in.size = buffer.size(); + in.pos = 0; + cursor.skip(in.size); + if (rc > 1 && in.size == 0) { + throw std::runtime_error(to("ZSTD: incomplete input")); + } + } + if (out.pos == out.size) { + if (out.pos != 0) { + queue.postallocate(out.pos); + } + auto buffer = queue.preallocate(outputSize, outputSize); + out.dst = buffer.first; + out.size = buffer.second; + out.pos = 0; + outputSize = ZSTD_DStreamOutSize(); + } + rc = ZSTD_decompressStream(zds, &out, &in); + zstdThrowIfError(rc); + if (rc == 0) { + break; + } + } + if (out.pos != 0) { + queue.postallocate(out.pos); + } + if (in.pos != in.size || !cursor.isAtEnd()) { + throw std::runtime_error("ZSTD: junk after end of data"); + } + if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH && + queue.chainLength() != uncompressedLength) { + throw std::runtime_error("ZSTD: invalid uncompressed length"); + } - return out; + return queue.move(); } #endif // FOLLY_HAVE_LIBZSTD diff --git a/folly/io/test/CompressionTest.cpp b/folly/io/test/CompressionTest.cpp index 7de76177..f3d66500 100644 --- a/folly/io/test/CompressionTest.cpp +++ b/folly/io/test/CompressionTest.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -128,31 +129,35 @@ TEST(CompressionTestNeedsUncompressedLength, Simple) { EXPECT_TRUE(getCodec(CodecType::LZMA2)->needsUncompressedLength()); EXPECT_FALSE(getCodec(CodecType::LZMA2_VARINT_SIZE) ->needsUncompressedLength()); - EXPECT_TRUE(getCodec(CodecType::ZSTD)->needsUncompressedLength()); + EXPECT_FALSE(getCodec(CodecType::ZSTD)->needsUncompressedLength()); EXPECT_FALSE(getCodec(CodecType::GZIP)->needsUncompressedLength()); } class CompressionTest - : public testing::TestWithParam> { - protected: - void SetUp() override { - auto tup = GetParam(); - uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup); - codec_ = getCodec(std::tr1::get<1>(tup)); - } + : public testing::TestWithParam> { + protected: + void SetUp() override { + auto tup = GetParam(); + uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup); + chunks_ = std::tr1::get<1>(tup); + codec_ = getCodec(std::tr1::get<2>(tup)); + } - void runSimpleTest(const DataHolder& dh); + void runSimpleTest(const DataHolder& dh); - uint64_t uncompressedLength_; - std::unique_ptr codec_; + private: + std::unique_ptr split(std::unique_ptr data) const; + + uint64_t uncompressedLength_; + size_t chunks_; + std::unique_ptr codec_; }; void CompressionTest::runSimpleTest(const DataHolder& dh) { - auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength_)); - auto compressed = codec_->compress(original.get()); + const auto original = split(IOBuf::wrapBuffer(dh.data(uncompressedLength_))); + const auto compressed = split(codec_->compress(original.get())); if (!codec_->needsUncompressedLength()) { auto uncompressed = codec_->uncompress(compressed.get()); - EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength()); EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get())); } @@ -164,6 +169,32 @@ void CompressionTest::runSimpleTest(const DataHolder& dh) { } } +// Uniformly split data into (potentially empty) chunks. +std::unique_ptr CompressionTest::split( + std::unique_ptr data) const { + if (data->isChained()) { + data->coalesce(); + } + + const size_t size = data->computeChainDataLength(); + + std::multiset splits; + for (size_t i = 1; i < chunks_; ++i) { + splits.insert(Random::rand64(size)); + } + + folly::IOBufQueue result; + + size_t offset = 0; + for (size_t split : splits) { + result.append(IOBuf::copyBuffer(data->data() + offset, split - offset)); + offset = split; + } + result.append(IOBuf::copyBuffer(data->data() + offset, size - offset)); + + return result.move(); +} + TEST_P(CompressionTest, RandomData) { runSimpleTest(randomDataHolder); } @@ -175,16 +206,19 @@ TEST_P(CompressionTest, ConstantData) { INSTANTIATE_TEST_CASE_P( CompressionTest, CompressionTest, - testing::Combine(testing::Values(0, 1, 12, 22, 25, 27), - testing::Values(CodecType::NO_COMPRESSION, - CodecType::LZ4, - CodecType::SNAPPY, - CodecType::ZLIB, - CodecType::LZ4_VARINT_SIZE, - CodecType::LZMA2, - CodecType::LZMA2_VARINT_SIZE, - CodecType::ZSTD, - CodecType::GZIP))); + testing::Combine( + testing::Values(0, 1, 12, 22, 25, 27), + testing::Values(1, 2, 3, 8, 65), + testing::Values( + CodecType::NO_COMPRESSION, + CodecType::LZ4, + CodecType::SNAPPY, + CodecType::ZLIB, + CodecType::LZ4_VARINT_SIZE, + CodecType::LZMA2, + CodecType::LZMA2_VARINT_SIZE, + CodecType::ZSTD, + CodecType::GZIP))); class CompressionVarintTest : public testing::TestWithParam> { @@ -227,7 +261,9 @@ void CompressionVarintTest::runSimpleTest(const DataHolder& dh) { EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get())); } -TEST_P(CompressionVarintTest, RandomData) { runSimpleTest(randomDataHolder); } +TEST_P(CompressionVarintTest, RandomData) { + runSimpleTest(randomDataHolder); +} TEST_P(CompressionVarintTest, ConstantData) { runSimpleTest(constantDataHolder); @@ -236,9 +272,11 @@ TEST_P(CompressionVarintTest, ConstantData) { INSTANTIATE_TEST_CASE_P( CompressionVarintTest, CompressionVarintTest, - testing::Combine(testing::Values(0, 1, 12, 22, 25, 27), - testing::Values(CodecType::LZ4_VARINT_SIZE, - CodecType::LZMA2_VARINT_SIZE))); + testing::Combine( + testing::Values(0, 1, 12, 22, 25, 27), + testing::Values( + CodecType::LZ4_VARINT_SIZE, + CodecType::LZMA2_VARINT_SIZE))); class CompressionCorruptionTest : public testing::TestWithParam { protected: