4bbf50271177b8aec31ec1054e7607c5349d9a08
[folly.git] / folly / compression / test / CompressionTest.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/compression/Compression.h>
18
19 #include <algorithm>
20 #include <random>
21 #include <set>
22 #include <thread>
23 #include <unordered_map>
24 #include <utility>
25
26 #include <boost/noncopyable.hpp>
27 #include <glog/logging.h>
28
29 #include <folly/Benchmark.h>
30 #include <folly/Hash.h>
31 #include <folly/Memory.h>
32 #include <folly/Random.h>
33 #include <folly/Varint.h>
34 #include <folly/io/IOBufQueue.h>
35 #include <folly/portability/GTest.h>
36
37 #if FOLLY_HAVE_LIBZSTD
38 #include <zstd.h>
39 #endif
40
41 #if FOLLY_HAVE_LIBZ
42 #include <folly/compression/Zlib.h>
43 #endif
44
45 namespace zlib = folly::io::zlib;
46
47 namespace folly {
48 namespace io {
49 namespace test {
50
51 class DataHolder : private boost::noncopyable {
52  public:
53   uint64_t hash(size_t size) const;
54   ByteRange data(size_t size) const;
55
56  protected:
57   explicit DataHolder(size_t sizeLog2);
58   const size_t size_;
59   std::unique_ptr<uint8_t[]> data_;
60   mutable std::unordered_map<uint64_t, uint64_t> hashCache_;
61 };
62
63 DataHolder::DataHolder(size_t sizeLog2)
64   : size_(size_t(1) << sizeLog2),
65     data_(new uint8_t[size_]) {
66 }
67
68 uint64_t DataHolder::hash(size_t size) const {
69   CHECK_LE(size, size_);
70   auto p = hashCache_.find(size);
71   if (p != hashCache_.end()) {
72     return p->second;
73   }
74
75   uint64_t h = folly::hash::fnv64_buf(data_.get(), size);
76   hashCache_[size] = h;
77   return h;
78 }
79
80 ByteRange DataHolder::data(size_t size) const {
81   CHECK_LE(size, size_);
82   return ByteRange(data_.get(), size);
83 }
84
85 uint64_t hashIOBuf(const IOBuf* buf) {
86   uint64_t h = folly::hash::FNV_64_HASH_START;
87   for (auto& range : *buf) {
88     h = folly::hash::fnv64_buf(range.data(), range.size(), h);
89   }
90   return h;
91 }
92
93 class RandomDataHolder : public DataHolder {
94  public:
95   explicit RandomDataHolder(size_t sizeLog2);
96 };
97
98 RandomDataHolder::RandomDataHolder(size_t sizeLog2)
99   : DataHolder(sizeLog2) {
100   static constexpr size_t numThreadsLog2 = 3;
101   static constexpr size_t numThreads = size_t(1) << numThreadsLog2;
102
103   uint32_t seed = randomNumberSeed();
104
105   std::vector<std::thread> threads;
106   threads.reserve(numThreads);
107   for (size_t t = 0; t < numThreads; ++t) {
108     threads.emplace_back([this, seed, t, sizeLog2] {
109       std::mt19937 rng(seed + t);
110       size_t countLog2 = sizeLog2 - numThreadsLog2;
111       size_t start = size_t(t) << countLog2;
112       for (size_t i = 0; i < countLog2; ++i) {
113         this->data_[start + i] = rng();
114       }
115     });
116   }
117
118   for (auto& t : threads) {
119     t.join();
120   }
121 }
122
123 class ConstantDataHolder : public DataHolder {
124  public:
125   explicit ConstantDataHolder(size_t sizeLog2);
126 };
127
128 ConstantDataHolder::ConstantDataHolder(size_t sizeLog2)
129   : DataHolder(sizeLog2) {
130   memset(data_.get(), 'a', size_);
131 }
132
133 constexpr size_t dataSizeLog2 = 27;  // 128MiB
134 RandomDataHolder randomDataHolder(dataSizeLog2);
135 ConstantDataHolder constantDataHolder(dataSizeLog2);
136
137 // The intersection of the provided codecs & those that are compiled in.
138 static std::vector<CodecType> supportedCodecs(std::vector<CodecType> const& v) {
139   std::vector<CodecType> supported;
140
141   std::copy_if(
142       std::begin(v),
143       std::end(v),
144       std::back_inserter(supported),
145       hasCodec);
146
147   return supported;
148 }
149
150 // All compiled-in compression codecs.
151 static std::vector<CodecType> availableCodecs() {
152   std::vector<CodecType> codecs;
153
154   for (size_t i = 0; i < static_cast<size_t>(CodecType::NUM_CODEC_TYPES); ++i) {
155     auto type = static_cast<CodecType>(i);
156     if (hasCodec(type)) {
157       codecs.push_back(type);
158     }
159   }
160
161   return codecs;
162 }
163
164 static std::vector<CodecType> availableStreamCodecs() {
165   std::vector<CodecType> codecs;
166
167   for (size_t i = 0; i < static_cast<size_t>(CodecType::NUM_CODEC_TYPES); ++i) {
168     auto type = static_cast<CodecType>(i);
169     if (hasStreamCodec(type)) {
170       codecs.push_back(type);
171     }
172   }
173
174   return codecs;
175 }
176
177 TEST(CompressionTestNeedsUncompressedLength, Simple) {
178   static const struct {
179     CodecType type;
180     bool needsUncompressedLength;
181   } expectations[] = {
182       {CodecType::NO_COMPRESSION, false},
183       {CodecType::LZ4, true},
184       {CodecType::SNAPPY, false},
185       {CodecType::ZLIB, false},
186       {CodecType::LZ4_VARINT_SIZE, false},
187       {CodecType::LZMA2, false},
188       {CodecType::LZMA2_VARINT_SIZE, false},
189       {CodecType::ZSTD, false},
190       {CodecType::GZIP, false},
191       {CodecType::LZ4_FRAME, false},
192       {CodecType::BZIP2, false},
193   };
194
195   for (auto const& test : expectations) {
196     if (hasCodec(test.type)) {
197       EXPECT_EQ(getCodec(test.type)->needsUncompressedLength(),
198                 test.needsUncompressedLength);
199     }
200   }
201 }
202
203 class CompressionTest
204     : public testing::TestWithParam<std::tr1::tuple<int, int, CodecType>> {
205  protected:
206   void SetUp() override {
207     auto tup = GetParam();
208     int lengthLog = std::tr1::get<0>(tup);
209     // Small hack to test empty data
210     uncompressedLength_ =
211         (lengthLog < 0) ? 0 : uint64_t(1) << std::tr1::get<0>(tup);
212     chunks_ = std::tr1::get<1>(tup);
213     codec_ = getCodec(std::tr1::get<2>(tup));
214   }
215
216   void runSimpleIOBufTest(const DataHolder& dh);
217
218   void runSimpleStringTest(const DataHolder& dh);
219
220  private:
221   std::unique_ptr<IOBuf> split(std::unique_ptr<IOBuf> data) const;
222
223   uint64_t uncompressedLength_;
224   size_t chunks_;
225   std::unique_ptr<Codec> codec_;
226 };
227
228 void CompressionTest::runSimpleIOBufTest(const DataHolder& dh) {
229   const auto original = split(IOBuf::wrapBuffer(dh.data(uncompressedLength_)));
230   const auto compressed = split(codec_->compress(original.get()));
231   EXPECT_LE(
232       compressed->computeChainDataLength(),
233       codec_->maxCompressedLength(uncompressedLength_));
234   if (!codec_->needsUncompressedLength()) {
235     auto uncompressed = codec_->uncompress(compressed.get());
236     EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
237     EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
238   }
239   {
240     auto uncompressed = codec_->uncompress(compressed.get(),
241                                            uncompressedLength_);
242     EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
243     EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
244   }
245 }
246
247 void CompressionTest::runSimpleStringTest(const DataHolder& dh) {
248   const auto original = std::string(
249       reinterpret_cast<const char*>(dh.data(uncompressedLength_).data()),
250       uncompressedLength_);
251   const auto compressed = codec_->compress(original);
252   EXPECT_LE(
253       compressed.length(), codec_->maxCompressedLength(uncompressedLength_));
254
255   if (!codec_->needsUncompressedLength()) {
256     auto uncompressed = codec_->uncompress(compressed);
257     EXPECT_EQ(uncompressedLength_, uncompressed.length());
258     EXPECT_EQ(uncompressed, original);
259   }
260   {
261     auto uncompressed = codec_->uncompress(compressed, uncompressedLength_);
262     EXPECT_EQ(uncompressedLength_, uncompressed.length());
263     EXPECT_EQ(uncompressed, original);
264   }
265 }
266
267 // Uniformly split data into (potentially empty) chunks.
268 std::unique_ptr<IOBuf> CompressionTest::split(
269     std::unique_ptr<IOBuf> data) const {
270   if (data->isChained()) {
271     data->coalesce();
272   }
273
274   const size_t size = data->computeChainDataLength();
275
276   std::multiset<size_t> splits;
277   for (size_t i = 1; i < chunks_; ++i) {
278     splits.insert(Random::rand64(size));
279   }
280
281   folly::IOBufQueue result;
282
283   size_t offset = 0;
284   for (size_t split : splits) {
285     result.append(IOBuf::copyBuffer(data->data() + offset, split - offset));
286     offset = split;
287   }
288   result.append(IOBuf::copyBuffer(data->data() + offset, size - offset));
289
290   return result.move();
291 }
292
293 TEST_P(CompressionTest, RandomData) {
294   runSimpleIOBufTest(randomDataHolder);
295 }
296
297 TEST_P(CompressionTest, ConstantData) {
298   runSimpleIOBufTest(constantDataHolder);
299 }
300
301 TEST_P(CompressionTest, RandomDataString) {
302   runSimpleStringTest(randomDataHolder);
303 }
304
305 TEST_P(CompressionTest, ConstantDataString) {
306   runSimpleStringTest(constantDataHolder);
307 }
308
309 INSTANTIATE_TEST_CASE_P(
310     CompressionTest,
311     CompressionTest,
312     testing::Combine(
313         testing::Values(-1, 0, 1, 12, 22, 25, 27),
314         testing::Values(1, 2, 3, 8, 65),
315         testing::ValuesIn(availableCodecs())));
316
317 class CompressionVarintTest
318     : public testing::TestWithParam<std::tr1::tuple<int, CodecType>> {
319  protected:
320   void SetUp() override {
321     auto tup = GetParam();
322     uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
323     codec_ = getCodec(std::tr1::get<1>(tup));
324   }
325
326   void runSimpleTest(const DataHolder& dh);
327
328   uint64_t uncompressedLength_;
329   std::unique_ptr<Codec> codec_;
330 };
331
332 inline uint64_t oneBasedMsbPos(uint64_t number) {
333   uint64_t pos = 0;
334   for (; number > 0; ++pos, number >>= 1) {
335   }
336   return pos;
337 }
338
339 void CompressionVarintTest::runSimpleTest(const DataHolder& dh) {
340   auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength_));
341   auto compressed = codec_->compress(original.get());
342   auto breakPoint =
343       1UL +
344       Random::rand64(
345           std::max(uint64_t(9), oneBasedMsbPos(uncompressedLength_)) / 9UL);
346   auto tinyBuf = IOBuf::copyBuffer(compressed->data(),
347                                    std::min(compressed->length(), breakPoint));
348   compressed->trimStart(breakPoint);
349   tinyBuf->prependChain(std::move(compressed));
350   compressed = std::move(tinyBuf);
351
352   auto uncompressed = codec_->uncompress(compressed.get());
353
354   EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
355   EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
356 }
357
358 TEST_P(CompressionVarintTest, RandomData) {
359   runSimpleTest(randomDataHolder);
360 }
361
362 TEST_P(CompressionVarintTest, ConstantData) {
363   runSimpleTest(constantDataHolder);
364 }
365
366 INSTANTIATE_TEST_CASE_P(
367     CompressionVarintTest,
368     CompressionVarintTest,
369     testing::Combine(
370         testing::Values(0, 1, 12, 22, 25, 27),
371         testing::ValuesIn(supportedCodecs({
372             CodecType::LZ4_VARINT_SIZE,
373             CodecType::LZMA2_VARINT_SIZE,
374         }))));
375
376 TEST(LZMATest, UncompressBadVarint) {
377   if (hasStreamCodec(CodecType::LZMA2_VARINT_SIZE)) {
378     std::string const str(kMaxVarintLength64 * 2, '\xff');
379     ByteRange input((folly::StringPiece(str)));
380     auto codec = getStreamCodec(CodecType::LZMA2_VARINT_SIZE);
381     auto buffer = IOBuf::create(16);
382     buffer->append(buffer->capacity());
383     MutableByteRange output{buffer->writableData(), buffer->length()};
384     EXPECT_THROW(codec->uncompressStream(input, output), std::runtime_error);
385   }
386 }
387
388 class CompressionCorruptionTest : public testing::TestWithParam<CodecType> {
389  protected:
390   void SetUp() override { codec_ = getCodec(GetParam()); }
391
392   void runSimpleTest(const DataHolder& dh);
393
394   std::unique_ptr<Codec> codec_;
395 };
396
397 void CompressionCorruptionTest::runSimpleTest(const DataHolder& dh) {
398   constexpr uint64_t uncompressedLength = 42;
399   auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength));
400   auto compressed = codec_->compress(original.get());
401
402   if (!codec_->needsUncompressedLength()) {
403     auto uncompressed = codec_->uncompress(compressed.get());
404     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
405     EXPECT_EQ(dh.hash(uncompressedLength), hashIOBuf(uncompressed.get()));
406   }
407   {
408     auto uncompressed = codec_->uncompress(compressed.get(),
409                                            uncompressedLength);
410     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
411     EXPECT_EQ(dh.hash(uncompressedLength), hashIOBuf(uncompressed.get()));
412   }
413
414   EXPECT_THROW(codec_->uncompress(compressed.get(), uncompressedLength + 1),
415                std::runtime_error);
416
417   auto corrupted = compressed->clone();
418   corrupted->unshare();
419   // Truncate the last character
420   corrupted->prev()->trimEnd(1);
421   if (!codec_->needsUncompressedLength()) {
422     EXPECT_THROW(codec_->uncompress(corrupted.get()),
423                  std::runtime_error);
424   }
425
426   EXPECT_THROW(codec_->uncompress(corrupted.get(), uncompressedLength),
427                std::runtime_error);
428
429   corrupted = compressed->clone();
430   corrupted->unshare();
431   // Corrupt the first character
432   ++(corrupted->writableData()[0]);
433
434   if (!codec_->needsUncompressedLength()) {
435     EXPECT_THROW(codec_->uncompress(corrupted.get()),
436                  std::runtime_error);
437   }
438
439   EXPECT_THROW(codec_->uncompress(corrupted.get(), uncompressedLength),
440                std::runtime_error);
441 }
442
443 TEST_P(CompressionCorruptionTest, RandomData) {
444   runSimpleTest(randomDataHolder);
445 }
446
447 TEST_P(CompressionCorruptionTest, ConstantData) {
448   runSimpleTest(constantDataHolder);
449 }
450
451 INSTANTIATE_TEST_CASE_P(
452     CompressionCorruptionTest,
453     CompressionCorruptionTest,
454     testing::ValuesIn(
455         // NO_COMPRESSION can't detect corruption
456         // LZ4 can't detect corruption reliably (sigh)
457         supportedCodecs({
458             CodecType::SNAPPY,
459             CodecType::ZLIB,
460             CodecType::LZMA2,
461             CodecType::ZSTD,
462             CodecType::LZ4_FRAME,
463             CodecType::BZIP2,
464         })));
465
466 class StreamingUnitTest : public testing::TestWithParam<CodecType> {
467  protected:
468   void SetUp() override {
469     codec_ = getStreamCodec(GetParam());
470   }
471
472   std::unique_ptr<StreamCodec> codec_;
473 };
474
475 TEST(StreamingUnitTest, needsDataLength) {
476   static const struct {
477     CodecType type;
478     bool needsDataLength;
479   } expectations[] = {
480       {CodecType::ZLIB, false},
481       {CodecType::GZIP, false},
482       {CodecType::LZMA2, false},
483       {CodecType::LZMA2_VARINT_SIZE, true},
484       {CodecType::ZSTD, false},
485   };
486
487   for (auto const& test : expectations) {
488     if (hasStreamCodec(test.type)) {
489       EXPECT_EQ(
490           getStreamCodec(test.type)->needsDataLength(), test.needsDataLength);
491     }
492   }
493 }
494
495 TEST_P(StreamingUnitTest, maxCompressedLength) {
496   for (uint64_t const length : {1, 10, 100, 1000, 10000, 100000, 1000000}) {
497     EXPECT_GE(codec_->maxCompressedLength(length), length);
498   }
499 }
500
501 TEST_P(StreamingUnitTest, getUncompressedLength) {
502   auto const empty = IOBuf::create(0);
503   EXPECT_EQ(uint64_t(0), codec_->getUncompressedLength(empty.get()));
504   EXPECT_EQ(uint64_t(0), codec_->getUncompressedLength(empty.get(), 0));
505   EXPECT_ANY_THROW(codec_->getUncompressedLength(empty.get(), 1));
506
507   auto const data = IOBuf::wrapBuffer(randomDataHolder.data(100));
508   auto const compressed = codec_->compress(data.get());
509
510   if (auto const length = codec_->getUncompressedLength(data.get())) {
511     EXPECT_EQ(100, *length);
512   }
513   EXPECT_EQ(uint64_t(100), codec_->getUncompressedLength(data.get(), 100));
514   // If the uncompressed length is stored in the frame, then make sure it throws
515   // when it is given the wrong length.
516   if (codec_->getUncompressedLength(data.get()) == uint64_t(100)) {
517     EXPECT_ANY_THROW(codec_->getUncompressedLength(data.get(), 200));
518   }
519 }
520
521 TEST_P(StreamingUnitTest, emptyData) {
522   ByteRange input{};
523   auto buffer = IOBuf::create(codec_->maxCompressedLength(0));
524   buffer->append(buffer->capacity());
525   MutableByteRange output;
526
527   // Test compressing empty data in one pass
528   if (!codec_->needsDataLength()) {
529     output = {buffer->writableData(), buffer->length()};
530     EXPECT_TRUE(
531         codec_->compressStream(input, output, StreamCodec::FlushOp::END));
532   }
533   codec_->resetStream(0);
534   output = {buffer->writableData(), buffer->length()};
535   EXPECT_TRUE(codec_->compressStream(input, output, StreamCodec::FlushOp::END));
536
537   // Test uncompressing the compressed empty data is equivalent to the empty
538   // string
539   {
540     size_t compressedSize = buffer->length() - output.size();
541     auto const compressed =
542         IOBuf::copyBuffer(buffer->writableData(), compressedSize);
543     auto inputRange = compressed->coalesce();
544     codec_->resetStream(0);
545     output = {buffer->writableData(), buffer->length()};
546     EXPECT_TRUE(codec_->uncompressStream(
547         inputRange, output, StreamCodec::FlushOp::END));
548     EXPECT_EQ(output.size(), buffer->length());
549   }
550
551   // Test compressing empty data with multiple calls to compressStream()
552   {
553     auto largeBuffer = IOBuf::create(codec_->maxCompressedLength(0) * 2);
554     largeBuffer->append(largeBuffer->capacity());
555     codec_->resetStream(0);
556     output = {largeBuffer->writableData(), largeBuffer->length()};
557     EXPECT_FALSE(codec_->compressStream(input, output));
558     EXPECT_TRUE(
559         codec_->compressStream(input, output, StreamCodec::FlushOp::FLUSH));
560     EXPECT_TRUE(
561         codec_->compressStream(input, output, StreamCodec::FlushOp::END));
562   }
563
564   // Test uncompressing empty data
565   output = {};
566   codec_->resetStream();
567   EXPECT_TRUE(codec_->uncompressStream(input, output));
568   codec_->resetStream();
569   EXPECT_TRUE(
570       codec_->uncompressStream(input, output, StreamCodec::FlushOp::FLUSH));
571   codec_->resetStream();
572   EXPECT_TRUE(
573       codec_->uncompressStream(input, output, StreamCodec::FlushOp::END));
574   codec_->resetStream(0);
575   EXPECT_TRUE(codec_->uncompressStream(input, output));
576   codec_->resetStream(0);
577   EXPECT_TRUE(
578       codec_->uncompressStream(input, output, StreamCodec::FlushOp::FLUSH));
579   codec_->resetStream(0);
580   EXPECT_TRUE(
581       codec_->uncompressStream(input, output, StreamCodec::FlushOp::END));
582 }
583
584 TEST_P(StreamingUnitTest, noForwardProgress) {
585   auto inBuffer = IOBuf::create(2);
586   inBuffer->writableData()[0] = 'a';
587   inBuffer->writableData()[1] = 'a';
588   inBuffer->append(2);
589   const auto compressed = codec_->compress(inBuffer.get());
590   auto outBuffer = IOBuf::create(codec_->maxCompressedLength(2));
591
592   ByteRange emptyInput;
593   MutableByteRange emptyOutput;
594
595   const std::array<StreamCodec::FlushOp, 3> flushOps = {{
596       StreamCodec::FlushOp::NONE,
597       StreamCodec::FlushOp::FLUSH,
598       StreamCodec::FlushOp::END,
599   }};
600
601   // No progress is not okay twice in a row for all flush operations when
602   // compressing
603   for (const auto flushOp : flushOps) {
604     if (codec_->needsDataLength()) {
605       codec_->resetStream(inBuffer->computeChainDataLength());
606     } else {
607       codec_->resetStream();
608     }
609     auto input = inBuffer->coalesce();
610     MutableByteRange output = {outBuffer->writableTail(),
611                                outBuffer->tailroom()};
612     // Compress some data to avoid empty data special casing
613     while (!input.empty()) {
614       codec_->compressStream(input, output);
615     }
616     EXPECT_FALSE(codec_->compressStream(emptyInput, emptyOutput, flushOp));
617     EXPECT_THROW(
618         codec_->compressStream(emptyInput, emptyOutput, flushOp),
619         std::runtime_error);
620   }
621
622   // No progress is not okay twice in a row for all flush operations when
623   // uncompressing
624   for (const auto flushOp : flushOps) {
625     codec_->resetStream();
626     auto input = compressed->coalesce();
627     // Remove the last byte so the operation is incomplete
628     input.uncheckedSubtract(1);
629     MutableByteRange output = {inBuffer->writableData(), inBuffer->length()};
630     // Uncompress some data to avoid empty data special casing
631     while (!input.empty()) {
632       EXPECT_FALSE(codec_->uncompressStream(input, output));
633     }
634     EXPECT_FALSE(codec_->uncompressStream(emptyInput, emptyOutput, flushOp));
635     EXPECT_THROW(
636         codec_->uncompressStream(emptyInput, emptyOutput, flushOp),
637         std::runtime_error);
638   }
639 }
640
641 TEST_P(StreamingUnitTest, stateTransitions) {
642   auto inBuffer = IOBuf::create(2);
643   inBuffer->writableData()[0] = 'a';
644   inBuffer->writableData()[1] = 'a';
645   inBuffer->append(2);
646   auto compressed = codec_->compress(inBuffer.get());
647   ByteRange const in = compressed->coalesce();
648   auto outBuffer = IOBuf::create(codec_->maxCompressedLength(in.size()));
649   MutableByteRange const out{outBuffer->writableTail(), outBuffer->tailroom()};
650
651   auto compress = [&](
652       StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE,
653       bool empty = false) {
654     auto input = in;
655     auto output = empty ? MutableByteRange{} : out;
656     return codec_->compressStream(input, output, flushOp);
657   };
658   auto compress_all = [&](bool expect,
659                           StreamCodec::FlushOp flushOp =
660                               StreamCodec::FlushOp::NONE,
661                           bool empty = false) {
662     auto input = in;
663     auto output = empty ? MutableByteRange{} : out;
664     while (!input.empty()) {
665       if (expect) {
666         EXPECT_TRUE(codec_->compressStream(input, output, flushOp));
667       } else {
668         EXPECT_FALSE(codec_->compressStream(input, output, flushOp));
669       }
670     }
671   };
672   auto uncompress = [&](
673       StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE,
674       bool empty = false) {
675     auto input = in;
676     auto output = empty ? MutableByteRange{} : out;
677     return codec_->uncompressStream(input, output, flushOp);
678   };
679
680   // compression flow
681   if (!codec_->needsDataLength()) {
682     codec_->resetStream();
683     EXPECT_FALSE(compress());
684     EXPECT_FALSE(compress());
685     EXPECT_TRUE(compress(StreamCodec::FlushOp::FLUSH));
686     EXPECT_FALSE(compress());
687     EXPECT_TRUE(compress(StreamCodec::FlushOp::END));
688   }
689   codec_->resetStream(in.size() * 5);
690   compress_all(false);
691   compress_all(false);
692   compress_all(true, StreamCodec::FlushOp::FLUSH);
693   compress_all(false);
694   compress_all(true, StreamCodec::FlushOp::END);
695
696   // uncompression flow
697   codec_->resetStream();
698   EXPECT_FALSE(uncompress(StreamCodec::FlushOp::NONE, true));
699   codec_->resetStream();
700   EXPECT_FALSE(uncompress(StreamCodec::FlushOp::FLUSH, true));
701   codec_->resetStream();
702   EXPECT_FALSE(uncompress(StreamCodec::FlushOp::NONE, true));
703   codec_->resetStream();
704   EXPECT_FALSE(uncompress(StreamCodec::FlushOp::NONE, true));
705   codec_->resetStream();
706   EXPECT_TRUE(uncompress(StreamCodec::FlushOp::FLUSH));
707   // compress -> uncompress
708   codec_->resetStream(in.size());
709   EXPECT_FALSE(compress());
710   EXPECT_THROW(uncompress(), std::logic_error);
711   // uncompress -> compress
712   codec_->resetStream(inBuffer->computeChainDataLength());
713   EXPECT_TRUE(uncompress(StreamCodec::FlushOp::FLUSH));
714   EXPECT_THROW(compress(), std::logic_error);
715   // end -> compress
716   if (!codec_->needsDataLength()) {
717     codec_->resetStream();
718     EXPECT_FALSE(compress());
719     EXPECT_TRUE(compress(StreamCodec::FlushOp::END));
720     EXPECT_THROW(compress(), std::logic_error);
721   }
722   codec_->resetStream(in.size() * 2);
723   compress_all(false);
724   compress_all(true, StreamCodec::FlushOp::END);
725   EXPECT_THROW(compress(), std::logic_error);
726   // end -> uncompress
727   codec_->resetStream();
728   EXPECT_TRUE(uncompress(StreamCodec::FlushOp::FLUSH));
729   EXPECT_THROW(uncompress(), std::logic_error);
730   // flush -> compress
731   codec_->resetStream(in.size());
732   EXPECT_FALSE(compress(StreamCodec::FlushOp::FLUSH, true));
733   EXPECT_THROW(compress(), std::logic_error);
734   // flush -> end
735   codec_->resetStream(in.size());
736   EXPECT_FALSE(compress(StreamCodec::FlushOp::FLUSH, true));
737   EXPECT_THROW(compress(StreamCodec::FlushOp::END), std::logic_error);
738   // undefined -> compress
739   codec_->compress(inBuffer.get());
740   EXPECT_THROW(compress(), std::logic_error);
741   codec_->uncompress(compressed.get(), inBuffer->computeChainDataLength());
742   EXPECT_THROW(compress(), std::logic_error);
743   // undefined -> undefined
744   codec_->uncompress(compressed.get());
745   codec_->compress(inBuffer.get());
746 }
747
748 INSTANTIATE_TEST_CASE_P(
749     StreamingUnitTest,
750     StreamingUnitTest,
751     testing::ValuesIn(availableStreamCodecs()));
752
753 class StreamingCompressionTest
754     : public testing::TestWithParam<std::tuple<int, int, CodecType>> {
755  protected:
756   void SetUp() override {
757     auto const tup = GetParam();
758     uncompressedLength_ = uint64_t(1) << std::get<0>(tup);
759     chunkSize_ = size_t(1) << std::get<1>(tup);
760     codec_ = getStreamCodec(std::get<2>(tup));
761   }
762
763   void runResetStreamTest(DataHolder const& dh);
764   void runCompressStreamTest(DataHolder const& dh);
765   void runUncompressStreamTest(DataHolder const& dh);
766   void runFlushTest(DataHolder const& dh);
767
768  private:
769   std::vector<ByteRange> split(ByteRange data) const;
770
771   uint64_t uncompressedLength_;
772   size_t chunkSize_;
773   std::unique_ptr<StreamCodec> codec_;
774 };
775
776 std::vector<ByteRange> StreamingCompressionTest::split(ByteRange data) const {
777   size_t const pieces = std::max<size_t>(1, data.size() / chunkSize_);
778   std::vector<ByteRange> result;
779   result.reserve(pieces + 1);
780   while (!data.empty()) {
781     size_t const pieceSize = std::min(data.size(), chunkSize_);
782     result.push_back(data.subpiece(0, pieceSize));
783     data.uncheckedAdvance(pieceSize);
784   }
785   return result;
786 }
787
788 static std::unique_ptr<IOBuf> compressSome(
789     StreamCodec* codec,
790     ByteRange data,
791     uint64_t bufferSize,
792     StreamCodec::FlushOp flush) {
793   bool result;
794   IOBufQueue queue;
795   do {
796     auto buffer = IOBuf::create(bufferSize);
797     buffer->append(buffer->capacity());
798     MutableByteRange output{buffer->writableData(), buffer->length()};
799
800     result = codec->compressStream(data, output, flush);
801     buffer->trimEnd(output.size());
802     queue.append(std::move(buffer));
803
804   } while (!(flush == StreamCodec::FlushOp::NONE && data.empty()) && !result);
805   EXPECT_TRUE(data.empty());
806   return queue.move();
807 }
808
809 static std::pair<bool, std::unique_ptr<IOBuf>> uncompressSome(
810     StreamCodec* codec,
811     ByteRange& data,
812     uint64_t bufferSize,
813     StreamCodec::FlushOp flush) {
814   bool result;
815   IOBufQueue queue;
816   do {
817     auto buffer = IOBuf::create(bufferSize);
818     buffer->append(buffer->capacity());
819     MutableByteRange output{buffer->writableData(), buffer->length()};
820
821     result = codec->uncompressStream(data, output, flush);
822     buffer->trimEnd(output.size());
823     queue.append(std::move(buffer));
824
825   } while (queue.tailroom() == 0 && !result);
826   return std::make_pair(result, queue.move());
827 }
828
829 void StreamingCompressionTest::runResetStreamTest(DataHolder const& dh) {
830   auto const input = dh.data(uncompressedLength_);
831   // Compress some but leave state unclean
832   codec_->resetStream(uncompressedLength_);
833   compressSome(codec_.get(), input, chunkSize_, StreamCodec::FlushOp::NONE);
834   // Reset stream and compress all
835   if (codec_->needsDataLength()) {
836     codec_->resetStream(uncompressedLength_);
837   } else {
838     codec_->resetStream();
839   }
840   auto compressed =
841       compressSome(codec_.get(), input, chunkSize_, StreamCodec::FlushOp::END);
842   auto const uncompressed = codec_->uncompress(compressed.get(), input.size());
843   EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
844 }
845
846 TEST_P(StreamingCompressionTest, resetStream) {
847   runResetStreamTest(constantDataHolder);
848   runResetStreamTest(randomDataHolder);
849 }
850
851 void StreamingCompressionTest::runCompressStreamTest(
852     const folly::io::test::DataHolder& dh) {
853   auto const inputs = split(dh.data(uncompressedLength_));
854
855   IOBufQueue queue;
856   codec_->resetStream(uncompressedLength_);
857   // Compress many inputs in a row
858   for (auto const input : inputs) {
859     queue.append(compressSome(
860         codec_.get(), input, chunkSize_, StreamCodec::FlushOp::NONE));
861   }
862   // Finish the operation with empty input.
863   ByteRange empty;
864   queue.append(
865       compressSome(codec_.get(), empty, chunkSize_, StreamCodec::FlushOp::END));
866
867   auto const uncompressed = codec_->uncompress(queue.front());
868   EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
869 }
870
871 TEST_P(StreamingCompressionTest, compressStream) {
872   runCompressStreamTest(constantDataHolder);
873   runCompressStreamTest(randomDataHolder);
874 }
875
876 void StreamingCompressionTest::runUncompressStreamTest(
877     const folly::io::test::DataHolder& dh) {
878   auto const data = IOBuf::wrapBuffer(dh.data(uncompressedLength_));
879   // Concatenate 3 compressed frames in a row
880   auto compressed = codec_->compress(data.get());
881   compressed->prependChain(codec_->compress(data.get()));
882   compressed->prependChain(codec_->compress(data.get()));
883   // Pass all 3 compressed frames in one input buffer
884   auto input = compressed->coalesce();
885   // Uncompress the first frame
886   codec_->resetStream(data->computeChainDataLength());
887   {
888     auto const result = uncompressSome(
889         codec_.get(), input, chunkSize_, StreamCodec::FlushOp::FLUSH);
890     ASSERT_TRUE(result.first);
891     ASSERT_EQ(hashIOBuf(data.get()), hashIOBuf(result.second.get()));
892   }
893   // Uncompress the second frame
894   codec_->resetStream();
895   {
896     auto const result = uncompressSome(
897         codec_.get(), input, chunkSize_, StreamCodec::FlushOp::END);
898     ASSERT_TRUE(result.first);
899     ASSERT_EQ(hashIOBuf(data.get()), hashIOBuf(result.second.get()));
900   }
901   // Uncompress the third frame
902   codec_->resetStream();
903   {
904     auto const result = uncompressSome(
905         codec_.get(), input, chunkSize_, StreamCodec::FlushOp::FLUSH);
906     ASSERT_TRUE(result.first);
907     ASSERT_EQ(hashIOBuf(data.get()), hashIOBuf(result.second.get()));
908   }
909   EXPECT_TRUE(input.empty());
910 }
911
912 TEST_P(StreamingCompressionTest, uncompressStream) {
913   runUncompressStreamTest(constantDataHolder);
914   runUncompressStreamTest(randomDataHolder);
915 }
916
917 void StreamingCompressionTest::runFlushTest(DataHolder const& dh) {
918   auto const inputs = split(dh.data(uncompressedLength_));
919   auto uncodec = getStreamCodec(codec_->type());
920
921   if (codec_->needsDataLength()) {
922     codec_->resetStream(uncompressedLength_);
923   } else {
924     codec_->resetStream();
925   }
926   for (auto input : inputs) {
927     // Compress some data and flush the stream
928     auto compressed = compressSome(
929         codec_.get(), input, chunkSize_, StreamCodec::FlushOp::FLUSH);
930     auto compressedRange = compressed->coalesce();
931     // Uncompress the compressed data
932     auto result = uncompressSome(
933         uncodec.get(),
934         compressedRange,
935         chunkSize_,
936         StreamCodec::FlushOp::FLUSH);
937     // All compressed data should have been consumed
938     EXPECT_TRUE(compressedRange.empty());
939     // The frame isn't complete
940     EXPECT_FALSE(result.first);
941     // The uncompressed data should be exactly the input data
942     EXPECT_EQ(input.size(), result.second->computeChainDataLength());
943     auto const data = IOBuf::wrapBuffer(input);
944     EXPECT_EQ(hashIOBuf(data.get()), hashIOBuf(result.second.get()));
945   }
946 }
947
948 TEST_P(StreamingCompressionTest, testFlush) {
949   runFlushTest(constantDataHolder);
950   runFlushTest(randomDataHolder);
951 }
952
953 INSTANTIATE_TEST_CASE_P(
954     StreamingCompressionTest,
955     StreamingCompressionTest,
956     testing::Combine(
957         testing::Values(0, 1, 12, 22, 27),
958         testing::Values(12, 17, 20),
959         testing::ValuesIn(availableStreamCodecs())));
960
961 namespace {
962
963 // Codec types included in the codec returned by getAutoUncompressionCodec() by
964 // default.
965 std::vector<CodecType> autoUncompressionCodecTypes = {{
966     CodecType::LZ4_FRAME,
967     CodecType::ZSTD,
968     CodecType::ZLIB,
969     CodecType::GZIP,
970     CodecType::LZMA2,
971     CodecType::BZIP2,
972 }};
973
974 } // namespace
975
976 class AutomaticCodecTest : public testing::TestWithParam<CodecType> {
977  protected:
978   void SetUp() override {
979     codecType_ = GetParam();
980     codec_ = getCodec(codecType_);
981     autoType_ = std::any_of(
982         autoUncompressionCodecTypes.begin(),
983         autoUncompressionCodecTypes.end(),
984         [&](CodecType o) { return codecType_ == o; });
985     // Add the codec with type codecType_ as the terminal codec if it is not in
986     // autoUncompressionCodecTypes.
987     auto_ = getAutoUncompressionCodec({}, getTerminalCodec());
988   }
989
990   void runSimpleTest(const DataHolder& dh);
991
992   std::unique_ptr<Codec> getTerminalCodec() {
993     return (autoType_ ? nullptr : getCodec(codecType_));
994   }
995
996   std::unique_ptr<Codec> codec_;
997   std::unique_ptr<Codec> auto_;
998   CodecType codecType_;
999   // true if codecType_ is in autoUncompressionCodecTypes
1000   bool autoType_;
1001 };
1002
1003 void AutomaticCodecTest::runSimpleTest(const DataHolder& dh) {
1004   constexpr uint64_t uncompressedLength = 1000;
1005   auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength));
1006   auto compressed = codec_->compress(original.get());
1007
1008   if (!codec_->needsUncompressedLength()) {
1009     auto uncompressed = auto_->uncompress(compressed.get());
1010     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
1011     EXPECT_EQ(dh.hash(uncompressedLength), hashIOBuf(uncompressed.get()));
1012   }
1013   {
1014     auto uncompressed = auto_->uncompress(compressed.get(), uncompressedLength);
1015     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
1016     EXPECT_EQ(dh.hash(uncompressedLength), hashIOBuf(uncompressed.get()));
1017   }
1018   ASSERT_GE(compressed->computeChainDataLength(), 8);
1019   for (size_t i = 0; i < 8; ++i) {
1020     auto split = compressed->clone();
1021     auto rest = compressed->clone();
1022     split->trimEnd(split->length() - i);
1023     rest->trimStart(i);
1024     split->appendChain(std::move(rest));
1025     auto uncompressed = auto_->uncompress(split.get(), uncompressedLength);
1026     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
1027     EXPECT_EQ(dh.hash(uncompressedLength), hashIOBuf(uncompressed.get()));
1028   }
1029 }
1030
1031 TEST_P(AutomaticCodecTest, RandomData) {
1032   runSimpleTest(randomDataHolder);
1033 }
1034
1035 TEST_P(AutomaticCodecTest, ConstantData) {
1036   runSimpleTest(constantDataHolder);
1037 }
1038
1039 TEST_P(AutomaticCodecTest, ValidPrefixes) {
1040   const auto prefixes = codec_->validPrefixes();
1041   for (const auto& prefix : prefixes) {
1042     EXPECT_FALSE(prefix.empty());
1043     // Ensure that all strings are at least 8 bytes for LZMA2.
1044     // The bytes after the prefix should be ignored by `canUncompress()`.
1045     IOBuf data{IOBuf::COPY_BUFFER, prefix, 0, 8};
1046     data.append(8);
1047     EXPECT_TRUE(codec_->canUncompress(&data));
1048     EXPECT_TRUE(auto_->canUncompress(&data));
1049   }
1050 }
1051
1052 TEST_P(AutomaticCodecTest, NeedsUncompressedLength) {
1053   if (codec_->needsUncompressedLength()) {
1054     EXPECT_TRUE(auto_->needsUncompressedLength());
1055   }
1056 }
1057
1058 TEST_P(AutomaticCodecTest, maxUncompressedLength) {
1059   EXPECT_LE(codec_->maxUncompressedLength(), auto_->maxUncompressedLength());
1060 }
1061
1062 TEST_P(AutomaticCodecTest, DefaultCodec) {
1063   const uint64_t length = 42;
1064   std::vector<std::unique_ptr<Codec>> codecs;
1065   codecs.push_back(getCodec(CodecType::ZSTD));
1066   auto automatic =
1067       getAutoUncompressionCodec(std::move(codecs), getTerminalCodec());
1068   auto original = IOBuf::wrapBuffer(constantDataHolder.data(length));
1069   auto compressed = codec_->compress(original.get());
1070   std::unique_ptr<IOBuf> decompressed;
1071
1072   if (automatic->needsUncompressedLength()) {
1073     decompressed = automatic->uncompress(compressed.get(), length);
1074   } else {
1075     decompressed = automatic->uncompress(compressed.get());
1076   }
1077
1078   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(decompressed.get()));
1079 }
1080
1081 namespace {
1082 class CustomCodec : public Codec {
1083  public:
1084   static std::unique_ptr<Codec> create(std::string prefix, CodecType type) {
1085     return std::make_unique<CustomCodec>(std::move(prefix), type);
1086   }
1087   explicit CustomCodec(std::string prefix, CodecType type)
1088       : Codec(CodecType::USER_DEFINED),
1089         prefix_(std::move(prefix)),
1090         codec_(getCodec(type)) {}
1091
1092  private:
1093   std::vector<std::string> validPrefixes() const override {
1094     return {prefix_};
1095   }
1096
1097   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override {
1098     return codec_->maxCompressedLength(uncompressedLength) + prefix_.size();
1099   }
1100
1101   bool canUncompress(const IOBuf* data, Optional<uint64_t>) const override {
1102     auto clone = data->cloneCoalescedAsValue();
1103     if (clone.length() < prefix_.size()) {
1104       return false;
1105     }
1106     return memcmp(clone.data(), prefix_.data(), prefix_.size()) == 0;
1107   }
1108
1109   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override {
1110     auto result = IOBuf::copyBuffer(prefix_);
1111     result->appendChain(codec_->compress(data));
1112     EXPECT_TRUE(canUncompress(result.get(), data->computeChainDataLength()));
1113     return result;
1114   }
1115
1116   std::unique_ptr<IOBuf> doUncompress(
1117       const IOBuf* data,
1118       Optional<uint64_t> uncompressedLength) override {
1119     EXPECT_TRUE(canUncompress(data, uncompressedLength));
1120     auto clone = data->cloneCoalescedAsValue();
1121     clone.trimStart(prefix_.size());
1122     return codec_->uncompress(&clone, uncompressedLength);
1123   }
1124
1125   std::string prefix_;
1126   std::unique_ptr<Codec> codec_;
1127 };
1128 } // namespace
1129
1130 TEST_P(AutomaticCodecTest, CustomCodec) {
1131   const uint64_t length = 42;
1132   auto ab = CustomCodec::create("ab", CodecType::ZSTD);
1133   std::vector<std::unique_ptr<Codec>> codecs;
1134   codecs.push_back(CustomCodec::create("ab", CodecType::ZSTD));
1135   auto automatic =
1136       getAutoUncompressionCodec(std::move(codecs), getTerminalCodec());
1137   auto original = IOBuf::wrapBuffer(constantDataHolder.data(length));
1138
1139   auto abCompressed = ab->compress(original.get());
1140   std::unique_ptr<IOBuf> abDecompressed;
1141   if (automatic->needsUncompressedLength()) {
1142     abDecompressed = automatic->uncompress(abCompressed.get(), length);
1143   } else {
1144     abDecompressed = automatic->uncompress(abCompressed.get());
1145   }
1146   EXPECT_TRUE(automatic->canUncompress(abCompressed.get()));
1147   EXPECT_FALSE(auto_->canUncompress(abCompressed.get()));
1148   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(abDecompressed.get()));
1149
1150   auto compressed = codec_->compress(original.get());
1151   std::unique_ptr<IOBuf> decompressed;
1152   if (automatic->needsUncompressedLength()) {
1153     decompressed = automatic->uncompress(compressed.get(), length);
1154   } else {
1155     decompressed = automatic->uncompress(compressed.get());
1156   }
1157   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(decompressed.get()));
1158 }
1159
1160 TEST_P(AutomaticCodecTest, CustomDefaultCodec) {
1161   const uint64_t length = 42;
1162   auto none = CustomCodec::create("none", CodecType::NO_COMPRESSION);
1163   std::vector<std::unique_ptr<Codec>> codecs;
1164   codecs.push_back(CustomCodec::create("none", CodecType::NO_COMPRESSION));
1165   codecs.push_back(getCodec(CodecType::LZ4_FRAME));
1166   auto automatic =
1167       getAutoUncompressionCodec(std::move(codecs), getTerminalCodec());
1168   auto original = IOBuf::wrapBuffer(constantDataHolder.data(length));
1169
1170   auto noneCompressed = none->compress(original.get());
1171   std::unique_ptr<IOBuf> noneDecompressed;
1172   if (automatic->needsUncompressedLength()) {
1173     noneDecompressed = automatic->uncompress(noneCompressed.get(), length);
1174   } else {
1175     noneDecompressed = automatic->uncompress(noneCompressed.get());
1176   }
1177   EXPECT_TRUE(automatic->canUncompress(noneCompressed.get()));
1178   EXPECT_FALSE(auto_->canUncompress(noneCompressed.get()));
1179   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(noneDecompressed.get()));
1180
1181   auto compressed = codec_->compress(original.get());
1182   std::unique_ptr<IOBuf> decompressed;
1183   if (automatic->needsUncompressedLength()) {
1184     decompressed = automatic->uncompress(compressed.get(), length);
1185   } else {
1186     decompressed = automatic->uncompress(compressed.get());
1187   }
1188   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(decompressed.get()));
1189 }
1190
1191 TEST_P(AutomaticCodecTest, canUncompressOneBytes) {
1192   // No default codec can uncompress 1 bytes.
1193   IOBuf buf{IOBuf::CREATE, 1};
1194   buf.append(1);
1195   EXPECT_FALSE(codec_->canUncompress(&buf, 1));
1196   EXPECT_FALSE(codec_->canUncompress(&buf, folly::none));
1197   EXPECT_FALSE(auto_->canUncompress(&buf, 1));
1198   EXPECT_FALSE(auto_->canUncompress(&buf, folly::none));
1199 }
1200
1201 INSTANTIATE_TEST_CASE_P(
1202     AutomaticCodecTest,
1203     AutomaticCodecTest,
1204     testing::ValuesIn(availableCodecs()));
1205
1206 namespace {
1207
1208 // Codec that always "uncompresses" to the same string.
1209 class ConstantCodec : public Codec {
1210  public:
1211   static std::unique_ptr<Codec> create(
1212       std::string uncompressed,
1213       CodecType type) {
1214     return std::make_unique<ConstantCodec>(std::move(uncompressed), type);
1215   }
1216   explicit ConstantCodec(std::string uncompressed, CodecType type)
1217       : Codec(type), uncompressed_(std::move(uncompressed)) {}
1218
1219  private:
1220   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override {
1221     return uncompressedLength;
1222   }
1223
1224   std::unique_ptr<IOBuf> doCompress(const IOBuf*) override {
1225     throw std::runtime_error("ConstantCodec error: compress() not supported.");
1226   }
1227
1228   std::unique_ptr<IOBuf> doUncompress(const IOBuf*, Optional<uint64_t>)
1229       override {
1230     return IOBuf::copyBuffer(uncompressed_);
1231   }
1232
1233   std::string uncompressed_;
1234   std::unique_ptr<Codec> codec_;
1235 };
1236
1237 } // namespace
1238
1239 class TerminalCodecTest : public testing::TestWithParam<CodecType> {
1240  protected:
1241   void SetUp() override {
1242     codecType_ = GetParam();
1243     codec_ = getCodec(codecType_);
1244     auto_ = getAutoUncompressionCodec();
1245   }
1246
1247   CodecType codecType_;
1248   std::unique_ptr<Codec> codec_;
1249   std::unique_ptr<Codec> auto_;
1250 };
1251
1252 // Test that the terminal codec's uncompress() function is called when the
1253 // default chosen automatic codec throws.
1254 TEST_P(TerminalCodecTest, uncompressIfDefaultThrows) {
1255   std::string const original = "abc";
1256   auto const compressed = codec_->compress(original);
1257
1258   // Sanity check: the automatic codec can uncompress the original string.
1259   auto const uncompressed = auto_->uncompress(compressed);
1260   EXPECT_EQ(uncompressed, original);
1261
1262   // Truncate the compressed string.
1263   auto const truncated = compressed.substr(0, compressed.size() - 1);
1264   auto const truncatedBuf =
1265       IOBuf::wrapBuffer(truncated.data(), truncated.size());
1266   EXPECT_TRUE(auto_->canUncompress(truncatedBuf.get()));
1267   EXPECT_ANY_THROW(auto_->uncompress(truncated));
1268
1269   // Expect the terminal codec to successfully uncompress the string.
1270   std::unique_ptr<Codec> terminal = getAutoUncompressionCodec(
1271       {}, ConstantCodec::create("dummyString", CodecType::USER_DEFINED));
1272   EXPECT_TRUE(terminal->canUncompress(truncatedBuf.get()));
1273   EXPECT_EQ(terminal->uncompress(truncated), "dummyString");
1274 }
1275
1276 // If the terminal codec has one of the "default types" automatically added in
1277 // the AutomaticCodec, check that the default codec is no longer added.
1278 TEST_P(TerminalCodecTest, terminalOverridesDefaults) {
1279   std::unique_ptr<Codec> terminal = getAutoUncompressionCodec(
1280       {}, ConstantCodec::create("dummyString", codecType_));
1281   std::string const original = "abc";
1282   auto const compressed = codec_->compress(original);
1283   EXPECT_EQ(terminal->uncompress(compressed), "dummyString");
1284 }
1285
1286 INSTANTIATE_TEST_CASE_P(
1287     TerminalCodecTest,
1288     TerminalCodecTest,
1289     testing::ValuesIn(autoUncompressionCodecTypes));
1290
1291 TEST(ValidPrefixesTest, CustomCodec) {
1292   std::vector<std::unique_ptr<Codec>> codecs;
1293   codecs.push_back(CustomCodec::create("none", CodecType::NO_COMPRESSION));
1294   const auto none = getAutoUncompressionCodec(std::move(codecs));
1295   const auto prefixes = none->validPrefixes();
1296   const auto it = std::find(prefixes.begin(), prefixes.end(), "none");
1297   EXPECT_TRUE(it != prefixes.end());
1298 }
1299
1300 #define EXPECT_THROW_IF_DEBUG(statement, expected_exception) \
1301   do {                                                       \
1302     if (kIsDebug) {                                          \
1303       EXPECT_THROW((statement), expected_exception);         \
1304     } else {                                                 \
1305       EXPECT_NO_THROW((statement));                          \
1306     }                                                        \
1307   } while (false)
1308
1309 TEST(CheckCompatibleTest, SimplePrefixSecond) {
1310   std::vector<std::unique_ptr<Codec>> codecs;
1311   codecs.push_back(CustomCodec::create("abc", CodecType::NO_COMPRESSION));
1312   codecs.push_back(CustomCodec::create("ab", CodecType::NO_COMPRESSION));
1313   EXPECT_THROW_IF_DEBUG(
1314       getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument);
1315 }
1316
1317 TEST(CheckCompatibleTest, SimplePrefixFirst) {
1318   std::vector<std::unique_ptr<Codec>> codecs;
1319   codecs.push_back(CustomCodec::create("ab", CodecType::NO_COMPRESSION));
1320   codecs.push_back(CustomCodec::create("abc", CodecType::NO_COMPRESSION));
1321   EXPECT_THROW_IF_DEBUG(
1322       getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument);
1323 }
1324
1325 TEST(CheckCompatibleTest, Empty) {
1326   std::vector<std::unique_ptr<Codec>> codecs;
1327   codecs.push_back(CustomCodec::create("", CodecType::NO_COMPRESSION));
1328   EXPECT_THROW_IF_DEBUG(
1329       getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument);
1330 }
1331
1332 TEST(CheckCompatibleTest, ZstdPrefix) {
1333   std::vector<std::unique_ptr<Codec>> codecs;
1334   codecs.push_back(CustomCodec::create("\x28\xB5\x2F", CodecType::ZSTD));
1335   EXPECT_THROW_IF_DEBUG(
1336       getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument);
1337 }
1338
1339 TEST(CheckCompatibleTest, ZstdDuplicate) {
1340   std::vector<std::unique_ptr<Codec>> codecs;
1341   codecs.push_back(CustomCodec::create("\x28\xB5\x2F\xFD", CodecType::ZSTD));
1342   EXPECT_THROW_IF_DEBUG(
1343       getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument);
1344 }
1345
1346 TEST(CheckCompatibleTest, ZlibIsPrefix) {
1347   std::vector<std::unique_ptr<Codec>> codecs;
1348   codecs.push_back(CustomCodec::create("\x18\x76zzasdf", CodecType::ZSTD));
1349   EXPECT_THROW_IF_DEBUG(
1350       getAutoUncompressionCodec(std::move(codecs)), std::invalid_argument);
1351 }
1352
1353 #if FOLLY_HAVE_LIBZSTD
1354
1355 TEST(ZstdTest, BackwardCompatible) {
1356   auto codec = getCodec(CodecType::ZSTD);
1357   {
1358     auto const data = IOBuf::wrapBuffer(randomDataHolder.data(size_t(1) << 20));
1359     auto compressed = codec->compress(data.get());
1360     compressed->coalesce();
1361     EXPECT_EQ(
1362         data->length(),
1363         ZSTD_getDecompressedSize(compressed->data(), compressed->length()));
1364   }
1365   {
1366     auto const data =
1367         IOBuf::wrapBuffer(randomDataHolder.data(size_t(100) << 20));
1368     auto compressed = codec->compress(data.get());
1369     compressed->coalesce();
1370     EXPECT_EQ(
1371         data->length(),
1372         ZSTD_getDecompressedSize(compressed->data(), compressed->length()));
1373   }
1374 }
1375
1376 #endif
1377
1378 #if FOLLY_HAVE_LIBZ
1379
1380 using ZlibFormat = zlib::Options::Format;
1381
1382 TEST(ZlibTest, Auto) {
1383   size_t const uncompressedLength_ = (size_t)1 << 15;
1384   auto const original = std::string(
1385       reinterpret_cast<const char*>(
1386           randomDataHolder.data(uncompressedLength_).data()),
1387       uncompressedLength_);
1388   auto optionCodec = zlib::getCodec(zlib::Options(ZlibFormat::AUTO));
1389
1390   // Test the codec can uncompress zlib data.
1391   {
1392     auto codec = getCodec(CodecType::ZLIB);
1393     auto const compressed = codec->compress(original);
1394     auto const uncompressed = optionCodec->uncompress(compressed);
1395     EXPECT_EQ(original, uncompressed);
1396   }
1397
1398   // Test the codec can uncompress gzip data.
1399   {
1400     auto codec = getCodec(CodecType::GZIP);
1401     auto const compressed = codec->compress(original);
1402     auto const uncompressed = optionCodec->uncompress(compressed);
1403     EXPECT_EQ(original, uncompressed);
1404   }
1405 }
1406
1407 TEST(ZlibTest, DefaultOptions) {
1408   size_t const uncompressedLength_ = (size_t)1 << 20;
1409   auto const original = std::string(
1410       reinterpret_cast<const char*>(
1411           randomDataHolder.data(uncompressedLength_).data()),
1412       uncompressedLength_);
1413   {
1414     auto codec = getCodec(CodecType::ZLIB);
1415     auto optionCodec = zlib::getCodec(zlib::defaultZlibOptions());
1416     auto const compressed = optionCodec->compress(original);
1417     auto uncompressed = codec->uncompress(compressed);
1418     EXPECT_EQ(original, uncompressed);
1419     uncompressed = optionCodec->uncompress(compressed);
1420     EXPECT_EQ(original, uncompressed);
1421   }
1422
1423   {
1424     auto codec = getCodec(CodecType::GZIP);
1425     auto optionCodec = zlib::getCodec(zlib::defaultGzipOptions());
1426     auto const compressed = optionCodec->compress(original);
1427     auto uncompressed = codec->uncompress(compressed);
1428     EXPECT_EQ(original, uncompressed);
1429     uncompressed = optionCodec->uncompress(compressed);
1430     EXPECT_EQ(original, uncompressed);
1431   }
1432 }
1433
1434 class ZlibOptionsTest : public testing::TestWithParam<
1435                             std::tr1::tuple<ZlibFormat, int, int, int>> {
1436  protected:
1437   void SetUp() override {
1438     auto tup = GetParam();
1439     options_.format = std::tr1::get<0>(tup);
1440     options_.windowSize = std::tr1::get<1>(tup);
1441     options_.memLevel = std::tr1::get<2>(tup);
1442     options_.strategy = std::tr1::get<3>(tup);
1443     codec_ = zlib::getStreamCodec(options_);
1444   }
1445
1446   void runSimpleRoundTripTest(const DataHolder& dh);
1447
1448  private:
1449   zlib::Options options_;
1450   std::unique_ptr<StreamCodec> codec_;
1451 };
1452
1453 void ZlibOptionsTest::runSimpleRoundTripTest(const DataHolder& dh) {
1454   size_t const uncompressedLength = (size_t)1 << 16;
1455   auto const original = std::string(
1456       reinterpret_cast<const char*>(dh.data(uncompressedLength).data()),
1457       uncompressedLength);
1458
1459   auto const compressed = codec_->compress(original);
1460   auto const uncompressed = codec_->uncompress(compressed);
1461   EXPECT_EQ(uncompressed, original);
1462 }
1463
1464 TEST_P(ZlibOptionsTest, simpleRoundTripTest) {
1465   runSimpleRoundTripTest(constantDataHolder);
1466   runSimpleRoundTripTest(randomDataHolder);
1467 }
1468
1469 INSTANTIATE_TEST_CASE_P(
1470     ZlibOptionsTest,
1471     ZlibOptionsTest,
1472     testing::Combine(
1473         testing::Values(
1474             ZlibFormat::ZLIB,
1475             ZlibFormat::GZIP,
1476             ZlibFormat::RAW,
1477             ZlibFormat::AUTO),
1478         testing::Values(9, 12, 15),
1479         testing::Values(1, 8, 9),
1480         testing::Values(
1481             Z_DEFAULT_STRATEGY,
1482             Z_FILTERED,
1483             Z_HUFFMAN_ONLY,
1484             Z_RLE,
1485             Z_FIXED)));
1486
1487 #endif // FOLLY_HAVE_LIBZ
1488
1489 } // namespace test
1490 } // namespace io
1491 } // namespace folly
1492
1493 int main(int argc, char *argv[]) {
1494   testing::InitGoogleTest(&argc, argv);
1495   gflags::ParseCommandLineFlags(&argc, &argv, true);
1496
1497   auto ret = RUN_ALL_TESTS();
1498   if (!ret) {
1499     folly::runBenchmarksOnFlag();
1500   }
1501   return ret;
1502 }