9425f98f278c69595380b575ce7723272b99829d
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #include <folly/Conv.h>
40 #include <folly/Memory.h>
41 #include <folly/Portability.h>
42 #include <folly/ScopeGuard.h>
43 #include <folly/Varint.h>
44 #include <folly/io/Cursor.h>
45
46 namespace folly { namespace io {
47
48 Codec::Codec(CodecType type) : type_(type) { }
49
50 // Ensure consistent behavior in the nullptr case
51 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
52   uint64_t len = data->computeChainDataLength();
53   if (len == 0) {
54     return IOBuf::create(0);
55   } else if (len > maxUncompressedLength()) {
56     throw std::runtime_error("Codec: uncompressed length too large");
57   }
58
59   return doCompress(data);
60 }
61
62 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
63                                          uint64_t uncompressedLength) {
64   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
65     if (needsUncompressedLength()) {
66       throw std::invalid_argument("Codec: uncompressed length required");
67     }
68   } else if (uncompressedLength > maxUncompressedLength()) {
69     throw std::runtime_error("Codec: uncompressed length too large");
70   }
71
72   if (data->empty()) {
73     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
74         uncompressedLength != 0) {
75       throw std::runtime_error("Codec: invalid uncompressed length");
76     }
77     return IOBuf::create(0);
78   }
79
80   return doUncompress(data, uncompressedLength);
81 }
82
83 bool Codec::needsUncompressedLength() const {
84   return doNeedsUncompressedLength();
85 }
86
87 uint64_t Codec::maxUncompressedLength() const {
88   return doMaxUncompressedLength();
89 }
90
91 bool Codec::doNeedsUncompressedLength() const {
92   return false;
93 }
94
95 uint64_t Codec::doMaxUncompressedLength() const {
96   return UNLIMITED_UNCOMPRESSED_LENGTH;
97 }
98
99 namespace {
100
101 /**
102  * No compression
103  */
104 class NoCompressionCodec FOLLY_FINAL : public Codec {
105  public:
106   static std::unique_ptr<Codec> create(int level, CodecType type);
107   explicit NoCompressionCodec(int level, CodecType type);
108
109  private:
110   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
111   std::unique_ptr<IOBuf> doUncompress(
112       const IOBuf* data,
113       uint64_t uncompressedLength) FOLLY_OVERRIDE;
114 };
115
116 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
117   return make_unique<NoCompressionCodec>(level, type);
118 }
119
120 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
121   : Codec(type) {
122   DCHECK(type == CodecType::NO_COMPRESSION);
123   switch (level) {
124   case COMPRESSION_LEVEL_DEFAULT:
125   case COMPRESSION_LEVEL_FASTEST:
126   case COMPRESSION_LEVEL_BEST:
127     level = 0;
128   }
129   if (level != 0) {
130     throw std::invalid_argument(to<std::string>(
131         "NoCompressionCodec: invalid level ", level));
132   }
133 }
134
135 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
136     const IOBuf* data) {
137   return data->clone();
138 }
139
140 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
141     const IOBuf* data,
142     uint64_t uncompressedLength) {
143   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
144       data->computeChainDataLength() != uncompressedLength) {
145     throw std::runtime_error(to<std::string>(
146         "NoCompressionCodec: invalid uncompressed length"));
147   }
148   return data->clone();
149 }
150
151 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
152
153 namespace {
154
155 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
156   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
157   out->append(encodeVarint(val, out->writableTail()));
158 }
159
160 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
161   uint64_t val = 0;
162   int8_t b = 0;
163   for (int shift = 0; shift <= 63; shift += 7) {
164     b = cursor.read<int8_t>();
165     val |= static_cast<uint64_t>(b & 0x7f) << shift;
166     if (b >= 0) {
167       break;
168     }
169   }
170   if (b < 0) {
171     throw std::invalid_argument("Invalid varint value. Too big.");
172   }
173   return val;
174 }
175
176 }  // namespace
177
178 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
179
180 #if FOLLY_HAVE_LIBLZ4
181
182 /**
183  * LZ4 compression
184  */
185 class LZ4Codec FOLLY_FINAL : public Codec {
186  public:
187   static std::unique_ptr<Codec> create(int level, CodecType type);
188   explicit LZ4Codec(int level, CodecType type);
189
190  private:
191   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
192   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
193
194   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
195
196   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
197   std::unique_ptr<IOBuf> doUncompress(
198       const IOBuf* data,
199       uint64_t uncompressedLength) FOLLY_OVERRIDE;
200
201   bool highCompression_;
202 };
203
204 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
205   return make_unique<LZ4Codec>(level, type);
206 }
207
208 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
209   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
210
211   switch (level) {
212   case COMPRESSION_LEVEL_FASTEST:
213   case COMPRESSION_LEVEL_DEFAULT:
214     level = 1;
215     break;
216   case COMPRESSION_LEVEL_BEST:
217     level = 2;
218     break;
219   }
220   if (level < 1 || level > 2) {
221     throw std::invalid_argument(to<std::string>(
222         "LZ4Codec: invalid level: ", level));
223   }
224   highCompression_ = (level > 1);
225 }
226
227 bool LZ4Codec::doNeedsUncompressedLength() const {
228   return !encodeSize();
229 }
230
231 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
232 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
233 // here.
234 #ifndef LZ4_MAX_INPUT_SIZE
235 # define LZ4_MAX_INPUT_SIZE 0x7E000000
236 #endif
237
238 uint64_t LZ4Codec::doMaxUncompressedLength() const {
239   return LZ4_MAX_INPUT_SIZE;
240 }
241
242 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
243   std::unique_ptr<IOBuf> clone;
244   if (data->isChained()) {
245     // LZ4 doesn't support streaming, so we have to coalesce
246     clone = data->clone();
247     clone->coalesce();
248     data = clone.get();
249   }
250
251   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
252   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
253   if (encodeSize()) {
254     encodeVarintToIOBuf(data->length(), out.get());
255   }
256
257   int n;
258   if (highCompression_) {
259     n = LZ4_compressHC(reinterpret_cast<const char*>(data->data()),
260                        reinterpret_cast<char*>(out->writableTail()),
261                        data->length());
262   } else {
263     n = LZ4_compress(reinterpret_cast<const char*>(data->data()),
264                      reinterpret_cast<char*>(out->writableTail()),
265                      data->length());
266   }
267
268   CHECK_GE(n, 0);
269   CHECK_LE(n, out->capacity());
270
271   out->append(n);
272   return out;
273 }
274
275 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
276     const IOBuf* data,
277     uint64_t uncompressedLength) {
278   std::unique_ptr<IOBuf> clone;
279   if (data->isChained()) {
280     // LZ4 doesn't support streaming, so we have to coalesce
281     clone = data->clone();
282     clone->coalesce();
283     data = clone.get();
284   }
285
286   folly::io::Cursor cursor(data);
287   uint64_t actualUncompressedLength;
288   if (encodeSize()) {
289     actualUncompressedLength = decodeVarintFromCursor(cursor);
290     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
291         uncompressedLength != actualUncompressedLength) {
292       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
293     }
294   } else {
295     actualUncompressedLength = uncompressedLength;
296     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
297         actualUncompressedLength > maxUncompressedLength()) {
298       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
299     }
300   }
301
302   auto p = cursor.peek();
303   auto out = IOBuf::create(actualUncompressedLength);
304   int n = LZ4_decompress_safe(reinterpret_cast<const char*>(p.first),
305                               reinterpret_cast<char*>(out->writableTail()),
306                               p.second,
307                               actualUncompressedLength);
308
309   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
310     throw std::runtime_error(to<std::string>(
311         "LZ4 decompression returned invalid value ", n));
312   }
313   out->append(actualUncompressedLength);
314   return out;
315 }
316
317 #endif  // FOLLY_HAVE_LIBLZ4
318
319 #if FOLLY_HAVE_LIBSNAPPY
320
321 /**
322  * Snappy compression
323  */
324
325 /**
326  * Implementation of snappy::Source that reads from a IOBuf chain.
327  */
328 class IOBufSnappySource FOLLY_FINAL : public snappy::Source {
329  public:
330   explicit IOBufSnappySource(const IOBuf* data);
331   size_t Available() const FOLLY_OVERRIDE;
332   const char* Peek(size_t* len) FOLLY_OVERRIDE;
333   void Skip(size_t n) FOLLY_OVERRIDE;
334  private:
335   size_t available_;
336   io::Cursor cursor_;
337 };
338
339 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
340   : available_(data->computeChainDataLength()),
341     cursor_(data) {
342 }
343
344 size_t IOBufSnappySource::Available() const {
345   return available_;
346 }
347
348 const char* IOBufSnappySource::Peek(size_t* len) {
349   auto p = cursor_.peek();
350   *len = p.second;
351   return reinterpret_cast<const char*>(p.first);
352 }
353
354 void IOBufSnappySource::Skip(size_t n) {
355   CHECK_LE(n, available_);
356   cursor_.skip(n);
357   available_ -= n;
358 }
359
360 class SnappyCodec FOLLY_FINAL : public Codec {
361  public:
362   static std::unique_ptr<Codec> create(int level, CodecType type);
363   explicit SnappyCodec(int level, CodecType type);
364
365  private:
366   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
367   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
368   std::unique_ptr<IOBuf> doUncompress(
369       const IOBuf* data,
370       uint64_t uncompressedLength) FOLLY_OVERRIDE;
371 };
372
373 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
374   return make_unique<SnappyCodec>(level, type);
375 }
376
377 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
378   DCHECK(type == CodecType::SNAPPY);
379   switch (level) {
380   case COMPRESSION_LEVEL_FASTEST:
381   case COMPRESSION_LEVEL_DEFAULT:
382   case COMPRESSION_LEVEL_BEST:
383     level = 1;
384   }
385   if (level != 1) {
386     throw std::invalid_argument(to<std::string>(
387         "SnappyCodec: invalid level: ", level));
388   }
389 }
390
391 uint64_t SnappyCodec::doMaxUncompressedLength() const {
392   // snappy.h uses uint32_t for lengths, so there's that.
393   return std::numeric_limits<uint32_t>::max();
394 }
395
396 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
397   IOBufSnappySource source(data);
398   auto out =
399     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
400
401   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
402       out->writableTail()));
403
404   size_t n = snappy::Compress(&source, &sink);
405
406   CHECK_LE(n, out->capacity());
407   out->append(n);
408   return out;
409 }
410
411 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
412                                                  uint64_t uncompressedLength) {
413   uint32_t actualUncompressedLength = 0;
414
415   {
416     IOBufSnappySource source(data);
417     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
418       throw std::runtime_error("snappy::GetUncompressedLength failed");
419     }
420     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
421         uncompressedLength != actualUncompressedLength) {
422       throw std::runtime_error("snappy: invalid uncompressed length");
423     }
424   }
425
426   auto out = IOBuf::create(actualUncompressedLength);
427
428   {
429     IOBufSnappySource source(data);
430     if (!snappy::RawUncompress(&source,
431                                reinterpret_cast<char*>(out->writableTail()))) {
432       throw std::runtime_error("snappy::RawUncompress failed");
433     }
434   }
435
436   out->append(actualUncompressedLength);
437   return out;
438 }
439
440 #endif  // FOLLY_HAVE_LIBSNAPPY
441
442 #if FOLLY_HAVE_LIBZ
443 /**
444  * Zlib codec
445  */
446 class ZlibCodec FOLLY_FINAL : public Codec {
447  public:
448   static std::unique_ptr<Codec> create(int level, CodecType type);
449   explicit ZlibCodec(int level, CodecType type);
450
451  private:
452   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
453   std::unique_ptr<IOBuf> doUncompress(
454       const IOBuf* data,
455       uint64_t uncompressedLength) FOLLY_OVERRIDE;
456
457   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
458   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
459
460   int level_;
461 };
462
463 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
464   return make_unique<ZlibCodec>(level, type);
465 }
466
467 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
468   DCHECK(type == CodecType::ZLIB);
469   switch (level) {
470   case COMPRESSION_LEVEL_FASTEST:
471     level = 1;
472     break;
473   case COMPRESSION_LEVEL_DEFAULT:
474     level = Z_DEFAULT_COMPRESSION;
475     break;
476   case COMPRESSION_LEVEL_BEST:
477     level = 9;
478     break;
479   }
480   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
481     throw std::invalid_argument(to<std::string>(
482         "ZlibCodec: invalid level: ", level));
483   }
484   level_ = level;
485 }
486
487 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
488                                                   uint32_t length) {
489   CHECK_EQ(stream->avail_out, 0);
490
491   auto buf = IOBuf::create(length);
492   buf->append(length);
493
494   stream->next_out = buf->writableData();
495   stream->avail_out = buf->length();
496
497   return buf;
498 }
499
500 bool ZlibCodec::doInflate(z_stream* stream,
501                           IOBuf* head,
502                           uint32_t bufferLength) {
503   if (stream->avail_out == 0) {
504     head->prependChain(addOutputBuffer(stream, bufferLength));
505   }
506
507   int rc = inflate(stream, Z_NO_FLUSH);
508
509   switch (rc) {
510   case Z_OK:
511     break;
512   case Z_STREAM_END:
513     return true;
514   case Z_BUF_ERROR:
515   case Z_NEED_DICT:
516   case Z_DATA_ERROR:
517   case Z_MEM_ERROR:
518     throw std::runtime_error(to<std::string>(
519         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
520   default:
521     CHECK(false) << rc << ": " << stream->msg;
522   }
523
524   return false;
525 }
526
527 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
528   z_stream stream;
529   stream.zalloc = nullptr;
530   stream.zfree = nullptr;
531   stream.opaque = nullptr;
532
533   int rc = deflateInit(&stream, level_);
534   if (rc != Z_OK) {
535     throw std::runtime_error(to<std::string>(
536         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
537   }
538
539   stream.next_in = stream.next_out = nullptr;
540   stream.avail_in = stream.avail_out = 0;
541   stream.total_in = stream.total_out = 0;
542
543   bool success = false;
544
545   SCOPE_EXIT {
546     int rc = deflateEnd(&stream);
547     // If we're here because of an exception, it's okay if some data
548     // got dropped.
549     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
550       << rc << ": " << stream.msg;
551   };
552
553   uint64_t uncompressedLength = data->computeChainDataLength();
554   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
555
556   // Max 64MiB in one go
557   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
558   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
559
560   auto out = addOutputBuffer(
561       &stream,
562       (maxCompressedLength <= maxSingleStepLength ?
563        maxCompressedLength :
564        defaultBufferLength));
565
566   for (auto& range : *data) {
567     uint64_t remaining = range.size();
568     uint64_t written = 0;
569     while (remaining) {
570       uint32_t step = (remaining > maxSingleStepLength ?
571                        maxSingleStepLength : remaining);
572       stream.next_in = const_cast<uint8_t*>(range.data() + written);
573       stream.avail_in = step;
574       remaining -= step;
575       written += step;
576
577       while (stream.avail_in != 0) {
578         if (stream.avail_out == 0) {
579           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
580         }
581
582         rc = deflate(&stream, Z_NO_FLUSH);
583
584         CHECK_EQ(rc, Z_OK) << stream.msg;
585       }
586     }
587   }
588
589   do {
590     if (stream.avail_out == 0) {
591       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
592     }
593
594     rc = deflate(&stream, Z_FINISH);
595   } while (rc == Z_OK);
596
597   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
598
599   out->prev()->trimEnd(stream.avail_out);
600
601   success = true;  // we survived
602
603   return out;
604 }
605
606 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
607                                                uint64_t uncompressedLength) {
608   z_stream stream;
609   stream.zalloc = nullptr;
610   stream.zfree = nullptr;
611   stream.opaque = nullptr;
612
613   int rc = inflateInit(&stream);
614   if (rc != Z_OK) {
615     throw std::runtime_error(to<std::string>(
616         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
617   }
618
619   stream.next_in = stream.next_out = nullptr;
620   stream.avail_in = stream.avail_out = 0;
621   stream.total_in = stream.total_out = 0;
622
623   bool success = false;
624
625   SCOPE_EXIT {
626     int rc = inflateEnd(&stream);
627     // If we're here because of an exception, it's okay if some data
628     // got dropped.
629     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
630       << rc << ": " << stream.msg;
631   };
632
633   // Max 64MiB in one go
634   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
635   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
636
637   auto out = addOutputBuffer(
638       &stream,
639       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
640         uncompressedLength <= maxSingleStepLength) ?
641        uncompressedLength :
642        defaultBufferLength));
643
644   bool streamEnd = false;
645   for (auto& range : *data) {
646     if (range.empty()) {
647       continue;
648     }
649
650     stream.next_in = const_cast<uint8_t*>(range.data());
651     stream.avail_in = range.size();
652
653     while (stream.avail_in != 0) {
654       if (streamEnd) {
655         throw std::runtime_error(to<std::string>(
656             "ZlibCodec: junk after end of data"));
657       }
658
659       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
660     }
661   }
662
663   while (!streamEnd) {
664     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
665   }
666
667   out->prev()->trimEnd(stream.avail_out);
668
669   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
670       uncompressedLength != stream.total_out) {
671     throw std::runtime_error(to<std::string>(
672         "ZlibCodec: invalid uncompressed length"));
673   }
674
675   success = true;  // we survived
676
677   return out;
678 }
679
680 #endif  // FOLLY_HAVE_LIBZ
681
682 #if FOLLY_HAVE_LIBLZMA
683
684 /**
685  * LZMA2 compression
686  */
687 class LZMA2Codec FOLLY_FINAL : public Codec {
688  public:
689   static std::unique_ptr<Codec> create(int level, CodecType type);
690   explicit LZMA2Codec(int level, CodecType type);
691
692  private:
693   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
694   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
695
696   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
697
698   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
699   std::unique_ptr<IOBuf> doUncompress(
700       const IOBuf* data,
701       uint64_t uncompressedLength) FOLLY_OVERRIDE;
702
703   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
704   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
705
706   int level_;
707 };
708
709 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
710   return make_unique<LZMA2Codec>(level, type);
711 }
712
713 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
714   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
715   switch (level) {
716   case COMPRESSION_LEVEL_FASTEST:
717     level = 0;
718     break;
719   case COMPRESSION_LEVEL_DEFAULT:
720     level = LZMA_PRESET_DEFAULT;
721     break;
722   case COMPRESSION_LEVEL_BEST:
723     level = 9;
724     break;
725   }
726   if (level < 0 || level > 9) {
727     throw std::invalid_argument(to<std::string>(
728         "LZMA2Codec: invalid level: ", level));
729   }
730   level_ = level;
731 }
732
733 bool LZMA2Codec::doNeedsUncompressedLength() const {
734   return !encodeSize();
735 }
736
737 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
738   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
739   return uint64_t(1) << 63;
740 }
741
742 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
743     lzma_stream* stream,
744     size_t length) {
745
746   CHECK_EQ(stream->avail_out, 0);
747
748   auto buf = IOBuf::create(length);
749   buf->append(length);
750
751   stream->next_out = buf->writableData();
752   stream->avail_out = buf->length();
753
754   return buf;
755 }
756
757 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
758   lzma_ret rc;
759   lzma_stream stream = LZMA_STREAM_INIT;
760
761   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
762   if (rc != LZMA_OK) {
763     throw std::runtime_error(folly::to<std::string>(
764       "LZMA2Codec: lzma_easy_encoder error: ", rc));
765   }
766
767   SCOPE_EXIT { lzma_end(&stream); };
768
769   uint64_t uncompressedLength = data->computeChainDataLength();
770   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
771
772   // Max 64MiB in one go
773   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
774   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
775
776   auto out = addOutputBuffer(
777     &stream,
778     (maxCompressedLength <= maxSingleStepLength ?
779      maxCompressedLength :
780      defaultBufferLength));
781
782   if (encodeSize()) {
783     auto size = IOBuf::createCombined(kMaxVarintLength64);
784     encodeVarintToIOBuf(uncompressedLength, size.get());
785     size->appendChain(std::move(out));
786     out = std::move(size);
787   }
788
789   for (auto& range : *data) {
790     if (range.empty()) {
791       continue;
792     }
793
794     stream.next_in = const_cast<uint8_t*>(range.data());
795     stream.avail_in = range.size();
796
797     while (stream.avail_in != 0) {
798       if (stream.avail_out == 0) {
799         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
800       }
801
802       rc = lzma_code(&stream, LZMA_RUN);
803
804       if (rc != LZMA_OK) {
805         throw std::runtime_error(folly::to<std::string>(
806           "LZMA2Codec: lzma_code error: ", rc));
807       }
808     }
809   }
810
811   do {
812     if (stream.avail_out == 0) {
813       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
814     }
815
816     rc = lzma_code(&stream, LZMA_FINISH);
817   } while (rc == LZMA_OK);
818
819   if (rc != LZMA_STREAM_END) {
820     throw std::runtime_error(folly::to<std::string>(
821       "LZMA2Codec: lzma_code ended with error: ", rc));
822   }
823
824   out->prev()->trimEnd(stream.avail_out);
825
826   return out;
827 }
828
829 bool LZMA2Codec::doInflate(lzma_stream* stream,
830                           IOBuf* head,
831                           size_t bufferLength) {
832   if (stream->avail_out == 0) {
833     head->prependChain(addOutputBuffer(stream, bufferLength));
834   }
835
836   lzma_ret rc = lzma_code(stream, LZMA_RUN);
837
838   switch (rc) {
839   case LZMA_OK:
840     break;
841   case LZMA_STREAM_END:
842     return true;
843   default:
844     throw std::runtime_error(to<std::string>(
845         "LZMA2Codec: lzma_code error: ", rc));
846   }
847
848   return false;
849 }
850
851 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
852                                                uint64_t uncompressedLength) {
853   lzma_ret rc;
854   lzma_stream stream = LZMA_STREAM_INIT;
855
856   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
857   if (rc != LZMA_OK) {
858     throw std::runtime_error(folly::to<std::string>(
859       "LZMA2Codec: lzma_auto_decoder error: ", rc));
860   }
861
862   SCOPE_EXIT { lzma_end(&stream); };
863
864   // Max 64MiB in one go
865   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
866   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
867
868   folly::io::Cursor cursor(data);
869   uint64_t actualUncompressedLength;
870   if (encodeSize()) {
871     actualUncompressedLength = decodeVarintFromCursor(cursor);
872     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
873         uncompressedLength != actualUncompressedLength) {
874       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
875     }
876   } else {
877     actualUncompressedLength = uncompressedLength;
878     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
879   }
880
881   auto out = addOutputBuffer(
882       &stream,
883       (actualUncompressedLength <= maxSingleStepLength ?
884        actualUncompressedLength :
885        defaultBufferLength));
886
887   bool streamEnd = false;
888   auto buf = cursor.peek();
889   while (buf.second != 0) {
890     stream.next_in = const_cast<uint8_t*>(buf.first);
891     stream.avail_in = buf.second;
892
893     while (stream.avail_in != 0) {
894       if (streamEnd) {
895         throw std::runtime_error(to<std::string>(
896             "LZMA2Codec: junk after end of data"));
897       }
898
899       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
900     }
901
902     cursor.skip(buf.second);
903     buf = cursor.peek();
904   }
905
906   while (!streamEnd) {
907     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
908   }
909
910   out->prev()->trimEnd(stream.avail_out);
911
912   if (actualUncompressedLength != stream.total_out) {
913     throw std::runtime_error(to<std::string>(
914         "LZMA2Codec: invalid uncompressed length"));
915   }
916
917   return out;
918 }
919
920 #endif  // FOLLY_HAVE_LIBLZMA
921
922 }  // namespace
923
924 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
925   typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
926
927   static CodecFactory codecFactories[
928     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
929     nullptr,  // USER_DEFINED
930     NoCompressionCodec::create,
931
932 #if FOLLY_HAVE_LIBLZ4
933     LZ4Codec::create,
934 #else
935     nullptr,
936 #endif
937
938 #if FOLLY_HAVE_LIBSNAPPY
939     SnappyCodec::create,
940 #else
941     nullptr,
942 #endif
943
944 #if FOLLY_HAVE_LIBZ
945     ZlibCodec::create,
946 #else
947     nullptr,
948 #endif
949
950 #if FOLLY_HAVE_LIBLZ4
951     LZ4Codec::create,
952 #else
953     nullptr,
954 #endif
955
956 #if FOLLY_HAVE_LIBLZMA
957     LZMA2Codec::create,
958     LZMA2Codec::create,
959 #else
960     nullptr,
961     nullptr,
962 #endif
963   };
964
965   size_t idx = static_cast<size_t>(type);
966   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
967     throw std::invalid_argument(to<std::string>(
968         "Compression type ", idx, " not supported"));
969   }
970   auto factory = codecFactories[idx];
971   if (!factory) {
972     throw std::invalid_argument(to<std::string>(
973         "Compression type ", idx, " not supported"));
974   }
975   auto codec = (*factory)(level, type);
976   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
977   return codec;
978 }
979
980 }}  // namespaces