Fix decompression of truncated data
[folly.git] / folly / io / test / CompressionTest.cpp
index 6d12f57cf6b4d915818d49d9befdd3cb8993fd71..e15a18da39f459c40086ef04ba25af06a626ac17 100644 (file)
 
 #include <folly/io/Compression.h>
 
+#include <algorithm>
 #include <random>
 #include <set>
 #include <thread>
 #include <unordered_map>
+#include <utility>
 
 #include <boost/noncopyable.hpp>
 #include <glog/logging.h>
 #include <folly/io/IOBufQueue.h>
 #include <folly/portability/GTest.h>
 
+#if FOLLY_HAVE_LIBZSTD
+#include <zstd.h>
+#endif
+
 namespace folly { namespace io { namespace test {
 
 class DataHolder : private boost::noncopyable {
@@ -83,23 +89,22 @@ class RandomDataHolder : public DataHolder {
 
 RandomDataHolder::RandomDataHolder(size_t sizeLog2)
   : DataHolder(sizeLog2) {
-  constexpr size_t numThreadsLog2 = 3;
-  constexpr size_t numThreads = size_t(1) << numThreadsLog2;
+  static constexpr size_t numThreadsLog2 = 3;
+  static constexpr size_t numThreads = size_t(1) << numThreadsLog2;
 
   uint32_t seed = randomNumberSeed();
 
   std::vector<std::thread> threads;
   threads.reserve(numThreads);
   for (size_t t = 0; t < numThreads; ++t) {
-    threads.emplace_back(
-        [this, seed, t, numThreadsLog2, sizeLog2] () {
-          std::mt19937 rng(seed + t);
-          size_t countLog2 = sizeLog2 - numThreadsLog2;
-          size_t start = size_t(t) << countLog2;
-          for (size_t i = 0; i < countLog2; ++i) {
-            this->data_[start + i] = rng();
-          }
-        });
+    threads.emplace_back([this, seed, t, sizeLog2] {
+      std::mt19937 rng(seed + t);
+      size_t countLog2 = sizeLog2 - numThreadsLog2;
+      size_t start = size_t(t) << countLog2;
+      for (size_t i = 0; i < countLog2; ++i) {
+        this->data_[start + i] = rng();
+      }
+    });
   }
 
   for (auto& t : threads) {
@@ -148,6 +153,19 @@ static std::vector<CodecType> availableCodecs() {
   return codecs;
 }
 
+static std::vector<CodecType> availableStreamCodecs() {
+  std::vector<CodecType> codecs;
+
+  for (size_t i = 0; i < static_cast<size_t>(CodecType::NUM_CODEC_TYPES); ++i) {
+    auto type = static_cast<CodecType>(i);
+    if (hasStreamCodec(type)) {
+      codecs.push_back(type);
+    }
+  }
+
+  return codecs;
+}
+
 TEST(CompressionTestNeedsUncompressedLength, Simple) {
   static const struct { CodecType type; bool needsUncompressedLength; }
     expectations[] = {
@@ -365,15 +383,29 @@ void CompressionCorruptionTest::runSimpleTest(const DataHolder& dh) {
   EXPECT_THROW(codec_->uncompress(compressed.get(), uncompressedLength + 1),
                std::runtime_error);
 
+  auto corrupted = compressed->clone();
+  corrupted->unshare();
+  // Truncate the last character
+  corrupted->prev()->trimEnd(1);
+  if (!codec_->needsUncompressedLength()) {
+    EXPECT_THROW(codec_->uncompress(corrupted.get()),
+                 std::runtime_error);
+  }
+
+  EXPECT_THROW(codec_->uncompress(corrupted.get(), uncompressedLength),
+               std::runtime_error);
+
+  corrupted = compressed->clone();
+  corrupted->unshare();
   // Corrupt the first character
-  ++(compressed->writableData()[0]);
+  ++(corrupted->writableData()[0]);
 
   if (!codec_->needsUncompressedLength()) {
-    EXPECT_THROW(codec_->uncompress(compressed.get()),
+    EXPECT_THROW(codec_->uncompress(corrupted.get()),
                  std::runtime_error);
   }
 
-  EXPECT_THROW(codec_->uncompress(compressed.get(), uncompressedLength),
+  EXPECT_THROW(codec_->uncompress(corrupted.get(), uncompressedLength),
                std::runtime_error);
 }
 
@@ -400,6 +432,422 @@ INSTANTIATE_TEST_CASE_P(
             CodecType::BZIP2,
         })));
 
+class StreamingUnitTest : public testing::TestWithParam<CodecType> {
+ protected:
+  void SetUp() override {
+    codec_ = getStreamCodec(GetParam());
+  }
+
+  std::unique_ptr<StreamCodec> codec_;
+};
+
+TEST_P(StreamingUnitTest, maxCompressedLength) {
+  EXPECT_EQ(0, codec_->maxCompressedLength(0));
+  for (uint64_t const length : {1, 10, 100, 1000, 10000, 100000, 1000000}) {
+    EXPECT_GE(codec_->maxCompressedLength(length), length);
+  }
+}
+
+TEST_P(StreamingUnitTest, getUncompressedLength) {
+  auto const empty = IOBuf::create(0);
+  EXPECT_EQ(uint64_t(0), codec_->getUncompressedLength(empty.get()));
+  EXPECT_EQ(uint64_t(0), codec_->getUncompressedLength(empty.get(), 0));
+
+  auto const data = IOBuf::wrapBuffer(randomDataHolder.data(100));
+  auto const compressed = codec_->compress(data.get());
+
+  EXPECT_ANY_THROW(codec_->getUncompressedLength(data.get(), 0));
+  if (auto const length = codec_->getUncompressedLength(data.get())) {
+    EXPECT_EQ(100, *length);
+  }
+  EXPECT_EQ(uint64_t(100), codec_->getUncompressedLength(data.get(), 100));
+  // If the uncompressed length is stored in the frame, then make sure it throws
+  // when it is given the wrong length.
+  if (codec_->getUncompressedLength(data.get()) == uint64_t(100)) {
+    EXPECT_ANY_THROW(codec_->getUncompressedLength(data.get(), 200));
+  }
+}
+
+TEST_P(StreamingUnitTest, emptyData) {
+  ByteRange input{};
+  auto buffer = IOBuf::create(1);
+  buffer->append(buffer->capacity());
+  MutableByteRange output{};
+
+  // Test compressing empty data in one pass
+  EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END));
+  codec_->resetStream(0);
+  EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END));
+  codec_->resetStream();
+  output = {buffer->writableData(), buffer->length()};
+  EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END));
+  EXPECT_EQ(buffer->length(), output.size());
+
+  // Test compressing empty data with multiple calls to compressStream()
+  codec_->resetStream();
+  output = {};
+  EXPECT_FALSE(codec_->compressStream(input, output));
+  EXPECT_TRUE(
+      codec_->compressStream(input, output, StreamCodec::FlushOp::FLUSH));
+  EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END));
+  codec_->resetStream();
+  output = {buffer->writableData(), buffer->length()};
+  EXPECT_FALSE(codec_->compressStream(input, output));
+  EXPECT_TRUE(
+      codec_->compressStream(input, output, StreamCodec::FlushOp::FLUSH));
+  EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END));
+  EXPECT_EQ(buffer->length(), output.size());
+
+  // Test uncompressing empty data
+  output = {};
+  codec_->resetStream();
+  EXPECT_TRUE(codec_->uncompressStream(input, output));
+  codec_->resetStream();
+  EXPECT_TRUE(
+      codec_->uncompressStream(input, output, StreamCodec::FlushOp::FLUSH));
+  codec_->resetStream();
+  EXPECT_TRUE(
+      codec_->uncompressStream(input, output, StreamCodec::FlushOp::END));
+  codec_->resetStream(0);
+  EXPECT_TRUE(codec_->uncompressStream(input, output));
+  codec_->resetStream(0);
+  EXPECT_TRUE(
+      codec_->uncompressStream(input, output, StreamCodec::FlushOp::FLUSH));
+  codec_->resetStream(0);
+  EXPECT_TRUE(
+      codec_->uncompressStream(input, output, StreamCodec::FlushOp::END));
+}
+
+TEST_P(StreamingUnitTest, noForwardProgressOkay) {
+  auto inBuffer = IOBuf::create(2);
+  inBuffer->writableData()[0] = 'a';
+  inBuffer->writableData()[0] = 'a';
+  inBuffer->append(2);
+  auto input = inBuffer->coalesce();
+  auto compressed = codec_->compress(inBuffer.get());
+
+  auto outBuffer = IOBuf::create(codec_->maxCompressedLength(2));
+  MutableByteRange output{outBuffer->writableTail(), outBuffer->tailroom()};
+
+  ByteRange emptyInput;
+  MutableByteRange emptyOutput;
+
+  // Compress some data to avoid empty data special casing
+  codec_->resetStream();
+  while (!input.empty()) {
+    codec_->compressStream(input, output);
+  }
+  // empty input and output is okay for flush NONE and FLUSH.
+  codec_->compressStream(emptyInput, emptyOutput);
+  codec_->compressStream(emptyInput, emptyOutput, StreamCodec::FlushOp::FLUSH);
+
+  codec_->resetStream();
+  input = inBuffer->coalesce();
+  output = {outBuffer->writableTail(), outBuffer->tailroom()};
+  while (!input.empty()) {
+    codec_->compressStream(input, output);
+  }
+  // empty input and output is okay for flush END.
+  codec_->compressStream(emptyInput, emptyOutput, StreamCodec::FlushOp::END);
+
+  codec_->resetStream();
+  input = compressed->coalesce();
+  input.uncheckedSubtract(1); // Remove last byte so the operation is incomplete
+  output = {inBuffer->writableData(), inBuffer->length()};
+  // Uncompress some data to avoid empty data special casing
+  while (!input.empty()) {
+    EXPECT_FALSE(codec_->uncompressStream(input, output));
+  }
+  // empty input and output is okay for all flush values.
+  EXPECT_FALSE(codec_->uncompressStream(emptyInput, emptyOutput));
+  EXPECT_FALSE(codec_->uncompressStream(
+      emptyInput, emptyOutput, StreamCodec::FlushOp::FLUSH));
+  EXPECT_FALSE(codec_->uncompressStream(
+      emptyInput, emptyOutput, StreamCodec::FlushOp::END));
+}
+
+TEST_P(StreamingUnitTest, stateTransitions) {
+  auto inBuffer = IOBuf::create(1);
+  inBuffer->writableData()[0] = 'a';
+  inBuffer->append(1);
+  auto compressed = codec_->compress(inBuffer.get());
+  ByteRange const in = compressed->coalesce();
+  auto outBuffer = IOBuf::create(codec_->maxCompressedLength(in.size()));
+  MutableByteRange const out{outBuffer->writableTail(), outBuffer->tailroom()};
+
+  auto compress = [&](
+      StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE,
+      bool empty = false) {
+    auto input = in;
+    auto output = empty ? MutableByteRange{} : out;
+    return codec_->compressStream(input, output, flushOp);
+  };
+  auto uncompress = [&](
+      StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE,
+      bool empty = false) {
+    auto input = in;
+    auto output = empty ? MutableByteRange{} : out;
+    return codec_->uncompressStream(input, output, flushOp);
+  };
+
+  // compression flow
+  codec_->resetStream();
+  EXPECT_FALSE(compress());
+  EXPECT_FALSE(compress());
+  EXPECT_TRUE(compress(StreamCodec::FlushOp::FLUSH));
+  EXPECT_FALSE(compress());
+  EXPECT_TRUE(compress(StreamCodec::FlushOp::END));
+  // uncompression flow
+  codec_->resetStream();
+  EXPECT_FALSE(uncompress(StreamCodec::FlushOp::NONE, true));
+  codec_->resetStream();
+  EXPECT_FALSE(uncompress(StreamCodec::FlushOp::FLUSH, true));
+  codec_->resetStream();
+  EXPECT_FALSE(uncompress(StreamCodec::FlushOp::NONE, true));
+  codec_->resetStream();
+  EXPECT_FALSE(uncompress(StreamCodec::FlushOp::NONE, true));
+  codec_->resetStream();
+  EXPECT_TRUE(uncompress(StreamCodec::FlushOp::FLUSH));
+  // compress -> uncompress
+  codec_->resetStream();
+  EXPECT_FALSE(compress());
+  EXPECT_THROW(uncompress(), std::logic_error);
+  // uncompress -> compress
+  codec_->resetStream();
+  EXPECT_TRUE(uncompress(StreamCodec::FlushOp::FLUSH));
+  EXPECT_THROW(compress(), std::logic_error);
+  // end -> compress
+  codec_->resetStream();
+  EXPECT_FALSE(compress());
+  EXPECT_TRUE(compress(StreamCodec::FlushOp::END));
+  EXPECT_THROW(compress(), std::logic_error);
+  // end -> uncompress
+  codec_->resetStream();
+  EXPECT_TRUE(uncompress(StreamCodec::FlushOp::FLUSH));
+  EXPECT_THROW(uncompress(), std::logic_error);
+  // flush -> compress
+  codec_->resetStream();
+  EXPECT_FALSE(compress(StreamCodec::FlushOp::FLUSH, true));
+  EXPECT_THROW(compress(), std::logic_error);
+  // flush -> end
+  codec_->resetStream();
+  EXPECT_FALSE(compress(StreamCodec::FlushOp::FLUSH, true));
+  EXPECT_THROW(compress(StreamCodec::FlushOp::END), std::logic_error);
+  // undefined -> compress
+  codec_->compress(inBuffer.get());
+  EXPECT_THROW(compress(), std::logic_error);
+  codec_->uncompress(compressed.get());
+  EXPECT_THROW(compress(), std::logic_error);
+  // undefined -> undefined
+  codec_->uncompress(compressed.get());
+  codec_->compress(inBuffer.get());
+}
+
+INSTANTIATE_TEST_CASE_P(
+    StreamingUnitTest,
+    StreamingUnitTest,
+    testing::ValuesIn(availableStreamCodecs()));
+
+class StreamingCompressionTest
+    : public testing::TestWithParam<std::tuple<int, int, CodecType>> {
+ protected:
+  void SetUp() override {
+    auto const tup = GetParam();
+    uncompressedLength_ = uint64_t(1) << std::get<0>(tup);
+    chunkSize_ = size_t(1) << std::get<1>(tup);
+    codec_ = getStreamCodec(std::get<2>(tup));
+  }
+
+  void runResetStreamTest(DataHolder const& dh);
+  void runCompressStreamTest(DataHolder const& dh);
+  void runUncompressStreamTest(DataHolder const& dh);
+  void runFlushTest(DataHolder const& dh);
+
+ private:
+  std::vector<ByteRange> split(ByteRange data) const;
+
+  uint64_t uncompressedLength_;
+  size_t chunkSize_;
+  std::unique_ptr<StreamCodec> codec_;
+};
+
+std::vector<ByteRange> StreamingCompressionTest::split(ByteRange data) const {
+  size_t const pieces = std::max<size_t>(1, data.size() / chunkSize_);
+  std::vector<ByteRange> result;
+  result.reserve(pieces + 1);
+  while (!data.empty()) {
+    size_t const pieceSize = std::min(data.size(), chunkSize_);
+    result.push_back(data.subpiece(0, pieceSize));
+    data.uncheckedAdvance(pieceSize);
+  }
+  return result;
+}
+
+static std::unique_ptr<IOBuf> compressSome(
+    StreamCodec* codec,
+    ByteRange data,
+    uint64_t bufferSize,
+    StreamCodec::FlushOp flush) {
+  bool result;
+  IOBufQueue queue;
+  do {
+    auto buffer = IOBuf::create(bufferSize);
+    buffer->append(buffer->capacity());
+    MutableByteRange output{buffer->writableData(), buffer->length()};
+
+    result = codec->compressStream(data, output, flush);
+    buffer->trimEnd(output.size());
+    queue.append(std::move(buffer));
+
+  } while (!(flush == StreamCodec::FlushOp::NONE && data.empty()) && !result);
+  EXPECT_TRUE(data.empty());
+  return queue.move();
+}
+
+static std::pair<bool, std::unique_ptr<IOBuf>> uncompressSome(
+    StreamCodec* codec,
+    ByteRange& data,
+    uint64_t bufferSize,
+    StreamCodec::FlushOp flush) {
+  bool result;
+  IOBufQueue queue;
+  do {
+    auto buffer = IOBuf::create(bufferSize);
+    buffer->append(buffer->capacity());
+    MutableByteRange output{buffer->writableData(), buffer->length()};
+
+    result = codec->uncompressStream(data, output, flush);
+    buffer->trimEnd(output.size());
+    queue.append(std::move(buffer));
+
+  } while (queue.tailroom() == 0 && !result);
+  return std::make_pair(result, queue.move());
+}
+
+void StreamingCompressionTest::runResetStreamTest(DataHolder const& dh) {
+  auto const input = dh.data(uncompressedLength_);
+  // Compress some but leave state unclean
+  codec_->resetStream(uncompressedLength_);
+  compressSome(codec_.get(), input, chunkSize_, StreamCodec::FlushOp::NONE);
+  // Reset stream and compress all
+  codec_->resetStream();
+  auto compressed =
+      compressSome(codec_.get(), input, chunkSize_, StreamCodec::FlushOp::END);
+  auto const uncompressed = codec_->uncompress(compressed.get(), input.size());
+  EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
+}
+
+TEST_P(StreamingCompressionTest, resetStream) {
+  runResetStreamTest(constantDataHolder);
+  runResetStreamTest(randomDataHolder);
+}
+
+void StreamingCompressionTest::runCompressStreamTest(
+    const folly::io::test::DataHolder& dh) {
+  auto const inputs = split(dh.data(uncompressedLength_));
+
+  IOBufQueue queue;
+  codec_->resetStream(uncompressedLength_);
+  // Compress many inputs in a row
+  for (auto const input : inputs) {
+    queue.append(compressSome(
+        codec_.get(), input, chunkSize_, StreamCodec::FlushOp::NONE));
+  }
+  // Finish the operation with empty input.
+  ByteRange empty;
+  queue.append(
+      compressSome(codec_.get(), empty, chunkSize_, StreamCodec::FlushOp::END));
+
+  auto const uncompressed = codec_->uncompress(queue.front());
+  EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
+}
+
+TEST_P(StreamingCompressionTest, compressStream) {
+  runCompressStreamTest(constantDataHolder);
+  runCompressStreamTest(randomDataHolder);
+}
+
+void StreamingCompressionTest::runUncompressStreamTest(
+    const folly::io::test::DataHolder& dh) {
+  auto const data = IOBuf::wrapBuffer(dh.data(uncompressedLength_));
+  // Concatenate 3 compressed frames in a row
+  auto compressed = codec_->compress(data.get());
+  compressed->prependChain(codec_->compress(data.get()));
+  compressed->prependChain(codec_->compress(data.get()));
+  // Pass all 3 compressed frames in one input buffer
+  auto input = compressed->coalesce();
+  // Uncompress the first frame
+  codec_->resetStream(data->computeChainDataLength());
+  {
+    auto const result = uncompressSome(
+        codec_.get(), input, chunkSize_, StreamCodec::FlushOp::FLUSH);
+    ASSERT_TRUE(result.first);
+    ASSERT_EQ(hashIOBuf(data.get()), hashIOBuf(result.second.get()));
+  }
+  // Uncompress the second frame
+  codec_->resetStream();
+  {
+    auto const result = uncompressSome(
+        codec_.get(), input, chunkSize_, StreamCodec::FlushOp::END);
+    ASSERT_TRUE(result.first);
+    ASSERT_EQ(hashIOBuf(data.get()), hashIOBuf(result.second.get()));
+  }
+  // Uncompress the third frame
+  codec_->resetStream();
+  {
+    auto const result = uncompressSome(
+        codec_.get(), input, chunkSize_, StreamCodec::FlushOp::FLUSH);
+    ASSERT_TRUE(result.first);
+    ASSERT_EQ(hashIOBuf(data.get()), hashIOBuf(result.second.get()));
+  }
+  EXPECT_TRUE(input.empty());
+}
+
+TEST_P(StreamingCompressionTest, uncompressStream) {
+  runUncompressStreamTest(constantDataHolder);
+  runUncompressStreamTest(randomDataHolder);
+}
+
+void StreamingCompressionTest::runFlushTest(DataHolder const& dh) {
+  auto const inputs = split(dh.data(uncompressedLength_));
+  auto uncodec = getStreamCodec(codec_->type());
+
+  codec_->resetStream();
+  for (auto input : inputs) {
+    // Compress some data and flush the stream
+    auto compressed = compressSome(
+        codec_.get(), input, chunkSize_, StreamCodec::FlushOp::FLUSH);
+    auto compressedRange = compressed->coalesce();
+    // Uncompress the compressed data
+    auto result = uncompressSome(
+        uncodec.get(),
+        compressedRange,
+        chunkSize_,
+        StreamCodec::FlushOp::FLUSH);
+    // All compressed data should have been consumed
+    EXPECT_TRUE(compressedRange.empty());
+    // The frame isn't complete
+    EXPECT_FALSE(result.first);
+    // The uncompressed data should be exactly the input data
+    EXPECT_EQ(input.size(), result.second->computeChainDataLength());
+    auto const data = IOBuf::wrapBuffer(input);
+    EXPECT_EQ(hashIOBuf(data.get()), hashIOBuf(result.second.get()));
+  }
+}
+
+TEST_P(StreamingCompressionTest, testFlush) {
+  runFlushTest(constantDataHolder);
+  runFlushTest(randomDataHolder);
+}
+
+INSTANTIATE_TEST_CASE_P(
+    StreamingCompressionTest,
+    StreamingCompressionTest,
+    testing::Combine(
+        testing::Values(0, 1, 12, 22, 27),
+        testing::Values(12, 17, 20),
+        testing::ValuesIn(availableStreamCodecs())));
+
 class AutomaticCodecTest : public testing::TestWithParam<CodecType> {
  protected:
   void SetUp() override {
@@ -488,7 +936,7 @@ namespace {
 class CustomCodec : public Codec {
  public:
   static std::unique_ptr<Codec> create(std::string prefix, CodecType type) {
-    return make_unique<CustomCodec>(std::move(prefix), type);
+    return std::make_unique<CustomCodec>(std::move(prefix), type);
   }
   explicit CustomCodec(std::string prefix, CodecType type)
       : Codec(CodecType::USER_DEFINED),
@@ -500,6 +948,10 @@ class CustomCodec : public Codec {
     return {prefix_};
   }
 
+  uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override {
+    return codec_->maxCompressedLength(uncompressedLength) + prefix_.size();
+  }
+
   bool canUncompress(const IOBuf* data, Optional<uint64_t>) const override {
     auto clone = data->cloneCoalescedAsValue();
     if (clone.length() < prefix_.size()) {
@@ -650,6 +1102,31 @@ TEST(CheckCompatibleTest, ZlibIsPrefix) {
   EXPECT_THROW_IF_DEBUG(
       getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument);
 }
+
+#if FOLLY_HAVE_LIBZSTD
+
+TEST(ZstdTest, BackwardCompatible) {
+  auto codec = getCodec(CodecType::ZSTD);
+  {
+    auto const data = IOBuf::wrapBuffer(randomDataHolder.data(size_t(1) << 20));
+    auto compressed = codec->compress(data.get());
+    compressed->coalesce();
+    EXPECT_EQ(
+        data->length(),
+        ZSTD_getDecompressedSize(compressed->data(), compressed->length()));
+  }
+  {
+    auto const data =
+        IOBuf::wrapBuffer(randomDataHolder.data(size_t(100) << 20));
+    auto compressed = codec->compress(data.get());
+    compressed->coalesce();
+    EXPECT_EQ(
+        data->length(),
+        ZSTD_getDecompressedSize(compressed->data(), compressed->length()));
+  }
+}
+
+#endif
 }}}  // namespaces
 
 int main(int argc, char *argv[]) {