Update zstd to 0.4.2
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2015 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #if FOLLY_HAVE_LIBZSTD
40 #include <zstd.h>
41 #endif
42
43 #include <folly/Conv.h>
44 #include <folly/Memory.h>
45 #include <folly/Portability.h>
46 #include <folly/ScopeGuard.h>
47 #include <folly/Varint.h>
48 #include <folly/io/Cursor.h>
49
50 namespace folly { namespace io {
51
52 Codec::Codec(CodecType type) : type_(type) { }
53
54 // Ensure consistent behavior in the nullptr case
55 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
56   uint64_t len = data->computeChainDataLength();
57   if (len == 0) {
58     return IOBuf::create(0);
59   } else if (len > maxUncompressedLength()) {
60     throw std::runtime_error("Codec: uncompressed length too large");
61   }
62
63   return doCompress(data);
64 }
65
66 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
67                                          uint64_t uncompressedLength) {
68   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
69     if (needsUncompressedLength()) {
70       throw std::invalid_argument("Codec: uncompressed length required");
71     }
72   } else if (uncompressedLength > maxUncompressedLength()) {
73     throw std::runtime_error("Codec: uncompressed length too large");
74   }
75
76   if (data->empty()) {
77     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
78         uncompressedLength != 0) {
79       throw std::runtime_error("Codec: invalid uncompressed length");
80     }
81     return IOBuf::create(0);
82   }
83
84   return doUncompress(data, uncompressedLength);
85 }
86
87 bool Codec::needsUncompressedLength() const {
88   return doNeedsUncompressedLength();
89 }
90
91 uint64_t Codec::maxUncompressedLength() const {
92   return doMaxUncompressedLength();
93 }
94
95 bool Codec::doNeedsUncompressedLength() const {
96   return false;
97 }
98
99 uint64_t Codec::doMaxUncompressedLength() const {
100   return UNLIMITED_UNCOMPRESSED_LENGTH;
101 }
102
103 namespace {
104
105 /**
106  * No compression
107  */
108 class NoCompressionCodec final : public Codec {
109  public:
110   static std::unique_ptr<Codec> create(int level, CodecType type);
111   explicit NoCompressionCodec(int level, CodecType type);
112
113  private:
114   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
115   std::unique_ptr<IOBuf> doUncompress(
116       const IOBuf* data,
117       uint64_t uncompressedLength) override;
118 };
119
120 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
121   return make_unique<NoCompressionCodec>(level, type);
122 }
123
124 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
125   : Codec(type) {
126   DCHECK(type == CodecType::NO_COMPRESSION);
127   switch (level) {
128   case COMPRESSION_LEVEL_DEFAULT:
129   case COMPRESSION_LEVEL_FASTEST:
130   case COMPRESSION_LEVEL_BEST:
131     level = 0;
132   }
133   if (level != 0) {
134     throw std::invalid_argument(to<std::string>(
135         "NoCompressionCodec: invalid level ", level));
136   }
137 }
138
139 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
140     const IOBuf* data) {
141   return data->clone();
142 }
143
144 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
145     const IOBuf* data,
146     uint64_t uncompressedLength) {
147   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
148       data->computeChainDataLength() != uncompressedLength) {
149     throw std::runtime_error(to<std::string>(
150         "NoCompressionCodec: invalid uncompressed length"));
151   }
152   return data->clone();
153 }
154
155 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
156
157 namespace {
158
159 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
160   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
161   out->append(encodeVarint(val, out->writableTail()));
162 }
163
164 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
165   uint64_t val = 0;
166   int8_t b = 0;
167   for (int shift = 0; shift <= 63; shift += 7) {
168     b = cursor.read<int8_t>();
169     val |= static_cast<uint64_t>(b & 0x7f) << shift;
170     if (b >= 0) {
171       break;
172     }
173   }
174   if (b < 0) {
175     throw std::invalid_argument("Invalid varint value. Too big.");
176   }
177   return val;
178 }
179
180 }  // namespace
181
182 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
183
184 #if FOLLY_HAVE_LIBLZ4
185
186 /**
187  * LZ4 compression
188  */
189 class LZ4Codec final : public Codec {
190  public:
191   static std::unique_ptr<Codec> create(int level, CodecType type);
192   explicit LZ4Codec(int level, CodecType type);
193
194  private:
195   bool doNeedsUncompressedLength() const override;
196   uint64_t doMaxUncompressedLength() const override;
197
198   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
199
200   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
201   std::unique_ptr<IOBuf> doUncompress(
202       const IOBuf* data,
203       uint64_t uncompressedLength) override;
204
205   bool highCompression_;
206 };
207
208 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
209   return make_unique<LZ4Codec>(level, type);
210 }
211
212 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
213   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
214
215   switch (level) {
216   case COMPRESSION_LEVEL_FASTEST:
217   case COMPRESSION_LEVEL_DEFAULT:
218     level = 1;
219     break;
220   case COMPRESSION_LEVEL_BEST:
221     level = 2;
222     break;
223   }
224   if (level < 1 || level > 2) {
225     throw std::invalid_argument(to<std::string>(
226         "LZ4Codec: invalid level: ", level));
227   }
228   highCompression_ = (level > 1);
229 }
230
231 bool LZ4Codec::doNeedsUncompressedLength() const {
232   return !encodeSize();
233 }
234
235 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
236 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
237 // here.
238 #ifndef LZ4_MAX_INPUT_SIZE
239 # define LZ4_MAX_INPUT_SIZE 0x7E000000
240 #endif
241
242 uint64_t LZ4Codec::doMaxUncompressedLength() const {
243   return LZ4_MAX_INPUT_SIZE;
244 }
245
246 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
247   std::unique_ptr<IOBuf> clone;
248   if (data->isChained()) {
249     // LZ4 doesn't support streaming, so we have to coalesce
250     clone = data->clone();
251     clone->coalesce();
252     data = clone.get();
253   }
254
255   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
256   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
257   if (encodeSize()) {
258     encodeVarintToIOBuf(data->length(), out.get());
259   }
260
261   int n;
262   if (highCompression_) {
263     n = LZ4_compressHC(reinterpret_cast<const char*>(data->data()),
264                        reinterpret_cast<char*>(out->writableTail()),
265                        data->length());
266   } else {
267     n = LZ4_compress(reinterpret_cast<const char*>(data->data()),
268                      reinterpret_cast<char*>(out->writableTail()),
269                      data->length());
270   }
271
272   CHECK_GE(n, 0);
273   CHECK_LE(n, out->capacity());
274
275   out->append(n);
276   return out;
277 }
278
279 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
280     const IOBuf* data,
281     uint64_t uncompressedLength) {
282   std::unique_ptr<IOBuf> clone;
283   if (data->isChained()) {
284     // LZ4 doesn't support streaming, so we have to coalesce
285     clone = data->clone();
286     clone->coalesce();
287     data = clone.get();
288   }
289
290   folly::io::Cursor cursor(data);
291   uint64_t actualUncompressedLength;
292   if (encodeSize()) {
293     actualUncompressedLength = decodeVarintFromCursor(cursor);
294     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
295         uncompressedLength != actualUncompressedLength) {
296       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
297     }
298   } else {
299     actualUncompressedLength = uncompressedLength;
300     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
301         actualUncompressedLength > maxUncompressedLength()) {
302       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
303     }
304   }
305
306   auto p = cursor.peek();
307   auto out = IOBuf::create(actualUncompressedLength);
308   int n = LZ4_decompress_safe(reinterpret_cast<const char*>(p.first),
309                               reinterpret_cast<char*>(out->writableTail()),
310                               p.second,
311                               actualUncompressedLength);
312
313   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
314     throw std::runtime_error(to<std::string>(
315         "LZ4 decompression returned invalid value ", n));
316   }
317   out->append(actualUncompressedLength);
318   return out;
319 }
320
321 #endif  // FOLLY_HAVE_LIBLZ4
322
323 #if FOLLY_HAVE_LIBSNAPPY
324
325 /**
326  * Snappy compression
327  */
328
329 /**
330  * Implementation of snappy::Source that reads from a IOBuf chain.
331  */
332 class IOBufSnappySource final : public snappy::Source {
333  public:
334   explicit IOBufSnappySource(const IOBuf* data);
335   size_t Available() const override;
336   const char* Peek(size_t* len) override;
337   void Skip(size_t n) override;
338  private:
339   size_t available_;
340   io::Cursor cursor_;
341 };
342
343 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
344   : available_(data->computeChainDataLength()),
345     cursor_(data) {
346 }
347
348 size_t IOBufSnappySource::Available() const {
349   return available_;
350 }
351
352 const char* IOBufSnappySource::Peek(size_t* len) {
353   auto p = cursor_.peek();
354   *len = p.second;
355   return reinterpret_cast<const char*>(p.first);
356 }
357
358 void IOBufSnappySource::Skip(size_t n) {
359   CHECK_LE(n, available_);
360   cursor_.skip(n);
361   available_ -= n;
362 }
363
364 class SnappyCodec final : public Codec {
365  public:
366   static std::unique_ptr<Codec> create(int level, CodecType type);
367   explicit SnappyCodec(int level, CodecType type);
368
369  private:
370   uint64_t doMaxUncompressedLength() const override;
371   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
372   std::unique_ptr<IOBuf> doUncompress(
373       const IOBuf* data,
374       uint64_t uncompressedLength) override;
375 };
376
377 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
378   return make_unique<SnappyCodec>(level, type);
379 }
380
381 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
382   DCHECK(type == CodecType::SNAPPY);
383   switch (level) {
384   case COMPRESSION_LEVEL_FASTEST:
385   case COMPRESSION_LEVEL_DEFAULT:
386   case COMPRESSION_LEVEL_BEST:
387     level = 1;
388   }
389   if (level != 1) {
390     throw std::invalid_argument(to<std::string>(
391         "SnappyCodec: invalid level: ", level));
392   }
393 }
394
395 uint64_t SnappyCodec::doMaxUncompressedLength() const {
396   // snappy.h uses uint32_t for lengths, so there's that.
397   return std::numeric_limits<uint32_t>::max();
398 }
399
400 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
401   IOBufSnappySource source(data);
402   auto out =
403     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
404
405   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
406       out->writableTail()));
407
408   size_t n = snappy::Compress(&source, &sink);
409
410   CHECK_LE(n, out->capacity());
411   out->append(n);
412   return out;
413 }
414
415 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
416                                                  uint64_t uncompressedLength) {
417   uint32_t actualUncompressedLength = 0;
418
419   {
420     IOBufSnappySource source(data);
421     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
422       throw std::runtime_error("snappy::GetUncompressedLength failed");
423     }
424     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
425         uncompressedLength != actualUncompressedLength) {
426       throw std::runtime_error("snappy: invalid uncompressed length");
427     }
428   }
429
430   auto out = IOBuf::create(actualUncompressedLength);
431
432   {
433     IOBufSnappySource source(data);
434     if (!snappy::RawUncompress(&source,
435                                reinterpret_cast<char*>(out->writableTail()))) {
436       throw std::runtime_error("snappy::RawUncompress failed");
437     }
438   }
439
440   out->append(actualUncompressedLength);
441   return out;
442 }
443
444 #endif  // FOLLY_HAVE_LIBSNAPPY
445
446 #if FOLLY_HAVE_LIBZ
447 /**
448  * Zlib codec
449  */
450 class ZlibCodec final : public Codec {
451  public:
452   static std::unique_ptr<Codec> create(int level, CodecType type);
453   explicit ZlibCodec(int level, CodecType type);
454
455  private:
456   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
457   std::unique_ptr<IOBuf> doUncompress(
458       const IOBuf* data,
459       uint64_t uncompressedLength) override;
460
461   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
462   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
463
464   int level_;
465 };
466
467 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
468   return make_unique<ZlibCodec>(level, type);
469 }
470
471 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
472   DCHECK(type == CodecType::ZLIB);
473   switch (level) {
474   case COMPRESSION_LEVEL_FASTEST:
475     level = 1;
476     break;
477   case COMPRESSION_LEVEL_DEFAULT:
478     level = Z_DEFAULT_COMPRESSION;
479     break;
480   case COMPRESSION_LEVEL_BEST:
481     level = 9;
482     break;
483   }
484   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
485     throw std::invalid_argument(to<std::string>(
486         "ZlibCodec: invalid level: ", level));
487   }
488   level_ = level;
489 }
490
491 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
492                                                   uint32_t length) {
493   CHECK_EQ(stream->avail_out, 0);
494
495   auto buf = IOBuf::create(length);
496   buf->append(length);
497
498   stream->next_out = buf->writableData();
499   stream->avail_out = buf->length();
500
501   return buf;
502 }
503
504 bool ZlibCodec::doInflate(z_stream* stream,
505                           IOBuf* head,
506                           uint32_t bufferLength) {
507   if (stream->avail_out == 0) {
508     head->prependChain(addOutputBuffer(stream, bufferLength));
509   }
510
511   int rc = inflate(stream, Z_NO_FLUSH);
512
513   switch (rc) {
514   case Z_OK:
515     break;
516   case Z_STREAM_END:
517     return true;
518   case Z_BUF_ERROR:
519   case Z_NEED_DICT:
520   case Z_DATA_ERROR:
521   case Z_MEM_ERROR:
522     throw std::runtime_error(to<std::string>(
523         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
524   default:
525     CHECK(false) << rc << ": " << stream->msg;
526   }
527
528   return false;
529 }
530
531 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
532   z_stream stream;
533   stream.zalloc = nullptr;
534   stream.zfree = nullptr;
535   stream.opaque = nullptr;
536
537   int rc = deflateInit(&stream, level_);
538   if (rc != Z_OK) {
539     throw std::runtime_error(to<std::string>(
540         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
541   }
542
543   stream.next_in = stream.next_out = nullptr;
544   stream.avail_in = stream.avail_out = 0;
545   stream.total_in = stream.total_out = 0;
546
547   bool success = false;
548
549   SCOPE_EXIT {
550     int rc = deflateEnd(&stream);
551     // If we're here because of an exception, it's okay if some data
552     // got dropped.
553     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
554       << rc << ": " << stream.msg;
555   };
556
557   uint64_t uncompressedLength = data->computeChainDataLength();
558   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
559
560   // Max 64MiB in one go
561   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
562   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
563
564   auto out = addOutputBuffer(
565       &stream,
566       (maxCompressedLength <= maxSingleStepLength ?
567        maxCompressedLength :
568        defaultBufferLength));
569
570   for (auto& range : *data) {
571     uint64_t remaining = range.size();
572     uint64_t written = 0;
573     while (remaining) {
574       uint32_t step = (remaining > maxSingleStepLength ?
575                        maxSingleStepLength : remaining);
576       stream.next_in = const_cast<uint8_t*>(range.data() + written);
577       stream.avail_in = step;
578       remaining -= step;
579       written += step;
580
581       while (stream.avail_in != 0) {
582         if (stream.avail_out == 0) {
583           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
584         }
585
586         rc = deflate(&stream, Z_NO_FLUSH);
587
588         CHECK_EQ(rc, Z_OK) << stream.msg;
589       }
590     }
591   }
592
593   do {
594     if (stream.avail_out == 0) {
595       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
596     }
597
598     rc = deflate(&stream, Z_FINISH);
599   } while (rc == Z_OK);
600
601   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
602
603   out->prev()->trimEnd(stream.avail_out);
604
605   success = true;  // we survived
606
607   return out;
608 }
609
610 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
611                                                uint64_t uncompressedLength) {
612   z_stream stream;
613   stream.zalloc = nullptr;
614   stream.zfree = nullptr;
615   stream.opaque = nullptr;
616
617   int rc = inflateInit(&stream);
618   if (rc != Z_OK) {
619     throw std::runtime_error(to<std::string>(
620         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
621   }
622
623   stream.next_in = stream.next_out = nullptr;
624   stream.avail_in = stream.avail_out = 0;
625   stream.total_in = stream.total_out = 0;
626
627   bool success = false;
628
629   SCOPE_EXIT {
630     int rc = inflateEnd(&stream);
631     // If we're here because of an exception, it's okay if some data
632     // got dropped.
633     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
634       << rc << ": " << stream.msg;
635   };
636
637   // Max 64MiB in one go
638   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
639   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
640
641   auto out = addOutputBuffer(
642       &stream,
643       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
644         uncompressedLength <= maxSingleStepLength) ?
645        uncompressedLength :
646        defaultBufferLength));
647
648   bool streamEnd = false;
649   for (auto& range : *data) {
650     if (range.empty()) {
651       continue;
652     }
653
654     stream.next_in = const_cast<uint8_t*>(range.data());
655     stream.avail_in = range.size();
656
657     while (stream.avail_in != 0) {
658       if (streamEnd) {
659         throw std::runtime_error(to<std::string>(
660             "ZlibCodec: junk after end of data"));
661       }
662
663       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
664     }
665   }
666
667   while (!streamEnd) {
668     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
669   }
670
671   out->prev()->trimEnd(stream.avail_out);
672
673   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
674       uncompressedLength != stream.total_out) {
675     throw std::runtime_error(to<std::string>(
676         "ZlibCodec: invalid uncompressed length"));
677   }
678
679   success = true;  // we survived
680
681   return out;
682 }
683
684 #endif  // FOLLY_HAVE_LIBZ
685
686 #if FOLLY_HAVE_LIBLZMA
687
688 /**
689  * LZMA2 compression
690  */
691 class LZMA2Codec final : public Codec {
692  public:
693   static std::unique_ptr<Codec> create(int level, CodecType type);
694   explicit LZMA2Codec(int level, CodecType type);
695
696  private:
697   bool doNeedsUncompressedLength() const override;
698   uint64_t doMaxUncompressedLength() const override;
699
700   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
701
702   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
703   std::unique_ptr<IOBuf> doUncompress(
704       const IOBuf* data,
705       uint64_t uncompressedLength) override;
706
707   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
708   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
709
710   int level_;
711 };
712
713 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
714   return make_unique<LZMA2Codec>(level, type);
715 }
716
717 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
718   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
719   switch (level) {
720   case COMPRESSION_LEVEL_FASTEST:
721     level = 0;
722     break;
723   case COMPRESSION_LEVEL_DEFAULT:
724     level = LZMA_PRESET_DEFAULT;
725     break;
726   case COMPRESSION_LEVEL_BEST:
727     level = 9;
728     break;
729   }
730   if (level < 0 || level > 9) {
731     throw std::invalid_argument(to<std::string>(
732         "LZMA2Codec: invalid level: ", level));
733   }
734   level_ = level;
735 }
736
737 bool LZMA2Codec::doNeedsUncompressedLength() const {
738   return !encodeSize();
739 }
740
741 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
742   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
743   return uint64_t(1) << 63;
744 }
745
746 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
747     lzma_stream* stream,
748     size_t length) {
749
750   CHECK_EQ(stream->avail_out, 0);
751
752   auto buf = IOBuf::create(length);
753   buf->append(length);
754
755   stream->next_out = buf->writableData();
756   stream->avail_out = buf->length();
757
758   return buf;
759 }
760
761 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
762   lzma_ret rc;
763   lzma_stream stream = LZMA_STREAM_INIT;
764
765   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
766   if (rc != LZMA_OK) {
767     throw std::runtime_error(folly::to<std::string>(
768       "LZMA2Codec: lzma_easy_encoder error: ", rc));
769   }
770
771   SCOPE_EXIT { lzma_end(&stream); };
772
773   uint64_t uncompressedLength = data->computeChainDataLength();
774   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
775
776   // Max 64MiB in one go
777   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
778   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
779
780   auto out = addOutputBuffer(
781     &stream,
782     (maxCompressedLength <= maxSingleStepLength ?
783      maxCompressedLength :
784      defaultBufferLength));
785
786   if (encodeSize()) {
787     auto size = IOBuf::createCombined(kMaxVarintLength64);
788     encodeVarintToIOBuf(uncompressedLength, size.get());
789     size->appendChain(std::move(out));
790     out = std::move(size);
791   }
792
793   for (auto& range : *data) {
794     if (range.empty()) {
795       continue;
796     }
797
798     stream.next_in = const_cast<uint8_t*>(range.data());
799     stream.avail_in = range.size();
800
801     while (stream.avail_in != 0) {
802       if (stream.avail_out == 0) {
803         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
804       }
805
806       rc = lzma_code(&stream, LZMA_RUN);
807
808       if (rc != LZMA_OK) {
809         throw std::runtime_error(folly::to<std::string>(
810           "LZMA2Codec: lzma_code error: ", rc));
811       }
812     }
813   }
814
815   do {
816     if (stream.avail_out == 0) {
817       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
818     }
819
820     rc = lzma_code(&stream, LZMA_FINISH);
821   } while (rc == LZMA_OK);
822
823   if (rc != LZMA_STREAM_END) {
824     throw std::runtime_error(folly::to<std::string>(
825       "LZMA2Codec: lzma_code ended with error: ", rc));
826   }
827
828   out->prev()->trimEnd(stream.avail_out);
829
830   return out;
831 }
832
833 bool LZMA2Codec::doInflate(lzma_stream* stream,
834                           IOBuf* head,
835                           size_t bufferLength) {
836   if (stream->avail_out == 0) {
837     head->prependChain(addOutputBuffer(stream, bufferLength));
838   }
839
840   lzma_ret rc = lzma_code(stream, LZMA_RUN);
841
842   switch (rc) {
843   case LZMA_OK:
844     break;
845   case LZMA_STREAM_END:
846     return true;
847   default:
848     throw std::runtime_error(to<std::string>(
849         "LZMA2Codec: lzma_code error: ", rc));
850   }
851
852   return false;
853 }
854
855 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
856                                                uint64_t uncompressedLength) {
857   lzma_ret rc;
858   lzma_stream stream = LZMA_STREAM_INIT;
859
860   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
861   if (rc != LZMA_OK) {
862     throw std::runtime_error(folly::to<std::string>(
863       "LZMA2Codec: lzma_auto_decoder error: ", rc));
864   }
865
866   SCOPE_EXIT { lzma_end(&stream); };
867
868   // Max 64MiB in one go
869   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
870   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
871
872   folly::io::Cursor cursor(data);
873   uint64_t actualUncompressedLength;
874   if (encodeSize()) {
875     actualUncompressedLength = decodeVarintFromCursor(cursor);
876     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
877         uncompressedLength != actualUncompressedLength) {
878       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
879     }
880   } else {
881     actualUncompressedLength = uncompressedLength;
882     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
883   }
884
885   auto out = addOutputBuffer(
886       &stream,
887       (actualUncompressedLength <= maxSingleStepLength ?
888        actualUncompressedLength :
889        defaultBufferLength));
890
891   bool streamEnd = false;
892   auto buf = cursor.peek();
893   while (buf.second != 0) {
894     stream.next_in = const_cast<uint8_t*>(buf.first);
895     stream.avail_in = buf.second;
896
897     while (stream.avail_in != 0) {
898       if (streamEnd) {
899         throw std::runtime_error(to<std::string>(
900             "LZMA2Codec: junk after end of data"));
901       }
902
903       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
904     }
905
906     cursor.skip(buf.second);
907     buf = cursor.peek();
908   }
909
910   while (!streamEnd) {
911     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
912   }
913
914   out->prev()->trimEnd(stream.avail_out);
915
916   if (actualUncompressedLength != stream.total_out) {
917     throw std::runtime_error(to<std::string>(
918         "LZMA2Codec: invalid uncompressed length"));
919   }
920
921   return out;
922 }
923
924 #endif  // FOLLY_HAVE_LIBLZMA
925
926 #ifdef FOLLY_HAVE_LIBZSTD
927
928 /**
929  * ZSTD_BETA compression
930  */
931 class ZSTDCodec final : public Codec {
932  public:
933   static std::unique_ptr<Codec> create(int level, CodecType);
934   explicit ZSTDCodec(int level, CodecType type);
935
936  private:
937   bool doNeedsUncompressedLength() const override;
938   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
939   std::unique_ptr<IOBuf> doUncompress(
940       const IOBuf* data,
941       uint64_t uncompressedLength) override;
942
943   int level_{1};
944 };
945
946 std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
947   return make_unique<ZSTDCodec>(level, type);
948 }
949
950 ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
951   DCHECK(type == CodecType::ZSTD_BETA);
952   switch (level) {
953     case COMPRESSION_LEVEL_FASTEST:
954       level_ = 1;
955       break;
956     case COMPRESSION_LEVEL_DEFAULT:
957       level_ = 1;
958       break;
959     case COMPRESSION_LEVEL_BEST:
960       level_ = 19;
961       break;
962   }
963 }
964
965 bool ZSTDCodec::doNeedsUncompressedLength() const {
966   return true;
967 }
968
969 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
970   size_t rc;
971   size_t maxCompressedLength = ZSTD_compressBound(data->length());
972   auto out = IOBuf::createCombined(maxCompressedLength);
973
974   CHECK_EQ(out->length(), 0);
975
976   rc = ZSTD_compress(out->writableTail(),
977                      out->capacity(),
978                      data->data(),
979                      data->length(),
980                      level_);
981
982   if (ZSTD_isError(rc)) {
983     throw std::runtime_error(to<std::string>(
984           "ZSTD compression returned an error: ",
985           ZSTD_getErrorName(rc)));
986   }
987
988   out->append(rc);
989   CHECK_EQ(out->length(), rc);
990
991   return out;
992 }
993
994 std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(const IOBuf* data,
995                                                uint64_t uncompressedLength) {
996   size_t rc;
997   auto out = IOBuf::createCombined(uncompressedLength);
998
999   CHECK_GE(out->capacity(), uncompressedLength);
1000   CHECK_EQ(out->length(), 0);
1001
1002   rc = ZSTD_decompress(
1003       out->writableTail(), out->capacity(), data->data(), data->length());
1004
1005   if (ZSTD_isError(rc)) {
1006     throw std::runtime_error(to<std::string>(
1007           "ZSTD decompression returned an error: ",
1008           ZSTD_getErrorName(rc)));
1009   }
1010
1011   out->append(rc);
1012   CHECK_EQ(out->length(), rc);
1013
1014   return out;
1015 }
1016
1017 #endif  // FOLLY_HAVE_LIBZSTD
1018
1019 }  // namespace
1020
1021 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
1022   typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
1023
1024   static CodecFactory codecFactories[
1025     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
1026     nullptr,  // USER_DEFINED
1027     NoCompressionCodec::create,
1028
1029 #if FOLLY_HAVE_LIBLZ4
1030     LZ4Codec::create,
1031 #else
1032     nullptr,
1033 #endif
1034
1035 #if FOLLY_HAVE_LIBSNAPPY
1036     SnappyCodec::create,
1037 #else
1038     nullptr,
1039 #endif
1040
1041 #if FOLLY_HAVE_LIBZ
1042     ZlibCodec::create,
1043 #else
1044     nullptr,
1045 #endif
1046
1047 #if FOLLY_HAVE_LIBLZ4
1048     LZ4Codec::create,
1049 #else
1050     nullptr,
1051 #endif
1052
1053 #if FOLLY_HAVE_LIBLZMA
1054     LZMA2Codec::create,
1055     LZMA2Codec::create,
1056 #else
1057     nullptr,
1058     nullptr,
1059 #endif
1060
1061 #if FOLLY_HAVE_LIBZSTD
1062     ZSTDCodec::create,
1063 #else
1064     nullptr,
1065 #endif
1066   };
1067
1068   size_t idx = static_cast<size_t>(type);
1069   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1070     throw std::invalid_argument(to<std::string>(
1071         "Compression type ", idx, " not supported"));
1072   }
1073   auto factory = codecFactories[idx];
1074   if (!factory) {
1075     throw std::invalid_argument(to<std::string>(
1076         "Compression type ", idx, " not supported"));
1077   }
1078   auto codec = (*factory)(level, type);
1079   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
1080   return codec;
1081 }
1082
1083 }}  // namespaces