fix ZSTD support
authorPhilip Pronin <philipp@fb.com>
Tue, 20 Sep 2016 03:52:08 +0000 (20:52 -0700)
committerFacebook Github Bot 4 <facebook-github-bot-4-bot@fb.com>
Tue, 20 Sep 2016 03:53:41 +0000 (20:53 -0700)
Summary:
Existing logic is broken (unable to correctly handle chained `IOBuf`
in case of both `compress` and `uncompress`) and has unnecessarly strict
`needsUncompressedLength() == true` requirement.

This diff switches `ZSTDCodec` to use streaming to handle chained `IOBuf`,
drops `needsUncompressedLength() == true`.

Reviewed By: luciang

Differential Revision: D3827579

fbshipit-source-id: 0ef6a9ea664ef585d0e181bff6ca17166b28efc2

folly/io/Compression.cpp
folly/io/test/CompressionTest.cpp

index d83a7b5cabca70c0162e7bd1b5b264af7d4ae023..0b7257037cf27a645b1b8489f7ed24197fabdd85 100644 (file)
@@ -983,55 +983,137 @@ ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
 }
 
 bool ZSTDCodec::doNeedsUncompressedLength() const {
-  return true;
+  return false;
+}
+
+void zstdThrowIfError(size_t rc) {
+  if (!ZSTD_isError(rc)) {
+    return;
+  }
+  throw std::runtime_error(
+      to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
 }
 
 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
-  size_t rc;
-  size_t maxCompressedLength = ZSTD_compressBound(data->length());
-  auto out = IOBuf::createCombined(maxCompressedLength);
+  // Support earlier versions of the codec (working with a single IOBuf,
+  // and using ZSTD_decompress which requires ZSTD frame to contain size,
+  // which isn't populated by streaming API).
+  if (!data->isChained()) {
+    auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
+    const auto rc = ZSTD_compress(
+        out->writableData(),
+        out->capacity(),
+        data->data(),
+        data->length(),
+        level_);
+    zstdThrowIfError(rc);
+    out->append(rc);
+    return out;
+  }
 
-  CHECK_EQ(out->length(), 0);
+  auto zcs = ZSTD_createCStream();
+  SCOPE_EXIT {
+    ZSTD_freeCStream(zcs);
+  };
 
-  rc = ZSTD_compress(out->writableTail(),
-                     out->capacity(),
-                     data->data(),
-                     data->length(),
-                     level_);
+  auto rc = ZSTD_initCStream(zcs, level_);
+  zstdThrowIfError(rc);
 
-  if (ZSTD_isError(rc)) {
-    throw std::runtime_error(to<std::string>(
-          "ZSTD compression returned an error: ",
-          ZSTD_getErrorName(rc)));
+  Cursor cursor(data);
+  auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
+
+  ZSTD_outBuffer out;
+  out.dst = result->writableTail();
+  out.size = result->capacity();
+  out.pos = 0;
+
+  for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
+    ZSTD_inBuffer in;
+    in.src = buffer.data();
+    in.size = buffer.size();
+    for (in.pos = 0; in.pos != in.size;) {
+      rc = ZSTD_compressStream(zcs, &out, &in);
+      zstdThrowIfError(rc);
+    }
+    cursor.skip(in.size);
+    buffer = cursor.peekBytes();
   }
 
-  out->append(rc);
-  CHECK_EQ(out->length(), rc);
+  rc = ZSTD_endStream(zcs, &out);
+  zstdThrowIfError(rc);
+  CHECK_EQ(rc, 0);
 
-  return out;
+  result->append(out.pos);
+  return result;
 }
 
