folly copyright 2015 -> copyright 2016
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #if FOLLY_HAVE_LIBZSTD
40 #include <zstd.h>
41 #endif
42
43 #include <folly/Conv.h>
44 #include <folly/Memory.h>
45 #include <folly/Portability.h>
46 #include <folly/ScopeGuard.h>
47 #include <folly/Varint.h>
48 #include <folly/io/Cursor.h>
49
50 namespace folly { namespace io {
51
52 Codec::Codec(CodecType type) : type_(type) { }
53
54 // Ensure consistent behavior in the nullptr case
55 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
56   uint64_t len = data->computeChainDataLength();
57   if (len == 0) {
58     return IOBuf::create(0);
59   } else if (len > maxUncompressedLength()) {
60     throw std::runtime_error("Codec: uncompressed length too large");
61   }
62
63   return doCompress(data);
64 }
65
66 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
67                                          uint64_t uncompressedLength) {
68   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
69     if (needsUncompressedLength()) {
70       throw std::invalid_argument("Codec: uncompressed length required");
71     }
72   } else if (uncompressedLength > maxUncompressedLength()) {
73     throw std::runtime_error("Codec: uncompressed length too large");
74   }
75
76   if (data->empty()) {
77     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
78         uncompressedLength != 0) {
79       throw std::runtime_error("Codec: invalid uncompressed length");
80     }
81     return IOBuf::create(0);
82   }
83
84   return doUncompress(data, uncompressedLength);
85 }
86
87 bool Codec::needsUncompressedLength() const {
88   return doNeedsUncompressedLength();
89 }
90
91 uint64_t Codec::maxUncompressedLength() const {
92   return doMaxUncompressedLength();
93 }
94
95 bool Codec::doNeedsUncompressedLength() const {
96   return false;
97 }
98
99 uint64_t Codec::doMaxUncompressedLength() const {
100   return UNLIMITED_UNCOMPRESSED_LENGTH;
101 }
102
103 namespace {
104
105 /**
106  * No compression
107  */
108 class NoCompressionCodec final : public Codec {
109  public:
110   static std::unique_ptr<Codec> create(int level, CodecType type);
111   explicit NoCompressionCodec(int level, CodecType type);
112
113  private:
114   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
115   std::unique_ptr<IOBuf> doUncompress(
116       const IOBuf* data,
117       uint64_t uncompressedLength) override;
118 };
119
120 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
121   return make_unique<NoCompressionCodec>(level, type);
122 }
123
124 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
125   : Codec(type) {
126   DCHECK(type == CodecType::NO_COMPRESSION);
127   switch (level) {
128   case COMPRESSION_LEVEL_DEFAULT:
129   case COMPRESSION_LEVEL_FASTEST:
130   case COMPRESSION_LEVEL_BEST:
131     level = 0;
132   }
133   if (level != 0) {
134     throw std::invalid_argument(to<std::string>(
135         "NoCompressionCodec: invalid level ", level));
136   }
137 }
138
139 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
140     const IOBuf* data) {
141   return data->clone();
142 }
143
144 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
145     const IOBuf* data,
146     uint64_t uncompressedLength) {
147   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
148       data->computeChainDataLength() != uncompressedLength) {
149     throw std::runtime_error(to<std::string>(
150         "NoCompressionCodec: invalid uncompressed length"));
151   }
152   return data->clone();
153 }
154
155 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
156
157 namespace {
158
159 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
160   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
161   out->append(encodeVarint(val, out->writableTail()));
162 }
163
164 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
165   uint64_t val = 0;
166   int8_t b = 0;
167   for (int shift = 0; shift <= 63; shift += 7) {
168     b = cursor.read<int8_t>();
169     val |= static_cast<uint64_t>(b & 0x7f) << shift;
170     if (b >= 0) {
171       break;
172     }
173   }
174   if (b < 0) {
175     throw std::invalid_argument("Invalid varint value. Too big.");
176   }
177   return val;
178 }
179
180 }  // namespace
181
182 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
183
184 #if FOLLY_HAVE_LIBLZ4
185
186 /**
187  * LZ4 compression
188  */
189 class LZ4Codec final : public Codec {
190  public:
191   static std::unique_ptr<Codec> create(int level, CodecType type);
192   explicit LZ4Codec(int level, CodecType type);
193
194  private:
195   bool doNeedsUncompressedLength() const override;
196   uint64_t doMaxUncompressedLength() const override;
197
198   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
199
200   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
201   std::unique_ptr<IOBuf> doUncompress(
202       const IOBuf* data,
203       uint64_t uncompressedLength) override;
204
205   bool highCompression_;
206 };
207
208 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
209   return make_unique<LZ4Codec>(level, type);
210 }
211
212 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
213   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
214
215   switch (level) {
216   case COMPRESSION_LEVEL_FASTEST:
217   case COMPRESSION_LEVEL_DEFAULT:
218     level = 1;
219     break;
220   case COMPRESSION_LEVEL_BEST:
221     level = 2;
222     break;
223   }
224   if (level < 1 || level > 2) {
225     throw std::invalid_argument(to<std::string>(
226         "LZ4Codec: invalid level: ", level));
227   }
228   highCompression_ = (level > 1);
229 }
230
231 bool LZ4Codec::doNeedsUncompressedLength() const {
232   return !encodeSize();
233 }
234
235 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
236 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
237 // here.
238 #ifndef LZ4_MAX_INPUT_SIZE
239 # define LZ4_MAX_INPUT_SIZE 0x7E000000
240 #endif
241
242 uint64_t LZ4Codec::doMaxUncompressedLength() const {
243   return LZ4_MAX_INPUT_SIZE;
244 }
245
246 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
247   std::unique_ptr<IOBuf> clone;
248   if (data->isChained()) {
249     // LZ4 doesn't support streaming, so we have to coalesce
250     clone = data->clone();
251     clone->coalesce();
252     data = clone.get();
253   }
254
255   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
256   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
257   if (encodeSize()) {
258     encodeVarintToIOBuf(data->length(), out.get());
259   }
260
261   int n;
262   if (highCompression_) {
263     n = LZ4_compressHC(reinterpret_cast<const char*>(data->data()),
264                        reinterpret_cast<char*>(out->writableTail()),
265                        data->length());
266   } else {
267     n = LZ4_compress(reinterpret_cast<const char*>(data->data()),
268                      reinterpret_cast<char*>(out->writableTail()),
269                      data->length());
270   }
271
272   CHECK_GE(n, 0);
273   CHECK_LE(n, out->capacity());
274
275   out->append(n);
276   return out;
277 }
278
279 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
280     const IOBuf* data,
281     uint64_t uncompressedLength) {
282   std::unique_ptr<IOBuf> clone;
283   if (data->isChained()) {
284     // LZ4 doesn't support streaming, so we have to coalesce
285     clone = data->clone();
286     clone->coalesce();
287     data = clone.get();
288   }
289
290   folly::io::Cursor cursor(data);
291   uint64_t actualUncompressedLength;
292   if (encodeSize()) {
293     actualUncompressedLength = decodeVarintFromCursor(cursor);
294     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
295         uncompressedLength != actualUncompressedLength) {
296       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
297     }
298   } else {
299     actualUncompressedLength = uncompressedLength;
300     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
301         actualUncompressedLength > maxUncompressedLength()) {
302       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
303     }
304   }
305
306   auto p = cursor.peek();
307   auto out = IOBuf::create(actualUncompressedLength);
308   int n = LZ4_decompress_safe(reinterpret_cast<const char*>(p.first),
309                               reinterpret_cast<char*>(out->writableTail()),
310                               p.second,
311                               actualUncompressedLength);
312
313   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
314     throw std::runtime_error(to<std::string>(
315         "LZ4 decompression returned invalid value ", n));
316   }
317   out->append(actualUncompressedLength);
318   return out;
319 }
320
321 #endif  // FOLLY_HAVE_LIBLZ4
322
323 #if FOLLY_HAVE_LIBSNAPPY
324
325 /**
326  * Snappy compression
327  */
328
329 /**
330  * Implementation of snappy::Source that reads from a IOBuf chain.
331  */
332 class IOBufSnappySource final : public snappy::Source {
333  public:
334   explicit IOBufSnappySource(const IOBuf* data);
335   size_t Available() const override;
336   const char* Peek(size_t* len) override;
337   void Skip(size_t n) override;
338  private:
339   size_t available_;
340   io::Cursor cursor_;
341 };
342
343 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
344   : available_(data->computeChainDataLength()),
345     cursor_(data) {
346 }
347
348 size_t IOBufSnappySource::Available() const {
349   return available_;
350 }
351
352 const char* IOBufSnappySource::Peek(size_t* len) {
353   auto p = cursor_.peek();
354   *len = p.second;
355   return reinterpret_cast<const char*>(p.first);
356 }
357
358 void IOBufSnappySource::Skip(size_t n) {
359   CHECK_LE(n, available_);
360   cursor_.skip(n);
361   available_ -= n;
362 }
363
364 class SnappyCodec final : public Codec {
365  public:
366   static std::unique_ptr<Codec> create(int level, CodecType type);
367   explicit SnappyCodec(int level, CodecType type);
368
369  private:
370   uint64_t doMaxUncompressedLength() const override;
371   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
372   std::unique_ptr<IOBuf> doUncompress(
373       const IOBuf* data,
374       uint64_t uncompressedLength) override;
375 };
376
377 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
378   return make_unique<SnappyCodec>(level, type);
379 }
380
381 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
382   DCHECK(type == CodecType::SNAPPY);
383   switch (level) {
384   case COMPRESSION_LEVEL_FASTEST:
385   case COMPRESSION_LEVEL_DEFAULT:
386   case COMPRESSION_LEVEL_BEST:
387     level = 1;
388   }
389   if (level != 1) {
390     throw std::invalid_argument(to<std::string>(
391         "SnappyCodec: invalid level: ", level));
392   }
393 }
394
395 uint64_t SnappyCodec::doMaxUncompressedLength() const {
396   // snappy.h uses uint32_t for lengths, so there's that.
397   return std::numeric_limits<uint32_t>::max();
398 }
399
400 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
401   IOBufSnappySource source(data);
402   auto out =
403     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
404
405   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
406       out->writableTail()));
407
408   size_t n = snappy::Compress(&source, &sink);
409
410   CHECK_LE(n, out->capacity());
411   out->append(n);
412   return out;
413 }
414
415 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
416                                                  uint64_t uncompressedLength) {
417   uint32_t actualUncompressedLength = 0;
418
419   {
420     IOBufSnappySource source(data);
421     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
422       throw std::runtime_error("snappy::GetUncompressedLength failed");
423     }
424     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
425         uncompressedLength != actualUncompressedLength) {
426       throw std::runtime_error("snappy: invalid uncompressed length");
427     }
428   }
429
430   auto out = IOBuf::create(actualUncompressedLength);
431
432   {
433     IOBufSnappySource source(data);
434     if (!snappy::RawUncompress(&source,
435                                reinterpret_cast<char*>(out->writableTail()))) {
436       throw std::runtime_error("snappy::RawUncompress failed");
437     }
438   }
439
440   out->append(actualUncompressedLength);
441   return out;
442 }
443
444 #endif  // FOLLY_HAVE_LIBSNAPPY
445
446 #if FOLLY_HAVE_LIBZ
447 /**
448  * Zlib codec
449  */
450 class ZlibCodec final : public Codec {
451  public:
452   static std::unique_ptr<Codec> create(int level, CodecType type);
453   explicit ZlibCodec(int level, CodecType type);
454
455  private:
456   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
457   std::unique_ptr<IOBuf> doUncompress(
458       const IOBuf* data,
459       uint64_t uncompressedLength) override;
460
461   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
462   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
463
464   int level_;
465 };
466
467 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
468   return make_unique<ZlibCodec>(level, type);
469 }
470
471 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
472   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
473   switch (level) {
474   case COMPRESSION_LEVEL_FASTEST:
475     level = 1;
476     break;
477   case COMPRESSION_LEVEL_DEFAULT:
478     level = Z_DEFAULT_COMPRESSION;
479     break;
480   case COMPRESSION_LEVEL_BEST:
481     level = 9;
482     break;
483   }
484   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
485     throw std::invalid_argument(to<std::string>(
486         "ZlibCodec: invalid level: ", level));
487   }
488   level_ = level;
489 }
490
491 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
492                                                   uint32_t length) {
493   CHECK_EQ(stream->avail_out, 0);
494
495   auto buf = IOBuf::create(length);
496   buf->append(length);
497
498   stream->next_out = buf->writableData();
499   stream->avail_out = buf->length();
500
501   return buf;
502 }
503
504 bool ZlibCodec::doInflate(z_stream* stream,
505                           IOBuf* head,
506                           uint32_t bufferLength) {
507   if (stream->avail_out == 0) {
508     head->prependChain(addOutputBuffer(stream, bufferLength));
509   }
510
511   int rc = inflate(stream, Z_NO_FLUSH);
512
513   switch (rc) {
514   case Z_OK:
515     break;
516   case Z_STREAM_END:
517     return true;
518   case Z_BUF_ERROR:
519   case Z_NEED_DICT:
520   case Z_DATA_ERROR:
521   case Z_MEM_ERROR:
522     throw std::runtime_error(to<std::string>(
523         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
524   default:
525     CHECK(false) << rc << ": " << stream->msg;
526   }
527
528   return false;
529 }
530
531 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
532   z_stream stream;
533   stream.zalloc = nullptr;
534   stream.zfree = nullptr;
535   stream.opaque = nullptr;
536
537   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
538   // base two logarithm of the maximum window size (...) The default value is
539   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
540   // around the compressed data instead of a zlib wrapper. The gzip header
541   // will have no file name, no extra data, no comment, no modification time
542   // (set to zero), no header crc, and the operating system will be set to 255
543   // (unknown)."
544   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
545   // All other parameters (method, memLevel, strategy) get default values from
546   // the zlib manual.
547   int rc = deflateInit2(&stream,
548                         level_,
549                         Z_DEFLATED,
550                         windowBits,
551                         /* memLevel */ 8,
552                         Z_DEFAULT_STRATEGY);
553   if (rc != Z_OK) {
554     throw std::runtime_error(to<std::string>(
555         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
556   }
557
558   stream.next_in = stream.next_out = nullptr;
559   stream.avail_in = stream.avail_out = 0;
560   stream.total_in = stream.total_out = 0;
561
562   bool success = false;
563
564   SCOPE_EXIT {
565     int rc = deflateEnd(&stream);
566     // If we're here because of an exception, it's okay if some data
567     // got dropped.
568     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
569       << rc << ": " << stream.msg;
570   };
571
572   uint64_t uncompressedLength = data->computeChainDataLength();
573   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
574
575   // Max 64MiB in one go
576   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
577   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
578
579   auto out = addOutputBuffer(
580       &stream,
581       (maxCompressedLength <= maxSingleStepLength ?
582        maxCompressedLength :
583        defaultBufferLength));
584
585   for (auto& range : *data) {
586     uint64_t remaining = range.size();
587     uint64_t written = 0;
588     while (remaining) {
589       uint32_t step = (remaining > maxSingleStepLength ?
590                        maxSingleStepLength : remaining);
591       stream.next_in = const_cast<uint8_t*>(range.data() + written);
592       stream.avail_in = step;
593       remaining -= step;
594       written += step;
595
596       while (stream.avail_in != 0) {
597         if (stream.avail_out == 0) {
598           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
599         }
600
601         rc = deflate(&stream, Z_NO_FLUSH);
602
603         CHECK_EQ(rc, Z_OK) << stream.msg;
604       }
605     }
606   }
607
608   do {
609     if (stream.avail_out == 0) {
610       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
611     }
612
613     rc = deflate(&stream, Z_FINISH);
614   } while (rc == Z_OK);
615
616   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
617
618   out->prev()->trimEnd(stream.avail_out);
619
620   success = true;  // we survived
621
622   return out;
623 }
624
625 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
626                                                uint64_t uncompressedLength) {
627   z_stream stream;
628   stream.zalloc = nullptr;
629   stream.zfree = nullptr;
630   stream.opaque = nullptr;
631
632   // "The windowBits parameter is the base two logarithm of the maximum window
633   // size (...) The default value is 15 (...) add 16 to decode only the gzip
634   // format (the zlib format will return a Z_DATA_ERROR)."
635   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
636   int rc = inflateInit2(&stream, windowBits);
637   if (rc != Z_OK) {
638     throw std::runtime_error(to<std::string>(
639         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
640   }
641
642   stream.next_in = stream.next_out = nullptr;
643   stream.avail_in = stream.avail_out = 0;
644   stream.total_in = stream.total_out = 0;
645
646   bool success = false;
647
648   SCOPE_EXIT {
649     int rc = inflateEnd(&stream);
650     // If we're here because of an exception, it's okay if some data
651     // got dropped.
652     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
653       << rc << ": " << stream.msg;
654   };
655
656   // Max 64MiB in one go
657   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
658   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
659
660   auto out = addOutputBuffer(
661       &stream,
662       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
663         uncompressedLength <= maxSingleStepLength) ?
664        uncompressedLength :
665        defaultBufferLength));
666
667   bool streamEnd = false;
668   for (auto& range : *data) {
669     if (range.empty()) {
670       continue;
671     }
672
673     stream.next_in = const_cast<uint8_t*>(range.data());
674     stream.avail_in = range.size();
675
676     while (stream.avail_in != 0) {
677       if (streamEnd) {
678         throw std::runtime_error(to<std::string>(
679             "ZlibCodec: junk after end of data"));
680       }
681
682       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
683     }
684   }
685
686   while (!streamEnd) {
687     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
688   }
689
690   out->prev()->trimEnd(stream.avail_out);
691
692   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
693       uncompressedLength != stream.total_out) {
694     throw std::runtime_error(to<std::string>(
695         "ZlibCodec: invalid uncompressed length"));
696   }
697
698   success = true;  // we survived
699
700   return out;
701 }
702
703 #endif  // FOLLY_HAVE_LIBZ
704
705 #if FOLLY_HAVE_LIBLZMA
706
707 /**
708  * LZMA2 compression
709  */
710 class LZMA2Codec final : public Codec {
711  public:
712   static std::unique_ptr<Codec> create(int level, CodecType type);
713   explicit LZMA2Codec(int level, CodecType type);
714
715  private:
716   bool doNeedsUncompressedLength() const override;
717   uint64_t doMaxUncompressedLength() const override;
718
719   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
720
721   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
722   std::unique_ptr<IOBuf> doUncompress(
723       const IOBuf* data,
724       uint64_t uncompressedLength) override;
725
726   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
727   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
728
729   int level_;
730 };
731
732 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
733   return make_unique<LZMA2Codec>(level, type);
734 }
735
736 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
737   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
738   switch (level) {
739   case COMPRESSION_LEVEL_FASTEST:
740     level = 0;
741     break;
742   case COMPRESSION_LEVEL_DEFAULT:
743     level = LZMA_PRESET_DEFAULT;
744     break;
745   case COMPRESSION_LEVEL_BEST:
746     level = 9;
747     break;
748   }
749   if (level < 0 || level > 9) {
750     throw std::invalid_argument(to<std::string>(
751         "LZMA2Codec: invalid level: ", level));
752   }
753   level_ = level;
754 }
755
756 bool LZMA2Codec::doNeedsUncompressedLength() const {
757   return !encodeSize();
758 }
759
760 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
761   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
762   return uint64_t(1) << 63;
763 }
764
765 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
766     lzma_stream* stream,
767     size_t length) {
768
769   CHECK_EQ(stream->avail_out, 0);
770
771   auto buf = IOBuf::create(length);
772   buf->append(length);
773
774   stream->next_out = buf->writableData();
775   stream->avail_out = buf->length();
776
777   return buf;
778 }
779
780 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
781   lzma_ret rc;
782   lzma_stream stream = LZMA_STREAM_INIT;
783
784   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
785   if (rc != LZMA_OK) {
786     throw std::runtime_error(folly::to<std::string>(
787       "LZMA2Codec: lzma_easy_encoder error: ", rc));
788   }
789
790   SCOPE_EXIT { lzma_end(&stream); };
791
792   uint64_t uncompressedLength = data->computeChainDataLength();
793   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
794
795   // Max 64MiB in one go
796   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
797   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
798
799   auto out = addOutputBuffer(
800     &stream,
801     (maxCompressedLength <= maxSingleStepLength ?
802      maxCompressedLength :
803      defaultBufferLength));
804
805   if (encodeSize()) {
806     auto size = IOBuf::createCombined(kMaxVarintLength64);
807     encodeVarintToIOBuf(uncompressedLength, size.get());
808     size->appendChain(std::move(out));
809     out = std::move(size);
810   }
811
812   for (auto& range : *data) {
813     if (range.empty()) {
814       continue;
815     }
816
817     stream.next_in = const_cast<uint8_t*>(range.data());
818     stream.avail_in = range.size();
819
820     while (stream.avail_in != 0) {
821       if (stream.avail_out == 0) {
822         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
823       }
824
825       rc = lzma_code(&stream, LZMA_RUN);
826
827       if (rc != LZMA_OK) {
828         throw std::runtime_error(folly::to<std::string>(
829           "LZMA2Codec: lzma_code error: ", rc));
830       }
831     }
832   }
833
834   do {
835     if (stream.avail_out == 0) {
836       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
837     }
838
839     rc = lzma_code(&stream, LZMA_FINISH);
840   } while (rc == LZMA_OK);
841
842   if (rc != LZMA_STREAM_END) {
843     throw std::runtime_error(folly::to<std::string>(
844       "LZMA2Codec: lzma_code ended with error: ", rc));
845   }
846
847   out->prev()->trimEnd(stream.avail_out);
848
849   return out;
850 }
851
852 bool LZMA2Codec::doInflate(lzma_stream* stream,
853                           IOBuf* head,
854                           size_t bufferLength) {
855   if (stream->avail_out == 0) {
856     head->prependChain(addOutputBuffer(stream, bufferLength));
857   }
858
859   lzma_ret rc = lzma_code(stream, LZMA_RUN);
860
861   switch (rc) {
862   case LZMA_OK:
863     break;
864   case LZMA_STREAM_END:
865     return true;
866   default:
867     throw std::runtime_error(to<std::string>(
868         "LZMA2Codec: lzma_code error: ", rc));
869   }
870
871   return false;
872 }
873
874 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
875                                                uint64_t uncompressedLength) {
876   lzma_ret rc;
877   lzma_stream stream = LZMA_STREAM_INIT;
878
879   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
880   if (rc != LZMA_OK) {
881     throw std::runtime_error(folly::to<std::string>(
882       "LZMA2Codec: lzma_auto_decoder error: ", rc));
883   }
884
885   SCOPE_EXIT { lzma_end(&stream); };
886
887   // Max 64MiB in one go
888   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
889   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
890
891   folly::io::Cursor cursor(data);
892   uint64_t actualUncompressedLength;
893   if (encodeSize()) {
894     actualUncompressedLength = decodeVarintFromCursor(cursor);
895     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
896         uncompressedLength != actualUncompressedLength) {
897       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
898     }
899   } else {
900     actualUncompressedLength = uncompressedLength;
901     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
902   }
903
904   auto out = addOutputBuffer(
905       &stream,
906       (actualUncompressedLength <= maxSingleStepLength ?
907        actualUncompressedLength :
908        defaultBufferLength));
909
910   bool streamEnd = false;
911   auto buf = cursor.peek();
912   while (buf.second != 0) {
913     stream.next_in = const_cast<uint8_t*>(buf.first);
914     stream.avail_in = buf.second;
915
916     while (stream.avail_in != 0) {
917       if (streamEnd) {
918         throw std::runtime_error(to<std::string>(
919             "LZMA2Codec: junk after end of data"));
920       }
921
922       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
923     }
924
925     cursor.skip(buf.second);
926     buf = cursor.peek();
927   }
928
929   while (!streamEnd) {
930     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
931   }
932
933   out->prev()->trimEnd(stream.avail_out);
934
935   if (actualUncompressedLength != stream.total_out) {
936     throw std::runtime_error(to<std::string>(
937         "LZMA2Codec: invalid uncompressed length"));
938   }
939
940   return out;
941 }
942
943 #endif  // FOLLY_HAVE_LIBLZMA
944
945 #ifdef FOLLY_HAVE_LIBZSTD
946
947 /**
948  * ZSTD_BETA compression
949  */
950 class ZSTDCodec final : public Codec {
951  public:
952   static std::unique_ptr<Codec> create(int level, CodecType);
953   explicit ZSTDCodec(int level, CodecType type);
954
955  private:
956   bool doNeedsUncompressedLength() const override;
957   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
958   std::unique_ptr<IOBuf> doUncompress(
959       const IOBuf* data,
960       uint64_t uncompressedLength) override;
961
962   int level_{1};
963 };
964
965 std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
966   return make_unique<ZSTDCodec>(level, type);
967 }
968
969 ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
970   DCHECK(type == CodecType::ZSTD_BETA);
971   switch (level) {
972     case COMPRESSION_LEVEL_FASTEST:
973       level_ = 1;
974       break;
975     case COMPRESSION_LEVEL_DEFAULT:
976       level_ = 1;
977       break;
978     case COMPRESSION_LEVEL_BEST:
979       level_ = 19;
980       break;
981   }
982 }
983
984 bool ZSTDCodec::doNeedsUncompressedLength() const {
985   return true;
986 }
987
988 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
989   size_t rc;
990   size_t maxCompressedLength = ZSTD_compressBound(data->length());
991   auto out = IOBuf::createCombined(maxCompressedLength);
992
993   CHECK_EQ(out->length(), 0);
994
995   rc = ZSTD_compress(out->writableTail(),
996                      out->capacity(),
997                      data->data(),
998                      data->length(),
999                      level_);
1000
1001   if (ZSTD_isError(rc)) {
1002     throw std::runtime_error(to<std::string>(
1003           "ZSTD compression returned an error: ",
1004           ZSTD_getErrorName(rc)));
1005   }
1006
1007   out->append(rc);
1008   CHECK_EQ(out->length(), rc);
1009
1010   return out;
1011 }
1012
1013 std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(const IOBuf* data,
1014                                                uint64_t uncompressedLength) {
1015   size_t rc;
1016   auto out = IOBuf::createCombined(uncompressedLength);
1017
1018   CHECK_GE(out->capacity(), uncompressedLength);
1019   CHECK_EQ(out->length(), 0);
1020
1021   rc = ZSTD_decompress(
1022       out->writableTail(), out->capacity(), data->data(), data->length());
1023
1024   if (ZSTD_isError(rc)) {
1025     throw std::runtime_error(to<std::string>(
1026           "ZSTD decompression returned an error: ",
1027           ZSTD_getErrorName(rc)));
1028   }
1029
1030   out->append(rc);
1031   CHECK_EQ(out->length(), rc);
1032
1033   return out;
1034 }
1035
1036 #endif  // FOLLY_HAVE_LIBZSTD
1037
1038 }  // namespace
1039
1040 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
1041   typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
1042
1043   static CodecFactory codecFactories[
1044     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
1045     nullptr,  // USER_DEFINED
1046     NoCompressionCodec::create,
1047
1048 #if FOLLY_HAVE_LIBLZ4
1049     LZ4Codec::create,
1050 #else
1051     nullptr,
1052 #endif
1053
1054 #if FOLLY_HAVE_LIBSNAPPY
1055     SnappyCodec::create,
1056 #else
1057     nullptr,
1058 #endif
1059
1060 #if FOLLY_HAVE_LIBZ
1061     ZlibCodec::create,
1062 #else
1063     nullptr,
1064 #endif
1065
1066 #if FOLLY_HAVE_LIBLZ4
1067     LZ4Codec::create,
1068 #else
1069     nullptr,
1070 #endif
1071
1072 #if FOLLY_HAVE_LIBLZMA
1073     LZMA2Codec::create,
1074     LZMA2Codec::create,
1075 #else
1076     nullptr,
1077     nullptr,
1078 #endif
1079
1080 #if FOLLY_HAVE_LIBZSTD
1081     ZSTDCodec::create,
1082 #else
1083     nullptr,
1084 #endif
1085
1086 #if FOLLY_HAVE_LIBZ
1087     ZlibCodec::create,
1088 #else
1089     nullptr,
1090 #endif
1091   };
1092
1093   size_t idx = static_cast<size_t>(type);
1094   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1095     throw std::invalid_argument(to<std::string>(
1096         "Compression type ", idx, " not supported"));
1097   }
1098   auto factory = codecFactories[idx];
1099   if (!factory) {
1100     throw std::invalid_argument(to<std::string>(
1101         "Compression type ", idx, " not supported"));
1102   }
1103   auto codec = (*factory)(level, type);
1104   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
1105   return codec;
1106 }
1107
1108 }}  // namespaces