zlib compression fails on large IOBufs
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #include <folly/Conv.h>
40 #include <folly/Memory.h>
41 #include <folly/Portability.h>
42 #include <folly/ScopeGuard.h>
43 #include <folly/Varint.h>
44 #include <folly/io/Cursor.h>
45
46 namespace folly { namespace io {
47
48 Codec::Codec(CodecType type) : type_(type) { }
49
50 // Ensure consistent behavior in the nullptr case
51 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
52   uint64_t len = data->computeChainDataLength();
53   if (len == 0) {
54     return IOBuf::create(0);
55   } else if (len > maxUncompressedLength()) {
56     throw std::runtime_error("Codec: uncompressed length too large");
57   }
58
59   return doCompress(data);
60 }
61
62 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
63                                          uint64_t uncompressedLength) {
64   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
65     if (needsUncompressedLength()) {
66       throw std::invalid_argument("Codec: uncompressed length required");
67     }
68   } else if (uncompressedLength > maxUncompressedLength()) {
69     throw std::runtime_error("Codec: uncompressed length too large");
70   }
71
72   if (data->empty()) {
73     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
74         uncompressedLength != 0) {
75       throw std::runtime_error("Codec: invalid uncompressed length");
76     }
77     return IOBuf::create(0);
78   }
79
80   return doUncompress(data, uncompressedLength);
81 }
82
83 bool Codec::needsUncompressedLength() const {
84   return doNeedsUncompressedLength();
85 }
86
87 uint64_t Codec::maxUncompressedLength() const {
88   return doMaxUncompressedLength();
89 }
90
91 bool Codec::doNeedsUncompressedLength() const {
92   return false;
93 }
94
95 uint64_t Codec::doMaxUncompressedLength() const {
96   return UNLIMITED_UNCOMPRESSED_LENGTH;
97 }
98
99 namespace {
100
101 /**
102  * No compression
103  */
104 class NoCompressionCodec FOLLY_FINAL : public Codec {
105  public:
106   static std::unique_ptr<Codec> create(int level, CodecType type);
107   explicit NoCompressionCodec(int level, CodecType type);
108
109  private:
110   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
111   std::unique_ptr<IOBuf> doUncompress(
112       const IOBuf* data,
113       uint64_t uncompressedLength) FOLLY_OVERRIDE;
114 };
115
116 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
117   return make_unique<NoCompressionCodec>(level, type);
118 }
119
120 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
121   : Codec(type) {
122   DCHECK(type == CodecType::NO_COMPRESSION);
123   switch (level) {
124   case COMPRESSION_LEVEL_DEFAULT:
125   case COMPRESSION_LEVEL_FASTEST:
126   case COMPRESSION_LEVEL_BEST:
127     level = 0;
128   }
129   if (level != 0) {
130     throw std::invalid_argument(to<std::string>(
131         "NoCompressionCodec: invalid level ", level));
132   }
133 }
134
135 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
136     const IOBuf* data) {
137   return data->clone();
138 }
139
140 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
141     const IOBuf* data,
142     uint64_t uncompressedLength) {
143   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
144       data->computeChainDataLength() != uncompressedLength) {
145     throw std::runtime_error(to<std::string>(
146         "NoCompressionCodec: invalid uncompressed length"));
147   }
148   return data->clone();
149 }
150
151 namespace {
152
153 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
154   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
155   out->append(encodeVarint(val, out->writableTail()));
156 }
157
158 uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
159   // Must have enough room in *this* buffer.
160   auto p = cursor.peek();
161   folly::ByteRange range(p.first, p.second);
162   uint64_t val = decodeVarint(range);
163   cursor.skip(range.data() - p.first);
164   return val;
165 }
166
167 }  // namespace
168
169 #if FOLLY_HAVE_LIBLZ4
170
171 /**
172  * LZ4 compression
173  */
174 class LZ4Codec FOLLY_FINAL : public Codec {
175  public:
176   static std::unique_ptr<Codec> create(int level, CodecType type);
177   explicit LZ4Codec(int level, CodecType type);
178
179  private:
180   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
181   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
182
183   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
184
185   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
186   std::unique_ptr<IOBuf> doUncompress(
187       const IOBuf* data,
188       uint64_t uncompressedLength) FOLLY_OVERRIDE;
189
190   bool highCompression_;
191 };
192
193 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
194   return make_unique<LZ4Codec>(level, type);
195 }
196
197 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
198   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
199
200   switch (level) {
201   case COMPRESSION_LEVEL_FASTEST:
202   case COMPRESSION_LEVEL_DEFAULT:
203     level = 1;
204     break;
205   case COMPRESSION_LEVEL_BEST:
206     level = 2;
207     break;
208   }
209   if (level < 1 || level > 2) {
210     throw std::invalid_argument(to<std::string>(
211         "LZ4Codec: invalid level: ", level));
212   }
213   highCompression_ = (level > 1);
214 }
215
216 bool LZ4Codec::doNeedsUncompressedLength() const {
217   return !encodeSize();
218 }
219
220 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
221 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
222 // here.
223 #ifndef LZ4_MAX_INPUT_SIZE
224 # define LZ4_MAX_INPUT_SIZE 0x7E000000
225 #endif
226
227 uint64_t LZ4Codec::doMaxUncompressedLength() const {
228   return LZ4_MAX_INPUT_SIZE;
229 }
230
231 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
232   std::unique_ptr<IOBuf> clone;
233   if (data->isChained()) {
234     // LZ4 doesn't support streaming, so we have to coalesce
235     clone = data->clone();
236     clone->coalesce();
237     data = clone.get();
238   }
239
240   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
241   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
242   if (encodeSize()) {
243     encodeVarintToIOBuf(data->length(), out.get());
244   }
245
246   int n;
247   if (highCompression_) {
248     n = LZ4_compressHC(reinterpret_cast<const char*>(data->data()),
249                        reinterpret_cast<char*>(out->writableTail()),
250                        data->length());
251   } else {
252     n = LZ4_compress(reinterpret_cast<const char*>(data->data()),
253                      reinterpret_cast<char*>(out->writableTail()),
254                      data->length());
255   }
256
257   CHECK_GE(n, 0);
258   CHECK_LE(n, out->capacity());
259
260   out->append(n);
261   return out;
262 }
263
264 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
265     const IOBuf* data,
266     uint64_t uncompressedLength) {
267   std::unique_ptr<IOBuf> clone;
268   if (data->isChained()) {
269     // LZ4 doesn't support streaming, so we have to coalesce
270     clone = data->clone();
271     clone->coalesce();
272     data = clone.get();
273   }
274
275   folly::io::Cursor cursor(data);
276   uint64_t actualUncompressedLength;
277   if (encodeSize()) {
278     actualUncompressedLength = decodeVarintFromCursor(cursor);
279     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
280         uncompressedLength != actualUncompressedLength) {
281       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
282     }
283   } else {
284     actualUncompressedLength = uncompressedLength;
285     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
286         actualUncompressedLength > maxUncompressedLength()) {
287       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
288     }
289   }
290
291   auto p = cursor.peek();
292   auto out = IOBuf::create(actualUncompressedLength);
293   int n = LZ4_decompress_safe(reinterpret_cast<const char*>(p.first),
294                               reinterpret_cast<char*>(out->writableTail()),
295                               p.second,
296                               actualUncompressedLength);
297
298   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
299     throw std::runtime_error(to<std::string>(
300         "LZ4 decompression returned invalid value ", n));
301   }
302   out->append(actualUncompressedLength);
303   return out;
304 }
305
306 #endif  // FOLLY_HAVE_LIBLZ4
307
308 #if FOLLY_HAVE_LIBSNAPPY
309
310 /**
311  * Snappy compression
312  */
313
314 /**
315  * Implementation of snappy::Source that reads from a IOBuf chain.
316  */
317 class IOBufSnappySource FOLLY_FINAL : public snappy::Source {
318  public:
319   explicit IOBufSnappySource(const IOBuf* data);
320   size_t Available() const FOLLY_OVERRIDE;
321   const char* Peek(size_t* len) FOLLY_OVERRIDE;
322   void Skip(size_t n) FOLLY_OVERRIDE;
323  private:
324   size_t available_;
325   io::Cursor cursor_;
326 };
327
328 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
329   : available_(data->computeChainDataLength()),
330     cursor_(data) {
331 }
332
333 size_t IOBufSnappySource::Available() const {
334   return available_;
335 }
336
337 const char* IOBufSnappySource::Peek(size_t* len) {
338   auto p = cursor_.peek();
339   *len = p.second;
340   return reinterpret_cast<const char*>(p.first);
341 }
342
343 void IOBufSnappySource::Skip(size_t n) {
344   CHECK_LE(n, available_);
345   cursor_.skip(n);
346   available_ -= n;
347 }
348
349 class SnappyCodec FOLLY_FINAL : public Codec {
350  public:
351   static std::unique_ptr<Codec> create(int level, CodecType type);
352   explicit SnappyCodec(int level, CodecType type);
353
354  private:
355   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
356   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
357   std::unique_ptr<IOBuf> doUncompress(
358       const IOBuf* data,
359       uint64_t uncompressedLength) FOLLY_OVERRIDE;
360 };
361
362 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
363   return make_unique<SnappyCodec>(level, type);
364 }
365
366 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
367   DCHECK(type == CodecType::SNAPPY);
368   switch (level) {
369   case COMPRESSION_LEVEL_FASTEST:
370   case COMPRESSION_LEVEL_DEFAULT:
371   case COMPRESSION_LEVEL_BEST:
372     level = 1;
373   }
374   if (level != 1) {
375     throw std::invalid_argument(to<std::string>(
376         "SnappyCodec: invalid level: ", level));
377   }
378 }
379
380 uint64_t SnappyCodec::doMaxUncompressedLength() const {
381   // snappy.h uses uint32_t for lengths, so there's that.
382   return std::numeric_limits<uint32_t>::max();
383 }
384
385 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
386   IOBufSnappySource source(data);
387   auto out =
388     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
389
390   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
391       out->writableTail()));
392
393   size_t n = snappy::Compress(&source, &sink);
394
395   CHECK_LE(n, out->capacity());
396   out->append(n);
397   return out;
398 }
399
400 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
401                                                  uint64_t uncompressedLength) {
402   uint32_t actualUncompressedLength = 0;
403
404   {
405     IOBufSnappySource source(data);
406     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
407       throw std::runtime_error("snappy::GetUncompressedLength failed");
408     }
409     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
410         uncompressedLength != actualUncompressedLength) {
411       throw std::runtime_error("snappy: invalid uncompressed length");
412     }
413   }
414
415   auto out = IOBuf::create(actualUncompressedLength);
416
417   {
418     IOBufSnappySource source(data);
419     if (!snappy::RawUncompress(&source,
420                                reinterpret_cast<char*>(out->writableTail()))) {
421       throw std::runtime_error("snappy::RawUncompress failed");
422     }
423   }
424
425   out->append(actualUncompressedLength);
426   return out;
427 }
428
429 #endif  // FOLLY_HAVE_LIBSNAPPY
430
431 #if FOLLY_HAVE_LIBZ
432 /**
433  * Zlib codec
434  */
435 class ZlibCodec FOLLY_FINAL : public Codec {
436  public:
437   static std::unique_ptr<Codec> create(int level, CodecType type);
438   explicit ZlibCodec(int level, CodecType type);
439
440  private:
441   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
442   std::unique_ptr<IOBuf> doUncompress(
443       const IOBuf* data,
444       uint64_t uncompressedLength) FOLLY_OVERRIDE;
445
446   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
447   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
448
449   int level_;
450 };
451
452 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
453   return make_unique<ZlibCodec>(level, type);
454 }
455
456 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
457   DCHECK(type == CodecType::ZLIB);
458   switch (level) {
459   case COMPRESSION_LEVEL_FASTEST:
460     level = 1;
461     break;
462   case COMPRESSION_LEVEL_DEFAULT:
463     level = Z_DEFAULT_COMPRESSION;
464     break;
465   case COMPRESSION_LEVEL_BEST:
466     level = 9;
467     break;
468   }
469   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
470     throw std::invalid_argument(to<std::string>(
471         "ZlibCodec: invalid level: ", level));
472   }
473   level_ = level;
474 }
475
476 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
477                                                   uint32_t length) {
478   CHECK_EQ(stream->avail_out, 0);
479
480   auto buf = IOBuf::create(length);
481   buf->append(length);
482
483   stream->next_out = buf->writableData();
484   stream->avail_out = buf->length();
485
486   return buf;
487 }
488
489 bool ZlibCodec::doInflate(z_stream* stream,
490                           IOBuf* head,
491                           uint32_t bufferLength) {
492   if (stream->avail_out == 0) {
493     head->prependChain(addOutputBuffer(stream, bufferLength));
494   }
495
496   int rc = inflate(stream, Z_NO_FLUSH);
497
498   switch (rc) {
499   case Z_OK:
500     break;
501   case Z_STREAM_END:
502     return true;
503   case Z_BUF_ERROR:
504   case Z_NEED_DICT:
505   case Z_DATA_ERROR:
506   case Z_MEM_ERROR:
507     throw std::runtime_error(to<std::string>(
508         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
509   default:
510     CHECK(false) << rc << ": " << stream->msg;
511   }
512
513   return false;
514 }
515
516 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
517   z_stream stream;
518   stream.zalloc = nullptr;
519   stream.zfree = nullptr;
520   stream.opaque = nullptr;
521
522   int rc = deflateInit(&stream, level_);
523   if (rc != Z_OK) {
524     throw std::runtime_error(to<std::string>(
525         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
526   }
527
528   stream.next_in = stream.next_out = nullptr;
529   stream.avail_in = stream.avail_out = 0;
530   stream.total_in = stream.total_out = 0;
531
532   bool success = false;
533
534   SCOPE_EXIT {
535     int rc = deflateEnd(&stream);
536     // If we're here because of an exception, it's okay if some data
537     // got dropped.
538     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
539       << rc << ": " << stream.msg;
540   };
541
542   uint64_t uncompressedLength = data->computeChainDataLength();
543   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
544
545   // Max 64MiB in one go
546   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
547   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
548
549   auto out = addOutputBuffer(
550       &stream,
551       (maxCompressedLength <= maxSingleStepLength ?
552        maxCompressedLength :
553        defaultBufferLength));
554
555   for (auto& range : *data) {
556     uint64_t remaining = range.size();
557     uint64_t written = 0;
558     while (remaining) {
559       uint32_t step = (remaining > maxSingleStepLength ?
560                        maxSingleStepLength : remaining);
561       stream.next_in = const_cast<uint8_t*>(range.data() + written);
562       stream.avail_in = step;
563       remaining -= step;
564       written += step;
565
566       while (stream.avail_in != 0) {
567         if (stream.avail_out == 0) {
568           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
569         }
570
571         rc = deflate(&stream, Z_NO_FLUSH);
572
573         CHECK_EQ(rc, Z_OK) << stream.msg;
574       }
575     }
576   }
577
578   do {
579     if (stream.avail_out == 0) {
580       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
581     }
582
583     rc = deflate(&stream, Z_FINISH);
584   } while (rc == Z_OK);
585
586   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
587
588   out->prev()->trimEnd(stream.avail_out);
589
590   success = true;  // we survived
591
592   return out;
593 }
594
595 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
596                                                uint64_t uncompressedLength) {
597   z_stream stream;
598   stream.zalloc = nullptr;
599   stream.zfree = nullptr;
600   stream.opaque = nullptr;
601
602   int rc = inflateInit(&stream);
603   if (rc != Z_OK) {
604     throw std::runtime_error(to<std::string>(
605         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
606   }
607
608   stream.next_in = stream.next_out = nullptr;
609   stream.avail_in = stream.avail_out = 0;
610   stream.total_in = stream.total_out = 0;
611
612   bool success = false;
613
614   SCOPE_EXIT {
615     int rc = inflateEnd(&stream);
616     // If we're here because of an exception, it's okay if some data
617     // got dropped.
618     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
619       << rc << ": " << stream.msg;
620   };
621
622   // Max 64MiB in one go
623   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
624   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
625
626   auto out = addOutputBuffer(
627       &stream,
628       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
629         uncompressedLength <= maxSingleStepLength) ?
630        uncompressedLength :
631        defaultBufferLength));
632
633   bool streamEnd = false;
634   for (auto& range : *data) {
635     if (range.empty()) {
636       continue;
637     }
638
639     stream.next_in = const_cast<uint8_t*>(range.data());
640     stream.avail_in = range.size();
641
642     while (stream.avail_in != 0) {
643       if (streamEnd) {
644         throw std::runtime_error(to<std::string>(
645             "ZlibCodec: junk after end of data"));
646       }
647
648       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
649     }
650   }
651
652   while (!streamEnd) {
653     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
654   }
655
656   out->prev()->trimEnd(stream.avail_out);
657
658   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
659       uncompressedLength != stream.total_out) {
660     throw std::runtime_error(to<std::string>(
661         "ZlibCodec: invalid uncompressed length"));
662   }
663
664   success = true;  // we survived
665
666   return out;
667 }
668
669 #endif  // FOLLY_HAVE_LIBZ
670
671 #if FOLLY_HAVE_LIBLZMA
672
673 /**
674  * LZMA2 compression
675  */
676 class LZMA2Codec FOLLY_FINAL : public Codec {
677  public:
678   static std::unique_ptr<Codec> create(int level, CodecType type);
679   explicit LZMA2Codec(int level, CodecType type);
680
681  private:
682   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
683   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
684
685   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
686
687   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
688   std::unique_ptr<IOBuf> doUncompress(
689       const IOBuf* data,
690       uint64_t uncompressedLength) FOLLY_OVERRIDE;
691
692   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
693   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
694
695   int level_;
696 };
697
698 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
699   return make_unique<LZMA2Codec>(level, type);
700 }
701
702 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
703   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
704   switch (level) {
705   case COMPRESSION_LEVEL_FASTEST:
706     level = 0;
707     break;
708   case COMPRESSION_LEVEL_DEFAULT:
709     level = LZMA_PRESET_DEFAULT;
710     break;
711   case COMPRESSION_LEVEL_BEST:
712     level = 9;
713     break;
714   }
715   if (level < 0 || level > 9) {
716     throw std::invalid_argument(to<std::string>(
717         "LZMA2Codec: invalid level: ", level));
718   }
719   level_ = level;
720 }
721
722 bool LZMA2Codec::doNeedsUncompressedLength() const {
723   return !encodeSize();
724 }
725
726 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
727   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
728   return uint64_t(1) << 63;
729 }
730
731 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
732     lzma_stream* stream,
733     size_t length) {
734
735   CHECK_EQ(stream->avail_out, 0);
736
737   auto buf = IOBuf::create(length);
738   buf->append(length);
739
740   stream->next_out = buf->writableData();
741   stream->avail_out = buf->length();
742
743   return buf;
744 }
745
746 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
747   lzma_ret rc;
748   lzma_stream stream = LZMA_STREAM_INIT;
749
750   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
751   if (rc != LZMA_OK) {
752     throw std::runtime_error(folly::to<std::string>(
753       "LZMA2Codec: lzma_easy_encoder error: ", rc));
754   }
755
756   SCOPE_EXIT { lzma_end(&stream); };
757
758   uint64_t uncompressedLength = data->computeChainDataLength();
759   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
760
761   // Max 64MiB in one go
762   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
763   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
764
765   auto out = addOutputBuffer(
766     &stream,
767     (maxCompressedLength <= maxSingleStepLength ?
768      maxCompressedLength :
769      defaultBufferLength));
770
771   if (encodeSize()) {
772     auto size = IOBuf::createCombined(kMaxVarintLength64);
773     encodeVarintToIOBuf(uncompressedLength, size.get());
774     size->appendChain(std::move(out));
775     out = std::move(size);
776   }
777
778   for (auto& range : *data) {
779     if (range.empty()) {
780       continue;
781     }
782
783     stream.next_in = const_cast<uint8_t*>(range.data());
784     stream.avail_in = range.size();
785
786     while (stream.avail_in != 0) {
787       if (stream.avail_out == 0) {
788         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
789       }
790
791       rc = lzma_code(&stream, LZMA_RUN);
792
793       if (rc != LZMA_OK) {
794         throw std::runtime_error(folly::to<std::string>(
795           "LZMA2Codec: lzma_code error: ", rc));
796       }
797     }
798   }
799
800   do {
801     if (stream.avail_out == 0) {
802       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
803     }
804
805     rc = lzma_code(&stream, LZMA_FINISH);
806   } while (rc == LZMA_OK);
807
808   if (rc != LZMA_STREAM_END) {
809     throw std::runtime_error(folly::to<std::string>(
810       "LZMA2Codec: lzma_code ended with error: ", rc));
811   }
812
813   out->prev()->trimEnd(stream.avail_out);
814
815   return out;
816 }
817
818 bool LZMA2Codec::doInflate(lzma_stream* stream,
819                           IOBuf* head,
820                           size_t bufferLength) {
821   if (stream->avail_out == 0) {
822     head->prependChain(addOutputBuffer(stream, bufferLength));
823   }
824
825   lzma_ret rc = lzma_code(stream, LZMA_RUN);
826
827   switch (rc) {
828   case LZMA_OK:
829     break;
830   case LZMA_STREAM_END:
831     return true;
832   default:
833     throw std::runtime_error(to<std::string>(
834         "LZMA2Codec: lzma_code error: ", rc));
835   }
836
837   return false;
838 }
839
840 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
841                                                uint64_t uncompressedLength) {
842   lzma_ret rc;
843   lzma_stream stream = LZMA_STREAM_INIT;
844
845   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
846   if (rc != LZMA_OK) {
847     throw std::runtime_error(folly::to<std::string>(
848       "LZMA2Codec: lzma_auto_decoder error: ", rc));
849   }
850
851   SCOPE_EXIT { lzma_end(&stream); };
852
853   // Max 64MiB in one go
854   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
855   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
856
857   folly::io::Cursor cursor(data);
858   uint64_t actualUncompressedLength;
859   if (encodeSize()) {
860     actualUncompressedLength = decodeVarintFromCursor(cursor);
861     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
862         uncompressedLength != actualUncompressedLength) {
863       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
864     }
865   } else {
866     actualUncompressedLength = uncompressedLength;
867     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
868   }
869
870   auto out = addOutputBuffer(
871       &stream,
872       (actualUncompressedLength <= maxSingleStepLength ?
873        actualUncompressedLength :
874        defaultBufferLength));
875
876   bool streamEnd = false;
877   auto buf = cursor.peek();
878   while (buf.second != 0) {
879     stream.next_in = const_cast<uint8_t*>(buf.first);
880     stream.avail_in = buf.second;
881
882     while (stream.avail_in != 0) {
883       if (streamEnd) {
884         throw std::runtime_error(to<std::string>(
885             "LZMA2Codec: junk after end of data"));
886       }
887
888       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
889     }
890
891     cursor.skip(buf.second);
892     buf = cursor.peek();
893   }
894
895   while (!streamEnd) {
896     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
897   }
898
899   out->prev()->trimEnd(stream.avail_out);
900
901   if (actualUncompressedLength != stream.total_out) {
902     throw std::runtime_error(to<std::string>(
903         "LZMA2Codec: invalid uncompressed length"));
904   }
905
906   return out;
907 }
908
909 #endif  // FOLLY_HAVE_LIBLZMA
910
911 typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
912
913 CodecFactory gCodecFactories[
914     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
915   nullptr,  // USER_DEFINED
916   NoCompressionCodec::create,
917
918 #if FOLLY_HAVE_LIBLZ4
919   LZ4Codec::create,
920 #else
921   nullptr,
922 #endif
923
924 #if FOLLY_HAVE_LIBSNAPPY
925   SnappyCodec::create,
926 #else
927   nullptr,
928 #endif
929
930 #if FOLLY_HAVE_LIBZ
931   ZlibCodec::create,
932 #else
933   nullptr,
934 #endif
935
936 #if FOLLY_HAVE_LIBLZ4
937   LZ4Codec::create,
938 #else
939   nullptr,
940 #endif
941
942 #if FOLLY_HAVE_LIBLZMA
943   LZMA2Codec::create,
944   LZMA2Codec::create,
945 #else
946   nullptr,
947   nullptr,
948 #endif
949 };
950
951 }  // namespace
952
953 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
954   size_t idx = static_cast<size_t>(type);
955   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
956     throw std::invalid_argument(to<std::string>(
957         "Compression type ", idx, " not supported"));
958   }
959   auto factory = gCodecFactories[idx];
960   if (!factory) {
961     throw std::invalid_argument(to<std::string>(
962         "Compression type ", idx, " not supported"));
963   }
964   auto codec = (*factory)(level, type);
965   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
966   return codec;
967 }
968
969 }}  // namespaces