Add String Support to Compression Codec
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #if FOLLY_HAVE_LIBZSTD
40 #include <zstd.h>
41 #endif
42
43 #include <folly/Conv.h>
44 #include <folly/Memory.h>
45 #include <folly/Portability.h>
46 #include <folly/ScopeGuard.h>
47 #include <folly/Varint.h>
48 #include <folly/io/Cursor.h>
49
50 namespace folly { namespace io {
51
52 Codec::Codec(CodecType type) : type_(type) { }
53
54 // Ensure consistent behavior in the nullptr case
55 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
56   uint64_t len = data->computeChainDataLength();
57   if (len == 0) {
58     return IOBuf::create(0);
59   }
60   if (len > maxUncompressedLength()) {
61     throw std::runtime_error("Codec: uncompressed length too large");
62   }
63
64   return doCompress(data);
65 }
66
67 std::string Codec::compress(const StringPiece data) {
68   const uint64_t len = data.size();
69   if (len == 0) {
70     return "";
71   }
72   if (len > maxUncompressedLength()) {
73     throw std::runtime_error("Codec: uncompressed length too large");
74   }
75
76   return doCompressString(data);
77 }
78
79 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
80                                          uint64_t uncompressedLength) {
81   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
82     if (needsUncompressedLength()) {
83       throw std::invalid_argument("Codec: uncompressed length required");
84     }
85   } else if (uncompressedLength > maxUncompressedLength()) {
86     throw std::runtime_error("Codec: uncompressed length too large");
87   }
88
89   if (data->empty()) {
90     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
91         uncompressedLength != 0) {
92       throw std::runtime_error("Codec: invalid uncompressed length");
93     }
94     return IOBuf::create(0);
95   }
96
97   return doUncompress(data, uncompressedLength);
98 }
99
100 std::string Codec::uncompress(
101     const StringPiece data,
102     uint64_t uncompressedLength) {
103   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
104     if (needsUncompressedLength()) {
105       throw std::invalid_argument("Codec: uncompressed length required");
106     }
107   } else if (uncompressedLength > maxUncompressedLength()) {
108     throw std::runtime_error("Codec: uncompressed length too large");
109   }
110
111   if (data.empty()) {
112     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
113         uncompressedLength != 0) {
114       throw std::runtime_error("Codec: invalid uncompressed length");
115     }
116     return "";
117   }
118
119   return doUncompressString(data, uncompressedLength);
120 }
121
122 bool Codec::needsUncompressedLength() const {
123   return doNeedsUncompressedLength();
124 }
125
126 uint64_t Codec::maxUncompressedLength() const {
127   return doMaxUncompressedLength();
128 }
129
130 bool Codec::doNeedsUncompressedLength() const {
131   return false;
132 }
133
134 uint64_t Codec::doMaxUncompressedLength() const {
135   return UNLIMITED_UNCOMPRESSED_LENGTH;
136 }
137
138 std::string Codec::doCompressString(const StringPiece data) {
139   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
140   auto outputBuffer = doCompress(&inputBuffer);
141   std::string output;
142   output.reserve(outputBuffer->computeChainDataLength());
143   for (auto range : *outputBuffer) {
144     output.append(reinterpret_cast<const char*>(range.data()), range.size());
145   }
146   return output;
147 }
148
149 std::string Codec::doUncompressString(
150     const StringPiece data,
151     uint64_t uncompressedLength) {
152   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
153   auto outputBuffer = doUncompress(&inputBuffer, uncompressedLength);
154   std::string output;
155   output.reserve(outputBuffer->computeChainDataLength());
156   for (auto range : *outputBuffer) {
157     output.append(reinterpret_cast<const char*>(range.data()), range.size());
158   }
159   return output;
160 }
161
162 namespace {
163
164 /**
165  * No compression
166  */
167 class NoCompressionCodec final : public Codec {
168  public:
169   static std::unique_ptr<Codec> create(int level, CodecType type);
170   explicit NoCompressionCodec(int level, CodecType type);
171
172  private:
173   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
174   std::unique_ptr<IOBuf> doUncompress(
175       const IOBuf* data,
176       uint64_t uncompressedLength) override;
177 };
178
179 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
180   return make_unique<NoCompressionCodec>(level, type);
181 }
182
183 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
184   : Codec(type) {
185   DCHECK(type == CodecType::NO_COMPRESSION);
186   switch (level) {
187   case COMPRESSION_LEVEL_DEFAULT:
188   case COMPRESSION_LEVEL_FASTEST:
189   case COMPRESSION_LEVEL_BEST:
190     level = 0;
191   }
192   if (level != 0) {
193     throw std::invalid_argument(to<std::string>(
194         "NoCompressionCodec: invalid level ", level));
195   }
196 }
197
198 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
199     const IOBuf* data) {
200   return data->clone();
201 }
202
203 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
204     const IOBuf* data,
205     uint64_t uncompressedLength) {
206   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
207       data->computeChainDataLength() != uncompressedLength) {
208     throw std::runtime_error(to<std::string>(
209         "NoCompressionCodec: invalid uncompressed length"));
210   }
211   return data->clone();
212 }
213
214 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
215
216 namespace {
217
218 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
219   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
220   out->append(encodeVarint(val, out->writableTail()));
221 }
222
223 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
224   uint64_t val = 0;
225   int8_t b = 0;
226   for (int shift = 0; shift <= 63; shift += 7) {
227     b = cursor.read<int8_t>();
228     val |= static_cast<uint64_t>(b & 0x7f) << shift;
229     if (b >= 0) {
230       break;
231     }
232   }
233   if (b < 0) {
234     throw std::invalid_argument("Invalid varint value. Too big.");
235   }
236   return val;
237 }
238
239 }  // namespace
240
241 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
242
243 #if FOLLY_HAVE_LIBLZ4
244
245 /**
246  * LZ4 compression
247  */
248 class LZ4Codec final : public Codec {
249  public:
250   static std::unique_ptr<Codec> create(int level, CodecType type);
251   explicit LZ4Codec(int level, CodecType type);
252
253  private:
254   bool doNeedsUncompressedLength() const override;
255   uint64_t doMaxUncompressedLength() const override;
256
257   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
258
259   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
260   std::unique_ptr<IOBuf> doUncompress(
261       const IOBuf* data,
262       uint64_t uncompressedLength) override;
263
264   bool highCompression_;
265 };
266
267 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
268   return make_unique<LZ4Codec>(level, type);
269 }
270
271 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
272   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
273
274   switch (level) {
275   case COMPRESSION_LEVEL_FASTEST:
276   case COMPRESSION_LEVEL_DEFAULT:
277     level = 1;
278     break;
279   case COMPRESSION_LEVEL_BEST:
280     level = 2;
281     break;
282   }
283   if (level < 1 || level > 2) {
284     throw std::invalid_argument(to<std::string>(
285         "LZ4Codec: invalid level: ", level));
286   }
287   highCompression_ = (level > 1);
288 }
289
290 bool LZ4Codec::doNeedsUncompressedLength() const {
291   return !encodeSize();
292 }
293
294 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
295 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
296 // here.
297 #ifndef LZ4_MAX_INPUT_SIZE
298 # define LZ4_MAX_INPUT_SIZE 0x7E000000
299 #endif
300
301 uint64_t LZ4Codec::doMaxUncompressedLength() const {
302   return LZ4_MAX_INPUT_SIZE;
303 }
304
305 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
306   std::unique_ptr<IOBuf> clone;
307   if (data->isChained()) {
308     // LZ4 doesn't support streaming, so we have to coalesce
309     clone = data->clone();
310     clone->coalesce();
311     data = clone.get();
312   }
313
314   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
315   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
316   if (encodeSize()) {
317     encodeVarintToIOBuf(data->length(), out.get());
318   }
319
320   int n;
321   auto input = reinterpret_cast<const char*>(data->data());
322   auto output = reinterpret_cast<char*>(out->writableTail());
323   const auto inputLength = data->length();
324 #if LZ4_VERSION_NUMBER >= 10700
325   if (highCompression_) {
326     n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0);
327   } else {
328     n = LZ4_compress_default(input, output, inputLength, out->tailroom());
329   }
330 #else
331   if (highCompression_) {
332     n = LZ4_compressHC(input, output, inputLength);
333   } else {
334     n = LZ4_compress(input, output, inputLength);
335   }
336 #endif
337
338   CHECK_GE(n, 0);
339   CHECK_LE(n, out->capacity());
340
341   out->append(n);
342   return out;
343 }
344
345 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
346     const IOBuf* data,
347     uint64_t uncompressedLength) {
348   std::unique_ptr<IOBuf> clone;
349   if (data->isChained()) {
350     // LZ4 doesn't support streaming, so we have to coalesce
351     clone = data->clone();
352     clone->coalesce();
353     data = clone.get();
354   }
355
356   folly::io::Cursor cursor(data);
357   uint64_t actualUncompressedLength;
358   if (encodeSize()) {
359     actualUncompressedLength = decodeVarintFromCursor(cursor);
360     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
361         uncompressedLength != actualUncompressedLength) {
362       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
363     }
364   } else {
365     actualUncompressedLength = uncompressedLength;
366     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
367         actualUncompressedLength > maxUncompressedLength()) {
368       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
369     }
370   }
371
372   auto sp = StringPiece{cursor.peekBytes()};
373   auto out = IOBuf::create(actualUncompressedLength);
374   int n = LZ4_decompress_safe(
375       sp.data(),
376       reinterpret_cast<char*>(out->writableTail()),
377       sp.size(),
378       actualUncompressedLength);
379
380   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
381     throw std::runtime_error(to<std::string>(
382         "LZ4 decompression returned invalid value ", n));
383   }
384   out->append(actualUncompressedLength);
385   return out;
386 }
387
388 #endif  // FOLLY_HAVE_LIBLZ4
389
390 #if FOLLY_HAVE_LIBSNAPPY
391
392 /**
393  * Snappy compression
394  */
395
396 /**
397  * Implementation of snappy::Source that reads from a IOBuf chain.
398  */
399 class IOBufSnappySource final : public snappy::Source {
400  public:
401   explicit IOBufSnappySource(const IOBuf* data);
402   size_t Available() const override;
403   const char* Peek(size_t* len) override;
404   void Skip(size_t n) override;
405  private:
406   size_t available_;
407   io::Cursor cursor_;
408 };
409
410 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
411   : available_(data->computeChainDataLength()),
412     cursor_(data) {
413 }
414
415 size_t IOBufSnappySource::Available() const {
416   return available_;
417 }
418
419 const char* IOBufSnappySource::Peek(size_t* len) {
420   auto sp = StringPiece{cursor_.peekBytes()};
421   *len = sp.size();
422   return sp.data();
423 }
424
425 void IOBufSnappySource::Skip(size_t n) {
426   CHECK_LE(n, available_);
427   cursor_.skip(n);
428   available_ -= n;
429 }
430
431 class SnappyCodec final : public Codec {
432  public:
433   static std::unique_ptr<Codec> create(int level, CodecType type);
434   explicit SnappyCodec(int level, CodecType type);
435
436  private:
437   uint64_t doMaxUncompressedLength() const override;
438   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
439   std::unique_ptr<IOBuf> doUncompress(
440       const IOBuf* data,
441       uint64_t uncompressedLength) override;
442 };
443
444 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
445   return make_unique<SnappyCodec>(level, type);
446 }
447
448 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
449   DCHECK(type == CodecType::SNAPPY);
450   switch (level) {
451   case COMPRESSION_LEVEL_FASTEST:
452   case COMPRESSION_LEVEL_DEFAULT:
453   case COMPRESSION_LEVEL_BEST:
454     level = 1;
455   }
456   if (level != 1) {
457     throw std::invalid_argument(to<std::string>(
458         "SnappyCodec: invalid level: ", level));
459   }
460 }
461
462 uint64_t SnappyCodec::doMaxUncompressedLength() const {
463   // snappy.h uses uint32_t for lengths, so there's that.
464   return std::numeric_limits<uint32_t>::max();
465 }
466
467 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
468   IOBufSnappySource source(data);
469   auto out =
470     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
471
472   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
473       out->writableTail()));
474
475   size_t n = snappy::Compress(&source, &sink);
476
477   CHECK_LE(n, out->capacity());
478   out->append(n);
479   return out;
480 }
481
482 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
483                                                  uint64_t uncompressedLength) {
484   uint32_t actualUncompressedLength = 0;
485
486   {
487     IOBufSnappySource source(data);
488     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
489       throw std::runtime_error("snappy::GetUncompressedLength failed");
490     }
491     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
492         uncompressedLength != actualUncompressedLength) {
493       throw std::runtime_error("snappy: invalid uncompressed length");
494     }
495   }
496
497   auto out = IOBuf::create(actualUncompressedLength);
498
499   {
500     IOBufSnappySource source(data);
501     if (!snappy::RawUncompress(&source,
502                                reinterpret_cast<char*>(out->writableTail()))) {
503       throw std::runtime_error("snappy::RawUncompress failed");
504     }
505   }
506
507   out->append(actualUncompressedLength);
508   return out;
509 }
510
511 #endif  // FOLLY_HAVE_LIBSNAPPY
512
513 #if FOLLY_HAVE_LIBZ
514 /**
515  * Zlib codec
516  */
517 class ZlibCodec final : public Codec {
518  public:
519   static std::unique_ptr<Codec> create(int level, CodecType type);
520   explicit ZlibCodec(int level, CodecType type);
521
522  private:
523   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
524   std::unique_ptr<IOBuf> doUncompress(
525       const IOBuf* data,
526       uint64_t uncompressedLength) override;
527
528   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
529   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
530
531   int level_;
532 };
533
534 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
535   return make_unique<ZlibCodec>(level, type);
536 }
537
538 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
539   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
540   switch (level) {
541   case COMPRESSION_LEVEL_FASTEST:
542     level = 1;
543     break;
544   case COMPRESSION_LEVEL_DEFAULT:
545     level = Z_DEFAULT_COMPRESSION;
546     break;
547   case COMPRESSION_LEVEL_BEST:
548     level = 9;
549     break;
550   }
551   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
552     throw std::invalid_argument(to<std::string>(
553         "ZlibCodec: invalid level: ", level));
554   }
555   level_ = level;
556 }
557
558 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
559                                                   uint32_t length) {
560   CHECK_EQ(stream->avail_out, 0);
561
562   auto buf = IOBuf::create(length);
563   buf->append(length);
564
565   stream->next_out = buf->writableData();
566   stream->avail_out = buf->length();
567
568   return buf;
569 }
570
571 bool ZlibCodec::doInflate(z_stream* stream,
572                           IOBuf* head,
573                           uint32_t bufferLength) {
574   if (stream->avail_out == 0) {
575     head->prependChain(addOutputBuffer(stream, bufferLength));
576   }
577
578   int rc = inflate(stream, Z_NO_FLUSH);
579
580   switch (rc) {
581   case Z_OK:
582     break;
583   case Z_STREAM_END:
584     return true;
585   case Z_BUF_ERROR:
586   case Z_NEED_DICT:
587   case Z_DATA_ERROR:
588   case Z_MEM_ERROR:
589     throw std::runtime_error(to<std::string>(
590         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
591   default:
592     CHECK(false) << rc << ": " << stream->msg;
593   }
594
595   return false;
596 }
597
598 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
599   z_stream stream;
600   stream.zalloc = nullptr;
601   stream.zfree = nullptr;
602   stream.opaque = nullptr;
603
604   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
605   // base two logarithm of the maximum window size (...) The default value is
606   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
607   // around the compressed data instead of a zlib wrapper. The gzip header
608   // will have no file name, no extra data, no comment, no modification time
609   // (set to zero), no header crc, and the operating system will be set to 255
610   // (unknown)."
611   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
612   // All other parameters (method, memLevel, strategy) get default values from
613   // the zlib manual.
614   int rc = deflateInit2(&stream,
615                         level_,
616                         Z_DEFLATED,
617                         windowBits,
618                         /* memLevel */ 8,
619                         Z_DEFAULT_STRATEGY);
620   if (rc != Z_OK) {
621     throw std::runtime_error(to<std::string>(
622         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
623   }
624
625   stream.next_in = stream.next_out = nullptr;
626   stream.avail_in = stream.avail_out = 0;
627   stream.total_in = stream.total_out = 0;
628
629   bool success = false;
630
631   SCOPE_EXIT {
632     rc = deflateEnd(&stream);
633     // If we're here because of an exception, it's okay if some data
634     // got dropped.
635     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
636       << rc << ": " << stream.msg;
637   };
638
639   uint64_t uncompressedLength = data->computeChainDataLength();
640   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
641
642   // Max 64MiB in one go
643   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
644   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
645
646   auto out = addOutputBuffer(
647       &stream,
648       (maxCompressedLength <= maxSingleStepLength ?
649        maxCompressedLength :
650        defaultBufferLength));
651
652   for (auto& range : *data) {
653     uint64_t remaining = range.size();
654     uint64_t written = 0;
655     while (remaining) {
656       uint32_t step = (remaining > maxSingleStepLength ?
657                        maxSingleStepLength : remaining);
658       stream.next_in = const_cast<uint8_t*>(range.data() + written);
659       stream.avail_in = step;
660       remaining -= step;
661       written += step;
662
663       while (stream.avail_in != 0) {
664         if (stream.avail_out == 0) {
665           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
666         }
667
668         rc = deflate(&stream, Z_NO_FLUSH);
669
670         CHECK_EQ(rc, Z_OK) << stream.msg;
671       }
672     }
673   }
674
675   do {
676     if (stream.avail_out == 0) {
677       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
678     }
679
680     rc = deflate(&stream, Z_FINISH);
681   } while (rc == Z_OK);
682
683   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
684
685   out->prev()->trimEnd(stream.avail_out);
686
687   success = true;  // we survived
688
689   return out;
690 }
691
692 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
693                                                uint64_t uncompressedLength) {
694   z_stream stream;
695   stream.zalloc = nullptr;
696   stream.zfree = nullptr;
697   stream.opaque = nullptr;
698
699   // "The windowBits parameter is the base two logarithm of the maximum window
700   // size (...) The default value is 15 (...) add 16 to decode only the gzip
701   // format (the zlib format will return a Z_DATA_ERROR)."
702   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
703   int rc = inflateInit2(&stream, windowBits);
704   if (rc != Z_OK) {
705     throw std::runtime_error(to<std::string>(
706         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
707   }
708
709   stream.next_in = stream.next_out = nullptr;
710   stream.avail_in = stream.avail_out = 0;
711   stream.total_in = stream.total_out = 0;
712
713   bool success = false;
714
715   SCOPE_EXIT {
716     rc = inflateEnd(&stream);
717     // If we're here because of an exception, it's okay if some data
718     // got dropped.
719     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
720       << rc << ": " << stream.msg;
721   };
722
723   // Max 64MiB in one go
724   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
725   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
726
727   auto out = addOutputBuffer(
728       &stream,
729       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
730         uncompressedLength <= maxSingleStepLength) ?
731        uncompressedLength :
732        defaultBufferLength));
733
734   bool streamEnd = false;
735   for (auto& range : *data) {
736     if (range.empty()) {
737       continue;
738     }
739
740     stream.next_in = const_cast<uint8_t*>(range.data());
741     stream.avail_in = range.size();
742
743     while (stream.avail_in != 0) {
744       if (streamEnd) {
745         throw std::runtime_error(to<std::string>(
746             "ZlibCodec: junk after end of data"));
747       }
748
749       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
750     }
751   }
752
753   while (!streamEnd) {
754     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
755   }
756
757   out->prev()->trimEnd(stream.avail_out);
758
759   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
760       uncompressedLength != stream.total_out) {
761     throw std::runtime_error(to<std::string>(
762         "ZlibCodec: invalid uncompressed length"));
763   }
764
765   success = true;  // we survived
766
767   return out;
768 }
769
770 #endif  // FOLLY_HAVE_LIBZ
771
772 #if FOLLY_HAVE_LIBLZMA
773
774 /**
775  * LZMA2 compression
776  */
777 class LZMA2Codec final : public Codec {
778  public:
779   static std::unique_ptr<Codec> create(int level, CodecType type);
780   explicit LZMA2Codec(int level, CodecType type);
781
782  private:
783   bool doNeedsUncompressedLength() const override;
784   uint64_t doMaxUncompressedLength() const override;
785
786   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
787
788   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
789   std::unique_ptr<IOBuf> doUncompress(
790       const IOBuf* data,
791       uint64_t uncompressedLength) override;
792
793   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
794   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
795
796   int level_;
797 };
798
799 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
800   return make_unique<LZMA2Codec>(level, type);
801 }
802
803 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
804   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
805   switch (level) {
806   case COMPRESSION_LEVEL_FASTEST:
807     level = 0;
808     break;
809   case COMPRESSION_LEVEL_DEFAULT:
810     level = LZMA_PRESET_DEFAULT;
811     break;
812   case COMPRESSION_LEVEL_BEST:
813     level = 9;
814     break;
815   }
816   if (level < 0 || level > 9) {
817     throw std::invalid_argument(to<std::string>(
818         "LZMA2Codec: invalid level: ", level));
819   }
820   level_ = level;
821 }
822
823 bool LZMA2Codec::doNeedsUncompressedLength() const {
824   return !encodeSize();
825 }
826
827 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
828   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
829   return uint64_t(1) << 63;
830 }
831
832 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
833     lzma_stream* stream,
834     size_t length) {
835
836   CHECK_EQ(stream->avail_out, 0);
837
838   auto buf = IOBuf::create(length);
839   buf->append(length);
840
841   stream->next_out = buf->writableData();
842   stream->avail_out = buf->length();
843
844   return buf;
845 }
846
847 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
848   lzma_ret rc;
849   lzma_stream stream = LZMA_STREAM_INIT;
850
851   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
852   if (rc != LZMA_OK) {
853     throw std::runtime_error(folly::to<std::string>(
854       "LZMA2Codec: lzma_easy_encoder error: ", rc));
855   }
856
857   SCOPE_EXIT { lzma_end(&stream); };
858
859   uint64_t uncompressedLength = data->computeChainDataLength();
860   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
861
862   // Max 64MiB in one go
863   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
864   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
865
866   auto out = addOutputBuffer(
867     &stream,
868     (maxCompressedLength <= maxSingleStepLength ?
869      maxCompressedLength :
870      defaultBufferLength));
871
872   if (encodeSize()) {
873     auto size = IOBuf::createCombined(kMaxVarintLength64);
874     encodeVarintToIOBuf(uncompressedLength, size.get());
875     size->appendChain(std::move(out));
876     out = std::move(size);
877   }
878
879   for (auto& range : *data) {
880     if (range.empty()) {
881       continue;
882     }
883
884     stream.next_in = const_cast<uint8_t*>(range.data());
885     stream.avail_in = range.size();
886
887     while (stream.avail_in != 0) {
888       if (stream.avail_out == 0) {
889         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
890       }
891
892       rc = lzma_code(&stream, LZMA_RUN);
893
894       if (rc != LZMA_OK) {
895         throw std::runtime_error(folly::to<std::string>(
896           "LZMA2Codec: lzma_code error: ", rc));
897       }
898     }
899   }
900
901   do {
902     if (stream.avail_out == 0) {
903       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
904     }
905
906     rc = lzma_code(&stream, LZMA_FINISH);
907   } while (rc == LZMA_OK);
908
909   if (rc != LZMA_STREAM_END) {
910     throw std::runtime_error(folly::to<std::string>(
911       "LZMA2Codec: lzma_code ended with error: ", rc));
912   }
913
914   out->prev()->trimEnd(stream.avail_out);
915
916   return out;
917 }
918
919 bool LZMA2Codec::doInflate(lzma_stream* stream,
920                           IOBuf* head,
921                           size_t bufferLength) {
922   if (stream->avail_out == 0) {
923     head->prependChain(addOutputBuffer(stream, bufferLength));
924   }
925
926   lzma_ret rc = lzma_code(stream, LZMA_RUN);
927
928   switch (rc) {
929   case LZMA_OK:
930     break;
931   case LZMA_STREAM_END:
932     return true;
933   default:
934     throw std::runtime_error(to<std::string>(
935         "LZMA2Codec: lzma_code error: ", rc));
936   }
937
938   return false;
939 }
940
941 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
942                                                uint64_t uncompressedLength) {
943   lzma_ret rc;
944   lzma_stream stream = LZMA_STREAM_INIT;
945
946   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
947   if (rc != LZMA_OK) {
948     throw std::runtime_error(folly::to<std::string>(
949       "LZMA2Codec: lzma_auto_decoder error: ", rc));
950   }
951
952   SCOPE_EXIT { lzma_end(&stream); };
953
954   // Max 64MiB in one go
955   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
956   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
957
958   folly::io::Cursor cursor(data);
959   uint64_t actualUncompressedLength;
960   if (encodeSize()) {
961     actualUncompressedLength = decodeVarintFromCursor(cursor);
962     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
963         uncompressedLength != actualUncompressedLength) {
964       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
965     }
966   } else {
967     actualUncompressedLength = uncompressedLength;
968     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
969   }
970
971   auto out = addOutputBuffer(
972       &stream,
973       (actualUncompressedLength <= maxSingleStepLength ?
974        actualUncompressedLength :
975        defaultBufferLength));
976
977   bool streamEnd = false;
978   auto buf = cursor.peekBytes();
979   while (!buf.empty()) {
980     stream.next_in = const_cast<uint8_t*>(buf.data());
981     stream.avail_in = buf.size();
982
983     while (stream.avail_in != 0) {
984       if (streamEnd) {
985         throw std::runtime_error(to<std::string>(
986             "LZMA2Codec: junk after end of data"));
987       }
988
989       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
990     }
991
992     cursor.skip(buf.size());
993     buf = cursor.peekBytes();
994   }
995
996   while (!streamEnd) {
997     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
998   }
999
1000   out->prev()->trimEnd(stream.avail_out);
1001
1002   if (actualUncompressedLength != stream.total_out) {
1003     throw std::runtime_error(to<std::string>(
1004         "LZMA2Codec: invalid uncompressed length"));
1005   }
1006
1007   return out;
1008 }
1009
1010 #endif  // FOLLY_HAVE_LIBLZMA
1011
1012 #ifdef FOLLY_HAVE_LIBZSTD
1013
1014 /**
1015  * ZSTD compression
1016  */
1017 class ZSTDCodec final : public Codec {
1018  public:
1019   static std::unique_ptr<Codec> create(int level, CodecType);
1020   explicit ZSTDCodec(int level, CodecType type);
1021
1022  private:
1023   bool doNeedsUncompressedLength() const override;
1024   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
1025   std::unique_ptr<IOBuf> doUncompress(
1026       const IOBuf* data,
1027       uint64_t uncompressedLength) override;
1028
1029   int level_;
1030 };
1031
1032 std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
1033   return make_unique<ZSTDCodec>(level, type);
1034 }
1035
1036 ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
1037   DCHECK(type == CodecType::ZSTD);
1038   switch (level) {
1039     case COMPRESSION_LEVEL_FASTEST:
1040       level = 1;
1041       break;
1042     case COMPRESSION_LEVEL_DEFAULT:
1043       level = 1;
1044       break;
1045     case COMPRESSION_LEVEL_BEST:
1046       level = 19;
1047       break;
1048   }
1049   if (level < 1 || level > ZSTD_maxCLevel()) {
1050     throw std::invalid_argument(
1051         to<std::string>("ZSTD: invalid level: ", level));
1052   }
1053   level_ = level;
1054 }
1055
1056 bool ZSTDCodec::doNeedsUncompressedLength() const {
1057   return false;
1058 }
1059
1060 void zstdThrowIfError(size_t rc) {
1061   if (!ZSTD_isError(rc)) {
1062     return;
1063   }
1064   throw std::runtime_error(
1065       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1066 }
1067
1068 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
1069   // Support earlier versions of the codec (working with a single IOBuf,
1070   // and using ZSTD_decompress which requires ZSTD frame to contain size,
1071   // which isn't populated by streaming API).
1072   if (!data->isChained()) {
1073     auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
1074     const auto rc = ZSTD_compress(
1075         out->writableData(),
1076         out->capacity(),
1077         data->data(),
1078         data->length(),
1079         level_);
1080     zstdThrowIfError(rc);
1081     out->append(rc);
1082     return out;
1083   }
1084
1085   auto zcs = ZSTD_createCStream();
1086   SCOPE_EXIT {
1087     ZSTD_freeCStream(zcs);
1088   };
1089
1090   auto rc = ZSTD_initCStream(zcs, level_);
1091   zstdThrowIfError(rc);
1092
1093   Cursor cursor(data);
1094   auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
1095
1096   ZSTD_outBuffer out;
1097   out.dst = result->writableTail();
1098   out.size = result->capacity();
1099   out.pos = 0;
1100
1101   for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
1102     ZSTD_inBuffer in;
1103     in.src = buffer.data();
1104     in.size = buffer.size();
1105     for (in.pos = 0; in.pos != in.size;) {
1106       rc = ZSTD_compressStream(zcs, &out, &in);
1107       zstdThrowIfError(rc);
1108     }
1109     cursor.skip(in.size);
1110     buffer = cursor.peekBytes();
1111   }
1112
1113   rc = ZSTD_endStream(zcs, &out);
1114   zstdThrowIfError(rc);
1115   CHECK_EQ(rc, 0);
1116
1117   result->append(out.pos);
1118   return result;
1119 }
1120
1121 std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
1122     const IOBuf* data,
1123     uint64_t uncompressedLength) {
1124   auto zds = ZSTD_createDStream();
1125   SCOPE_EXIT {
1126     ZSTD_freeDStream(zds);
1127   };
1128
1129   auto rc = ZSTD_initDStream(zds);
1130   zstdThrowIfError(rc);
1131
1132   ZSTD_outBuffer out{};
1133   ZSTD_inBuffer in{};
1134
1135   auto outputSize = ZSTD_DStreamOutSize();
1136   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
1137     outputSize = uncompressedLength;
1138   } else {
1139     auto decompressedSize =
1140         ZSTD_getDecompressedSize(data->data(), data->length());
1141     if (decompressedSize != 0 && decompressedSize < outputSize) {
1142       outputSize = decompressedSize;
1143     }
1144   }
1145
1146   IOBufQueue queue(IOBufQueue::cacheChainLength());
1147
1148   Cursor cursor(data);
1149   for (rc = 0;;) {
1150     if (in.pos == in.size) {
1151       auto buffer = cursor.peekBytes();
1152       in.src = buffer.data();
1153       in.size = buffer.size();
1154       in.pos = 0;
1155       cursor.skip(in.size);
1156       if (rc > 1 && in.size == 0) {
1157         throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
1158       }
1159     }
1160     if (out.pos == out.size) {
1161       if (out.pos != 0) {
1162         queue.postallocate(out.pos);
1163       }
1164       auto buffer = queue.preallocate(outputSize, outputSize);
1165       out.dst = buffer.first;
1166       out.size = buffer.second;
1167       out.pos = 0;
1168       outputSize = ZSTD_DStreamOutSize();
1169     }
1170     rc = ZSTD_decompressStream(zds, &out, &in);
1171     zstdThrowIfError(rc);
1172     if (rc == 0) {
1173       break;
1174     }
1175   }
1176   if (out.pos != 0) {
1177     queue.postallocate(out.pos);
1178   }
1179   if (in.pos != in.size || !cursor.isAtEnd()) {
1180     throw std::runtime_error("ZSTD: junk after end of data");
1181   }
1182   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1183       queue.chainLength() != uncompressedLength) {
1184     throw std::runtime_error("ZSTD: invalid uncompressed length");
1185   }
1186
1187   return queue.move();
1188 }
1189
1190 #endif  // FOLLY_HAVE_LIBZSTD
1191
1192 }  // namespace
1193
1194 typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
1195 static CodecFactory
1196     codecFactories[static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
1197         nullptr, // USER_DEFINED
1198         NoCompressionCodec::create,
1199
1200 #if FOLLY_HAVE_LIBLZ4
1201         LZ4Codec::create,
1202 #else
1203         nullptr,
1204 #endif
1205
1206 #if FOLLY_HAVE_LIBSNAPPY
1207         SnappyCodec::create,
1208 #else
1209         nullptr,
1210 #endif
1211
1212 #if FOLLY_HAVE_LIBZ
1213         ZlibCodec::create,
1214 #else
1215         nullptr,
1216 #endif
1217
1218 #if FOLLY_HAVE_LIBLZ4
1219         LZ4Codec::create,
1220 #else
1221         nullptr,
1222 #endif
1223
1224 #if FOLLY_HAVE_LIBLZMA
1225         LZMA2Codec::create,
1226         LZMA2Codec::create,
1227 #else
1228         nullptr,
1229         nullptr,
1230 #endif
1231
1232 #if FOLLY_HAVE_LIBZSTD
1233         ZSTDCodec::create,
1234 #else
1235         nullptr,
1236 #endif
1237
1238 #if FOLLY_HAVE_LIBZ
1239         ZlibCodec::create,
1240 #else
1241         nullptr,
1242 #endif
1243 };
1244
1245 bool hasCodec(CodecType type) {
1246   size_t idx = static_cast<size_t>(type);
1247   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1248     throw std::invalid_argument(
1249         to<std::string>("Compression type ", idx, " invalid"));
1250   }
1251   return codecFactories[idx] != nullptr;
1252 }
1253
1254 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
1255   size_t idx = static_cast<size_t>(type);
1256   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1257     throw std::invalid_argument(
1258         to<std::string>("Compression type ", idx, " invalid"));
1259   }
1260   auto factory = codecFactories[idx];
1261   if (!factory) {
1262     throw std::invalid_argument(to<std::string>(
1263         "Compression type ", idx, " not supported"));
1264   }
1265   auto codec = (*factory)(level, type);
1266   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
1267   return codec;
1268 }
1269
1270 }}  // namespaces