-std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(const IOBuf* data,
-                                               uint64_t uncompressedLength) {
-  size_t rc;
-  auto out = IOBuf::createCombined(uncompressedLength);
+std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
+    const IOBuf* data,
+    uint64_t uncompressedLength) {
+  auto zds = ZSTD_createDStream();
+  SCOPE_EXIT {
+    ZSTD_freeDStream(zds);
+  };
 
-  CHECK_GE(out->capacity(), uncompressedLength);
-  CHECK_EQ(out->length(), 0);
+  auto rc = ZSTD_initDStream(zds);
+  zstdThrowIfError(rc);
 
-  rc = ZSTD_decompress(
-      out->writableTail(), out->capacity(), data->data(), data->length());
+  ZSTD_outBuffer out{};
+  ZSTD_inBuffer in{};
 
-  if (ZSTD_isError(rc)) {
-    throw std::runtime_error(to<std::string>(
-          "ZSTD decompression returned an error: ",
-          ZSTD_getErrorName(rc)));
+  auto outputSize = ZSTD_DStreamOutSize();
+  if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
+    outputSize = uncompressedLength;
+  } else {
+    auto decompressedSize =
+        ZSTD_getDecompressedSize(data->data(), data->length());
+    if (decompressedSize != 0 && decompressedSize < outputSize) {
+      outputSize = decompressedSize;
+    }
   }
 
-  out->append(rc);
-  CHECK_EQ(out->length(), rc);
+  IOBufQueue queue(IOBufQueue::cacheChainLength());
+
+  Cursor cursor(data);
+  for (rc = 0;;) {
+    if (in.pos == in.size) {
+      auto buffer = cursor.peekBytes();
+      in.src = buffer.data();
+      in.size = buffer.size();
+      in.pos = 0;
+      cursor.skip(in.size);
+      if (rc > 1 && in.size == 0) {
+        throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
+      }
+    }
+    if (out.pos == out.size) {
+      if (out.pos != 0) {
+        queue.postallocate(out.pos);
+      }
+      auto buffer = queue.preallocate(outputSize, outputSize);
+      out.dst = buffer.first;
+      out.size = buffer.second;
+      out.pos = 0;
+      outputSize = ZSTD_DStreamOutSize();
+    }
+    rc = ZSTD_decompressStream(zds, &out, &in);
+    zstdThrowIfError(rc);
+    if (rc == 0) {
+      break;
+    }
+  }
+  if (out.pos != 0) {
+    queue.postallocate(out.pos);
+  }
+  if (in.pos != in.size || !cursor.isAtEnd()) {
+    throw std::runtime_error("ZSTD: junk after end of data");
+  }
+  if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
+      queue.chainLength() != uncompressedLength) {
+    throw std::runtime_error("ZSTD: invalid uncompressed length");
+  }
 
-  return out;
+  return queue.move();
 }
 
 #endif  // FOLLY_HAVE_LIBZSTD
index 7de76177e9f4fe95218ecbd86d42300fef6c547f..f3d66500cb4efc8e38a063272c32ceb05bf48884 100644 (file)
@@ -17,6 +17,7 @@
 #include <folly/io/Compression.h>
 
 #include <random>
+#include <set>
 #include <thread>
 #include <unordered_map>
 
@@ -128,31 +129,35 @@ TEST(CompressionTestNeedsUncompressedLength, Simple) {
   EXPECT_TRUE(getCodec(CodecType::LZMA2)->needsUncompressedLength());
   EXPECT_FALSE(getCodec(CodecType::LZMA2_VARINT_SIZE)
     ->needsUncompressedLength());
-  EXPECT_TRUE(getCodec(CodecType::ZSTD)->needsUncompressedLength());
+  EXPECT_FALSE(getCodec(CodecType::ZSTD)->needsUncompressedLength());
   EXPECT_FALSE(getCodec(CodecType::GZIP)->needsUncompressedLength());
 }
 
 class CompressionTest
-    : public testing::TestWithParam<std::tr1::tuple<int, CodecType>> {
-  protected:
-   void SetUp() override {
-     auto tup = GetParam();
-     uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
-     codec_ = getCodec(std::tr1::get<1>(tup));
-   }
+    : public testing::TestWithParam<std::tr1::tuple<int, int, CodecType>> {
+ protected:
+  void SetUp() override {
+    auto tup = GetParam();
+    uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
+    chunks_ = std::tr1::get<1>(tup);
+    codec_ = getCodec(std::tr1::get<2>(tup));
+  }
 
-   void runSimpleTest(const DataHolder& dh);
+  void runSimpleTest(const DataHolder& dh);
 
-   uint64_t uncompressedLength_;
-   std::unique_ptr<Codec> codec_;
+ private:
+  std::unique_ptr<IOBuf> split(std::unique_ptr<IOBuf> data) const;
+
+  uint64_t uncompressedLength_;
+  size_t chunks_;
+  std::unique_ptr<Codec> codec_;
 };
 
 void CompressionTest::runSimpleTest(const DataHolder& dh) {
-  auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength_));
-  auto compressed = codec_->compress(original.get());
+  const auto original = split(IOBuf::wrapBuffer(dh.data(uncompressedLength_)));
+  const auto compressed = split(codec_->compress(original.get()));
   if (!codec_->needsUncompressedLength()) {
     auto uncompressed = codec_->uncompress(compressed.get());
-
     EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
     EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
   }
