Add zstd streaming interface
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #if LZ4_VERSION_NUMBER >= 10301
23 #include <lz4frame.h>
24 #endif
25 #endif
26
27 #include <glog/logging.h>
28
29 #if FOLLY_HAVE_LIBSNAPPY
30 #include <snappy.h>
31 #include <snappy-sinksource.h>
32 #endif
33
34 #if FOLLY_HAVE_LIBZ
35 #include <zlib.h>
36 #endif
37
38 #if FOLLY_HAVE_LIBLZMA
39 #include <lzma.h>
40 #endif
41
42 #if FOLLY_HAVE_LIBZSTD
43 #define ZSTD_STATIC_LINKING_ONLY
44 #include <zstd.h>
45 #endif
46
47 #if FOLLY_HAVE_LIBBZ2
48 #include <bzlib.h>
49 #endif
50
51 #include <folly/Bits.h>
52 #include <folly/Conv.h>
53 #include <folly/Memory.h>
54 #include <folly/Portability.h>
55 #include <folly/ScopeGuard.h>
56 #include <folly/Varint.h>
57 #include <folly/io/Cursor.h>
58 #include <algorithm>
59 #include <unordered_set>
60
61 namespace folly { namespace io {
62
63 Codec::Codec(CodecType type) : type_(type) { }
64
65 // Ensure consistent behavior in the nullptr case
66 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
67   uint64_t len = data->computeChainDataLength();
68   if (len == 0) {
69     return IOBuf::create(0);
70   }
71   if (len > maxUncompressedLength()) {
72     throw std::runtime_error("Codec: uncompressed length too large");
73   }
74
75   return doCompress(data);
76 }
77
78 std::string Codec::compress(const StringPiece data) {
79   const uint64_t len = data.size();
80   if (len == 0) {
81     return "";
82   }
83   if (len > maxUncompressedLength()) {
84     throw std::runtime_error("Codec: uncompressed length too large");
85   }
86
87   return doCompressString(data);
88 }
89
90 std::unique_ptr<IOBuf> Codec::uncompress(
91     const IOBuf* data,
92     Optional<uint64_t> uncompressedLength) {
93   if (!uncompressedLength) {
94     if (needsUncompressedLength()) {
95       throw std::invalid_argument("Codec: uncompressed length required");
96     }
97   } else if (*uncompressedLength > maxUncompressedLength()) {
98     throw std::runtime_error("Codec: uncompressed length too large");
99   }
100
101   if (data->empty()) {
102     if (uncompressedLength.value_or(0) != 0) {
103       throw std::runtime_error("Codec: invalid uncompressed length");
104     }
105     return IOBuf::create(0);
106   }
107
108   return doUncompress(data, uncompressedLength);
109 }
110
111 std::string Codec::uncompress(
112     const StringPiece data,
113     Optional<uint64_t> uncompressedLength) {
114   if (!uncompressedLength) {
115     if (needsUncompressedLength()) {
116       throw std::invalid_argument("Codec: uncompressed length required");
117     }
118   } else if (*uncompressedLength > maxUncompressedLength()) {
119     throw std::runtime_error("Codec: uncompressed length too large");
120   }
121
122   if (data.empty()) {
123     if (uncompressedLength.value_or(0) != 0) {
124       throw std::runtime_error("Codec: invalid uncompressed length");
125     }
126     return "";
127   }
128
129   return doUncompressString(data, uncompressedLength);
130 }
131
132 bool Codec::needsUncompressedLength() const {
133   return doNeedsUncompressedLength();
134 }
135
136 uint64_t Codec::maxUncompressedLength() const {
137   return doMaxUncompressedLength();
138 }
139
140 bool Codec::doNeedsUncompressedLength() const {
141   return false;
142 }
143
144 uint64_t Codec::doMaxUncompressedLength() const {
145   return UNLIMITED_UNCOMPRESSED_LENGTH;
146 }
147
148 std::vector<std::string> Codec::validPrefixes() const {
149   return {};
150 }
151
152 bool Codec::canUncompress(const IOBuf*, Optional<uint64_t>) const {
153   return false;
154 }
155
156 std::string Codec::doCompressString(const StringPiece data) {
157   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
158   auto outputBuffer = doCompress(&inputBuffer);
159   std::string output;
160   output.reserve(outputBuffer->computeChainDataLength());
161   for (auto range : *outputBuffer) {
162     output.append(reinterpret_cast<const char*>(range.data()), range.size());
163   }
164   return output;
165 }
166
167 std::string Codec::doUncompressString(
168     const StringPiece data,
169     Optional<uint64_t> uncompressedLength) {
170   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
171   auto outputBuffer = doUncompress(&inputBuffer, uncompressedLength);
172   std::string output;
173   output.reserve(outputBuffer->computeChainDataLength());
174   for (auto range : *outputBuffer) {
175     output.append(reinterpret_cast<const char*>(range.data()), range.size());
176   }
177   return output;
178 }
179
180 uint64_t Codec::maxCompressedLength(uint64_t uncompressedLength) const {
181   if (uncompressedLength == 0) {
182     return 0;
183   }
184   return doMaxCompressedLength(uncompressedLength);
185 }
186
187 Optional<uint64_t> Codec::getUncompressedLength(
188     const folly::IOBuf* data,
189     Optional<uint64_t> uncompressedLength) const {
190   auto const compressedLength = data->computeChainDataLength();
191   if (uncompressedLength == uint64_t(0) || compressedLength == 0) {
192     if (uncompressedLength.value_or(0) != 0 || compressedLength != 0) {
193       throw std::runtime_error("Invalid uncompressed length");
194     }
195     return 0;
196   }
197   return doGetUncompressedLength(data, uncompressedLength);
198 }
199
200 Optional<uint64_t> Codec::doGetUncompressedLength(
201     const folly::IOBuf*,
202     Optional<uint64_t> uncompressedLength) const {
203   return uncompressedLength;
204 }
205
206 bool StreamCodec::needsDataLength() const {
207   return doNeedsDataLength();
208 }
209
210 bool StreamCodec::doNeedsDataLength() const {
211   return false;
212 }
213
214 void StreamCodec::assertStateIs(State expected) const {
215   if (state_ != expected) {
216     throw std::logic_error(folly::to<std::string>(
217         "Codec: state is ", state_, "; expected state ", expected));
218   }
219 }
220
221 void StreamCodec::resetStream(Optional<uint64_t> uncompressedLength) {
222   state_ = State::RESET;
223   uncompressedLength_ = uncompressedLength;
224   doResetStream();
225 }
226
227 bool StreamCodec::compressStream(
228     ByteRange& input,
229     MutableByteRange& output,
230     StreamCodec::FlushOp flushOp) {
231   if (state_ == State::RESET && input.empty()) {
232     if (flushOp == StreamCodec::FlushOp::NONE) {
233       return false;
234     }
235     if (flushOp == StreamCodec::FlushOp::END &&
236         uncompressedLength().value_or(0) != 0) {
237       throw std::runtime_error("Codec: invalid uncompressed length");
238     }
239     return true;
240   }
241   if (state_ == State::RESET && !input.empty() &&
242       uncompressedLength() == uint64_t(0)) {
243     throw std::runtime_error("Codec: invalid uncompressed length");
244   }
245   // Handle input state transitions
246   switch (flushOp) {
247     case StreamCodec::FlushOp::NONE:
248       if (state_ == State::RESET) {
249         state_ = State::COMPRESS;
250       }
251       assertStateIs(State::COMPRESS);
252       break;
253     case StreamCodec::FlushOp::FLUSH:
254       if (state_ == State::RESET || state_ == State::COMPRESS) {
255         state_ = State::COMPRESS_FLUSH;
256       }
257       assertStateIs(State::COMPRESS_FLUSH);
258       break;
259     case StreamCodec::FlushOp::END:
260       if (state_ == State::RESET || state_ == State::COMPRESS) {
261         state_ = State::COMPRESS_END;
262       }
263       assertStateIs(State::COMPRESS_END);
264       break;
265   }
266   bool const done = doCompressStream(input, output, flushOp);
267   // Handle output state transitions
268   if (done) {
269     if (state_ == State::COMPRESS_FLUSH) {
270       state_ = State::COMPRESS;
271     } else if (state_ == State::COMPRESS_END) {
272       state_ = State::END;
273     }
274     // Check internal invariants
275     DCHECK(input.empty());
276     DCHECK(flushOp != StreamCodec::FlushOp::NONE);
277   }
278   return done;
279 }
280
281 bool StreamCodec::uncompressStream(
282     ByteRange& input,
283     MutableByteRange& output,
284     StreamCodec::FlushOp flushOp) {
285   if (state_ == State::RESET && input.empty()) {
286     if (uncompressedLength().value_or(0) == 0) {
287       return true;
288     }
289     return false;
290   }
291   // Handle input state transitions
292   if (state_ == State::RESET) {
293     state_ = State::UNCOMPRESS;
294   }
295   assertStateIs(State::UNCOMPRESS);
296   bool const done = doUncompressStream(input, output, flushOp);
297   // Handle output state transitions
298   if (done) {
299     state_ = State::END;
300   }
301   return done;
302 }
303
304 static std::unique_ptr<IOBuf> addOutputBuffer(
305     MutableByteRange& output,
306     uint64_t size) {
307   DCHECK(output.empty());
308   auto buffer = IOBuf::create(size);
309   buffer->append(buffer->capacity());
310   output = {buffer->writableData(), buffer->length()};
311   return buffer;
312 }
313
314 std::unique_ptr<IOBuf> StreamCodec::doCompress(IOBuf const* data) {
315   uint64_t const uncompressedLength = data->computeChainDataLength();
316   resetStream(uncompressedLength);
317   uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength);
318
319   auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB
320   auto constexpr kDefaultBufferLength = uint64_t(4) << 20; // 4 MB
321
322   MutableByteRange output;
323   auto buffer = addOutputBuffer(
324       output,
325       maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen
326                                                : kDefaultBufferLength);
327
328   // Compress the entire IOBuf chain into the IOBuf chain pointed to by buffer
329   IOBuf const* current = data;
330   ByteRange input{current->data(), current->length()};
331   StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
332   for (;;) {
333     while (input.empty() && current->next() != data) {
334       current = current->next();
335       input = {current->data(), current->length()};
336     }
337     if (current->next() == data) {
338       // This is the last input buffer so end the stream
339       flushOp = StreamCodec::FlushOp::END;
340     }
341     if (output.empty()) {
342       buffer->prependChain(addOutputBuffer(output, kDefaultBufferLength));
343     }
344     bool const done = compressStream(input, output, flushOp);
345     if (done) {
346       DCHECK(input.empty());
347       DCHECK(flushOp == StreamCodec::FlushOp::END);
348       DCHECK_EQ(current->next(), data);
349       break;
350     }
351   }
352   buffer->prev()->trimEnd(output.size());
353   return buffer;
354 }
355
356 static uint64_t computeBufferLength(
357     uint64_t const compressedLength,
358     uint64_t const blockSize) {
359   uint64_t constexpr kMaxBufferLength = uint64_t(4) << 20; // 4 MiB
360   uint64_t const goodBufferSize = 4 * std::max(blockSize, compressedLength);
361   return std::min(goodBufferSize, kMaxBufferLength);
362 }
363
364 std::unique_ptr<IOBuf> StreamCodec::doUncompress(
365     IOBuf const* data,
366     Optional<uint64_t> uncompressedLength) {
367   auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB
368   auto constexpr kBlockSize = uint64_t(128) << 10;
369   auto const defaultBufferLength =
370       computeBufferLength(data->computeChainDataLength(), kBlockSize);
371
372   uncompressedLength = getUncompressedLength(data, uncompressedLength);
373   resetStream(uncompressedLength);
374
375   MutableByteRange output;
376   auto buffer = addOutputBuffer(
377       output,
378       (uncompressedLength && *uncompressedLength <= kMaxSingleStepLength
379            ? *uncompressedLength
380            : defaultBufferLength));
381
382   // Uncompress the entire IOBuf chain into the IOBuf chain pointed to by buffer
383   IOBuf const* current = data;
384   ByteRange input{current->data(), current->length()};
385   StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
386   for (;;) {
387     while (input.empty() && current->next() != data) {
388       current = current->next();
389       input = {current->data(), current->length()};
390     }
391     if (current->next() == data) {
392       // Tell the uncompressor there is no more input (it may optimize)
393       flushOp = StreamCodec::FlushOp::END;
394     }
395     if (output.empty()) {
396       buffer->prependChain(addOutputBuffer(output, defaultBufferLength));
397     }
398     bool const done = uncompressStream(input, output, flushOp);
399     if (done) {
400       break;
401     }
402   }
403   if (!input.empty()) {
404     throw std::runtime_error("Codec: Junk after end of data");
405   }
406
407   buffer->prev()->trimEnd(output.size());
408   if (uncompressedLength &&
409       *uncompressedLength != buffer->computeChainDataLength()) {
410     throw std::runtime_error("Codec: invalid uncompressed length");
411   }
412
413   return buffer;
414 }
415
416 namespace {
417
418 /**
419  * No compression
420  */
421 class NoCompressionCodec final : public Codec {
422  public:
423   static std::unique_ptr<Codec> create(int level, CodecType type);
424   explicit NoCompressionCodec(int level, CodecType type);
425
426  private:
427   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
428   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
429   std::unique_ptr<IOBuf> doUncompress(
430       const IOBuf* data,
431       Optional<uint64_t> uncompressedLength) override;
432 };
433
434 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
435   return std::make_unique<NoCompressionCodec>(level, type);
436 }
437
438 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
439   : Codec(type) {
440   DCHECK(type == CodecType::NO_COMPRESSION);
441   switch (level) {
442   case COMPRESSION_LEVEL_DEFAULT:
443   case COMPRESSION_LEVEL_FASTEST:
444   case COMPRESSION_LEVEL_BEST:
445     level = 0;
446   }
447   if (level != 0) {
448     throw std::invalid_argument(to<std::string>(
449         "NoCompressionCodec: invalid level ", level));
450   }
451 }
452
453 uint64_t NoCompressionCodec::doMaxCompressedLength(
454     uint64_t uncompressedLength) const {
455   return uncompressedLength;
456 }
457
458 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
459     const IOBuf* data) {
460   return data->clone();
461 }
462
463 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
464     const IOBuf* data,
465     Optional<uint64_t> uncompressedLength) {
466   if (uncompressedLength &&
467       data->computeChainDataLength() != *uncompressedLength) {
468     throw std::runtime_error(
469         to<std::string>("NoCompressionCodec: invalid uncompressed length"));
470   }
471   return data->clone();
472 }
473
474 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
475
476 namespace {
477
478 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
479   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
480   out->append(encodeVarint(val, out->writableTail()));
481 }
482
483 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
484   uint64_t val = 0;
485   int8_t b = 0;
486   for (int shift = 0; shift <= 63; shift += 7) {
487     b = cursor.read<int8_t>();
488     val |= static_cast<uint64_t>(b & 0x7f) << shift;
489     if (b >= 0) {
490       break;
491     }
492   }
493   if (b < 0) {
494     throw std::invalid_argument("Invalid varint value. Too big.");
495   }
496   return val;
497 }
498
499 }  // namespace
500
501 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
502
503 namespace {
504 /**
505  * Reads sizeof(T) bytes, and returns false if not enough bytes are available.
506  * Returns true if the first n bytes are equal to prefix when interpreted as
507  * a little endian T.
508  */
509 template <typename T>
510 typename std::enable_if<std::is_unsigned<T>::value, bool>::type
511 dataStartsWithLE(const IOBuf* data, T prefix, uint64_t n = sizeof(T)) {
512   DCHECK_GT(n, 0);
513   DCHECK_LE(n, sizeof(T));
514   T value;
515   Cursor cursor{data};
516   if (!cursor.tryReadLE(value)) {
517     return false;
518   }
519   const T mask = n == sizeof(T) ? T(-1) : (T(1) << (8 * n)) - 1;
520   return prefix == (value & mask);
521 }
522
523 template <typename T>
524 typename std::enable_if<std::is_arithmetic<T>::value, std::string>::type
525 prefixToStringLE(T prefix, uint64_t n = sizeof(T)) {
526   DCHECK_GT(n, 0);
527   DCHECK_LE(n, sizeof(T));
528   prefix = Endian::little(prefix);
529   std::string result;
530   result.resize(n);
531   memcpy(&result[0], &prefix, n);
532   return result;
533 }
534 } // namespace
535
536 #if FOLLY_HAVE_LIBLZ4
537
538 /**
539  * LZ4 compression
540  */
541 class LZ4Codec final : public Codec {
542  public:
543   static std::unique_ptr<Codec> create(int level, CodecType type);
544   explicit LZ4Codec(int level, CodecType type);
545
546  private:
547   bool doNeedsUncompressedLength() const override;
548   uint64_t doMaxUncompressedLength() const override;
549   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
550
551   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
552
553   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
554   std::unique_ptr<IOBuf> doUncompress(
555       const IOBuf* data,
556       Optional<uint64_t> uncompressedLength) override;
557
558   bool highCompression_;
559 };
560
561 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
562   return std::make_unique<LZ4Codec>(level, type);
563 }
564
565 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
566   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
567
568   switch (level) {
569   case COMPRESSION_LEVEL_FASTEST:
570   case COMPRESSION_LEVEL_DEFAULT:
571     level = 1;
572     break;
573   case COMPRESSION_LEVEL_BEST:
574     level = 2;
575     break;
576   }
577   if (level < 1 || level > 2) {
578     throw std::invalid_argument(to<std::string>(
579         "LZ4Codec: invalid level: ", level));
580   }
581   highCompression_ = (level > 1);
582 }
583
584 bool LZ4Codec::doNeedsUncompressedLength() const {
585   return !encodeSize();
586 }
587
588 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
589 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
590 // here.
591 #ifndef LZ4_MAX_INPUT_SIZE
592 # define LZ4_MAX_INPUT_SIZE 0x7E000000
593 #endif
594
595 uint64_t LZ4Codec::doMaxUncompressedLength() const {
596   return LZ4_MAX_INPUT_SIZE;
597 }
598
599 uint64_t LZ4Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
600   return LZ4_compressBound(uncompressedLength) +
601       (encodeSize() ? kMaxVarintLength64 : 0);
602 }
603
604 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
605   IOBuf clone;
606   if (data->isChained()) {
607     // LZ4 doesn't support streaming, so we have to coalesce
608     clone = data->cloneCoalescedAsValue();
609     data = &clone;
610   }
611
612   auto out = IOBuf::create(maxCompressedLength(data->length()));
613   if (encodeSize()) {
614     encodeVarintToIOBuf(data->length(), out.get());
615   }
616
617   int n;
618   auto input = reinterpret_cast<const char*>(data->data());
619   auto output = reinterpret_cast<char*>(out->writableTail());
620   const auto inputLength = data->length();
621 #if LZ4_VERSION_NUMBER >= 10700
622   if (highCompression_) {
623     n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0);
624   } else {
625     n = LZ4_compress_default(input, output, inputLength, out->tailroom());
626   }
627 #else
628   if (highCompression_) {
629     n = LZ4_compressHC(input, output, inputLength);
630   } else {
631     n = LZ4_compress(input, output, inputLength);
632   }
633 #endif
634
635   CHECK_GE(n, 0);
636   CHECK_LE(n, out->capacity());
637
638   out->append(n);
639   return out;
640 }
641
642 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
643     const IOBuf* data,
644     Optional<uint64_t> uncompressedLength) {
645   IOBuf clone;
646   if (data->isChained()) {
647     // LZ4 doesn't support streaming, so we have to coalesce
648     clone = data->cloneCoalescedAsValue();
649     data = &clone;
650   }
651
652   folly::io::Cursor cursor(data);
653   uint64_t actualUncompressedLength;
654   if (encodeSize()) {
655     actualUncompressedLength = decodeVarintFromCursor(cursor);
656     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
657       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
658     }
659   } else {
660     // Invariants
661     DCHECK(uncompressedLength.hasValue());
662     DCHECK(*uncompressedLength <= maxUncompressedLength());
663     actualUncompressedLength = *uncompressedLength;
664   }
665
666   auto sp = StringPiece{cursor.peekBytes()};
667   auto out = IOBuf::create(actualUncompressedLength);
668   int n = LZ4_decompress_safe(
669       sp.data(),
670       reinterpret_cast<char*>(out->writableTail()),
671       sp.size(),
672       actualUncompressedLength);
673
674   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
675     throw std::runtime_error(to<std::string>(
676         "LZ4 decompression returned invalid value ", n));
677   }
678   out->append(actualUncompressedLength);
679   return out;
680 }
681
682 #if LZ4_VERSION_NUMBER >= 10301
683
684 class LZ4FrameCodec final : public Codec {
685  public:
686   static std::unique_ptr<Codec> create(int level, CodecType type);
687   explicit LZ4FrameCodec(int level, CodecType type);
688   ~LZ4FrameCodec() override;
689
690   std::vector<std::string> validPrefixes() const override;
691   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
692       const override;
693
694  private:
695   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
696
697   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
698   std::unique_ptr<IOBuf> doUncompress(
699       const IOBuf* data,
700       Optional<uint64_t> uncompressedLength) override;
701
702   // Reset the dctx_ if it is dirty or null.
703   void resetDCtx();
704
705   int level_;
706   LZ4F_decompressionContext_t dctx_{nullptr};
707   bool dirty_{false};
708 };
709
710 /* static */ std::unique_ptr<Codec> LZ4FrameCodec::create(
711     int level,
712     CodecType type) {
713   return std::make_unique<LZ4FrameCodec>(level, type);
714 }
715
716 static constexpr uint32_t kLZ4FrameMagicLE = 0x184D2204;
717
718 std::vector<std::string> LZ4FrameCodec::validPrefixes() const {
719   return {prefixToStringLE(kLZ4FrameMagicLE)};
720 }
721
722 bool LZ4FrameCodec::canUncompress(const IOBuf* data, Optional<uint64_t>) const {
723   return dataStartsWithLE(data, kLZ4FrameMagicLE);
724 }
725
726 uint64_t LZ4FrameCodec::doMaxCompressedLength(
727     uint64_t uncompressedLength) const {
728   LZ4F_preferences_t prefs{};
729   prefs.compressionLevel = level_;
730   prefs.frameInfo.contentSize = uncompressedLength;
731   return LZ4F_compressFrameBound(uncompressedLength, &prefs);
732 }
733
734 static size_t lz4FrameThrowOnError(size_t code) {
735   if (LZ4F_isError(code)) {
736     throw std::runtime_error(
737         to<std::string>("LZ4Frame error: ", LZ4F_getErrorName(code)));
738   }
739   return code;
740 }
741
742 void LZ4FrameCodec::resetDCtx() {
743   if (dctx_ && !dirty_) {
744     return;
745   }
746   if (dctx_) {
747     LZ4F_freeDecompressionContext(dctx_);
748   }
749   lz4FrameThrowOnError(LZ4F_createDecompressionContext(&dctx_, 100));
750   dirty_ = false;
751 }
752
753 LZ4FrameCodec::LZ4FrameCodec(int level, CodecType type) : Codec(type) {
754   DCHECK(type == CodecType::LZ4_FRAME);
755   switch (level) {
756     case COMPRESSION_LEVEL_FASTEST:
757     case COMPRESSION_LEVEL_DEFAULT:
758       level_ = 0;
759       break;
760     case COMPRESSION_LEVEL_BEST:
761       level_ = 16;
762       break;
763     default:
764       level_ = level;
765       break;
766   }
767 }
768
769 LZ4FrameCodec::~LZ4FrameCodec() {
770   if (dctx_) {
771     LZ4F_freeDecompressionContext(dctx_);
772   }
773 }
774
775 std::unique_ptr<IOBuf> LZ4FrameCodec::doCompress(const IOBuf* data) {
776   // LZ4 Frame compression doesn't support streaming so we have to coalesce
777   IOBuf clone;
778   if (data->isChained()) {
779     clone = data->cloneCoalescedAsValue();
780     data = &clone;
781   }
782   // Set preferences
783   const auto uncompressedLength = data->length();
784   LZ4F_preferences_t prefs{};
785   prefs.compressionLevel = level_;
786   prefs.frameInfo.contentSize = uncompressedLength;
787   // Compress
788   auto buf = IOBuf::create(maxCompressedLength(uncompressedLength));
789   const size_t written = lz4FrameThrowOnError(LZ4F_compressFrame(
790       buf->writableTail(),
791       buf->tailroom(),
792       data->data(),
793       data->length(),
794       &prefs));
795   buf->append(written);
796   return buf;
797 }
798
799 std::unique_ptr<IOBuf> LZ4FrameCodec::doUncompress(
800     const IOBuf* data,
801     Optional<uint64_t> uncompressedLength) {
802   // Reset the dctx if any errors have occurred
803   resetDCtx();
804   // Coalesce the data
805   ByteRange in = *data->begin();
806   IOBuf clone;
807   if (data->isChained()) {
808     clone = data->cloneCoalescedAsValue();
809     in = clone.coalesce();
810   }
811   data = nullptr;
812   // Select decompression options
813   LZ4F_decompressOptions_t options;
814   options.stableDst = 1;
815   // Select blockSize and growthSize for the IOBufQueue
816   IOBufQueue queue(IOBufQueue::cacheChainLength());
817   auto blockSize = uint64_t{64} << 10;
818   auto growthSize = uint64_t{4} << 20;
819   if (uncompressedLength) {
820     // Allocate uncompressedLength in one chunk (up to 64 MB)
821     const auto allocateSize = std::min(*uncompressedLength, uint64_t{64} << 20);
822     queue.preallocate(allocateSize, allocateSize);
823     blockSize = std::min(*uncompressedLength, blockSize);
824     growthSize = std::min(*uncompressedLength, growthSize);
825   } else {
826     // Reduce growthSize for small data
827     const auto guessUncompressedLen =
828         4 * std::max<uint64_t>(blockSize, in.size());
829     growthSize = std::min(guessUncompressedLen, growthSize);
830   }
831   // Once LZ4_decompress() is called, the dctx_ cannot be reused until it
832   // returns 0
833   dirty_ = true;
834   // Decompress until the frame is over
835   size_t code = 0;
836   do {
837     // Allocate enough space to decompress at least a block
838     void* out;
839     size_t outSize;
840     std::tie(out, outSize) = queue.preallocate(blockSize, growthSize);
841     // Decompress
842     size_t inSize = in.size();
843     code = lz4FrameThrowOnError(
844         LZ4F_decompress(dctx_, out, &outSize, in.data(), &inSize, &options));
845     if (in.empty() && outSize == 0 && code != 0) {
846       // We passed no input, no output was produced, and the frame isn't over
847       // No more forward progress is possible
848       throw std::runtime_error("LZ4Frame error: Incomplete frame");
849     }
850     in.uncheckedAdvance(inSize);
851     queue.postallocate(outSize);
852   } while (code != 0);
853   // At this point the decompression context can be reused
854   dirty_ = false;
855   if (uncompressedLength && queue.chainLength() != *uncompressedLength) {
856     throw std::runtime_error("LZ4Frame error: Invalid uncompressedLength");
857   }
858   return queue.move();
859 }
860
861 #endif // LZ4_VERSION_NUMBER >= 10301
862 #endif // FOLLY_HAVE_LIBLZ4
863
864 #if FOLLY_HAVE_LIBSNAPPY
865
866 /**
867  * Snappy compression
868  */
869
870 /**
871  * Implementation of snappy::Source that reads from a IOBuf chain.
872  */
873 class IOBufSnappySource final : public snappy::Source {
874  public:
875   explicit IOBufSnappySource(const IOBuf* data);
876   size_t Available() const override;
877   const char* Peek(size_t* len) override;
878   void Skip(size_t n) override;
879  private:
880   size_t available_;
881   io::Cursor cursor_;
882 };
883
884 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
885   : available_(data->computeChainDataLength()),
886     cursor_(data) {
887 }
888
889 size_t IOBufSnappySource::Available() const {
890   return available_;
891 }
892
893 const char* IOBufSnappySource::Peek(size_t* len) {
894   auto sp = StringPiece{cursor_.peekBytes()};
895   *len = sp.size();
896   return sp.data();
897 }
898
899 void IOBufSnappySource::Skip(size_t n) {
900   CHECK_LE(n, available_);
901   cursor_.skip(n);
902   available_ -= n;
903 }
904
905 class SnappyCodec final : public Codec {
906  public:
907   static std::unique_ptr<Codec> create(int level, CodecType type);
908   explicit SnappyCodec(int level, CodecType type);
909
910  private:
911   uint64_t doMaxUncompressedLength() const override;
912   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
913   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
914   std::unique_ptr<IOBuf> doUncompress(
915       const IOBuf* data,
916       Optional<uint64_t> uncompressedLength) override;
917 };
918
919 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
920   return std::make_unique<SnappyCodec>(level, type);
921 }
922
923 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
924   DCHECK(type == CodecType::SNAPPY);
925   switch (level) {
926   case COMPRESSION_LEVEL_FASTEST:
927   case COMPRESSION_LEVEL_DEFAULT:
928   case COMPRESSION_LEVEL_BEST:
929     level = 1;
930   }
931   if (level != 1) {
932     throw std::invalid_argument(to<std::string>(
933         "SnappyCodec: invalid level: ", level));
934   }
935 }
936
937 uint64_t SnappyCodec::doMaxUncompressedLength() const {
938   // snappy.h uses uint32_t for lengths, so there's that.
939   return std::numeric_limits<uint32_t>::max();
940 }
941
942 uint64_t SnappyCodec::doMaxCompressedLength(uint64_t uncompressedLength) const {
943   return snappy::MaxCompressedLength(uncompressedLength);
944 }
945
946 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
947   IOBufSnappySource source(data);
948   auto out = IOBuf::create(maxCompressedLength(source.Available()));
949
950   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
951       out->writableTail()));
952
953   size_t n = snappy::Compress(&source, &sink);
954
955   CHECK_LE(n, out->capacity());
956   out->append(n);
957   return out;
958 }
959
960 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(
961     const IOBuf* data,
962     Optional<uint64_t> uncompressedLength) {
963   uint32_t actualUncompressedLength = 0;
964
965   {
966     IOBufSnappySource source(data);
967     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
968       throw std::runtime_error("snappy::GetUncompressedLength failed");
969     }
970     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
971       throw std::runtime_error("snappy: invalid uncompressed length");
972     }
973   }
974
975   auto out = IOBuf::create(actualUncompressedLength);
976
977   {
978     IOBufSnappySource source(data);
979     if (!snappy::RawUncompress(&source,
980                                reinterpret_cast<char*>(out->writableTail()))) {
981       throw std::runtime_error("snappy::RawUncompress failed");
982     }
983   }
984
985   out->append(actualUncompressedLength);
986   return out;
987 }
988
989 #endif  // FOLLY_HAVE_LIBSNAPPY
990
991 #if FOLLY_HAVE_LIBZ
992 /**
993  * Zlib codec
994  */
995 class ZlibCodec final : public Codec {
996  public:
997   static std::unique_ptr<Codec> create(int level, CodecType type);
998   explicit ZlibCodec(int level, CodecType type);
999
1000   std::vector<std::string> validPrefixes() const override;
1001   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1002       const override;
1003
1004  private:
1005   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1006   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
1007   std::unique_ptr<IOBuf> doUncompress(
1008       const IOBuf* data,
1009       Optional<uint64_t> uncompressedLength) override;
1010
1011   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
1012   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
1013
1014   int level_;
1015 };
1016
1017 static constexpr uint16_t kGZIPMagicLE = 0x8B1F;
1018
1019 std::vector<std::string> ZlibCodec::validPrefixes() const {
1020   if (type() == CodecType::ZLIB) {
1021     // Zlib streams start with a 2 byte header.
1022     //
1023     //   0   1
1024     // +---+---+
1025     // |CMF|FLG|
1026     // +---+---+
1027     //
1028     // We won't restrict the values of any sub-fields except as described below.
1029     //
1030     // The lowest 4 bits of CMF is the compression method (CM).
1031     // CM == 0x8 is the deflate compression method, which is currently the only
1032     // supported compression method, so any valid prefix must have CM == 0x8.
1033     //
1034     // The lowest 5 bits of FLG is FCHECK.
1035     // FCHECK must be such that the two header bytes are a multiple of 31 when
1036     // interpreted as a big endian 16-bit number.
1037     std::vector<std::string> result;
1038     // 16 values for the first byte, 8 values for the second byte.
1039     // There are also 4 combinations where both 0x00 and 0x1F work as FCHECK.
1040     result.reserve(132);
1041     // Select all values for the CMF byte that use the deflate algorithm 0x8.
1042     for (uint32_t first = 0x0800; first <= 0xF800; first += 0x1000) {
1043       // Select all values for the FLG, but leave FCHECK as 0 since it's fixed.
1044       for (uint32_t second = 0x00; second <= 0xE0; second += 0x20) {
1045         uint16_t prefix = first | second;
1046         // Compute FCHECK.
1047         prefix += 31 - (prefix % 31);
1048         result.push_back(prefixToStringLE(Endian::big(prefix)));
1049         // zlib won't produce this, but it is a valid prefix.
1050         if ((prefix & 0x1F) == 31) {
1051           prefix -= 31;
1052           result.push_back(prefixToStringLE(Endian::big(prefix)));
1053         }
1054       }
1055     }
1056     return result;
1057   } else {
1058     // The gzip frame starts with 2 magic bytes.
1059     return {prefixToStringLE(kGZIPMagicLE)};
1060   }
1061 }
1062
1063 bool ZlibCodec::canUncompress(const IOBuf* data, Optional<uint64_t>) const {
1064   if (type() == CodecType::ZLIB) {
1065     uint16_t value;
1066     Cursor cursor{data};
1067     if (!cursor.tryReadBE(value)) {
1068       return false;
1069     }
1070     // zlib compressed if using deflate and is a multiple of 31.
1071     return (value & 0x0F00) == 0x0800 && value % 31 == 0;
1072   } else {
1073     return dataStartsWithLE(data, kGZIPMagicLE);
1074   }
1075 }
1076
1077 uint64_t ZlibCodec::doMaxCompressedLength(uint64_t uncompressedLength) const {
1078   return deflateBound(nullptr, uncompressedLength);
1079 }
1080
1081 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
1082   return std::make_unique<ZlibCodec>(level, type);
1083 }
1084
1085 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
1086   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
1087   switch (level) {
1088   case COMPRESSION_LEVEL_FASTEST:
1089     level = 1;
1090     break;
1091   case COMPRESSION_LEVEL_DEFAULT:
1092     level = Z_DEFAULT_COMPRESSION;
1093     break;
1094   case COMPRESSION_LEVEL_BEST:
1095     level = 9;
1096     break;
1097   }
1098   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
1099     throw std::invalid_argument(to<std::string>(
1100         "ZlibCodec: invalid level: ", level));
1101   }
1102   level_ = level;
1103 }
1104
1105 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
1106                                                   uint32_t length) {
1107   CHECK_EQ(stream->avail_out, 0);
1108
1109   auto buf = IOBuf::create(length);
1110   buf->append(buf->capacity());
1111
1112   stream->next_out = buf->writableData();
1113   stream->avail_out = buf->length();
1114
1115   return buf;
1116 }
1117
1118 bool ZlibCodec::doInflate(z_stream* stream,
1119                           IOBuf* head,
1120                           uint32_t bufferLength) {
1121   if (stream->avail_out == 0) {
1122     head->prependChain(addOutputBuffer(stream, bufferLength));
1123   }
1124
1125   int rc = inflate(stream, Z_NO_FLUSH);
1126
1127   switch (rc) {
1128   case Z_OK:
1129     break;
1130   case Z_STREAM_END:
1131     return true;
1132   case Z_BUF_ERROR:
1133   case Z_NEED_DICT:
1134   case Z_DATA_ERROR:
1135   case Z_MEM_ERROR:
1136     throw std::runtime_error(to<std::string>(
1137         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
1138   default:
1139     CHECK(false) << rc << ": " << stream->msg;
1140   }
1141
1142   return false;
1143 }
1144
1145 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
1146   z_stream stream;
1147   stream.zalloc = nullptr;
1148   stream.zfree = nullptr;
1149   stream.opaque = nullptr;
1150
1151   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
1152   // base two logarithm of the maximum window size (...) The default value is
1153   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
1154   // around the compressed data instead of a zlib wrapper. The gzip header
1155   // will have no file name, no extra data, no comment, no modification time
1156   // (set to zero), no header crc, and the operating system will be set to 255
1157   // (unknown)."
1158   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
1159   // All other parameters (method, memLevel, strategy) get default values from
1160   // the zlib manual.
1161   int rc = deflateInit2(&stream,
1162                         level_,
1163                         Z_DEFLATED,
1164                         windowBits,
1165                         /* memLevel */ 8,
1166                         Z_DEFAULT_STRATEGY);
1167   if (rc != Z_OK) {
1168     throw std::runtime_error(to<std::string>(
1169         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
1170   }
1171
1172   stream.next_in = stream.next_out = nullptr;
1173   stream.avail_in = stream.avail_out = 0;
1174   stream.total_in = stream.total_out = 0;
1175
1176   bool success = false;
1177
1178   SCOPE_EXIT {
1179     rc = deflateEnd(&stream);
1180     // If we're here because of an exception, it's okay if some data
1181     // got dropped.
1182     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
1183       << rc << ": " << stream.msg;
1184   };
1185
1186   uint64_t uncompressedLength = data->computeChainDataLength();
1187   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
1188
1189   // Max 64MiB in one go
1190   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
1191   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
1192
1193   auto out = addOutputBuffer(
1194       &stream,
1195       (maxCompressedLength <= maxSingleStepLength ?
1196        maxCompressedLength :
1197        defaultBufferLength));
1198
1199   for (auto& range : *data) {
1200     uint64_t remaining = range.size();
1201     uint64_t written = 0;
1202     while (remaining) {
1203       uint32_t step = (remaining > maxSingleStepLength ?
1204                        maxSingleStepLength : remaining);
1205       stream.next_in = const_cast<uint8_t*>(range.data() + written);
1206       stream.avail_in = step;
1207       remaining -= step;
1208       written += step;
1209
1210       while (stream.avail_in != 0) {
1211         if (stream.avail_out == 0) {
1212           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1213         }
1214
1215         rc = deflate(&stream, Z_NO_FLUSH);
1216
1217         CHECK_EQ(rc, Z_OK) << stream.msg;
1218       }
1219     }
1220   }
1221
1222   do {
1223     if (stream.avail_out == 0) {
1224       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1225     }
1226
1227     rc = deflate(&stream, Z_FINISH);
1228   } while (rc == Z_OK);
1229
1230   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
1231
1232   out->prev()->trimEnd(stream.avail_out);
1233
1234   success = true;  // we survived
1235
1236   return out;
1237 }
1238
1239 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(
1240     const IOBuf* data,
1241     Optional<uint64_t> uncompressedLength) {
1242   z_stream stream;
1243   stream.zalloc = nullptr;
1244   stream.zfree = nullptr;
1245   stream.opaque = nullptr;
1246
1247   // "The windowBits parameter is the base two logarithm of the maximum window
1248   // size (...) The default value is 15 (...) add 16 to decode only the gzip
1249   // format (the zlib format will return a Z_DATA_ERROR)."
1250   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
1251   int rc = inflateInit2(&stream, windowBits);
1252   if (rc != Z_OK) {
1253     throw std::runtime_error(to<std::string>(
1254         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
1255   }
1256
1257   stream.next_in = stream.next_out = nullptr;
1258   stream.avail_in = stream.avail_out = 0;
1259   stream.total_in = stream.total_out = 0;
1260
1261   bool success = false;
1262
1263   SCOPE_EXIT {
1264     rc = inflateEnd(&stream);
1265     // If we're here because of an exception, it's okay if some data
1266     // got dropped.
1267     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
1268       << rc << ": " << stream.msg;
1269   };
1270
1271   // Max 64MiB in one go
1272   constexpr uint64_t maxSingleStepLength = uint64_t(64) << 20; // 64MiB
1273   constexpr uint64_t kBlockSize = uint64_t(32) << 10; // 32 KiB
1274   const uint64_t defaultBufferLength =
1275       computeBufferLength(data->computeChainDataLength(), kBlockSize);
1276
1277   auto out = addOutputBuffer(
1278       &stream,
1279       ((uncompressedLength && *uncompressedLength <= maxSingleStepLength)
1280            ? *uncompressedLength
1281            : defaultBufferLength));
1282
1283   bool streamEnd = false;
1284   for (auto& range : *data) {
1285     if (range.empty()) {
1286       continue;
1287     }
1288
1289     stream.next_in = const_cast<uint8_t*>(range.data());
1290     stream.avail_in = range.size();
1291
1292     while (stream.avail_in != 0) {
1293       if (streamEnd) {
1294         throw std::runtime_error(to<std::string>(
1295             "ZlibCodec: junk after end of data"));
1296       }
1297
1298       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1299     }
1300   }
1301
1302   while (!streamEnd) {
1303     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1304   }
1305
1306   out->prev()->trimEnd(stream.avail_out);
1307
1308   if (uncompressedLength && *uncompressedLength != stream.total_out) {
1309     throw std::runtime_error(
1310         to<std::string>("ZlibCodec: invalid uncompressed length"));
1311   }
1312
1313   success = true;  // we survived
1314
1315   return out;
1316 }
1317
1318 #endif  // FOLLY_HAVE_LIBZ
1319
1320 #if FOLLY_HAVE_LIBLZMA
1321
1322 /**
1323  * LZMA2 compression
1324  */
1325 class LZMA2Codec final : public Codec {
1326  public:
1327   static std::unique_ptr<Codec> create(int level, CodecType type);
1328   explicit LZMA2Codec(int level, CodecType type);
1329
1330   std::vector<std::string> validPrefixes() const override;
1331   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1332       const override;
1333
1334  private:
1335   bool doNeedsUncompressedLength() const override;
1336   uint64_t doMaxUncompressedLength() const override;
1337   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1338
1339   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
1340
1341   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
1342   std::unique_ptr<IOBuf> doUncompress(
1343       const IOBuf* data,
1344       Optional<uint64_t> uncompressedLength) override;
1345
1346   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
1347   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
1348
1349   int level_;
1350 };
1351
1352 static constexpr uint64_t kLZMA2MagicLE = 0x005A587A37FD;
1353 static constexpr unsigned kLZMA2MagicBytes = 6;
1354
1355 std::vector<std::string> LZMA2Codec::validPrefixes() const {
1356   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1357     return {};
1358   }
1359   return {prefixToStringLE(kLZMA2MagicLE, kLZMA2MagicBytes)};
1360 }
1361
1362 bool LZMA2Codec::canUncompress(const IOBuf* data, Optional<uint64_t>) const {
1363   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1364     return false;
1365   }
1366   // Returns false for all inputs less than 8 bytes.
1367   // This is okay, because no valid LZMA2 streams are less than 8 bytes.
1368   return dataStartsWithLE(data, kLZMA2MagicLE, kLZMA2MagicBytes);
1369 }
1370
1371 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
1372   return std::make_unique<LZMA2Codec>(level, type);
1373 }
1374
1375 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
1376   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
1377   switch (level) {
1378   case COMPRESSION_LEVEL_FASTEST:
1379     level = 0;
1380     break;
1381   case COMPRESSION_LEVEL_DEFAULT:
1382     level = LZMA_PRESET_DEFAULT;
1383     break;
1384   case COMPRESSION_LEVEL_BEST:
1385     level = 9;
1386     break;
1387   }
1388   if (level < 0 || level > 9) {
1389     throw std::invalid_argument(to<std::string>(
1390         "LZMA2Codec: invalid level: ", level));
1391   }
1392   level_ = level;
1393 }
1394
1395 bool LZMA2Codec::doNeedsUncompressedLength() const {
1396   return false;
1397 }
1398
1399 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
1400   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
1401   return uint64_t(1) << 63;
1402 }
1403
1404 uint64_t LZMA2Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
1405   return lzma_stream_buffer_bound(uncompressedLength) +
1406       (encodeSize() ? kMaxVarintLength64 : 0);
1407 }
1408
1409 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
1410     lzma_stream* stream,
1411     size_t length) {
1412
1413   CHECK_EQ(stream->avail_out, 0);
1414
1415   auto buf = IOBuf::create(length);
1416   buf->append(buf->capacity());
1417
1418   stream->next_out = buf->writableData();
1419   stream->avail_out = buf->length();
1420
1421   return buf;
1422 }
1423
1424 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
1425   lzma_ret rc;
1426   lzma_stream stream = LZMA_STREAM_INIT;
1427
1428   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
1429   if (rc != LZMA_OK) {
1430     throw std::runtime_error(folly::to<std::string>(
1431       "LZMA2Codec: lzma_easy_encoder error: ", rc));
1432   }
1433
1434   SCOPE_EXIT { lzma_end(&stream); };
1435
1436   uint64_t uncompressedLength = data->computeChainDataLength();
1437   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
1438
1439   // Max 64MiB in one go
1440   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
1441   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
1442
1443   auto out = addOutputBuffer(
1444     &stream,
1445     (maxCompressedLength <= maxSingleStepLength ?
1446      maxCompressedLength :
1447      defaultBufferLength));
1448
1449   if (encodeSize()) {
1450     auto size = IOBuf::createCombined(kMaxVarintLength64);
1451     encodeVarintToIOBuf(uncompressedLength, size.get());
1452     size->appendChain(std::move(out));
1453     out = std::move(size);
1454   }
1455
1456   for (auto& range : *data) {
1457     if (range.empty()) {
1458       continue;
1459     }
1460
1461     stream.next_in = const_cast<uint8_t*>(range.data());
1462     stream.avail_in = range.size();
1463
1464     while (stream.avail_in != 0) {
1465       if (stream.avail_out == 0) {
1466         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1467       }
1468
1469       rc = lzma_code(&stream, LZMA_RUN);
1470
1471       if (rc != LZMA_OK) {
1472         throw std::runtime_error(folly::to<std::string>(
1473           "LZMA2Codec: lzma_code error: ", rc));
1474       }
1475     }
1476   }
1477
1478   do {
1479     if (stream.avail_out == 0) {
1480       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1481     }
1482
1483     rc = lzma_code(&stream, LZMA_FINISH);
1484   } while (rc == LZMA_OK);
1485
1486   if (rc != LZMA_STREAM_END) {
1487     throw std::runtime_error(folly::to<std::string>(
1488       "LZMA2Codec: lzma_code ended with error: ", rc));
1489   }
1490
1491   out->prev()->trimEnd(stream.avail_out);
1492
1493   return out;
1494 }
1495
1496 bool LZMA2Codec::doInflate(lzma_stream* stream,
1497                           IOBuf* head,
1498                           size_t bufferLength) {
1499   if (stream->avail_out == 0) {
1500     head->prependChain(addOutputBuffer(stream, bufferLength));
1501   }
1502
1503   lzma_ret rc = lzma_code(stream, LZMA_RUN);
1504
1505   switch (rc) {
1506   case LZMA_OK:
1507     break;
1508   case LZMA_STREAM_END:
1509     return true;
1510   default:
1511     throw std::runtime_error(to<std::string>(
1512         "LZMA2Codec: lzma_code error: ", rc));
1513   }
1514
1515   return false;
1516 }
1517
1518 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(
1519     const IOBuf* data,
1520     Optional<uint64_t> uncompressedLength) {
1521   lzma_ret rc;
1522   lzma_stream stream = LZMA_STREAM_INIT;
1523
1524   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
1525   if (rc != LZMA_OK) {
1526     throw std::runtime_error(folly::to<std::string>(
1527       "LZMA2Codec: lzma_auto_decoder error: ", rc));
1528   }
1529
1530   SCOPE_EXIT { lzma_end(&stream); };
1531
1532   // Max 64MiB in one go
1533   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB
1534   constexpr uint32_t defaultBufferLength = uint32_t(256) << 10; // 256 KiB
1535
1536   folly::io::Cursor cursor(data);
1537   if (encodeSize()) {
1538     const uint64_t actualUncompressedLength = decodeVarintFromCursor(cursor);
1539     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
1540       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
1541     }
1542     uncompressedLength = actualUncompressedLength;
1543   }
1544
1545   auto out = addOutputBuffer(
1546       &stream,
1547       ((uncompressedLength && *uncompressedLength <= maxSingleStepLength)
1548            ? *uncompressedLength
1549            : defaultBufferLength));
1550
1551   bool streamEnd = false;
1552   auto buf = cursor.peekBytes();
1553   while (!buf.empty()) {
1554     stream.next_in = const_cast<uint8_t*>(buf.data());
1555     stream.avail_in = buf.size();
1556
1557     while (stream.avail_in != 0) {
1558       if (streamEnd) {
1559         throw std::runtime_error(to<std::string>(
1560             "LZMA2Codec: junk after end of data"));
1561       }
1562
1563       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1564     }
1565
1566     cursor.skip(buf.size());
1567     buf = cursor.peekBytes();
1568   }
1569
1570   while (!streamEnd) {
1571     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1572   }
1573
1574   out->prev()->trimEnd(stream.avail_out);
1575
1576   if (uncompressedLength && *uncompressedLength != stream.total_out) {
1577     throw std::runtime_error(
1578         to<std::string>("LZMA2Codec: invalid uncompressed length"));
1579   }
1580
1581   return out;
1582 }
1583
1584 #endif  // FOLLY_HAVE_LIBLZMA
1585
1586 #ifdef FOLLY_HAVE_LIBZSTD
1587
1588 namespace {
1589 void zstdFreeCStream(ZSTD_CStream* zcs) {
1590   ZSTD_freeCStream(zcs);
1591 }
1592
1593 void zstdFreeDStream(ZSTD_DStream* zds) {
1594   ZSTD_freeDStream(zds);
1595 }
1596 }
1597
1598 /**
1599  * ZSTD compression
1600  */
1601 class ZSTDStreamCodec final : public StreamCodec {
1602  public:
1603   static std::unique_ptr<Codec> createCodec(int level, CodecType);
1604   static std::unique_ptr<StreamCodec> createStream(int level, CodecType);
1605   explicit ZSTDStreamCodec(int level, CodecType type);
1606
1607   std::vector<std::string> validPrefixes() const override;
1608   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1609       const override;
1610
1611  private:
1612   bool doNeedsUncompressedLength() const override;
1613   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1614   Optional<uint64_t> doGetUncompressedLength(
1615       IOBuf const* data,
1616       Optional<uint64_t> uncompressedLength) const override;
1617
1618   void doResetStream() override;
1619   bool doCompressStream(
1620       ByteRange& input,
1621       MutableByteRange& output,
1622       StreamCodec::FlushOp flushOp) override;
1623   bool doUncompressStream(
1624       ByteRange& input,
1625       MutableByteRange& output,
1626       StreamCodec::FlushOp flushOp) override;
1627
1628   void resetCStream();
1629   void resetDStream();
1630
1631   bool tryBlockCompress(ByteRange& input, MutableByteRange& output) const;
1632   bool tryBlockUncompress(ByteRange& input, MutableByteRange& output) const;
1633
1634   int level_;
1635   bool needReset_{true};
1636   std::unique_ptr<
1637       ZSTD_CStream,
1638       folly::static_function_deleter<ZSTD_CStream, &zstdFreeCStream>>
1639       cstream_{nullptr};
1640   std::unique_ptr<
1641       ZSTD_DStream,
1642       folly::static_function_deleter<ZSTD_DStream, &zstdFreeDStream>>
1643       dstream_{nullptr};
1644 };
1645
1646 static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528;
1647
1648 std::vector<std::string> ZSTDStreamCodec::validPrefixes() const {
1649   return {prefixToStringLE(kZSTDMagicLE)};
1650 }
1651
1652 bool ZSTDStreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
1653     const {
1654   return dataStartsWithLE(data, kZSTDMagicLE);
1655 }
1656
1657 std::unique_ptr<Codec> ZSTDStreamCodec::createCodec(int level, CodecType type) {
1658   return make_unique<ZSTDStreamCodec>(level, type);
1659 }
1660
1661 std::unique_ptr<StreamCodec> ZSTDStreamCodec::createStream(
1662     int level,
1663     CodecType type) {
1664   return make_unique<ZSTDStreamCodec>(level, type);
1665 }
1666
1667 ZSTDStreamCodec::ZSTDStreamCodec(int level, CodecType type)
1668     : StreamCodec(type) {
1669   DCHECK(type == CodecType::ZSTD);
1670   switch (level) {
1671     case COMPRESSION_LEVEL_FASTEST:
1672       level = 1;
1673       break;
1674     case COMPRESSION_LEVEL_DEFAULT:
1675       level = 1;
1676       break;
1677     case COMPRESSION_LEVEL_BEST:
1678       level = 19;
1679       break;
1680   }
1681   if (level < 1 || level > ZSTD_maxCLevel()) {
1682     throw std::invalid_argument(
1683         to<std::string>("ZSTD: invalid level: ", level));
1684   }
1685   level_ = level;
1686 }
1687
1688 bool ZSTDStreamCodec::doNeedsUncompressedLength() const {
1689   return false;
1690 }
1691
1692 uint64_t ZSTDStreamCodec::doMaxCompressedLength(
1693     uint64_t uncompressedLength) const {
1694   return ZSTD_compressBound(uncompressedLength);
1695 }
1696
1697 void zstdThrowIfError(size_t rc) {
1698   if (!ZSTD_isError(rc)) {
1699     return;
1700   }
1701   throw std::runtime_error(
1702       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1703 }
1704
1705 Optional<uint64_t> ZSTDStreamCodec::doGetUncompressedLength(
1706     IOBuf const* data,
1707     Optional<uint64_t> uncompressedLength) const {
1708   // Read decompressed size from frame if available in first IOBuf.
1709   auto const decompressedSize =
1710       ZSTD_getDecompressedSize(data->data(), data->length());
1711   if (decompressedSize != 0) {
1712     if (uncompressedLength && *uncompressedLength != decompressedSize) {
1713       throw std::runtime_error("ZSTD: invalid uncompressed length");
1714     }
1715     uncompressedLength = decompressedSize;
1716   }
1717   return uncompressedLength;
1718 }
1719
1720 void ZSTDStreamCodec::doResetStream() {
1721   needReset_ = true;
1722 }
1723
1724 bool ZSTDStreamCodec::tryBlockCompress(
1725     ByteRange& input,
1726     MutableByteRange& output) const {
1727   DCHECK(needReset_);
1728   // We need to know that we have enough output space to use block compression
1729   if (output.size() < ZSTD_compressBound(input.size())) {
1730     return false;
1731   }
1732   size_t const length = ZSTD_compress(
1733       output.data(), output.size(), input.data(), input.size(), level_);
1734   zstdThrowIfError(length);
1735   input.uncheckedAdvance(input.size());
1736   output.uncheckedAdvance(length);
1737   return true;
1738 }
1739
1740 void ZSTDStreamCodec::resetCStream() {
1741   if (!cstream_) {
1742     cstream_.reset(ZSTD_createCStream());
1743     if (!cstream_) {
1744       throw std::bad_alloc{};
1745     }
1746   }
1747   // Advanced API usage works for all supported versions of zstd.
1748   // Required to set contentSizeFlag.
1749   auto params = ZSTD_getParams(level_, uncompressedLength().value_or(0), 0);
1750   params.fParams.contentSizeFlag = uncompressedLength().hasValue();
1751   zstdThrowIfError(ZSTD_initCStream_advanced(
1752       cstream_.get(), nullptr, 0, params, uncompressedLength().value_or(0)));
1753 }
1754
1755 bool ZSTDStreamCodec::doCompressStream(
1756     ByteRange& input,
1757     MutableByteRange& output,
1758     StreamCodec::FlushOp flushOp) {
1759   if (needReset_) {
1760     // If we are given all the input in one chunk try to use block compression
1761     if (flushOp == StreamCodec::FlushOp::END &&
1762         tryBlockCompress(input, output)) {
1763       return true;
1764     }
1765     resetCStream();
1766     needReset_ = false;
1767   }
1768   ZSTD_inBuffer in = {input.data(), input.size(), 0};
1769   ZSTD_outBuffer out = {output.data(), output.size(), 0};
1770   SCOPE_EXIT {
1771     input.uncheckedAdvance(in.pos);
1772     output.uncheckedAdvance(out.pos);
1773   };
1774   if (flushOp == StreamCodec::FlushOp::NONE || !input.empty()) {
1775     zstdThrowIfError(ZSTD_compressStream(cstream_.get(), &out, &in));
1776   }
1777   if (in.pos == in.size && flushOp != StreamCodec::FlushOp::NONE) {
1778     size_t rc;
1779     switch (flushOp) {
1780       case StreamCodec::FlushOp::FLUSH:
1781         rc = ZSTD_flushStream(cstream_.get(), &out);
1782         break;
1783       case StreamCodec::FlushOp::END:
1784         rc = ZSTD_endStream(cstream_.get(), &out);
1785         break;
1786       default:
1787         throw std::invalid_argument("ZSTD: invalid FlushOp");
1788     }
1789     zstdThrowIfError(rc);
1790     if (rc == 0) {
1791       return true;
1792     }
1793   }
1794   return false;
1795 }
1796
1797 bool ZSTDStreamCodec::tryBlockUncompress(
1798     ByteRange& input,
1799     MutableByteRange& output) const {
1800   DCHECK(needReset_);
1801 #if ZSTD_VERSION_NUMBER < 10104
1802   // We require ZSTD_findFrameCompressedSize() to perform this optimization.
1803   return false;
1804 #else
1805   // We need to know the uncompressed length and have enough output space.
1806   if (!uncompressedLength() || output.size() < *uncompressedLength()) {
1807     return false;
1808   }
1809   size_t const compressedLength =
1810       ZSTD_findFrameCompressedSize(input.data(), input.size());
1811   zstdThrowIfError(compressedLength);
1812   size_t const length = ZSTD_decompress(
1813       output.data(), *uncompressedLength(), input.data(), compressedLength);
1814   zstdThrowIfError(length);
1815   DCHECK_EQ(length, *uncompressedLength());
1816   input.uncheckedAdvance(compressedLength);
1817   output.uncheckedAdvance(length);
1818   return true;
1819 #endif
1820 }
1821
1822 void ZSTDStreamCodec::resetDStream() {
1823   if (!dstream_) {
1824     dstream_.reset(ZSTD_createDStream());
1825     if (!dstream_) {
1826       throw std::bad_alloc{};
1827     }
1828   }
1829   zstdThrowIfError(ZSTD_initDStream(dstream_.get()));
1830 }
1831
1832 bool ZSTDStreamCodec::doUncompressStream(
1833     ByteRange& input,
1834     MutableByteRange& output,
1835     StreamCodec::FlushOp flushOp) {
1836   if (needReset_) {
1837     // If we are given all the input in one chunk try to use block uncompression
1838     if (flushOp == StreamCodec::FlushOp::END &&
1839         tryBlockUncompress(input, output)) {
1840       return true;
1841     }
1842     resetDStream();
1843     needReset_ = false;
1844   }
1845   ZSTD_inBuffer in = {input.data(), input.size(), 0};
1846   ZSTD_outBuffer out = {output.data(), output.size(), 0};
1847   SCOPE_EXIT {
1848     input.uncheckedAdvance(in.pos);
1849     output.uncheckedAdvance(out.pos);
1850   };
1851   size_t const rc = ZSTD_decompressStream(dstream_.get(), &out, &in);
1852   zstdThrowIfError(rc);
1853   return rc == 0;
1854 }
1855
1856 #endif // FOLLY_HAVE_LIBZSTD
1857
1858 #if FOLLY_HAVE_LIBBZ2
1859
1860 class Bzip2Codec final : public Codec {
1861  public:
1862   static std::unique_ptr<Codec> create(int level, CodecType type);
1863   explicit Bzip2Codec(int level, CodecType type);
1864
1865   std::vector<std::string> validPrefixes() const override;
1866   bool canUncompress(IOBuf const* data, Optional<uint64_t> uncompressedLength)
1867       const override;
1868
1869  private:
1870   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1871   std::unique_ptr<IOBuf> doCompress(IOBuf const* data) override;
1872   std::unique_ptr<IOBuf> doUncompress(
1873       IOBuf const* data,
1874       Optional<uint64_t> uncompressedLength) override;
1875
1876   int level_;
1877 };
1878
1879 /* static */ std::unique_ptr<Codec> Bzip2Codec::create(
1880     int level,
1881     CodecType type) {
1882   return std::make_unique<Bzip2Codec>(level, type);
1883 }
1884
1885 Bzip2Codec::Bzip2Codec(int level, CodecType type) : Codec(type) {
1886   DCHECK(type == CodecType::BZIP2);
1887   switch (level) {
1888     case COMPRESSION_LEVEL_FASTEST:
1889       level = 1;
1890       break;
1891     case COMPRESSION_LEVEL_DEFAULT:
1892       level = 9;
1893       break;
1894     case COMPRESSION_LEVEL_BEST:
1895       level = 9;
1896       break;
1897   }
1898   if (level < 1 || level > 9) {
1899     throw std::invalid_argument(
1900         to<std::string>("Bzip2: invalid level: ", level));
1901   }
1902   level_ = level;
1903 }
1904
1905 static uint32_t constexpr kBzip2MagicLE = 0x685a42;
1906 static uint64_t constexpr kBzip2MagicBytes = 3;
1907
1908 std::vector<std::string> Bzip2Codec::validPrefixes() const {
1909   return {prefixToStringLE(kBzip2MagicLE, kBzip2MagicBytes)};
1910 }
1911
1912 bool Bzip2Codec::canUncompress(IOBuf const* data, Optional<uint64_t>) const {
1913   return dataStartsWithLE(data, kBzip2MagicLE, kBzip2MagicBytes);
1914 }
1915
1916 uint64_t Bzip2Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
1917   // http://www.bzip.org/1.0.5/bzip2-manual-1.0.5.html#bzbufftobuffcompress
1918   //   To guarantee that the compressed data will fit in its buffer, allocate an
1919   //   output buffer of size 1% larger than the uncompressed data, plus six
1920   //   hundred extra bytes.
1921   return uncompressedLength + uncompressedLength / 100 + 600;
1922 }
1923
1924 static bz_stream createBzStream() {
1925   bz_stream stream;
1926   stream.bzalloc = nullptr;
1927   stream.bzfree = nullptr;
1928   stream.opaque = nullptr;
1929   stream.next_in = stream.next_out = nullptr;
1930   stream.avail_in = stream.avail_out = 0;
1931   return stream;
1932 }
1933
1934 // Throws on error condition, otherwise returns the code.
1935 static int bzCheck(int const rc) {
1936   switch (rc) {
1937     case BZ_OK:
1938     case BZ_RUN_OK:
1939     case BZ_FLUSH_OK:
1940     case BZ_FINISH_OK:
1941     case BZ_STREAM_END:
1942       return rc;
1943     default:
1944       throw std::runtime_error(to<std::string>("Bzip2 error: ", rc));
1945   }
1946 }
1947
1948 static std::unique_ptr<IOBuf> addOutputBuffer(
1949     bz_stream* stream,
1950     uint64_t const bufferLength) {
1951   DCHECK_LE(bufferLength, std::numeric_limits<unsigned>::max());
1952   DCHECK_EQ(stream->avail_out, 0);
1953
1954   auto buf = IOBuf::create(bufferLength);
1955   buf->append(buf->capacity());
1956
1957   stream->next_out = reinterpret_cast<char*>(buf->writableData());
1958   stream->avail_out = buf->length();
1959
1960   return buf;
1961 }
1962
1963 std::unique_ptr<IOBuf> Bzip2Codec::doCompress(IOBuf const* data) {
1964   bz_stream stream = createBzStream();
1965   bzCheck(BZ2_bzCompressInit(&stream, level_, 0, 0));
1966   SCOPE_EXIT {
1967     bzCheck(BZ2_bzCompressEnd(&stream));
1968   };
1969
1970   uint64_t const uncompressedLength = data->computeChainDataLength();
1971   uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength);
1972   uint64_t constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MiB
1973   uint64_t constexpr kDefaultBufferLength = uint64_t(4) << 20;
1974
1975   auto out = addOutputBuffer(
1976       &stream,
1977       maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen
1978                                                : kDefaultBufferLength);
1979
1980   for (auto range : *data) {
1981     while (!range.empty()) {
1982       auto const inSize = std::min<size_t>(range.size(), kMaxSingleStepLength);
1983       stream.next_in =
1984           const_cast<char*>(reinterpret_cast<char const*>(range.data()));
1985       stream.avail_in = inSize;
1986
1987       if (stream.avail_out == 0) {
1988         out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1989       }
1990
1991       bzCheck(BZ2_bzCompress(&stream, BZ_RUN));
1992       range.uncheckedAdvance(inSize - stream.avail_in);
1993     }
1994   }
1995   do {
1996     if (stream.avail_out == 0) {
1997       out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1998     }
1999   } while (bzCheck(BZ2_bzCompress(&stream, BZ_FINISH)) != BZ_STREAM_END);
2000
2001   out->prev()->trimEnd(stream.avail_out);
2002
2003   return out;
2004 }
2005
2006 std::unique_ptr<IOBuf> Bzip2Codec::doUncompress(
2007     const IOBuf* data,
2008     Optional<uint64_t> uncompressedLength) {
2009   bz_stream stream = createBzStream();
2010   bzCheck(BZ2_bzDecompressInit(&stream, 0, 0));
2011   SCOPE_EXIT {
2012     bzCheck(BZ2_bzDecompressEnd(&stream));
2013   };
2014
2015   uint64_t constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MiB
2016   uint64_t const kBlockSize = uint64_t(100) << 10; // 100 KiB
2017   uint64_t const kDefaultBufferLength =
2018       computeBufferLength(data->computeChainDataLength(), kBlockSize);
2019
2020   auto out = addOutputBuffer(
2021       &stream,
2022       ((uncompressedLength && *uncompressedLength <= kMaxSingleStepLength)
2023            ? *uncompressedLength
2024            : kDefaultBufferLength));
2025
2026   int rc = BZ_OK;
2027   for (auto range : *data) {
2028     while (!range.empty()) {
2029       auto const inSize = std::min<size_t>(range.size(), kMaxSingleStepLength);
2030       stream.next_in =
2031           const_cast<char*>(reinterpret_cast<char const*>(range.data()));
2032       stream.avail_in = inSize;
2033
2034       if (stream.avail_out == 0) {
2035         out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
2036       }
2037
2038       rc = bzCheck(BZ2_bzDecompress(&stream));
2039       range.uncheckedAdvance(inSize - stream.avail_in);
2040     }
2041   }
2042   while (rc != BZ_STREAM_END) {
2043     if (stream.avail_out == 0) {
2044       out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
2045     }
2046
2047     rc = bzCheck(BZ2_bzDecompress(&stream));
2048   }
2049
2050   out->prev()->trimEnd(stream.avail_out);
2051
2052   uint64_t const totalOut =
2053       (uint64_t(stream.total_out_hi32) << 32) + stream.total_out_lo32;
2054   if (uncompressedLength && uncompressedLength != totalOut) {
2055     throw std::runtime_error("Bzip2 error: Invalid uncompressed length");
2056   }
2057
2058   return out;
2059 }
2060
2061 #endif // FOLLY_HAVE_LIBBZ2
2062
2063 /**
2064  * Automatic decompression
2065  */
2066 class AutomaticCodec final : public Codec {
2067  public:
2068   static std::unique_ptr<Codec> create(
2069       std::vector<std::unique_ptr<Codec>> customCodecs);
2070   explicit AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs);
2071
2072   std::vector<std::string> validPrefixes() const override;
2073   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
2074       const override;
2075
2076  private:
2077   bool doNeedsUncompressedLength() const override;
2078   uint64_t doMaxUncompressedLength() const override;
2079
2080   uint64_t doMaxCompressedLength(uint64_t) const override {
2081     throw std::runtime_error(
2082         "AutomaticCodec error: maxCompressedLength() not supported.");
2083   }
2084   std::unique_ptr<IOBuf> doCompress(const IOBuf*) override {
2085     throw std::runtime_error("AutomaticCodec error: compress() not supported.");
2086   }
2087   std::unique_ptr<IOBuf> doUncompress(
2088       const IOBuf* data,
2089       Optional<uint64_t> uncompressedLength) override;
2090
2091   void addCodecIfSupported(CodecType type);
2092
2093   // Throws iff the codecs aren't compatible (very slow)
2094   void checkCompatibleCodecs() const;
2095
2096   std::vector<std::unique_ptr<Codec>> codecs_;
2097   bool needsUncompressedLength_;
2098   uint64_t maxUncompressedLength_;
2099 };
2100
2101 std::vector<std::string> AutomaticCodec::validPrefixes() const {
2102   std::unordered_set<std::string> prefixes;
2103   for (const auto& codec : codecs_) {
2104     const auto codecPrefixes = codec->validPrefixes();
2105     prefixes.insert(codecPrefixes.begin(), codecPrefixes.end());
2106   }
2107   return std::vector<std::string>{prefixes.begin(), prefixes.end()};
2108 }
2109
2110 bool AutomaticCodec::canUncompress(
2111     const IOBuf* data,
2112     Optional<uint64_t> uncompressedLength) const {
2113   return std::any_of(
2114       codecs_.begin(),
2115       codecs_.end(),
2116       [data, uncompressedLength](std::unique_ptr<Codec> const& codec) {
2117         return codec->canUncompress(data, uncompressedLength);
2118       });
2119 }
2120
2121 void AutomaticCodec::addCodecIfSupported(CodecType type) {
2122   const bool present = std::any_of(
2123       codecs_.begin(),
2124       codecs_.end(),
2125       [&type](std::unique_ptr<Codec> const& codec) {
2126         return codec->type() == type;
2127       });
2128   if (hasCodec(type) && !present) {
2129     codecs_.push_back(getCodec(type));
2130   }
2131 }
2132
2133 /* static */ std::unique_ptr<Codec> AutomaticCodec::create(
2134     std::vector<std::unique_ptr<Codec>> customCodecs) {
2135   return std::make_unique<AutomaticCodec>(std::move(customCodecs));
2136 }
2137
2138 AutomaticCodec::AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs)
2139     : Codec(CodecType::USER_DEFINED), codecs_(std::move(customCodecs)) {
2140   // Fastest -> slowest
2141   addCodecIfSupported(CodecType::LZ4_FRAME);
2142   addCodecIfSupported(CodecType::ZSTD);
2143   addCodecIfSupported(CodecType::ZLIB);
2144   addCodecIfSupported(CodecType::GZIP);
2145   addCodecIfSupported(CodecType::LZMA2);
2146   addCodecIfSupported(CodecType::BZIP2);
2147   if (kIsDebug) {
2148     checkCompatibleCodecs();
2149   }
2150   // Check that none of the codes are are null
2151   DCHECK(std::none_of(
2152       codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
2153         return codec == nullptr;
2154       }));
2155
2156   needsUncompressedLength_ = std::any_of(
2157       codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
2158         return codec->needsUncompressedLength();
2159       });
2160
2161   const auto it = std::max_element(
2162       codecs_.begin(),
2163       codecs_.end(),
2164       [](std::unique_ptr<Codec> const& lhs, std::unique_ptr<Codec> const& rhs) {
2165         return lhs->maxUncompressedLength() < rhs->maxUncompressedLength();
2166       });
2167   DCHECK(it != codecs_.end());
2168   maxUncompressedLength_ = (*it)->maxUncompressedLength();
2169 }
2170
2171 void AutomaticCodec::checkCompatibleCodecs() const {
2172   // Keep track of all the possible headers.
2173   std::unordered_set<std::string> headers;
2174   // The empty header is not allowed.
2175   headers.insert("");
2176   // Step 1:
2177   // Construct a set of headers and check that none of the headers occur twice.
2178   // Eliminate edge cases.
2179   for (auto&& codec : codecs_) {
2180     const auto codecHeaders = codec->validPrefixes();
2181     // Codecs without any valid headers are not allowed.
2182     if (codecHeaders.empty()) {
2183       throw std::invalid_argument{
2184           "AutomaticCodec: validPrefixes() must not be empty."};
2185     }
2186     // Insert all the headers for the current codec.
2187     const size_t beforeSize = headers.size();
2188     headers.insert(codecHeaders.begin(), codecHeaders.end());
2189     // Codecs are not compatible if any header occurred twice.
2190     if (beforeSize + codecHeaders.size() != headers.size()) {
2191       throw std::invalid_argument{
2192           "AutomaticCodec: Two valid prefixes collide."};
2193     }
2194   }
2195   // Step 2:
2196   // Check if any strict non-empty prefix of any header is a header.
2197   for (const auto& header : headers) {
2198     for (size_t i = 1; i < header.size(); ++i) {
2199       if (headers.count(header.substr(0, i))) {
2200         throw std::invalid_argument{
2201             "AutomaticCodec: One valid prefix is a prefix of another valid "
2202             "prefix."};
2203       }
2204     }
2205   }
2206 }
2207
2208 bool AutomaticCodec::doNeedsUncompressedLength() const {
2209   return needsUncompressedLength_;
2210 }
2211
2212 uint64_t AutomaticCodec::doMaxUncompressedLength() const {
2213   return maxUncompressedLength_;
2214 }
2215
2216 std::unique_ptr<IOBuf> AutomaticCodec::doUncompress(
2217     const IOBuf* data,
2218     Optional<uint64_t> uncompressedLength) {
2219   for (auto&& codec : codecs_) {
2220     if (codec->canUncompress(data, uncompressedLength)) {
2221       return codec->uncompress(data, uncompressedLength);
2222     }
2223   }
2224   throw std::runtime_error("AutomaticCodec error: Unknown compressed data");
2225 }
2226
2227 using CodecFactory = std::unique_ptr<Codec> (*)(int, CodecType);
2228 using StreamCodecFactory = std::unique_ptr<StreamCodec> (*)(int, CodecType);
2229 struct Factory {
2230   CodecFactory codec;
2231   StreamCodecFactory stream;
2232 };
2233
2234 constexpr Factory
2235     codecFactories[static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
2236         {}, // USER_DEFINED
2237         {NoCompressionCodec::create, nullptr},
2238
2239 #if FOLLY_HAVE_LIBLZ4
2240         {LZ4Codec::create, nullptr},
2241 #else
2242         {},
2243 #endif
2244
2245 #if FOLLY_HAVE_LIBSNAPPY
2246         {SnappyCodec::create, nullptr},
2247 #else
2248         {},
2249 #endif
2250
2251 #if FOLLY_HAVE_LIBZ
2252         {ZlibCodec::create, nullptr},
2253 #else
2254         {},
2255 #endif
2256
2257 #if FOLLY_HAVE_LIBLZ4
2258         {LZ4Codec::create, nullptr},
2259 #else
2260         {},
2261 #endif
2262
2263 #if FOLLY_HAVE_LIBLZMA
2264         {LZMA2Codec::create, nullptr},
2265         {LZMA2Codec::create, nullptr},
2266 #else
2267         {},
2268         {},
2269 #endif
2270
2271 #if FOLLY_HAVE_LIBZSTD
2272         {ZSTDStreamCodec::createCodec, ZSTDStreamCodec::createStream},
2273 #else
2274         {},
2275 #endif
2276
2277 #if FOLLY_HAVE_LIBZ
2278         {ZlibCodec::create, nullptr},
2279 #else
2280         {},
2281 #endif
2282
2283 #if (FOLLY_HAVE_LIBLZ4 && LZ4_VERSION_NUMBER >= 10301)
2284         {LZ4FrameCodec::create, nullptr},
2285 #else
2286         {},
2287 #endif
2288
2289 #if FOLLY_HAVE_LIBBZ2
2290         {Bzip2Codec::create, nullptr},
2291 #else
2292         {},
2293 #endif
2294 };
2295
2296 Factory const& getFactory(CodecType type) {
2297   size_t const idx = static_cast<size_t>(type);
2298   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
2299     throw std::invalid_argument(
2300         to<std::string>("Compression type ", idx, " invalid"));
2301   }
2302   return codecFactories[idx];
2303 }
2304 } // namespace
2305
2306 bool hasCodec(CodecType type) {
2307   return getFactory(type).codec != nullptr;
2308 }
2309
2310 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
2311   auto const factory = getFactory(type).codec;
2312   if (!factory) {
2313     throw std::invalid_argument(
2314         to<std::string>("Compression type ", type, " not supported"));
2315   }
2316   auto codec = (*factory)(level, type);
2317   DCHECK(codec->type() == type);
2318   return codec;
2319 }
2320
2321 bool hasStreamCodec(CodecType type) {
2322   return getFactory(type).stream != nullptr;
2323 }
2324
2325 std::unique_ptr<StreamCodec> getStreamCodec(CodecType type, int level) {
2326   auto const factory = getFactory(type).stream;
2327   if (!factory) {
2328     throw std::invalid_argument(
2329         to<std::string>("Compression type ", type, " not supported"));
2330   }
2331   auto codec = (*factory)(level, type);
2332   DCHECK(codec->type() == type);
2333   return codec;
2334 }
2335
2336 std::unique_ptr<Codec> getAutoUncompressionCodec(
2337     std::vector<std::unique_ptr<Codec>> customCodecs) {
2338   return AutomaticCodec::create(std::move(customCodecs));
2339 }
2340 }}  // namespaces