Add IOBuf::cloneCoalesced()
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #if FOLLY_HAVE_LIBZSTD
40 #include <zstd.h>
41 #endif
42
43 #include <folly/Conv.h>
44 #include <folly/Memory.h>
45 #include <folly/Portability.h>
46 #include <folly/ScopeGuard.h>
47 #include <folly/Varint.h>
48 #include <folly/io/Cursor.h>
49
50 namespace folly { namespace io {
51
52 Codec::Codec(CodecType type) : type_(type) { }
53
54 // Ensure consistent behavior in the nullptr case
55 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
56   uint64_t len = data->computeChainDataLength();
57   if (len == 0) {
58     return IOBuf::create(0);
59   }
60   if (len > maxUncompressedLength()) {
61     throw std::runtime_error("Codec: uncompressed length too large");
62   }
63
64   return doCompress(data);
65 }
66
67 std::string Codec::compress(const StringPiece data) {
68   const uint64_t len = data.size();
69   if (len == 0) {
70     return "";
71   }
72   if (len > maxUncompressedLength()) {
73     throw std::runtime_error("Codec: uncompressed length too large");
74   }
75
76   return doCompressString(data);
77 }
78
79 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
80                                          uint64_t uncompressedLength) {
81   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
82     if (needsUncompressedLength()) {
83       throw std::invalid_argument("Codec: uncompressed length required");
84     }
85   } else if (uncompressedLength > maxUncompressedLength()) {
86     throw std::runtime_error("Codec: uncompressed length too large");
87   }
88
89   if (data->empty()) {
90     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
91         uncompressedLength != 0) {
92       throw std::runtime_error("Codec: invalid uncompressed length");
93     }
94     return IOBuf::create(0);
95   }
96
97   return doUncompress(data, uncompressedLength);
98 }
99
100 std::string Codec::uncompress(
101     const StringPiece data,
102     uint64_t uncompressedLength) {
103   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
104     if (needsUncompressedLength()) {
105       throw std::invalid_argument("Codec: uncompressed length required");
106     }
107   } else if (uncompressedLength > maxUncompressedLength()) {
108     throw std::runtime_error("Codec: uncompressed length too large");
109   }
110
111   if (data.empty()) {
112     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
113         uncompressedLength != 0) {
114       throw std::runtime_error("Codec: invalid uncompressed length");
115     }
116     return "";
117   }
118
119   return doUncompressString(data, uncompressedLength);
120 }
121
122 bool Codec::needsUncompressedLength() const {
123   return doNeedsUncompressedLength();
124 }
125
126 uint64_t Codec::maxUncompressedLength() const {
127   return doMaxUncompressedLength();
128 }
129
130 bool Codec::doNeedsUncompressedLength() const {
131   return false;
132 }
133
134 uint64_t Codec::doMaxUncompressedLength() const {
135   return UNLIMITED_UNCOMPRESSED_LENGTH;
136 }
137
138 std::string Codec::doCompressString(const StringPiece data) {
139   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
140   auto outputBuffer = doCompress(&inputBuffer);
141   std::string output;
142   output.reserve(outputBuffer->computeChainDataLength());
143   for (auto range : *outputBuffer) {
144     output.append(reinterpret_cast<const char*>(range.data()), range.size());
145   }
146   return output;
147 }
148
149 std::string Codec::doUncompressString(
150     const StringPiece data,
151     uint64_t uncompressedLength) {
152   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
153   auto outputBuffer = doUncompress(&inputBuffer, uncompressedLength);
154   std::string output;
155   output.reserve(outputBuffer->computeChainDataLength());
156   for (auto range : *outputBuffer) {
157     output.append(reinterpret_cast<const char*>(range.data()), range.size());
158   }
159   return output;
160 }
161
162 namespace {
163
164 /**
165  * No compression
166  */
167 class NoCompressionCodec final : public Codec {
168  public:
169   static std::unique_ptr<Codec> create(int level, CodecType type);
170   explicit NoCompressionCodec(int level, CodecType type);
171
172  private:
173   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
174   std::unique_ptr<IOBuf> doUncompress(
175       const IOBuf* data,
176       uint64_t uncompressedLength) override;
177 };
178
179 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
180   return make_unique<NoCompressionCodec>(level, type);
181 }
182
183 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
184   : Codec(type) {
185   DCHECK(type == CodecType::NO_COMPRESSION);
186   switch (level) {
187   case COMPRESSION_LEVEL_DEFAULT:
188   case COMPRESSION_LEVEL_FASTEST:
189   case COMPRESSION_LEVEL_BEST:
190     level = 0;
191   }
192   if (level != 0) {
193     throw std::invalid_argument(to<std::string>(
194         "NoCompressionCodec: invalid level ", level));
195   }
196 }
197
198 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
199     const IOBuf* data) {
200   return data->clone();
201 }
202
203 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
204     const IOBuf* data,
205     uint64_t uncompressedLength) {
206   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
207       data->computeChainDataLength() != uncompressedLength) {
208     throw std::runtime_error(to<std::string>(
209         "NoCompressionCodec: invalid uncompressed length"));
210   }
211   return data->clone();
212 }
213
214 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
215
216 namespace {
217
218 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
219   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
220   out->append(encodeVarint(val, out->writableTail()));
221 }
222
223 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
224   uint64_t val = 0;
225   int8_t b = 0;
226   for (int shift = 0; shift <= 63; shift += 7) {
227     b = cursor.read<int8_t>();
228     val |= static_cast<uint64_t>(b & 0x7f) << shift;
229     if (b >= 0) {
230       break;
231     }
232   }
233   if (b < 0) {
234     throw std::invalid_argument("Invalid varint value. Too big.");
235   }
236   return val;
237 }
238
239 }  // namespace
240
241 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
242
243 #if FOLLY_HAVE_LIBLZ4
244
245 /**
246  * LZ4 compression
247  */
248 class LZ4Codec final : public Codec {
249  public:
250   static std::unique_ptr<Codec> create(int level, CodecType type);
251   explicit LZ4Codec(int level, CodecType type);
252
253  private:
254   bool doNeedsUncompressedLength() const override;
255   uint64_t doMaxUncompressedLength() const override;
256
257   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
258
259   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
260   std::unique_ptr<IOBuf> doUncompress(
261       const IOBuf* data,
262       uint64_t uncompressedLength) override;
263
264   bool highCompression_;
265 };
266
267 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
268   return make_unique<LZ4Codec>(level, type);
269 }
270
271 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
272   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
273
274   switch (level) {
275   case COMPRESSION_LEVEL_FASTEST:
276   case COMPRESSION_LEVEL_DEFAULT:
277     level = 1;
278     break;
279   case COMPRESSION_LEVEL_BEST:
280     level = 2;
281     break;
282   }
283   if (level < 1 || level > 2) {
284     throw std::invalid_argument(to<std::string>(
285         "LZ4Codec: invalid level: ", level));
286   }
287   highCompression_ = (level > 1);
288 }
289
290 bool LZ4Codec::doNeedsUncompressedLength() const {
291   return !encodeSize();
292 }
293
294 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
295 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
296 // here.
297 #ifndef LZ4_MAX_INPUT_SIZE
298 # define LZ4_MAX_INPUT_SIZE 0x7E000000
299 #endif
300
301 uint64_t LZ4Codec::doMaxUncompressedLength() const {
302   return LZ4_MAX_INPUT_SIZE;
303 }
304
305 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
306   IOBuf clone;
307   if (data->isChained()) {
308     // LZ4 doesn't support streaming, so we have to coalesce
309     clone = data->cloneCoalescedAsValue();
310     data = &clone;
311   }
312
313   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
314   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
315   if (encodeSize()) {
316     encodeVarintToIOBuf(data->length(), out.get());
317   }
318
319   int n;
320   auto input = reinterpret_cast<const char*>(data->data());
321   auto output = reinterpret_cast<char*>(out->writableTail());
322   const auto inputLength = data->length();
323 #if LZ4_VERSION_NUMBER >= 10700
324   if (highCompression_) {
325     n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0);
326   } else {
327     n = LZ4_compress_default(input, output, inputLength, out->tailroom());
328   }
329 #else
330   if (highCompression_) {
331     n = LZ4_compressHC(input, output, inputLength);
332   } else {
333     n = LZ4_compress(input, output, inputLength);
334   }
335 #endif
336
337   CHECK_GE(n, 0);
338   CHECK_LE(n, out->capacity());
339
340   out->append(n);
341   return out;
342 }
343
344 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
345     const IOBuf* data,
346     uint64_t uncompressedLength) {
347   IOBuf clone;
348   if (data->isChained()) {
349     // LZ4 doesn't support streaming, so we have to coalesce
350     clone = data->cloneCoalescedAsValue();
351     data = &clone;
352   }
353
354   folly::io::Cursor cursor(data);
355   uint64_t actualUncompressedLength;
356   if (encodeSize()) {
357     actualUncompressedLength = decodeVarintFromCursor(cursor);
358     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
359         uncompressedLength != actualUncompressedLength) {
360       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
361     }
362   } else {
363     actualUncompressedLength = uncompressedLength;
364     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
365         actualUncompressedLength > maxUncompressedLength()) {
366       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
367     }
368   }
369
370   auto sp = StringPiece{cursor.peekBytes()};
371   auto out = IOBuf::create(actualUncompressedLength);
372   int n = LZ4_decompress_safe(
373       sp.data(),
374       reinterpret_cast<char*>(out->writableTail()),
375       sp.size(),
376       actualUncompressedLength);
377
378   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
379     throw std::runtime_error(to<std::string>(
380         "LZ4 decompression returned invalid value ", n));
381   }
382   out->append(actualUncompressedLength);
383   return out;
384 }
385
386 #endif  // FOLLY_HAVE_LIBLZ4
387
388 #if FOLLY_HAVE_LIBSNAPPY
389
390 /**
391  * Snappy compression
392  */
393
394 /**
395  * Implementation of snappy::Source that reads from a IOBuf chain.
396  */
397 class IOBufSnappySource final : public snappy::Source {
398  public:
399   explicit IOBufSnappySource(const IOBuf* data);
400   size_t Available() const override;
401   const char* Peek(size_t* len) override;
402   void Skip(size_t n) override;
403  private:
404   size_t available_;
405   io::Cursor cursor_;
406 };
407
408 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
409   : available_(data->computeChainDataLength()),
410     cursor_(data) {
411 }
412
413 size_t IOBufSnappySource::Available() const {
414   return available_;
415 }
416
417 const char* IOBufSnappySource::Peek(size_t* len) {
418   auto sp = StringPiece{cursor_.peekBytes()};
419   *len = sp.size();
420   return sp.data();
421 }
422
423 void IOBufSnappySource::Skip(size_t n) {
424   CHECK_LE(n, available_);
425   cursor_.skip(n);
426   available_ -= n;
427 }
428
429 class SnappyCodec final : public Codec {
430  public:
431   static std::unique_ptr<Codec> create(int level, CodecType type);
432   explicit SnappyCodec(int level, CodecType type);
433
434  private:
435   uint64_t doMaxUncompressedLength() const override;
436   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
437   std::unique_ptr<IOBuf> doUncompress(
438       const IOBuf* data,
439       uint64_t uncompressedLength) override;
440 };
441
442 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
443   return make_unique<SnappyCodec>(level, type);
444 }
445
446 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
447   DCHECK(type == CodecType::SNAPPY);
448   switch (level) {
449   case COMPRESSION_LEVEL_FASTEST:
450   case COMPRESSION_LEVEL_DEFAULT:
451   case COMPRESSION_LEVEL_BEST:
452     level = 1;
453   }
454   if (level != 1) {
455     throw std::invalid_argument(to<std::string>(
456         "SnappyCodec: invalid level: ", level));
457   }
458 }
459
460 uint64_t SnappyCodec::doMaxUncompressedLength() const {
461   // snappy.h uses uint32_t for lengths, so there's that.
462   return std::numeric_limits<uint32_t>::max();
463 }
464
465 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
466   IOBufSnappySource source(data);
467   auto out =
468     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
469
470   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
471       out->writableTail()));
472
473   size_t n = snappy::Compress(&source, &sink);
474
475   CHECK_LE(n, out->capacity());
476   out->append(n);
477   return out;
478 }
479
480 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
481                                                  uint64_t uncompressedLength) {
482   uint32_t actualUncompressedLength = 0;
483
484   {
485     IOBufSnappySource source(data);
486     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
487       throw std::runtime_error("snappy::GetUncompressedLength failed");
488     }
489     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
490         uncompressedLength != actualUncompressedLength) {
491       throw std::runtime_error("snappy: invalid uncompressed length");
492     }
493   }
494
495   auto out = IOBuf::create(actualUncompressedLength);
496
497   {
498     IOBufSnappySource source(data);
499     if (!snappy::RawUncompress(&source,
500                                reinterpret_cast<char*>(out->writableTail()))) {
501       throw std::runtime_error("snappy::RawUncompress failed");
502     }
503   }
504
505   out->append(actualUncompressedLength);
506   return out;
507 }
508
509 #endif  // FOLLY_HAVE_LIBSNAPPY
510
511 #if FOLLY_HAVE_LIBZ
512 /**
513  * Zlib codec
514  */
515 class ZlibCodec final : public Codec {
516  public:
517   static std::unique_ptr<Codec> create(int level, CodecType type);
518   explicit ZlibCodec(int level, CodecType type);
519
520  private:
521   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
522   std::unique_ptr<IOBuf> doUncompress(
523       const IOBuf* data,
524       uint64_t uncompressedLength) override;
525
526   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
527   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
528
529   int level_;
530 };
531
532 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
533   return make_unique<ZlibCodec>(level, type);
534 }
535
536 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
537   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
538   switch (level) {
539   case COMPRESSION_LEVEL_FASTEST:
540     level = 1;
541     break;
542   case COMPRESSION_LEVEL_DEFAULT:
543     level = Z_DEFAULT_COMPRESSION;
544     break;
545   case COMPRESSION_LEVEL_BEST:
546     level = 9;
547     break;
548   }
549   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
550     throw std::invalid_argument(to<std::string>(
551         "ZlibCodec: invalid level: ", level));
552   }
553   level_ = level;
554 }
555
556 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
557                                                   uint32_t length) {
558   CHECK_EQ(stream->avail_out, 0);
559
560   auto buf = IOBuf::create(length);
561   buf->append(length);
562
563   stream->next_out = buf->writableData();
564   stream->avail_out = buf->length();
565
566   return buf;
567 }
568
569 bool ZlibCodec::doInflate(z_stream* stream,
570                           IOBuf* head,
571                           uint32_t bufferLength) {
572   if (stream->avail_out == 0) {
573     head->prependChain(addOutputBuffer(stream, bufferLength));
574   }
575
576   int rc = inflate(stream, Z_NO_FLUSH);
577
578   switch (rc) {
579   case Z_OK:
580     break;
581   case Z_STREAM_END:
582     return true;
583   case Z_BUF_ERROR:
584   case Z_NEED_DICT:
585   case Z_DATA_ERROR:
586   case Z_MEM_ERROR:
587     throw std::runtime_error(to<std::string>(
588         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
589   default:
590     CHECK(false) << rc << ": " << stream->msg;
591   }
592
593   return false;
594 }
595
596 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
597   z_stream stream;
598   stream.zalloc = nullptr;
599   stream.zfree = nullptr;
600   stream.opaque = nullptr;
601
602   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
603   // base two logarithm of the maximum window size (...) The default value is
604   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
605   // around the compressed data instead of a zlib wrapper. The gzip header
606   // will have no file name, no extra data, no comment, no modification time
607   // (set to zero), no header crc, and the operating system will be set to 255
608   // (unknown)."
609   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
610   // All other parameters (method, memLevel, strategy) get default values from
611   // the zlib manual.
612   int rc = deflateInit2(&stream,
613                         level_,
614                         Z_DEFLATED,
615                         windowBits,
616                         /* memLevel */ 8,
617                         Z_DEFAULT_STRATEGY);
618   if (rc != Z_OK) {
619     throw std::runtime_error(to<std::string>(
620         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
621   }
622
623   stream.next_in = stream.next_out = nullptr;
624   stream.avail_in = stream.avail_out = 0;
625   stream.total_in = stream.total_out = 0;
626
627   bool success = false;
628
629   SCOPE_EXIT {
630     rc = deflateEnd(&stream);
631     // If we're here because of an exception, it's okay if some data
632     // got dropped.
633     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
634       << rc << ": " << stream.msg;
635   };
636
637   uint64_t uncompressedLength = data->computeChainDataLength();
638   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
639
640   // Max 64MiB in one go
641   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
642   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
643
644   auto out = addOutputBuffer(
645       &stream,
646       (maxCompressedLength <= maxSingleStepLength ?
647        maxCompressedLength :
648        defaultBufferLength));
649
650   for (auto& range : *data) {
651     uint64_t remaining = range.size();
652     uint64_t written = 0;
653     while (remaining) {
654       uint32_t step = (remaining > maxSingleStepLength ?
655                        maxSingleStepLength : remaining);
656       stream.next_in = const_cast<uint8_t*>(range.data() + written);
657       stream.avail_in = step;
658       remaining -= step;
659       written += step;
660
661       while (stream.avail_in != 0) {
662         if (stream.avail_out == 0) {
663           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
664         }
665
666         rc = deflate(&stream, Z_NO_FLUSH);
667
668         CHECK_EQ(rc, Z_OK) << stream.msg;
669       }
670     }
671   }
672
673   do {
674     if (stream.avail_out == 0) {
675       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
676     }
677
678     rc = deflate(&stream, Z_FINISH);
679   } while (rc == Z_OK);
680
681   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
682
683   out->prev()->trimEnd(stream.avail_out);
684
685   success = true;  // we survived
686
687   return out;
688 }
689
690 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
691                                                uint64_t uncompressedLength) {
692   z_stream stream;
693   stream.zalloc = nullptr;
694   stream.zfree = nullptr;
695   stream.opaque = nullptr;
696
697   // "The windowBits parameter is the base two logarithm of the maximum window
698   // size (...) The default value is 15 (...) add 16 to decode only the gzip
699   // format (the zlib format will return a Z_DATA_ERROR)."
700   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
701   int rc = inflateInit2(&stream, windowBits);
702   if (rc != Z_OK) {
703     throw std::runtime_error(to<std::string>(
704         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
705   }
706
707   stream.next_in = stream.next_out = nullptr;
708   stream.avail_in = stream.avail_out = 0;
709   stream.total_in = stream.total_out = 0;
710
711   bool success = false;
712
713   SCOPE_EXIT {
714     rc = inflateEnd(&stream);
715     // If we're here because of an exception, it's okay if some data
716     // got dropped.
717     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
718       << rc << ": " << stream.msg;
719   };
720
721   // Max 64MiB in one go
722   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
723   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
724
725   auto out = addOutputBuffer(
726       &stream,
727       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
728         uncompressedLength <= maxSingleStepLength) ?
729        uncompressedLength :
730        defaultBufferLength));
731
732   bool streamEnd = false;
733   for (auto& range : *data) {
734     if (range.empty()) {
735       continue;
736     }
737
738     stream.next_in = const_cast<uint8_t*>(range.data());
739     stream.avail_in = range.size();
740
741     while (stream.avail_in != 0) {
742       if (streamEnd) {
743         throw std::runtime_error(to<std::string>(
744             "ZlibCodec: junk after end of data"));
745       }
746
747       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
748     }
749   }
750
751   while (!streamEnd) {
752     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
753   }
754
755   out->prev()->trimEnd(stream.avail_out);
756
757   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
758       uncompressedLength != stream.total_out) {
759     throw std::runtime_error(to<std::string>(
760         "ZlibCodec: invalid uncompressed length"));
761   }
762
763   success = true;  // we survived
764
765   return out;
766 }
767
768 #endif  // FOLLY_HAVE_LIBZ
769
770 #if FOLLY_HAVE_LIBLZMA
771
772 /**
773  * LZMA2 compression
774  */
775 class LZMA2Codec final : public Codec {
776  public:
777   static std::unique_ptr<Codec> create(int level, CodecType type);
778   explicit LZMA2Codec(int level, CodecType type);
779
780  private:
781   bool doNeedsUncompressedLength() const override;
782   uint64_t doMaxUncompressedLength() const override;
783
784   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
785
786   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
787   std::unique_ptr<IOBuf> doUncompress(
788       const IOBuf* data,
789       uint64_t uncompressedLength) override;
790
791   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
792   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
793
794   int level_;
795 };
796
797 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
798   return make_unique<LZMA2Codec>(level, type);
799 }
800
801 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
802   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
803   switch (level) {
804   case COMPRESSION_LEVEL_FASTEST:
805     level = 0;
806     break;
807   case COMPRESSION_LEVEL_DEFAULT:
808     level = LZMA_PRESET_DEFAULT;
809     break;
810   case COMPRESSION_LEVEL_BEST:
811     level = 9;
812     break;
813   }
814   if (level < 0 || level > 9) {
815     throw std::invalid_argument(to<std::string>(
816         "LZMA2Codec: invalid level: ", level));
817   }
818   level_ = level;
819 }
820
821 bool LZMA2Codec::doNeedsUncompressedLength() const {
822   return !encodeSize();
823 }
824
825 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
826   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
827   return uint64_t(1) << 63;
828 }
829
830 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
831     lzma_stream* stream,
832     size_t length) {
833
834   CHECK_EQ(stream->avail_out, 0);
835
836   auto buf = IOBuf::create(length);
837   buf->append(length);
838
839   stream->next_out = buf->writableData();
840   stream->avail_out = buf->length();
841
842   return buf;
843 }
844
845 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
846   lzma_ret rc;
847   lzma_stream stream = LZMA_STREAM_INIT;
848
849   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
850   if (rc != LZMA_OK) {
851     throw std::runtime_error(folly::to<std::string>(
852       "LZMA2Codec: lzma_easy_encoder error: ", rc));
853   }
854
855   SCOPE_EXIT { lzma_end(&stream); };
856
857   uint64_t uncompressedLength = data->computeChainDataLength();
858   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
859
860   // Max 64MiB in one go
861   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
862   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
863
864   auto out = addOutputBuffer(
865     &stream,
866     (maxCompressedLength <= maxSingleStepLength ?
867      maxCompressedLength :
868      defaultBufferLength));
869
870   if (encodeSize()) {
871     auto size = IOBuf::createCombined(kMaxVarintLength64);
872     encodeVarintToIOBuf(uncompressedLength, size.get());
873     size->appendChain(std::move(out));
874     out = std::move(size);
875   }
876
877   for (auto& range : *data) {
878     if (range.empty()) {
879       continue;
880     }
881
882     stream.next_in = const_cast<uint8_t*>(range.data());
883     stream.avail_in = range.size();
884
885     while (stream.avail_in != 0) {
886       if (stream.avail_out == 0) {
887         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
888       }
889
890       rc = lzma_code(&stream, LZMA_RUN);
891
892       if (rc != LZMA_OK) {
893         throw std::runtime_error(folly::to<std::string>(
894           "LZMA2Codec: lzma_code error: ", rc));
895       }
896     }
897   }
898
899   do {
900     if (stream.avail_out == 0) {
901       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
902     }
903
904     rc = lzma_code(&stream, LZMA_FINISH);
905   } while (rc == LZMA_OK);
906
907   if (rc != LZMA_STREAM_END) {
908     throw std::runtime_error(folly::to<std::string>(
909       "LZMA2Codec: lzma_code ended with error: ", rc));
910   }
911
912   out->prev()->trimEnd(stream.avail_out);
913
914   return out;
915 }
916
917 bool LZMA2Codec::doInflate(lzma_stream* stream,
918                           IOBuf* head,
919                           size_t bufferLength) {
920   if (stream->avail_out == 0) {
921     head->prependChain(addOutputBuffer(stream, bufferLength));
922   }
923
924   lzma_ret rc = lzma_code(stream, LZMA_RUN);
925
926   switch (rc) {
927   case LZMA_OK:
928     break;
929   case LZMA_STREAM_END:
930     return true;
931   default:
932     throw std::runtime_error(to<std::string>(
933         "LZMA2Codec: lzma_code error: ", rc));
934   }
935
936   return false;
937 }
938
939 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
940                                                uint64_t uncompressedLength) {
941   lzma_ret rc;
942   lzma_stream stream = LZMA_STREAM_INIT;
943
944   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
945   if (rc != LZMA_OK) {
946     throw std::runtime_error(folly::to<std::string>(
947       "LZMA2Codec: lzma_auto_decoder error: ", rc));
948   }
949
950   SCOPE_EXIT { lzma_end(&stream); };
951
952   // Max 64MiB in one go
953   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
954   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
955
956   folly::io::Cursor cursor(data);
957   uint64_t actualUncompressedLength;
958   if (encodeSize()) {
959     actualUncompressedLength = decodeVarintFromCursor(cursor);
960     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
961         uncompressedLength != actualUncompressedLength) {
962       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
963     }
964   } else {
965     actualUncompressedLength = uncompressedLength;
966     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
967   }
968
969   auto out = addOutputBuffer(
970       &stream,
971       (actualUncompressedLength <= maxSingleStepLength ?
972        actualUncompressedLength :
973        defaultBufferLength));
974
975   bool streamEnd = false;
976   auto buf = cursor.peekBytes();
977   while (!buf.empty()) {
978     stream.next_in = const_cast<uint8_t*>(buf.data());
979     stream.avail_in = buf.size();
980
981     while (stream.avail_in != 0) {
982       if (streamEnd) {
983         throw std::runtime_error(to<std::string>(
984             "LZMA2Codec: junk after end of data"));
985       }
986
987       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
988     }
989
990     cursor.skip(buf.size());
991     buf = cursor.peekBytes();
992   }
993
994   while (!streamEnd) {
995     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
996   }
997
998   out->prev()->trimEnd(stream.avail_out);
999
1000   if (actualUncompressedLength != stream.total_out) {
1001     throw std::runtime_error(to<std::string>(
1002         "LZMA2Codec: invalid uncompressed length"));
1003   }
1004
1005   return out;
1006 }
1007
1008 #endif  // FOLLY_HAVE_LIBLZMA
1009
1010 #ifdef FOLLY_HAVE_LIBZSTD
1011
1012 /**
1013  * ZSTD compression
1014  */
1015 class ZSTDCodec final : public Codec {
1016  public:
1017   static std::unique_ptr<Codec> create(int level, CodecType);
1018   explicit ZSTDCodec(int level, CodecType type);
1019
1020  private:
1021   bool doNeedsUncompressedLength() const override;
1022   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
1023   std::unique_ptr<IOBuf> doUncompress(
1024       const IOBuf* data,
1025       uint64_t uncompressedLength) override;
1026
1027   int level_;
1028 };
1029
1030 std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
1031   return make_unique<ZSTDCodec>(level, type);
1032 }
1033
1034 ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
1035   DCHECK(type == CodecType::ZSTD);
1036   switch (level) {
1037     case COMPRESSION_LEVEL_FASTEST:
1038       level = 1;
1039       break;
1040     case COMPRESSION_LEVEL_DEFAULT:
1041       level = 1;
1042       break;
1043     case COMPRESSION_LEVEL_BEST:
1044       level = 19;
1045       break;
1046   }
1047   if (level < 1 || level > ZSTD_maxCLevel()) {
1048     throw std::invalid_argument(
1049         to<std::string>("ZSTD: invalid level: ", level));
1050   }
1051   level_ = level;
1052 }
1053
1054 bool ZSTDCodec::doNeedsUncompressedLength() const {
1055   return false;
1056 }
1057
1058 void zstdThrowIfError(size_t rc) {
1059   if (!ZSTD_isError(rc)) {
1060     return;
1061   }
1062   throw std::runtime_error(
1063       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1064 }
1065
1066 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
1067   // Support earlier versions of the codec (working with a single IOBuf,
1068   // and using ZSTD_decompress which requires ZSTD frame to contain size,
1069   // which isn't populated by streaming API).
1070   if (!data->isChained()) {
1071     auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
1072     const auto rc = ZSTD_compress(
1073         out->writableData(),
1074         out->capacity(),
1075         data->data(),
1076         data->length(),
1077         level_);
1078     zstdThrowIfError(rc);
1079     out->append(rc);
1080     return out;
1081   }
1082
1083   auto zcs = ZSTD_createCStream();
1084   SCOPE_EXIT {
1085     ZSTD_freeCStream(zcs);
1086   };
1087
1088   auto rc = ZSTD_initCStream(zcs, level_);
1089   zstdThrowIfError(rc);
1090
1091   Cursor cursor(data);
1092   auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
1093
1094   ZSTD_outBuffer out;
1095   out.dst = result->writableTail();
1096   out.size = result->capacity();
1097   out.pos = 0;
1098
1099   for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
1100     ZSTD_inBuffer in;
1101     in.src = buffer.data();
1102     in.size = buffer.size();
1103     for (in.pos = 0; in.pos != in.size;) {
1104       rc = ZSTD_compressStream(zcs, &out, &in);
1105       zstdThrowIfError(rc);
1106     }
1107     cursor.skip(in.size);
1108     buffer = cursor.peekBytes();
1109   }
1110
1111   rc = ZSTD_endStream(zcs, &out);
1112   zstdThrowIfError(rc);
1113   CHECK_EQ(rc, 0);
1114
1115   result->append(out.pos);
1116   return result;
1117 }
1118
1119 std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
1120     const IOBuf* data,
1121     uint64_t uncompressedLength) {
1122   auto zds = ZSTD_createDStream();
1123   SCOPE_EXIT {
1124     ZSTD_freeDStream(zds);
1125   };
1126
1127   auto rc = ZSTD_initDStream(zds);
1128   zstdThrowIfError(rc);
1129
1130   ZSTD_outBuffer out{};
1131   ZSTD_inBuffer in{};
1132
1133   auto outputSize = ZSTD_DStreamOutSize();
1134   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
1135     outputSize = uncompressedLength;
1136   } else {
1137     auto decompressedSize =
1138         ZSTD_getDecompressedSize(data->data(), data->length());
1139     if (decompressedSize != 0 && decompressedSize < outputSize) {
1140       outputSize = decompressedSize;
1141     }
1142   }
1143
1144   IOBufQueue queue(IOBufQueue::cacheChainLength());
1145
1146   Cursor cursor(data);
1147   for (rc = 0;;) {
1148     if (in.pos == in.size) {
1149       auto buffer = cursor.peekBytes();
1150       in.src = buffer.data();
1151       in.size = buffer.size();
1152       in.pos = 0;
1153       cursor.skip(in.size);
1154       if (rc > 1 && in.size == 0) {
1155         throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
1156       }
1157     }
1158     if (out.pos == out.size) {
1159       if (out.pos != 0) {
1160         queue.postallocate(out.pos);
1161       }
1162       auto buffer = queue.preallocate(outputSize, outputSize);
1163       out.dst = buffer.first;
1164       out.size = buffer.second;
1165       out.pos = 0;
1166       outputSize = ZSTD_DStreamOutSize();
1167     }
1168     rc = ZSTD_decompressStream(zds, &out, &in);
1169     zstdThrowIfError(rc);
1170     if (rc == 0) {
1171       break;
1172     }
1173   }
1174   if (out.pos != 0) {
1175     queue.postallocate(out.pos);
1176   }
1177   if (in.pos != in.size || !cursor.isAtEnd()) {
1178     throw std::runtime_error("ZSTD: junk after end of data");
1179   }
1180   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1181       queue.chainLength() != uncompressedLength) {
1182     throw std::runtime_error("ZSTD: invalid uncompressed length");
1183   }
1184
1185   return queue.move();
1186 }
1187
1188 #endif  // FOLLY_HAVE_LIBZSTD
1189
1190 }  // namespace
1191
1192 typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
1193 static CodecFactory
1194     codecFactories[static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
1195         nullptr, // USER_DEFINED
1196         NoCompressionCodec::create,
1197
1198 #if FOLLY_HAVE_LIBLZ4
1199         LZ4Codec::create,
1200 #else
1201         nullptr,
1202 #endif
1203
1204 #if FOLLY_HAVE_LIBSNAPPY
1205         SnappyCodec::create,
1206 #else
1207         nullptr,
1208 #endif
1209
1210 #if FOLLY_HAVE_LIBZ
1211         ZlibCodec::create,
1212 #else
1213         nullptr,
1214 #endif
1215
1216 #if FOLLY_HAVE_LIBLZ4
1217         LZ4Codec::create,
1218 #else
1219         nullptr,
1220 #endif
1221
1222 #if FOLLY_HAVE_LIBLZMA
1223         LZMA2Codec::create,
1224         LZMA2Codec::create,
1225 #else
1226         nullptr,
1227         nullptr,
1228 #endif
1229
1230 #if FOLLY_HAVE_LIBZSTD
1231         ZSTDCodec::create,
1232 #else
1233         nullptr,
1234 #endif
1235
1236 #if FOLLY_HAVE_LIBZ
1237         ZlibCodec::create,
1238 #else
1239         nullptr,
1240 #endif
1241 };
1242
1243 bool hasCodec(CodecType type) {
1244   size_t idx = static_cast<size_t>(type);
1245   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1246     throw std::invalid_argument(
1247         to<std::string>("Compression type ", idx, " invalid"));
1248   }
1249   return codecFactories[idx] != nullptr;
1250 }
1251
1252 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
1253   size_t idx = static_cast<size_t>(type);
1254   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1255     throw std::invalid_argument(
1256         to<std::string>("Compression type ", idx, " invalid"));
1257   }
1258   auto factory = codecFactories[idx];
1259   if (!factory) {
1260     throw std::invalid_argument(to<std::string>(
1261         "Compression type ", idx, " not supported"));
1262   }
1263   auto codec = (*factory)(level, type);
1264   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
1265   return codec;
1266 }
1267
1268 }}  // namespaces