d014b69abc6e8202a8e1f25c8cca205f8771c1de
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #if LZ4_VERSION_NUMBER >= 10301
23 #include <lz4frame.h>
24 #endif
25 #endif
26
27 #include <glog/logging.h>
28
29 #if FOLLY_HAVE_LIBSNAPPY
30 #include <snappy.h>
31 #include <snappy-sinksource.h>
32 #endif
33
34 #if FOLLY_HAVE_LIBZ
35 #include <zlib.h>
36 #endif
37
38 #if FOLLY_HAVE_LIBLZMA
39 #include <lzma.h>
40 #endif
41
42 #if FOLLY_HAVE_LIBZSTD
43 #define ZSTD_STATIC_LINKING_ONLY
44 #include <zstd.h>
45 #endif
46
47 #if FOLLY_HAVE_LIBBZ2
48 #include <bzlib.h>
49 #endif
50
51 #include <folly/Bits.h>
52 #include <folly/Conv.h>
53 #include <folly/Memory.h>
54 #include <folly/Portability.h>
55 #include <folly/ScopeGuard.h>
56 #include <folly/Varint.h>
57 #include <folly/io/Cursor.h>
58 #include <algorithm>
59 #include <unordered_set>
60
61 namespace folly { namespace io {
62
63 Codec::Codec(CodecType type) : type_(type) { }
64
65 // Ensure consistent behavior in the nullptr case
66 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
67   uint64_t len = data->computeChainDataLength();
68   if (len == 0) {
69     return IOBuf::create(0);
70   }
71   if (len > maxUncompressedLength()) {
72     throw std::runtime_error("Codec: uncompressed length too large");
73   }
74
75   return doCompress(data);
76 }
77
78 std::string Codec::compress(const StringPiece data) {
79   const uint64_t len = data.size();
80   if (len == 0) {
81     return "";
82   }
83   if (len > maxUncompressedLength()) {
84     throw std::runtime_error("Codec: uncompressed length too large");
85   }
86
87   return doCompressString(data);
88 }
89
90 std::unique_ptr<IOBuf> Codec::uncompress(
91     const IOBuf* data,
92     Optional<uint64_t> uncompressedLength) {
93   if (!uncompressedLength) {
94     if (needsUncompressedLength()) {
95       throw std::invalid_argument("Codec: uncompressed length required");
96     }
97   } else if (*uncompressedLength > maxUncompressedLength()) {
98     throw std::runtime_error("Codec: uncompressed length too large");
99   }
100
101   if (data->empty()) {
102     if (uncompressedLength.value_or(0) != 0) {
103       throw std::runtime_error("Codec: invalid uncompressed length");
104     }
105     return IOBuf::create(0);
106   }
107
108   return doUncompress(data, uncompressedLength);
109 }
110
111 std::string Codec::uncompress(
112     const StringPiece data,
113     Optional<uint64_t> uncompressedLength) {
114   if (!uncompressedLength) {
115     if (needsUncompressedLength()) {
116       throw std::invalid_argument("Codec: uncompressed length required");
117     }
118   } else if (*uncompressedLength > maxUncompressedLength()) {
119     throw std::runtime_error("Codec: uncompressed length too large");
120   }
121
122   if (data.empty()) {
123     if (uncompressedLength.value_or(0) != 0) {
124       throw std::runtime_error("Codec: invalid uncompressed length");
125     }
126     return "";
127   }
128
129   return doUncompressString(data, uncompressedLength);
130 }
131
132 bool Codec::needsUncompressedLength() const {
133   return doNeedsUncompressedLength();
134 }
135
136 uint64_t Codec::maxUncompressedLength() const {
137   return doMaxUncompressedLength();
138 }
139
140 bool Codec::doNeedsUncompressedLength() const {
141   return false;
142 }
143
144 uint64_t Codec::doMaxUncompressedLength() const {
145   return UNLIMITED_UNCOMPRESSED_LENGTH;
146 }
147
148 std::vector<std::string> Codec::validPrefixes() const {
149   return {};
150 }
151
152 bool Codec::canUncompress(const IOBuf*, Optional<uint64_t>) const {
153   return false;
154 }
155
156 std::string Codec::doCompressString(const StringPiece data) {
157   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
158   auto outputBuffer = doCompress(&inputBuffer);
159   std::string output;
160   output.reserve(outputBuffer->computeChainDataLength());
161   for (auto range : *outputBuffer) {
162     output.append(reinterpret_cast<const char*>(range.data()), range.size());
163   }
164   return output;
165 }
166
167 std::string Codec::doUncompressString(
168     const StringPiece data,
169     Optional<uint64_t> uncompressedLength) {
170   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
171   auto outputBuffer = doUncompress(&inputBuffer, uncompressedLength);
172   std::string output;
173   output.reserve(outputBuffer->computeChainDataLength());
174   for (auto range : *outputBuffer) {
175     output.append(reinterpret_cast<const char*>(range.data()), range.size());
176   }
177   return output;
178 }
179
180 uint64_t Codec::maxCompressedLength(uint64_t uncompressedLength) const {
181   if (uncompressedLength == 0) {
182     return 0;
183   }
184   return doMaxCompressedLength(uncompressedLength);
185 }
186
187 Optional<uint64_t> Codec::getUncompressedLength(
188     const folly::IOBuf* data,
189     Optional<uint64_t> uncompressedLength) const {
190   auto const compressedLength = data->computeChainDataLength();
191   if (uncompressedLength == uint64_t(0) || compressedLength == 0) {
192     if (uncompressedLength.value_or(0) != 0 || compressedLength != 0) {
193       throw std::runtime_error("Invalid uncompressed length");
194     }
195     return 0;
196   }
197   return doGetUncompressedLength(data, uncompressedLength);
198 }
199
200 Optional<uint64_t> Codec::doGetUncompressedLength(
201     const folly::IOBuf*,
202     Optional<uint64_t> uncompressedLength) const {
203   return uncompressedLength;
204 }
205
206 bool StreamCodec::needsDataLength() const {
207   return doNeedsDataLength();
208 }
209
210 bool StreamCodec::doNeedsDataLength() const {
211   return false;
212 }
213
214 void StreamCodec::assertStateIs(State expected) const {
215   if (state_ != expected) {
216     throw std::logic_error(folly::to<std::string>(
217         "Codec: state is ", state_, "; expected state ", expected));
218   }
219 }
220
221 void StreamCodec::resetStream(Optional<uint64_t> uncompressedLength) {
222   state_ = State::RESET;
223   uncompressedLength_ = uncompressedLength;
224   doResetStream();
225 }
226
227 bool StreamCodec::compressStream(
228     ByteRange& input,
229     MutableByteRange& output,
230     StreamCodec::FlushOp flushOp) {
231   if (state_ == State::RESET && input.empty()) {
232     if (flushOp == StreamCodec::FlushOp::NONE) {
233       return false;
234     }
235     if (flushOp == StreamCodec::FlushOp::END &&
236         uncompressedLength().value_or(0) != 0) {
237       throw std::runtime_error("Codec: invalid uncompressed length");
238     }
239     return true;
240   }
241   if (state_ == State::RESET && !input.empty() &&
242       uncompressedLength() == uint64_t(0)) {
243     throw std::runtime_error("Codec: invalid uncompressed length");
244   }
245   // Handle input state transitions
246   switch (flushOp) {
247     case StreamCodec::FlushOp::NONE:
248       if (state_ == State::RESET) {
249         state_ = State::COMPRESS;
250       }
251       assertStateIs(State::COMPRESS);
252       break;
253     case StreamCodec::FlushOp::FLUSH:
254       if (state_ == State::RESET || state_ == State::COMPRESS) {
255         state_ = State::COMPRESS_FLUSH;
256       }
257       assertStateIs(State::COMPRESS_FLUSH);
258       break;
259     case StreamCodec::FlushOp::END:
260       if (state_ == State::RESET || state_ == State::COMPRESS) {
261         state_ = State::COMPRESS_END;
262       }
263       assertStateIs(State::COMPRESS_END);
264       break;
265   }
266   bool const done = doCompressStream(input, output, flushOp);
267   // Handle output state transitions
268   if (done) {
269     if (state_ == State::COMPRESS_FLUSH) {
270       state_ = State::COMPRESS;
271     } else if (state_ == State::COMPRESS_END) {
272       state_ = State::END;
273     }
274     // Check internal invariants
275     DCHECK(input.empty());
276     DCHECK(flushOp != StreamCodec::FlushOp::NONE);
277   }
278   return done;
279 }
280
281 bool StreamCodec::uncompressStream(
282     ByteRange& input,
283     MutableByteRange& output,
284     StreamCodec::FlushOp flushOp) {
285   if (state_ == State::RESET && input.empty()) {
286     if (uncompressedLength().value_or(0) == 0) {
287       return true;
288     }
289     return false;
290   }
291   // Handle input state transitions
292   if (state_ == State::RESET) {
293     state_ = State::UNCOMPRESS;
294   }
295   assertStateIs(State::UNCOMPRESS);
296   bool const done = doUncompressStream(input, output, flushOp);
297   // Handle output state transitions
298   if (done) {
299     state_ = State::END;
300   }
301   return done;
302 }
303
304 static std::unique_ptr<IOBuf> addOutputBuffer(
305     MutableByteRange& output,
306     uint64_t size) {
307   DCHECK(output.empty());
308   auto buffer = IOBuf::create(size);
309   buffer->append(buffer->capacity());
310   output = {buffer->writableData(), buffer->length()};
311   return buffer;
312 }
313
314 std::unique_ptr<IOBuf> StreamCodec::doCompress(IOBuf const* data) {
315   uint64_t const uncompressedLength = data->computeChainDataLength();
316   resetStream(uncompressedLength);
317   uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength);
318
319   auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB
320   auto constexpr kDefaultBufferLength = uint64_t(4) << 20; // 4 MB
321
322   MutableByteRange output;
323   auto buffer = addOutputBuffer(
324       output,
325       maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen
326                                                : kDefaultBufferLength);
327
328   // Compress the entire IOBuf chain into the IOBuf chain pointed to by buffer
329   IOBuf const* current = data;
330   ByteRange input{current->data(), current->length()};
331   StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
332   for (;;) {
333     while (input.empty() && current->next() != data) {
334       current = current->next();
335       input = {current->data(), current->length()};
336     }
337     if (current->next() == data) {
338       // This is the last input buffer so end the stream
339       flushOp = StreamCodec::FlushOp::END;
340     }
341     if (output.empty()) {
342       buffer->prependChain(addOutputBuffer(output, kDefaultBufferLength));
343     }
344     bool const done = compressStream(input, output, flushOp);
345     if (done) {
346       DCHECK(input.empty());
347       DCHECK(flushOp == StreamCodec::FlushOp::END);
348       DCHECK_EQ(current->next(), data);
349       break;
350     }
351   }
352   buffer->prev()->trimEnd(output.size());
353   return buffer;
354 }
355
356 static uint64_t computeBufferLength(
357     uint64_t const compressedLength,
358     uint64_t const blockSize) {
359   uint64_t constexpr kMaxBufferLength = uint64_t(4) << 20; // 4 MiB
360   uint64_t const goodBufferSize = 4 * std::max(blockSize, compressedLength);
361   return std::min(goodBufferSize, kMaxBufferLength);
362 }
363
364 std::unique_ptr<IOBuf> StreamCodec::doUncompress(
365     IOBuf const* data,
366     Optional<uint64_t> uncompressedLength) {
367   auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB
368   auto constexpr kBlockSize = uint64_t(128) << 10;
369   auto const defaultBufferLength =
370       computeBufferLength(data->computeChainDataLength(), kBlockSize);
371
372   uncompressedLength = getUncompressedLength(data, uncompressedLength);
373   resetStream(uncompressedLength);
374
375   MutableByteRange output;
376   auto buffer = addOutputBuffer(
377       output,
378       (uncompressedLength && *uncompressedLength <= kMaxSingleStepLength
379            ? *uncompressedLength
380            : defaultBufferLength));
381
382   // Uncompress the entire IOBuf chain into the IOBuf chain pointed to by buffer
383   IOBuf const* current = data;
384   ByteRange input{current->data(), current->length()};
385   StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
386   for (;;) {
387     while (input.empty() && current->next() != data) {
388       current = current->next();
389       input = {current->data(), current->length()};
390     }
391     if (current->next() == data) {
392       // Tell the uncompressor there is no more input (it may optimize)
393       flushOp = StreamCodec::FlushOp::END;
394     }
395     if (output.empty()) {
396       buffer->prependChain(addOutputBuffer(output, defaultBufferLength));
397     }
398     bool const done = uncompressStream(input, output, flushOp);
399     if (done) {
400       break;
401     }
402   }
403   if (!input.empty()) {
404     throw std::runtime_error("Codec: Junk after end of data");
405   }
406
407   buffer->prev()->trimEnd(output.size());
408   if (uncompressedLength &&
409       *uncompressedLength != buffer->computeChainDataLength()) {
410     throw std::runtime_error("Codec: invalid uncompressed length");
411   }
412
413   return buffer;
414 }
415
416 namespace {
417
418 /**
419  * No compression
420  */
421 class NoCompressionCodec final : public Codec {
422  public:
423   static std::unique_ptr<Codec> create(int level, CodecType type);
424   explicit NoCompressionCodec(int level, CodecType type);
425
426  private:
427   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
428   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
429   std::unique_ptr<IOBuf> doUncompress(
430       const IOBuf* data,
431       Optional<uint64_t> uncompressedLength) override;
432 };
433
434 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
435   return std::make_unique<NoCompressionCodec>(level, type);
436 }
437
438 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
439   : Codec(type) {
440   DCHECK(type == CodecType::NO_COMPRESSION);
441   switch (level) {
442   case COMPRESSION_LEVEL_DEFAULT:
443   case COMPRESSION_LEVEL_FASTEST:
444   case COMPRESSION_LEVEL_BEST:
445     level = 0;
446   }
447   if (level != 0) {
448     throw std::invalid_argument(to<std::string>(
449         "NoCompressionCodec: invalid level ", level));
450   }
451 }
452
453 uint64_t NoCompressionCodec::doMaxCompressedLength(
454     uint64_t uncompressedLength) const {
455   return uncompressedLength;
456 }
457
458 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
459     const IOBuf* data) {
460   return data->clone();
461 }
462
463 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
464     const IOBuf* data,
465     Optional<uint64_t> uncompressedLength) {
466   if (uncompressedLength &&
467       data->computeChainDataLength() != *uncompressedLength) {
468     throw std::runtime_error(
469         to<std::string>("NoCompressionCodec: invalid uncompressed length"));
470   }
471   return data->clone();
472 }
473
474 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
475
476 namespace {
477
478 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
479   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
480   out->append(encodeVarint(val, out->writableTail()));
481 }
482
483 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
484   uint64_t val = 0;
485   int8_t b = 0;
486   for (int shift = 0; shift <= 63; shift += 7) {
487     b = cursor.read<int8_t>();
488     val |= static_cast<uint64_t>(b & 0x7f) << shift;
489     if (b >= 0) {
490       break;
491     }
492   }
493   if (b < 0) {
494     throw std::invalid_argument("Invalid varint value. Too big.");
495   }
496   return val;
497 }
498
499 }  // namespace
500
501 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
502
503 namespace {
504 /**
505  * Reads sizeof(T) bytes, and returns false if not enough bytes are available.
506  * Returns true if the first n bytes are equal to prefix when interpreted as
507  * a little endian T.
508  */
509 template <typename T>
510 typename std::enable_if<std::is_unsigned<T>::value, bool>::type
511 dataStartsWithLE(const IOBuf* data, T prefix, uint64_t n = sizeof(T)) {
512   DCHECK_GT(n, 0);
513   DCHECK_LE(n, sizeof(T));
514   T value;
515   Cursor cursor{data};
516   if (!cursor.tryReadLE(value)) {
517     return false;
518   }
519   const T mask = n == sizeof(T) ? T(-1) : (T(1) << (8 * n)) - 1;
520   return prefix == (value & mask);
521 }
522
523 template <typename T>
524 typename std::enable_if<std::is_arithmetic<T>::value, std::string>::type
525 prefixToStringLE(T prefix, uint64_t n = sizeof(T)) {
526   DCHECK_GT(n, 0);
527   DCHECK_LE(n, sizeof(T));
528   prefix = Endian::little(prefix);
529   std::string result;
530   result.resize(n);
531   memcpy(&result[0], &prefix, n);
532   return result;
533 }
534 } // namespace
535
536 #if FOLLY_HAVE_LIBLZ4
537
538 /**
539  * LZ4 compression
540  */
541 class LZ4Codec final : public Codec {
542  public:
543   static std::unique_ptr<Codec> create(int level, CodecType type);
544   explicit LZ4Codec(int level, CodecType type);
545
546  private:
547   bool doNeedsUncompressedLength() const override;
548   uint64_t doMaxUncompressedLength() const override;
549   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
550
551   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
552
553   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
554   std::unique_ptr<IOBuf> doUncompress(
555       const IOBuf* data,
556       Optional<uint64_t> uncompressedLength) override;
557
558   bool highCompression_;
559 };
560
561 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
562   return std::make_unique<LZ4Codec>(level, type);
563 }
564
565 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
566   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
567
568   switch (level) {
569   case COMPRESSION_LEVEL_FASTEST:
570   case COMPRESSION_LEVEL_DEFAULT:
571     level = 1;
572     break;
573   case COMPRESSION_LEVEL_BEST:
574     level = 2;
575     break;
576   }
577   if (level < 1 || level > 2) {
578     throw std::invalid_argument(to<std::string>(
579         "LZ4Codec: invalid level: ", level));
580   }
581   highCompression_ = (level > 1);
582 }
583
584 bool LZ4Codec::doNeedsUncompressedLength() const {
585   return !encodeSize();
586 }
587
588 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
589 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
590 // here.
591 #ifndef LZ4_MAX_INPUT_SIZE
592 # define LZ4_MAX_INPUT_SIZE 0x7E000000
593 #endif
594
595 uint64_t LZ4Codec::doMaxUncompressedLength() const {
596   return LZ4_MAX_INPUT_SIZE;
597 }
598
599 uint64_t LZ4Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
600   return LZ4_compressBound(uncompressedLength) +
601       (encodeSize() ? kMaxVarintLength64 : 0);
602 }
603
604 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
605   IOBuf clone;
606   if (data->isChained()) {
607     // LZ4 doesn't support streaming, so we have to coalesce
608     clone = data->cloneCoalescedAsValue();
609     data = &clone;
610   }
611
612   auto out = IOBuf::create(maxCompressedLength(data->length()));
613   if (encodeSize()) {
614     encodeVarintToIOBuf(data->length(), out.get());
615   }
616
617   int n;
618   auto input = reinterpret_cast<const char*>(data->data());
619   auto output = reinterpret_cast<char*>(out->writableTail());
620   const auto inputLength = data->length();
621 #if LZ4_VERSION_NUMBER >= 10700
622   if (highCompression_) {
623     n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0);
624   } else {
625     n = LZ4_compress_default(input, output, inputLength, out->tailroom());
626   }
627 #else
628   if (highCompression_) {
629     n = LZ4_compressHC(input, output, inputLength);
630   } else {
631     n = LZ4_compress(input, output, inputLength);
632   }
633 #endif
634
635   CHECK_GE(n, 0);
636   CHECK_LE(n, out->capacity());
637
638   out->append(n);
639   return out;
640 }
641
642 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
643     const IOBuf* data,
644     Optional<uint64_t> uncompressedLength) {
645   IOBuf clone;
646   if (data->isChained()) {
647     // LZ4 doesn't support streaming, so we have to coalesce
648     clone = data->cloneCoalescedAsValue();
649     data = &clone;
650   }
651
652   folly::io::Cursor cursor(data);
653   uint64_t actualUncompressedLength;
654   if (encodeSize()) {
655     actualUncompressedLength = decodeVarintFromCursor(cursor);
656     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
657       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
658     }
659   } else {
660     // Invariants
661     DCHECK(uncompressedLength.hasValue());
662     DCHECK(*uncompressedLength <= maxUncompressedLength());
663     actualUncompressedLength = *uncompressedLength;
664   }
665
666   auto sp = StringPiece{cursor.peekBytes()};
667   auto out = IOBuf::create(actualUncompressedLength);
668   int n = LZ4_decompress_safe(
669       sp.data(),
670       reinterpret_cast<char*>(out->writableTail()),
671       sp.size(),
672       actualUncompressedLength);
673
674   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
675     throw std::runtime_error(to<std::string>(
676         "LZ4 decompression returned invalid value ", n));
677   }
678   out->append(actualUncompressedLength);
679   return out;
680 }
681
682 #if LZ4_VERSION_NUMBER >= 10301
683
684 class LZ4FrameCodec final : public Codec {
685  public:
686   static std::unique_ptr<Codec> create(int level, CodecType type);
687   explicit LZ4FrameCodec(int level, CodecType type);
688   ~LZ4FrameCodec() override;
689
690   std::vector<std::string> validPrefixes() const override;
691   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
692       const override;
693
694  private:
695   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
696
697   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
698   std::unique_ptr<IOBuf> doUncompress(
699       const IOBuf* data,
700       Optional<uint64_t> uncompressedLength) override;
701
702   // Reset the dctx_ if it is dirty or null.
703   void resetDCtx();
704
705   int level_;
706   LZ4F_decompressionContext_t dctx_{nullptr};
707   bool dirty_{false};
708 };
709
710 /* static */ std::unique_ptr<Codec> LZ4FrameCodec::create(
711     int level,
712     CodecType type) {
713   return std::make_unique<LZ4FrameCodec>(level, type);
714 }
715
716 static constexpr uint32_t kLZ4FrameMagicLE = 0x184D2204;
717
718 std::vector<std::string> LZ4FrameCodec::validPrefixes() const {
719   return {prefixToStringLE(kLZ4FrameMagicLE)};
720 }
721
722 bool LZ4FrameCodec::canUncompress(const IOBuf* data, Optional<uint64_t>) const {
723   return dataStartsWithLE(data, kLZ4FrameMagicLE);
724 }
725
726 uint64_t LZ4FrameCodec::doMaxCompressedLength(
727     uint64_t uncompressedLength) const {
728   LZ4F_preferences_t prefs{};
729   prefs.compressionLevel = level_;
730   prefs.frameInfo.contentSize = uncompressedLength;
731   return LZ4F_compressFrameBound(uncompressedLength, &prefs);
732 }
733
734 static size_t lz4FrameThrowOnError(size_t code) {
735   if (LZ4F_isError(code)) {
736     throw std::runtime_error(
737         to<std::string>("LZ4Frame error: ", LZ4F_getErrorName(code)));
738   }
739   return code;
740 }
741
742 void LZ4FrameCodec::resetDCtx() {
743   if (dctx_ && !dirty_) {
744     return;
745   }
746   if (dctx_) {
747     LZ4F_freeDecompressionContext(dctx_);
748   }
749   lz4FrameThrowOnError(LZ4F_createDecompressionContext(&dctx_, 100));
750   dirty_ = false;
751 }
752
753 LZ4FrameCodec::LZ4FrameCodec(int level, CodecType type) : Codec(type) {
754   DCHECK(type == CodecType::LZ4_FRAME);
755   switch (level) {
756     case COMPRESSION_LEVEL_FASTEST:
757     case COMPRESSION_LEVEL_DEFAULT:
758       level_ = 0;
759       break;
760     case COMPRESSION_LEVEL_BEST:
761       level_ = 16;
762       break;
763     default:
764       level_ = level;
765       break;
766   }
767 }
768
769 LZ4FrameCodec::~LZ4FrameCodec() {
770   if (dctx_) {
771     LZ4F_freeDecompressionContext(dctx_);
772   }
773 }
774
775 std::unique_ptr<IOBuf> LZ4FrameCodec::doCompress(const IOBuf* data) {
776   // LZ4 Frame compression doesn't support streaming so we have to coalesce
777   IOBuf clone;
778   if (data->isChained()) {
779     clone = data->cloneCoalescedAsValue();
780     data = &clone;
781   }
782   // Set preferences
783   const auto uncompressedLength = data->length();
784   LZ4F_preferences_t prefs{};
785   prefs.compressionLevel = level_;
786   prefs.frameInfo.contentSize = uncompressedLength;
787   // Compress
788   auto buf = IOBuf::create(maxCompressedLength(uncompressedLength));
789   const size_t written = lz4FrameThrowOnError(LZ4F_compressFrame(
790       buf->writableTail(),
791       buf->tailroom(),
792       data->data(),
793       data->length(),
794       &prefs));
795   buf->append(written);
796   return buf;
797 }
798
799 std::unique_ptr<IOBuf> LZ4FrameCodec::doUncompress(
800     const IOBuf* data,
801     Optional<uint64_t> uncompressedLength) {
802   // Reset the dctx if any errors have occurred
803   resetDCtx();
804   // Coalesce the data
805   ByteRange in = *data->begin();
806   IOBuf clone;
807   if (data->isChained()) {
808     clone = data->cloneCoalescedAsValue();
809     in = clone.coalesce();
810   }
811   data = nullptr;
812   // Select decompression options
813   LZ4F_decompressOptions_t options;
814   options.stableDst = 1;
815   // Select blockSize and growthSize for the IOBufQueue
816   IOBufQueue queue(IOBufQueue::cacheChainLength());
817   auto blockSize = uint64_t{64} << 10;
818   auto growthSize = uint64_t{4} << 20;
819   if (uncompressedLength) {
820     // Allocate uncompressedLength in one chunk (up to 64 MB)
821     const auto allocateSize = std::min(*uncompressedLength, uint64_t{64} << 20);
822     queue.preallocate(allocateSize, allocateSize);
823     blockSize = std::min(*uncompressedLength, blockSize);
824     growthSize = std::min(*uncompressedLength, growthSize);
825   } else {
826     // Reduce growthSize for small data
827     const auto guessUncompressedLen =
828         4 * std::max<uint64_t>(blockSize, in.size());
829     growthSize = std::min(guessUncompressedLen, growthSize);
830   }
831   // Once LZ4_decompress() is called, the dctx_ cannot be reused until it
832   // returns 0
833   dirty_ = true;
834   // Decompress until the frame is over
835   size_t code = 0;
836   do {
837     // Allocate enough space to decompress at least a block
838     void* out;
839     size_t outSize;
840     std::tie(out, outSize) = queue.preallocate(blockSize, growthSize);
841     // Decompress
842     size_t inSize = in.size();
843     code = lz4FrameThrowOnError(
844         LZ4F_decompress(dctx_, out, &outSize, in.data(), &inSize, &options));
845     if (in.empty() && outSize == 0 && code != 0) {
846       // We passed no input, no output was produced, and the frame isn't over
847       // No more forward progress is possible
848       throw std::runtime_error("LZ4Frame error: Incomplete frame");
849     }
850     in.uncheckedAdvance(inSize);
851     queue.postallocate(outSize);
852   } while (code != 0);
853   // At this point the decompression context can be reused
854   dirty_ = false;
855   if (uncompressedLength && queue.chainLength() != *uncompressedLength) {
856     throw std::runtime_error("LZ4Frame error: Invalid uncompressedLength");
857   }
858   return queue.move();
859 }
860
861 #endif // LZ4_VERSION_NUMBER >= 10301
862 #endif // FOLLY_HAVE_LIBLZ4
863
864 #if FOLLY_HAVE_LIBSNAPPY
865
866 /**
867  * Snappy compression
868  */
869
870 /**
871  * Implementation of snappy::Source that reads from a IOBuf chain.
872  */
873 class IOBufSnappySource final : public snappy::Source {
874  public:
875   explicit IOBufSnappySource(const IOBuf* data);
876   size_t Available() const override;
877   const char* Peek(size_t* len) override;
878   void Skip(size_t n) override;
879  private:
880   size_t available_;
881   io::Cursor cursor_;
882 };
883
884 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
885   : available_(data->computeChainDataLength()),
886     cursor_(data) {
887 }
888
889 size_t IOBufSnappySource::Available() const {
890   return available_;
891 }
892
893 const char* IOBufSnappySource::Peek(size_t* len) {
894   auto sp = StringPiece{cursor_.peekBytes()};
895   *len = sp.size();
896   return sp.data();
897 }
898
899 void IOBufSnappySource::Skip(size_t n) {
900   CHECK_LE(n, available_);
901   cursor_.skip(n);
902   available_ -= n;
903 }
904
905 class SnappyCodec final : public Codec {
906  public:
907   static std::unique_ptr<Codec> create(int level, CodecType type);
908   explicit SnappyCodec(int level, CodecType type);
909
910  private:
911   uint64_t doMaxUncompressedLength() const override;
912   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
913   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
914   std::unique_ptr<IOBuf> doUncompress(
915       const IOBuf* data,
916       Optional<uint64_t> uncompressedLength) override;
917 };
918
919 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
920   return std::make_unique<SnappyCodec>(level, type);
921 }
922
923 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
924   DCHECK(type == CodecType::SNAPPY);
925   switch (level) {
926   case COMPRESSION_LEVEL_FASTEST:
927   case COMPRESSION_LEVEL_DEFAULT:
928   case COMPRESSION_LEVEL_BEST:
929     level = 1;
930   }
931   if (level != 1) {
932     throw std::invalid_argument(to<std::string>(
933         "SnappyCodec: invalid level: ", level));
934   }
935 }
936
937 uint64_t SnappyCodec::doMaxUncompressedLength() const {
938   // snappy.h uses uint32_t for lengths, so there's that.
939   return std::numeric_limits<uint32_t>::max();
940 }
941
942 uint64_t SnappyCodec::doMaxCompressedLength(uint64_t uncompressedLength) const {
943   return snappy::MaxCompressedLength(uncompressedLength);
944 }
945
946 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
947   IOBufSnappySource source(data);
948   auto out = IOBuf::create(maxCompressedLength(source.Available()));
949
950   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
951       out->writableTail()));
952
953   size_t n = snappy::Compress(&source, &sink);
954
955   CHECK_LE(n, out->capacity());
956   out->append(n);
957   return out;
958 }
959
960 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(
961     const IOBuf* data,
962     Optional<uint64_t> uncompressedLength) {
963   uint32_t actualUncompressedLength = 0;
964
965   {
966     IOBufSnappySource source(data);
967     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
968       throw std::runtime_error("snappy::GetUncompressedLength failed");
969     }
970     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
971       throw std::runtime_error("snappy: invalid uncompressed length");
972     }
973   }
974
975   auto out = IOBuf::create(actualUncompressedLength);
976
977   {
978     IOBufSnappySource source(data);
979     if (!snappy::RawUncompress(&source,
980                                reinterpret_cast<char*>(out->writableTail()))) {
981       throw std::runtime_error("snappy::RawUncompress failed");
982     }
983   }
984
985   out->append(actualUncompressedLength);
986   return out;
987 }
988
989 #endif  // FOLLY_HAVE_LIBSNAPPY
990
991 #if FOLLY_HAVE_LIBZ
992 /**
993  * Zlib codec
994  */
995 class ZlibStreamCodec final : public StreamCodec {
996  public:
997   static std::unique_ptr<Codec> createCodec(int level, CodecType type);
998   static std::unique_ptr<StreamCodec> createStream(int level, CodecType type);
999   explicit ZlibStreamCodec(int level, CodecType type);
1000   ~ZlibStreamCodec();
1001
1002   std::vector<std::string> validPrefixes() const override;
1003   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1004       const override;
1005
1006  private:
1007   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1008
1009   void doResetStream() override;
1010   bool doCompressStream(
1011       ByteRange& input,
1012       MutableByteRange& output,
1013       StreamCodec::FlushOp flush) override;
1014   bool doUncompressStream(
1015       ByteRange& input,
1016       MutableByteRange& output,
1017       StreamCodec::FlushOp flush) override;
1018
1019   void resetDeflateStream();
1020   void resetInflateStream();
1021
1022   Optional<z_stream> deflateStream_{};
1023   Optional<z_stream> inflateStream_{};
1024   int level_;
1025   bool needReset_{true};
1026 };
1027
1028 static constexpr uint16_t kGZIPMagicLE = 0x8B1F;
1029
1030 std::vector<std::string> ZlibStreamCodec::validPrefixes() const {
1031   if (type() == CodecType::ZLIB) {
1032     // Zlib streams start with a 2 byte header.
1033     //
1034     //   0   1
1035     // +---+---+
1036     // |CMF|FLG|
1037     // +---+---+
1038     //
1039     // We won't restrict the values of any sub-fields except as described below.
1040     //
1041     // The lowest 4 bits of CMF is the compression method (CM).
1042     // CM == 0x8 is the deflate compression method, which is currently the only
1043     // supported compression method, so any valid prefix must have CM == 0x8.
1044     //
1045     // The lowest 5 bits of FLG is FCHECK.
1046     // FCHECK must be such that the two header bytes are a multiple of 31 when
1047     // interpreted as a big endian 16-bit number.
1048     std::vector<std::string> result;
1049     // 16 values for the first byte, 8 values for the second byte.
1050     // There are also 4 combinations where both 0x00 and 0x1F work as FCHECK.
1051     result.reserve(132);
1052     // Select all values for the CMF byte that use the deflate algorithm 0x8.
1053     for (uint32_t first = 0x0800; first <= 0xF800; first += 0x1000) {
1054       // Select all values for the FLG, but leave FCHECK as 0 since it's fixed.
1055       for (uint32_t second = 0x00; second <= 0xE0; second += 0x20) {
1056         uint16_t prefix = first | second;
1057         // Compute FCHECK.
1058         prefix += 31 - (prefix % 31);
1059         result.push_back(prefixToStringLE(Endian::big(prefix)));
1060         // zlib won't produce this, but it is a valid prefix.
1061         if ((prefix & 0x1F) == 31) {
1062           prefix -= 31;
1063           result.push_back(prefixToStringLE(Endian::big(prefix)));
1064         }
1065       }
1066     }
1067     return result;
1068   } else {
1069     // The gzip frame starts with 2 magic bytes.
1070     return {prefixToStringLE(kGZIPMagicLE)};
1071   }
1072 }
1073
1074 bool ZlibStreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
1075     const {
1076   if (type() == CodecType::ZLIB) {
1077     uint16_t value;
1078     Cursor cursor{data};
1079     if (!cursor.tryReadBE(value)) {
1080       return false;
1081     }
1082     // zlib compressed if using deflate and is a multiple of 31.
1083     return (value & 0x0F00) == 0x0800 && value % 31 == 0;
1084   } else {
1085     return dataStartsWithLE(data, kGZIPMagicLE);
1086   }
1087 }
1088
1089 uint64_t ZlibStreamCodec::doMaxCompressedLength(
1090     uint64_t uncompressedLength) const {
1091   return deflateBound(nullptr, uncompressedLength);
1092 }
1093
1094 std::unique_ptr<Codec> ZlibStreamCodec::createCodec(int level, CodecType type) {
1095   return std::make_unique<ZlibStreamCodec>(level, type);
1096 }
1097
1098 std::unique_ptr<StreamCodec> ZlibStreamCodec::createStream(
1099     int level,
1100     CodecType type) {
1101   return std::make_unique<ZlibStreamCodec>(level, type);
1102 }
1103
1104 ZlibStreamCodec::ZlibStreamCodec(int level, CodecType type)
1105     : StreamCodec(type) {
1106   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
1107   switch (level) {
1108     case COMPRESSION_LEVEL_FASTEST:
1109       level = 1;
1110       break;
1111     case COMPRESSION_LEVEL_DEFAULT:
1112       level = Z_DEFAULT_COMPRESSION;
1113       break;
1114     case COMPRESSION_LEVEL_BEST:
1115       level = 9;
1116       break;
1117   }
1118   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
1119     throw std::invalid_argument(
1120         to<std::string>("ZlibStreamCodec: invalid level: ", level));
1121   }
1122   level_ = level;
1123 }
1124
1125 ZlibStreamCodec::~ZlibStreamCodec() {
1126   if (deflateStream_) {
1127     deflateEnd(deflateStream_.get_pointer());
1128     deflateStream_.clear();
1129   }
1130   if (inflateStream_) {
1131     inflateEnd(inflateStream_.get_pointer());
1132     inflateStream_.clear();
1133   }
1134 }
1135
1136 void ZlibStreamCodec::doResetStream() {
1137   needReset_ = true;
1138 }
1139
1140 void ZlibStreamCodec::resetDeflateStream() {
1141   if (deflateStream_) {
1142     int const rc = deflateReset(deflateStream_.get_pointer());
1143     if (rc != Z_OK) {
1144       deflateStream_.clear();
1145       throw std::runtime_error(
1146           to<std::string>("ZlibStreamCodec: deflateReset error: ", rc));
1147     }
1148     return;
1149   }
1150   deflateStream_ = z_stream{};
1151   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
1152   // base two logarithm of the maximum window size (...) The default value is
1153   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
1154   // around the compressed data instead of a zlib wrapper. The gzip header
1155   // will have no file name, no extra data, no comment, no modification time
1156   // (set to zero), no header crc, and the operating system will be set to 255
1157   // (unknown)."
1158   int const windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
1159   // All other parameters (method, memLevel, strategy) get default values from
1160   // the zlib manual.
1161   int const rc = deflateInit2(
1162       deflateStream_.get_pointer(),
1163       level_,
1164       Z_DEFLATED,
1165       windowBits,
1166       /* memLevel */ 8,
1167       Z_DEFAULT_STRATEGY);
1168   if (rc != Z_OK) {
1169     deflateStream_.clear();
1170     throw std::runtime_error(
1171         to<std::string>("ZlibStreamCodec: deflateInit error: ", rc));
1172   }
1173 }
1174
1175 void ZlibStreamCodec::resetInflateStream() {
1176   if (inflateStream_) {
1177     int const rc = inflateReset(inflateStream_.get_pointer());
1178     if (rc != Z_OK) {
1179       inflateStream_.clear();
1180       throw std::runtime_error(
1181           to<std::string>("ZlibStreamCodec: inflateReset error: ", rc));
1182     }
1183     return;
1184   }
1185   inflateStream_ = z_stream{};
1186   // "The windowBits parameter is the base two logarithm of the maximum window
1187   // size (...) The default value is 15 (...) add 16 to decode only the gzip
1188   // format (the zlib format will return a Z_DATA_ERROR)."
1189   int const windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
1190   int const rc = inflateInit2(inflateStream_.get_pointer(), windowBits);
1191   if (rc != Z_OK) {
1192     inflateStream_.clear();
1193     throw std::runtime_error(
1194         to<std::string>("ZlibStreamCodec: inflateInit error: ", rc));
1195   }
1196 }
1197
1198 static int zlibTranslateFlush(StreamCodec::FlushOp flush) {
1199   switch (flush) {
1200     case StreamCodec::FlushOp::NONE:
1201       return Z_NO_FLUSH;
1202     case StreamCodec::FlushOp::FLUSH:
1203       return Z_SYNC_FLUSH;
1204     case StreamCodec::FlushOp::END:
1205       return Z_FINISH;
1206     default:
1207       throw std::invalid_argument("ZlibStreamCodec: Invalid flush");
1208   }
1209 }
1210
1211 static int zlibThrowOnError(int rc) {
1212   switch (rc) {
1213     case Z_OK:
1214     case Z_BUF_ERROR:
1215     case Z_STREAM_END:
1216       return rc;
1217     default:
1218       throw std::runtime_error(to<std::string>("ZlibStreamCodec: error: ", rc));
1219   }
1220 }
1221
1222 bool ZlibStreamCodec::doCompressStream(
1223     ByteRange& input,
1224     MutableByteRange& output,
1225     StreamCodec::FlushOp flush) {
1226   if (needReset_) {
1227     resetDeflateStream();
1228     needReset_ = false;
1229   }
1230   DCHECK(deflateStream_.hasValue());
1231   // zlib will return Z_STREAM_ERROR if output.data() is null.
1232   if (output.data() == nullptr) {
1233     return false;
1234   }
1235   deflateStream_->next_in = const_cast<uint8_t*>(input.data());
1236   deflateStream_->avail_in = input.size();
1237   deflateStream_->next_out = output.data();
1238   deflateStream_->avail_out = output.size();
1239   SCOPE_EXIT {
1240     input.uncheckedAdvance(input.size() - deflateStream_->avail_in);
1241     output.uncheckedAdvance(output.size() - deflateStream_->avail_out);
1242   };
1243   int const rc = zlibThrowOnError(
1244       deflate(deflateStream_.get_pointer(), zlibTranslateFlush(flush)));
1245   switch (flush) {
1246     case StreamCodec::FlushOp::NONE:
1247       return false;
1248     case StreamCodec::FlushOp::FLUSH:
1249       return deflateStream_->avail_in == 0 && deflateStream_->avail_out != 0;
1250     case StreamCodec::FlushOp::END:
1251       return rc == Z_STREAM_END;
1252     default:
1253       throw std::invalid_argument("ZlibStreamCodec: Invalid flush");
1254   }
1255 }
1256
1257 bool ZlibStreamCodec::doUncompressStream(
1258     ByteRange& input,
1259     MutableByteRange& output,
1260     StreamCodec::FlushOp flush) {
1261   if (needReset_) {
1262     resetInflateStream();
1263     needReset_ = false;
1264   }
1265   DCHECK(inflateStream_.hasValue());
1266   // zlib will return Z_STREAM_ERROR if output.data() is null.
1267   if (output.data() == nullptr) {
1268     return false;
1269   }
1270   inflateStream_->next_in = const_cast<uint8_t*>(input.data());
1271   inflateStream_->avail_in = input.size();
1272   inflateStream_->next_out = output.data();
1273   inflateStream_->avail_out = output.size();
1274   SCOPE_EXIT {
1275     input.advance(input.size() - inflateStream_->avail_in);
1276     output.advance(output.size() - inflateStream_->avail_out);
1277   };
1278   int const rc = zlibThrowOnError(
1279       inflate(inflateStream_.get_pointer(), zlibTranslateFlush(flush)));
1280   return rc == Z_STREAM_END;
1281 }
1282
1283 #endif // FOLLY_HAVE_LIBZ
1284
1285 #if FOLLY_HAVE_LIBLZMA
1286
1287 /**
1288  * LZMA2 compression
1289  */
1290 class LZMA2Codec final : public Codec {
1291  public:
1292   static std::unique_ptr<Codec> create(int level, CodecType type);
1293   explicit LZMA2Codec(int level, CodecType type);
1294
1295   std::vector<std::string> validPrefixes() const override;
1296   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1297       const override;
1298
1299  private:
1300   bool doNeedsUncompressedLength() const override;
1301   uint64_t doMaxUncompressedLength() const override;
1302   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1303
1304   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
1305
1306   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
1307   std::unique_ptr<IOBuf> doUncompress(
1308       const IOBuf* data,
1309       Optional<uint64_t> uncompressedLength) override;
1310
1311   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
1312   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
1313
1314   int level_;
1315 };
1316
1317 static constexpr uint64_t kLZMA2MagicLE = 0x005A587A37FD;
1318 static constexpr unsigned kLZMA2MagicBytes = 6;
1319
1320 std::vector<std::string> LZMA2Codec::validPrefixes() const {
1321   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1322     return {};
1323   }
1324   return {prefixToStringLE(kLZMA2MagicLE, kLZMA2MagicBytes)};
1325 }
1326
1327 bool LZMA2Codec::canUncompress(const IOBuf* data, Optional<uint64_t>) const {
1328   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1329     return false;
1330   }
1331   // Returns false for all inputs less than 8 bytes.
1332   // This is okay, because no valid LZMA2 streams are less than 8 bytes.
1333   return dataStartsWithLE(data, kLZMA2MagicLE, kLZMA2MagicBytes);
1334 }
1335
1336 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
1337   return std::make_unique<LZMA2Codec>(level, type);
1338 }
1339
1340 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
1341   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
1342   switch (level) {
1343   case COMPRESSION_LEVEL_FASTEST:
1344     level = 0;
1345     break;
1346   case COMPRESSION_LEVEL_DEFAULT:
1347     level = LZMA_PRESET_DEFAULT;
1348     break;
1349   case COMPRESSION_LEVEL_BEST:
1350     level = 9;
1351     break;
1352   }
1353   if (level < 0 || level > 9) {
1354     throw std::invalid_argument(to<std::string>(
1355         "LZMA2Codec: invalid level: ", level));
1356   }
1357   level_ = level;
1358 }
1359
1360 bool LZMA2Codec::doNeedsUncompressedLength() const {
1361   return false;
1362 }
1363
1364 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
1365   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
1366   return uint64_t(1) << 63;
1367 }
1368
1369 uint64_t LZMA2Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
1370   return lzma_stream_buffer_bound(uncompressedLength) +
1371       (encodeSize() ? kMaxVarintLength64 : 0);
1372 }
1373
1374 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
1375     lzma_stream* stream,
1376     size_t length) {
1377
1378   CHECK_EQ(stream->avail_out, 0);
1379
1380   auto buf = IOBuf::create(length);
1381   buf->append(buf->capacity());
1382
1383   stream->next_out = buf->writableData();
1384   stream->avail_out = buf->length();
1385
1386   return buf;
1387 }
1388
1389 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
1390   lzma_ret rc;
1391   lzma_stream stream = LZMA_STREAM_INIT;
1392
1393   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
1394   if (rc != LZMA_OK) {
1395     throw std::runtime_error(folly::to<std::string>(
1396       "LZMA2Codec: lzma_easy_encoder error: ", rc));
1397   }
1398
1399   SCOPE_EXIT { lzma_end(&stream); };
1400
1401   uint64_t uncompressedLength = data->computeChainDataLength();
1402   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
1403
1404   // Max 64MiB in one go
1405   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
1406   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
1407
1408   auto out = addOutputBuffer(
1409     &stream,
1410     (maxCompressedLength <= maxSingleStepLength ?
1411      maxCompressedLength :
1412      defaultBufferLength));
1413
1414   if (encodeSize()) {
1415     auto size = IOBuf::createCombined(kMaxVarintLength64);
1416     encodeVarintToIOBuf(uncompressedLength, size.get());
1417     size->appendChain(std::move(out));
1418     out = std::move(size);
1419   }
1420
1421   for (auto& range : *data) {
1422     if (range.empty()) {
1423       continue;
1424     }
1425
1426     stream.next_in = const_cast<uint8_t*>(range.data());
1427     stream.avail_in = range.size();
1428
1429     while (stream.avail_in != 0) {
1430       if (stream.avail_out == 0) {
1431         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1432       }
1433
1434       rc = lzma_code(&stream, LZMA_RUN);
1435
1436       if (rc != LZMA_OK) {
1437         throw std::runtime_error(folly::to<std::string>(
1438           "LZMA2Codec: lzma_code error: ", rc));
1439       }
1440     }
1441   }
1442
1443   do {
1444     if (stream.avail_out == 0) {
1445       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1446     }
1447
1448     rc = lzma_code(&stream, LZMA_FINISH);
1449   } while (rc == LZMA_OK);
1450
1451   if (rc != LZMA_STREAM_END) {
1452     throw std::runtime_error(folly::to<std::string>(
1453       "LZMA2Codec: lzma_code ended with error: ", rc));
1454   }
1455
1456   out->prev()->trimEnd(stream.avail_out);
1457
1458   return out;
1459 }
1460
1461 bool LZMA2Codec::doInflate(lzma_stream* stream,
1462                           IOBuf* head,
1463                           size_t bufferLength) {
1464   if (stream->avail_out == 0) {
1465     head->prependChain(addOutputBuffer(stream, bufferLength));
1466   }
1467
1468   lzma_ret rc = lzma_code(stream, LZMA_RUN);
1469
1470   switch (rc) {
1471   case LZMA_OK:
1472     break;
1473   case LZMA_STREAM_END:
1474     return true;
1475   default:
1476     throw std::runtime_error(to<std::string>(
1477         "LZMA2Codec: lzma_code error: ", rc));
1478   }
1479
1480   return false;
1481 }
1482
1483 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(
1484     const IOBuf* data,
1485     Optional<uint64_t> uncompressedLength) {
1486   lzma_ret rc;
1487   lzma_stream stream = LZMA_STREAM_INIT;
1488
1489   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
1490   if (rc != LZMA_OK) {
1491     throw std::runtime_error(folly::to<std::string>(
1492       "LZMA2Codec: lzma_auto_decoder error: ", rc));
1493   }
1494
1495   SCOPE_EXIT { lzma_end(&stream); };
1496
1497   // Max 64MiB in one go
1498   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB
1499   constexpr uint32_t defaultBufferLength = uint32_t(256) << 10; // 256 KiB
1500
1501   folly::io::Cursor cursor(data);
1502   if (encodeSize()) {
1503     const uint64_t actualUncompressedLength = decodeVarintFromCursor(cursor);
1504     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
1505       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
1506     }
1507     uncompressedLength = actualUncompressedLength;
1508   }
1509
1510   auto out = addOutputBuffer(
1511       &stream,
1512       ((uncompressedLength && *uncompressedLength <= maxSingleStepLength)
1513            ? *uncompressedLength
1514            : defaultBufferLength));
1515
1516   bool streamEnd = false;
1517   auto buf = cursor.peekBytes();
1518   while (!buf.empty()) {
1519     stream.next_in = const_cast<uint8_t*>(buf.data());
1520     stream.avail_in = buf.size();
1521
1522     while (stream.avail_in != 0) {
1523       if (streamEnd) {
1524         throw std::runtime_error(to<std::string>(
1525             "LZMA2Codec: junk after end of data"));
1526       }
1527
1528       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1529     }
1530
1531     cursor.skip(buf.size());
1532     buf = cursor.peekBytes();
1533   }
1534
1535   while (!streamEnd) {
1536     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1537   }
1538
1539   out->prev()->trimEnd(stream.avail_out);
1540
1541   if (uncompressedLength && *uncompressedLength != stream.total_out) {
1542     throw std::runtime_error(
1543         to<std::string>("LZMA2Codec: invalid uncompressed length"));
1544   }
1545
1546   return out;
1547 }
1548
1549 #endif  // FOLLY_HAVE_LIBLZMA
1550
1551 #ifdef FOLLY_HAVE_LIBZSTD
1552
1553 namespace {
1554 void zstdFreeCStream(ZSTD_CStream* zcs) {
1555   ZSTD_freeCStream(zcs);
1556 }
1557
1558 void zstdFreeDStream(ZSTD_DStream* zds) {
1559   ZSTD_freeDStream(zds);
1560 }
1561 }
1562
1563 /**
1564  * ZSTD compression
1565  */
1566 class ZSTDStreamCodec final : public StreamCodec {
1567  public:
1568   static std::unique_ptr<Codec> createCodec(int level, CodecType);
1569   static std::unique_ptr<StreamCodec> createStream(int level, CodecType);
1570   explicit ZSTDStreamCodec(int level, CodecType type);
1571
1572   std::vector<std::string> validPrefixes() const override;
1573   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1574       const override;
1575
1576  private:
1577   bool doNeedsUncompressedLength() const override;
1578   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1579   Optional<uint64_t> doGetUncompressedLength(
1580       IOBuf const* data,
1581       Optional<uint64_t> uncompressedLength) const override;
1582
1583   void doResetStream() override;
1584   bool doCompressStream(
1585       ByteRange& input,
1586       MutableByteRange& output,
1587       StreamCodec::FlushOp flushOp) override;
1588   bool doUncompressStream(
1589       ByteRange& input,
1590       MutableByteRange& output,
1591       StreamCodec::FlushOp flushOp) override;
1592
1593   void resetCStream();
1594   void resetDStream();
1595
1596   bool tryBlockCompress(ByteRange& input, MutableByteRange& output) const;
1597   bool tryBlockUncompress(ByteRange& input, MutableByteRange& output) const;
1598
1599   int level_;
1600   bool needReset_{true};
1601   std::unique_ptr<
1602       ZSTD_CStream,
1603       folly::static_function_deleter<ZSTD_CStream, &zstdFreeCStream>>
1604       cstream_{nullptr};
1605   std::unique_ptr<
1606       ZSTD_DStream,
1607       folly::static_function_deleter<ZSTD_DStream, &zstdFreeDStream>>
1608       dstream_{nullptr};
1609 };
1610
1611 static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528;
1612
1613 std::vector<std::string> ZSTDStreamCodec::validPrefixes() const {
1614   return {prefixToStringLE(kZSTDMagicLE)};
1615 }
1616
1617 bool ZSTDStreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
1618     const {
1619   return dataStartsWithLE(data, kZSTDMagicLE);
1620 }
1621
1622 std::unique_ptr<Codec> ZSTDStreamCodec::createCodec(int level, CodecType type) {
1623   return make_unique<ZSTDStreamCodec>(level, type);
1624 }
1625
1626 std::unique_ptr<StreamCodec> ZSTDStreamCodec::createStream(
1627     int level,
1628     CodecType type) {
1629   return make_unique<ZSTDStreamCodec>(level, type);
1630 }
1631
1632 ZSTDStreamCodec::ZSTDStreamCodec(int level, CodecType type)
1633     : StreamCodec(type) {
1634   DCHECK(type == CodecType::ZSTD);
1635   switch (level) {
1636     case COMPRESSION_LEVEL_FASTEST:
1637       level = 1;
1638       break;
1639     case COMPRESSION_LEVEL_DEFAULT:
1640       level = 1;
1641       break;
1642     case COMPRESSION_LEVEL_BEST:
1643       level = 19;
1644       break;
1645   }
1646   if (level < 1 || level > ZSTD_maxCLevel()) {
1647     throw std::invalid_argument(
1648         to<std::string>("ZSTD: invalid level: ", level));
1649   }
1650   level_ = level;
1651 }
1652
1653 bool ZSTDStreamCodec::doNeedsUncompressedLength() const {
1654   return false;
1655 }
1656
1657 uint64_t ZSTDStreamCodec::doMaxCompressedLength(
1658     uint64_t uncompressedLength) const {
1659   return ZSTD_compressBound(uncompressedLength);
1660 }
1661
1662 void zstdThrowIfError(size_t rc) {
1663   if (!ZSTD_isError(rc)) {
1664     return;
1665   }
1666   throw std::runtime_error(
1667       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1668 }
1669
1670 Optional<uint64_t> ZSTDStreamCodec::doGetUncompressedLength(
1671     IOBuf const* data,
1672     Optional<uint64_t> uncompressedLength) const {
1673   // Read decompressed size from frame if available in first IOBuf.
1674   auto const decompressedSize =
1675       ZSTD_getDecompressedSize(data->data(), data->length());
1676   if (decompressedSize != 0) {
1677     if (uncompressedLength && *uncompressedLength != decompressedSize) {
1678       throw std::runtime_error("ZSTD: invalid uncompressed length");
1679     }
1680     uncompressedLength = decompressedSize;
1681   }
1682   return uncompressedLength;
1683 }
1684
1685 void ZSTDStreamCodec::doResetStream() {
1686   needReset_ = true;
1687 }
1688
1689 bool ZSTDStreamCodec::tryBlockCompress(
1690     ByteRange& input,
1691     MutableByteRange& output) const {
1692   DCHECK(needReset_);
1693   // We need to know that we have enough output space to use block compression
1694   if (output.size() < ZSTD_compressBound(input.size())) {
1695     return false;
1696   }
1697   size_t const length = ZSTD_compress(
1698       output.data(), output.size(), input.data(), input.size(), level_);
1699   zstdThrowIfError(length);
1700   input.uncheckedAdvance(input.size());
1701   output.uncheckedAdvance(length);
1702   return true;
1703 }
1704
1705 void ZSTDStreamCodec::resetCStream() {
1706   if (!cstream_) {
1707     cstream_.reset(ZSTD_createCStream());
1708     if (!cstream_) {
1709       throw std::bad_alloc{};
1710     }
1711   }
1712   // Advanced API usage works for all supported versions of zstd.
1713   // Required to set contentSizeFlag.
1714   auto params = ZSTD_getParams(level_, uncompressedLength().value_or(0), 0);
1715   params.fParams.contentSizeFlag = uncompressedLength().hasValue();
1716   zstdThrowIfError(ZSTD_initCStream_advanced(
1717       cstream_.get(), nullptr, 0, params, uncompressedLength().value_or(0)));
1718 }
1719
1720 bool ZSTDStreamCodec::doCompressStream(
1721     ByteRange& input,
1722     MutableByteRange& output,
1723     StreamCodec::FlushOp flushOp) {
1724   if (needReset_) {
1725     // If we are given all the input in one chunk try to use block compression
1726     if (flushOp == StreamCodec::FlushOp::END &&
1727         tryBlockCompress(input, output)) {
1728       return true;
1729     }
1730     resetCStream();
1731     needReset_ = false;
1732   }
1733   ZSTD_inBuffer in = {input.data(), input.size(), 0};
1734   ZSTD_outBuffer out = {output.data(), output.size(), 0};
1735   SCOPE_EXIT {
1736     input.uncheckedAdvance(in.pos);
1737     output.uncheckedAdvance(out.pos);
1738   };
1739   if (flushOp == StreamCodec::FlushOp::NONE || !input.empty()) {
1740     zstdThrowIfError(ZSTD_compressStream(cstream_.get(), &out, &in));
1741   }
1742   if (in.pos == in.size && flushOp != StreamCodec::FlushOp::NONE) {
1743     size_t rc;
1744     switch (flushOp) {
1745       case StreamCodec::FlushOp::FLUSH:
1746         rc = ZSTD_flushStream(cstream_.get(), &out);
1747         break;
1748       case StreamCodec::FlushOp::END:
1749         rc = ZSTD_endStream(cstream_.get(), &out);
1750         break;
1751       default:
1752         throw std::invalid_argument("ZSTD: invalid FlushOp");
1753     }
1754     zstdThrowIfError(rc);
1755     if (rc == 0) {
1756       return true;
1757     }
1758   }
1759   return false;
1760 }
1761
1762 bool ZSTDStreamCodec::tryBlockUncompress(
1763     ByteRange& input,
1764     MutableByteRange& output) const {
1765   DCHECK(needReset_);
1766 #if ZSTD_VERSION_NUMBER < 10104
1767   // We require ZSTD_findFrameCompressedSize() to perform this optimization.
1768   return false;
1769 #else
1770   // We need to know the uncompressed length and have enough output space.
1771   if (!uncompressedLength() || output.size() < *uncompressedLength()) {
1772     return false;
1773   }
1774   size_t const compressedLength =
1775       ZSTD_findFrameCompressedSize(input.data(), input.size());
1776   zstdThrowIfError(compressedLength);
1777   size_t const length = ZSTD_decompress(
1778       output.data(), *uncompressedLength(), input.data(), compressedLength);
1779   zstdThrowIfError(length);
1780   DCHECK_EQ(length, *uncompressedLength());
1781   input.uncheckedAdvance(compressedLength);
1782   output.uncheckedAdvance(length);
1783   return true;
1784 #endif
1785 }
1786
1787 void ZSTDStreamCodec::resetDStream() {
1788   if (!dstream_) {
1789     dstream_.reset(ZSTD_createDStream());
1790     if (!dstream_) {
1791       throw std::bad_alloc{};
1792     }
1793   }
1794   zstdThrowIfError(ZSTD_initDStream(dstream_.get()));
1795 }
1796
1797 bool ZSTDStreamCodec::doUncompressStream(
1798     ByteRange& input,
1799     MutableByteRange& output,
1800     StreamCodec::FlushOp flushOp) {
1801   if (needReset_) {
1802     // If we are given all the input in one chunk try to use block uncompression
1803     if (flushOp == StreamCodec::FlushOp::END &&
1804         tryBlockUncompress(input, output)) {
1805       return true;
1806     }
1807     resetDStream();
1808     needReset_ = false;
1809   }
1810   ZSTD_inBuffer in = {input.data(), input.size(), 0};
1811   ZSTD_outBuffer out = {output.data(), output.size(), 0};
1812   SCOPE_EXIT {
1813     input.uncheckedAdvance(in.pos);
1814     output.uncheckedAdvance(out.pos);
1815   };
1816   size_t const rc = ZSTD_decompressStream(dstream_.get(), &out, &in);
1817   zstdThrowIfError(rc);
1818   return rc == 0;
1819 }
1820
1821 #endif // FOLLY_HAVE_LIBZSTD
1822
1823 #if FOLLY_HAVE_LIBBZ2
1824
1825 class Bzip2Codec final : public Codec {
1826  public:
1827   static std::unique_ptr<Codec> create(int level, CodecType type);
1828   explicit Bzip2Codec(int level, CodecType type);
1829
1830   std::vector<std::string> validPrefixes() const override;
1831   bool canUncompress(IOBuf const* data, Optional<uint64_t> uncompressedLength)
1832       const override;
1833
1834  private:
1835   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1836   std::unique_ptr<IOBuf> doCompress(IOBuf const* data) override;
1837   std::unique_ptr<IOBuf> doUncompress(
1838       IOBuf const* data,
1839       Optional<uint64_t> uncompressedLength) override;
1840
1841   int level_;
1842 };
1843
1844 /* static */ std::unique_ptr<Codec> Bzip2Codec::create(
1845     int level,
1846     CodecType type) {
1847   return std::make_unique<Bzip2Codec>(level, type);
1848 }
1849
1850 Bzip2Codec::Bzip2Codec(int level, CodecType type) : Codec(type) {
1851   DCHECK(type == CodecType::BZIP2);
1852   switch (level) {
1853     case COMPRESSION_LEVEL_FASTEST:
1854       level = 1;
1855       break;
1856     case COMPRESSION_LEVEL_DEFAULT:
1857       level = 9;
1858       break;
1859     case COMPRESSION_LEVEL_BEST:
1860       level = 9;
1861       break;
1862   }
1863   if (level < 1 || level > 9) {
1864     throw std::invalid_argument(
1865         to<std::string>("Bzip2: invalid level: ", level));
1866   }
1867   level_ = level;
1868 }
1869
1870 static uint32_t constexpr kBzip2MagicLE = 0x685a42;
1871 static uint64_t constexpr kBzip2MagicBytes = 3;
1872
1873 std::vector<std::string> Bzip2Codec::validPrefixes() const {
1874   return {prefixToStringLE(kBzip2MagicLE, kBzip2MagicBytes)};
1875 }
1876
1877 bool Bzip2Codec::canUncompress(IOBuf const* data, Optional<uint64_t>) const {
1878   return dataStartsWithLE(data, kBzip2MagicLE, kBzip2MagicBytes);
1879 }
1880
1881 uint64_t Bzip2Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
1882   // http://www.bzip.org/1.0.5/bzip2-manual-1.0.5.html#bzbufftobuffcompress
1883   //   To guarantee that the compressed data will fit in its buffer, allocate an
1884   //   output buffer of size 1% larger than the uncompressed data, plus six
1885   //   hundred extra bytes.
1886   return uncompressedLength + uncompressedLength / 100 + 600;
1887 }
1888
1889 static bz_stream createBzStream() {
1890   bz_stream stream;
1891   stream.bzalloc = nullptr;
1892   stream.bzfree = nullptr;
1893   stream.opaque = nullptr;
1894   stream.next_in = stream.next_out = nullptr;
1895   stream.avail_in = stream.avail_out = 0;
1896   return stream;
1897 }
1898
1899 // Throws on error condition, otherwise returns the code.
1900 static int bzCheck(int const rc) {
1901   switch (rc) {
1902     case BZ_OK:
1903     case BZ_RUN_OK:
1904     case BZ_FLUSH_OK:
1905     case BZ_FINISH_OK:
1906     case BZ_STREAM_END:
1907       return rc;
1908     default:
1909       throw std::runtime_error(to<std::string>("Bzip2 error: ", rc));
1910   }
1911 }
1912
1913 static std::unique_ptr<IOBuf> addOutputBuffer(
1914     bz_stream* stream,
1915     uint64_t const bufferLength) {
1916   DCHECK_LE(bufferLength, std::numeric_limits<unsigned>::max());
1917   DCHECK_EQ(stream->avail_out, 0);
1918
1919   auto buf = IOBuf::create(bufferLength);
1920   buf->append(buf->capacity());
1921
1922   stream->next_out = reinterpret_cast<char*>(buf->writableData());
1923   stream->avail_out = buf->length();
1924
1925   return buf;
1926 }
1927
1928 std::unique_ptr<IOBuf> Bzip2Codec::doCompress(IOBuf const* data) {
1929   bz_stream stream = createBzStream();
1930   bzCheck(BZ2_bzCompressInit(&stream, level_, 0, 0));
1931   SCOPE_EXIT {
1932     bzCheck(BZ2_bzCompressEnd(&stream));
1933   };
1934
1935   uint64_t const uncompressedLength = data->computeChainDataLength();
1936   uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength);
1937   uint64_t constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MiB
1938   uint64_t constexpr kDefaultBufferLength = uint64_t(4) << 20;
1939
1940   auto out = addOutputBuffer(
1941       &stream,
1942       maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen
1943                                                : kDefaultBufferLength);
1944
1945   for (auto range : *data) {
1946     while (!range.empty()) {
1947       auto const inSize = std::min<size_t>(range.size(), kMaxSingleStepLength);
1948       stream.next_in =
1949           const_cast<char*>(reinterpret_cast<char const*>(range.data()));
1950       stream.avail_in = inSize;
1951
1952       if (stream.avail_out == 0) {
1953         out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1954       }
1955
1956       bzCheck(BZ2_bzCompress(&stream, BZ_RUN));
1957       range.uncheckedAdvance(inSize - stream.avail_in);
1958     }
1959   }
1960   do {
1961     if (stream.avail_out == 0) {
1962       out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1963     }
1964   } while (bzCheck(BZ2_bzCompress(&stream, BZ_FINISH)) != BZ_STREAM_END);
1965
1966   out->prev()->trimEnd(stream.avail_out);
1967
1968   return out;
1969 }
1970
1971 std::unique_ptr<IOBuf> Bzip2Codec::doUncompress(
1972     const IOBuf* data,
1973     Optional<uint64_t> uncompressedLength) {
1974   bz_stream stream = createBzStream();
1975   bzCheck(BZ2_bzDecompressInit(&stream, 0, 0));
1976   SCOPE_EXIT {
1977     bzCheck(BZ2_bzDecompressEnd(&stream));
1978   };
1979
1980   uint64_t constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MiB
1981   uint64_t const kBlockSize = uint64_t(100) << 10; // 100 KiB
1982   uint64_t const kDefaultBufferLength =
1983       computeBufferLength(data->computeChainDataLength(), kBlockSize);
1984
1985   auto out = addOutputBuffer(
1986       &stream,
1987       ((uncompressedLength && *uncompressedLength <= kMaxSingleStepLength)
1988            ? *uncompressedLength
1989            : kDefaultBufferLength));
1990
1991   int rc = BZ_OK;
1992   for (auto range : *data) {
1993     while (!range.empty()) {
1994       auto const inSize = std::min<size_t>(range.size(), kMaxSingleStepLength);
1995       stream.next_in =
1996           const_cast<char*>(reinterpret_cast<char const*>(range.data()));
1997       stream.avail_in = inSize;
1998
1999       if (stream.avail_out == 0) {
2000         out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
2001       }
2002
2003       rc = bzCheck(BZ2_bzDecompress(&stream));
2004       range.uncheckedAdvance(inSize - stream.avail_in);
2005     }
2006   }
2007   while (rc != BZ_STREAM_END) {
2008     if (stream.avail_out == 0) {
2009       out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
2010     }
2011
2012     rc = bzCheck(BZ2_bzDecompress(&stream));
2013   }
2014
2015   out->prev()->trimEnd(stream.avail_out);
2016
2017   uint64_t const totalOut =
2018       (uint64_t(stream.total_out_hi32) << 32) + stream.total_out_lo32;
2019   if (uncompressedLength && uncompressedLength != totalOut) {
2020     throw std::runtime_error("Bzip2 error: Invalid uncompressed length");
2021   }
2022
2023   return out;
2024 }
2025
2026 #endif // FOLLY_HAVE_LIBBZ2
2027
2028 /**
2029  * Automatic decompression
2030  */
2031 class AutomaticCodec final : public Codec {
2032  public:
2033   static std::unique_ptr<Codec> create(
2034       std::vector<std::unique_ptr<Codec>> customCodecs);
2035   explicit AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs);
2036
2037   std::vector<std::string> validPrefixes() const override;
2038   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
2039       const override;
2040
2041  private:
2042   bool doNeedsUncompressedLength() const override;
2043   uint64_t doMaxUncompressedLength() const override;
2044
2045   uint64_t doMaxCompressedLength(uint64_t) const override {
2046     throw std::runtime_error(
2047         "AutomaticCodec error: maxCompressedLength() not supported.");
2048   }
2049   std::unique_ptr<IOBuf> doCompress(const IOBuf*) override {
2050     throw std::runtime_error("AutomaticCodec error: compress() not supported.");
2051   }
2052   std::unique_ptr<IOBuf> doUncompress(
2053       const IOBuf* data,
2054       Optional<uint64_t> uncompressedLength) override;
2055
2056   void addCodecIfSupported(CodecType type);
2057
2058   // Throws iff the codecs aren't compatible (very slow)
2059   void checkCompatibleCodecs() const;
2060
2061   std::vector<std::unique_ptr<Codec>> codecs_;
2062   bool needsUncompressedLength_;
2063   uint64_t maxUncompressedLength_;
2064 };
2065
2066 std::vector<std::string> AutomaticCodec::validPrefixes() const {
2067   std::unordered_set<std::string> prefixes;
2068   for (const auto& codec : codecs_) {
2069     const auto codecPrefixes = codec->validPrefixes();
2070     prefixes.insert(codecPrefixes.begin(), codecPrefixes.end());
2071   }
2072   return std::vector<std::string>{prefixes.begin(), prefixes.end()};
2073 }
2074
2075 bool AutomaticCodec::canUncompress(
2076     const IOBuf* data,
2077     Optional<uint64_t> uncompressedLength) const {
2078   return std::any_of(
2079       codecs_.begin(),
2080       codecs_.end(),
2081       [data, uncompressedLength](std::unique_ptr<Codec> const& codec) {
2082         return codec->canUncompress(data, uncompressedLength);
2083       });
2084 }
2085
2086 void AutomaticCodec::addCodecIfSupported(CodecType type) {
2087   const bool present = std::any_of(
2088       codecs_.begin(),
2089       codecs_.end(),
2090       [&type](std::unique_ptr<Codec> const& codec) {
2091         return codec->type() == type;
2092       });
2093   if (hasCodec(type) && !present) {
2094     codecs_.push_back(getCodec(type));
2095   }
2096 }
2097
2098 /* static */ std::unique_ptr<Codec> AutomaticCodec::create(
2099     std::vector<std::unique_ptr<Codec>> customCodecs) {
2100   return std::make_unique<AutomaticCodec>(std::move(customCodecs));
2101 }
2102
2103 AutomaticCodec::AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs)
2104     : Codec(CodecType::USER_DEFINED), codecs_(std::move(customCodecs)) {
2105   // Fastest -> slowest
2106   addCodecIfSupported(CodecType::LZ4_FRAME);
2107   addCodecIfSupported(CodecType::ZSTD);
2108   addCodecIfSupported(CodecType::ZLIB);
2109   addCodecIfSupported(CodecType::GZIP);
2110   addCodecIfSupported(CodecType::LZMA2);
2111   addCodecIfSupported(CodecType::BZIP2);
2112   if (kIsDebug) {
2113     checkCompatibleCodecs();
2114   }
2115   // Check that none of the codes are are null
2116   DCHECK(std::none_of(
2117       codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
2118         return codec == nullptr;
2119       }));
2120
2121   needsUncompressedLength_ = std::any_of(
2122       codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
2123         return codec->needsUncompressedLength();
2124       });
2125
2126   const auto it = std::max_element(
2127       codecs_.begin(),
2128       codecs_.end(),
2129       [](std::unique_ptr<Codec> const& lhs, std::unique_ptr<Codec> const& rhs) {
2130         return lhs->maxUncompressedLength() < rhs->maxUncompressedLength();
2131       });
2132   DCHECK(it != codecs_.end());
2133   maxUncompressedLength_ = (*it)->maxUncompressedLength();
2134 }
2135
2136 void AutomaticCodec::checkCompatibleCodecs() const {
2137   // Keep track of all the possible headers.
2138   std::unordered_set<std::string> headers;
2139   // The empty header is not allowed.
2140   headers.insert("");
2141   // Step 1:
2142   // Construct a set of headers and check that none of the headers occur twice.
2143   // Eliminate edge cases.
2144   for (auto&& codec : codecs_) {
2145     const auto codecHeaders = codec->validPrefixes();
2146     // Codecs without any valid headers are not allowed.
2147     if (codecHeaders.empty()) {
2148       throw std::invalid_argument{
2149           "AutomaticCodec: validPrefixes() must not be empty."};
2150     }
2151     // Insert all the headers for the current codec.
2152     const size_t beforeSize = headers.size();
2153     headers.insert(codecHeaders.begin(), codecHeaders.end());
2154     // Codecs are not compatible if any header occurred twice.
2155     if (beforeSize + codecHeaders.size() != headers.size()) {
2156       throw std::invalid_argument{
2157           "AutomaticCodec: Two valid prefixes collide."};
2158     }
2159   }
2160   // Step 2:
2161   // Check if any strict non-empty prefix of any header is a header.
2162   for (const auto& header : headers) {
2163     for (size_t i = 1; i < header.size(); ++i) {
2164       if (headers.count(header.substr(0, i))) {
2165         throw std::invalid_argument{
2166             "AutomaticCodec: One valid prefix is a prefix of another valid "
2167             "prefix."};
2168       }
2169     }
2170   }
2171 }
2172
2173 bool AutomaticCodec::doNeedsUncompressedLength() const {
2174   return needsUncompressedLength_;
2175 }
2176
2177 uint64_t AutomaticCodec::doMaxUncompressedLength() const {
2178   return maxUncompressedLength_;
2179 }
2180
2181 std::unique_ptr<IOBuf> AutomaticCodec::doUncompress(
2182     const IOBuf* data,
2183     Optional<uint64_t> uncompressedLength) {
2184   for (auto&& codec : codecs_) {
2185     if (codec->canUncompress(data, uncompressedLength)) {
2186       return codec->uncompress(data, uncompressedLength);
2187     }
2188   }
2189   throw std::runtime_error("AutomaticCodec error: Unknown compressed data");
2190 }
2191
2192 using CodecFactory = std::unique_ptr<Codec> (*)(int, CodecType);
2193 using StreamCodecFactory = std::unique_ptr<StreamCodec> (*)(int, CodecType);
2194 struct Factory {
2195   CodecFactory codec;
2196   StreamCodecFactory stream;
2197 };
2198
2199 constexpr Factory
2200     codecFactories[static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
2201         {}, // USER_DEFINED
2202         {NoCompressionCodec::create, nullptr},
2203
2204 #if FOLLY_HAVE_LIBLZ4
2205         {LZ4Codec::create, nullptr},
2206 #else
2207         {},
2208 #endif
2209
2210 #if FOLLY_HAVE_LIBSNAPPY
2211         {SnappyCodec::create, nullptr},
2212 #else
2213         {},
2214 #endif
2215
2216 #if FOLLY_HAVE_LIBZ
2217         {ZlibStreamCodec::createCodec, ZlibStreamCodec::createStream},
2218 #else
2219         {},
2220 #endif
2221
2222 #if FOLLY_HAVE_LIBLZ4
2223         {LZ4Codec::create, nullptr},
2224 #else
2225         {},
2226 #endif
2227
2228 #if FOLLY_HAVE_LIBLZMA
2229         {LZMA2Codec::create, nullptr},
2230         {LZMA2Codec::create, nullptr},
2231 #else
2232         {},
2233         {},
2234 #endif
2235
2236 #if FOLLY_HAVE_LIBZSTD
2237         {ZSTDStreamCodec::createCodec, ZSTDStreamCodec::createStream},
2238 #else
2239         {},
2240 #endif
2241
2242 #if FOLLY_HAVE_LIBZ
2243         {ZlibStreamCodec::createCodec, ZlibStreamCodec::createStream},
2244 #else
2245         {},
2246 #endif
2247
2248 #if (FOLLY_HAVE_LIBLZ4 && LZ4_VERSION_NUMBER >= 10301)
2249         {LZ4FrameCodec::create, nullptr},
2250 #else
2251         {},
2252 #endif
2253
2254 #if FOLLY_HAVE_LIBBZ2
2255         {Bzip2Codec::create, nullptr},
2256 #else
2257         {},
2258 #endif
2259 };
2260
2261 Factory const& getFactory(CodecType type) {
2262   size_t const idx = static_cast<size_t>(type);
2263   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
2264     throw std::invalid_argument(
2265         to<std::string>("Compression type ", idx, " invalid"));
2266   }
2267   return codecFactories[idx];
2268 }
2269 } // namespace
2270
2271 bool hasCodec(CodecType type) {
2272   return getFactory(type).codec != nullptr;
2273 }
2274
2275 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
2276   auto const factory = getFactory(type).codec;
2277   if (!factory) {
2278     throw std::invalid_argument(
2279         to<std::string>("Compression type ", type, " not supported"));
2280   }
2281   auto codec = (*factory)(level, type);
2282   DCHECK(codec->type() == type);
2283   return codec;
2284 }
2285
2286 bool hasStreamCodec(CodecType type) {
2287   return getFactory(type).stream != nullptr;
2288 }
2289
2290 std::unique_ptr<StreamCodec> getStreamCodec(CodecType type, int level) {
2291   auto const factory = getFactory(type).stream;
2292   if (!factory) {
2293     throw std::invalid_argument(
2294         to<std::string>("Compression type ", type, " not supported"));
2295   }
2296   auto codec = (*factory)(level, type);
2297   DCHECK(codec->type() == type);
2298   return codec;
2299 }
2300
2301 std::unique_ptr<Codec> getAutoUncompressionCodec(
2302     std::vector<std::unique_ptr<Codec>> customCodecs) {
2303   return AutomaticCodec::create(std::move(customCodecs));
2304 }
2305 }}  // namespaces