f46df40103759db64b44682186134415305098b0
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2013 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "folly/io/Compression.h"
18
19 #include <lz4.h>
20 #include <lz4hc.h>
21 #include <glog/logging.h>
22 #include <snappy.h>
23 #include <snappy-sinksource.h>
24 #include <zlib.h>
25
26 #include "folly/Conv.h"
27 #include "folly/Memory.h"
28 #include "folly/Portability.h"
29 #include "folly/ScopeGuard.h"
30 #include "folly/io/Cursor.h"
31
32 namespace folly { namespace io {
33
34 // Ensure consistent behavior in the nullptr case
35 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
36   return !data->empty() ? doCompress(data) : IOBuf::create(0);
37 }
38
39 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
40                                          uint64_t uncompressedLength) {
41   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
42     if (needsUncompressedLength()) {
43       throw std::invalid_argument("Codec: uncompressed length required");
44     }
45   } else if (uncompressedLength > maxUncompressedLength()) {
46     throw std::runtime_error("Codec: uncompressed length too large");
47   }
48
49   if (data->empty()) {
50     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
51         uncompressedLength != 0) {
52       throw std::runtime_error("Codec: invalid uncompressed length");
53     }
54     return IOBuf::create(0);
55   }
56
57   return doUncompress(data, uncompressedLength);
58 }
59
60 bool Codec::needsUncompressedLength() const {
61   return doNeedsUncompressedLength();
62 }
63
64 uint64_t Codec::maxUncompressedLength() const {
65   return doMaxUncompressedLength();
66 }
67
68 CodecType Codec::type() const {
69   return doType();
70 }
71
72 bool Codec::doNeedsUncompressedLength() const {
73   return false;
74 }
75
76 uint64_t Codec::doMaxUncompressedLength() const {
77   return std::numeric_limits<uint64_t>::max() - 1;
78 }
79
80 namespace {
81
82 /**
83  * No compression
84  */
85 class NoCompressionCodec FOLLY_FINAL : public Codec {
86  public:
87   static std::unique_ptr<Codec> create(int level);
88   explicit NoCompressionCodec(int level);
89
90  private:
91   CodecType doType() const FOLLY_OVERRIDE;
92   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
93   std::unique_ptr<IOBuf> doUncompress(
94       const IOBuf* data,
95       uint64_t uncompressedLength) FOLLY_OVERRIDE;
96 };
97
98 std::unique_ptr<Codec> NoCompressionCodec::create(int level) {
99   return make_unique<NoCompressionCodec>(level);
100 }
101
102 NoCompressionCodec::NoCompressionCodec(int level) {
103   switch (level) {
104   case COMPRESSION_LEVEL_DEFAULT:
105   case COMPRESSION_LEVEL_FASTEST:
106   case COMPRESSION_LEVEL_BEST:
107     level = 0;
108   }
109   if (level != 0) {
110     throw std::invalid_argument(to<std::string>(
111         "NoCompressionCodec: invalid level ", level));
112   }
113 }
114
115 CodecType NoCompressionCodec::doType() const {
116   return CodecType::NO_COMPRESSION;
117 }
118
119 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
120     const IOBuf* data) {
121   return data->clone();
122 }
123
124 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
125     const IOBuf* data,
126     uint64_t uncompressedLength) {
127   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
128       data->computeChainDataLength() != uncompressedLength) {
129     throw std::runtime_error(to<std::string>(
130         "NoCompressionCodec: invalid uncompressed length"));
131   }
132   return data->clone();
133 }
134
135 /**
136  * LZ4 compression
137  */
138 class LZ4Codec FOLLY_FINAL : public Codec {
139  public:
140   static std::unique_ptr<Codec> create(int level);
141   explicit LZ4Codec(int level);
142
143  private:
144   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
145   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
146   CodecType doType() const FOLLY_OVERRIDE;
147   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
148   std::unique_ptr<IOBuf> doUncompress(
149       const IOBuf* data,
150       uint64_t uncompressedLength) FOLLY_OVERRIDE;
151
152   bool highCompression_;
153 };
154
155 std::unique_ptr<Codec> LZ4Codec::create(int level) {
156   return make_unique<LZ4Codec>(level);
157 }
158
159 LZ4Codec::LZ4Codec(int level) {
160   switch (level) {
161   case COMPRESSION_LEVEL_FASTEST:
162   case COMPRESSION_LEVEL_DEFAULT:
163     level = 1;
164     break;
165   case COMPRESSION_LEVEL_BEST:
166     level = 2;
167     break;
168   }
169   if (level < 1 || level > 2) {
170     throw std::invalid_argument(to<std::string>(
171         "LZ4Codec: invalid level: ", level));
172   }
173   highCompression_ = (level > 1);
174 }
175
176 bool LZ4Codec::doNeedsUncompressedLength() const {
177   return true;
178 }
179
180 uint64_t LZ4Codec::doMaxUncompressedLength() const {
181   // From lz4.h: "Max supported value is ~1.9GB"; I wish we had something
182   // more accurate.
183   return 1.8 * (uint64_t(1) << 30);
184 }
185
186 CodecType LZ4Codec::doType() const {
187   return CodecType::LZ4;
188 }
189
190 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
191   std::unique_ptr<IOBuf> clone;
192   if (data->isChained()) {
193     // LZ4 doesn't support streaming, so we have to coalesce
194     clone = data->clone();
195     clone->coalesce();
196     data = clone.get();
197   }
198
199   auto out = IOBuf::create(LZ4_compressBound(data->length()));
200   int n;
201   if (highCompression_) {
202     n = LZ4_compress(reinterpret_cast<const char*>(data->data()),
203                      reinterpret_cast<char*>(out->writableTail()),
204                      data->length());
205   } else {
206     n = LZ4_compressHC(reinterpret_cast<const char*>(data->data()),
207                        reinterpret_cast<char*>(out->writableTail()),
208                        data->length());
209   }
210
211   CHECK_GE(n, 0);
212   CHECK_LE(n, out->capacity());
213
214   out->append(n);
215   return out;
216 }
217
218 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
219     const IOBuf* data,
220     uint64_t uncompressedLength) {
221   std::unique_ptr<IOBuf> clone;
222   if (data->isChained()) {
223     // LZ4 doesn't support streaming, so we have to coalesce
224     clone = data->clone();
225     clone->coalesce();
226     data = clone.get();
227   }
228
229   auto out = IOBuf::create(uncompressedLength);
230   int n = LZ4_uncompress(reinterpret_cast<const char*>(data->data()),
231                          reinterpret_cast<char*>(out->writableTail()),
232                          uncompressedLength);
233   if (n != data->length()) {
234     throw std::runtime_error(to<std::string>(
235         "LZ4 decompression returned invalid value ", n));
236   }
237   out->append(uncompressedLength);
238   return out;
239 }
240
241 /**
242  * Snappy compression
243  */
244
245 /**
246  * Implementation of snappy::Source that reads from a IOBuf chain.
247  */
248 class IOBufSnappySource FOLLY_FINAL : public snappy::Source {
249  public:
250   explicit IOBufSnappySource(const IOBuf* data);
251   size_t Available() const FOLLY_OVERRIDE;
252   const char* Peek(size_t* len) FOLLY_OVERRIDE;
253   void Skip(size_t n) FOLLY_OVERRIDE;
254  private:
255   size_t available_;
256   io::Cursor cursor_;
257 };
258
259 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
260   : available_(data->computeChainDataLength()),
261     cursor_(data) {
262 }
263
264 size_t IOBufSnappySource::Available() const {
265   return available_;
266 }
267
268 const char* IOBufSnappySource::Peek(size_t* len) {
269   auto p = cursor_.peek();
270   *len = p.second;
271   return reinterpret_cast<const char*>(p.first);
272 }
273
274 void IOBufSnappySource::Skip(size_t n) {
275   CHECK_LE(n, available_);
276   cursor_.skip(n);
277   available_ -= n;
278 }
279
280 class SnappyCodec FOLLY_FINAL : public Codec {
281  public:
282   static std::unique_ptr<Codec> create(int level);
283   explicit SnappyCodec(int level);
284
285  private:
286   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
287   CodecType doType() const FOLLY_OVERRIDE;
288   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
289   std::unique_ptr<IOBuf> doUncompress(
290       const IOBuf* data,
291       uint64_t uncompressedLength) FOLLY_OVERRIDE;
292 };
293
294 std::unique_ptr<Codec> SnappyCodec::create(int level) {
295   return make_unique<SnappyCodec>(level);
296 }
297
298 SnappyCodec::SnappyCodec(int level) {
299   switch (level) {
300   case COMPRESSION_LEVEL_FASTEST:
301   case COMPRESSION_LEVEL_DEFAULT:
302   case COMPRESSION_LEVEL_BEST:
303     level = 1;
304   }
305   if (level != 1) {
306     throw std::invalid_argument(to<std::string>(
307         "SnappyCodec: invalid level: ", level));
308   }
309 }
310
311 uint64_t SnappyCodec::doMaxUncompressedLength() const {
312   // snappy.h uses uint32_t for lengths, so there's that.
313   return std::numeric_limits<uint32_t>::max();
314 }
315
316 CodecType SnappyCodec::doType() const {
317   return CodecType::SNAPPY;
318 }
319
320 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
321   IOBufSnappySource source(data);
322   auto out =
323     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
324
325   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
326       out->writableTail()));
327
328   size_t n = snappy::Compress(&source, &sink);
329
330   CHECK_LE(n, out->capacity());
331   out->append(n);
332   return out;
333 }
334
335 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
336                                                  uint64_t uncompressedLength) {
337   uint32_t actualUncompressedLength = 0;
338
339   {
340     IOBufSnappySource source(data);
341     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
342       throw std::runtime_error("snappy::GetUncompressedLength failed");
343     }
344     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
345         uncompressedLength != actualUncompressedLength) {
346       throw std::runtime_error("snappy: invalid uncompressed length");
347     }
348   }
349
350   auto out = IOBuf::create(actualUncompressedLength);
351
352   {
353     IOBufSnappySource source(data);
354     if (!snappy::RawUncompress(&source,
355                                reinterpret_cast<char*>(out->writableTail()))) {
356       throw std::runtime_error("snappy::RawUncompress failed");
357     }
358   }
359
360   out->append(actualUncompressedLength);
361   return out;
362 }
363
364 /**
365  * Zlib codec
366  */
367 class ZlibCodec FOLLY_FINAL : public Codec {
368  public:
369   static std::unique_ptr<Codec> create(int level);
370   explicit ZlibCodec(int level);
371
372  private:
373   CodecType doType() const FOLLY_OVERRIDE;
374   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
375   std::unique_ptr<IOBuf> doUncompress(
376       const IOBuf* data,
377       uint64_t uncompressedLength) FOLLY_OVERRIDE;
378
379   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
380   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
381
382   int level_;
383 };
384
385 std::unique_ptr<Codec> ZlibCodec::create(int level) {
386   return make_unique<ZlibCodec>(level);
387 }
388
389 ZlibCodec::ZlibCodec(int level) {
390   switch (level) {
391   case COMPRESSION_LEVEL_FASTEST:
392     level = 1;
393     break;
394   case COMPRESSION_LEVEL_DEFAULT:
395     level = Z_DEFAULT_COMPRESSION;
396     break;
397   case COMPRESSION_LEVEL_BEST:
398     level = 9;
399     break;
400   }
401   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
402     throw std::invalid_argument(to<std::string>(
403         "ZlibCodec: invalid level: ", level));
404   }
405   level_ = level;
406 }
407
408 CodecType ZlibCodec::doType() const {
409   return CodecType::ZLIB;
410 }
411
412 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
413                                                   uint32_t length) {
414   CHECK_EQ(stream->avail_out, 0);
415
416   auto buf = IOBuf::create(length);
417   buf->append(length);
418
419   stream->next_out = buf->writableData();
420   stream->avail_out = buf->length();
421
422   return buf;
423 }
424
425 bool ZlibCodec::doInflate(z_stream* stream,
426                           IOBuf* head,
427                           uint32_t bufferLength) {
428   if (stream->avail_out == 0) {
429     head->prependChain(addOutputBuffer(stream, bufferLength));
430   }
431
432   int rc = inflate(stream, Z_NO_FLUSH);
433
434   switch (rc) {
435   case Z_OK:
436     break;
437   case Z_STREAM_END:
438     return true;
439   case Z_BUF_ERROR:
440   case Z_NEED_DICT:
441   case Z_DATA_ERROR:
442   case Z_MEM_ERROR:
443     throw std::runtime_error(to<std::string>(
444         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
445   default:
446     CHECK(false) << rc << ": " << stream->msg;
447   }
448
449   return false;
450 }
451
452
453 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
454   z_stream stream;
455   stream.zalloc = nullptr;
456   stream.zfree = nullptr;
457   stream.opaque = nullptr;
458
459   int rc = deflateInit(&stream, level_);
460   if (rc != Z_OK) {
461     throw std::runtime_error(to<std::string>(
462         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
463   }
464
465   stream.next_in = stream.next_out = nullptr;
466   stream.avail_in = stream.avail_out = 0;
467   stream.total_in = stream.total_out = 0;
468
469   bool success = false;
470
471   SCOPE_EXIT {
472     int rc = deflateEnd(&stream);
473     // If we're here because of an exception, it's okay if some data
474     // got dropped.
475     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
476       << rc << ": " << stream.msg;
477   };
478
479   uint64_t uncompressedLength = data->computeChainDataLength();
480   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
481
482   // Max 64MiB in one go
483   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
484   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
485
486   auto out = addOutputBuffer(
487       &stream,
488       (maxCompressedLength <= maxSingleStepLength ?
489        maxCompressedLength :
490        defaultBufferLength));
491
492   for (auto& range : *data) {
493     if (range.empty()) {
494       continue;
495     }
496
497     stream.next_in = const_cast<uint8_t*>(range.data());
498     stream.avail_in = range.size();
499
500     while (stream.avail_in != 0) {
501       if (stream.avail_out == 0) {
502         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
503       }
504
505       rc = deflate(&stream, Z_NO_FLUSH);
506
507       CHECK_EQ(rc, Z_OK) << stream.msg;
508     }
509   }
510
511   do {
512     if (stream.avail_out == 0) {
513       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
514     }
515
516     rc = deflate(&stream, Z_FINISH);
517   } while (rc == Z_OK);
518
519   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
520
521   out->prev()->trimEnd(stream.avail_out);
522
523   success = true;  // we survived
524
525   return out;
526 }
527
528 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
529                                                uint64_t uncompressedLength) {
530   z_stream stream;
531   stream.zalloc = nullptr;
532   stream.zfree = nullptr;
533   stream.opaque = nullptr;
534
535   int rc = inflateInit(&stream);
536   if (rc != Z_OK) {
537     throw std::runtime_error(to<std::string>(
538         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
539   }
540
541   stream.next_in = stream.next_out = nullptr;
542   stream.avail_in = stream.avail_out = 0;
543   stream.total_in = stream.total_out = 0;
544
545   bool success = false;
546
547   SCOPE_EXIT {
548     int rc = inflateEnd(&stream);
549     // If we're here because of an exception, it's okay if some data
550     // got dropped.
551     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
552       << rc << ": " << stream.msg;
553   };
554
555   // Max 64MiB in one go
556   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
557   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
558
559   auto out = addOutputBuffer(
560       &stream,
561       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
562         uncompressedLength <= maxSingleStepLength) ?
563        uncompressedLength :
564        defaultBufferLength));
565
566   bool streamEnd = false;
567   for (auto& range : *data) {
568     if (range.empty()) {
569       continue;
570     }
571
572     stream.next_in = const_cast<uint8_t*>(range.data());
573     stream.avail_in = range.size();
574
575     while (stream.avail_in != 0) {
576       if (streamEnd) {
577         throw std::runtime_error(to<std::string>(
578             "ZlibCodec: junk after end of data"));
579       }
580
581       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
582     }
583   }
584
585   while (!streamEnd) {
586     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
587   }
588
589   out->prev()->trimEnd(stream.avail_out);
590
591   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
592       uncompressedLength != stream.total_out) {
593     throw std::runtime_error(to<std::string>(
594         "ZlibCodec: invalid uncompressed length"));
595   }
596
597   success = true;  // we survived
598
599   return out;
600 }
601
602 typedef std::unique_ptr<Codec> (*CodecFactory)(int);
603
604 CodecFactory gCodecFactories[
605     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
606   NoCompressionCodec::create,
607   LZ4Codec::create,
608   SnappyCodec::create,
609   ZlibCodec::create
610 };
611
612 }  // namespace
613
614 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
615   size_t idx = static_cast<size_t>(type);
616   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
617     throw std::invalid_argument(to<std::string>(
618         "Compression type ", idx, " not supported"));
619   }
620   auto factory = gCodecFactories[idx];
621   if (!factory) {
622     throw std::invalid_argument(to<std::string>(
623         "Compression type ", idx, " not supported"));
624   }
625   auto codec = (*factory)(level);
626   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
627   return codec;
628 }
629
630 }}  // namespaces
631