27a6b9fdccde103bdff6a12a4310d109b221e3e5
[folly.git] / folly / io / test / CompressionTest.cpp
1 /*
2  * Copyright 2013 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "folly/io/Compression.h"
18
19 // Yes, tr1, as that's what gtest requires
20 #include <random>
21 #include <thread>
22 #include <tr1/tuple>
23 #include <unordered_map>
24
25 #include <boost/noncopyable.hpp>
26 #include <glog/logging.h>
27 #include <gtest/gtest.h>
28
29 #include "folly/Benchmark.h"
30 #include "folly/Hash.h"
31 #include "folly/Random.h"
32 #include "folly/io/IOBufQueue.h"
33
34 namespace folly { namespace io { namespace test {
35
36 class DataHolder : private boost::noncopyable {
37  public:
38   uint64_t hash(size_t size) const;
39   ByteRange data(size_t size) const;
40
41  protected:
42   explicit DataHolder(size_t sizeLog2);
43   const size_t size_;
44   std::unique_ptr<uint8_t[]> data_;
45   mutable std::unordered_map<uint64_t, uint64_t> hashCache_;
46 };
47
48 DataHolder::DataHolder(size_t sizeLog2)
49   : size_(size_t(1) << sizeLog2),
50     data_(new uint8_t[size_]) {
51 }
52
53 uint64_t DataHolder::hash(size_t size) const {
54   CHECK_LE(size, size_);
55   auto p = hashCache_.find(size);
56   if (p != hashCache_.end()) {
57     return p->second;
58   }
59
60   uint64_t h = folly::hash::fnv64_buf(data_.get(), size);
61   hashCache_[size] = h;
62   return h;
63 }
64
65 ByteRange DataHolder::data(size_t size) const {
66   CHECK_LE(size, size_);
67   return ByteRange(data_.get(), size);
68 }
69
70 uint64_t hashIOBuf(const IOBuf* buf) {
71   uint64_t h = folly::hash::FNV_64_HASH_START;
72   for (auto& range : *buf) {
73     h = folly::hash::fnv64_buf(range.data(), range.size(), h);
74   }
75   return h;
76 }
77
78 class RandomDataHolder : public DataHolder {
79  public:
80   explicit RandomDataHolder(size_t sizeLog2);
81 };
82
83 RandomDataHolder::RandomDataHolder(size_t sizeLog2)
84   : DataHolder(sizeLog2) {
85   constexpr size_t numThreadsLog2 = 3;
86   constexpr size_t numThreads = size_t(1) << numThreadsLog2;
87
88   uint32_t seed = randomNumberSeed();
89
90   std::vector<std::thread> threads;
91   threads.reserve(numThreads);
92   for (size_t t = 0; t < numThreads; ++t) {
93     threads.emplace_back(
94         [this, seed, t, numThreadsLog2, sizeLog2] () {
95           std::mt19937 rng(seed + t);
96           size_t countLog2 = size_t(1) << (sizeLog2 - numThreadsLog2);
97           size_t start = size_t(t) << countLog2;
98           for (size_t i = 0; i < countLog2; ++i) {
99             this->data_[start + i] = rng();
100           }
101         });
102   }
103
104   for (auto& t : threads) {
105     t.join();
106   }
107 }
108
109 class ConstantDataHolder : public DataHolder {
110  public:
111   explicit ConstantDataHolder(size_t sizeLog2);
112 };
113
114 ConstantDataHolder::ConstantDataHolder(size_t sizeLog2)
115   : DataHolder(sizeLog2) {
116   memset(data_.get(), 'a', size_);
117 }
118
119 constexpr size_t dataSizeLog2 = 27;  // 128MiB
120 RandomDataHolder randomDataHolder(dataSizeLog2);
121 ConstantDataHolder constantDataHolder(dataSizeLog2);
122
123 TEST(CompressionTestNeedsUncompressedLength, Simple) {
124   EXPECT_FALSE(getCodec(CodecType::NO_COMPRESSION)->needsUncompressedLength());
125   EXPECT_TRUE(getCodec(CodecType::LZ4)->needsUncompressedLength());
126   EXPECT_FALSE(getCodec(CodecType::SNAPPY)->needsUncompressedLength());
127   EXPECT_FALSE(getCodec(CodecType::ZLIB)->needsUncompressedLength());
128   EXPECT_FALSE(getCodec(CodecType::LZ4_VARINT_SIZE)->needsUncompressedLength());
129 }
130
131 class CompressionTest : public testing::TestWithParam<
132     std::tr1::tuple<int, CodecType>> {
133   protected:
134    void SetUp() {
135      auto tup = GetParam();
136      uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
137      codec_ = getCodec(std::tr1::get<1>(tup));
138    }
139
140    void runSimpleTest(const DataHolder& dh);
141
142    uint64_t uncompressedLength_;
143    std::unique_ptr<Codec> codec_;
144 };
145
146 void CompressionTest::runSimpleTest(const DataHolder& dh) {
147   auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength_));
148   auto compressed = codec_->compress(original.get());
149   if (!codec_->needsUncompressedLength()) {
150     auto uncompressed = codec_->uncompress(compressed.get());
151     EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
152     EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
153   }
154   {
155     auto uncompressed = codec_->uncompress(compressed.get(),
156                                            uncompressedLength_);
157     EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
158     EXPECT_EQ(dh.hash(uncompressedLength_), hashIOBuf(uncompressed.get()));
159   }
160 }
161
162 TEST_P(CompressionTest, RandomData) {
163   runSimpleTest(randomDataHolder);
164 }
165
166 TEST_P(CompressionTest, ConstantData) {
167   runSimpleTest(constantDataHolder);
168 }
169
170 INSTANTIATE_TEST_CASE_P(
171     CompressionTest,
172     CompressionTest,
173     testing::Combine(
174         testing::Values(0, 1, 12, 22, 25, 27),
175         testing::Values(CodecType::NO_COMPRESSION,
176                         CodecType::LZ4,
177                         CodecType::SNAPPY,
178                         CodecType::ZLIB,
179                         CodecType::LZ4_VARINT_SIZE)));
180
181 class CompressionCorruptionTest : public testing::TestWithParam<CodecType> {
182  protected:
183   void SetUp() {
184     codec_ = getCodec(GetParam());
185   }
186
187   void runSimpleTest(const DataHolder& dh);
188
189   std::unique_ptr<Codec> codec_;
190 };
191
192 void CompressionCorruptionTest::runSimpleTest(const DataHolder& dh) {
193   constexpr uint64_t uncompressedLength = 42;
194   auto original = IOBuf::wrapBuffer(dh.data(uncompressedLength));
195   auto compressed = codec_->compress(original.get());
196
197   if (!codec_->needsUncompressedLength()) {
198     auto uncompressed = codec_->uncompress(compressed.get());
199     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
200     EXPECT_EQ(dh.hash(uncompressedLength), hashIOBuf(uncompressed.get()));
201   }
202   {
203     auto uncompressed = codec_->uncompress(compressed.get(),
204                                            uncompressedLength);
205     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
206     EXPECT_EQ(dh.hash(uncompressedLength), hashIOBuf(uncompressed.get()));
207   }
208
209   EXPECT_THROW(codec_->uncompress(compressed.get(), uncompressedLength + 1),
210                std::runtime_error);
211
212   // Corrupt the first character
213   ++(compressed->writableData()[0]);
214
215   if (!codec_->needsUncompressedLength()) {
216     EXPECT_THROW(codec_->uncompress(compressed.get()),
217                  std::runtime_error);
218   }
219
220   EXPECT_THROW(codec_->uncompress(compressed.get(), uncompressedLength),
221                std::runtime_error);
222 }
223
224 TEST_P(CompressionCorruptionTest, RandomData) {
225   runSimpleTest(randomDataHolder);
226 }
227
228 TEST_P(CompressionCorruptionTest, ConstantData) {
229   runSimpleTest(constantDataHolder);
230 }
231
232 INSTANTIATE_TEST_CASE_P(
233     CompressionCorruptionTest,
234     CompressionCorruptionTest,
235     testing::Values(
236         // NO_COMPRESSION can't detect corruption
237         // LZ4 can't detect corruption reliably (sigh)
238         CodecType::SNAPPY,
239         CodecType::ZLIB));
240
241 }}}  // namespaces
242
243 int main(int argc, char *argv[]) {
244   testing::InitGoogleTest(&argc, argv);
245   google::ParseCommandLineFlags(&argc, &argv, true);
246
247   auto ret = RUN_ALL_TESTS();
248   if (!ret) {
249     folly::runBenchmarksOnFlag();
250   }
251   return ret;
252 }
253