@@ -164,6 +169,32 @@ void CompressionTest::runSimpleTest(const DataHolder& dh) {
   }
 }
 
+// Uniformly split data into (potentially empty) chunks.
+std::unique_ptr<IOBuf> CompressionTest::split(
+    std::unique_ptr<IOBuf> data) const {
+  if (data->isChained()) {
+    data->coalesce();
+  }
+
+  const size_t size = data->computeChainDataLength();
+
+  std::multiset<size_t> splits;
+  for (size_t i = 1; i < chunks_; ++i) {
+    splits.insert(Random::rand64(size));
+  }
+
+  folly::IOBufQueue result;
+
+  size_t offset = 0;
+  for (size_t split : splits) {
+    result.append(IOBuf::copyBuffer(data->data() + offset, split - offset));
+    offset = split;
+  }
+  result.append(IOBuf::copyBuffer(data->data() + offset, size - offset));
+
+  return result.move();
+}
+
 TEST_P(CompressionTest, RandomData) {
   runSimpleTest(randomDataHolder);
 }
@@ -175,16 +206,19 @@ TEST_P(CompressionTest, ConstantData) {
 INSTANTIATE_TEST_CASE_P(
     CompressionTest,
     CompressionTest,
-    testing::Combine(testing::Values(0, 1, 12, 22, 25, 27),
-                     testing::Values(CodecType::NO_COMPRESSION,
-                                     CodecType::LZ4,
-                                     CodecType::SNAPPY,
-                                     CodecType::ZLIB,
-                                     CodecType::LZ4_VARINT_SIZE,
-                                     CodecType::LZMA2,
-                                     CodecType::LZMA2_VARINT_SIZE,
-                                     CodecType::ZSTD,
-                                     CodecType::GZIP)));
+    testing::Combine(
+        testing::Values(0, 1, 12, 22, 25, 27),
+        testing::Values(1, 2, 3, 8, 65),
+        testing::Values(
+            CodecType::NO_COMPRESSION,
+            CodecType::LZ4,
+            CodecType::SNAPPY,
+            CodecType::ZLIB,
+            CodecType::LZ4_VARINT_SIZE,
+            CodecType::LZMA2,
+            CodecType::LZMA2_VARINT_SIZE,
+            CodecType::ZSTD,
+            CodecType::GZIP)));
 
 class CompressionVarintTest
     : public testing::TestWithParam<std::tr1::tuple<int, CodecType>> {
@@ -227,7 +261,9 @@ void CompressionVarintTest::runSimpleTest(const DataHolder& dh) {
   EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
 }
 
-TEST_P(CompressionVarintTest, RandomData) { runSimpleTest(randomDataHolder); }
+TEST_P(CompressionVarintTest, RandomData) {
+  runSimpleTest(randomDataHolder);
+}
 
 TEST_P(CompressionVarintTest, ConstantData) {
   runSimpleTest(constantDataHolder);
@@ -236,9 +272,11 @@ TEST_P(CompressionVarintTest, ConstantData) {
 INSTANTIATE_TEST_CASE_P(
     CompressionVarintTest,
     CompressionVarintTest,
-    testing::Combine(testing::Values(0, 1, 12, 22, 25, 27),
-                     testing::Values(CodecType::LZ4_VARINT_SIZE,
-                                     CodecType::LZMA2_VARINT_SIZE)));
+    testing::Combine(
+        testing::Values(0, 1, 12, 22, 25, 27),
+        testing::Values(
+            CodecType::LZ4_VARINT_SIZE,
+            CodecType::LZMA2_VARINT_SIZE)));
 
 class CompressionCorruptionTest : public testing::TestWithParam<CodecType> {
  protected: