98f332b85eb55d34f047d47ef51e7cd368369e16
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #if LZ4_VERSION_NUMBER >= 10301
23 #include <lz4frame.h>
24 #endif
25 #endif
26
27 #include <glog/logging.h>
28
29 #if FOLLY_HAVE_LIBSNAPPY
30 #include <snappy.h>
31 #include <snappy-sinksource.h>
32 #endif
33
34 #if FOLLY_HAVE_LIBZ
35 #include <zlib.h>
36 #endif
37
38 #if FOLLY_HAVE_LIBLZMA
39 #include <lzma.h>
40 #endif
41
42 #if FOLLY_HAVE_LIBZSTD
43 #define ZSTD_STATIC_LINKING_ONLY
44 #include <zstd.h>
45 #endif
46
47 #if FOLLY_HAVE_LIBBZ2
48 #include <bzlib.h>
49 #endif
50
51 #include <folly/Bits.h>
52 #include <folly/Conv.h>
53 #include <folly/Memory.h>
54 #include <folly/Portability.h>
55 #include <folly/ScopeGuard.h>
56 #include <folly/Varint.h>
57 #include <folly/io/Cursor.h>
58 #include <algorithm>
59 #include <unordered_set>
60
61 namespace folly {
62 namespace io {
63
64 Codec::Codec(CodecType type) : type_(type) { }
65
66 // Ensure consistent behavior in the nullptr case
67 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
68   if (data == nullptr) {
69     throw std::invalid_argument("Codec: data must not be nullptr");
70   }
71   uint64_t len = data->computeChainDataLength();
72   if (len == 0) {
73     return IOBuf::create(0);
74   }
75   if (len > maxUncompressedLength()) {
76     throw std::runtime_error("Codec: uncompressed length too large");
77   }
78
79   return doCompress(data);
80 }
81
82 std::string Codec::compress(const StringPiece data) {
83   const uint64_t len = data.size();
84   if (len == 0) {
85     return "";
86   }
87   if (len > maxUncompressedLength()) {
88     throw std::runtime_error("Codec: uncompressed length too large");
89   }
90
91   return doCompressString(data);
92 }
93
94 std::unique_ptr<IOBuf> Codec::uncompress(
95     const IOBuf* data,
96     Optional<uint64_t> uncompressedLength) {
97   if (data == nullptr) {
98     throw std::invalid_argument("Codec: data must not be nullptr");
99   }
100   if (!uncompressedLength) {
101     if (needsUncompressedLength()) {
102       throw std::invalid_argument("Codec: uncompressed length required");
103     }
104   } else if (*uncompressedLength > maxUncompressedLength()) {
105     throw std::runtime_error("Codec: uncompressed length too large");
106   }
107
108   if (data->empty()) {
109     if (uncompressedLength.value_or(0) != 0) {
110       throw std::runtime_error("Codec: invalid uncompressed length");
111     }
112     return IOBuf::create(0);
113   }
114
115   return doUncompress(data, uncompressedLength);
116 }
117
118 std::string Codec::uncompress(
119     const StringPiece data,
120     Optional<uint64_t> uncompressedLength) {
121   if (!uncompressedLength) {
122     if (needsUncompressedLength()) {
123       throw std::invalid_argument("Codec: uncompressed length required");
124     }
125   } else if (*uncompressedLength > maxUncompressedLength()) {
126     throw std::runtime_error("Codec: uncompressed length too large");
127   }
128
129   if (data.empty()) {
130     if (uncompressedLength.value_or(0) != 0) {
131       throw std::runtime_error("Codec: invalid uncompressed length");
132     }
133     return "";
134   }
135
136   return doUncompressString(data, uncompressedLength);
137 }
138
139 bool Codec::needsUncompressedLength() const {
140   return doNeedsUncompressedLength();
141 }
142
143 uint64_t Codec::maxUncompressedLength() const {
144   return doMaxUncompressedLength();
145 }
146
147 bool Codec::doNeedsUncompressedLength() const {
148   return false;
149 }
150
151 uint64_t Codec::doMaxUncompressedLength() const {
152   return UNLIMITED_UNCOMPRESSED_LENGTH;
153 }
154
155 std::vector<std::string> Codec::validPrefixes() const {
156   return {};
157 }
158
159 bool Codec::canUncompress(const IOBuf*, Optional<uint64_t>) const {
160   return false;
161 }
162
163 std::string Codec::doCompressString(const StringPiece data) {
164   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
165   auto outputBuffer = doCompress(&inputBuffer);
166   std::string output;
167   output.reserve(outputBuffer->computeChainDataLength());
168   for (auto range : *outputBuffer) {
169     output.append(reinterpret_cast<const char*>(range.data()), range.size());
170   }
171   return output;
172 }
173
174 std::string Codec::doUncompressString(
175     const StringPiece data,
176     Optional<uint64_t> uncompressedLength) {
177   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
178   auto outputBuffer = doUncompress(&inputBuffer, uncompressedLength);
179   std::string output;
180   output.reserve(outputBuffer->computeChainDataLength());
181   for (auto range : *outputBuffer) {
182     output.append(reinterpret_cast<const char*>(range.data()), range.size());
183   }
184   return output;
185 }
186
187 uint64_t Codec::maxCompressedLength(uint64_t uncompressedLength) const {
188   if (uncompressedLength == 0) {
189     return 0;
190   }
191   return doMaxCompressedLength(uncompressedLength);
192 }
193
194 Optional<uint64_t> Codec::getUncompressedLength(
195     const folly::IOBuf* data,
196     Optional<uint64_t> uncompressedLength) const {
197   auto const compressedLength = data->computeChainDataLength();
198   if (uncompressedLength == uint64_t(0) || compressedLength == 0) {
199     if (uncompressedLength.value_or(0) != 0 || compressedLength != 0) {
200       throw std::runtime_error("Invalid uncompressed length");
201     }
202     return 0;
203   }
204   return doGetUncompressedLength(data, uncompressedLength);
205 }
206
207 Optional<uint64_t> Codec::doGetUncompressedLength(
208     const folly::IOBuf*,
209     Optional<uint64_t> uncompressedLength) const {
210   return uncompressedLength;
211 }
212
213 bool StreamCodec::needsDataLength() const {
214   return doNeedsDataLength();
215 }
216
217 bool StreamCodec::doNeedsDataLength() const {
218   return false;
219 }
220
221 void StreamCodec::assertStateIs(State expected) const {
222   if (state_ != expected) {
223     throw std::logic_error(folly::to<std::string>(
224         "Codec: state is ", state_, "; expected state ", expected));
225   }
226 }
227
228 void StreamCodec::resetStream(Optional<uint64_t> uncompressedLength) {
229   state_ = State::RESET;
230   uncompressedLength_ = uncompressedLength;
231   doResetStream();
232 }
233
234 bool StreamCodec::compressStream(
235     ByteRange& input,
236     MutableByteRange& output,
237     StreamCodec::FlushOp flushOp) {
238   if (state_ == State::RESET && input.empty()) {
239     if (flushOp == StreamCodec::FlushOp::NONE) {
240       return false;
241     }
242     if (flushOp == StreamCodec::FlushOp::END &&
243         uncompressedLength().value_or(0) != 0) {
244       throw std::runtime_error("Codec: invalid uncompressed length");
245     }
246     return true;
247   }
248   if (state_ == State::RESET && !input.empty() &&
249       uncompressedLength() == uint64_t(0)) {
250     throw std::runtime_error("Codec: invalid uncompressed length");
251   }
252   // Handle input state transitions
253   switch (flushOp) {
254     case StreamCodec::FlushOp::NONE:
255       if (state_ == State::RESET) {
256         state_ = State::COMPRESS;
257       }
258       assertStateIs(State::COMPRESS);
259       break;
260     case StreamCodec::FlushOp::FLUSH:
261       if (state_ == State::RESET || state_ == State::COMPRESS) {
262         state_ = State::COMPRESS_FLUSH;
263       }
264       assertStateIs(State::COMPRESS_FLUSH);
265       break;
266     case StreamCodec::FlushOp::END:
267       if (state_ == State::RESET || state_ == State::COMPRESS) {
268         state_ = State::COMPRESS_END;
269       }
270       assertStateIs(State::COMPRESS_END);
271       break;
272   }
273   bool const done = doCompressStream(input, output, flushOp);
274   // Handle output state transitions
275   if (done) {
276     if (state_ == State::COMPRESS_FLUSH) {
277       state_ = State::COMPRESS;
278     } else if (state_ == State::COMPRESS_END) {
279       state_ = State::END;
280     }
281     // Check internal invariants
282     DCHECK(input.empty());
283     DCHECK(flushOp != StreamCodec::FlushOp::NONE);
284   }
285   return done;
286 }
287
288 bool StreamCodec::uncompressStream(
289     ByteRange& input,
290     MutableByteRange& output,
291     StreamCodec::FlushOp flushOp) {
292   if (state_ == State::RESET && input.empty()) {
293     if (uncompressedLength().value_or(0) == 0) {
294       return true;
295     }
296     return false;
297   }
298   // Handle input state transitions
299   if (state_ == State::RESET) {
300     state_ = State::UNCOMPRESS;
301   }
302   assertStateIs(State::UNCOMPRESS);
303   bool const done = doUncompressStream(input, output, flushOp);
304   // Handle output state transitions
305   if (done) {
306     state_ = State::END;
307   }
308   return done;
309 }
310
311 static std::unique_ptr<IOBuf> addOutputBuffer(
312     MutableByteRange& output,
313     uint64_t size) {
314   DCHECK(output.empty());
315   auto buffer = IOBuf::create(size);
316   buffer->append(buffer->capacity());
317   output = {buffer->writableData(), buffer->length()};
318   return buffer;
319 }
320
321 std::unique_ptr<IOBuf> StreamCodec::doCompress(IOBuf const* data) {
322   uint64_t const uncompressedLength = data->computeChainDataLength();
323   resetStream(uncompressedLength);
324   uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength);
325
326   auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB
327   auto constexpr kDefaultBufferLength = uint64_t(4) << 20; // 4 MB
328
329   MutableByteRange output;
330   auto buffer = addOutputBuffer(
331       output,
332       maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen
333                                                : kDefaultBufferLength);
334
335   // Compress the entire IOBuf chain into the IOBuf chain pointed to by buffer
336   IOBuf const* current = data;
337   ByteRange input{current->data(), current->length()};
338   StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
339   for (;;) {
340     while (input.empty() && current->next() != data) {
341       current = current->next();
342       input = {current->data(), current->length()};
343     }
344     if (current->next() == data) {
345       // This is the last input buffer so end the stream
346       flushOp = StreamCodec::FlushOp::END;
347     }
348     if (output.empty()) {
349       buffer->prependChain(addOutputBuffer(output, kDefaultBufferLength));
350     }
351     size_t const inputSize = input.size();
352     size_t const outputSize = output.size();
353     bool const done = compressStream(input, output, flushOp);
354     if (done) {
355       DCHECK(input.empty());
356       DCHECK(flushOp == StreamCodec::FlushOp::END);
357       DCHECK_EQ(current->next(), data);
358       break;
359     }
360     if (inputSize == input.size() && outputSize == output.size()) {
361       throw std::runtime_error("Codec: No forward progress made");
362     }
363   }
364   buffer->prev()->trimEnd(output.size());
365   return buffer;
366 }
367
368 static uint64_t computeBufferLength(
369     uint64_t const compressedLength,
370     uint64_t const blockSize) {
371   uint64_t constexpr kMaxBufferLength = uint64_t(4) << 20; // 4 MiB
372   uint64_t const goodBufferSize = 4 * std::max(blockSize, compressedLength);
373   return std::min(goodBufferSize, kMaxBufferLength);
374 }
375
376 std::unique_ptr<IOBuf> StreamCodec::doUncompress(
377     IOBuf const* data,
378     Optional<uint64_t> uncompressedLength) {
379   auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB
380   auto constexpr kBlockSize = uint64_t(128) << 10;
381   auto const defaultBufferLength =
382       computeBufferLength(data->computeChainDataLength(), kBlockSize);
383
384   uncompressedLength = getUncompressedLength(data, uncompressedLength);
385   resetStream(uncompressedLength);
386
387   MutableByteRange output;
388   auto buffer = addOutputBuffer(
389       output,
390       (uncompressedLength && *uncompressedLength <= kMaxSingleStepLength
391            ? *uncompressedLength
392            : defaultBufferLength));
393
394   // Uncompress the entire IOBuf chain into the IOBuf chain pointed to by buffer
395   IOBuf const* current = data;
396   ByteRange input{current->data(), current->length()};
397   StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
398   for (;;) {
399     while (input.empty() && current->next() != data) {
400       current = current->next();
401       input = {current->data(), current->length()};
402     }
403     if (current->next() == data) {
404       // Tell the uncompressor there is no more input (it may optimize)
405       flushOp = StreamCodec::FlushOp::END;
406     }
407     if (output.empty()) {
408       buffer->prependChain(addOutputBuffer(output, defaultBufferLength));
409     }
410     size_t const inputSize = input.size();
411     size_t const outputSize = output.size();
412     bool const done = uncompressStream(input, output, flushOp);
413     if (done) {
414       break;
415     }
416     if (inputSize == input.size() && outputSize == output.size()) {
417       throw std::runtime_error("Codec: Truncated data");
418     }
419   }
420   if (!input.empty()) {
421     throw std::runtime_error("Codec: Junk after end of data");
422   }
423
424   buffer->prev()->trimEnd(output.size());
425   if (uncompressedLength &&
426       *uncompressedLength != buffer->computeChainDataLength()) {
427     throw std::runtime_error("Codec: invalid uncompressed length");
428   }
429
430   return buffer;
431 }
432
433 namespace {
434
435 /**
436  * No compression
437  */
438 class NoCompressionCodec final : public Codec {
439  public:
440   static std::unique_ptr<Codec> create(int level, CodecType type);
441   explicit NoCompressionCodec(int level, CodecType type);
442
443  private:
444   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
445   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
446   std::unique_ptr<IOBuf> doUncompress(
447       const IOBuf* data,
448       Optional<uint64_t> uncompressedLength) override;
449 };
450
451 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
452   return std::make_unique<NoCompressionCodec>(level, type);
453 }
454
455 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
456   : Codec(type) {
457   DCHECK(type == CodecType::NO_COMPRESSION);
458   switch (level) {
459   case COMPRESSION_LEVEL_DEFAULT:
460   case COMPRESSION_LEVEL_FASTEST:
461   case COMPRESSION_LEVEL_BEST:
462     level = 0;
463   }
464   if (level != 0) {
465     throw std::invalid_argument(to<std::string>(
466         "NoCompressionCodec: invalid level ", level));
467   }
468 }
469
470 uint64_t NoCompressionCodec::doMaxCompressedLength(
471     uint64_t uncompressedLength) const {
472   return uncompressedLength;
473 }
474
475 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
476     const IOBuf* data) {
477   return data->clone();
478 }
479
480 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
481     const IOBuf* data,
482     Optional<uint64_t> uncompressedLength) {
483   if (uncompressedLength &&
484       data->computeChainDataLength() != *uncompressedLength) {
485     throw std::runtime_error(
486         to<std::string>("NoCompressionCodec: invalid uncompressed length"));
487   }
488   return data->clone();
489 }
490
491 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
492
493 namespace {
494
495 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
496   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
497   out->append(encodeVarint(val, out->writableTail()));
498 }
499
500 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
501   uint64_t val = 0;
502   int8_t b = 0;
503   for (int shift = 0; shift <= 63; shift += 7) {
504     b = cursor.read<int8_t>();
505     val |= static_cast<uint64_t>(b & 0x7f) << shift;
506     if (b >= 0) {
507       break;
508     }
509   }
510   if (b < 0) {
511     throw std::invalid_argument("Invalid varint value. Too big.");
512   }
513   return val;
514 }
515
516 } // namespace
517
518 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
519
520 namespace {
521 /**
522  * Reads sizeof(T) bytes, and returns false if not enough bytes are available.
523  * Returns true if the first n bytes are equal to prefix when interpreted as
524  * a little endian T.
525  */
526 template <typename T>
527 typename std::enable_if<std::is_unsigned<T>::value, bool>::type
528 dataStartsWithLE(const IOBuf* data, T prefix, uint64_t n = sizeof(T)) {
529   DCHECK_GT(n, 0);
530   DCHECK_LE(n, sizeof(T));
531   T value;
532   Cursor cursor{data};
533   if (!cursor.tryReadLE(value)) {
534     return false;
535   }
536   const T mask = n == sizeof(T) ? T(-1) : (T(1) << (8 * n)) - 1;
537   return prefix == (value & mask);
538 }
539
540 template <typename T>
541 typename std::enable_if<std::is_arithmetic<T>::value, std::string>::type
542 prefixToStringLE(T prefix, uint64_t n = sizeof(T)) {
543   DCHECK_GT(n, 0);
544   DCHECK_LE(n, sizeof(T));
545   prefix = Endian::little(prefix);
546   std::string result;
547   result.resize(n);
548   memcpy(&result[0], &prefix, n);
549   return result;
550 }
551 } // namespace
552
553 #if FOLLY_HAVE_LIBLZ4
554
555 /**
556  * LZ4 compression
557  */
558 class LZ4Codec final : public Codec {
559  public:
560   static std::unique_ptr<Codec> create(int level, CodecType type);
561   explicit LZ4Codec(int level, CodecType type);
562
563  private:
564   bool doNeedsUncompressedLength() const override;
565   uint64_t doMaxUncompressedLength() const override;
566   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
567
568   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
569
570   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
571   std::unique_ptr<IOBuf> doUncompress(
572       const IOBuf* data,
573       Optional<uint64_t> uncompressedLength) override;
574
575   bool highCompression_;
576 };
577
578 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
579   return std::make_unique<LZ4Codec>(level, type);
580 }
581
582 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
583   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
584
585   switch (level) {
586   case COMPRESSION_LEVEL_FASTEST:
587   case COMPRESSION_LEVEL_DEFAULT:
588     level = 1;
589     break;
590   case COMPRESSION_LEVEL_BEST:
591     level = 2;
592     break;
593   }
594   if (level < 1 || level > 2) {
595     throw std::invalid_argument(to<std::string>(
596         "LZ4Codec: invalid level: ", level));
597   }
598   highCompression_ = (level > 1);
599 }
600
601 bool LZ4Codec::doNeedsUncompressedLength() const {
602   return !encodeSize();
603 }
604
605 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
606 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
607 // here.
608 #ifndef LZ4_MAX_INPUT_SIZE
609 # define LZ4_MAX_INPUT_SIZE 0x7E000000
610 #endif
611
612 uint64_t LZ4Codec::doMaxUncompressedLength() const {
613   return LZ4_MAX_INPUT_SIZE;
614 }
615
616 uint64_t LZ4Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
617   return LZ4_compressBound(uncompressedLength) +
618       (encodeSize() ? kMaxVarintLength64 : 0);
619 }
620
621 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
622   IOBuf clone;
623   if (data->isChained()) {
624     // LZ4 doesn't support streaming, so we have to coalesce
625     clone = data->cloneCoalescedAsValue();
626     data = &clone;
627   }
628
629   auto out = IOBuf::create(maxCompressedLength(data->length()));
630   if (encodeSize()) {
631     encodeVarintToIOBuf(data->length(), out.get());
632   }
633
634   int n;
635   auto input = reinterpret_cast<const char*>(data->data());
636   auto output = reinterpret_cast<char*>(out->writableTail());
637   const auto inputLength = data->length();
638 #if LZ4_VERSION_NUMBER >= 10700
639   if (highCompression_) {
640     n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0);
641   } else {
642     n = LZ4_compress_default(input, output, inputLength, out->tailroom());
643   }
644 #else
645   if (highCompression_) {
646     n = LZ4_compressHC(input, output, inputLength);
647   } else {
648     n = LZ4_compress(input, output, inputLength);
649   }
650 #endif
651
652   CHECK_GE(n, 0);
653   CHECK_LE(n, out->capacity());
654
655   out->append(n);
656   return out;
657 }
658
659 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
660     const IOBuf* data,
661     Optional<uint64_t> uncompressedLength) {
662   IOBuf clone;
663   if (data->isChained()) {
664     // LZ4 doesn't support streaming, so we have to coalesce
665     clone = data->cloneCoalescedAsValue();
666     data = &clone;
667   }
668
669   folly::io::Cursor cursor(data);
670   uint64_t actualUncompressedLength;
671   if (encodeSize()) {
672     actualUncompressedLength = decodeVarintFromCursor(cursor);
673     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
674       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
675     }
676   } else {
677     // Invariants
678     DCHECK(uncompressedLength.hasValue());
679     DCHECK(*uncompressedLength <= maxUncompressedLength());
680     actualUncompressedLength = *uncompressedLength;
681   }
682
683   auto sp = StringPiece{cursor.peekBytes()};
684   auto out = IOBuf::create(actualUncompressedLength);
685   int n = LZ4_decompress_safe(
686       sp.data(),
687       reinterpret_cast<char*>(out->writableTail()),
688       sp.size(),
689       actualUncompressedLength);
690
691   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
692     throw std::runtime_error(to<std::string>(
693         "LZ4 decompression returned invalid value ", n));
694   }
695   out->append(actualUncompressedLength);
696   return out;
697 }
698
699 #if LZ4_VERSION_NUMBER >= 10301
700
701 class LZ4FrameCodec final : public Codec {
702  public:
703   static std::unique_ptr<Codec> create(int level, CodecType type);
704   explicit LZ4FrameCodec(int level, CodecType type);
705   ~LZ4FrameCodec() override;
706
707   std::vector<std::string> validPrefixes() const override;
708   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
709       const override;
710
711  private:
712   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
713
714   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
715   std::unique_ptr<IOBuf> doUncompress(
716       const IOBuf* data,
717       Optional<uint64_t> uncompressedLength) override;
718
719   // Reset the dctx_ if it is dirty or null.
720   void resetDCtx();
721
722   int level_;
723   LZ4F_decompressionContext_t dctx_{nullptr};
724   bool dirty_{false};
725 };
726
727 /* static */ std::unique_ptr<Codec> LZ4FrameCodec::create(
728     int level,
729     CodecType type) {
730   return std::make_unique<LZ4FrameCodec>(level, type);
731 }
732
733 static constexpr uint32_t kLZ4FrameMagicLE = 0x184D2204;
734
735 std::vector<std::string> LZ4FrameCodec::validPrefixes() const {
736   return {prefixToStringLE(kLZ4FrameMagicLE)};
737 }
738
739 bool LZ4FrameCodec::canUncompress(const IOBuf* data, Optional<uint64_t>) const {
740   return dataStartsWithLE(data, kLZ4FrameMagicLE);
741 }
742
743 uint64_t LZ4FrameCodec::doMaxCompressedLength(
744     uint64_t uncompressedLength) const {
745   LZ4F_preferences_t prefs{};
746   prefs.compressionLevel = level_;
747   prefs.frameInfo.contentSize = uncompressedLength;
748   return LZ4F_compressFrameBound(uncompressedLength, &prefs);
749 }
750
751 static size_t lz4FrameThrowOnError(size_t code) {
752   if (LZ4F_isError(code)) {
753     throw std::runtime_error(
754         to<std::string>("LZ4Frame error: ", LZ4F_getErrorName(code)));
755   }
756   return code;
757 }
758
759 void LZ4FrameCodec::resetDCtx() {
760   if (dctx_ && !dirty_) {
761     return;
762   }
763   if (dctx_) {
764     LZ4F_freeDecompressionContext(dctx_);
765   }
766   lz4FrameThrowOnError(LZ4F_createDecompressionContext(&dctx_, 100));
767   dirty_ = false;
768 }
769
770 LZ4FrameCodec::LZ4FrameCodec(int level, CodecType type) : Codec(type) {
771   DCHECK(type == CodecType::LZ4_FRAME);
772   switch (level) {
773     case COMPRESSION_LEVEL_FASTEST:
774     case COMPRESSION_LEVEL_DEFAULT:
775       level_ = 0;
776       break;
777     case COMPRESSION_LEVEL_BEST:
778       level_ = 16;
779       break;
780     default:
781       level_ = level;
782       break;
783   }
784 }
785
786 LZ4FrameCodec::~LZ4FrameCodec() {
787   if (dctx_) {
788     LZ4F_freeDecompressionContext(dctx_);
789   }
790 }
791
792 std::unique_ptr<IOBuf> LZ4FrameCodec::doCompress(const IOBuf* data) {
793   // LZ4 Frame compression doesn't support streaming so we have to coalesce
794   IOBuf clone;
795   if (data->isChained()) {
796     clone = data->cloneCoalescedAsValue();
797     data = &clone;
798   }
799   // Set preferences
800   const auto uncompressedLength = data->length();
801   LZ4F_preferences_t prefs{};
802   prefs.compressionLevel = level_;
803   prefs.frameInfo.contentSize = uncompressedLength;
804   // Compress
805   auto buf = IOBuf::create(maxCompressedLength(uncompressedLength));
806   const size_t written = lz4FrameThrowOnError(LZ4F_compressFrame(
807       buf->writableTail(),
808       buf->tailroom(),
809       data->data(),
810       data->length(),
811       &prefs));
812   buf->append(written);
813   return buf;
814 }
815
816 std::unique_ptr<IOBuf> LZ4FrameCodec::doUncompress(
817     const IOBuf* data,
818     Optional<uint64_t> uncompressedLength) {
819   // Reset the dctx if any errors have occurred
820   resetDCtx();
821   // Coalesce the data
822   ByteRange in = *data->begin();
823   IOBuf clone;
824   if (data->isChained()) {
825     clone = data->cloneCoalescedAsValue();
826     in = clone.coalesce();
827   }
828   data = nullptr;
829   // Select decompression options
830   LZ4F_decompressOptions_t options;
831   options.stableDst = 1;
832   // Select blockSize and growthSize for the IOBufQueue
833   IOBufQueue queue(IOBufQueue::cacheChainLength());
834   auto blockSize = uint64_t{64} << 10;
835   auto growthSize = uint64_t{4} << 20;
836   if (uncompressedLength) {
837     // Allocate uncompressedLength in one chunk (up to 64 MB)
838     const auto allocateSize = std::min(*uncompressedLength, uint64_t{64} << 20);
839     queue.preallocate(allocateSize, allocateSize);
840     blockSize = std::min(*uncompressedLength, blockSize);
841     growthSize = std::min(*uncompressedLength, growthSize);
842   } else {
843     // Reduce growthSize for small data
844     const auto guessUncompressedLen =
845         4 * std::max<uint64_t>(blockSize, in.size());
846     growthSize = std::min(guessUncompressedLen, growthSize);
847   }
848   // Once LZ4_decompress() is called, the dctx_ cannot be reused until it
849   // returns 0
850   dirty_ = true;
851   // Decompress until the frame is over
852   size_t code = 0;
853   do {
854     // Allocate enough space to decompress at least a block
855     void* out;
856     size_t outSize;
857     std::tie(out, outSize) = queue.preallocate(blockSize, growthSize);
858     // Decompress
859     size_t inSize = in.size();
860     code = lz4FrameThrowOnError(
861         LZ4F_decompress(dctx_, out, &outSize, in.data(), &inSize, &options));
862     if (in.empty() && outSize == 0 && code != 0) {
863       // We passed no input, no output was produced, and the frame isn't over
864       // No more forward progress is possible
865       throw std::runtime_error("LZ4Frame error: Incomplete frame");
866     }
867     in.uncheckedAdvance(inSize);
868     queue.postallocate(outSize);
869   } while (code != 0);
870   // At this point the decompression context can be reused
871   dirty_ = false;
872   if (uncompressedLength && queue.chainLength() != *uncompressedLength) {
873     throw std::runtime_error("LZ4Frame error: Invalid uncompressedLength");
874   }
875   return queue.move();
876 }
877
878 #endif // LZ4_VERSION_NUMBER >= 10301
879 #endif // FOLLY_HAVE_LIBLZ4
880
881 #if FOLLY_HAVE_LIBSNAPPY
882
883 /**
884  * Snappy compression
885  */
886
887 /**
888  * Implementation of snappy::Source that reads from a IOBuf chain.
889  */
890 class IOBufSnappySource final : public snappy::Source {
891  public:
892   explicit IOBufSnappySource(const IOBuf* data);
893   size_t Available() const override;
894   const char* Peek(size_t* len) override;
895   void Skip(size_t n) override;
896  private:
897   size_t available_;
898   io::Cursor cursor_;
899 };
900
901 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
902   : available_(data->computeChainDataLength()),
903     cursor_(data) {
904 }
905
906 size_t IOBufSnappySource::Available() const {
907   return available_;
908 }
909
910 const char* IOBufSnappySource::Peek(size_t* len) {
911   auto sp = StringPiece{cursor_.peekBytes()};
912   *len = sp.size();
913   return sp.data();
914 }
915
916 void IOBufSnappySource::Skip(size_t n) {
917   CHECK_LE(n, available_);
918   cursor_.skip(n);
919   available_ -= n;
920 }
921
922 class SnappyCodec final : public Codec {
923  public:
924   static std::unique_ptr<Codec> create(int level, CodecType type);
925   explicit SnappyCodec(int level, CodecType type);
926
927  private:
928   uint64_t doMaxUncompressedLength() const override;
929   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
930   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
931   std::unique_ptr<IOBuf> doUncompress(
932       const IOBuf* data,
933       Optional<uint64_t> uncompressedLength) override;
934 };
935
936 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
937   return std::make_unique<SnappyCodec>(level, type);
938 }
939
940 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
941   DCHECK(type == CodecType::SNAPPY);
942   switch (level) {
943   case COMPRESSION_LEVEL_FASTEST:
944   case COMPRESSION_LEVEL_DEFAULT:
945   case COMPRESSION_LEVEL_BEST:
946     level = 1;
947   }
948   if (level != 1) {
949     throw std::invalid_argument(to<std::string>(
950         "SnappyCodec: invalid level: ", level));
951   }
952 }
953
954 uint64_t SnappyCodec::doMaxUncompressedLength() const {
955   // snappy.h uses uint32_t for lengths, so there's that.
956   return std::numeric_limits<uint32_t>::max();
957 }
958
959 uint64_t SnappyCodec::doMaxCompressedLength(uint64_t uncompressedLength) const {
960   return snappy::MaxCompressedLength(uncompressedLength);
961 }
962
963 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
964   IOBufSnappySource source(data);
965   auto out = IOBuf::create(maxCompressedLength(source.Available()));
966
967   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
968       out->writableTail()));
969
970   size_t n = snappy::Compress(&source, &sink);
971
972   CHECK_LE(n, out->capacity());
973   out->append(n);
974   return out;
975 }
976
977 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(
978     const IOBuf* data,
979     Optional<uint64_t> uncompressedLength) {
980   uint32_t actualUncompressedLength = 0;
981
982   {
983     IOBufSnappySource source(data);
984     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
985       throw std::runtime_error("snappy::GetUncompressedLength failed");
986     }
987     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
988       throw std::runtime_error("snappy: invalid uncompressed length");
989     }
990   }
991
992   auto out = IOBuf::create(actualUncompressedLength);
993
994   {
995     IOBufSnappySource source(data);
996     if (!snappy::RawUncompress(&source,
997                                reinterpret_cast<char*>(out->writableTail()))) {
998       throw std::runtime_error("snappy::RawUncompress failed");
999     }
1000   }
1001
1002   out->append(actualUncompressedLength);
1003   return out;
1004 }
1005
1006 #endif  // FOLLY_HAVE_LIBSNAPPY
1007
1008 #if FOLLY_HAVE_LIBZ
1009 /**
1010  * Zlib codec
1011  */
1012 class ZlibStreamCodec final : public StreamCodec {
1013  public:
1014   static std::unique_ptr<Codec> createCodec(int level, CodecType type);
1015   static std::unique_ptr<StreamCodec> createStream(int level, CodecType type);
1016   explicit ZlibStreamCodec(int level, CodecType type);
1017   ~ZlibStreamCodec() override;
1018
1019   std::vector<std::string> validPrefixes() const override;
1020   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1021       const override;
1022
1023  private:
1024   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1025
1026   void doResetStream() override;
1027   bool doCompressStream(
1028       ByteRange& input,
1029       MutableByteRange& output,
1030       StreamCodec::FlushOp flush) override;
1031   bool doUncompressStream(
1032       ByteRange& input,
1033       MutableByteRange& output,
1034       StreamCodec::FlushOp flush) override;
1035
1036   void resetDeflateStream();
1037   void resetInflateStream();
1038
1039   Optional<z_stream> deflateStream_{};
1040   Optional<z_stream> inflateStream_{};
1041   int level_;
1042   bool needReset_{true};
1043 };
1044
1045 static constexpr uint16_t kGZIPMagicLE = 0x8B1F;
1046
1047 std::vector<std::string> ZlibStreamCodec::validPrefixes() const {
1048   if (type() == CodecType::ZLIB) {
1049     // Zlib streams start with a 2 byte header.
1050     //
1051     //   0   1
1052     // +---+---+
1053     // |CMF|FLG|
1054     // +---+---+
1055     //
1056     // We won't restrict the values of any sub-fields except as described below.
1057     //
1058     // The lowest 4 bits of CMF is the compression method (CM).
1059     // CM == 0x8 is the deflate compression method, which is currently the only
1060     // supported compression method, so any valid prefix must have CM == 0x8.
1061     //
1062     // The lowest 5 bits of FLG is FCHECK.
1063     // FCHECK must be such that the two header bytes are a multiple of 31 when
1064     // interpreted as a big endian 16-bit number.
1065     std::vector<std::string> result;
1066     // 16 values for the first byte, 8 values for the second byte.
1067     // There are also 4 combinations where both 0x00 and 0x1F work as FCHECK.
1068     result.reserve(132);
1069     // Select all values for the CMF byte that use the deflate algorithm 0x8.
1070     for (uint32_t first = 0x0800; first <= 0xF800; first += 0x1000) {
1071       // Select all values for the FLG, but leave FCHECK as 0 since it's fixed.
1072       for (uint32_t second = 0x00; second <= 0xE0; second += 0x20) {
1073         uint16_t prefix = first | second;
1074         // Compute FCHECK.
1075         prefix += 31 - (prefix % 31);
1076         result.push_back(prefixToStringLE(Endian::big(prefix)));
1077         // zlib won't produce this, but it is a valid prefix.
1078         if ((prefix & 0x1F) == 31) {
1079           prefix -= 31;
1080           result.push_back(prefixToStringLE(Endian::big(prefix)));
1081         }
1082       }
1083     }
1084     return result;
1085   } else {
1086     // The gzip frame starts with 2 magic bytes.
1087     return {prefixToStringLE(kGZIPMagicLE)};
1088   }
1089 }
1090
1091 bool ZlibStreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
1092     const {
1093   if (type() == CodecType::ZLIB) {
1094     uint16_t value;
1095     Cursor cursor{data};
1096     if (!cursor.tryReadBE(value)) {
1097       return false;
1098     }
1099     // zlib compressed if using deflate and is a multiple of 31.
1100     return (value & 0x0F00) == 0x0800 && value % 31 == 0;
1101   } else {
1102     return dataStartsWithLE(data, kGZIPMagicLE);
1103   }
1104 }
1105
1106 uint64_t ZlibStreamCodec::doMaxCompressedLength(
1107     uint64_t uncompressedLength) const {
1108   return deflateBound(nullptr, uncompressedLength);
1109 }
1110
1111 std::unique_ptr<Codec> ZlibStreamCodec::createCodec(int level, CodecType type) {
1112   return std::make_unique<ZlibStreamCodec>(level, type);
1113 }
1114
1115 std::unique_ptr<StreamCodec> ZlibStreamCodec::createStream(
1116     int level,
1117     CodecType type) {
1118   return std::make_unique<ZlibStreamCodec>(level, type);
1119 }
1120
1121 ZlibStreamCodec::ZlibStreamCodec(int level, CodecType type)
1122     : StreamCodec(type) {
1123   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
1124   switch (level) {
1125     case COMPRESSION_LEVEL_FASTEST:
1126       level = 1;
1127       break;
1128     case COMPRESSION_LEVEL_DEFAULT:
1129       level = Z_DEFAULT_COMPRESSION;
1130       break;
1131     case COMPRESSION_LEVEL_BEST:
1132       level = 9;
1133       break;
1134   }
1135   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
1136     throw std::invalid_argument(
1137         to<std::string>("ZlibStreamCodec: invalid level: ", level));
1138   }
1139   level_ = level;
1140 }
1141
1142 ZlibStreamCodec::~ZlibStreamCodec() {
1143   if (deflateStream_) {
1144     deflateEnd(deflateStream_.get_pointer());
1145     deflateStream_.clear();
1146   }
1147   if (inflateStream_) {
1148     inflateEnd(inflateStream_.get_pointer());
1149     inflateStream_.clear();
1150   }
1151 }
1152
1153 void ZlibStreamCodec::doResetStream() {
1154   needReset_ = true;
1155 }
1156
1157 void ZlibStreamCodec::resetDeflateStream() {
1158   if (deflateStream_) {
1159     int const rc = deflateReset(deflateStream_.get_pointer());
1160     if (rc != Z_OK) {
1161       deflateStream_.clear();
1162       throw std::runtime_error(
1163           to<std::string>("ZlibStreamCodec: deflateReset error: ", rc));
1164     }
1165     return;
1166   }
1167   deflateStream_ = z_stream{};
1168   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
1169   // base two logarithm of the maximum window size (...) The default value is
1170   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
1171   // around the compressed data instead of a zlib wrapper. The gzip header
1172   // will have no file name, no extra data, no comment, no modification time
1173   // (set to zero), no header crc, and the operating system will be set to 255
1174   // (unknown)."
1175   int const windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
1176   // All other parameters (method, memLevel, strategy) get default values from
1177   // the zlib manual.
1178   int const rc = deflateInit2(
1179       deflateStream_.get_pointer(),
1180       level_,
1181       Z_DEFLATED,
1182       windowBits,
1183       /* memLevel */ 8,
1184       Z_DEFAULT_STRATEGY);
1185   if (rc != Z_OK) {
1186     deflateStream_.clear();
1187     throw std::runtime_error(
1188         to<std::string>("ZlibStreamCodec: deflateInit error: ", rc));
1189   }
1190 }
1191
1192 void ZlibStreamCodec::resetInflateStream() {
1193   if (inflateStream_) {
1194     int const rc = inflateReset(inflateStream_.get_pointer());
1195     if (rc != Z_OK) {
1196       inflateStream_.clear();
1197       throw std::runtime_error(
1198           to<std::string>("ZlibStreamCodec: inflateReset error: ", rc));
1199     }
1200     return;
1201   }
1202   inflateStream_ = z_stream{};
1203   // "The windowBits parameter is the base two logarithm of the maximum window
1204   // size (...) The default value is 15 (...) add 16 to decode only the gzip
1205   // format (the zlib format will return a Z_DATA_ERROR)."
1206   int const windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
1207   int const rc = inflateInit2(inflateStream_.get_pointer(), windowBits);
1208   if (rc != Z_OK) {
1209     inflateStream_.clear();
1210     throw std::runtime_error(
1211         to<std::string>("ZlibStreamCodec: inflateInit error: ", rc));
1212   }
1213 }
1214
1215 static int zlibTranslateFlush(StreamCodec::FlushOp flush) {
1216   switch (flush) {
1217     case StreamCodec::FlushOp::NONE:
1218       return Z_NO_FLUSH;
1219     case StreamCodec::FlushOp::FLUSH:
1220       return Z_SYNC_FLUSH;
1221     case StreamCodec::FlushOp::END:
1222       return Z_FINISH;
1223     default:
1224       throw std::invalid_argument("ZlibStreamCodec: Invalid flush");
1225   }
1226 }
1227
1228 static int zlibThrowOnError(int rc) {
1229   switch (rc) {
1230     case Z_OK:
1231     case Z_BUF_ERROR:
1232     case Z_STREAM_END:
1233       return rc;
1234     default:
1235       throw std::runtime_error(to<std::string>("ZlibStreamCodec: error: ", rc));
1236   }
1237 }
1238
1239 bool ZlibStreamCodec::doCompressStream(
1240     ByteRange& input,
1241     MutableByteRange& output,
1242     StreamCodec::FlushOp flush) {
1243   if (needReset_) {
1244     resetDeflateStream();
1245     needReset_ = false;
1246   }
1247   DCHECK(deflateStream_.hasValue());
1248   // zlib will return Z_STREAM_ERROR if output.data() is null.
1249   if (output.data() == nullptr) {
1250     return false;
1251   }
1252   deflateStream_->next_in = const_cast<uint8_t*>(input.data());
1253   deflateStream_->avail_in = input.size();
1254   deflateStream_->next_out = output.data();
1255   deflateStream_->avail_out = output.size();
1256   SCOPE_EXIT {
1257     input.uncheckedAdvance(input.size() - deflateStream_->avail_in);
1258     output.uncheckedAdvance(output.size() - deflateStream_->avail_out);
1259   };
1260   int const rc = zlibThrowOnError(
1261       deflate(deflateStream_.get_pointer(), zlibTranslateFlush(flush)));
1262   switch (flush) {
1263     case StreamCodec::FlushOp::NONE:
1264       return false;
1265     case StreamCodec::FlushOp::FLUSH:
1266       return deflateStream_->avail_in == 0 && deflateStream_->avail_out != 0;
1267     case StreamCodec::FlushOp::END:
1268       return rc == Z_STREAM_END;
1269     default:
1270       throw std::invalid_argument("ZlibStreamCodec: Invalid flush");
1271   }
1272 }
1273
1274 bool ZlibStreamCodec::doUncompressStream(
1275     ByteRange& input,
1276     MutableByteRange& output,
1277     StreamCodec::FlushOp flush) {
1278   if (needReset_) {
1279     resetInflateStream();
1280     needReset_ = false;
1281   }
1282   DCHECK(inflateStream_.hasValue());
1283   // zlib will return Z_STREAM_ERROR if output.data() is null.
1284   if (output.data() == nullptr) {
1285     return false;
1286   }
1287   inflateStream_->next_in = const_cast<uint8_t*>(input.data());
1288   inflateStream_->avail_in = input.size();
1289   inflateStream_->next_out = output.data();
1290   inflateStream_->avail_out = output.size();
1291   SCOPE_EXIT {
1292     input.advance(input.size() - inflateStream_->avail_in);
1293     output.advance(output.size() - inflateStream_->avail_out);
1294   };
1295   int const rc = zlibThrowOnError(
1296       inflate(inflateStream_.get_pointer(), zlibTranslateFlush(flush)));
1297   return rc == Z_STREAM_END;
1298 }
1299
1300 #endif // FOLLY_HAVE_LIBZ
1301
1302 #if FOLLY_HAVE_LIBLZMA
1303
1304 /**
1305  * LZMA2 compression
1306  */
1307 class LZMA2Codec final : public Codec {
1308  public:
1309   static std::unique_ptr<Codec> create(int level, CodecType type);
1310   explicit LZMA2Codec(int level, CodecType type);
1311
1312   std::vector<std::string> validPrefixes() const override;
1313   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1314       const override;
1315
1316  private:
1317   bool doNeedsUncompressedLength() const override;
1318   uint64_t doMaxUncompressedLength() const override;
1319   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1320
1321   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
1322
1323   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
1324   std::unique_ptr<IOBuf> doUncompress(
1325       const IOBuf* data,
1326       Optional<uint64_t> uncompressedLength) override;
1327
1328   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
1329   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
1330
1331   int level_;
1332 };
1333
1334 static constexpr uint64_t kLZMA2MagicLE = 0x005A587A37FD;
1335 static constexpr unsigned kLZMA2MagicBytes = 6;
1336
1337 std::vector<std::string> LZMA2Codec::validPrefixes() const {
1338   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1339     return {};
1340   }
1341   return {prefixToStringLE(kLZMA2MagicLE, kLZMA2MagicBytes)};
1342 }
1343
1344 bool LZMA2Codec::canUncompress(const IOBuf* data, Optional<uint64_t>) const {
1345   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1346     return false;
1347   }
1348   // Returns false for all inputs less than 8 bytes.
1349   // This is okay, because no valid LZMA2 streams are less than 8 bytes.
1350   return dataStartsWithLE(data, kLZMA2MagicLE, kLZMA2MagicBytes);
1351 }
1352
1353 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
1354   return std::make_unique<LZMA2Codec>(level, type);
1355 }
1356
1357 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
1358   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
1359   switch (level) {
1360   case COMPRESSION_LEVEL_FASTEST:
1361     level = 0;
1362     break;
1363   case COMPRESSION_LEVEL_DEFAULT:
1364     level = LZMA_PRESET_DEFAULT;
1365     break;
1366   case COMPRESSION_LEVEL_BEST:
1367     level = 9;
1368     break;
1369   }
1370   if (level < 0 || level > 9) {
1371     throw std::invalid_argument(to<std::string>(
1372         "LZMA2Codec: invalid level: ", level));
1373   }
1374   level_ = level;
1375 }
1376
1377 bool LZMA2Codec::doNeedsUncompressedLength() const {
1378   return false;
1379 }
1380
1381 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
1382   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
1383   return uint64_t(1) << 63;
1384 }
1385
1386 uint64_t LZMA2Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
1387   return lzma_stream_buffer_bound(uncompressedLength) +
1388       (encodeSize() ? kMaxVarintLength64 : 0);
1389 }
1390
1391 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
1392     lzma_stream* stream,
1393     size_t length) {
1394
1395   CHECK_EQ(stream->avail_out, 0);
1396
1397   auto buf = IOBuf::create(length);
1398   buf->append(buf->capacity());
1399
1400   stream->next_out = buf->writableData();
1401   stream->avail_out = buf->length();
1402
1403   return buf;
1404 }
1405
1406 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
1407   lzma_ret rc;
1408   lzma_stream stream = LZMA_STREAM_INIT;
1409
1410   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
1411   if (rc != LZMA_OK) {
1412     throw std::runtime_error(folly::to<std::string>(
1413       "LZMA2Codec: lzma_easy_encoder error: ", rc));
1414   }
1415
1416   SCOPE_EXIT { lzma_end(&stream); };
1417
1418   uint64_t uncompressedLength = data->computeChainDataLength();
1419   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
1420
1421   // Max 64MiB in one go
1422   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
1423   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
1424
1425   auto out = addOutputBuffer(
1426     &stream,
1427     (maxCompressedLength <= maxSingleStepLength ?
1428      maxCompressedLength :
1429      defaultBufferLength));
1430
1431   if (encodeSize()) {
1432     auto size = IOBuf::createCombined(kMaxVarintLength64);
1433     encodeVarintToIOBuf(uncompressedLength, size.get());
1434     size->appendChain(std::move(out));
1435     out = std::move(size);
1436   }
1437
1438   for (auto& range : *data) {
1439     if (range.empty()) {
1440       continue;
1441     }
1442
1443     stream.next_in = const_cast<uint8_t*>(range.data());
1444     stream.avail_in = range.size();
1445
1446     while (stream.avail_in != 0) {
1447       if (stream.avail_out == 0) {
1448         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1449       }
1450
1451       rc = lzma_code(&stream, LZMA_RUN);
1452
1453       if (rc != LZMA_OK) {
1454         throw std::runtime_error(folly::to<std::string>(
1455           "LZMA2Codec: lzma_code error: ", rc));
1456       }
1457     }
1458   }
1459
1460   do {
1461     if (stream.avail_out == 0) {
1462       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1463     }
1464
1465     rc = lzma_code(&stream, LZMA_FINISH);
1466   } while (rc == LZMA_OK);
1467
1468   if (rc != LZMA_STREAM_END) {
1469     throw std::runtime_error(folly::to<std::string>(
1470       "LZMA2Codec: lzma_code ended with error: ", rc));
1471   }
1472
1473   out->prev()->trimEnd(stream.avail_out);
1474
1475   return out;
1476 }
1477
1478 bool LZMA2Codec::doInflate(lzma_stream* stream,
1479                           IOBuf* head,
1480                           size_t bufferLength) {
1481   if (stream->avail_out == 0) {
1482     head->prependChain(addOutputBuffer(stream, bufferLength));
1483   }
1484
1485   lzma_ret rc = lzma_code(stream, LZMA_RUN);
1486
1487   switch (rc) {
1488   case LZMA_OK:
1489     break;
1490   case LZMA_STREAM_END:
1491     return true;
1492   default:
1493     throw std::runtime_error(to<std::string>(
1494         "LZMA2Codec: lzma_code error: ", rc));
1495   }
1496
1497   return false;
1498 }
1499
1500 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(
1501     const IOBuf* data,
1502     Optional<uint64_t> uncompressedLength) {
1503   lzma_ret rc;
1504   lzma_stream stream = LZMA_STREAM_INIT;
1505
1506   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
1507   if (rc != LZMA_OK) {
1508     throw std::runtime_error(folly::to<std::string>(
1509       "LZMA2Codec: lzma_auto_decoder error: ", rc));
1510   }
1511
1512   SCOPE_EXIT { lzma_end(&stream); };
1513
1514   // Max 64MiB in one go
1515   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB
1516   constexpr uint32_t defaultBufferLength = uint32_t(256) << 10; // 256 KiB
1517
1518   folly::io::Cursor cursor(data);
1519   if (encodeSize()) {
1520     const uint64_t actualUncompressedLength = decodeVarintFromCursor(cursor);
1521     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
1522       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
1523     }
1524     uncompressedLength = actualUncompressedLength;
1525   }
1526
1527   auto out = addOutputBuffer(
1528       &stream,
1529       ((uncompressedLength && *uncompressedLength <= maxSingleStepLength)
1530            ? *uncompressedLength
1531            : defaultBufferLength));
1532
1533   bool streamEnd = false;
1534   auto buf = cursor.peekBytes();
1535   while (!buf.empty()) {
1536     stream.next_in = const_cast<uint8_t*>(buf.data());
1537     stream.avail_in = buf.size();
1538
1539     while (stream.avail_in != 0) {
1540       if (streamEnd) {
1541         throw std::runtime_error(to<std::string>(
1542             "LZMA2Codec: junk after end of data"));
1543       }
1544
1545       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1546     }
1547
1548     cursor.skip(buf.size());
1549     buf = cursor.peekBytes();
1550   }
1551
1552   while (!streamEnd) {
1553     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1554   }
1555
1556   out->prev()->trimEnd(stream.avail_out);
1557
1558   if (uncompressedLength && *uncompressedLength != stream.total_out) {
1559     throw std::runtime_error(
1560         to<std::string>("LZMA2Codec: invalid uncompressed length"));
1561   }
1562
1563   return out;
1564 }
1565
1566 #endif  // FOLLY_HAVE_LIBLZMA
1567
1568 #ifdef FOLLY_HAVE_LIBZSTD
1569
1570 namespace {
1571 void zstdFreeCStream(ZSTD_CStream* zcs) {
1572   ZSTD_freeCStream(zcs);
1573 }
1574
1575 void zstdFreeDStream(ZSTD_DStream* zds) {
1576   ZSTD_freeDStream(zds);
1577 }
1578 }
1579
1580 /**
1581  * ZSTD compression
1582  */
1583 class ZSTDStreamCodec final : public StreamCodec {
1584  public:
1585   static std::unique_ptr<Codec> createCodec(int level, CodecType);
1586   static std::unique_ptr<StreamCodec> createStream(int level, CodecType);
1587   explicit ZSTDStreamCodec(int level, CodecType type);
1588
1589   std::vector<std::string> validPrefixes() const override;
1590   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1591       const override;
1592
1593  private:
1594   bool doNeedsUncompressedLength() const override;
1595   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1596   Optional<uint64_t> doGetUncompressedLength(
1597       IOBuf const* data,
1598       Optional<uint64_t> uncompressedLength) const override;
1599
1600   void doResetStream() override;
1601   bool doCompressStream(
1602       ByteRange& input,
1603       MutableByteRange& output,
1604       StreamCodec::FlushOp flushOp) override;
1605   bool doUncompressStream(
1606       ByteRange& input,
1607       MutableByteRange& output,
1608       StreamCodec::FlushOp flushOp) override;
1609
1610   void resetCStream();
1611   void resetDStream();
1612
1613   bool tryBlockCompress(ByteRange& input, MutableByteRange& output) const;
1614   bool tryBlockUncompress(ByteRange& input, MutableByteRange& output) const;
1615
1616   int level_;
1617   bool needReset_{true};
1618   std::unique_ptr<
1619       ZSTD_CStream,
1620       folly::static_function_deleter<ZSTD_CStream, &zstdFreeCStream>>
1621       cstream_{nullptr};
1622   std::unique_ptr<
1623       ZSTD_DStream,
1624       folly::static_function_deleter<ZSTD_DStream, &zstdFreeDStream>>
1625       dstream_{nullptr};
1626 };
1627
1628 static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528;
1629
1630 std::vector<std::string> ZSTDStreamCodec::validPrefixes() const {
1631   return {prefixToStringLE(kZSTDMagicLE)};
1632 }
1633
1634 bool ZSTDStreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
1635     const {
1636   return dataStartsWithLE(data, kZSTDMagicLE);
1637 }
1638
1639 std::unique_ptr<Codec> ZSTDStreamCodec::createCodec(int level, CodecType type) {
1640   return make_unique<ZSTDStreamCodec>(level, type);
1641 }
1642
1643 std::unique_ptr<StreamCodec> ZSTDStreamCodec::createStream(
1644     int level,
1645     CodecType type) {
1646   return make_unique<ZSTDStreamCodec>(level, type);
1647 }
1648
1649 ZSTDStreamCodec::ZSTDStreamCodec(int level, CodecType type)
1650     : StreamCodec(type) {
1651   DCHECK(type == CodecType::ZSTD);
1652   switch (level) {
1653     case COMPRESSION_LEVEL_FASTEST:
1654       level = 1;
1655       break;
1656     case COMPRESSION_LEVEL_DEFAULT:
1657       level = 1;
1658       break;
1659     case COMPRESSION_LEVEL_BEST:
1660       level = 19;
1661       break;
1662   }
1663   if (level < 1 || level > ZSTD_maxCLevel()) {
1664     throw std::invalid_argument(
1665         to<std::string>("ZSTD: invalid level: ", level));
1666   }
1667   level_ = level;
1668 }
1669
1670 bool ZSTDStreamCodec::doNeedsUncompressedLength() const {
1671   return false;
1672 }
1673
1674 uint64_t ZSTDStreamCodec::doMaxCompressedLength(
1675     uint64_t uncompressedLength) const {
1676   return ZSTD_compressBound(uncompressedLength);
1677 }
1678
1679 void zstdThrowIfError(size_t rc) {
1680   if (!ZSTD_isError(rc)) {
1681     return;
1682   }
1683   throw std::runtime_error(
1684       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1685 }
1686
1687 Optional<uint64_t> ZSTDStreamCodec::doGetUncompressedLength(
1688     IOBuf const* data,
1689     Optional<uint64_t> uncompressedLength) const {
1690   // Read decompressed size from frame if available in first IOBuf.
1691   auto const decompressedSize =
1692       ZSTD_getDecompressedSize(data->data(), data->length());
1693   if (decompressedSize != 0) {
1694     if (uncompressedLength && *uncompressedLength != decompressedSize) {
1695       throw std::runtime_error("ZSTD: invalid uncompressed length");
1696     }
1697     uncompressedLength = decompressedSize;
1698   }
1699   return uncompressedLength;
1700 }
1701
1702 void ZSTDStreamCodec::doResetStream() {
1703   needReset_ = true;
1704 }
1705
1706 bool ZSTDStreamCodec::tryBlockCompress(
1707     ByteRange& input,
1708     MutableByteRange& output) const {
1709   DCHECK(needReset_);
1710   // We need to know that we have enough output space to use block compression
1711   if (output.size() < ZSTD_compressBound(input.size())) {
1712     return false;
1713   }
1714   size_t const length = ZSTD_compress(
1715       output.data(), output.size(), input.data(), input.size(), level_);
1716   zstdThrowIfError(length);
1717   input.uncheckedAdvance(input.size());
1718   output.uncheckedAdvance(length);
1719   return true;
1720 }
1721
1722 void ZSTDStreamCodec::resetCStream() {
1723   if (!cstream_) {
1724     cstream_.reset(ZSTD_createCStream());
1725     if (!cstream_) {
1726       throw std::bad_alloc{};
1727     }
1728   }
1729   // Advanced API usage works for all supported versions of zstd.
1730   // Required to set contentSizeFlag.
1731   auto params = ZSTD_getParams(level_, uncompressedLength().value_or(0), 0);
1732   params.fParams.contentSizeFlag = uncompressedLength().hasValue();
1733   zstdThrowIfError(ZSTD_initCStream_advanced(
1734       cstream_.get(), nullptr, 0, params, uncompressedLength().value_or(0)));
1735 }
1736
1737 bool ZSTDStreamCodec::doCompressStream(
1738     ByteRange& input,
1739     MutableByteRange& output,
1740     StreamCodec::FlushOp flushOp) {
1741   if (needReset_) {
1742     // If we are given all the input in one chunk try to use block compression
1743     if (flushOp == StreamCodec::FlushOp::END &&
1744         tryBlockCompress(input, output)) {
1745       return true;
1746     }
1747     resetCStream();
1748     needReset_ = false;
1749   }
1750   ZSTD_inBuffer in = {input.data(), input.size(), 0};
1751   ZSTD_outBuffer out = {output.data(), output.size(), 0};
1752   SCOPE_EXIT {
1753     input.uncheckedAdvance(in.pos);
1754     output.uncheckedAdvance(out.pos);
1755   };
1756   if (flushOp == StreamCodec::FlushOp::NONE || !input.empty()) {
1757     zstdThrowIfError(ZSTD_compressStream(cstream_.get(), &out, &in));
1758   }
1759   if (in.pos == in.size && flushOp != StreamCodec::FlushOp::NONE) {
1760     size_t rc;
1761     switch (flushOp) {
1762       case StreamCodec::FlushOp::FLUSH:
1763         rc = ZSTD_flushStream(cstream_.get(), &out);
1764         break;
1765       case StreamCodec::FlushOp::END:
1766         rc = ZSTD_endStream(cstream_.get(), &out);
1767         break;
1768       default:
1769         throw std::invalid_argument("ZSTD: invalid FlushOp");
1770     }
1771     zstdThrowIfError(rc);
1772     if (rc == 0) {
1773       return true;
1774     }
1775   }
1776   return false;
1777 }
1778
1779 bool ZSTDStreamCodec::tryBlockUncompress(
1780     ByteRange& input,
1781     MutableByteRange& output) const {
1782   DCHECK(needReset_);
1783 #if ZSTD_VERSION_NUMBER < 10104
1784   // We require ZSTD_findFrameCompressedSize() to perform this optimization.
1785   return false;
1786 #else
1787   // We need to know the uncompressed length and have enough output space.
1788   if (!uncompressedLength() || output.size() < *uncompressedLength()) {
1789     return false;
1790   }
1791   size_t const compressedLength =
1792       ZSTD_findFrameCompressedSize(input.data(), input.size());
1793   zstdThrowIfError(compressedLength);
1794   size_t const length = ZSTD_decompress(
1795       output.data(), *uncompressedLength(), input.data(), compressedLength);
1796   zstdThrowIfError(length);
1797   if (length != *uncompressedLength()) {
1798     throw std::runtime_error("ZSTDStreamCodec: Incorrect uncompressed length");
1799   }
1800   input.uncheckedAdvance(compressedLength);
1801   output.uncheckedAdvance(length);
1802   return true;
1803 #endif
1804 }
1805
1806 void ZSTDStreamCodec::resetDStream() {
1807   if (!dstream_) {
1808     dstream_.reset(ZSTD_createDStream());
1809     if (!dstream_) {
1810       throw std::bad_alloc{};
1811     }
1812   }
1813   zstdThrowIfError(ZSTD_initDStream(dstream_.get()));
1814 }
1815
1816 bool ZSTDStreamCodec::doUncompressStream(
1817     ByteRange& input,
1818     MutableByteRange& output,
1819     StreamCodec::FlushOp flushOp) {
1820   if (needReset_) {
1821     // If we are given all the input in one chunk try to use block uncompression
1822     if (flushOp == StreamCodec::FlushOp::END &&
1823         tryBlockUncompress(input, output)) {
1824       return true;
1825     }
1826     resetDStream();
1827     needReset_ = false;
1828   }
1829   ZSTD_inBuffer in = {input.data(), input.size(), 0};
1830   ZSTD_outBuffer out = {output.data(), output.size(), 0};
1831   SCOPE_EXIT {
1832     input.uncheckedAdvance(in.pos);
1833     output.uncheckedAdvance(out.pos);
1834   };
1835   size_t const rc = ZSTD_decompressStream(dstream_.get(), &out, &in);
1836   zstdThrowIfError(rc);
1837   return rc == 0;
1838 }
1839
1840 #endif // FOLLY_HAVE_LIBZSTD
1841
1842 #if FOLLY_HAVE_LIBBZ2
1843
1844 class Bzip2Codec final : public Codec {
1845  public:
1846   static std::unique_ptr<Codec> create(int level, CodecType type);
1847   explicit Bzip2Codec(int level, CodecType type);
1848
1849   std::vector<std::string> validPrefixes() const override;
1850   bool canUncompress(IOBuf const* data, Optional<uint64_t> uncompressedLength)
1851       const override;
1852
1853  private:
1854   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1855   std::unique_ptr<IOBuf> doCompress(IOBuf const* data) override;
1856   std::unique_ptr<IOBuf> doUncompress(
1857       IOBuf const* data,
1858       Optional<uint64_t> uncompressedLength) override;
1859
1860   int level_;
1861 };
1862
1863 /* static */ std::unique_ptr<Codec> Bzip2Codec::create(
1864     int level,
1865     CodecType type) {
1866   return std::make_unique<Bzip2Codec>(level, type);
1867 }
1868
1869 Bzip2Codec::Bzip2Codec(int level, CodecType type) : Codec(type) {
1870   DCHECK(type == CodecType::BZIP2);
1871   switch (level) {
1872     case COMPRESSION_LEVEL_FASTEST:
1873       level = 1;
1874       break;
1875     case COMPRESSION_LEVEL_DEFAULT:
1876       level = 9;
1877       break;
1878     case COMPRESSION_LEVEL_BEST:
1879       level = 9;
1880       break;
1881   }
1882   if (level < 1 || level > 9) {
1883     throw std::invalid_argument(
1884         to<std::string>("Bzip2: invalid level: ", level));
1885   }
1886   level_ = level;
1887 }
1888
1889 static uint32_t constexpr kBzip2MagicLE = 0x685a42;
1890 static uint64_t constexpr kBzip2MagicBytes = 3;
1891
1892 std::vector<std::string> Bzip2Codec::validPrefixes() const {
1893   return {prefixToStringLE(kBzip2MagicLE, kBzip2MagicBytes)};
1894 }
1895
1896 bool Bzip2Codec::canUncompress(IOBuf const* data, Optional<uint64_t>) const {
1897   return dataStartsWithLE(data, kBzip2MagicLE, kBzip2MagicBytes);
1898 }
1899
1900 uint64_t Bzip2Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
1901   // http://www.bzip.org/1.0.5/bzip2-manual-1.0.5.html#bzbufftobuffcompress
1902   //   To guarantee that the compressed data will fit in its buffer, allocate an
1903   //   output buffer of size 1% larger than the uncompressed data, plus six
1904   //   hundred extra bytes.
1905   return uncompressedLength + uncompressedLength / 100 + 600;
1906 }
1907
1908 static bz_stream createBzStream() {
1909   bz_stream stream;
1910   stream.bzalloc = nullptr;
1911   stream.bzfree = nullptr;
1912   stream.opaque = nullptr;
1913   stream.next_in = stream.next_out = nullptr;
1914   stream.avail_in = stream.avail_out = 0;
1915   return stream;
1916 }
1917
1918 // Throws on error condition, otherwise returns the code.
1919 static int bzCheck(int const rc) {
1920   switch (rc) {
1921     case BZ_OK:
1922     case BZ_RUN_OK:
1923     case BZ_FLUSH_OK:
1924     case BZ_FINISH_OK:
1925     case BZ_STREAM_END:
1926       return rc;
1927     default:
1928       throw std::runtime_error(to<std::string>("Bzip2 error: ", rc));
1929   }
1930 }
1931
1932 static std::unique_ptr<IOBuf> addOutputBuffer(
1933     bz_stream* stream,
1934     uint64_t const bufferLength) {
1935   DCHECK_LE(bufferLength, std::numeric_limits<unsigned>::max());
1936   DCHECK_EQ(stream->avail_out, 0);
1937
1938   auto buf = IOBuf::create(bufferLength);
1939   buf->append(buf->capacity());
1940
1941   stream->next_out = reinterpret_cast<char*>(buf->writableData());
1942   stream->avail_out = buf->length();
1943
1944   return buf;
1945 }
1946
1947 std::unique_ptr<IOBuf> Bzip2Codec::doCompress(IOBuf const* data) {
1948   bz_stream stream = createBzStream();
1949   bzCheck(BZ2_bzCompressInit(&stream, level_, 0, 0));
1950   SCOPE_EXIT {
1951     bzCheck(BZ2_bzCompressEnd(&stream));
1952   };
1953
1954   uint64_t const uncompressedLength = data->computeChainDataLength();
1955   uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength);
1956   uint64_t constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MiB
1957   uint64_t constexpr kDefaultBufferLength = uint64_t(4) << 20;
1958
1959   auto out = addOutputBuffer(
1960       &stream,
1961       maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen
1962                                                : kDefaultBufferLength);
1963
1964   for (auto range : *data) {
1965     while (!range.empty()) {
1966       auto const inSize = std::min<size_t>(range.size(), kMaxSingleStepLength);
1967       stream.next_in =
1968           const_cast<char*>(reinterpret_cast<char const*>(range.data()));
1969       stream.avail_in = inSize;
1970
1971       if (stream.avail_out == 0) {
1972         out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1973       }
1974
1975       bzCheck(BZ2_bzCompress(&stream, BZ_RUN));
1976       range.uncheckedAdvance(inSize - stream.avail_in);
1977     }
1978   }
1979   do {
1980     if (stream.avail_out == 0) {
1981       out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1982     }
1983   } while (bzCheck(BZ2_bzCompress(&stream, BZ_FINISH)) != BZ_STREAM_END);
1984
1985   out->prev()->trimEnd(stream.avail_out);
1986
1987   return out;
1988 }
1989
1990 std::unique_ptr<IOBuf> Bzip2Codec::doUncompress(
1991     const IOBuf* data,
1992     Optional<uint64_t> uncompressedLength) {
1993   bz_stream stream = createBzStream();
1994   bzCheck(BZ2_bzDecompressInit(&stream, 0, 0));
1995   SCOPE_EXIT {
1996     bzCheck(BZ2_bzDecompressEnd(&stream));
1997   };
1998
1999   uint64_t constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MiB
2000   uint64_t const kBlockSize = uint64_t(100) << 10; // 100 KiB
2001   uint64_t const kDefaultBufferLength =
2002       computeBufferLength(data->computeChainDataLength(), kBlockSize);
2003
2004   auto out = addOutputBuffer(
2005       &stream,
2006       ((uncompressedLength && *uncompressedLength <= kMaxSingleStepLength)
2007            ? *uncompressedLength
2008            : kDefaultBufferLength));
2009
2010   int rc = BZ_OK;
2011   for (auto range : *data) {
2012     while (!range.empty()) {
2013       auto const inSize = std::min<size_t>(range.size(), kMaxSingleStepLength);
2014       stream.next_in =
2015           const_cast<char*>(reinterpret_cast<char const*>(range.data()));
2016       stream.avail_in = inSize;
2017
2018       if (stream.avail_out == 0) {
2019         out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
2020       }
2021
2022       rc = bzCheck(BZ2_bzDecompress(&stream));
2023       range.uncheckedAdvance(inSize - stream.avail_in);
2024     }
2025   }
2026   while (rc != BZ_STREAM_END) {
2027     if (stream.avail_out == 0) {
2028       out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
2029     }
2030     size_t const outputSize = stream.avail_out;
2031     rc = bzCheck(BZ2_bzDecompress(&stream));
2032     if (outputSize == stream.avail_out) {
2033       throw std::runtime_error("Bzip2Codec: Truncated input");
2034     }
2035   }
2036
2037   out->prev()->trimEnd(stream.avail_out);
2038
2039   uint64_t const totalOut =
2040       (uint64_t(stream.total_out_hi32) << 32) + stream.total_out_lo32;
2041   if (uncompressedLength && uncompressedLength != totalOut) {
2042     throw std::runtime_error("Bzip2 error: Invalid uncompressed length");
2043   }
2044
2045   return out;
2046 }
2047
2048 #endif // FOLLY_HAVE_LIBBZ2
2049
2050 /**
2051  * Automatic decompression
2052  */
2053 class AutomaticCodec final : public Codec {
2054  public:
2055   static std::unique_ptr<Codec> create(
2056       std::vector<std::unique_ptr<Codec>> customCodecs);
2057   explicit AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs);
2058
2059   std::vector<std::string> validPrefixes() const override;
2060   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
2061       const override;
2062
2063  private:
2064   bool doNeedsUncompressedLength() const override;
2065   uint64_t doMaxUncompressedLength() const override;
2066
2067   uint64_t doMaxCompressedLength(uint64_t) const override {
2068     throw std::runtime_error(
2069         "AutomaticCodec error: maxCompressedLength() not supported.");
2070   }
2071   std::unique_ptr<IOBuf> doCompress(const IOBuf*) override {
2072     throw std::runtime_error("AutomaticCodec error: compress() not supported.");
2073   }
2074   std::unique_ptr<IOBuf> doUncompress(
2075       const IOBuf* data,
2076       Optional<uint64_t> uncompressedLength) override;
2077
2078   void addCodecIfSupported(CodecType type);
2079
2080   // Throws iff the codecs aren't compatible (very slow)
2081   void checkCompatibleCodecs() const;
2082
2083   std::vector<std::unique_ptr<Codec>> codecs_;
2084   bool needsUncompressedLength_;
2085   uint64_t maxUncompressedLength_;
2086 };
2087
2088 std::vector<std::string> AutomaticCodec::validPrefixes() const {
2089   std::unordered_set<std::string> prefixes;
2090   for (const auto& codec : codecs_) {
2091     const auto codecPrefixes = codec->validPrefixes();
2092     prefixes.insert(codecPrefixes.begin(), codecPrefixes.end());
2093   }
2094   return std::vector<std::string>{prefixes.begin(), prefixes.end()};
2095 }
2096
2097 bool AutomaticCodec::canUncompress(
2098     const IOBuf* data,
2099     Optional<uint64_t> uncompressedLength) const {
2100   return std::any_of(
2101       codecs_.begin(),
2102       codecs_.end(),
2103       [data, uncompressedLength](std::unique_ptr<Codec> const& codec) {
2104         return codec->canUncompress(data, uncompressedLength);
2105       });
2106 }
2107
2108 void AutomaticCodec::addCodecIfSupported(CodecType type) {
2109   const bool present = std::any_of(
2110       codecs_.begin(),
2111       codecs_.end(),
2112       [&type](std::unique_ptr<Codec> const& codec) {
2113         return codec->type() == type;
2114       });
2115   if (hasCodec(type) && !present) {
2116     codecs_.push_back(getCodec(type));
2117   }
2118 }
2119
2120 /* static */ std::unique_ptr<Codec> AutomaticCodec::create(
2121     std::vector<std::unique_ptr<Codec>> customCodecs) {
2122   return std::make_unique<AutomaticCodec>(std::move(customCodecs));
2123 }
2124
2125 AutomaticCodec::AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs)
2126     : Codec(CodecType::USER_DEFINED), codecs_(std::move(customCodecs)) {
2127   // Fastest -> slowest
2128   addCodecIfSupported(CodecType::LZ4_FRAME);
2129   addCodecIfSupported(CodecType::ZSTD);
2130   addCodecIfSupported(CodecType::ZLIB);
2131   addCodecIfSupported(CodecType::GZIP);
2132   addCodecIfSupported(CodecType::LZMA2);
2133   addCodecIfSupported(CodecType::BZIP2);
2134   if (kIsDebug) {
2135     checkCompatibleCodecs();
2136   }
2137   // Check that none of the codes are are null
2138   DCHECK(std::none_of(
2139       codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
2140         return codec == nullptr;
2141       }));
2142
2143   needsUncompressedLength_ = std::any_of(
2144       codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
2145         return codec->needsUncompressedLength();
2146       });
2147
2148   const auto it = std::max_element(
2149       codecs_.begin(),
2150       codecs_.end(),
2151       [](std::unique_ptr<Codec> const& lhs, std::unique_ptr<Codec> const& rhs) {
2152         return lhs->maxUncompressedLength() < rhs->maxUncompressedLength();
2153       });
2154   DCHECK(it != codecs_.end());
2155   maxUncompressedLength_ = (*it)->maxUncompressedLength();
2156 }
2157
2158 void AutomaticCodec::checkCompatibleCodecs() const {
2159   // Keep track of all the possible headers.
2160   std::unordered_set<std::string> headers;
2161   // The empty header is not allowed.
2162   headers.insert("");
2163   // Step 1:
2164   // Construct a set of headers and check that none of the headers occur twice.
2165   // Eliminate edge cases.
2166   for (auto&& codec : codecs_) {
2167     const auto codecHeaders = codec->validPrefixes();
2168     // Codecs without any valid headers are not allowed.
2169     if (codecHeaders.empty()) {
2170       throw std::invalid_argument{
2171           "AutomaticCodec: validPrefixes() must not be empty."};
2172     }
2173     // Insert all the headers for the current codec.
2174     const size_t beforeSize = headers.size();
2175     headers.insert(codecHeaders.begin(), codecHeaders.end());
2176     // Codecs are not compatible if any header occurred twice.
2177     if (beforeSize + codecHeaders.size() != headers.size()) {
2178       throw std::invalid_argument{
2179           "AutomaticCodec: Two valid prefixes collide."};
2180     }
2181   }
2182   // Step 2:
2183   // Check if any strict non-empty prefix of any header is a header.
2184   for (const auto& header : headers) {
2185     for (size_t i = 1; i < header.size(); ++i) {
2186       if (headers.count(header.substr(0, i))) {
2187         throw std::invalid_argument{
2188             "AutomaticCodec: One valid prefix is a prefix of another valid "
2189             "prefix."};
2190       }
2191     }
2192   }
2193 }
2194
2195 bool AutomaticCodec::doNeedsUncompressedLength() const {
2196   return needsUncompressedLength_;
2197 }
2198
2199 uint64_t AutomaticCodec::doMaxUncompressedLength() const {
2200   return maxUncompressedLength_;
2201 }
2202
2203 std::unique_ptr<IOBuf> AutomaticCodec::doUncompress(
2204     const IOBuf* data,
2205     Optional<uint64_t> uncompressedLength) {
2206   for (auto&& codec : codecs_) {
2207     if (codec->canUncompress(data, uncompressedLength)) {
2208       return codec->uncompress(data, uncompressedLength);
2209     }
2210   }
2211   throw std::runtime_error("AutomaticCodec error: Unknown compressed data");
2212 }
2213
2214 using CodecFactory = std::unique_ptr<Codec> (*)(int, CodecType);
2215 using StreamCodecFactory = std::unique_ptr<StreamCodec> (*)(int, CodecType);
2216 struct Factory {
2217   CodecFactory codec;
2218   StreamCodecFactory stream;
2219 };
2220
2221 constexpr Factory
2222     codecFactories[static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
2223         {}, // USER_DEFINED
2224         {NoCompressionCodec::create, nullptr},
2225
2226 #if FOLLY_HAVE_LIBLZ4
2227         {LZ4Codec::create, nullptr},
2228 #else
2229         {},
2230 #endif
2231
2232 #if FOLLY_HAVE_LIBSNAPPY
2233         {SnappyCodec::create, nullptr},
2234 #else
2235         {},
2236 #endif
2237
2238 #if FOLLY_HAVE_LIBZ
2239         {ZlibStreamCodec::createCodec, ZlibStreamCodec::createStream},
2240 #else
2241         {},
2242 #endif
2243
2244 #if FOLLY_HAVE_LIBLZ4
2245         {LZ4Codec::create, nullptr},
2246 #else
2247         {},
2248 #endif
2249
2250 #if FOLLY_HAVE_LIBLZMA
2251         {LZMA2Codec::create, nullptr},
2252         {LZMA2Codec::create, nullptr},
2253 #else
2254         {},
2255         {},
2256 #endif
2257
2258 #if FOLLY_HAVE_LIBZSTD
2259         {ZSTDStreamCodec::createCodec, ZSTDStreamCodec::createStream},
2260 #else
2261         {},
2262 #endif
2263
2264 #if FOLLY_HAVE_LIBZ
2265         {ZlibStreamCodec::createCodec, ZlibStreamCodec::createStream},
2266 #else
2267         {},
2268 #endif
2269
2270 #if (FOLLY_HAVE_LIBLZ4 && LZ4_VERSION_NUMBER >= 10301)
2271         {LZ4FrameCodec::create, nullptr},
2272 #else
2273         {},
2274 #endif
2275
2276 #if FOLLY_HAVE_LIBBZ2
2277         {Bzip2Codec::create, nullptr},
2278 #else
2279         {},
2280 #endif
2281 };
2282
2283 Factory const& getFactory(CodecType type) {
2284   size_t const idx = static_cast<size_t>(type);
2285   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
2286     throw std::invalid_argument(
2287         to<std::string>("Compression type ", idx, " invalid"));
2288   }
2289   return codecFactories[idx];
2290 }
2291 } // namespace
2292
2293 bool hasCodec(CodecType type) {
2294   return getFactory(type).codec != nullptr;
2295 }
2296
2297 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
2298   auto const factory = getFactory(type).codec;
2299   if (!factory) {
2300     throw std::invalid_argument(
2301         to<std::string>("Compression type ", type, " not supported"));
2302   }
2303   auto codec = (*factory)(level, type);
2304   DCHECK(codec->type() == type);
2305   return codec;
2306 }
2307
2308 bool hasStreamCodec(CodecType type) {
2309   return getFactory(type).stream != nullptr;
2310 }
2311
2312 std::unique_ptr<StreamCodec> getStreamCodec(CodecType type, int level) {
2313   auto const factory = getFactory(type).stream;
2314   if (!factory) {
2315     throw std::invalid_argument(
2316         to<std::string>("Compression type ", type, " not supported"));
2317   }
2318   auto codec = (*factory)(level, type);
2319   DCHECK(codec->type() == type);
2320   return codec;
2321 }
2322
2323 std::unique_ptr<Codec> getAutoUncompressionCodec(
2324     std::vector<std::unique_ptr<Codec>> customCodecs) {
2325   return AutomaticCodec::create(std::move(customCodecs));
2326 }
2327 } // namespace io
2328 } // namespace folly