Revert D6745720: [folly][compression] Log (de)compression bytes
[folly.git] / folly / compression / Compression.cpp
1 /*
2  * Copyright 2013-present Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/compression/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #if LZ4_VERSION_NUMBER >= 10301
23 #include <lz4frame.h>
24 #endif
25 #endif
26
27 #include <glog/logging.h>
28
29 #if FOLLY_HAVE_LIBSNAPPY
30 #include <snappy-sinksource.h>
31 #include <snappy.h>
32 #endif
33
34 #if FOLLY_HAVE_LIBZ
35 #include <folly/compression/Zlib.h>
36 #endif
37
38 #if FOLLY_HAVE_LIBLZMA
39 #include <lzma.h>
40 #endif
41
42 #if FOLLY_HAVE_LIBZSTD
43 #define ZSTD_STATIC_LINKING_ONLY
44 #include <zstd.h>
45 #endif
46
47 #if FOLLY_HAVE_LIBBZ2
48 #include <bzlib.h>
49 #endif
50
51 #include <folly/Conv.h>
52 #include <folly/Memory.h>
53 #include <folly/Portability.h>
54 #include <folly/ScopeGuard.h>
55 #include <folly/Varint.h>
56 #include <folly/compression/Utils.h>
57 #include <folly/io/Cursor.h>
58 #include <folly/lang/Bits.h>
59 #include <algorithm>
60 #include <unordered_set>
61
62 using folly::io::compression::detail::dataStartsWithLE;
63 using folly::io::compression::detail::prefixToStringLE;
64
65 namespace folly {
66 namespace io {
67
68 Codec::Codec(CodecType type) : type_(type) { }
69
70 // Ensure consistent behavior in the nullptr case
71 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
72   if (data == nullptr) {
73     throw std::invalid_argument("Codec: data must not be nullptr");
74   }
75   uint64_t len = data->computeChainDataLength();
76   if (len > maxUncompressedLength()) {
77     throw std::runtime_error("Codec: uncompressed length too large");
78   }
79
80   return doCompress(data);
81 }
82
83 std::string Codec::compress(const StringPiece data) {
84   const uint64_t len = data.size();
85   if (len > maxUncompressedLength()) {
86     throw std::runtime_error("Codec: uncompressed length too large");
87   }
88
89   return doCompressString(data);
90 }
91
92 std::unique_ptr<IOBuf> Codec::uncompress(
93     const IOBuf* data,
94     Optional<uint64_t> uncompressedLength) {
95   if (data == nullptr) {
96     throw std::invalid_argument("Codec: data must not be nullptr");
97   }
98   if (!uncompressedLength) {
99     if (needsUncompressedLength()) {
100       throw std::invalid_argument("Codec: uncompressed length required");
101     }
102   } else if (*uncompressedLength > maxUncompressedLength()) {
103     throw std::runtime_error("Codec: uncompressed length too large");
104   }
105
106   if (data->empty()) {
107     if (uncompressedLength.value_or(0) != 0) {
108       throw std::runtime_error("Codec: invalid uncompressed length");
109     }
110     return IOBuf::create(0);
111   }
112
113   return doUncompress(data, uncompressedLength);
114 }
115
116 std::string Codec::uncompress(
117     const StringPiece data,
118     Optional<uint64_t> uncompressedLength) {
119   if (!uncompressedLength) {
120     if (needsUncompressedLength()) {
121       throw std::invalid_argument("Codec: uncompressed length required");
122     }
123   } else if (*uncompressedLength > maxUncompressedLength()) {
124     throw std::runtime_error("Codec: uncompressed length too large");
125   }
126
127   if (data.empty()) {
128     if (uncompressedLength.value_or(0) != 0) {
129       throw std::runtime_error("Codec: invalid uncompressed length");
130     }
131     return "";
132   }
133
134   return doUncompressString(data, uncompressedLength);
135 }
136
137 bool Codec::needsUncompressedLength() const {
138   return doNeedsUncompressedLength();
139 }
140
141 uint64_t Codec::maxUncompressedLength() const {
142   return doMaxUncompressedLength();
143 }
144
145 bool Codec::doNeedsUncompressedLength() const {
146   return false;
147 }
148
149 uint64_t Codec::doMaxUncompressedLength() const {
150   return UNLIMITED_UNCOMPRESSED_LENGTH;
151 }
152
153 std::vector<std::string> Codec::validPrefixes() const {
154   return {};
155 }
156
157 bool Codec::canUncompress(const IOBuf*, Optional<uint64_t>) const {
158   return false;
159 }
160
161 std::string Codec::doCompressString(const StringPiece data) {
162   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
163   auto outputBuffer = doCompress(&inputBuffer);
164   std::string output;
165   output.reserve(outputBuffer->computeChainDataLength());
166   for (auto range : *outputBuffer) {
167     output.append(reinterpret_cast<const char*>(range.data()), range.size());
168   }
169   return output;
170 }
171
172 std::string Codec::doUncompressString(
173     const StringPiece data,
174     Optional<uint64_t> uncompressedLength) {
175   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
176   auto outputBuffer = doUncompress(&inputBuffer, uncompressedLength);
177   std::string output;
178   output.reserve(outputBuffer->computeChainDataLength());
179   for (auto range : *outputBuffer) {
180     output.append(reinterpret_cast<const char*>(range.data()), range.size());
181   }
182   return output;
183 }
184
185 uint64_t Codec::maxCompressedLength(uint64_t uncompressedLength) const {
186   return doMaxCompressedLength(uncompressedLength);
187 }
188
189 Optional<uint64_t> Codec::getUncompressedLength(
190     const folly::IOBuf* data,
191     Optional<uint64_t> uncompressedLength) const {
192   auto const compressedLength = data->computeChainDataLength();
193   if (compressedLength == 0) {
194     if (uncompressedLength.value_or(0) != 0) {
195       throw std::runtime_error("Invalid uncompressed length");
196     }
197     return 0;
198   }
199   return doGetUncompressedLength(data, uncompressedLength);
200 }
201
202 Optional<uint64_t> Codec::doGetUncompressedLength(
203     const folly::IOBuf*,
204     Optional<uint64_t> uncompressedLength) const {
205   return uncompressedLength;
206 }
207
208 bool StreamCodec::needsDataLength() const {
209   return doNeedsDataLength();
210 }
211
212 bool StreamCodec::doNeedsDataLength() const {
213   return false;
214 }
215
216 void StreamCodec::assertStateIs(State expected) const {
217   if (state_ != expected) {
218     throw std::logic_error(folly::to<std::string>(
219         "Codec: state is ", state_, "; expected state ", expected));
220   }
221 }
222
223 void StreamCodec::resetStream(Optional<uint64_t> uncompressedLength) {
224   state_ = State::RESET;
225   uncompressedLength_ = uncompressedLength;
226   progressMade_ = true;
227   doResetStream();
228 }
229
230 bool StreamCodec::compressStream(
231     ByteRange& input,
232     MutableByteRange& output,
233     StreamCodec::FlushOp flushOp) {
234   if (state_ == State::RESET && input.empty() &&
235       flushOp == StreamCodec::FlushOp::END &&
236       uncompressedLength().value_or(0) != 0) {
237     throw std::runtime_error("Codec: invalid uncompressed length");
238   }
239
240   if (!uncompressedLength() && needsDataLength()) {
241     throw std::runtime_error("Codec: uncompressed length required");
242   }
243   if (state_ == State::RESET && !input.empty() &&
244       uncompressedLength() == uint64_t(0)) {
245     throw std::runtime_error("Codec: invalid uncompressed length");
246   }
247   // Handle input state transitions
248   switch (flushOp) {
249     case StreamCodec::FlushOp::NONE:
250       if (state_ == State::RESET) {
251         state_ = State::COMPRESS;
252       }
253       assertStateIs(State::COMPRESS);
254       break;
255     case StreamCodec::FlushOp::FLUSH:
256       if (state_ == State::RESET || state_ == State::COMPRESS) {
257         state_ = State::COMPRESS_FLUSH;
258       }
259       assertStateIs(State::COMPRESS_FLUSH);
260       break;
261     case StreamCodec::FlushOp::END:
262       if (state_ == State::RESET || state_ == State::COMPRESS) {
263         state_ = State::COMPRESS_END;
264       }
265       assertStateIs(State::COMPRESS_END);
266       break;
267   }
268   size_t const inputSize = input.size();
269   size_t const outputSize = output.size();
270   bool const done = doCompressStream(input, output, flushOp);
271   if (!done && inputSize == input.size() && outputSize == output.size()) {
272     if (!progressMade_) {
273       throw std::runtime_error("Codec: No forward progress made");
274     }
275     // Throw an exception if there is no progress again next time
276     progressMade_ = false;
277   } else {
278     progressMade_ = true;
279   }
280   // Handle output state transitions
281   if (done) {
282     if (state_ == State::COMPRESS_FLUSH) {
283       state_ = State::COMPRESS;
284     } else if (state_ == State::COMPRESS_END) {
285       state_ = State::END;
286     }
287     // Check internal invariants
288     DCHECK(input.empty());
289     DCHECK(flushOp != StreamCodec::FlushOp::NONE);
290   }
291   return done;
292 }
293
294 bool StreamCodec::uncompressStream(
295     ByteRange& input,
296     MutableByteRange& output,
297     StreamCodec::FlushOp flushOp) {
298   if (state_ == State::RESET && input.empty()) {
299     if (uncompressedLength().value_or(0) == 0) {
300       return true;
301     }
302     return false;
303   }
304   // Handle input state transitions
305   if (state_ == State::RESET) {
306     state_ = State::UNCOMPRESS;
307   }
308   assertStateIs(State::UNCOMPRESS);
309   size_t const inputSize = input.size();
310   size_t const outputSize = output.size();
311   bool const done = doUncompressStream(input, output, flushOp);
312   if (!done && inputSize == input.size() && outputSize == output.size()) {
313     if (!progressMade_) {
314       throw std::runtime_error("Codec: no forward progress made");
315     }
316     // Throw an exception if there is no progress again next time
317     progressMade_ = false;
318   } else {
319     progressMade_ = true;
320   }
321   // Handle output state transitions
322   if (done) {
323     state_ = State::END;
324   }
325   return done;
326 }
327
328 static std::unique_ptr<IOBuf> addOutputBuffer(
329     MutableByteRange& output,
330     uint64_t size) {
331   DCHECK(output.empty());
332   auto buffer = IOBuf::create(size);
333   buffer->append(buffer->capacity());
334   output = {buffer->writableData(), buffer->length()};
335   return buffer;
336 }
337
338 std::unique_ptr<IOBuf> StreamCodec::doCompress(IOBuf const* data) {
339   uint64_t const uncompressedLength = data->computeChainDataLength();
340   resetStream(uncompressedLength);
341   uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength);
342
343   auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB
344   auto constexpr kDefaultBufferLength = uint64_t(4) << 20; // 4 MB
345
346   MutableByteRange output;
347   auto buffer = addOutputBuffer(
348       output,
349       maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen
350                                                : kDefaultBufferLength);
351
352   // Compress the entire IOBuf chain into the IOBuf chain pointed to by buffer
353   IOBuf const* current = data;
354   ByteRange input{current->data(), current->length()};
355   StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
356   bool done = false;
357   while (!done) {
358     while (input.empty() && current->next() != data) {
359       current = current->next();
360       input = {current->data(), current->length()};
361     }
362     if (current->next() == data) {
363       // This is the last input buffer so end the stream
364       flushOp = StreamCodec::FlushOp::END;
365     }
366     if (output.empty()) {
367       buffer->prependChain(addOutputBuffer(output, kDefaultBufferLength));
368     }
369     done = compressStream(input, output, flushOp);
370     if (done) {
371       DCHECK(input.empty());
372       DCHECK(flushOp == StreamCodec::FlushOp::END);
373       DCHECK_EQ(current->next(), data);
374     }
375   }
376   buffer->prev()->trimEnd(output.size());
377   return buffer;
378 }
379
380 static uint64_t computeBufferLength(
381     uint64_t const compressedLength,
382     uint64_t const blockSize) {
383   uint64_t constexpr kMaxBufferLength = uint64_t(4) << 20; // 4 MiB
384   uint64_t const goodBufferSize = 4 * std::max(blockSize, compressedLength);
385   return std::min(goodBufferSize, kMaxBufferLength);
386 }
387
388 std::unique_ptr<IOBuf> StreamCodec::doUncompress(
389     IOBuf const* data,
390     Optional<uint64_t> uncompressedLength) {
391   auto constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MB
392   auto constexpr kBlockSize = uint64_t(128) << 10;
393   auto const defaultBufferLength =
394       computeBufferLength(data->computeChainDataLength(), kBlockSize);
395
396   uncompressedLength = getUncompressedLength(data, uncompressedLength);
397   resetStream(uncompressedLength);
398
399   MutableByteRange output;
400   auto buffer = addOutputBuffer(
401       output,
402       (uncompressedLength && *uncompressedLength <= kMaxSingleStepLength
403            ? *uncompressedLength
404            : defaultBufferLength));
405
406   // Uncompress the entire IOBuf chain into the IOBuf chain pointed to by buffer
407   IOBuf const* current = data;
408   ByteRange input{current->data(), current->length()};
409   StreamCodec::FlushOp flushOp = StreamCodec::FlushOp::NONE;
410   bool done = false;
411   while (!done) {
412     while (input.empty() && current->next() != data) {
413       current = current->next();
414       input = {current->data(), current->length()};
415     }
416     if (current->next() == data) {
417       // Tell the uncompressor there is no more input (it may optimize)
418       flushOp = StreamCodec::FlushOp::END;
419     }
420     if (output.empty()) {
421       buffer->prependChain(addOutputBuffer(output, defaultBufferLength));
422     }
423     done = uncompressStream(input, output, flushOp);
424   }
425   if (!input.empty()) {
426     throw std::runtime_error("Codec: Junk after end of data");
427   }
428
429   buffer->prev()->trimEnd(output.size());
430   if (uncompressedLength &&
431       *uncompressedLength != buffer->computeChainDataLength()) {
432     throw std::runtime_error("Codec: invalid uncompressed length");
433   }
434
435   return buffer;
436 }
437
438 namespace {
439
440 /**
441  * No compression
442  */
443 class NoCompressionCodec final : public Codec {
444  public:
445   static std::unique_ptr<Codec> create(int level, CodecType type);
446   explicit NoCompressionCodec(int level, CodecType type);
447
448  private:
449   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
450   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
451   std::unique_ptr<IOBuf> doUncompress(
452       const IOBuf* data,
453       Optional<uint64_t> uncompressedLength) override;
454 };
455
456 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
457   return std::make_unique<NoCompressionCodec>(level, type);
458 }
459
460 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
461     : Codec(type) {
462   DCHECK(type == CodecType::NO_COMPRESSION);
463   switch (level) {
464     case COMPRESSION_LEVEL_DEFAULT:
465     case COMPRESSION_LEVEL_FASTEST:
466     case COMPRESSION_LEVEL_BEST:
467       level = 0;
468   }
469   if (level != 0) {
470     throw std::invalid_argument(to<std::string>(
471         "NoCompressionCodec: invalid level ", level));
472   }
473 }
474
475 uint64_t NoCompressionCodec::doMaxCompressedLength(
476     uint64_t uncompressedLength) const {
477   return uncompressedLength;
478 }
479
480 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
481     const IOBuf* data) {
482   return data->clone();
483 }
484
485 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
486     const IOBuf* data,
487     Optional<uint64_t> uncompressedLength) {
488   if (uncompressedLength &&
489       data->computeChainDataLength() != *uncompressedLength) {
490     throw std::runtime_error(
491         to<std::string>("NoCompressionCodec: invalid uncompressed length"));
492   }
493   return data->clone();
494 }
495
496 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
497
498 namespace {
499
500 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
501   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
502   out->append(encodeVarint(val, out->writableTail()));
503 }
504
505 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
506   uint64_t val = 0;
507   int8_t b = 0;
508   for (int shift = 0; shift <= 63; shift += 7) {
509     b = cursor.read<int8_t>();
510     val |= static_cast<uint64_t>(b & 0x7f) << shift;
511     if (b >= 0) {
512       break;
513     }
514   }
515   if (b < 0) {
516     throw std::invalid_argument("Invalid varint value. Too big.");
517   }
518   return val;
519 }
520
521 } // namespace
522
523 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
524
525 #if FOLLY_HAVE_LIBLZ4
526
527 /**
528  * LZ4 compression
529  */
530 class LZ4Codec final : public Codec {
531  public:
532   static std::unique_ptr<Codec> create(int level, CodecType type);
533   explicit LZ4Codec(int level, CodecType type);
534
535  private:
536   bool doNeedsUncompressedLength() const override;
537   uint64_t doMaxUncompressedLength() const override;
538   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
539
540   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
541
542   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
543   std::unique_ptr<IOBuf> doUncompress(
544       const IOBuf* data,
545       Optional<uint64_t> uncompressedLength) override;
546
547   bool highCompression_;
548 };
549
550 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
551   return std::make_unique<LZ4Codec>(level, type);
552 }
553
554 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
555   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
556
557   switch (level) {
558     case COMPRESSION_LEVEL_FASTEST:
559     case COMPRESSION_LEVEL_DEFAULT:
560       level = 1;
561       break;
562     case COMPRESSION_LEVEL_BEST:
563       level = 2;
564       break;
565   }
566   if (level < 1 || level > 2) {
567     throw std::invalid_argument(to<std::string>(
568         "LZ4Codec: invalid level: ", level));
569   }
570   highCompression_ = (level > 1);
571 }
572
573 bool LZ4Codec::doNeedsUncompressedLength() const {
574   return !encodeSize();
575 }
576
577 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
578 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
579 // here.
580 #ifndef LZ4_MAX_INPUT_SIZE
581 # define LZ4_MAX_INPUT_SIZE 0x7E000000
582 #endif
583
584 uint64_t LZ4Codec::doMaxUncompressedLength() const {
585   return LZ4_MAX_INPUT_SIZE;
586 }
587
588 uint64_t LZ4Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
589   return LZ4_compressBound(uncompressedLength) +
590       (encodeSize() ? kMaxVarintLength64 : 0);
591 }
592
593 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
594   IOBuf clone;
595   if (data->isChained()) {
596     // LZ4 doesn't support streaming, so we have to coalesce
597     clone = data->cloneCoalescedAsValue();
598     data = &clone;
599   }
600
601   auto out = IOBuf::create(maxCompressedLength(data->length()));
602   if (encodeSize()) {
603     encodeVarintToIOBuf(data->length(), out.get());
604   }
605
606   int n;
607   auto input = reinterpret_cast<const char*>(data->data());
608   auto output = reinterpret_cast<char*>(out->writableTail());
609   const auto inputLength = data->length();
610 #if LZ4_VERSION_NUMBER >= 10700
611   if (highCompression_) {
612     n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0);
613   } else {
614     n = LZ4_compress_default(input, output, inputLength, out->tailroom());
615   }
616 #else
617   if (highCompression_) {
618     n = LZ4_compressHC(input, output, inputLength);
619   } else {
620     n = LZ4_compress(input, output, inputLength);
621   }
622 #endif
623
624   CHECK_GE(n, 0);
625   CHECK_LE(n, out->capacity());
626
627   out->append(n);
628   return out;
629 }
630
631 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
632     const IOBuf* data,
633     Optional<uint64_t> uncompressedLength) {
634   IOBuf clone;
635   if (data->isChained()) {
636     // LZ4 doesn't support streaming, so we have to coalesce
637     clone = data->cloneCoalescedAsValue();
638     data = &clone;
639   }
640
641   folly::io::Cursor cursor(data);
642   uint64_t actualUncompressedLength;
643   if (encodeSize()) {
644     actualUncompressedLength = decodeVarintFromCursor(cursor);
645     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
646       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
647     }
648   } else {
649     // Invariants
650     DCHECK(uncompressedLength.hasValue());
651     DCHECK(*uncompressedLength <= maxUncompressedLength());
652     actualUncompressedLength = *uncompressedLength;
653   }
654
655   auto sp = StringPiece{cursor.peekBytes()};
656   auto out = IOBuf::create(actualUncompressedLength);
657   int n = LZ4_decompress_safe(
658       sp.data(),
659       reinterpret_cast<char*>(out->writableTail()),
660       sp.size(),
661       actualUncompressedLength);
662
663   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
664     throw std::runtime_error(to<std::string>(
665         "LZ4 decompression returned invalid value ", n));
666   }
667   out->append(actualUncompressedLength);
668   return out;
669 }
670
671 #if LZ4_VERSION_NUMBER >= 10301
672
673 class LZ4FrameCodec final : public Codec {
674  public:
675   static std::unique_ptr<Codec> create(int level, CodecType type);
676   explicit LZ4FrameCodec(int level, CodecType type);
677   ~LZ4FrameCodec() override;
678
679   std::vector<std::string> validPrefixes() const override;
680   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
681       const override;
682
683  private:
684   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
685
686   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
687   std::unique_ptr<IOBuf> doUncompress(
688       const IOBuf* data,
689       Optional<uint64_t> uncompressedLength) override;
690
691   // Reset the dctx_ if it is dirty or null.
692   void resetDCtx();
693
694   int level_;
695   LZ4F_decompressionContext_t dctx_{nullptr};
696   bool dirty_{false};
697 };
698
699 /* static */ std::unique_ptr<Codec> LZ4FrameCodec::create(
700     int level,
701     CodecType type) {
702   return std::make_unique<LZ4FrameCodec>(level, type);
703 }
704
705 static constexpr uint32_t kLZ4FrameMagicLE = 0x184D2204;
706
707 std::vector<std::string> LZ4FrameCodec::validPrefixes() const {
708   return {prefixToStringLE(kLZ4FrameMagicLE)};
709 }
710
711 bool LZ4FrameCodec::canUncompress(const IOBuf* data, Optional<uint64_t>) const {
712   return dataStartsWithLE(data, kLZ4FrameMagicLE);
713 }
714
715 uint64_t LZ4FrameCodec::doMaxCompressedLength(
716     uint64_t uncompressedLength) const {
717   LZ4F_preferences_t prefs{};
718   prefs.compressionLevel = level_;
719   prefs.frameInfo.contentSize = uncompressedLength;
720   return LZ4F_compressFrameBound(uncompressedLength, &prefs);
721 }
722
723 static size_t lz4FrameThrowOnError(size_t code) {
724   if (LZ4F_isError(code)) {
725     throw std::runtime_error(
726         to<std::string>("LZ4Frame error: ", LZ4F_getErrorName(code)));
727   }
728   return code;
729 }
730
731 void LZ4FrameCodec::resetDCtx() {
732   if (dctx_ && !dirty_) {
733     return;
734   }
735   if (dctx_) {
736     LZ4F_freeDecompressionContext(dctx_);
737   }
738   lz4FrameThrowOnError(LZ4F_createDecompressionContext(&dctx_, 100));
739   dirty_ = false;
740 }
741
742 LZ4FrameCodec::LZ4FrameCodec(int level, CodecType type) : Codec(type) {
743   DCHECK(type == CodecType::LZ4_FRAME);
744   switch (level) {
745     case COMPRESSION_LEVEL_FASTEST:
746     case COMPRESSION_LEVEL_DEFAULT:
747       level_ = 0;
748       break;
749     case COMPRESSION_LEVEL_BEST:
750       level_ = 16;
751       break;
752     default:
753       level_ = level;
754       break;
755   }
756 }
757
758 LZ4FrameCodec::~LZ4FrameCodec() {
759   if (dctx_) {
760     LZ4F_freeDecompressionContext(dctx_);
761   }
762 }
763
764 std::unique_ptr<IOBuf> LZ4FrameCodec::doCompress(const IOBuf* data) {
765   // LZ4 Frame compression doesn't support streaming so we have to coalesce
766   IOBuf clone;
767   if (data->isChained()) {
768     clone = data->cloneCoalescedAsValue();
769     data = &clone;
770   }
771   // Set preferences
772   const auto uncompressedLength = data->length();
773   LZ4F_preferences_t prefs{};
774   prefs.compressionLevel = level_;
775   prefs.frameInfo.contentSize = uncompressedLength;
776   // Compress
777   auto buf = IOBuf::create(maxCompressedLength(uncompressedLength));
778   const size_t written = lz4FrameThrowOnError(LZ4F_compressFrame(
779       buf->writableTail(),
780       buf->tailroom(),
781       data->data(),
782       data->length(),
783       &prefs));
784   buf->append(written);
785   return buf;
786 }
787
788 std::unique_ptr<IOBuf> LZ4FrameCodec::doUncompress(
789     const IOBuf* data,
790     Optional<uint64_t> uncompressedLength) {
791   // Reset the dctx if any errors have occurred
792   resetDCtx();
793   // Coalesce the data
794   ByteRange in = *data->begin();
795   IOBuf clone;
796   if (data->isChained()) {
797     clone = data->cloneCoalescedAsValue();
798     in = clone.coalesce();
799   }
800   data = nullptr;
801   // Select decompression options
802   LZ4F_decompressOptions_t options;
803   options.stableDst = 1;
804   // Select blockSize and growthSize for the IOBufQueue
805   IOBufQueue queue(IOBufQueue::cacheChainLength());
806   auto blockSize = uint64_t{64} << 10;
807   auto growthSize = uint64_t{4} << 20;
808   if (uncompressedLength) {
809     // Allocate uncompressedLength in one chunk (up to 64 MB)
810     const auto allocateSize = std::min(*uncompressedLength, uint64_t{64} << 20);
811     queue.preallocate(allocateSize, allocateSize);
812     blockSize = std::min(*uncompressedLength, blockSize);
813     growthSize = std::min(*uncompressedLength, growthSize);
814   } else {
815     // Reduce growthSize for small data
816     const auto guessUncompressedLen =
817         4 * std::max<uint64_t>(blockSize, in.size());
818     growthSize = std::min(guessUncompressedLen, growthSize);
819   }
820   // Once LZ4_decompress() is called, the dctx_ cannot be reused until it
821   // returns 0
822   dirty_ = true;
823   // Decompress until the frame is over
824   size_t code = 0;
825   do {
826     // Allocate enough space to decompress at least a block
827     void* out;
828     size_t outSize;
829     std::tie(out, outSize) = queue.preallocate(blockSize, growthSize);
830     // Decompress
831     size_t inSize = in.size();
832     code = lz4FrameThrowOnError(
833         LZ4F_decompress(dctx_, out, &outSize, in.data(), &inSize, &options));
834     if (in.empty() && outSize == 0 && code != 0) {
835       // We passed no input, no output was produced, and the frame isn't over
836       // No more forward progress is possible
837       throw std::runtime_error("LZ4Frame error: Incomplete frame");
838     }
839     in.uncheckedAdvance(inSize);
840     queue.postallocate(outSize);
841   } while (code != 0);
842   // At this point the decompression context can be reused
843   dirty_ = false;
844   if (uncompressedLength && queue.chainLength() != *uncompressedLength) {
845     throw std::runtime_error("LZ4Frame error: Invalid uncompressedLength");
846   }
847   return queue.move();
848 }
849
850 #endif // LZ4_VERSION_NUMBER >= 10301
851 #endif // FOLLY_HAVE_LIBLZ4
852
853 #if FOLLY_HAVE_LIBSNAPPY
854
855 /**
856  * Snappy compression
857  */
858
859 /**
860  * Implementation of snappy::Source that reads from a IOBuf chain.
861  */
862 class IOBufSnappySource final : public snappy::Source {
863  public:
864   explicit IOBufSnappySource(const IOBuf* data);
865   size_t Available() const override;
866   const char* Peek(size_t* len) override;
867   void Skip(size_t n) override;
868  private:
869   size_t available_;
870   io::Cursor cursor_;
871 };
872
873 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
874   : available_(data->computeChainDataLength()),
875     cursor_(data) {
876 }
877
878 size_t IOBufSnappySource::Available() const {
879   return available_;
880 }
881
882 const char* IOBufSnappySource::Peek(size_t* len) {
883   auto sp = StringPiece{cursor_.peekBytes()};
884   *len = sp.size();
885   return sp.data();
886 }
887
888 void IOBufSnappySource::Skip(size_t n) {
889   CHECK_LE(n, available_);
890   cursor_.skip(n);
891   available_ -= n;
892 }
893
894 class SnappyCodec final : public Codec {
895  public:
896   static std::unique_ptr<Codec> create(int level, CodecType type);
897   explicit SnappyCodec(int level, CodecType type);
898
899  private:
900   uint64_t doMaxUncompressedLength() const override;
901   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
902   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
903   std::unique_ptr<IOBuf> doUncompress(
904       const IOBuf* data,
905       Optional<uint64_t> uncompressedLength) override;
906 };
907
908 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
909   return std::make_unique<SnappyCodec>(level, type);
910 }
911
912 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
913   DCHECK(type == CodecType::SNAPPY);
914   switch (level) {
915     case COMPRESSION_LEVEL_FASTEST:
916     case COMPRESSION_LEVEL_DEFAULT:
917     case COMPRESSION_LEVEL_BEST:
918       level = 1;
919   }
920   if (level != 1) {
921     throw std::invalid_argument(to<std::string>(
922         "SnappyCodec: invalid level: ", level));
923   }
924 }
925
926 uint64_t SnappyCodec::doMaxUncompressedLength() const {
927   // snappy.h uses uint32_t for lengths, so there's that.
928   return std::numeric_limits<uint32_t>::max();
929 }
930
931 uint64_t SnappyCodec::doMaxCompressedLength(uint64_t uncompressedLength) const {
932   return snappy::MaxCompressedLength(uncompressedLength);
933 }
934
935 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
936   IOBufSnappySource source(data);
937   auto out = IOBuf::create(maxCompressedLength(source.Available()));
938
939   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
940       out->writableTail()));
941
942   size_t n = snappy::Compress(&source, &sink);
943
944   CHECK_LE(n, out->capacity());
945   out->append(n);
946   return out;
947 }
948
949 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(
950     const IOBuf* data,
951     Optional<uint64_t> uncompressedLength) {
952   uint32_t actualUncompressedLength = 0;
953
954   {
955     IOBufSnappySource source(data);
956     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
957       throw std::runtime_error("snappy::GetUncompressedLength failed");
958     }
959     if (uncompressedLength && *uncompressedLength != actualUncompressedLength) {
960       throw std::runtime_error("snappy: invalid uncompressed length");
961     }
962   }
963
964   auto out = IOBuf::create(actualUncompressedLength);
965
966   {
967     IOBufSnappySource source(data);
968     if (!snappy::RawUncompress(&source,
969                                reinterpret_cast<char*>(out->writableTail()))) {
970       throw std::runtime_error("snappy::RawUncompress failed");
971     }
972   }
973
974   out->append(actualUncompressedLength);
975   return out;
976 }
977
978 #endif  // FOLLY_HAVE_LIBSNAPPY
979
980 #if FOLLY_HAVE_LIBLZMA
981
982 /**
983  * LZMA2 compression
984  */
985 class LZMA2StreamCodec final : public StreamCodec {
986  public:
987   static std::unique_ptr<Codec> createCodec(int level, CodecType type);
988   static std::unique_ptr<StreamCodec> createStream(int level, CodecType type);
989   explicit LZMA2StreamCodec(int level, CodecType type);
990   ~LZMA2StreamCodec() override;
991
992   std::vector<std::string> validPrefixes() const override;
993   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
994       const override;
995
996  private:
997   bool doNeedsDataLength() const override;
998   uint64_t doMaxUncompressedLength() const override;
999   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1000
1001   bool encodeSize() const {
1002     return type() == CodecType::LZMA2_VARINT_SIZE;
1003   }
1004
1005   void doResetStream() override;
1006   bool doCompressStream(
1007       ByteRange& input,
1008       MutableByteRange& output,
1009       StreamCodec::FlushOp flushOp) override;
1010   bool doUncompressStream(
1011       ByteRange& input,
1012       MutableByteRange& output,
1013       StreamCodec::FlushOp flushOp) override;
1014
1015   void resetCStream();
1016   void resetDStream();
1017
1018   bool decodeAndCheckVarint(ByteRange& input);
1019   bool flushVarintBuffer(MutableByteRange& output);
1020   void resetVarintBuffer();
1021
1022   Optional<lzma_stream> cstream_{};
1023   Optional<lzma_stream> dstream_{};
1024
1025   std::array<uint8_t, kMaxVarintLength64> varintBuffer_;
1026   ByteRange varintToEncode_;
1027   size_t varintBufferPos_{0};
1028
1029   int level_;
1030   bool needReset_{true};
1031   bool needDecodeSize_{false};
1032 };
1033
1034 static constexpr uint64_t kLZMA2MagicLE = 0x005A587A37FD;
1035 static constexpr unsigned kLZMA2MagicBytes = 6;
1036
1037 std::vector<std::string> LZMA2StreamCodec::validPrefixes() const {
1038   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1039     return {};
1040   }
1041   return {prefixToStringLE(kLZMA2MagicLE, kLZMA2MagicBytes)};
1042 }
1043
1044 bool LZMA2StreamCodec::doNeedsDataLength() const {
1045   return encodeSize();
1046 }
1047
1048 bool LZMA2StreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
1049     const {
1050   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1051     return false;
1052   }
1053   // Returns false for all inputs less than 8 bytes.
1054   // This is okay, because no valid LZMA2 streams are less than 8 bytes.
1055   return dataStartsWithLE(data, kLZMA2MagicLE, kLZMA2MagicBytes);
1056 }
1057
1058 std::unique_ptr<Codec> LZMA2StreamCodec::createCodec(
1059     int level,
1060     CodecType type) {
1061   return make_unique<LZMA2StreamCodec>(level, type);
1062 }
1063
1064 std::unique_ptr<StreamCodec> LZMA2StreamCodec::createStream(
1065     int level,
1066     CodecType type) {
1067   return make_unique<LZMA2StreamCodec>(level, type);
1068 }
1069
1070 LZMA2StreamCodec::LZMA2StreamCodec(int level, CodecType type)
1071     : StreamCodec(type) {
1072   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
1073   switch (level) {
1074     case COMPRESSION_LEVEL_FASTEST:
1075       level = 0;
1076       break;
1077     case COMPRESSION_LEVEL_DEFAULT:
1078       level = LZMA_PRESET_DEFAULT;
1079       break;
1080     case COMPRESSION_LEVEL_BEST:
1081       level = 9;
1082       break;
1083   }
1084   if (level < 0 || level > 9) {
1085     throw std::invalid_argument(
1086         to<std::string>("LZMA2Codec: invalid level: ", level));
1087   }
1088   level_ = level;
1089 }
1090
1091 LZMA2StreamCodec::~LZMA2StreamCodec() {
1092   if (cstream_) {
1093     lzma_end(cstream_.get_pointer());
1094     cstream_.clear();
1095   }
1096   if (dstream_) {
1097     lzma_end(dstream_.get_pointer());
1098     dstream_.clear();
1099   }
1100 }
1101
1102 uint64_t LZMA2StreamCodec::doMaxUncompressedLength() const {
1103   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
1104   return uint64_t(1) << 63;
1105 }
1106
1107 uint64_t LZMA2StreamCodec::doMaxCompressedLength(
1108     uint64_t uncompressedLength) const {
1109   return lzma_stream_buffer_bound(uncompressedLength) +
1110       (encodeSize() ? kMaxVarintLength64 : 0);
1111 }
1112
1113 void LZMA2StreamCodec::doResetStream() {
1114   needReset_ = true;
1115 }
1116
1117 void LZMA2StreamCodec::resetCStream() {
1118   if (!cstream_) {
1119     cstream_.assign(LZMA_STREAM_INIT);
1120   }
1121   lzma_ret const rc =
1122       lzma_easy_encoder(cstream_.get_pointer(), level_, LZMA_CHECK_NONE);
1123   if (rc != LZMA_OK) {
1124     throw std::runtime_error(folly::to<std::string>(
1125         "LZMA2StreamCodec: lzma_easy_encoder error: ", rc));
1126   }
1127 }
1128
1129 void LZMA2StreamCodec::resetDStream() {
1130   if (!dstream_) {
1131     dstream_.assign(LZMA_STREAM_INIT);
1132   }
1133   lzma_ret const rc = lzma_auto_decoder(
1134       dstream_.get_pointer(), std::numeric_limits<uint64_t>::max(), 0);
1135   if (rc != LZMA_OK) {
1136     throw std::runtime_error(folly::to<std::string>(
1137         "LZMA2StreamCodec: lzma_auto_decoder error: ", rc));
1138   }
1139 }
1140
1141 static lzma_ret lzmaThrowOnError(lzma_ret const rc) {
1142   switch (rc) {
1143     case LZMA_OK:
1144     case LZMA_STREAM_END:
1145     case LZMA_BUF_ERROR: // not fatal: returned if no progress was made twice
1146       return rc;
1147     default:
1148       throw std::runtime_error(
1149           to<std::string>("LZMA2StreamCodec: error: ", rc));
1150   }
1151 }
1152
1153 static lzma_action lzmaTranslateFlush(StreamCodec::FlushOp flush) {
1154   switch (flush) {
1155     case StreamCodec::FlushOp::NONE:
1156       return LZMA_RUN;
1157     case StreamCodec::FlushOp::FLUSH:
1158       return LZMA_SYNC_FLUSH;
1159     case StreamCodec::FlushOp::END:
1160       return LZMA_FINISH;
1161     default:
1162       throw std::invalid_argument("LZMA2StreamCodec: Invalid flush");
1163   }
1164 }
1165
1166 /**
1167  * Flushes the varint buffer.
1168  * Advances output by the number of bytes written.
1169  * Returns true when flushing is complete.
1170  */
1171 bool LZMA2StreamCodec::flushVarintBuffer(MutableByteRange& output) {
1172   if (varintToEncode_.empty()) {
1173     return true;
1174   }
1175   const size_t numBytesToCopy = std::min(varintToEncode_.size(), output.size());
1176   if (numBytesToCopy > 0) {
1177     memcpy(output.data(), varintToEncode_.data(), numBytesToCopy);
1178   }
1179   varintToEncode_.advance(numBytesToCopy);
1180   output.advance(numBytesToCopy);
1181   return varintToEncode_.empty();
1182 }
1183
1184 bool LZMA2StreamCodec::doCompressStream(
1185     ByteRange& input,
1186     MutableByteRange& output,
1187     StreamCodec::FlushOp flushOp) {
1188   if (needReset_) {
1189     resetCStream();
1190     if (encodeSize()) {
1191       varintBufferPos_ = 0;
1192       size_t const varintSize =
1193           encodeVarint(*uncompressedLength(), varintBuffer_.data());
1194       varintToEncode_ = {varintBuffer_.data(), varintSize};
1195     }
1196     needReset_ = false;
1197   }
1198
1199   if (!flushVarintBuffer(output)) {
1200     return false;
1201   }
1202
1203   cstream_->next_in = const_cast<uint8_t*>(input.data());
1204   cstream_->avail_in = input.size();
1205   cstream_->next_out = output.data();
1206   cstream_->avail_out = output.size();
1207   SCOPE_EXIT {
1208     input.uncheckedAdvance(input.size() - cstream_->avail_in);
1209     output.uncheckedAdvance(output.size() - cstream_->avail_out);
1210   };
1211   lzma_ret const rc = lzmaThrowOnError(
1212       lzma_code(cstream_.get_pointer(), lzmaTranslateFlush(flushOp)));
1213   switch (flushOp) {
1214     case StreamCodec::FlushOp::NONE:
1215       return false;
1216     case StreamCodec::FlushOp::FLUSH:
1217       return cstream_->avail_in == 0 && cstream_->avail_out != 0;
1218     case StreamCodec::FlushOp::END:
1219       return rc == LZMA_STREAM_END;
1220     default:
1221       throw std::invalid_argument("LZMA2StreamCodec: invalid FlushOp");
1222   }
1223 }
1224
1225 /**
1226  * Attempts to decode a varint from input.
1227  * The function advances input by the number of bytes read.
1228  *
1229  * If there are too many bytes and the varint is not valid, throw a
1230  * runtime_error.
1231  *
1232  * If the uncompressed length was provided and a decoded varint does not match
1233  * the provided length, throw a runtime_error.
1234  *
1235  * Returns true if the varint was successfully decoded and matches the
1236  * uncompressed length if provided, and false if more bytes are needed.
1237  */
1238 bool LZMA2StreamCodec::decodeAndCheckVarint(ByteRange& input) {
1239   if (input.empty()) {
1240     return false;
1241   }
1242   size_t const numBytesToCopy =
1243       std::min(kMaxVarintLength64 - varintBufferPos_, input.size());
1244   memcpy(varintBuffer_.data() + varintBufferPos_, input.data(), numBytesToCopy);
1245
1246   size_t const rangeSize = varintBufferPos_ + numBytesToCopy;
1247   ByteRange range{varintBuffer_.data(), rangeSize};
1248   auto const ret = tryDecodeVarint(range);
1249
1250   if (ret.hasValue()) {
1251     size_t const varintSize = rangeSize - range.size();
1252     input.advance(varintSize - varintBufferPos_);
1253     if (uncompressedLength() && *uncompressedLength() != ret.value()) {
1254       throw std::runtime_error("LZMA2StreamCodec: invalid uncompressed length");
1255     }
1256     return true;
1257   } else if (ret.error() == DecodeVarintError::TooManyBytes) {
1258     throw std::runtime_error("LZMA2StreamCodec: invalid uncompressed length");
1259   } else {
1260     // Too few bytes
1261     input.advance(numBytesToCopy);
1262     varintBufferPos_ += numBytesToCopy;
1263     return false;
1264   }
1265 }
1266
1267 bool LZMA2StreamCodec::doUncompressStream(
1268     ByteRange& input,
1269     MutableByteRange& output,
1270     StreamCodec::FlushOp flushOp) {
1271   if (needReset_) {
1272     resetDStream();
1273     needReset_ = false;
1274     needDecodeSize_ = encodeSize();
1275     if (encodeSize()) {
1276       // Reset buffer
1277       varintBufferPos_ = 0;
1278     }
1279   }
1280
1281   if (needDecodeSize_) {
1282     // Try decoding the varint. If the input does not contain the entire varint,
1283     // buffer the input. If the varint can not be decoded, fail.
1284     if (!decodeAndCheckVarint(input)) {
1285       return false;
1286     }
1287     needDecodeSize_ = false;
1288   }
1289
1290   dstream_->next_in = const_cast<uint8_t*>(input.data());
1291   dstream_->avail_in = input.size();
1292   dstream_->next_out = output.data();
1293   dstream_->avail_out = output.size();
1294   SCOPE_EXIT {
1295     input.advance(input.size() - dstream_->avail_in);
1296     output.advance(output.size() - dstream_->avail_out);
1297   };
1298
1299   lzma_ret rc;
1300   switch (flushOp) {
1301     case StreamCodec::FlushOp::NONE:
1302     case StreamCodec::FlushOp::FLUSH:
1303       rc = lzmaThrowOnError(lzma_code(dstream_.get_pointer(), LZMA_RUN));
1304       break;
1305     case StreamCodec::FlushOp::END:
1306       rc = lzmaThrowOnError(lzma_code(dstream_.get_pointer(), LZMA_FINISH));
1307       break;
1308     default:
1309       throw std::invalid_argument("LZMA2StreamCodec: invalid flush");
1310   }
1311   return rc == LZMA_STREAM_END;
1312 }
1313 #endif // FOLLY_HAVE_LIBLZMA
1314
1315 #ifdef FOLLY_HAVE_LIBZSTD
1316
1317 namespace {
1318 void zstdFreeCStream(ZSTD_CStream* zcs) {
1319   ZSTD_freeCStream(zcs);
1320 }
1321
1322 void zstdFreeDStream(ZSTD_DStream* zds) {
1323   ZSTD_freeDStream(zds);
1324 }
1325 } // namespace
1326
1327 /**
1328  * ZSTD compression
1329  */
1330 class ZSTDStreamCodec final : public StreamCodec {
1331  public:
1332   static std::unique_ptr<Codec> createCodec(int level, CodecType);
1333   static std::unique_ptr<StreamCodec> createStream(int level, CodecType);
1334   explicit ZSTDStreamCodec(int level, CodecType type);
1335
1336   std::vector<std::string> validPrefixes() const override;
1337   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1338       const override;
1339
1340  private:
1341   bool doNeedsUncompressedLength() const override;
1342   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1343   Optional<uint64_t> doGetUncompressedLength(
1344       IOBuf const* data,
1345       Optional<uint64_t> uncompressedLength) const override;
1346
1347   void doResetStream() override;
1348   bool doCompressStream(
1349       ByteRange& input,
1350       MutableByteRange& output,
1351       StreamCodec::FlushOp flushOp) override;
1352   bool doUncompressStream(
1353       ByteRange& input,
1354       MutableByteRange& output,
1355       StreamCodec::FlushOp flushOp) override;
1356
1357   void resetCStream();
1358   void resetDStream();
1359
1360   bool tryBlockCompress(ByteRange& input, MutableByteRange& output) const;
1361   bool tryBlockUncompress(ByteRange& input, MutableByteRange& output) const;
1362
1363   int level_;
1364   bool needReset_{true};
1365   std::unique_ptr<
1366       ZSTD_CStream,
1367       folly::static_function_deleter<ZSTD_CStream, &zstdFreeCStream>>
1368       cstream_{nullptr};
1369   std::unique_ptr<
1370       ZSTD_DStream,
1371       folly::static_function_deleter<ZSTD_DStream, &zstdFreeDStream>>
1372       dstream_{nullptr};
1373 };
1374
1375 static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528;
1376
1377 std::vector<std::string> ZSTDStreamCodec::validPrefixes() const {
1378   return {prefixToStringLE(kZSTDMagicLE)};
1379 }
1380
1381 bool ZSTDStreamCodec::canUncompress(const IOBuf* data, Optional<uint64_t>)
1382     const {
1383   return dataStartsWithLE(data, kZSTDMagicLE);
1384 }
1385
1386 std::unique_ptr<Codec> ZSTDStreamCodec::createCodec(int level, CodecType type) {
1387   return make_unique<ZSTDStreamCodec>(level, type);
1388 }
1389
1390 std::unique_ptr<StreamCodec> ZSTDStreamCodec::createStream(
1391     int level,
1392     CodecType type) {
1393   return make_unique<ZSTDStreamCodec>(level, type);
1394 }
1395
1396 ZSTDStreamCodec::ZSTDStreamCodec(int level, CodecType type)
1397     : StreamCodec(type) {
1398   DCHECK(type == CodecType::ZSTD);
1399   switch (level) {
1400     case COMPRESSION_LEVEL_FASTEST:
1401       level = 1;
1402       break;
1403     case COMPRESSION_LEVEL_DEFAULT:
1404       level = 1;
1405       break;
1406     case COMPRESSION_LEVEL_BEST:
1407       level = 19;
1408       break;
1409   }
1410   if (level < 1 || level > ZSTD_maxCLevel()) {
1411     throw std::invalid_argument(
1412         to<std::string>("ZSTD: invalid level: ", level));
1413   }
1414   level_ = level;
1415 }
1416
1417 bool ZSTDStreamCodec::doNeedsUncompressedLength() const {
1418   return false;
1419 }
1420
1421 uint64_t ZSTDStreamCodec::doMaxCompressedLength(
1422     uint64_t uncompressedLength) const {
1423   return ZSTD_compressBound(uncompressedLength);
1424 }
1425
1426 void zstdThrowIfError(size_t rc) {
1427   if (!ZSTD_isError(rc)) {
1428     return;
1429   }
1430   throw std::runtime_error(
1431       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1432 }
1433
1434 Optional<uint64_t> ZSTDStreamCodec::doGetUncompressedLength(
1435     IOBuf const* data,
1436     Optional<uint64_t> uncompressedLength) const {
1437   // Read decompressed size from frame if available in first IOBuf.
1438   auto const decompressedSize =
1439       ZSTD_getDecompressedSize(data->data(), data->length());
1440   if (decompressedSize != 0) {
1441     if (uncompressedLength && *uncompressedLength != decompressedSize) {
1442       throw std::runtime_error("ZSTD: invalid uncompressed length");
1443     }
1444     uncompressedLength = decompressedSize;
1445   }
1446   return uncompressedLength;
1447 }
1448
1449 void ZSTDStreamCodec::doResetStream() {
1450   needReset_ = true;
1451 }
1452
1453 bool ZSTDStreamCodec::tryBlockCompress(
1454     ByteRange& input,
1455     MutableByteRange& output) const {
1456   DCHECK(needReset_);
1457   // We need to know that we have enough output space to use block compression
1458   if (output.size() < ZSTD_compressBound(input.size())) {
1459     return false;
1460   }
1461   size_t const length = ZSTD_compress(
1462       output.data(), output.size(), input.data(), input.size(), level_);
1463   zstdThrowIfError(length);
1464   input.uncheckedAdvance(input.size());
1465   output.uncheckedAdvance(length);
1466   return true;
1467 }
1468
1469 void ZSTDStreamCodec::resetCStream() {
1470   if (!cstream_) {
1471     cstream_.reset(ZSTD_createCStream());
1472     if (!cstream_) {
1473       throw std::bad_alloc{};
1474     }
1475   }
1476   // As of 1.3.2 ZSTD_initCStream_advanced() interprets content size 0 as
1477   // unknown if contentSizeFlag == 0, but this behavior is deprecated, and will
1478   // be removed in the future. Starting with version 1.3.2 start passing the
1479   // correct value, ZSTD_CONTENTSIZE_UNKNOWN.
1480 #if ZSTD_VERSION_NUMBER >= 10302
1481   constexpr uint64_t kZstdUnknownContentSize = ZSTD_CONTENTSIZE_UNKNOWN;
1482 #else
1483   constexpr uint64_t kZstdUnknownContentSize = 0;
1484 #endif
1485   // Advanced API usage works for all supported versions of zstd.
1486   // Required to set contentSizeFlag.
1487   auto params = ZSTD_getParams(level_, uncompressedLength().value_or(0), 0);
1488   params.fParams.contentSizeFlag = uncompressedLength().hasValue();
1489   zstdThrowIfError(ZSTD_initCStream_advanced(
1490       cstream_.get(),
1491       nullptr,
1492       0,
1493       params,
1494       uncompressedLength().value_or(kZstdUnknownContentSize)));
1495 }
1496
1497 bool ZSTDStreamCodec::doCompressStream(
1498     ByteRange& input,
1499     MutableByteRange& output,
1500     StreamCodec::FlushOp flushOp) {
1501   if (needReset_) {
1502     // If we are given all the input in one chunk try to use block compression
1503     if (flushOp == StreamCodec::FlushOp::END &&
1504         tryBlockCompress(input, output)) {
1505       return true;
1506     }
1507     resetCStream();
1508     needReset_ = false;
1509   }
1510   ZSTD_inBuffer in = {input.data(), input.size(), 0};
1511   ZSTD_outBuffer out = {output.data(), output.size(), 0};
1512   SCOPE_EXIT {
1513     input.uncheckedAdvance(in.pos);
1514     output.uncheckedAdvance(out.pos);
1515   };
1516   if (flushOp == StreamCodec::FlushOp::NONE || !input.empty()) {
1517     zstdThrowIfError(ZSTD_compressStream(cstream_.get(), &out, &in));
1518   }
1519   if (in.pos == in.size && flushOp != StreamCodec::FlushOp::NONE) {
1520     size_t rc;
1521     switch (flushOp) {
1522       case StreamCodec::FlushOp::FLUSH:
1523         rc = ZSTD_flushStream(cstream_.get(), &out);
1524         break;
1525       case StreamCodec::FlushOp::END:
1526         rc = ZSTD_endStream(cstream_.get(), &out);
1527         break;
1528       default:
1529         throw std::invalid_argument("ZSTD: invalid FlushOp");
1530     }
1531     zstdThrowIfError(rc);
1532     if (rc == 0) {
1533       return true;
1534     }
1535   }
1536   return false;
1537 }
1538
1539 bool ZSTDStreamCodec::tryBlockUncompress(
1540     ByteRange& input,
1541     MutableByteRange& output) const {
1542   DCHECK(needReset_);
1543 #if ZSTD_VERSION_NUMBER < 10104
1544   // We require ZSTD_findFrameCompressedSize() to perform this optimization.
1545   return false;
1546 #else
1547   // We need to know the uncompressed length and have enough output space.
1548   if (!uncompressedLength() || output.size() < *uncompressedLength()) {
1549     return false;
1550   }
1551   size_t const compressedLength =
1552       ZSTD_findFrameCompressedSize(input.data(), input.size());
1553   zstdThrowIfError(compressedLength);
1554   size_t const length = ZSTD_decompress(
1555       output.data(), *uncompressedLength(), input.data(), compressedLength);
1556   zstdThrowIfError(length);
1557   if (length != *uncompressedLength()) {
1558     throw std::runtime_error("ZSTDStreamCodec: Incorrect uncompressed length");
1559   }
1560   input.uncheckedAdvance(compressedLength);
1561   output.uncheckedAdvance(length);
1562   return true;
1563 #endif
1564 }
1565
1566 void ZSTDStreamCodec::resetDStream() {
1567   if (!dstream_) {
1568     dstream_.reset(ZSTD_createDStream());
1569     if (!dstream_) {
1570       throw std::bad_alloc{};
1571     }
1572   }
1573   zstdThrowIfError(ZSTD_initDStream(dstream_.get()));
1574 }
1575
1576 bool ZSTDStreamCodec::doUncompressStream(
1577     ByteRange& input,
1578     MutableByteRange& output,
1579     StreamCodec::FlushOp flushOp) {
1580   if (needReset_) {
1581     // If we are given all the input in one chunk try to use block uncompression
1582     if (flushOp == StreamCodec::FlushOp::END &&
1583         tryBlockUncompress(input, output)) {
1584       return true;
1585     }
1586     resetDStream();
1587     needReset_ = false;
1588   }
1589   ZSTD_inBuffer in = {input.data(), input.size(), 0};
1590   ZSTD_outBuffer out = {output.data(), output.size(), 0};
1591   SCOPE_EXIT {
1592     input.uncheckedAdvance(in.pos);
1593     output.uncheckedAdvance(out.pos);
1594   };
1595   size_t const rc = ZSTD_decompressStream(dstream_.get(), &out, &in);
1596   zstdThrowIfError(rc);
1597   return rc == 0;
1598 }
1599
1600 #endif // FOLLY_HAVE_LIBZSTD
1601
1602 #if FOLLY_HAVE_LIBBZ2
1603
1604 class Bzip2Codec final : public Codec {
1605  public:
1606   static std::unique_ptr<Codec> create(int level, CodecType type);
1607   explicit Bzip2Codec(int level, CodecType type);
1608
1609   std::vector<std::string> validPrefixes() const override;
1610   bool canUncompress(IOBuf const* data, Optional<uint64_t> uncompressedLength)
1611       const override;
1612
1613  private:
1614   uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override;
1615   std::unique_ptr<IOBuf> doCompress(IOBuf const* data) override;
1616   std::unique_ptr<IOBuf> doUncompress(
1617       IOBuf const* data,
1618       Optional<uint64_t> uncompressedLength) override;
1619
1620   int level_;
1621 };
1622
1623 /* static */ std::unique_ptr<Codec> Bzip2Codec::create(
1624     int level,
1625     CodecType type) {
1626   return std::make_unique<Bzip2Codec>(level, type);
1627 }
1628
1629 Bzip2Codec::Bzip2Codec(int level, CodecType type) : Codec(type) {
1630   DCHECK(type == CodecType::BZIP2);
1631   switch (level) {
1632     case COMPRESSION_LEVEL_FASTEST:
1633       level = 1;
1634       break;
1635     case COMPRESSION_LEVEL_DEFAULT:
1636       level = 9;
1637       break;
1638     case COMPRESSION_LEVEL_BEST:
1639       level = 9;
1640       break;
1641   }
1642   if (level < 1 || level > 9) {
1643     throw std::invalid_argument(
1644         to<std::string>("Bzip2: invalid level: ", level));
1645   }
1646   level_ = level;
1647 }
1648
1649 static uint32_t constexpr kBzip2MagicLE = 0x685a42;
1650 static uint64_t constexpr kBzip2MagicBytes = 3;
1651
1652 std::vector<std::string> Bzip2Codec::validPrefixes() const {
1653   return {prefixToStringLE(kBzip2MagicLE, kBzip2MagicBytes)};
1654 }
1655
1656 bool Bzip2Codec::canUncompress(IOBuf const* data, Optional<uint64_t>) const {
1657   return dataStartsWithLE(data, kBzip2MagicLE, kBzip2MagicBytes);
1658 }
1659
1660 uint64_t Bzip2Codec::doMaxCompressedLength(uint64_t uncompressedLength) const {
1661   // http://www.bzip.org/1.0.5/bzip2-manual-1.0.5.html#bzbufftobuffcompress
1662   //   To guarantee that the compressed data will fit in its buffer, allocate an
1663   //   output buffer of size 1% larger than the uncompressed data, plus six
1664   //   hundred extra bytes.
1665   return uncompressedLength + uncompressedLength / 100 + 600;
1666 }
1667
1668 static bz_stream createBzStream() {
1669   bz_stream stream;
1670   stream.bzalloc = nullptr;
1671   stream.bzfree = nullptr;
1672   stream.opaque = nullptr;
1673   stream.next_in = stream.next_out = nullptr;
1674   stream.avail_in = stream.avail_out = 0;
1675   return stream;
1676 }
1677
1678 // Throws on error condition, otherwise returns the code.
1679 static int bzCheck(int const rc) {
1680   switch (rc) {
1681     case BZ_OK:
1682     case BZ_RUN_OK:
1683     case BZ_FLUSH_OK:
1684     case BZ_FINISH_OK:
1685     case BZ_STREAM_END:
1686       return rc;
1687     default:
1688       throw std::runtime_error(to<std::string>("Bzip2 error: ", rc));
1689   }
1690 }
1691
1692 static std::unique_ptr<IOBuf> addOutputBuffer(
1693     bz_stream* stream,
1694     uint64_t const bufferLength) {
1695   DCHECK_LE(bufferLength, std::numeric_limits<unsigned>::max());
1696   DCHECK_EQ(stream->avail_out, 0);
1697
1698   auto buf = IOBuf::create(bufferLength);
1699   buf->append(buf->capacity());
1700
1701   stream->next_out = reinterpret_cast<char*>(buf->writableData());
1702   stream->avail_out = buf->length();
1703
1704   return buf;
1705 }
1706
1707 std::unique_ptr<IOBuf> Bzip2Codec::doCompress(IOBuf const* data) {
1708   bz_stream stream = createBzStream();
1709   bzCheck(BZ2_bzCompressInit(&stream, level_, 0, 0));
1710   SCOPE_EXIT {
1711     bzCheck(BZ2_bzCompressEnd(&stream));
1712   };
1713
1714   uint64_t const uncompressedLength = data->computeChainDataLength();
1715   uint64_t const maxCompressedLen = maxCompressedLength(uncompressedLength);
1716   uint64_t constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MiB
1717   uint64_t constexpr kDefaultBufferLength = uint64_t(4) << 20;
1718
1719   auto out = addOutputBuffer(
1720       &stream,
1721       maxCompressedLen <= kMaxSingleStepLength ? maxCompressedLen
1722                                                : kDefaultBufferLength);
1723
1724   for (auto range : *data) {
1725     while (!range.empty()) {
1726       auto const inSize = std::min<size_t>(range.size(), kMaxSingleStepLength);
1727       stream.next_in =
1728           const_cast<char*>(reinterpret_cast<char const*>(range.data()));
1729       stream.avail_in = inSize;
1730
1731       if (stream.avail_out == 0) {
1732         out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1733       }
1734
1735       bzCheck(BZ2_bzCompress(&stream, BZ_RUN));
1736       range.uncheckedAdvance(inSize - stream.avail_in);
1737     }
1738   }
1739   do {
1740     if (stream.avail_out == 0) {
1741       out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1742     }
1743   } while (bzCheck(BZ2_bzCompress(&stream, BZ_FINISH)) != BZ_STREAM_END);
1744
1745   out->prev()->trimEnd(stream.avail_out);
1746
1747   return out;
1748 }
1749
1750 std::unique_ptr<IOBuf> Bzip2Codec::doUncompress(
1751     const IOBuf* data,
1752     Optional<uint64_t> uncompressedLength) {
1753   bz_stream stream = createBzStream();
1754   bzCheck(BZ2_bzDecompressInit(&stream, 0, 0));
1755   SCOPE_EXIT {
1756     bzCheck(BZ2_bzDecompressEnd(&stream));
1757   };
1758
1759   uint64_t constexpr kMaxSingleStepLength = uint64_t(64) << 20; // 64 MiB
1760   uint64_t const kBlockSize = uint64_t(100) << 10; // 100 KiB
1761   uint64_t const kDefaultBufferLength =
1762       computeBufferLength(data->computeChainDataLength(), kBlockSize);
1763
1764   auto out = addOutputBuffer(
1765       &stream,
1766       ((uncompressedLength && *uncompressedLength <= kMaxSingleStepLength)
1767            ? *uncompressedLength
1768            : kDefaultBufferLength));
1769
1770   int rc = BZ_OK;
1771   for (auto range : *data) {
1772     while (!range.empty()) {
1773       auto const inSize = std::min<size_t>(range.size(), kMaxSingleStepLength);
1774       stream.next_in =
1775           const_cast<char*>(reinterpret_cast<char const*>(range.data()));
1776       stream.avail_in = inSize;
1777
1778       if (stream.avail_out == 0) {
1779         out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1780       }
1781
1782       rc = bzCheck(BZ2_bzDecompress(&stream));
1783       range.uncheckedAdvance(inSize - stream.avail_in);
1784     }
1785   }
1786   while (rc != BZ_STREAM_END) {
1787     if (stream.avail_out == 0) {
1788       out->prependChain(addOutputBuffer(&stream, kDefaultBufferLength));
1789     }
1790     size_t const outputSize = stream.avail_out;
1791     rc = bzCheck(BZ2_bzDecompress(&stream));
1792     if (outputSize == stream.avail_out) {
1793       throw std::runtime_error("Bzip2Codec: Truncated input");
1794     }
1795   }
1796
1797   out->prev()->trimEnd(stream.avail_out);
1798
1799   uint64_t const totalOut =
1800       (uint64_t(stream.total_out_hi32) << 32) + stream.total_out_lo32;
1801   if (uncompressedLength && uncompressedLength != totalOut) {
1802     throw std::runtime_error("Bzip2 error: Invalid uncompressed length");
1803   }
1804
1805   return out;
1806 }
1807
1808 #endif // FOLLY_HAVE_LIBBZ2
1809
1810 #if FOLLY_HAVE_LIBZ
1811
1812 zlib::Options getZlibOptions(CodecType type) {
1813   DCHECK(type == CodecType::GZIP || type == CodecType::ZLIB);
1814   return type == CodecType::GZIP ? zlib::defaultGzipOptions()
1815                                  : zlib::defaultZlibOptions();
1816 }
1817
1818 std::unique_ptr<Codec> getZlibCodec(int level, CodecType type) {
1819   return zlib::getCodec(getZlibOptions(type), level);
1820 }
1821
1822 std::unique_ptr<StreamCodec> getZlibStreamCodec(int level, CodecType type) {
1823   return zlib::getStreamCodec(getZlibOptions(type), level);
1824 }
1825
1826 #endif // FOLLY_HAVE_LIBZ
1827
1828 /**
1829  * Automatic decompression
1830  */
1831 class AutomaticCodec final : public Codec {
1832  public:
1833   static std::unique_ptr<Codec> create(
1834       std::vector<std::unique_ptr<Codec>> customCodecs,
1835       std::unique_ptr<Codec> terminalCodec);
1836   explicit AutomaticCodec(
1837       std::vector<std::unique_ptr<Codec>> customCodecs,
1838       std::unique_ptr<Codec> terminalCodec);
1839
1840   std::vector<std::string> validPrefixes() const override;
1841   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
1842       const override;
1843
1844  private:
1845   bool doNeedsUncompressedLength() const override;
1846   uint64_t doMaxUncompressedLength() const override;
1847
1848   uint64_t doMaxCompressedLength(uint64_t) const override {
1849     throw std::runtime_error(
1850         "AutomaticCodec error: maxCompressedLength() not supported.");
1851   }
1852   std::unique_ptr<IOBuf> doCompress(const IOBuf*) override {
1853     throw std::runtime_error("AutomaticCodec error: compress() not supported.");
1854   }
1855   std::unique_ptr<IOBuf> doUncompress(
1856       const IOBuf* data,
1857       Optional<uint64_t> uncompressedLength) override;
1858
1859   void addCodecIfSupported(CodecType type);
1860
1861   // Throws iff the codecs aren't compatible (very slow)
1862   void checkCompatibleCodecs() const;
1863
1864   std::vector<std::unique_ptr<Codec>> codecs_;
1865   std::unique_ptr<Codec> terminalCodec_;
1866   bool needsUncompressedLength_;
1867   uint64_t maxUncompressedLength_;
1868 };
1869
1870 std::vector<std::string> AutomaticCodec::validPrefixes() const {
1871   std::unordered_set<std::string> prefixes;
1872   for (const auto& codec : codecs_) {
1873     const auto codecPrefixes = codec->validPrefixes();
1874     prefixes.insert(codecPrefixes.begin(), codecPrefixes.end());
1875   }
1876   return std::vector<std::string>{prefixes.begin(), prefixes.end()};
1877 }
1878
1879 bool AutomaticCodec::canUncompress(
1880     const IOBuf* data,
1881     Optional<uint64_t> uncompressedLength) const {
1882   return std::any_of(
1883       codecs_.begin(),
1884       codecs_.end(),
1885       [data, uncompressedLength](std::unique_ptr<Codec> const& codec) {
1886         return codec->canUncompress(data, uncompressedLength);
1887       });
1888 }
1889
1890 void AutomaticCodec::addCodecIfSupported(CodecType type) {
1891   const bool present = std::any_of(
1892       codecs_.begin(),
1893       codecs_.end(),
1894       [&type](std::unique_ptr<Codec> const& codec) {
1895         return codec->type() == type;
1896       });
1897   bool const isTerminalType = terminalCodec_ && terminalCodec_->type() == type;
1898   if (hasCodec(type) && !present && !isTerminalType) {
1899     codecs_.push_back(getCodec(type));
1900   }
1901 }
1902
1903 /* static */ std::unique_ptr<Codec> AutomaticCodec::create(
1904     std::vector<std::unique_ptr<Codec>> customCodecs,
1905     std::unique_ptr<Codec> terminalCodec) {
1906   return std::make_unique<AutomaticCodec>(
1907       std::move(customCodecs), std::move(terminalCodec));
1908 }
1909
1910 AutomaticCodec::AutomaticCodec(
1911     std::vector<std::unique_ptr<Codec>> customCodecs,
1912     std::unique_ptr<Codec> terminalCodec)
1913     : Codec(CodecType::USER_DEFINED),
1914       codecs_(std::move(customCodecs)),
1915       terminalCodec_(std::move(terminalCodec)) {
1916   // Fastest -> slowest
1917   std::array<CodecType, 6> defaultTypes{{
1918       CodecType::LZ4_FRAME,
1919       CodecType::ZSTD,
1920       CodecType::ZLIB,
1921       CodecType::GZIP,
1922       CodecType::LZMA2,
1923       CodecType::BZIP2,
1924   }};
1925
1926   for (auto type : defaultTypes) {
1927     addCodecIfSupported(type);
1928   }
1929
1930   if (kIsDebug) {
1931     checkCompatibleCodecs();
1932   }
1933
1934   // Check that none of the codecs are null
1935   DCHECK(std::none_of(
1936       codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
1937         return codec == nullptr;
1938       }));
1939
1940   // Check that the terminal codec's type is not duplicated (with the exception
1941   // of USER_DEFINED).
1942   if (terminalCodec_) {
1943     DCHECK(std::none_of(
1944         codecs_.begin(),
1945         codecs_.end(),
1946         [&](std::unique_ptr<Codec> const& codec) {
1947           return codec->type() != CodecType::USER_DEFINED &&
1948               codec->type() == terminalCodec_->type();
1949         }));
1950   }
1951
1952   bool const terminalNeedsUncompressedLength =
1953       terminalCodec_ && terminalCodec_->needsUncompressedLength();
1954   needsUncompressedLength_ = std::any_of(
1955                                  codecs_.begin(),
1956                                  codecs_.end(),
1957                                  [](std::unique_ptr<Codec> const& codec) {
1958                                    return codec->needsUncompressedLength();
1959                                  }) ||
1960       terminalNeedsUncompressedLength;
1961
1962   const auto it = std::max_element(
1963       codecs_.begin(),
1964       codecs_.end(),
1965       [](std::unique_ptr<Codec> const& lhs, std::unique_ptr<Codec> const& rhs) {
1966         return lhs->maxUncompressedLength() < rhs->maxUncompressedLength();
1967       });
1968   DCHECK(it != codecs_.end());
1969   auto const terminalMaxUncompressedLength =
1970       terminalCodec_ ? terminalCodec_->maxUncompressedLength() : 0;
1971   maxUncompressedLength_ =
1972       std::max((*it)->maxUncompressedLength(), terminalMaxUncompressedLength);
1973 }
1974
1975 void AutomaticCodec::checkCompatibleCodecs() const {
1976   // Keep track of all the possible headers.
1977   std::unordered_set<std::string> headers;
1978   // The empty header is not allowed.
1979   headers.insert("");
1980   // Step 1:
1981   // Construct a set of headers and check that none of the headers occur twice.
1982   // Eliminate edge cases.
1983   for (auto&& codec : codecs_) {
1984     const auto codecHeaders = codec->validPrefixes();
1985     // Codecs without any valid headers are not allowed.
1986     if (codecHeaders.empty()) {
1987       throw std::invalid_argument{
1988           "AutomaticCodec: validPrefixes() must not be empty."};
1989     }
1990     // Insert all the headers for the current codec.
1991     const size_t beforeSize = headers.size();
1992     headers.insert(codecHeaders.begin(), codecHeaders.end());
1993     // Codecs are not compatible if any header occurred twice.
1994     if (beforeSize + codecHeaders.size() != headers.size()) {
1995       throw std::invalid_argument{
1996           "AutomaticCodec: Two valid prefixes collide."};
1997     }
1998   }
1999   // Step 2:
2000   // Check if any strict non-empty prefix of any header is a header.
2001   for (const auto& header : headers) {
2002     for (size_t i = 1; i < header.size(); ++i) {
2003       if (headers.count(header.substr(0, i))) {
2004         throw std::invalid_argument{
2005             "AutomaticCodec: One valid prefix is a prefix of another valid "
2006             "prefix."};
2007       }
2008     }
2009   }
2010 }
2011
2012 bool AutomaticCodec::doNeedsUncompressedLength() const {
2013   return needsUncompressedLength_;
2014 }
2015
2016 uint64_t AutomaticCodec::doMaxUncompressedLength() const {
2017   return maxUncompressedLength_;
2018 }
2019
2020 std::unique_ptr<IOBuf> AutomaticCodec::doUncompress(
2021     const IOBuf* data,
2022     Optional<uint64_t> uncompressedLength) {
2023   try {
2024     for (auto&& codec : codecs_) {
2025       if (codec->canUncompress(data, uncompressedLength)) {
2026         return codec->uncompress(data, uncompressedLength);
2027       }
2028     }
2029   } catch (std::exception const& e) {
2030     if (!terminalCodec_) {
2031       throw e;
2032     }
2033   }
2034
2035   // Try terminal codec
2036   if (terminalCodec_) {
2037     return terminalCodec_->uncompress(data, uncompressedLength);
2038   }
2039
2040   throw std::runtime_error("AutomaticCodec error: Unknown compressed data");
2041 }
2042
2043 using CodecFactory = std::unique_ptr<Codec> (*)(int, CodecType);
2044 using StreamCodecFactory = std::unique_ptr<StreamCodec> (*)(int, CodecType);
2045 struct Factory {
2046   CodecFactory codec;
2047   StreamCodecFactory stream;
2048 };
2049
2050 constexpr Factory
2051     codecFactories[static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
2052         {}, // USER_DEFINED
2053         {NoCompressionCodec::create, nullptr},
2054
2055 #if FOLLY_HAVE_LIBLZ4
2056         {LZ4Codec::create, nullptr},
2057 #else
2058         {},
2059 #endif
2060
2061 #if FOLLY_HAVE_LIBSNAPPY
2062         {SnappyCodec::create, nullptr},
2063 #else
2064         {},
2065 #endif
2066
2067 #if FOLLY_HAVE_LIBZ
2068         {getZlibCodec, getZlibStreamCodec},
2069 #else
2070         {},
2071 #endif
2072
2073 #if FOLLY_HAVE_LIBLZ4
2074         {LZ4Codec::create, nullptr},
2075 #else
2076         {},
2077 #endif
2078
2079 #if FOLLY_HAVE_LIBLZMA
2080         {LZMA2StreamCodec::createCodec, LZMA2StreamCodec::createStream},
2081         {LZMA2StreamCodec::createCodec, LZMA2StreamCodec::createStream},
2082 #else
2083         {},
2084         {},
2085 #endif
2086
2087 #if FOLLY_HAVE_LIBZSTD
2088         {ZSTDStreamCodec::createCodec, ZSTDStreamCodec::createStream},
2089 #else
2090         {},
2091 #endif
2092
2093 #if FOLLY_HAVE_LIBZ
2094         {getZlibCodec, getZlibStreamCodec},
2095 #else
2096         {},
2097 #endif
2098
2099 #if (FOLLY_HAVE_LIBLZ4 && LZ4_VERSION_NUMBER >= 10301)
2100         {LZ4FrameCodec::create, nullptr},
2101 #else
2102         {},
2103 #endif
2104
2105 #if FOLLY_HAVE_LIBBZ2
2106         {Bzip2Codec::create, nullptr},
2107 #else
2108         {},
2109 #endif
2110 };
2111
2112 Factory const& getFactory(CodecType type) {
2113   size_t const idx = static_cast<size_t>(type);
2114   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
2115     throw std::invalid_argument(
2116         to<std::string>("Compression type ", idx, " invalid"));
2117   }
2118   return codecFactories[idx];
2119 }
2120 } // namespace
2121
2122 bool hasCodec(CodecType type) {
2123   return getFactory(type).codec != nullptr;
2124 }
2125
2126 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
2127   auto const factory = getFactory(type).codec;
2128   if (!factory) {
2129     throw std::invalid_argument(
2130         to<std::string>("Compression type ", type, " not supported"));
2131   }
2132   auto codec = (*factory)(level, type);
2133   DCHECK(codec->type() == type);
2134   return codec;
2135 }
2136
2137 bool hasStreamCodec(CodecType type) {
2138   return getFactory(type).stream != nullptr;
2139 }
2140
2141 std::unique_ptr<StreamCodec> getStreamCodec(CodecType type, int level) {
2142   auto const factory = getFactory(type).stream;
2143   if (!factory) {
2144     throw std::invalid_argument(
2145         to<std::string>("Compression type ", type, " not supported"));
2146   }
2147   auto codec = (*factory)(level, type);
2148   DCHECK(codec->type() == type);
2149   return codec;
2150 }
2151
2152 std::unique_ptr<Codec> getAutoUncompressionCodec(
2153     std::vector<std::unique_ptr<Codec>> customCodecs,
2154     std::unique_ptr<Codec> terminalCodec) {
2155   return AutomaticCodec::create(
2156       std::move(customCodecs), std::move(terminalCodec));
2157 }
2158 } // namespace io
2159 } // namespace folly