Enforce max uncompressed size; use safe LZ4 decompressor, which requires a newer...
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #include <folly/Conv.h>
40 #include <folly/Memory.h>
41 #include <folly/Portability.h>
42 #include <folly/ScopeGuard.h>
43 #include <folly/Varint.h>
44 #include <folly/io/Cursor.h>
45
46 namespace folly { namespace io {
47
48 Codec::Codec(CodecType type) : type_(type) { }
49
50 // Ensure consistent behavior in the nullptr case
51 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
52   uint64_t len = data->computeChainDataLength();
53   if (len == 0) {
54     return IOBuf::create(0);
55   } else if (len > maxUncompressedLength()) {
56     throw std::runtime_error("Codec: uncompressed length too large");
57   }
58
59   return doCompress(data);
60 }
61
62 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
63                                          uint64_t uncompressedLength) {
64   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
65     if (needsUncompressedLength()) {
66       throw std::invalid_argument("Codec: uncompressed length required");
67     }
68   } else if (uncompressedLength > maxUncompressedLength()) {
69     throw std::runtime_error("Codec: uncompressed length too large");
70   }
71
72   if (data->empty()) {
73     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
74         uncompressedLength != 0) {
75       throw std::runtime_error("Codec: invalid uncompressed length");
76     }
77     return IOBuf::create(0);
78   }
79
80   return doUncompress(data, uncompressedLength);
81 }
82
83 bool Codec::needsUncompressedLength() const {
84   return doNeedsUncompressedLength();
85 }
86
87 uint64_t Codec::maxUncompressedLength() const {
88   return doMaxUncompressedLength();
89 }
90
91 bool Codec::doNeedsUncompressedLength() const {
92   return false;
93 }
94
95 uint64_t Codec::doMaxUncompressedLength() const {
96   return UNLIMITED_UNCOMPRESSED_LENGTH;
97 }
98
99 namespace {
100
101 /**
102  * No compression
103  */
104 class NoCompressionCodec FOLLY_FINAL : public Codec {
105  public:
106   static std::unique_ptr<Codec> create(int level, CodecType type);
107   explicit NoCompressionCodec(int level, CodecType type);
108
109  private:
110   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
111   std::unique_ptr<IOBuf> doUncompress(
112       const IOBuf* data,
113       uint64_t uncompressedLength) FOLLY_OVERRIDE;
114 };
115
116 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
117   return make_unique<NoCompressionCodec>(level, type);
118 }
119
120 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
121   : Codec(type) {
122   DCHECK(type == CodecType::NO_COMPRESSION);
123   switch (level) {
124   case COMPRESSION_LEVEL_DEFAULT:
125   case COMPRESSION_LEVEL_FASTEST:
126   case COMPRESSION_LEVEL_BEST:
127     level = 0;
128   }
129   if (level != 0) {
130     throw std::invalid_argument(to<std::string>(
131         "NoCompressionCodec: invalid level ", level));
132   }
133 }
134
135 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
136     const IOBuf* data) {
137   return data->clone();
138 }
139
140 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
141     const IOBuf* data,
142     uint64_t uncompressedLength) {
143   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
144       data->computeChainDataLength() != uncompressedLength) {
145     throw std::runtime_error(to<std::string>(
146         "NoCompressionCodec: invalid uncompressed length"));
147   }
148   return data->clone();
149 }
150
151 namespace {
152
153 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
154   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
155   out->append(encodeVarint(val, out->writableTail()));
156 }
157
158 uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
159   // Must have enough room in *this* buffer.
160   auto p = cursor.peek();
161   folly::ByteRange range(p.first, p.second);
162   uint64_t val = decodeVarint(range);
163   cursor.skip(range.data() - p.first);
164   return val;
165 }
166
167 }  // namespace
168
169 #if FOLLY_HAVE_LIBLZ4
170
171 /**
172  * LZ4 compression
173  */
174 class LZ4Codec FOLLY_FINAL : public Codec {
175  public:
176   static std::unique_ptr<Codec> create(int level, CodecType type);
177   explicit LZ4Codec(int level, CodecType type);
178
179  private:
180   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
181   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
182
183   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
184
185   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
186   std::unique_ptr<IOBuf> doUncompress(
187       const IOBuf* data,
188       uint64_t uncompressedLength) FOLLY_OVERRIDE;
189
190   bool highCompression_;
191 };
192
193 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
194   return make_unique<LZ4Codec>(level, type);
195 }
196
197 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
198   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
199
200   switch (level) {
201   case COMPRESSION_LEVEL_FASTEST:
202   case COMPRESSION_LEVEL_DEFAULT:
203     level = 1;
204     break;
205   case COMPRESSION_LEVEL_BEST:
206     level = 2;
207     break;
208   }
209   if (level < 1 || level > 2) {
210     throw std::invalid_argument(to<std::string>(
211         "LZ4Codec: invalid level: ", level));
212   }
213   highCompression_ = (level > 1);
214 }
215
216 bool LZ4Codec::doNeedsUncompressedLength() const {
217   return !encodeSize();
218 }
219
220 uint64_t LZ4Codec::doMaxUncompressedLength() const {
221   return LZ4_MAX_INPUT_SIZE;
222 }
223
224 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
225   std::unique_ptr<IOBuf> clone;
226   if (data->isChained()) {
227     // LZ4 doesn't support streaming, so we have to coalesce
228     clone = data->clone();
229     clone->coalesce();
230     data = clone.get();
231   }
232
233   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
234   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
235   if (encodeSize()) {
236     encodeVarintToIOBuf(data->length(), out.get());
237   }
238
239   int n;
240   if (highCompression_) {
241     n = LZ4_compressHC(reinterpret_cast<const char*>(data->data()),
242                        reinterpret_cast<char*>(out->writableTail()),
243                        data->length());
244   } else {
245     n = LZ4_compress(reinterpret_cast<const char*>(data->data()),
246                      reinterpret_cast<char*>(out->writableTail()),
247                      data->length());
248   }
249
250   CHECK_GE(n, 0);
251   CHECK_LE(n, out->capacity());
252
253   out->append(n);
254   return out;
255 }
256
257 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
258     const IOBuf* data,
259     uint64_t uncompressedLength) {
260   std::unique_ptr<IOBuf> clone;
261   if (data->isChained()) {
262     // LZ4 doesn't support streaming, so we have to coalesce
263     clone = data->clone();
264     clone->coalesce();
265     data = clone.get();
266   }
267
268   folly::io::Cursor cursor(data);
269   uint64_t actualUncompressedLength;
270   if (encodeSize()) {
271     actualUncompressedLength = decodeVarintFromCursor(cursor);
272     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
273         uncompressedLength != actualUncompressedLength) {
274       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
275     }
276   } else {
277     actualUncompressedLength = uncompressedLength;
278     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
279         actualUncompressedLength > maxUncompressedLength()) {
280       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
281     }
282   }
283
284   auto p = cursor.peek();
285   auto out = IOBuf::create(actualUncompressedLength);
286   int n = LZ4_decompress_safe(reinterpret_cast<const char*>(p.first),
287                               reinterpret_cast<char*>(out->writableTail()),
288                               p.second,
289                               actualUncompressedLength);
290
291   if (n != actualUncompressedLength) {
292     throw std::runtime_error(to<std::string>(
293         "LZ4 decompression returned invalid value ", n));
294   }
295   out->append(actualUncompressedLength);
296   return out;
297 }
298
299 #endif  // FOLLY_HAVE_LIBLZ4
300
301 #if FOLLY_HAVE_LIBSNAPPY
302
303 /**
304  * Snappy compression
305  */
306
307 /**
308  * Implementation of snappy::Source that reads from a IOBuf chain.
309  */
310 class IOBufSnappySource FOLLY_FINAL : public snappy::Source {
311  public:
312   explicit IOBufSnappySource(const IOBuf* data);
313   size_t Available() const FOLLY_OVERRIDE;
314   const char* Peek(size_t* len) FOLLY_OVERRIDE;
315   void Skip(size_t n) FOLLY_OVERRIDE;
316  private:
317   size_t available_;
318   io::Cursor cursor_;
319 };
320
321 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
322   : available_(data->computeChainDataLength()),
323     cursor_(data) {
324 }
325
326 size_t IOBufSnappySource::Available() const {
327   return available_;
328 }
329
330 const char* IOBufSnappySource::Peek(size_t* len) {
331   auto p = cursor_.peek();
332   *len = p.second;
333   return reinterpret_cast<const char*>(p.first);
334 }
335
336 void IOBufSnappySource::Skip(size_t n) {
337   CHECK_LE(n, available_);
338   cursor_.skip(n);
339   available_ -= n;
340 }
341
342 class SnappyCodec FOLLY_FINAL : public Codec {
343  public:
344   static std::unique_ptr<Codec> create(int level, CodecType type);
345   explicit SnappyCodec(int level, CodecType type);
346
347  private:
348   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
349   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
350   std::unique_ptr<IOBuf> doUncompress(
351       const IOBuf* data,
352       uint64_t uncompressedLength) FOLLY_OVERRIDE;
353 };
354
355 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
356   return make_unique<SnappyCodec>(level, type);
357 }
358
359 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
360   DCHECK(type == CodecType::SNAPPY);
361   switch (level) {
362   case COMPRESSION_LEVEL_FASTEST:
363   case COMPRESSION_LEVEL_DEFAULT:
364   case COMPRESSION_LEVEL_BEST:
365     level = 1;
366   }
367   if (level != 1) {
368     throw std::invalid_argument(to<std::string>(
369         "SnappyCodec: invalid level: ", level));
370   }
371 }
372
373 uint64_t SnappyCodec::doMaxUncompressedLength() const {
374   // snappy.h uses uint32_t for lengths, so there's that.
375   return std::numeric_limits<uint32_t>::max();
376 }
377
378 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
379   IOBufSnappySource source(data);
380   auto out =
381     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
382
383   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
384       out->writableTail()));
385
386   size_t n = snappy::Compress(&source, &sink);
387
388   CHECK_LE(n, out->capacity());
389   out->append(n);
390   return out;
391 }
392
393 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
394                                                  uint64_t uncompressedLength) {
395   uint32_t actualUncompressedLength = 0;
396
397   {
398     IOBufSnappySource source(data);
399     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
400       throw std::runtime_error("snappy::GetUncompressedLength failed");
401     }
402     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
403         uncompressedLength != actualUncompressedLength) {
404       throw std::runtime_error("snappy: invalid uncompressed length");
405     }
406   }
407
408   auto out = IOBuf::create(actualUncompressedLength);
409
410   {
411     IOBufSnappySource source(data);
412     if (!snappy::RawUncompress(&source,
413                                reinterpret_cast<char*>(out->writableTail()))) {
414       throw std::runtime_error("snappy::RawUncompress failed");
415     }
416   }
417
418   out->append(actualUncompressedLength);
419   return out;
420 }
421
422 #endif  // FOLLY_HAVE_LIBSNAPPY
423
424 #if FOLLY_HAVE_LIBZ
425 /**
426  * Zlib codec
427  */
428 class ZlibCodec FOLLY_FINAL : public Codec {
429  public:
430   static std::unique_ptr<Codec> create(int level, CodecType type);
431   explicit ZlibCodec(int level, CodecType type);
432
433  private:
434   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
435   std::unique_ptr<IOBuf> doUncompress(
436       const IOBuf* data,
437       uint64_t uncompressedLength) FOLLY_OVERRIDE;
438
439   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
440   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
441
442   int level_;
443 };
444
445 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
446   return make_unique<ZlibCodec>(level, type);
447 }
448
449 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
450   DCHECK(type == CodecType::ZLIB);
451   switch (level) {
452   case COMPRESSION_LEVEL_FASTEST:
453     level = 1;
454     break;
455   case COMPRESSION_LEVEL_DEFAULT:
456     level = Z_DEFAULT_COMPRESSION;
457     break;
458   case COMPRESSION_LEVEL_BEST:
459     level = 9;
460     break;
461   }
462   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
463     throw std::invalid_argument(to<std::string>(
464         "ZlibCodec: invalid level: ", level));
465   }
466   level_ = level;
467 }
468
469 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
470                                                   uint32_t length) {
471   CHECK_EQ(stream->avail_out, 0);
472
473   auto buf = IOBuf::create(length);
474   buf->append(length);
475
476   stream->next_out = buf->writableData();
477   stream->avail_out = buf->length();
478
479   return buf;
480 }
481
482 bool ZlibCodec::doInflate(z_stream* stream,
483                           IOBuf* head,
484                           uint32_t bufferLength) {
485   if (stream->avail_out == 0) {
486     head->prependChain(addOutputBuffer(stream, bufferLength));
487   }
488
489   int rc = inflate(stream, Z_NO_FLUSH);
490
491   switch (rc) {
492   case Z_OK:
493     break;
494   case Z_STREAM_END:
495     return true;
496   case Z_BUF_ERROR:
497   case Z_NEED_DICT:
498   case Z_DATA_ERROR:
499   case Z_MEM_ERROR:
500     throw std::runtime_error(to<std::string>(
501         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
502   default:
503     CHECK(false) << rc << ": " << stream->msg;
504   }
505
506   return false;
507 }
508
509 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
510   z_stream stream;
511   stream.zalloc = nullptr;
512   stream.zfree = nullptr;
513   stream.opaque = nullptr;
514
515   int rc = deflateInit(&stream, level_);
516   if (rc != Z_OK) {
517     throw std::runtime_error(to<std::string>(
518         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
519   }
520
521   stream.next_in = stream.next_out = nullptr;
522   stream.avail_in = stream.avail_out = 0;
523   stream.total_in = stream.total_out = 0;
524
525   bool success = false;
526
527   SCOPE_EXIT {
528     int rc = deflateEnd(&stream);
529     // If we're here because of an exception, it's okay if some data
530     // got dropped.
531     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
532       << rc << ": " << stream.msg;
533   };
534
535   uint64_t uncompressedLength = data->computeChainDataLength();
536   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
537
538   // Max 64MiB in one go
539   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
540   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
541
542   auto out = addOutputBuffer(
543       &stream,
544       (maxCompressedLength <= maxSingleStepLength ?
545        maxCompressedLength :
546        defaultBufferLength));
547
548   for (auto& range : *data) {
549     if (range.empty()) {
550       continue;
551     }
552
553     stream.next_in = const_cast<uint8_t*>(range.data());
554     stream.avail_in = range.size();
555
556     while (stream.avail_in != 0) {
557       if (stream.avail_out == 0) {
558         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
559       }
560
561       rc = deflate(&stream, Z_NO_FLUSH);
562
563       CHECK_EQ(rc, Z_OK) << stream.msg;
564     }
565   }
566
567   do {
568     if (stream.avail_out == 0) {
569       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
570     }
571
572     rc = deflate(&stream, Z_FINISH);
573   } while (rc == Z_OK);
574
575   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
576
577   out->prev()->trimEnd(stream.avail_out);
578
579   success = true;  // we survived
580
581   return out;
582 }
583
584 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
585                                                uint64_t uncompressedLength) {
586   z_stream stream;
587   stream.zalloc = nullptr;
588   stream.zfree = nullptr;
589   stream.opaque = nullptr;
590
591   int rc = inflateInit(&stream);
592   if (rc != Z_OK) {
593     throw std::runtime_error(to<std::string>(
594         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
595   }
596
597   stream.next_in = stream.next_out = nullptr;
598   stream.avail_in = stream.avail_out = 0;
599   stream.total_in = stream.total_out = 0;
600
601   bool success = false;
602
603   SCOPE_EXIT {
604     int rc = inflateEnd(&stream);
605     // If we're here because of an exception, it's okay if some data
606     // got dropped.
607     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
608       << rc << ": " << stream.msg;
609   };
610
611   // Max 64MiB in one go
612   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
613   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
614
615   auto out = addOutputBuffer(
616       &stream,
617       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
618         uncompressedLength <= maxSingleStepLength) ?
619        uncompressedLength :
620        defaultBufferLength));
621
622   bool streamEnd = false;
623   for (auto& range : *data) {
624     if (range.empty()) {
625       continue;
626     }
627
628     stream.next_in = const_cast<uint8_t*>(range.data());
629     stream.avail_in = range.size();
630
631     while (stream.avail_in != 0) {
632       if (streamEnd) {
633         throw std::runtime_error(to<std::string>(
634             "ZlibCodec: junk after end of data"));
635       }
636
637       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
638     }
639   }
640
641   while (!streamEnd) {
642     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
643   }
644
645   out->prev()->trimEnd(stream.avail_out);
646
647   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
648       uncompressedLength != stream.total_out) {
649     throw std::runtime_error(to<std::string>(
650         "ZlibCodec: invalid uncompressed length"));
651   }
652
653   success = true;  // we survived
654
655   return out;
656 }
657
658 #endif  // FOLLY_HAVE_LIBZ
659
660 #if FOLLY_HAVE_LIBLZMA
661
662 /**
663  * LZMA2 compression
664  */
665 class LZMA2Codec FOLLY_FINAL : public Codec {
666  public:
667   static std::unique_ptr<Codec> create(int level, CodecType type);
668   explicit LZMA2Codec(int level, CodecType type);
669
670  private:
671   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
672   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
673
674   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
675
676   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
677   std::unique_ptr<IOBuf> doUncompress(
678       const IOBuf* data,
679       uint64_t uncompressedLength) FOLLY_OVERRIDE;
680
681   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
682   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
683
684   int level_;
685 };
686
687 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
688   return make_unique<LZMA2Codec>(level, type);
689 }
690
691 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
692   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
693   switch (level) {
694   case COMPRESSION_LEVEL_FASTEST:
695     level = 0;
696     break;
697   case COMPRESSION_LEVEL_DEFAULT:
698     level = LZMA_PRESET_DEFAULT;
699     break;
700   case COMPRESSION_LEVEL_BEST:
701     level = 9;
702     break;
703   }
704   if (level < 0 || level > 9) {
705     throw std::invalid_argument(to<std::string>(
706         "LZMA2Codec: invalid level: ", level));
707   }
708   level_ = level;
709 }
710
711 bool LZMA2Codec::doNeedsUncompressedLength() const {
712   return !encodeSize();
713 }
714
715 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
716   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
717   return uint64_t(1) << 63;
718 }
719
720 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
721     lzma_stream* stream,
722     size_t length) {
723
724   CHECK_EQ(stream->avail_out, 0);
725
726   auto buf = IOBuf::create(length);
727   buf->append(length);
728
729   stream->next_out = buf->writableData();
730   stream->avail_out = buf->length();
731
732   return buf;
733 }
734
735 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
736   lzma_ret rc;
737   lzma_stream stream = LZMA_STREAM_INIT;
738
739   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
740   if (rc != LZMA_OK) {
741     throw std::runtime_error(folly::to<std::string>(
742       "LZMA2Codec: lzma_easy_encoder error: ", rc));
743   }
744
745   SCOPE_EXIT { lzma_end(&stream); };
746
747   uint64_t uncompressedLength = data->computeChainDataLength();
748   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
749
750   // Max 64MiB in one go
751   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
752   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
753
754   auto out = addOutputBuffer(
755     &stream,
756     (maxCompressedLength <= maxSingleStepLength ?
757      maxCompressedLength :
758      defaultBufferLength));
759
760   if (encodeSize()) {
761     auto size = IOBuf::createCombined(kMaxVarintLength64);
762     encodeVarintToIOBuf(uncompressedLength, size.get());
763     size->appendChain(std::move(out));
764     out = std::move(size);
765   }
766
767   for (auto& range : *data) {
768     if (range.empty()) {
769       continue;
770     }
771
772     stream.next_in = const_cast<uint8_t*>(range.data());
773     stream.avail_in = range.size();
774
775     while (stream.avail_in != 0) {
776       if (stream.avail_out == 0) {
777         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
778       }
779
780       rc = lzma_code(&stream, LZMA_RUN);
781
782       if (rc != LZMA_OK) {
783         throw std::runtime_error(folly::to<std::string>(
784           "LZMA2Codec: lzma_code error: ", rc));
785       }
786     }
787   }
788
789   do {
790     if (stream.avail_out == 0) {
791       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
792     }
793
794     rc = lzma_code(&stream, LZMA_FINISH);
795   } while (rc == LZMA_OK);
796
797   if (rc != LZMA_STREAM_END) {
798     throw std::runtime_error(folly::to<std::string>(
799       "LZMA2Codec: lzma_code ended with error: ", rc));
800   }
801
802   out->prev()->trimEnd(stream.avail_out);
803
804   return out;
805 }
806
807 bool LZMA2Codec::doInflate(lzma_stream* stream,
808                           IOBuf* head,
809                           size_t bufferLength) {
810   if (stream->avail_out == 0) {
811     head->prependChain(addOutputBuffer(stream, bufferLength));
812   }
813
814   lzma_ret rc = lzma_code(stream, LZMA_RUN);
815
816   switch (rc) {
817   case LZMA_OK:
818     break;
819   case LZMA_STREAM_END:
820     return true;
821   default:
822     throw std::runtime_error(to<std::string>(
823         "LZMA2Codec: lzma_code error: ", rc));
824   }
825
826   return false;
827 }
828
829 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
830                                                uint64_t uncompressedLength) {
831   lzma_ret rc;
832   lzma_stream stream = LZMA_STREAM_INIT;
833
834   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
835   if (rc != LZMA_OK) {
836     throw std::runtime_error(folly::to<std::string>(
837       "LZMA2Codec: lzma_auto_decoder error: ", rc));
838   }
839
840   SCOPE_EXIT { lzma_end(&stream); };
841
842   // Max 64MiB in one go
843   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
844   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
845
846   folly::io::Cursor cursor(data);
847   uint64_t actualUncompressedLength;
848   if (encodeSize()) {
849     actualUncompressedLength = decodeVarintFromCursor(cursor);
850     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
851         uncompressedLength != actualUncompressedLength) {
852       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
853     }
854   } else {
855     actualUncompressedLength = uncompressedLength;
856     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
857   }
858
859   auto out = addOutputBuffer(
860       &stream,
861       (actualUncompressedLength <= maxSingleStepLength ?
862        actualUncompressedLength :
863        defaultBufferLength));
864
865   bool streamEnd = false;
866   auto buf = cursor.peek();
867   while (buf.second != 0) {
868     stream.next_in = const_cast<uint8_t*>(buf.first);
869     stream.avail_in = buf.second;
870
871     while (stream.avail_in != 0) {
872       if (streamEnd) {
873         throw std::runtime_error(to<std::string>(
874             "LZMA2Codec: junk after end of data"));
875       }
876
877       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
878     }
879
880     cursor.skip(buf.second);
881     buf = cursor.peek();
882   }
883
884   while (!streamEnd) {
885     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
886   }
887
888   out->prev()->trimEnd(stream.avail_out);
889
890   if (actualUncompressedLength != stream.total_out) {
891     throw std::runtime_error(to<std::string>(
892         "LZMA2Codec: invalid uncompressed length"));
893   }
894
895   return out;
896 }
897
898 #endif  // FOLLY_HAVE_LIBLZMA
899
900 typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
901
902 CodecFactory gCodecFactories[
903     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
904   nullptr,  // USER_DEFINED
905   NoCompressionCodec::create,
906
907 #if FOLLY_HAVE_LIBLZ4
908   LZ4Codec::create,
909 #else
910   nullptr,
911 #endif
912
913 #if FOLLY_HAVE_LIBSNAPPY
914   SnappyCodec::create,
915 #else
916   nullptr,
917 #endif
918
919 #if FOLLY_HAVE_LIBZ
920   ZlibCodec::create,
921 #else
922   nullptr,
923 #endif
924
925 #if FOLLY_HAVE_LIBLZ4
926   LZ4Codec::create,
927 #else
928   nullptr,
929 #endif
930
931 #if FOLLY_HAVE_LIBLZMA
932   LZMA2Codec::create,
933   LZMA2Codec::create,
934 #else
935   nullptr,
936   nullptr,
937 #endif
938 };
939
940 }  // namespace
941
942 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
943   size_t idx = static_cast<size_t>(type);
944   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
945     throw std::invalid_argument(to<std::string>(
946         "Compression type ", idx, " not supported"));
947   }
948   auto factory = gCodecFactories[idx];
949   if (!factory) {
950     throw std::invalid_argument(to<std::string>(
951         "Compression type ", idx, " not supported"));
952   }
953   auto codec = (*factory)(level, type);
954   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
955   return codec;
956 }
957
958 }}  // namespaces
959