zstd is no longer in beta -- s/ZSTD_BETA/ZSTD/g
[folly.git] / folly / io / test / CompressionTest.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #include <random>
20 #include <thread>
21 #include <unordered_map>
22
23 #include <boost/noncopyable.hpp>
24 #include <glog/logging.h>
25 #include <gtest/gtest.h>
26
27 #include <folly/Benchmark.h>
28 #include <folly/Hash.h>
29 #include <folly/Random.h>
30 #include <folly/Varint.h>
31 #include <folly/io/IOBufQueue.h>
32
33 namespace folly { namespace io { namespace test {
34
35 class DataHolder : private boost::noncopyable {
36  public:
37   uint64_t hash(size_t size) const;
38   ByteRange data(size_t size) const;
39
40  protected:
41   explicit DataHolder(size_t sizeLog2);
42   const size_t size_;
43   std::unique_ptr<uint8_t[]> data_;
44   mutable std::unordered_map<uint64_t, uint64_t> hashCache_;
45 };
46
47 DataHolder::DataHolder(size_t sizeLog2)
48   : size_(size_t(1) << sizeLog2),
49     data_(new uint8_t[size_]) {
50 }
51
52 uint64_t DataHolder::hash(size_t size) const {
53   CHECK_LE(size, size_);
54   auto p = hashCache_.find(size);
55   if (p != hashCache_.end()) {
56     return p->second;
57   }
58
59   uint64_t h = folly::hash::fnv64_buf(data_.get(), size);
60   hashCache_[size] = h;
61   return h;
62 }
63
64 ByteRange DataHolder::data(size_t size) const {
65   CHECK_LE(size, size_);
66   return ByteRange(data_.get(), size);
67 }
68
69 uint64_t hashIOBuf(const IOBuf* buf) {
70   uint64_t h = folly::hash::FNV_64_HASH_START;
71   for (auto& range : *buf) {
72     h = folly::hash::fnv64_buf(range.data(), range.size(), h);
73   }
74   return h;
75 }
76
77 class RandomDataHolder : public DataHolder {
78  public:
79   explicit RandomDataHolder(size_t sizeLog2);
80 };
81
82 RandomDataHolder::RandomDataHolder(size_t sizeLog2)
83   : DataHolder(sizeLog2) {
84   constexpr size_t numThreadsLog2 = 3;
85   constexpr size_t numThreads = size_t(1) << numThreadsLog2;
86
87   uint32_t seed = randomNumberSeed();
88
89   std::vector<std::thread> threads;
90   threads.reserve(numThreads);
91   for (size_t t = 0; t < numThreads; ++t) {
92     threads.emplace_back(
93         [this, seed, t, numThreadsLog2, sizeLog2] () {
94           std::mt19937 rng(seed + t);
95           size_t countLog2 = sizeLog2 - numThreadsLog2;
96           size_t start = size_t(t) << countLog2;
97           for (size_t i = 0; i < countLog2; ++i) {
98             this->data_[start + i] = rng();
99           }
100         });
101   }
102
103   for (auto& t : threads) {
104     t.join();
105   }
106 }
107
108 class ConstantDataHolder : public DataHolder {
109  public:
110   explicit ConstantDataHolder(size_t sizeLog2);
111 };
112
113 ConstantDataHolder::ConstantDataHolder(size_t sizeLog2)
114   : DataHolder(sizeLog2) {
115   memset(data_.get(), 'a', size_);
116 }
117
118 constexpr size_t dataSizeLog2 = 27;  // 128MiB
119 RandomDataHolder randomDataHolder(dataSizeLog2);
120 ConstantDataHolder constantDataHolder(dataSizeLog2);
121
122 TEST(CompressionTestNeedsUncompressedLength, Simple) {
123   EXPECT_FALSE(getCodec(CodecType::NO_COMPRESSION)->needsUncompressedLength());
124   EXPECT_TRUE(getCodec(CodecType::LZ4)->needsUncompressedLength());
125   EXPECT_FALSE(getCodec(CodecType::SNAPPY)->needsUncompressedLength());
126   EXPECT_FALSE(getCodec(CodecType::ZLIB)->needsUncompressedLength());
127   EXPECT_FALSE(getCodec(CodecType::LZ4_VARINT_SIZE)->needsUncompressedLength());
128   EXPECT_TRUE(getCodec(CodecType::LZMA2)->needsUncompressedLength());
129   EXPECT_FALSE(getCodec(CodecType::LZMA2_VARINT_SIZE)
130     ->needsUncompressedLength());
131   EXPECT_TRUE(getCodec(CodecType::ZSTD)->needsUncompressedLength());
132   EXPECT_FALSE(getCodec(CodecType::GZIP)->needsUncompressedLength());
133 }
134
135 class CompressionTest
136     : public testing::TestWithParam<std::tr1::tuple<int, CodecType>> {
137   protected:
138    void SetUp() override {
139      auto tup = GetParam();
140      uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
141      codec_ = getCodec(std::tr1::get<1>(tup));
142    }
143
144    void runSimpleTest(const DataHolder& dh);
145
146    uint64_t uncompressedLength_;
147    std::unique_ptr<Codec> codec_;
148 };
149
150 void CompressionTest::runSimpleTest(const DataHolder& dh) {
151   auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength_));
152   auto compressed = codec_->compress(original.get());
153   if (!codec_->needsUncompressedLength()) {
154     auto uncompressed = codec_->uncompress(compressed.get());
155
156     EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
157     EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
158   }
159   {
160     auto uncompressed = codec_->uncompress(compressed.get(),
161                                            uncompressedLength_);
162     EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
163     EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
164   }
165 }
166
167 TEST_P(CompressionTest, RandomData) {
168   runSimpleTest(randomDataHolder);
169 }
170
171 TEST_P(CompressionTest, ConstantData) {
172   runSimpleTest(constantDataHolder);
173 }
174
175 INSTANTIATE_TEST_CASE_P(
176     CompressionTest,
177     CompressionTest,
178     testing::Combine(testing::Values(0, 1, 12, 22, 25, 27),
179                      testing::Values(CodecType::NO_COMPRESSION,
180                                      CodecType::LZ4,
181                                      CodecType::SNAPPY,
182                                      CodecType::ZLIB,
183                                      CodecType::LZ4_VARINT_SIZE,
184                                      CodecType::LZMA2,
185                                      CodecType::LZMA2_VARINT_SIZE,
186                                      CodecType::ZSTD,
187                                      CodecType::GZIP)));
188
189 class CompressionVarintTest
190     : public testing::TestWithParam<std::tr1::tuple<int, CodecType>> {
191  protected:
192   void SetUp() override {
193     auto tup = GetParam();
194     uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
195     codec_ = getCodec(std::tr1::get<1>(tup));
196   }
197
198   void runSimpleTest(const DataHolder& dh);
199
200   uint64_t uncompressedLength_;
201   std::unique_ptr<Codec> codec_;
202 };
203
204 inline uint64_t oneBasedMsbPos(uint64_t number) {
205   uint64_t pos = 0;
206   for (; number > 0; ++pos, number >>= 1) {
207   }
208   return pos;
209 }
210
211 void CompressionVarintTest::runSimpleTest(const DataHolder& dh) {
212   auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength_));
213   auto compressed = codec_->compress(original.get());
214   auto breakPoint =
215       1UL +
216       Random::rand64(
217           std::max(uint64_t(9), oneBasedMsbPos(uncompressedLength_)) / 9UL);
218   auto tinyBuf = IOBuf::copyBuffer(compressed->data(),
219                                    std::min(compressed->length(), breakPoint));
220   compressed->trimStart(breakPoint);
221   tinyBuf->prependChain(std::move(compressed));
222   compressed = std::move(tinyBuf);
223
224   auto uncompressed = codec_->uncompress(compressed.get());
225
226   EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
227   EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
228 }
229
230 TEST_P(CompressionVarintTest, RandomData) { runSimpleTest(randomDataHolder); }
231
232 TEST_P(CompressionVarintTest, ConstantData) {
233   runSimpleTest(constantDataHolder);
234 }
235
236 INSTANTIATE_TEST_CASE_P(
237     CompressionVarintTest,
238     CompressionVarintTest,
239     testing::Combine(testing::Values(0, 1, 12, 22, 25, 27),
240                      testing::Values(CodecType::LZ4_VARINT_SIZE,
241                                      CodecType::LZMA2_VARINT_SIZE)));
242
243 class CompressionCorruptionTest : public testing::TestWithParam<CodecType> {
244  protected:
245   void SetUp() override { codec_ = getCodec(GetParam()); }
246
247   void runSimpleTest(const DataHolder& dh);
248
249   std::unique_ptr<Codec> codec_;
250 };
251
252 void CompressionCorruptionTest::runSimpleTest(const DataHolder& dh) {
253   constexpr uint64_t uncompressedLength = 42;
254   auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength));
255   auto compressed = codec_->compress(original.get());
256
257   if (!codec_->needsUncompressedLength()) {
258     auto uncompressed = codec_->uncompress(compressed.get());
259     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
260     EXPECT_EQ(dh.hash(uncompressedLength), hashIOBuf(uncompressed.get()));
261   }
262   {
263     auto uncompressed = codec_->uncompress(compressed.get(),
264                                            uncompressedLength);
265     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
266     EXPECT_EQ(dh.hash(uncompressedLength), hashIOBuf(uncompressed.get()));
267   }
268
269   EXPECT_THROW(codec_->uncompress(compressed.get(), uncompressedLength + 1),
270                std::runtime_error);
271
272   // Corrupt the first character
273   ++(compressed->writableData()[0]);
274
275   if (!codec_->needsUncompressedLength()) {
276     EXPECT_THROW(codec_->uncompress(compressed.get()),
277                  std::runtime_error);
278   }
279
280   EXPECT_THROW(codec_->uncompress(compressed.get(), uncompressedLength),
281                std::runtime_error);
282 }
283
284 TEST_P(CompressionCorruptionTest, RandomData) {
285   runSimpleTest(randomDataHolder);
286 }
287
288 TEST_P(CompressionCorruptionTest, ConstantData) {
289   runSimpleTest(constantDataHolder);
290 }
291
292 INSTANTIATE_TEST_CASE_P(
293     CompressionCorruptionTest,
294     CompressionCorruptionTest,
295     testing::Values(
296         // NO_COMPRESSION can't detect corruption
297         // LZ4 can't detect corruption reliably (sigh)
298         CodecType::SNAPPY,
299         CodecType::ZLIB));
300
301 }}}  // namespaces
302
303 int main(int argc, char *argv[]) {
304   testing::InitGoogleTest(&argc, argv);
305   gflags::ParseCommandLineFlags(&argc, &argv, true);
306
307   auto ret = RUN_ALL_TESTS();
308   if (!ret) {
309     folly::runBenchmarksOnFlag();
310   }
311   return ret;
312 }