zstd is no longer in beta -- s/ZSTD_BETA/ZSTD/g
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2016 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #if FOLLY_HAVE_LIBZSTD
40 #include <zstd.h>
41 #endif
42
43 #include <folly/Conv.h>
44 #include <folly/Memory.h>
45 #include <folly/Portability.h>
46 #include <folly/ScopeGuard.h>
47 #include <folly/Varint.h>
48 #include <folly/io/Cursor.h>
49
50 namespace folly { namespace io {
51
52 Codec::Codec(CodecType type) : type_(type) { }
53
54 // Ensure consistent behavior in the nullptr case
55 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
56   uint64_t len = data->computeChainDataLength();
57   if (len == 0) {
58     return IOBuf::create(0);
59   } else if (len > maxUncompressedLength()) {
60     throw std::runtime_error("Codec: uncompressed length too large");
61   }
62
63   return doCompress(data);
64 }
65
66 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
67                                          uint64_t uncompressedLength) {
68   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
69     if (needsUncompressedLength()) {
70       throw std::invalid_argument("Codec: uncompressed length required");
71     }
72   } else if (uncompressedLength > maxUncompressedLength()) {
73     throw std::runtime_error("Codec: uncompressed length too large");
74   }
75
76   if (data->empty()) {
77     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
78         uncompressedLength != 0) {
79       throw std::runtime_error("Codec: invalid uncompressed length");
80     }
81     return IOBuf::create(0);
82   }
83
84   return doUncompress(data, uncompressedLength);
85 }
86
87 bool Codec::needsUncompressedLength() const {
88   return doNeedsUncompressedLength();
89 }
90
91 uint64_t Codec::maxUncompressedLength() const {
92   return doMaxUncompressedLength();
93 }
94
95 bool Codec::doNeedsUncompressedLength() const {
96   return false;
97 }
98
99 uint64_t Codec::doMaxUncompressedLength() const {
100   return UNLIMITED_UNCOMPRESSED_LENGTH;
101 }
102
103 namespace {
104
105 /**
106  * No compression
107  */
108 class NoCompressionCodec final : public Codec {
109  public:
110   static std::unique_ptr<Codec> create(int level, CodecType type);
111   explicit NoCompressionCodec(int level, CodecType type);
112
113  private:
114   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
115   std::unique_ptr<IOBuf> doUncompress(
116       const IOBuf* data,
117       uint64_t uncompressedLength) override;
118 };
119
120 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
121   return make_unique<NoCompressionCodec>(level, type);
122 }
123
124 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
125   : Codec(type) {
126   DCHECK(type == CodecType::NO_COMPRESSION);
127   switch (level) {
128   case COMPRESSION_LEVEL_DEFAULT:
129   case COMPRESSION_LEVEL_FASTEST:
130   case COMPRESSION_LEVEL_BEST:
131     level = 0;
132   }
133   if (level != 0) {
134     throw std::invalid_argument(to<std::string>(
135         "NoCompressionCodec: invalid level ", level));
136   }
137 }
138
139 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
140     const IOBuf* data) {
141   return data->clone();
142 }
143
144 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
145     const IOBuf* data,
146     uint64_t uncompressedLength) {
147   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
148       data->computeChainDataLength() != uncompressedLength) {
149     throw std::runtime_error(to<std::string>(
150         "NoCompressionCodec: invalid uncompressed length"));
151   }
152   return data->clone();
153 }
154
155 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
156
157 namespace {
158
159 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
160   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
161   out->append(encodeVarint(val, out->writableTail()));
162 }
163
164 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
165   uint64_t val = 0;
166   int8_t b = 0;
167   for (int shift = 0; shift <= 63; shift += 7) {
168     b = cursor.read<int8_t>();
169     val |= static_cast<uint64_t>(b & 0x7f) << shift;
170     if (b >= 0) {
171       break;
172     }
173   }
174   if (b < 0) {
175     throw std::invalid_argument("Invalid varint value. Too big.");
176   }
177   return val;
178 }
179
180 }  // namespace
181
182 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
183
184 #if FOLLY_HAVE_LIBLZ4
185
186 /**
187  * LZ4 compression
188  */
189 class LZ4Codec final : public Codec {
190  public:
191   static std::unique_ptr<Codec> create(int level, CodecType type);
192   explicit LZ4Codec(int level, CodecType type);
193
194  private:
195   bool doNeedsUncompressedLength() const override;
196   uint64_t doMaxUncompressedLength() const override;
197
198   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
199
200   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
201   std::unique_ptr<IOBuf> doUncompress(
202       const IOBuf* data,
203       uint64_t uncompressedLength) override;
204
205   bool highCompression_;
206 };
207
208 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
209   return make_unique<LZ4Codec>(level, type);
210 }
211
212 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
213   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
214
215   switch (level) {
216   case COMPRESSION_LEVEL_FASTEST:
217   case COMPRESSION_LEVEL_DEFAULT:
218     level = 1;
219     break;
220   case COMPRESSION_LEVEL_BEST:
221     level = 2;
222     break;
223   }
224   if (level < 1 || level > 2) {
225     throw std::invalid_argument(to<std::string>(
226         "LZ4Codec: invalid level: ", level));
227   }
228   highCompression_ = (level > 1);
229 }
230
231 bool LZ4Codec::doNeedsUncompressedLength() const {
232   return !encodeSize();
233 }
234
235 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
236 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
237 // here.
238 #ifndef LZ4_MAX_INPUT_SIZE
239 # define LZ4_MAX_INPUT_SIZE 0x7E000000
240 #endif
241
242 uint64_t LZ4Codec::doMaxUncompressedLength() const {
243   return LZ4_MAX_INPUT_SIZE;
244 }
245
246 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
247   std::unique_ptr<IOBuf> clone;
248   if (data->isChained()) {
249     // LZ4 doesn't support streaming, so we have to coalesce
250     clone = data->clone();
251     clone->coalesce();
252     data = clone.get();
253   }
254
255   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
256   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
257   if (encodeSize()) {
258     encodeVarintToIOBuf(data->length(), out.get());
259   }
260
261   int n;
262   if (highCompression_) {
263     n = LZ4_compressHC(reinterpret_cast<const char*>(data->data()),
264                        reinterpret_cast<char*>(out->writableTail()),
265                        data->length());
266   } else {
267     n = LZ4_compress(reinterpret_cast<const char*>(data->data()),
268                      reinterpret_cast<char*>(out->writableTail()),
269                      data->length());
270   }
271
272   CHECK_GE(n, 0);
273   CHECK_LE(n, out->capacity());
274
275   out->append(n);
276   return out;
277 }
278
279 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
280     const IOBuf* data,
281     uint64_t uncompressedLength) {
282   std::unique_ptr<IOBuf> clone;
283   if (data->isChained()) {
284     // LZ4 doesn't support streaming, so we have to coalesce
285     clone = data->clone();
286     clone->coalesce();
287     data = clone.get();
288   }
289
290   folly::io::Cursor cursor(data);
291   uint64_t actualUncompressedLength;
292   if (encodeSize()) {
293     actualUncompressedLength = decodeVarintFromCursor(cursor);
294     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
295         uncompressedLength != actualUncompressedLength) {
296       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
297     }
298   } else {
299     actualUncompressedLength = uncompressedLength;
300     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
301         actualUncompressedLength > maxUncompressedLength()) {
302       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
303     }
304   }
305
306   auto sp = StringPiece{cursor.peekBytes()};
307   auto out = IOBuf::create(actualUncompressedLength);
308   int n = LZ4_decompress_safe(
309       sp.data(),
310       reinterpret_cast<char*>(out->writableTail()),
311       sp.size(),
312       actualUncompressedLength);
313
314   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
315     throw std::runtime_error(to<std::string>(
316         "LZ4 decompression returned invalid value ", n));
317   }
318   out->append(actualUncompressedLength);
319   return out;
320 }
321
322 #endif  // FOLLY_HAVE_LIBLZ4
323
324 #if FOLLY_HAVE_LIBSNAPPY
325
326 /**
327  * Snappy compression
328  */
329
330 /**
331  * Implementation of snappy::Source that reads from a IOBuf chain.
332  */
333 class IOBufSnappySource final : public snappy::Source {
334  public:
335   explicit IOBufSnappySource(const IOBuf* data);
336   size_t Available() const override;
337   const char* Peek(size_t* len) override;
338   void Skip(size_t n) override;
339  private:
340   size_t available_;
341   io::Cursor cursor_;
342 };
343
344 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
345   : available_(data->computeChainDataLength()),
346     cursor_(data) {
347 }
348
349 size_t IOBufSnappySource::Available() const {
350   return available_;
351 }
352
353 const char* IOBufSnappySource::Peek(size_t* len) {
354   auto sp = StringPiece{cursor_.peekBytes()};
355   *len = sp.size();
356   return sp.data();
357 }
358
359 void IOBufSnappySource::Skip(size_t n) {
360   CHECK_LE(n, available_);
361   cursor_.skip(n);
362   available_ -= n;
363 }
364
365 class SnappyCodec final : public Codec {
366  public:
367   static std::unique_ptr<Codec> create(int level, CodecType type);
368   explicit SnappyCodec(int level, CodecType type);
369
370  private:
371   uint64_t doMaxUncompressedLength() const override;
372   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
373   std::unique_ptr<IOBuf> doUncompress(
374       const IOBuf* data,
375       uint64_t uncompressedLength) override;
376 };
377
378 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
379   return make_unique<SnappyCodec>(level, type);
380 }
381
382 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
383   DCHECK(type == CodecType::SNAPPY);
384   switch (level) {
385   case COMPRESSION_LEVEL_FASTEST:
386   case COMPRESSION_LEVEL_DEFAULT:
387   case COMPRESSION_LEVEL_BEST:
388     level = 1;
389   }
390   if (level != 1) {
391     throw std::invalid_argument(to<std::string>(
392         "SnappyCodec: invalid level: ", level));
393   }
394 }
395
396 uint64_t SnappyCodec::doMaxUncompressedLength() const {
397   // snappy.h uses uint32_t for lengths, so there's that.
398   return std::numeric_limits<uint32_t>::max();
399 }
400
401 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
402   IOBufSnappySource source(data);
403   auto out =
404     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
405
406   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
407       out->writableTail()));
408
409   size_t n = snappy::Compress(&source, &sink);
410
411   CHECK_LE(n, out->capacity());
412   out->append(n);
413   return out;
414 }
415
416 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
417                                                  uint64_t uncompressedLength) {
418   uint32_t actualUncompressedLength = 0;
419
420   {
421     IOBufSnappySource source(data);
422     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
423       throw std::runtime_error("snappy::GetUncompressedLength failed");
424     }
425     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
426         uncompressedLength != actualUncompressedLength) {
427       throw std::runtime_error("snappy: invalid uncompressed length");
428     }
429   }
430
431   auto out = IOBuf::create(actualUncompressedLength);
432
433   {
434     IOBufSnappySource source(data);
435     if (!snappy::RawUncompress(&source,
436                                reinterpret_cast<char*>(out->writableTail()))) {
437       throw std::runtime_error("snappy::RawUncompress failed");
438     }
439   }
440
441   out->append(actualUncompressedLength);
442   return out;
443 }
444
445 #endif  // FOLLY_HAVE_LIBSNAPPY
446
447 #if FOLLY_HAVE_LIBZ
448 /**
449  * Zlib codec
450  */
451 class ZlibCodec final : public Codec {
452  public:
453   static std::unique_ptr<Codec> create(int level, CodecType type);
454   explicit ZlibCodec(int level, CodecType type);
455
456  private:
457   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
458   std::unique_ptr<IOBuf> doUncompress(
459       const IOBuf* data,
460       uint64_t uncompressedLength) override;
461
462   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
463   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
464
465   int level_;
466 };
467
468 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
469   return make_unique<ZlibCodec>(level, type);
470 }
471
472 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
473   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
474   switch (level) {
475   case COMPRESSION_LEVEL_FASTEST:
476     level = 1;
477     break;
478   case COMPRESSION_LEVEL_DEFAULT:
479     level = Z_DEFAULT_COMPRESSION;
480     break;
481   case COMPRESSION_LEVEL_BEST:
482     level = 9;
483     break;
484   }
485   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
486     throw std::invalid_argument(to<std::string>(
487         "ZlibCodec: invalid level: ", level));
488   }
489   level_ = level;
490 }
491
492 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
493                                                   uint32_t length) {
494   CHECK_EQ(stream->avail_out, 0);
495
496   auto buf = IOBuf::create(length);
497   buf->append(length);
498
499   stream->next_out = buf->writableData();
500   stream->avail_out = buf->length();
501
502   return buf;
503 }
504
505 bool ZlibCodec::doInflate(z_stream* stream,
506                           IOBuf* head,
507                           uint32_t bufferLength) {
508   if (stream->avail_out == 0) {
509     head->prependChain(addOutputBuffer(stream, bufferLength));
510   }
511
512   int rc = inflate(stream, Z_NO_FLUSH);
513
514   switch (rc) {
515   case Z_OK:
516     break;
517   case Z_STREAM_END:
518     return true;
519   case Z_BUF_ERROR:
520   case Z_NEED_DICT:
521   case Z_DATA_ERROR:
522   case Z_MEM_ERROR:
523     throw std::runtime_error(to<std::string>(
524         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
525   default:
526     CHECK(false) << rc << ": " << stream->msg;
527   }
528
529   return false;
530 }
531
532 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
533   z_stream stream;
534   stream.zalloc = nullptr;
535   stream.zfree = nullptr;
536   stream.opaque = nullptr;
537
538   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
539   // base two logarithm of the maximum window size (...) The default value is
540   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
541   // around the compressed data instead of a zlib wrapper. The gzip header
542   // will have no file name, no extra data, no comment, no modification time
543   // (set to zero), no header crc, and the operating system will be set to 255
544   // (unknown)."
545   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
546   // All other parameters (method, memLevel, strategy) get default values from
547   // the zlib manual.
548   int rc = deflateInit2(&stream,
549                         level_,
550                         Z_DEFLATED,
551                         windowBits,
552                         /* memLevel */ 8,
553                         Z_DEFAULT_STRATEGY);
554   if (rc != Z_OK) {
555     throw std::runtime_error(to<std::string>(
556         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
557   }
558
559   stream.next_in = stream.next_out = nullptr;
560   stream.avail_in = stream.avail_out = 0;
561   stream.total_in = stream.total_out = 0;
562
563   bool success = false;
564
565   SCOPE_EXIT {
566     int rc = deflateEnd(&stream);
567     // If we're here because of an exception, it's okay if some data
568     // got dropped.
569     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
570       << rc << ": " << stream.msg;
571   };
572
573   uint64_t uncompressedLength = data->computeChainDataLength();
574   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
575
576   // Max 64MiB in one go
577   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
578   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
579
580   auto out = addOutputBuffer(
581       &stream,
582       (maxCompressedLength <= maxSingleStepLength ?
583        maxCompressedLength :
584        defaultBufferLength));
585
586   for (auto& range : *data) {
587     uint64_t remaining = range.size();
588     uint64_t written = 0;
589     while (remaining) {
590       uint32_t step = (remaining > maxSingleStepLength ?
591                        maxSingleStepLength : remaining);
592       stream.next_in = const_cast<uint8_t*>(range.data() + written);
593       stream.avail_in = step;
594       remaining -= step;
595       written += step;
596
597       while (stream.avail_in != 0) {
598         if (stream.avail_out == 0) {
599           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
600         }
601
602         rc = deflate(&stream, Z_NO_FLUSH);
603
604         CHECK_EQ(rc, Z_OK) << stream.msg;
605       }
606     }
607   }
608
609   do {
610     if (stream.avail_out == 0) {
611       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
612     }
613
614     rc = deflate(&stream, Z_FINISH);
615   } while (rc == Z_OK);
616
617   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
618
619   out->prev()->trimEnd(stream.avail_out);
620
621   success = true;  // we survived
622
623   return out;
624 }
625
626 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
627                                                uint64_t uncompressedLength) {
628   z_stream stream;
629   stream.zalloc = nullptr;
630   stream.zfree = nullptr;
631   stream.opaque = nullptr;
632
633   // "The windowBits parameter is the base two logarithm of the maximum window
634   // size (...) The default value is 15 (...) add 16 to decode only the gzip
635   // format (the zlib format will return a Z_DATA_ERROR)."
636   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
637   int rc = inflateInit2(&stream, windowBits);
638   if (rc != Z_OK) {
639     throw std::runtime_error(to<std::string>(
640         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
641   }
642
643   stream.next_in = stream.next_out = nullptr;
644   stream.avail_in = stream.avail_out = 0;
645   stream.total_in = stream.total_out = 0;
646
647   bool success = false;
648
649   SCOPE_EXIT {
650     int rc = inflateEnd(&stream);
651     // If we're here because of an exception, it's okay if some data
652     // got dropped.
653     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
654       << rc << ": " << stream.msg;
655   };
656
657   // Max 64MiB in one go
658   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
659   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
660
661   auto out = addOutputBuffer(
662       &stream,
663       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
664         uncompressedLength <= maxSingleStepLength) ?
665        uncompressedLength :
666        defaultBufferLength));
667
668   bool streamEnd = false;
669   for (auto& range : *data) {
670     if (range.empty()) {
671       continue;
672     }
673
674     stream.next_in = const_cast<uint8_t*>(range.data());
675     stream.avail_in = range.size();
676
677     while (stream.avail_in != 0) {
678       if (streamEnd) {
679         throw std::runtime_error(to<std::string>(
680             "ZlibCodec: junk after end of data"));
681       }
682
683       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
684     }
685   }
686
687   while (!streamEnd) {
688     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
689   }
690
691   out->prev()->trimEnd(stream.avail_out);
692
693   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
694       uncompressedLength != stream.total_out) {
695     throw std::runtime_error(to<std::string>(
696         "ZlibCodec: invalid uncompressed length"));
697   }
698
699   success = true;  // we survived
700
701   return out;
702 }
703
704 #endif  // FOLLY_HAVE_LIBZ
705
706 #if FOLLY_HAVE_LIBLZMA
707
708 /**
709  * LZMA2 compression
710  */
711 class LZMA2Codec final : public Codec {
712  public:
713   static std::unique_ptr<Codec> create(int level, CodecType type);
714   explicit LZMA2Codec(int level, CodecType type);
715
716  private:
717   bool doNeedsUncompressedLength() const override;
718   uint64_t doMaxUncompressedLength() const override;
719
720   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
721
722   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
723   std::unique_ptr<IOBuf> doUncompress(
724       const IOBuf* data,
725       uint64_t uncompressedLength) override;
726
727   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
728   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
729
730   int level_;
731 };
732
733 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
734   return make_unique<LZMA2Codec>(level, type);
735 }
736
737 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
738   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
739   switch (level) {
740   case COMPRESSION_LEVEL_FASTEST:
741     level = 0;
742     break;
743   case COMPRESSION_LEVEL_DEFAULT:
744     level = LZMA_PRESET_DEFAULT;
745     break;
746   case COMPRESSION_LEVEL_BEST:
747     level = 9;
748     break;
749   }
750   if (level < 0 || level > 9) {
751     throw std::invalid_argument(to<std::string>(
752         "LZMA2Codec: invalid level: ", level));
753   }
754   level_ = level;
755 }
756
757 bool LZMA2Codec::doNeedsUncompressedLength() const {
758   return !encodeSize();
759 }
760
761 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
762   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
763   return uint64_t(1) << 63;
764 }
765
766 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
767     lzma_stream* stream,
768     size_t length) {
769
770   CHECK_EQ(stream->avail_out, 0);
771
772   auto buf = IOBuf::create(length);
773   buf->append(length);
774
775   stream->next_out = buf->writableData();
776   stream->avail_out = buf->length();
777
778   return buf;
779 }
780
781 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
782   lzma_ret rc;
783   lzma_stream stream = LZMA_STREAM_INIT;
784
785   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
786   if (rc != LZMA_OK) {
787     throw std::runtime_error(folly::to<std::string>(
788       "LZMA2Codec: lzma_easy_encoder error: ", rc));
789   }
790
791   SCOPE_EXIT { lzma_end(&stream); };
792
793   uint64_t uncompressedLength = data->computeChainDataLength();
794   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
795
796   // Max 64MiB in one go
797   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
798   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
799
800   auto out = addOutputBuffer(
801     &stream,
802     (maxCompressedLength <= maxSingleStepLength ?
803      maxCompressedLength :
804      defaultBufferLength));
805
806   if (encodeSize()) {
807     auto size = IOBuf::createCombined(kMaxVarintLength64);
808     encodeVarintToIOBuf(uncompressedLength, size.get());
809     size->appendChain(std::move(out));
810     out = std::move(size);
811   }
812
813   for (auto& range : *data) {
814     if (range.empty()) {
815       continue;
816     }
817
818     stream.next_in = const_cast<uint8_t*>(range.data());
819     stream.avail_in = range.size();
820
821     while (stream.avail_in != 0) {
822       if (stream.avail_out == 0) {
823         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
824       }
825
826       rc = lzma_code(&stream, LZMA_RUN);
827
828       if (rc != LZMA_OK) {
829         throw std::runtime_error(folly::to<std::string>(
830           "LZMA2Codec: lzma_code error: ", rc));
831       }
832     }
833   }
834
835   do {
836     if (stream.avail_out == 0) {
837       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
838     }
839
840     rc = lzma_code(&stream, LZMA_FINISH);
841   } while (rc == LZMA_OK);
842
843   if (rc != LZMA_STREAM_END) {
844     throw std::runtime_error(folly::to<std::string>(
845       "LZMA2Codec: lzma_code ended with error: ", rc));
846   }
847
848   out->prev()->trimEnd(stream.avail_out);
849
850   return out;
851 }
852
853 bool LZMA2Codec::doInflate(lzma_stream* stream,
854                           IOBuf* head,
855                           size_t bufferLength) {
856   if (stream->avail_out == 0) {
857     head->prependChain(addOutputBuffer(stream, bufferLength));
858   }
859
860   lzma_ret rc = lzma_code(stream, LZMA_RUN);
861
862   switch (rc) {
863   case LZMA_OK:
864     break;
865   case LZMA_STREAM_END:
866     return true;
867   default:
868     throw std::runtime_error(to<std::string>(
869         "LZMA2Codec: lzma_code error: ", rc));
870   }
871
872   return false;
873 }
874
875 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
876                                                uint64_t uncompressedLength) {
877   lzma_ret rc;
878   lzma_stream stream = LZMA_STREAM_INIT;
879
880   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
881   if (rc != LZMA_OK) {
882     throw std::runtime_error(folly::to<std::string>(
883       "LZMA2Codec: lzma_auto_decoder error: ", rc));
884   }
885
886   SCOPE_EXIT { lzma_end(&stream); };
887
888   // Max 64MiB in one go
889   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
890   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
891
892   folly::io::Cursor cursor(data);
893   uint64_t actualUncompressedLength;
894   if (encodeSize()) {
895     actualUncompressedLength = decodeVarintFromCursor(cursor);
896     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
897         uncompressedLength != actualUncompressedLength) {
898       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
899     }
900   } else {
901     actualUncompressedLength = uncompressedLength;
902     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
903   }
904
905   auto out = addOutputBuffer(
906       &stream,
907       (actualUncompressedLength <= maxSingleStepLength ?
908        actualUncompressedLength :
909        defaultBufferLength));
910
911   bool streamEnd = false;
912   auto buf = cursor.peekBytes();
913   while (!buf.empty()) {
914     stream.next_in = const_cast<uint8_t*>(buf.data());
915     stream.avail_in = buf.size();
916
917     while (stream.avail_in != 0) {
918       if (streamEnd) {
919         throw std::runtime_error(to<std::string>(
920             "LZMA2Codec: junk after end of data"));
921       }
922
923       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
924     }
925
926     cursor.skip(buf.size());
927     buf = cursor.peekBytes();
928   }
929
930   while (!streamEnd) {
931     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
932   }
933
934   out->prev()->trimEnd(stream.avail_out);
935
936   if (actualUncompressedLength != stream.total_out) {
937     throw std::runtime_error(to<std::string>(
938         "LZMA2Codec: invalid uncompressed length"));
939   }
940
941   return out;
942 }
943
944 #endif  // FOLLY_HAVE_LIBLZMA
945
946 #ifdef FOLLY_HAVE_LIBZSTD
947
948 /**
949  * ZSTD compression
950  */
951 class ZSTDCodec final : public Codec {
952  public:
953   static std::unique_ptr<Codec> create(int level, CodecType);
954   explicit ZSTDCodec(int level, CodecType type);
955
956  private:
957   bool doNeedsUncompressedLength() const override;
958   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
959   std::unique_ptr<IOBuf> doUncompress(
960       const IOBuf* data,
961       uint64_t uncompressedLength) override;
962
963   int level_{1};
964 };
965
966 std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
967   return make_unique<ZSTDCodec>(level, type);
968 }
969
970 ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
971   DCHECK(type == CodecType::ZSTD);
972   switch (level) {
973     case COMPRESSION_LEVEL_FASTEST:
974       level_ = 1;
975       break;
976     case COMPRESSION_LEVEL_DEFAULT:
977       level_ = 1;
978       break;
979     case COMPRESSION_LEVEL_BEST:
980       level_ = 19;
981       break;
982   }
983 }
984
985 bool ZSTDCodec::doNeedsUncompressedLength() const {
986   return true;
987 }
988
989 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
990   size_t rc;
991   size_t maxCompressedLength = ZSTD_compressBound(data->length());
992   auto out = IOBuf::createCombined(maxCompressedLength);
993
994   CHECK_EQ(out->length(), 0);
995
996   rc = ZSTD_compress(out->writableTail(),
997                      out->capacity(),
998                      data->data(),
999                      data->length(),
1000                      level_);
1001
1002   if (ZSTD_isError(rc)) {
1003     throw std::runtime_error(to<std::string>(
1004           "ZSTD compression returned an error: ",
1005           ZSTD_getErrorName(rc)));
1006   }
1007
1008   out->append(rc);
1009   CHECK_EQ(out->length(), rc);
1010
1011   return out;
1012 }
1013
1014 std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(const IOBuf* data,
1015                                                uint64_t uncompressedLength) {
1016   size_t rc;
1017   auto out = IOBuf::createCombined(uncompressedLength);
1018
1019   CHECK_GE(out->capacity(), uncompressedLength);
1020   CHECK_EQ(out->length(), 0);
1021
1022   rc = ZSTD_decompress(
1023       out->writableTail(), out->capacity(), data->data(), data->length());
1024
1025   if (ZSTD_isError(rc)) {
1026     throw std::runtime_error(to<std::string>(
1027           "ZSTD decompression returned an error: ",
1028           ZSTD_getErrorName(rc)));
1029   }
1030
1031   out->append(rc);
1032   CHECK_EQ(out->length(), rc);
1033
1034   return out;
1035 }
1036
1037 #endif  // FOLLY_HAVE_LIBZSTD
1038
1039 }  // namespace
1040
1041 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
1042   typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
1043
1044   static CodecFactory codecFactories[
1045     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
1046     nullptr,  // USER_DEFINED
1047     NoCompressionCodec::create,
1048
1049 #if FOLLY_HAVE_LIBLZ4
1050     LZ4Codec::create,
1051 #else
1052     nullptr,
1053 #endif
1054
1055 #if FOLLY_HAVE_LIBSNAPPY
1056     SnappyCodec::create,
1057 #else
1058     nullptr,
1059 #endif
1060
1061 #if FOLLY_HAVE_LIBZ
1062     ZlibCodec::create,
1063 #else
1064     nullptr,
1065 #endif
1066
1067 #if FOLLY_HAVE_LIBLZ4
1068     LZ4Codec::create,
1069 #else
1070     nullptr,
1071 #endif
1072
1073 #if FOLLY_HAVE_LIBLZMA
1074     LZMA2Codec::create,
1075     LZMA2Codec::create,
1076 #else
1077     nullptr,
1078     nullptr,
1079 #endif
1080
1081 #if FOLLY_HAVE_LIBZSTD
1082     ZSTDCodec::create,
1083 #else
1084     nullptr,
1085 #endif
1086
1087 #if FOLLY_HAVE_LIBZ
1088     ZlibCodec::create,
1089 #else
1090     nullptr,
1091 #endif
1092   };
1093
1094   size_t idx = static_cast<size_t>(type);
1095   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1096     throw std::invalid_argument(to<std::string>(
1097         "Compression type ", idx, " not supported"));
1098   }
1099   auto factory = codecFactories[idx];
1100   if (!factory) {
1101     throw std::invalid_argument(to<std::string>(
1102         "Compression type ", idx, " not supported"));
1103   }
1104   auto codec = (*factory)(level, type);
1105   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
1106   return codec;
1107 }
1108
1109 }}  // namespaces