Codemod: use #include angle brackets in folly and thrift
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #include <folly/Conv.h>
40 #include <folly/Memory.h>
41 #include <folly/Portability.h>
42 #include <folly/ScopeGuard.h>
43 #include <folly/Varint.h>
44 #include <folly/io/Cursor.h>
45
46 namespace folly { namespace io {
47
48 Codec::Codec(CodecType type) : type_(type) { }
49
50 // Ensure consistent behavior in the nullptr case
51 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
52   return !data->empty() ? doCompress(data) : IOBuf::create(0);
53 }
54
55 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
56                                          uint64_t uncompressedLength) {
57   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
58     if (needsUncompressedLength()) {
59       throw std::invalid_argument("Codec: uncompressed length required");
60     }
61   } else if (uncompressedLength > maxUncompressedLength()) {
62     throw std::runtime_error("Codec: uncompressed length too large");
63   }
64
65   if (data->empty()) {
66     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
67         uncompressedLength != 0) {
68       throw std::runtime_error("Codec: invalid uncompressed length");
69     }
70     return IOBuf::create(0);
71   }
72
73   return doUncompress(data, uncompressedLength);
74 }
75
76 bool Codec::needsUncompressedLength() const {
77   return doNeedsUncompressedLength();
78 }
79
80 uint64_t Codec::maxUncompressedLength() const {
81   return doMaxUncompressedLength();
82 }
83
84 bool Codec::doNeedsUncompressedLength() const {
85   return false;
86 }
87
88 uint64_t Codec::doMaxUncompressedLength() const {
89   return std::numeric_limits<uint64_t>::max() - 1;
90 }
91
92 namespace {
93
94 /**
95  * No compression
96  */
97 class NoCompressionCodec FOLLY_FINAL : public Codec {
98  public:
99   static std::unique_ptr<Codec> create(int level, CodecType type);
100   explicit NoCompressionCodec(int level, CodecType type);
101
102  private:
103   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
104   std::unique_ptr<IOBuf> doUncompress(
105       const IOBuf* data,
106       uint64_t uncompressedLength) FOLLY_OVERRIDE;
107 };
108
109 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
110   return make_unique<NoCompressionCodec>(level, type);
111 }
112
113 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
114   : Codec(type) {
115   DCHECK(type == CodecType::NO_COMPRESSION);
116   switch (level) {
117   case COMPRESSION_LEVEL_DEFAULT:
118   case COMPRESSION_LEVEL_FASTEST:
119   case COMPRESSION_LEVEL_BEST:
120     level = 0;
121   }
122   if (level != 0) {
123     throw std::invalid_argument(to<std::string>(
124         "NoCompressionCodec: invalid level ", level));
125   }
126 }
127
128 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
129     const IOBuf* data) {
130   return data->clone();
131 }
132
133 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
134     const IOBuf* data,
135     uint64_t uncompressedLength) {
136   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
137       data->computeChainDataLength() != uncompressedLength) {
138     throw std::runtime_error(to<std::string>(
139         "NoCompressionCodec: invalid uncompressed length"));
140   }
141   return data->clone();
142 }
143
144 namespace {
145
146 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
147   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
148   out->append(encodeVarint(val, out->writableTail()));
149 }
150
151 uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
152   // Must have enough room in *this* buffer.
153   auto p = cursor.peek();
154   folly::ByteRange range(p.first, p.second);
155   uint64_t val = decodeVarint(range);
156   cursor.skip(range.data() - p.first);
157   return val;
158 }
159
160 }  // namespace
161
162 #if FOLLY_HAVE_LIBLZ4
163
164 /**
165  * LZ4 compression
166  */
167 class LZ4Codec FOLLY_FINAL : public Codec {
168  public:
169   static std::unique_ptr<Codec> create(int level, CodecType type);
170   explicit LZ4Codec(int level, CodecType type);
171
172  private:
173   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
174   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
175
176   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
177
178   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
179   std::unique_ptr<IOBuf> doUncompress(
180       const IOBuf* data,
181       uint64_t uncompressedLength) FOLLY_OVERRIDE;
182
183   bool highCompression_;
184 };
185
186 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
187   return make_unique<LZ4Codec>(level, type);
188 }
189
190 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
191   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
192
193   switch (level) {
194   case COMPRESSION_LEVEL_FASTEST:
195   case COMPRESSION_LEVEL_DEFAULT:
196     level = 1;
197     break;
198   case COMPRESSION_LEVEL_BEST:
199     level = 2;
200     break;
201   }
202   if (level < 1 || level > 2) {
203     throw std::invalid_argument(to<std::string>(
204         "LZ4Codec: invalid level: ", level));
205   }
206   highCompression_ = (level > 1);
207 }
208
209 bool LZ4Codec::doNeedsUncompressedLength() const {
210   return !encodeSize();
211 }
212
213 uint64_t LZ4Codec::doMaxUncompressedLength() const {
214   // From lz4.h: "Max supported value is ~1.9GB"; I wish we had something
215   // more accurate.
216   return 1.8 * (uint64_t(1) << 30);
217 }
218
219 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
220   std::unique_ptr<IOBuf> clone;
221   if (data->isChained()) {
222     // LZ4 doesn't support streaming, so we have to coalesce
223     clone = data->clone();
224     clone->coalesce();
225     data = clone.get();
226   }
227
228   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
229   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
230   if (encodeSize()) {
231     encodeVarintToIOBuf(data->length(), out.get());
232   }
233
234   int n;
235   if (highCompression_) {
236     n = LZ4_compressHC(reinterpret_cast<const char*>(data->data()),
237                        reinterpret_cast<char*>(out->writableTail()),
238                        data->length());
239   } else {
240     n = LZ4_compress(reinterpret_cast<const char*>(data->data()),
241                      reinterpret_cast<char*>(out->writableTail()),
242                      data->length());
243   }
244
245   CHECK_GE(n, 0);
246   CHECK_LE(n, out->capacity());
247
248   out->append(n);
249   return out;
250 }
251
252 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
253     const IOBuf* data,
254     uint64_t uncompressedLength) {
255   std::unique_ptr<IOBuf> clone;
256   if (data->isChained()) {
257     // LZ4 doesn't support streaming, so we have to coalesce
258     clone = data->clone();
259     clone->coalesce();
260     data = clone.get();
261   }
262
263   folly::io::Cursor cursor(data);
264   uint64_t actualUncompressedLength;
265   if (encodeSize()) {
266     actualUncompressedLength = decodeVarintFromCursor(cursor);
267     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
268         uncompressedLength != actualUncompressedLength) {
269       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
270     }
271   } else {
272     actualUncompressedLength = uncompressedLength;
273     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
274   }
275
276   auto out = IOBuf::create(actualUncompressedLength);
277   auto p = cursor.peek();
278   int n = LZ4_uncompress(reinterpret_cast<const char*>(p.first),
279                          reinterpret_cast<char*>(out->writableTail()),
280                          actualUncompressedLength);
281   if (n != p.second) {
282     throw std::runtime_error(to<std::string>(
283         "LZ4 decompression returned invalid value ", n));
284   }
285   out->append(actualUncompressedLength);
286   return out;
287 }
288
289 #endif  // FOLLY_HAVE_LIBLZ4
290
291 #if FOLLY_HAVE_LIBSNAPPY
292
293 /**
294  * Snappy compression
295  */
296
297 /**
298  * Implementation of snappy::Source that reads from a IOBuf chain.
299  */
300 class IOBufSnappySource FOLLY_FINAL : public snappy::Source {
301  public:
302   explicit IOBufSnappySource(const IOBuf* data);
303   size_t Available() const FOLLY_OVERRIDE;
304   const char* Peek(size_t* len) FOLLY_OVERRIDE;
305   void Skip(size_t n) FOLLY_OVERRIDE;
306  private:
307   size_t available_;
308   io::Cursor cursor_;
309 };
310
311 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
312   : available_(data->computeChainDataLength()),
313     cursor_(data) {
314 }
315
316 size_t IOBufSnappySource::Available() const {
317   return available_;
318 }
319
320 const char* IOBufSnappySource::Peek(size_t* len) {
321   auto p = cursor_.peek();
322   *len = p.second;
323   return reinterpret_cast<const char*>(p.first);
324 }
325
326 void IOBufSnappySource::Skip(size_t n) {
327   CHECK_LE(n, available_);
328   cursor_.skip(n);
329   available_ -= n;
330 }
331
332 class SnappyCodec FOLLY_FINAL : public Codec {
333  public:
334   static std::unique_ptr<Codec> create(int level, CodecType type);
335   explicit SnappyCodec(int level, CodecType type);
336
337  private:
338   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
339   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
340   std::unique_ptr<IOBuf> doUncompress(
341       const IOBuf* data,
342       uint64_t uncompressedLength) FOLLY_OVERRIDE;
343 };
344
345 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
346   return make_unique<SnappyCodec>(level, type);
347 }
348
349 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
350   DCHECK(type == CodecType::SNAPPY);
351   switch (level) {
352   case COMPRESSION_LEVEL_FASTEST:
353   case COMPRESSION_LEVEL_DEFAULT:
354   case COMPRESSION_LEVEL_BEST:
355     level = 1;
356   }
357   if (level != 1) {
358     throw std::invalid_argument(to<std::string>(
359         "SnappyCodec: invalid level: ", level));
360   }
361 }
362
363 uint64_t SnappyCodec::doMaxUncompressedLength() const {
364   // snappy.h uses uint32_t for lengths, so there's that.
365   return std::numeric_limits<uint32_t>::max();
366 }
367
368 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
369   IOBufSnappySource source(data);
370   auto out =
371     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
372
373   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
374       out->writableTail()));
375
376   size_t n = snappy::Compress(&source, &sink);
377
378   CHECK_LE(n, out->capacity());
379   out->append(n);
380   return out;
381 }
382
383 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
384                                                  uint64_t uncompressedLength) {
385   uint32_t actualUncompressedLength = 0;
386
387   {
388     IOBufSnappySource source(data);
389     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
390       throw std::runtime_error("snappy::GetUncompressedLength failed");
391     }
392     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
393         uncompressedLength != actualUncompressedLength) {
394       throw std::runtime_error("snappy: invalid uncompressed length");
395     }
396   }
397
398   auto out = IOBuf::create(actualUncompressedLength);
399
400   {
401     IOBufSnappySource source(data);
402     if (!snappy::RawUncompress(&source,
403                                reinterpret_cast<char*>(out->writableTail()))) {
404       throw std::runtime_error("snappy::RawUncompress failed");
405     }
406   }
407
408   out->append(actualUncompressedLength);
409   return out;
410 }
411
412 #endif  // FOLLY_HAVE_LIBSNAPPY
413
414 #if FOLLY_HAVE_LIBZ
415 /**
416  * Zlib codec
417  */
418 class ZlibCodec FOLLY_FINAL : public Codec {
419  public:
420   static std::unique_ptr<Codec> create(int level, CodecType type);
421   explicit ZlibCodec(int level, CodecType type);
422
423  private:
424   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
425   std::unique_ptr<IOBuf> doUncompress(
426       const IOBuf* data,
427       uint64_t uncompressedLength) FOLLY_OVERRIDE;
428
429   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
430   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
431
432   int level_;
433 };
434
435 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
436   return make_unique<ZlibCodec>(level, type);
437 }
438
439 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
440   DCHECK(type == CodecType::ZLIB);
441   switch (level) {
442   case COMPRESSION_LEVEL_FASTEST:
443     level = 1;
444     break;
445   case COMPRESSION_LEVEL_DEFAULT:
446     level = Z_DEFAULT_COMPRESSION;
447     break;
448   case COMPRESSION_LEVEL_BEST:
449     level = 9;
450     break;
451   }
452   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
453     throw std::invalid_argument(to<std::string>(
454         "ZlibCodec: invalid level: ", level));
455   }
456   level_ = level;
457 }
458
459 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
460                                                   uint32_t length) {
461   CHECK_EQ(stream->avail_out, 0);
462
463   auto buf = IOBuf::create(length);
464   buf->append(length);
465
466   stream->next_out = buf->writableData();
467   stream->avail_out = buf->length();
468
469   return buf;
470 }
471
472 bool ZlibCodec::doInflate(z_stream* stream,
473                           IOBuf* head,
474                           uint32_t bufferLength) {
475   if (stream->avail_out == 0) {
476     head->prependChain(addOutputBuffer(stream, bufferLength));
477   }
478
479   int rc = inflate(stream, Z_NO_FLUSH);
480
481   switch (rc) {
482   case Z_OK:
483     break;
484   case Z_STREAM_END:
485     return true;
486   case Z_BUF_ERROR:
487   case Z_NEED_DICT:
488   case Z_DATA_ERROR:
489   case Z_MEM_ERROR:
490     throw std::runtime_error(to<std::string>(
491         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
492   default:
493     CHECK(false) << rc << ": " << stream->msg;
494   }
495
496   return false;
497 }
498
499 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
500   z_stream stream;
501   stream.zalloc = nullptr;
502   stream.zfree = nullptr;
503   stream.opaque = nullptr;
504
505   int rc = deflateInit(&stream, level_);
506   if (rc != Z_OK) {
507     throw std::runtime_error(to<std::string>(
508         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
509   }
510
511   stream.next_in = stream.next_out = nullptr;
512   stream.avail_in = stream.avail_out = 0;
513   stream.total_in = stream.total_out = 0;
514
515   bool success = false;
516
517   SCOPE_EXIT {
518     int rc = deflateEnd(&stream);
519     // If we're here because of an exception, it's okay if some data
520     // got dropped.
521     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
522       << rc << ": " << stream.msg;
523   };
524
525   uint64_t uncompressedLength = data->computeChainDataLength();
526   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
527
528   // Max 64MiB in one go
529   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
530   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
531
532   auto out = addOutputBuffer(
533       &stream,
534       (maxCompressedLength <= maxSingleStepLength ?
535        maxCompressedLength :
536        defaultBufferLength));
537
538   for (auto& range : *data) {
539     if (range.empty()) {
540       continue;
541     }
542
543     stream.next_in = const_cast<uint8_t*>(range.data());
544     stream.avail_in = range.size();
545
546     while (stream.avail_in != 0) {
547       if (stream.avail_out == 0) {
548         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
549       }
550
551       rc = deflate(&stream, Z_NO_FLUSH);
552
553       CHECK_EQ(rc, Z_OK) << stream.msg;
554     }
555   }
556
557   do {
558     if (stream.avail_out == 0) {
559       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
560     }
561
562     rc = deflate(&stream, Z_FINISH);
563   } while (rc == Z_OK);
564
565   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
566
567   out->prev()->trimEnd(stream.avail_out);
568
569   success = true;  // we survived
570
571   return out;
572 }
573
574 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
575                                                uint64_t uncompressedLength) {
576   z_stream stream;
577   stream.zalloc = nullptr;
578   stream.zfree = nullptr;
579   stream.opaque = nullptr;
580
581   int rc = inflateInit(&stream);
582   if (rc != Z_OK) {
583     throw std::runtime_error(to<std::string>(
584         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
585   }
586
587   stream.next_in = stream.next_out = nullptr;
588   stream.avail_in = stream.avail_out = 0;
589   stream.total_in = stream.total_out = 0;
590
591   bool success = false;
592
593   SCOPE_EXIT {
594     int rc = inflateEnd(&stream);
595     // If we're here because of an exception, it's okay if some data
596     // got dropped.
597     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
598       << rc << ": " << stream.msg;
599   };
600
601   // Max 64MiB in one go
602   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
603   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
604
605   auto out = addOutputBuffer(
606       &stream,
607       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
608         uncompressedLength <= maxSingleStepLength) ?
609        uncompressedLength :
610        defaultBufferLength));
611
612   bool streamEnd = false;
613   for (auto& range : *data) {
614     if (range.empty()) {
615       continue;
616     }
617
618     stream.next_in = const_cast<uint8_t*>(range.data());
619     stream.avail_in = range.size();
620
621     while (stream.avail_in != 0) {
622       if (streamEnd) {
623         throw std::runtime_error(to<std::string>(
624             "ZlibCodec: junk after end of data"));
625       }
626
627       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
628     }
629   }
630
631   while (!streamEnd) {
632     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
633   }
634
635   out->prev()->trimEnd(stream.avail_out);
636
637   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
638       uncompressedLength != stream.total_out) {
639     throw std::runtime_error(to<std::string>(
640         "ZlibCodec: invalid uncompressed length"));
641   }
642
643   success = true;  // we survived
644
645   return out;
646 }
647
648 #endif  // FOLLY_HAVE_LIBZ
649
650 #if FOLLY_HAVE_LIBLZMA
651
652 /**
653  * LZMA2 compression
654  */
655 class LZMA2Codec FOLLY_FINAL : public Codec {
656  public:
657   static std::unique_ptr<Codec> create(int level, CodecType type);
658   explicit LZMA2Codec(int level, CodecType type);
659
660  private:
661   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
662   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
663
664   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
665
666   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
667   std::unique_ptr<IOBuf> doUncompress(
668       const IOBuf* data,
669       uint64_t uncompressedLength) FOLLY_OVERRIDE;
670
671   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
672   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
673
674   int level_;
675 };
676
677 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
678   return make_unique<LZMA2Codec>(level, type);
679 }
680
681 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
682   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
683   switch (level) {
684   case COMPRESSION_LEVEL_FASTEST:
685     level = 0;
686     break;
687   case COMPRESSION_LEVEL_DEFAULT:
688     level = LZMA_PRESET_DEFAULT;
689     break;
690   case COMPRESSION_LEVEL_BEST:
691     level = 9;
692     break;
693   }
694   if (level < 0 || level > 9) {
695     throw std::invalid_argument(to<std::string>(
696         "LZMA2Codec: invalid level: ", level));
697   }
698   level_ = level;
699 }
700
701 bool LZMA2Codec::doNeedsUncompressedLength() const {
702   return !encodeSize();
703 }
704
705 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
706   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
707   return uint64_t(1) << 63;
708 }
709
710 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
711     lzma_stream* stream,
712     size_t length) {
713
714   CHECK_EQ(stream->avail_out, 0);
715
716   auto buf = IOBuf::create(length);
717   buf->append(length);
718
719   stream->next_out = buf->writableData();
720   stream->avail_out = buf->length();
721
722   return buf;
723 }
724
725 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
726   lzma_ret rc;
727   lzma_stream stream = LZMA_STREAM_INIT;
728
729   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
730   if (rc != LZMA_OK) {
731     throw std::runtime_error(folly::to<std::string>(
732       "LZMA2Codec: lzma_easy_encoder error: ", rc));
733   }
734
735   SCOPE_EXIT { lzma_end(&stream); };
736
737   uint64_t uncompressedLength = data->computeChainDataLength();
738   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
739
740   // Max 64MiB in one go
741   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
742   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
743
744   auto out = addOutputBuffer(
745     &stream,
746     (maxCompressedLength <= maxSingleStepLength ?
747      maxCompressedLength :
748      defaultBufferLength));
749
750   if (encodeSize()) {
751     auto size = IOBuf::createCombined(kMaxVarintLength64);
752     encodeVarintToIOBuf(uncompressedLength, size.get());
753     size->appendChain(std::move(out));
754     out = std::move(size);
755   }
756
757   for (auto& range : *data) {
758     if (range.empty()) {
759       continue;
760     }
761
762     stream.next_in = const_cast<uint8_t*>(range.data());
763     stream.avail_in = range.size();
764
765     while (stream.avail_in != 0) {
766       if (stream.avail_out == 0) {
767         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
768       }
769
770       rc = lzma_code(&stream, LZMA_RUN);
771
772       if (rc != LZMA_OK) {
773         throw std::runtime_error(folly::to<std::string>(
774           "LZMA2Codec: lzma_code error: ", rc));
775       }
776     }
777   }
778
779   do {
780     if (stream.avail_out == 0) {
781       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
782     }
783
784     rc = lzma_code(&stream, LZMA_FINISH);
785   } while (rc == LZMA_OK);
786
787   if (rc != LZMA_STREAM_END) {
788     throw std::runtime_error(folly::to<std::string>(
789       "LZMA2Codec: lzma_code ended with error: ", rc));
790   }
791
792   out->prev()->trimEnd(stream.avail_out);
793
794   return out;
795 }
796
797 bool LZMA2Codec::doInflate(lzma_stream* stream,
798                           IOBuf* head,
799                           size_t bufferLength) {
800   if (stream->avail_out == 0) {
801     head->prependChain(addOutputBuffer(stream, bufferLength));
802   }
803
804   lzma_ret rc = lzma_code(stream, LZMA_RUN);
805
806   switch (rc) {
807   case LZMA_OK:
808     break;
809   case LZMA_STREAM_END:
810     return true;
811   default:
812     throw std::runtime_error(to<std::string>(
813         "LZMA2Codec: lzma_code error: ", rc));
814   }
815
816   return false;
817 }
818
819 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
820                                                uint64_t uncompressedLength) {
821   lzma_ret rc;
822   lzma_stream stream = LZMA_STREAM_INIT;
823
824   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
825   if (rc != LZMA_OK) {
826     throw std::runtime_error(folly::to<std::string>(
827       "LZMA2Codec: lzma_auto_decoder error: ", rc));
828   }
829
830   SCOPE_EXIT { lzma_end(&stream); };
831
832   // Max 64MiB in one go
833   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
834   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
835
836   folly::io::Cursor cursor(data);
837   uint64_t actualUncompressedLength;
838   if (encodeSize()) {
839     actualUncompressedLength = decodeVarintFromCursor(cursor);
840     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
841         uncompressedLength != actualUncompressedLength) {
842       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
843     }
844   } else {
845     actualUncompressedLength = uncompressedLength;
846     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
847   }
848
849   auto out = addOutputBuffer(
850       &stream,
851       (actualUncompressedLength <= maxSingleStepLength ?
852        actualUncompressedLength :
853        defaultBufferLength));
854
855   bool streamEnd = false;
856   auto buf = cursor.peek();
857   while (buf.second != 0) {
858     stream.next_in = const_cast<uint8_t*>(buf.first);
859     stream.avail_in = buf.second;
860
861     while (stream.avail_in != 0) {
862       if (streamEnd) {
863         throw std::runtime_error(to<std::string>(
864             "LZMA2Codec: junk after end of data"));
865       }
866
867       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
868     }
869
870     cursor.skip(buf.second);
871     buf = cursor.peek();
872   }
873
874   while (!streamEnd) {
875     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
876   }
877
878   out->prev()->trimEnd(stream.avail_out);
879
880   if (actualUncompressedLength != stream.total_out) {
881     throw std::runtime_error(to<std::string>(
882         "LZMA2Codec: invalid uncompressed length"));
883   }
884
885   return out;
886 }
887
888 #endif  // FOLLY_HAVE_LIBLZMA
889
890 typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
891
892 CodecFactory gCodecFactories[
893     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
894   nullptr,  // USER_DEFINED
895   NoCompressionCodec::create,
896
897 #if FOLLY_HAVE_LIBLZ4
898   LZ4Codec::create,
899 #else
900   nullptr,
901 #endif
902
903 #if FOLLY_HAVE_LIBSNAPPY
904   SnappyCodec::create,
905 #else
906   nullptr,
907 #endif
908
909 #if FOLLY_HAVE_LIBZ
910   ZlibCodec::create,
911 #else
912   nullptr,
913 #endif
914
915 #if FOLLY_HAVE_LIBLZ4
916   LZ4Codec::create,
917 #else
918   nullptr,
919 #endif
920
921 #if FOLLY_HAVE_LIBLZMA
922   LZMA2Codec::create,
923   LZMA2Codec::create,
924 #else
925   nullptr,
926   nullptr,
927 #endif
928 };
929
930 }  // namespace
931
932 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
933   size_t idx = static_cast<size_t>(type);
934   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
935     throw std::invalid_argument(to<std::string>(
936         "Compression type ", idx, " not supported"));
937   }
938   auto factory = gCodecFactories[idx];
939   if (!factory) {
940     throw std::invalid_argument(to<std::string>(
941         "Compression type ", idx, " not supported"));
942   }
943   auto codec = (*factory)(level, type);
944   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
945   return codec;
946 }
947
948 }}  // namespaces
949