IOBuf compression
[folly.git] / folly / io / test / CompressionTest.cpp
1 /*
2  * Copyright 2013 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "folly/io/Compression.h"
18
19 // Yes, tr1, as that's what gtest requires
20 #include <random>
21 #include <thread>
22 #include <tr1/tuple>
23 #include <unordered_map>
24
25 #include <glog/logging.h>
26 #include <gtest/gtest.h>
27
28 #include "folly/Benchmark.h"
29 #include "folly/Hash.h"
30 #include "folly/Random.h"
31 #include "folly/io/IOBufQueue.h"
32
33 namespace folly { namespace io { namespace test {
34
35 constexpr size_t randomDataSizeLog2 = 27;  // 128MiB
36 constexpr size_t randomDataSize = size_t(1) << randomDataSizeLog2;
37
38 std::unique_ptr<uint8_t[]> randomData;
39 std::unordered_map<uint64_t, uint64_t> hashes;
40
41 uint64_t hashIOBuf(const IOBuf* buf) {
42   uint64_t h = folly::hash::FNV_64_HASH_START;
43   for (auto& range : *buf) {
44     h = folly::hash::fnv64_buf(range.data(), range.size(), h);
45   }
46   return h;
47 }
48
49 uint64_t getRandomDataHash(uint64_t size) {
50   auto p = hashes.find(size);
51   if (p != hashes.end()) {
52     return p->second;
53   }
54
55   uint64_t h = folly::hash::fnv64_buf(randomData.get(), size);
56   hashes[size] = h;
57   return h;
58 }
59
60 void generateRandomData() {
61   randomData.reset(new uint8_t[size_t(1) << randomDataSizeLog2]);
62
63   constexpr size_t numThreadsLog2 = 3;
64   constexpr size_t numThreads = size_t(1) << numThreadsLog2;
65
66   uint32_t seed = randomNumberSeed();
67
68   std::vector<std::thread> threads;
69   threads.reserve(numThreads);
70   for (size_t t = 0; t < numThreads; ++t) {
71     threads.emplace_back(
72         [seed, t, numThreadsLog2] () {
73           std::mt19937 rng(seed + t);
74           size_t countLog2 = size_t(1) << (randomDataSizeLog2 - numThreadsLog2);
75           size_t start = size_t(t) << countLog2;
76           for (size_t i = 0; i < countLog2; ++i) {
77             randomData[start + i] = rng();
78           }
79         });
80   }
81
82   for (auto& t : threads) {
83     t.join();
84   }
85 }
86
87 class CompressionTest : public testing::TestWithParam<
88     std::tr1::tuple<int, CodecType>> {
89   protected:
90    void SetUp() {
91      auto tup = GetParam();
92      uncompressedLength_ = uint64_t(1) << std::tr1::get<0>(tup);
93      codec_ = getCodec(std::tr1::get<1>(tup));
94    }
95
96    uint64_t uncompressedLength_;
97    std::unique_ptr<Codec> codec_;
98 };
99
100 TEST_P(CompressionTest, Simple) {
101   auto original = IOBuf::wrapBuffer(randomData.get(), uncompressedLength_);
102   auto compressed = codec_->compress(original.get());
103   if (!codec_->needsUncompressedLength()) {
104     auto uncompressed = codec_->uncompress(compressed.get());
105     EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
106     EXPECT_EQ(getRandomDataHash(uncompressedLength_),
107               hashIOBuf(uncompressed.get()));
108   }
109   {
110     auto uncompressed = codec_->uncompress(compressed.get(),
111                                            uncompressedLength_);
112     EXPECT_EQ(uncompressedLength_, uncompressed->computeChainDataLength());
113     EXPECT_EQ(getRandomDataHash(uncompressedLength_),
114               hashIOBuf(uncompressed.get()));
115   }
116 }
117
118 INSTANTIATE_TEST_CASE_P(
119     CompressionTest,
120     CompressionTest,
121     testing::Combine(
122         testing::Values(0, 1, 12, 22, int(randomDataSizeLog2)),
123         testing::Values(CodecType::NO_COMPRESSION,
124                         CodecType::LZ4,
125                         CodecType::SNAPPY,
126                         CodecType::ZLIB)));
127
128 class CompressionCorruptionTest : public testing::TestWithParam<CodecType> {
129  protected:
130   void SetUp() {
131     codec_ = getCodec(GetParam());
132   }
133
134   std::unique_ptr<Codec> codec_;
135 };
136
137 TEST_P(CompressionCorruptionTest, Simple) {
138   constexpr uint64_t uncompressedLength = 42;
139   auto original = IOBuf::wrapBuffer(randomData.get(), uncompressedLength);
140   auto compressed = codec_->compress(original.get());
141
142   if (!codec_->needsUncompressedLength()) {
143     auto uncompressed = codec_->uncompress(compressed.get());
144     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
145     EXPECT_EQ(getRandomDataHash(uncompressedLength),
146               hashIOBuf(uncompressed.get()));
147   }
148   {
149     auto uncompressed = codec_->uncompress(compressed.get(),
150                                            uncompressedLength);
151     EXPECT_EQ(uncompressedLength, uncompressed->computeChainDataLength());
152     EXPECT_EQ(getRandomDataHash(uncompressedLength),
153               hashIOBuf(uncompressed.get()));
154   }
155
156   EXPECT_THROW(codec_->uncompress(compressed.get(), uncompressedLength + 1),
157                std::runtime_error);
158
159   // Corrupt the first character
160   ++(compressed->writableData()[0]);
161
162   if (!codec_->needsUncompressedLength()) {
163     EXPECT_THROW(codec_->uncompress(compressed.get()),
164                  std::runtime_error);
165   }
166
167   EXPECT_THROW(codec_->uncompress(compressed.get(), uncompressedLength),
168                std::runtime_error);
169 }
170
171 INSTANTIATE_TEST_CASE_P(
172     CompressionCorruptionTest,
173     CompressionCorruptionTest,
174     testing::Values(
175         // NO_COMPRESSION can't detect corruption
176         // LZ4 can't detect corruption reliably (sigh)
177         CodecType::SNAPPY,
178         CodecType::ZLIB));
179
180 }}}  // namespaces
181
182 int main(int argc, char *argv[]) {
183   testing::InitGoogleTest(&argc, argv);
184   google::ParseCommandLineFlags(&argc, &argv, true);
185
186   folly::io::test::generateRandomData();  // 4GB
187
188   auto ret = RUN_ALL_TESTS();
189   if (!ret) {
190     folly::runBenchmarksOnFlag();
191   }
192   return ret;
193 }
194