5b83a6f0bff8b1665c43b9a7c8ded285d1e170b8
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #if LZ4_VERSION_NUMBER >= 10301
23 #include <lz4frame.h>
24 #endif
25 #endif
26
27 #include <glog/logging.h>
28
29 #if FOLLY_HAVE_LIBSNAPPY
30 #include <snappy.h>
31 #include <snappy-sinksource.h>
32 #endif
33
34 #if FOLLY_HAVE_LIBZ
35 #include <zlib.h>
36 #endif
37
38 #if FOLLY_HAVE_LIBLZMA
39 #include <lzma.h>
40 #endif
41
42 #if FOLLY_HAVE_LIBZSTD
43 #include <zstd.h>
44 #endif
45
46 #include <folly/Bits.h>
47 #include <folly/Conv.h>
48 #include <folly/Memory.h>
49 #include <folly/Portability.h>
50 #include <folly/ScopeGuard.h>
51 #include <folly/Varint.h>
52 #include <folly/io/Cursor.h>
53 #include <algorithm>
54 #include <unordered_set>
55
56 namespace folly { namespace io {
57
58 Codec::Codec(CodecType type) : type_(type) { }
59
60 // Ensure consistent behavior in the nullptr case
61 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
62   uint64_t len = data->computeChainDataLength();
63   if (len == 0) {
64     return IOBuf::create(0);
65   }
66   if (len > maxUncompressedLength()) {
67     throw std::runtime_error("Codec: uncompressed length too large");
68   }
69
70   return doCompress(data);
71 }
72
73 std::string Codec::compress(const StringPiece data) {
74   const uint64_t len = data.size();
75   if (len == 0) {
76     return "";
77   }
78   if (len > maxUncompressedLength()) {
79     throw std::runtime_error("Codec: uncompressed length too large");
80   }
81
82   return doCompressString(data);
83 }
84
85 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
86                                          uint64_t uncompressedLength) {
87   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
88     if (needsUncompressedLength()) {
89       throw std::invalid_argument("Codec: uncompressed length required");
90     }
91   } else if (uncompressedLength > maxUncompressedLength()) {
92     throw std::runtime_error("Codec: uncompressed length too large");
93   }
94
95   if (data->empty()) {
96     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
97         uncompressedLength != 0) {
98       throw std::runtime_error("Codec: invalid uncompressed length");
99     }
100     return IOBuf::create(0);
101   }
102
103   return doUncompress(data, uncompressedLength);
104 }
105
106 std::string Codec::uncompress(
107     const StringPiece data,
108     uint64_t uncompressedLength) {
109   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
110     if (needsUncompressedLength()) {
111       throw std::invalid_argument("Codec: uncompressed length required");
112     }
113   } else if (uncompressedLength > maxUncompressedLength()) {
114     throw std::runtime_error("Codec: uncompressed length too large");
115   }
116
117   if (data.empty()) {
118     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
119         uncompressedLength != 0) {
120       throw std::runtime_error("Codec: invalid uncompressed length");
121     }
122     return "";
123   }
124
125   return doUncompressString(data, uncompressedLength);
126 }
127
128 bool Codec::needsUncompressedLength() const {
129   return doNeedsUncompressedLength();
130 }
131
132 uint64_t Codec::maxUncompressedLength() const {
133   return doMaxUncompressedLength();
134 }
135
136 bool Codec::doNeedsUncompressedLength() const {
137   return false;
138 }
139
140 uint64_t Codec::doMaxUncompressedLength() const {
141   return UNLIMITED_UNCOMPRESSED_LENGTH;
142 }
143
144 std::vector<std::string> Codec::validPrefixes() const {
145   return {};
146 }
147
148 bool Codec::canUncompress(const IOBuf*, uint64_t) const {
149   return false;
150 }
151
152 std::string Codec::doCompressString(const StringPiece data) {
153   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
154   auto outputBuffer = doCompress(&inputBuffer);
155   std::string output;
156   output.reserve(outputBuffer->computeChainDataLength());
157   for (auto range : *outputBuffer) {
158     output.append(reinterpret_cast<const char*>(range.data()), range.size());
159   }
160   return output;
161 }
162
163 std::string Codec::doUncompressString(
164     const StringPiece data,
165     uint64_t uncompressedLength) {
166   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
167   auto outputBuffer = doUncompress(&inputBuffer, uncompressedLength);
168   std::string output;
169   output.reserve(outputBuffer->computeChainDataLength());
170   for (auto range : *outputBuffer) {
171     output.append(reinterpret_cast<const char*>(range.data()), range.size());
172   }
173   return output;
174 }
175
176 namespace {
177
178 /**
179  * No compression
180  */
181 class NoCompressionCodec final : public Codec {
182  public:
183   static std::unique_ptr<Codec> create(int level, CodecType type);
184   explicit NoCompressionCodec(int level, CodecType type);
185
186  private:
187   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
188   std::unique_ptr<IOBuf> doUncompress(
189       const IOBuf* data,
190       uint64_t uncompressedLength) override;
191 };
192
193 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
194   return make_unique<NoCompressionCodec>(level, type);
195 }
196
197 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
198   : Codec(type) {
199   DCHECK(type == CodecType::NO_COMPRESSION);
200   switch (level) {
201   case COMPRESSION_LEVEL_DEFAULT:
202   case COMPRESSION_LEVEL_FASTEST:
203   case COMPRESSION_LEVEL_BEST:
204     level = 0;
205   }
206   if (level != 0) {
207     throw std::invalid_argument(to<std::string>(
208         "NoCompressionCodec: invalid level ", level));
209   }
210 }
211
212 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
213     const IOBuf* data) {
214   return data->clone();
215 }
216
217 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
218     const IOBuf* data,
219     uint64_t uncompressedLength) {
220   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
221       data->computeChainDataLength() != uncompressedLength) {
222     throw std::runtime_error(to<std::string>(
223         "NoCompressionCodec: invalid uncompressed length"));
224   }
225   return data->clone();
226 }
227
228 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
229
230 namespace {
231
232 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
233   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
234   out->append(encodeVarint(val, out->writableTail()));
235 }
236
237 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
238   uint64_t val = 0;
239   int8_t b = 0;
240   for (int shift = 0; shift <= 63; shift += 7) {
241     b = cursor.read<int8_t>();
242     val |= static_cast<uint64_t>(b & 0x7f) << shift;
243     if (b >= 0) {
244       break;
245     }
246   }
247   if (b < 0) {
248     throw std::invalid_argument("Invalid varint value. Too big.");
249   }
250   return val;
251 }
252
253 }  // namespace
254
255 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
256
257 namespace {
258 /**
259  * Reads sizeof(T) bytes, and returns false if not enough bytes are available.
260  * Returns true if the first n bytes are equal to prefix when interpreted as
261  * a little endian T.
262  */
263 template <typename T>
264 typename std::enable_if<std::is_unsigned<T>::value, bool>::type
265 dataStartsWithLE(const IOBuf* data, T prefix, uint64_t n = sizeof(T)) {
266   DCHECK_GT(n, 0);
267   DCHECK_LE(n, sizeof(T));
268   T value;
269   Cursor cursor{data};
270   if (!cursor.tryReadLE(value)) {
271     return false;
272   }
273   const T mask = n == sizeof(T) ? T(-1) : (T(1) << (8 * n)) - 1;
274   return prefix == (value & mask);
275 }
276
277 template <typename T>
278 typename std::enable_if<std::is_arithmetic<T>::value, std::string>::type
279 prefixToStringLE(T prefix, uint64_t n = sizeof(T)) {
280   DCHECK_GT(n, 0);
281   DCHECK_LE(n, sizeof(T));
282   prefix = Endian::little(prefix);
283   std::string result;
284   result.resize(n);
285   memcpy(&result[0], &prefix, n);
286   return result;
287 }
288 } // namespace
289
290 #if FOLLY_HAVE_LIBLZ4
291
292 /**
293  * LZ4 compression
294  */
295 class LZ4Codec final : public Codec {
296  public:
297   static std::unique_ptr<Codec> create(int level, CodecType type);
298   explicit LZ4Codec(int level, CodecType type);
299
300  private:
301   bool doNeedsUncompressedLength() const override;
302   uint64_t doMaxUncompressedLength() const override;
303
304   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
305
306   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
307   std::unique_ptr<IOBuf> doUncompress(
308       const IOBuf* data,
309       uint64_t uncompressedLength) override;
310
311   bool highCompression_;
312 };
313
314 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
315   return make_unique<LZ4Codec>(level, type);
316 }
317
318 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
319   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
320
321   switch (level) {
322   case COMPRESSION_LEVEL_FASTEST:
323   case COMPRESSION_LEVEL_DEFAULT:
324     level = 1;
325     break;
326   case COMPRESSION_LEVEL_BEST:
327     level = 2;
328     break;
329   }
330   if (level < 1 || level > 2) {
331     throw std::invalid_argument(to<std::string>(
332         "LZ4Codec: invalid level: ", level));
333   }
334   highCompression_ = (level > 1);
335 }
336
337 bool LZ4Codec::doNeedsUncompressedLength() const {
338   return !encodeSize();
339 }
340
341 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
342 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
343 // here.
344 #ifndef LZ4_MAX_INPUT_SIZE
345 # define LZ4_MAX_INPUT_SIZE 0x7E000000
346 #endif
347
348 uint64_t LZ4Codec::doMaxUncompressedLength() const {
349   return LZ4_MAX_INPUT_SIZE;
350 }
351
352 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
353   IOBuf clone;
354   if (data->isChained()) {
355     // LZ4 doesn't support streaming, so we have to coalesce
356     clone = data->cloneCoalescedAsValue();
357     data = &clone;
358   }
359
360   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
361   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
362   if (encodeSize()) {
363     encodeVarintToIOBuf(data->length(), out.get());
364   }
365
366   int n;
367   auto input = reinterpret_cast<const char*>(data->data());
368   auto output = reinterpret_cast<char*>(out->writableTail());
369   const auto inputLength = data->length();
370 #if LZ4_VERSION_NUMBER >= 10700
371   if (highCompression_) {
372     n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0);
373   } else {
374     n = LZ4_compress_default(input, output, inputLength, out->tailroom());
375   }
376 #else
377   if (highCompression_) {
378     n = LZ4_compressHC(input, output, inputLength);
379   } else {
380     n = LZ4_compress(input, output, inputLength);
381   }
382 #endif
383
384   CHECK_GE(n, 0);
385   CHECK_LE(n, out->capacity());
386
387   out->append(n);
388   return out;
389 }
390
391 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
392     const IOBuf* data,
393     uint64_t uncompressedLength) {
394   IOBuf clone;
395   if (data->isChained()) {
396     // LZ4 doesn't support streaming, so we have to coalesce
397     clone = data->cloneCoalescedAsValue();
398     data = &clone;
399   }
400
401   folly::io::Cursor cursor(data);
402   uint64_t actualUncompressedLength;
403   if (encodeSize()) {
404     actualUncompressedLength = decodeVarintFromCursor(cursor);
405     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
406         uncompressedLength != actualUncompressedLength) {
407       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
408     }
409   } else {
410     actualUncompressedLength = uncompressedLength;
411     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
412         actualUncompressedLength > maxUncompressedLength()) {
413       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
414     }
415   }
416
417   auto sp = StringPiece{cursor.peekBytes()};
418   auto out = IOBuf::create(actualUncompressedLength);
419   int n = LZ4_decompress_safe(
420       sp.data(),
421       reinterpret_cast<char*>(out->writableTail()),
422       sp.size(),
423       actualUncompressedLength);
424
425   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
426     throw std::runtime_error(to<std::string>(
427         "LZ4 decompression returned invalid value ", n));
428   }
429   out->append(actualUncompressedLength);
430   return out;
431 }
432
433 #if LZ4_VERSION_NUMBER >= 10301
434
435 class LZ4FrameCodec final : public Codec {
436  public:
437   static std::unique_ptr<Codec> create(int level, CodecType type);
438   explicit LZ4FrameCodec(int level, CodecType type);
439   ~LZ4FrameCodec();
440
441   std::vector<std::string> validPrefixes() const override;
442   bool canUncompress(const IOBuf* data, uint64_t uncompressedLength)
443       const override;
444
445  private:
446   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
447   std::unique_ptr<IOBuf> doUncompress(
448       const IOBuf* data,
449       uint64_t uncompressedLength) override;
450
451   // Reset the dctx_ if it is dirty or null.
452   void resetDCtx();
453
454   int level_;
455   LZ4F_decompressionContext_t dctx_{nullptr};
456   bool dirty_{false};
457 };
458
459 /* static */ std::unique_ptr<Codec> LZ4FrameCodec::create(
460     int level,
461     CodecType type) {
462   return make_unique<LZ4FrameCodec>(level, type);
463 }
464
465 static constexpr uint32_t kLZ4FrameMagicLE = 0x184D2204;
466
467 std::vector<std::string> LZ4FrameCodec::validPrefixes() const {
468   return {prefixToStringLE(kLZ4FrameMagicLE)};
469 }
470
471 bool LZ4FrameCodec::canUncompress(const IOBuf* data, uint64_t) const {
472   return dataStartsWithLE(data, kLZ4FrameMagicLE);
473 }
474
475 static size_t lz4FrameThrowOnError(size_t code) {
476   if (LZ4F_isError(code)) {
477     throw std::runtime_error(
478         to<std::string>("LZ4Frame error: ", LZ4F_getErrorName(code)));
479   }
480   return code;
481 }
482
483 void LZ4FrameCodec::resetDCtx() {
484   if (dctx_ && !dirty_) {
485     return;
486   }
487   if (dctx_) {
488     LZ4F_freeDecompressionContext(dctx_);
489   }
490   lz4FrameThrowOnError(LZ4F_createDecompressionContext(&dctx_, 100));
491   dirty_ = false;
492 }
493
494 LZ4FrameCodec::LZ4FrameCodec(int level, CodecType type) : Codec(type) {
495   DCHECK(type == CodecType::LZ4_FRAME);
496   switch (level) {
497     case COMPRESSION_LEVEL_FASTEST:
498     case COMPRESSION_LEVEL_DEFAULT:
499       level_ = 0;
500       break;
501     case COMPRESSION_LEVEL_BEST:
502       level_ = 16;
503       break;
504     default:
505       level_ = level;
506       break;
507   }
508 }
509
510 LZ4FrameCodec::~LZ4FrameCodec() {
511   if (dctx_) {
512     LZ4F_freeDecompressionContext(dctx_);
513   }
514 }
515
516 std::unique_ptr<IOBuf> LZ4FrameCodec::doCompress(const IOBuf* data) {
517   // LZ4 Frame compression doesn't support streaming so we have to coalesce
518   IOBuf clone;
519   if (data->isChained()) {
520     clone = data->cloneCoalescedAsValue();
521     data = &clone;
522   }
523   // Set preferences
524   const auto uncompressedLength = data->length();
525   LZ4F_preferences_t prefs{};
526   prefs.compressionLevel = level_;
527   prefs.frameInfo.contentSize = uncompressedLength;
528   // Compress
529   auto buf = IOBuf::create(LZ4F_compressFrameBound(uncompressedLength, &prefs));
530   const size_t written = lz4FrameThrowOnError(LZ4F_compressFrame(
531       buf->writableTail(),
532       buf->tailroom(),
533       data->data(),
534       data->length(),
535       &prefs));
536   buf->append(written);
537   return buf;
538 }
539
540 std::unique_ptr<IOBuf> LZ4FrameCodec::doUncompress(
541     const IOBuf* data,
542     uint64_t uncompressedLength) {
543   // Reset the dctx if any errors have occurred
544   resetDCtx();
545   // Coalesce the data
546   ByteRange in = *data->begin();
547   IOBuf clone;
548   if (data->isChained()) {
549     clone = data->cloneCoalescedAsValue();
550     in = clone.coalesce();
551   }
552   data = nullptr;
553   // Select decompression options
554   LZ4F_decompressOptions_t options;
555   options.stableDst = 1;
556   // Select blockSize and growthSize for the IOBufQueue
557   IOBufQueue queue(IOBufQueue::cacheChainLength());
558   auto blockSize = uint64_t{64} << 10;
559   auto growthSize = uint64_t{4} << 20;
560   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
561     // Allocate uncompressedLength in one chunk (up to 64 MB)
562     const auto allocateSize = std::min(uncompressedLength, uint64_t{64} << 20);
563     queue.preallocate(allocateSize, allocateSize);
564     blockSize = std::min(uncompressedLength, blockSize);
565     growthSize = std::min(uncompressedLength, growthSize);
566   } else {
567     // Reduce growthSize for small data
568     const auto guessUncompressedLen = 4 * std::max(blockSize, in.size());
569     growthSize = std::min(guessUncompressedLen, growthSize);
570   }
571   // Once LZ4_decompress() is called, the dctx_ cannot be reused until it
572   // returns 0
573   dirty_ = true;
574   // Decompress until the frame is over
575   size_t code = 0;
576   do {
577     // Allocate enough space to decompress at least a block
578     void* out;
579     size_t outSize;
580     std::tie(out, outSize) = queue.preallocate(blockSize, growthSize);
581     // Decompress
582     size_t inSize = in.size();
583     code = lz4FrameThrowOnError(
584         LZ4F_decompress(dctx_, out, &outSize, in.data(), &inSize, &options));
585     if (in.empty() && outSize == 0 && code != 0) {
586       // We passed no input, no output was produced, and the frame isn't over
587       // No more forward progress is possible
588       throw std::runtime_error("LZ4Frame error: Incomplete frame");
589     }
590     in.uncheckedAdvance(inSize);
591     queue.postallocate(outSize);
592   } while (code != 0);
593   // At this point the decompression context can be reused
594   dirty_ = false;
595   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
596       queue.chainLength() != uncompressedLength) {
597     throw std::runtime_error("LZ4Frame error: Invalid uncompressedLength");
598   }
599   return queue.move();
600 }
601
602 #endif // LZ4_VERSION_NUMBER >= 10301
603 #endif // FOLLY_HAVE_LIBLZ4
604
605 #if FOLLY_HAVE_LIBSNAPPY
606
607 /**
608  * Snappy compression
609  */
610
611 /**
612  * Implementation of snappy::Source that reads from a IOBuf chain.
613  */
614 class IOBufSnappySource final : public snappy::Source {
615  public:
616   explicit IOBufSnappySource(const IOBuf* data);
617   size_t Available() const override;
618   const char* Peek(size_t* len) override;
619   void Skip(size_t n) override;
620  private:
621   size_t available_;
622   io::Cursor cursor_;
623 };
624
625 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
626   : available_(data->computeChainDataLength()),
627     cursor_(data) {
628 }
629
630 size_t IOBufSnappySource::Available() const {
631   return available_;
632 }
633
634 const char* IOBufSnappySource::Peek(size_t* len) {
635   auto sp = StringPiece{cursor_.peekBytes()};
636   *len = sp.size();
637   return sp.data();
638 }
639
640 void IOBufSnappySource::Skip(size_t n) {
641   CHECK_LE(n, available_);
642   cursor_.skip(n);
643   available_ -= n;
644 }
645
646 class SnappyCodec final : public Codec {
647  public:
648   static std::unique_ptr<Codec> create(int level, CodecType type);
649   explicit SnappyCodec(int level, CodecType type);
650
651  private:
652   uint64_t doMaxUncompressedLength() const override;
653   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
654   std::unique_ptr<IOBuf> doUncompress(
655       const IOBuf* data,
656       uint64_t uncompressedLength) override;
657 };
658
659 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
660   return make_unique<SnappyCodec>(level, type);
661 }
662
663 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
664   DCHECK(type == CodecType::SNAPPY);
665   switch (level) {
666   case COMPRESSION_LEVEL_FASTEST:
667   case COMPRESSION_LEVEL_DEFAULT:
668   case COMPRESSION_LEVEL_BEST:
669     level = 1;
670   }
671   if (level != 1) {
672     throw std::invalid_argument(to<std::string>(
673         "SnappyCodec: invalid level: ", level));
674   }
675 }
676
677 uint64_t SnappyCodec::doMaxUncompressedLength() const {
678   // snappy.h uses uint32_t for lengths, so there's that.
679   return std::numeric_limits<uint32_t>::max();
680 }
681
682 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
683   IOBufSnappySource source(data);
684   auto out =
685     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
686
687   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
688       out->writableTail()));
689
690   size_t n = snappy::Compress(&source, &sink);
691
692   CHECK_LE(n, out->capacity());
693   out->append(n);
694   return out;
695 }
696
697 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
698                                                  uint64_t uncompressedLength) {
699   uint32_t actualUncompressedLength = 0;
700
701   {
702     IOBufSnappySource source(data);
703     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
704       throw std::runtime_error("snappy::GetUncompressedLength failed");
705     }
706     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
707         uncompressedLength != actualUncompressedLength) {
708       throw std::runtime_error("snappy: invalid uncompressed length");
709     }
710   }
711
712   auto out = IOBuf::create(actualUncompressedLength);
713
714   {
715     IOBufSnappySource source(data);
716     if (!snappy::RawUncompress(&source,
717                                reinterpret_cast<char*>(out->writableTail()))) {
718       throw std::runtime_error("snappy::RawUncompress failed");
719     }
720   }
721
722   out->append(actualUncompressedLength);
723   return out;
724 }
725
726 #endif  // FOLLY_HAVE_LIBSNAPPY
727
728 #if FOLLY_HAVE_LIBZ
729 /**
730  * Zlib codec
731  */
732 class ZlibCodec final : public Codec {
733  public:
734   static std::unique_ptr<Codec> create(int level, CodecType type);
735   explicit ZlibCodec(int level, CodecType type);
736
737   std::vector<std::string> validPrefixes() const override;
738   bool canUncompress(const IOBuf* data, uint64_t uncompressedLength)
739       const override;
740
741  private:
742   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
743   std::unique_ptr<IOBuf> doUncompress(
744       const IOBuf* data,
745       uint64_t uncompressedLength) override;
746
747   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
748   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
749
750   int level_;
751 };
752
753 static constexpr uint16_t kGZIPMagicLE = 0x8B1F;
754
755 std::vector<std::string> ZlibCodec::validPrefixes() const {
756   if (type() == CodecType::ZLIB) {
757     // Zlib streams start with a 2 byte header.
758     //
759     //   0   1
760     // +---+---+
761     // |CMF|FLG|
762     // +---+---+
763     //
764     // We won't restrict the values of any sub-fields except as described below.
765     //
766     // The lowest 4 bits of CMF is the compression method (CM).
767     // CM == 0x8 is the deflate compression method, which is currently the only
768     // supported compression method, so any valid prefix must have CM == 0x8.
769     //
770     // The lowest 5 bits of FLG is FCHECK.
771     // FCHECK must be such that the two header bytes are a multiple of 31 when
772     // interpreted as a big endian 16-bit number.
773     std::vector<std::string> result;
774     // 16 values for the first byte, 8 values for the second byte.
775     // There are also 4 combinations where both 0x00 and 0x1F work as FCHECK.
776     result.reserve(132);
777     // Select all values for the CMF byte that use the deflate algorithm 0x8.
778     for (uint32_t first = 0x0800; first <= 0xF800; first += 0x1000) {
779       // Select all values for the FLG, but leave FCHECK as 0 since it's fixed.
780       for (uint32_t second = 0x00; second <= 0xE0; second += 0x20) {
781         uint16_t prefix = first | second;
782         // Compute FCHECK.
783         prefix += 31 - (prefix % 31);
784         result.push_back(prefixToStringLE(Endian::big(prefix)));
785         // zlib won't produce this, but it is a valid prefix.
786         if ((prefix & 0x1F) == 31) {
787           prefix -= 31;
788           result.push_back(prefixToStringLE(Endian::big(prefix)));
789         }
790       }
791     }
792     return result;
793   } else {
794     // The gzip frame starts with 2 magic bytes.
795     return {prefixToStringLE(kGZIPMagicLE)};
796   }
797 }
798
799 bool ZlibCodec::canUncompress(const IOBuf* data, uint64_t) const {
800   if (type() == CodecType::ZLIB) {
801     uint16_t value;
802     Cursor cursor{data};
803     if (!cursor.tryReadBE(value)) {
804       return false;
805     }
806     // zlib compressed if using deflate and is a multiple of 31.
807     return (value & 0x0F00) == 0x0800 && value % 31 == 0;
808   } else {
809     return dataStartsWithLE(data, kGZIPMagicLE);
810   }
811 }
812
813 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
814   return make_unique<ZlibCodec>(level, type);
815 }
816
817 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
818   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
819   switch (level) {
820   case COMPRESSION_LEVEL_FASTEST:
821     level = 1;
822     break;
823   case COMPRESSION_LEVEL_DEFAULT:
824     level = Z_DEFAULT_COMPRESSION;
825     break;
826   case COMPRESSION_LEVEL_BEST:
827     level = 9;
828     break;
829   }
830   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
831     throw std::invalid_argument(to<std::string>(
832         "ZlibCodec: invalid level: ", level));
833   }
834   level_ = level;
835 }
836
837 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
838                                                   uint32_t length) {
839   CHECK_EQ(stream->avail_out, 0);
840
841   auto buf = IOBuf::create(length);
842   buf->append(length);
843
844   stream->next_out = buf->writableData();
845   stream->avail_out = buf->length();
846
847   return buf;
848 }
849
850 bool ZlibCodec::doInflate(z_stream* stream,
851                           IOBuf* head,
852                           uint32_t bufferLength) {
853   if (stream->avail_out == 0) {
854     head->prependChain(addOutputBuffer(stream, bufferLength));
855   }
856
857   int rc = inflate(stream, Z_NO_FLUSH);
858
859   switch (rc) {
860   case Z_OK:
861     break;
862   case Z_STREAM_END:
863     return true;
864   case Z_BUF_ERROR:
865   case Z_NEED_DICT:
866   case Z_DATA_ERROR:
867   case Z_MEM_ERROR:
868     throw std::runtime_error(to<std::string>(
869         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
870   default:
871     CHECK(false) << rc << ": " << stream->msg;
872   }
873
874   return false;
875 }
876
877 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
878   z_stream stream;
879   stream.zalloc = nullptr;
880   stream.zfree = nullptr;
881   stream.opaque = nullptr;
882
883   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
884   // base two logarithm of the maximum window size (...) The default value is
885   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
886   // around the compressed data instead of a zlib wrapper. The gzip header
887   // will have no file name, no extra data, no comment, no modification time
888   // (set to zero), no header crc, and the operating system will be set to 255
889   // (unknown)."
890   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
891   // All other parameters (method, memLevel, strategy) get default values from
892   // the zlib manual.
893   int rc = deflateInit2(&stream,
894                         level_,
895                         Z_DEFLATED,
896                         windowBits,
897                         /* memLevel */ 8,
898                         Z_DEFAULT_STRATEGY);
899   if (rc != Z_OK) {
900     throw std::runtime_error(to<std::string>(
901         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
902   }
903
904   stream.next_in = stream.next_out = nullptr;
905   stream.avail_in = stream.avail_out = 0;
906   stream.total_in = stream.total_out = 0;
907
908   bool success = false;
909
910   SCOPE_EXIT {
911     rc = deflateEnd(&stream);
912     // If we're here because of an exception, it's okay if some data
913     // got dropped.
914     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
915       << rc << ": " << stream.msg;
916   };
917
918   uint64_t uncompressedLength = data->computeChainDataLength();
919   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
920
921   // Max 64MiB in one go
922   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
923   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
924
925   auto out = addOutputBuffer(
926       &stream,
927       (maxCompressedLength <= maxSingleStepLength ?
928        maxCompressedLength :
929        defaultBufferLength));
930
931   for (auto& range : *data) {
932     uint64_t remaining = range.size();
933     uint64_t written = 0;
934     while (remaining) {
935       uint32_t step = (remaining > maxSingleStepLength ?
936                        maxSingleStepLength : remaining);
937       stream.next_in = const_cast<uint8_t*>(range.data() + written);
938       stream.avail_in = step;
939       remaining -= step;
940       written += step;
941
942       while (stream.avail_in != 0) {
943         if (stream.avail_out == 0) {
944           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
945         }
946
947         rc = deflate(&stream, Z_NO_FLUSH);
948
949         CHECK_EQ(rc, Z_OK) << stream.msg;
950       }
951     }
952   }
953
954   do {
955     if (stream.avail_out == 0) {
956       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
957     }
958
959     rc = deflate(&stream, Z_FINISH);
960   } while (rc == Z_OK);
961
962   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
963
964   out->prev()->trimEnd(stream.avail_out);
965
966   success = true;  // we survived
967
968   return out;
969 }
970
971 static uint64_t computeBufferLength(uint64_t const compressedLength) {
972   constexpr uint64_t kMaxBufferLength = uint64_t(4) << 20; // 4 MiB
973   constexpr uint64_t kBlockSize = uint64_t(32) << 10; // 32 KiB
974   const uint64_t goodBufferSize = 4 * std::max(kBlockSize, compressedLength);
975   return std::min(goodBufferSize, kMaxBufferLength);
976 }
977
978 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
979                                                uint64_t uncompressedLength) {
980   z_stream stream;
981   stream.zalloc = nullptr;
982   stream.zfree = nullptr;
983   stream.opaque = nullptr;
984
985   // "The windowBits parameter is the base two logarithm of the maximum window
986   // size (...) The default value is 15 (...) add 16 to decode only the gzip
987   // format (the zlib format will return a Z_DATA_ERROR)."
988   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
989   int rc = inflateInit2(&stream, windowBits);
990   if (rc != Z_OK) {
991     throw std::runtime_error(to<std::string>(
992         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
993   }
994
995   stream.next_in = stream.next_out = nullptr;
996   stream.avail_in = stream.avail_out = 0;
997   stream.total_in = stream.total_out = 0;
998
999   bool success = false;
1000
1001   SCOPE_EXIT {
1002     rc = inflateEnd(&stream);
1003     // If we're here because of an exception, it's okay if some data
1004     // got dropped.
1005     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
1006       << rc << ": " << stream.msg;
1007   };
1008
1009   // Max 64MiB in one go
1010   constexpr uint64_t maxSingleStepLength = uint64_t(64) << 20; // 64MiB
1011   const uint64_t defaultBufferLength =
1012       computeBufferLength(data->computeChainDataLength());
1013
1014   auto out = addOutputBuffer(
1015       &stream,
1016       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1017         uncompressedLength <= maxSingleStepLength) ?
1018        uncompressedLength :
1019        defaultBufferLength));
1020
1021   bool streamEnd = false;
1022   for (auto& range : *data) {
1023     if (range.empty()) {
1024       continue;
1025     }
1026
1027     stream.next_in = const_cast<uint8_t*>(range.data());
1028     stream.avail_in = range.size();
1029
1030     while (stream.avail_in != 0) {
1031       if (streamEnd) {
1032         throw std::runtime_error(to<std::string>(
1033             "ZlibCodec: junk after end of data"));
1034       }
1035
1036       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1037     }
1038   }
1039
1040   while (!streamEnd) {
1041     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1042   }
1043
1044   out->prev()->trimEnd(stream.avail_out);
1045
1046   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1047       uncompressedLength != stream.total_out) {
1048     throw std::runtime_error(to<std::string>(
1049         "ZlibCodec: invalid uncompressed length"));
1050   }
1051
1052   success = true;  // we survived
1053
1054   return out;
1055 }
1056
1057 #endif  // FOLLY_HAVE_LIBZ
1058
1059 #if FOLLY_HAVE_LIBLZMA
1060
1061 /**
1062  * LZMA2 compression
1063  */
1064 class LZMA2Codec final : public Codec {
1065  public:
1066   static std::unique_ptr<Codec> create(int level, CodecType type);
1067   explicit LZMA2Codec(int level, CodecType type);
1068
1069   std::vector<std::string> validPrefixes() const override;
1070   bool canUncompress(const IOBuf* data, uint64_t uncompressedLength)
1071       const override;
1072
1073  private:
1074   bool doNeedsUncompressedLength() const override;
1075   uint64_t doMaxUncompressedLength() const override;
1076
1077   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
1078
1079   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
1080   std::unique_ptr<IOBuf> doUncompress(
1081       const IOBuf* data,
1082       uint64_t uncompressedLength) override;
1083
1084   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
1085   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
1086
1087   int level_;
1088 };
1089
1090 static constexpr uint64_t kLZMA2MagicLE = 0x005A587A37FD;
1091 static constexpr unsigned kLZMA2MagicBytes = 6;
1092
1093 std::vector<std::string> LZMA2Codec::validPrefixes() const {
1094   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1095     return {};
1096   }
1097   return {prefixToStringLE(kLZMA2MagicLE, kLZMA2MagicBytes)};
1098 }
1099
1100 bool LZMA2Codec::canUncompress(const IOBuf* data, uint64_t) const {
1101   if (type() == CodecType::LZMA2_VARINT_SIZE) {
1102     return false;
1103   }
1104   // Returns false for all inputs less than 8 bytes.
1105   // This is okay, because no valid LZMA2 streams are less than 8 bytes.
1106   return dataStartsWithLE(data, kLZMA2MagicLE, kLZMA2MagicBytes);
1107 }
1108
1109 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
1110   return make_unique<LZMA2Codec>(level, type);
1111 }
1112
1113 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
1114   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
1115   switch (level) {
1116   case COMPRESSION_LEVEL_FASTEST:
1117     level = 0;
1118     break;
1119   case COMPRESSION_LEVEL_DEFAULT:
1120     level = LZMA_PRESET_DEFAULT;
1121     break;
1122   case COMPRESSION_LEVEL_BEST:
1123     level = 9;
1124     break;
1125   }
1126   if (level < 0 || level > 9) {
1127     throw std::invalid_argument(to<std::string>(
1128         "LZMA2Codec: invalid level: ", level));
1129   }
1130   level_ = level;
1131 }
1132
1133 bool LZMA2Codec::doNeedsUncompressedLength() const {
1134   return false;
1135 }
1136
1137 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
1138   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
1139   return uint64_t(1) << 63;
1140 }
1141
1142 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
1143     lzma_stream* stream,
1144     size_t length) {
1145
1146   CHECK_EQ(stream->avail_out, 0);
1147
1148   auto buf = IOBuf::create(length);
1149   buf->append(length);
1150
1151   stream->next_out = buf->writableData();
1152   stream->avail_out = buf->length();
1153
1154   return buf;
1155 }
1156
1157 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
1158   lzma_ret rc;
1159   lzma_stream stream = LZMA_STREAM_INIT;
1160
1161   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
1162   if (rc != LZMA_OK) {
1163     throw std::runtime_error(folly::to<std::string>(
1164       "LZMA2Codec: lzma_easy_encoder error: ", rc));
1165   }
1166
1167   SCOPE_EXIT { lzma_end(&stream); };
1168
1169   uint64_t uncompressedLength = data->computeChainDataLength();
1170   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
1171
1172   // Max 64MiB in one go
1173   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
1174   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
1175
1176   auto out = addOutputBuffer(
1177     &stream,
1178     (maxCompressedLength <= maxSingleStepLength ?
1179      maxCompressedLength :
1180      defaultBufferLength));
1181
1182   if (encodeSize()) {
1183     auto size = IOBuf::createCombined(kMaxVarintLength64);
1184     encodeVarintToIOBuf(uncompressedLength, size.get());
1185     size->appendChain(std::move(out));
1186     out = std::move(size);
1187   }
1188
1189   for (auto& range : *data) {
1190     if (range.empty()) {
1191       continue;
1192     }
1193
1194     stream.next_in = const_cast<uint8_t*>(range.data());
1195     stream.avail_in = range.size();
1196
1197     while (stream.avail_in != 0) {
1198       if (stream.avail_out == 0) {
1199         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1200       }
1201
1202       rc = lzma_code(&stream, LZMA_RUN);
1203
1204       if (rc != LZMA_OK) {
1205         throw std::runtime_error(folly::to<std::string>(
1206           "LZMA2Codec: lzma_code error: ", rc));
1207       }
1208     }
1209   }
1210
1211   do {
1212     if (stream.avail_out == 0) {
1213       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1214     }
1215
1216     rc = lzma_code(&stream, LZMA_FINISH);
1217   } while (rc == LZMA_OK);
1218
1219   if (rc != LZMA_STREAM_END) {
1220     throw std::runtime_error(folly::to<std::string>(
1221       "LZMA2Codec: lzma_code ended with error: ", rc));
1222   }
1223
1224   out->prev()->trimEnd(stream.avail_out);
1225
1226   return out;
1227 }
1228
1229 bool LZMA2Codec::doInflate(lzma_stream* stream,
1230                           IOBuf* head,
1231                           size_t bufferLength) {
1232   if (stream->avail_out == 0) {
1233     head->prependChain(addOutputBuffer(stream, bufferLength));
1234   }
1235
1236   lzma_ret rc = lzma_code(stream, LZMA_RUN);
1237
1238   switch (rc) {
1239   case LZMA_OK:
1240     break;
1241   case LZMA_STREAM_END:
1242     return true;
1243   default:
1244     throw std::runtime_error(to<std::string>(
1245         "LZMA2Codec: lzma_code error: ", rc));
1246   }
1247
1248   return false;
1249 }
1250
1251 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
1252                                                uint64_t uncompressedLength) {
1253   lzma_ret rc;
1254   lzma_stream stream = LZMA_STREAM_INIT;
1255
1256   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
1257   if (rc != LZMA_OK) {
1258     throw std::runtime_error(folly::to<std::string>(
1259       "LZMA2Codec: lzma_auto_decoder error: ", rc));
1260   }
1261
1262   SCOPE_EXIT { lzma_end(&stream); };
1263
1264   // Max 64MiB in one go
1265   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20; // 64MiB
1266   constexpr uint32_t defaultBufferLength = uint32_t(256) << 10; // 256 KiB
1267
1268   folly::io::Cursor cursor(data);
1269   if (encodeSize()) {
1270     const uint64_t actualUncompressedLength = decodeVarintFromCursor(cursor);
1271     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1272         uncompressedLength != actualUncompressedLength) {
1273       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
1274     }
1275     uncompressedLength = actualUncompressedLength;
1276   }
1277
1278   auto out = addOutputBuffer(
1279       &stream,
1280       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1281         uncompressedLength <= maxSingleStepLength)
1282            ? uncompressedLength
1283            : defaultBufferLength));
1284
1285   bool streamEnd = false;
1286   auto buf = cursor.peekBytes();
1287   while (!buf.empty()) {
1288     stream.next_in = const_cast<uint8_t*>(buf.data());
1289     stream.avail_in = buf.size();
1290
1291     while (stream.avail_in != 0) {
1292       if (streamEnd) {
1293         throw std::runtime_error(to<std::string>(
1294             "LZMA2Codec: junk after end of data"));
1295       }
1296
1297       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1298     }
1299
1300     cursor.skip(buf.size());
1301     buf = cursor.peekBytes();
1302   }
1303
1304   while (!streamEnd) {
1305     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1306   }
1307
1308   out->prev()->trimEnd(stream.avail_out);
1309
1310   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1311       uncompressedLength != stream.total_out) {
1312     throw std::runtime_error(
1313         to<std::string>("LZMA2Codec: invalid uncompressed length"));
1314   }
1315
1316   return out;
1317 }
1318
1319 #endif  // FOLLY_HAVE_LIBLZMA
1320
1321 #ifdef FOLLY_HAVE_LIBZSTD
1322
1323 /**
1324  * ZSTD compression
1325  */
1326 class ZSTDCodec final : public Codec {
1327  public:
1328   static std::unique_ptr<Codec> create(int level, CodecType);
1329   explicit ZSTDCodec(int level, CodecType type);
1330
1331   std::vector<std::string> validPrefixes() const override;
1332   bool canUncompress(const IOBuf* data, uint64_t uncompressedLength)
1333       const override;
1334
1335  private:
1336   bool doNeedsUncompressedLength() const override;
1337   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
1338   std::unique_ptr<IOBuf> doUncompress(
1339       const IOBuf* data,
1340       uint64_t uncompressedLength) override;
1341
1342   int level_;
1343 };
1344
1345 static constexpr uint32_t kZSTDMagicLE = 0xFD2FB528;
1346
1347 std::vector<std::string> ZSTDCodec::validPrefixes() const {
1348   return {prefixToStringLE(kZSTDMagicLE)};
1349 }
1350
1351 bool ZSTDCodec::canUncompress(const IOBuf* data, uint64_t) const {
1352   return dataStartsWithLE(data, kZSTDMagicLE);
1353 }
1354
1355 std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
1356   return make_unique<ZSTDCodec>(level, type);
1357 }
1358
1359 ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
1360   DCHECK(type == CodecType::ZSTD);
1361   switch (level) {
1362     case COMPRESSION_LEVEL_FASTEST:
1363       level = 1;
1364       break;
1365     case COMPRESSION_LEVEL_DEFAULT:
1366       level = 1;
1367       break;
1368     case COMPRESSION_LEVEL_BEST:
1369       level = 19;
1370       break;
1371   }
1372   if (level < 1 || level > ZSTD_maxCLevel()) {
1373     throw std::invalid_argument(
1374         to<std::string>("ZSTD: invalid level: ", level));
1375   }
1376   level_ = level;
1377 }
1378
1379 bool ZSTDCodec::doNeedsUncompressedLength() const {
1380   return false;
1381 }
1382
1383 void zstdThrowIfError(size_t rc) {
1384   if (!ZSTD_isError(rc)) {
1385     return;
1386   }
1387   throw std::runtime_error(
1388       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1389 }
1390
1391 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
1392   // Support earlier versions of the codec (working with a single IOBuf,
1393   // and using ZSTD_decompress which requires ZSTD frame to contain size,
1394   // which isn't populated by streaming API).
1395   if (!data->isChained()) {
1396     auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
1397     const auto rc = ZSTD_compress(
1398         out->writableData(),
1399         out->capacity(),
1400         data->data(),
1401         data->length(),
1402         level_);
1403     zstdThrowIfError(rc);
1404     out->append(rc);
1405     return out;
1406   }
1407
1408   auto zcs = ZSTD_createCStream();
1409   SCOPE_EXIT {
1410     ZSTD_freeCStream(zcs);
1411   };
1412
1413   auto rc = ZSTD_initCStream(zcs, level_);
1414   zstdThrowIfError(rc);
1415
1416   Cursor cursor(data);
1417   auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
1418
1419   ZSTD_outBuffer out;
1420   out.dst = result->writableTail();
1421   out.size = result->capacity();
1422   out.pos = 0;
1423
1424   for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
1425     ZSTD_inBuffer in;
1426     in.src = buffer.data();
1427     in.size = buffer.size();
1428     for (in.pos = 0; in.pos != in.size;) {
1429       rc = ZSTD_compressStream(zcs, &out, &in);
1430       zstdThrowIfError(rc);
1431     }
1432     cursor.skip(in.size);
1433     buffer = cursor.peekBytes();
1434   }
1435
1436   rc = ZSTD_endStream(zcs, &out);
1437   zstdThrowIfError(rc);
1438   CHECK_EQ(rc, 0);
1439
1440   result->append(out.pos);
1441   return result;
1442 }
1443
1444 static std::unique_ptr<IOBuf> zstdUncompressBuffer(
1445     const IOBuf* data,
1446     uint64_t uncompressedLength) {
1447   // Check preconditions
1448   DCHECK(!data->isChained());
1449   DCHECK(uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH);
1450
1451   auto uncompressed = IOBuf::create(uncompressedLength);
1452   const auto decompressedSize = ZSTD_decompress(
1453       uncompressed->writableTail(),
1454       uncompressed->tailroom(),
1455       data->data(),
1456       data->length());
1457   zstdThrowIfError(decompressedSize);
1458   if (decompressedSize != uncompressedLength) {
1459     throw std::runtime_error("ZSTD: invalid uncompressed length");
1460   }
1461   uncompressed->append(decompressedSize);
1462   return uncompressed;
1463 }
1464
1465 static std::unique_ptr<IOBuf> zstdUncompressStream(
1466     const IOBuf* data,
1467     uint64_t uncompressedLength) {
1468   auto zds = ZSTD_createDStream();
1469   SCOPE_EXIT {
1470     ZSTD_freeDStream(zds);
1471   };
1472
1473   auto rc = ZSTD_initDStream(zds);
1474   zstdThrowIfError(rc);
1475
1476   ZSTD_outBuffer out{};
1477   ZSTD_inBuffer in{};
1478
1479   auto outputSize = ZSTD_DStreamOutSize();
1480   if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH) {
1481     outputSize = uncompressedLength;
1482   }
1483
1484   IOBufQueue queue(IOBufQueue::cacheChainLength());
1485
1486   Cursor cursor(data);
1487   for (rc = 0;;) {
1488     if (in.pos == in.size) {
1489       auto buffer = cursor.peekBytes();
1490       in.src = buffer.data();
1491       in.size = buffer.size();
1492       in.pos = 0;
1493       cursor.skip(in.size);
1494       if (rc > 1 && in.size == 0) {
1495         throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
1496       }
1497     }
1498     if (out.pos == out.size) {
1499       if (out.pos != 0) {
1500         queue.postallocate(out.pos);
1501       }
1502       auto buffer = queue.preallocate(outputSize, outputSize);
1503       out.dst = buffer.first;
1504       out.size = buffer.second;
1505       out.pos = 0;
1506       outputSize = ZSTD_DStreamOutSize();
1507     }
1508     rc = ZSTD_decompressStream(zds, &out, &in);
1509     zstdThrowIfError(rc);
1510     if (rc == 0) {
1511       break;
1512     }
1513   }
1514   if (out.pos != 0) {
1515     queue.postallocate(out.pos);
1516   }
1517   if (in.pos != in.size || !cursor.isAtEnd()) {
1518     throw std::runtime_error("ZSTD: junk after end of data");
1519   }
1520   if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH &&
1521       queue.chainLength() != uncompressedLength) {
1522     throw std::runtime_error("ZSTD: invalid uncompressed length");
1523   }
1524
1525   return queue.move();
1526 }
1527
1528 std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
1529     const IOBuf* data,
1530     uint64_t uncompressedLength) {
1531   {
1532     // Read decompressed size from frame if available in first IOBuf.
1533     const auto decompressedSize =
1534         ZSTD_getDecompressedSize(data->data(), data->length());
1535     if (decompressedSize != 0) {
1536       if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH &&
1537           uncompressedLength != decompressedSize) {
1538         throw std::runtime_error("ZSTD: invalid uncompressed length");
1539       }
1540       uncompressedLength = decompressedSize;
1541     }
1542   }
1543   // Faster to decompress using ZSTD_decompress() if we can.
1544   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH && !data->isChained()) {
1545     return zstdUncompressBuffer(data, uncompressedLength);
1546   }
1547   // Fall back to slower streaming decompression.
1548   return zstdUncompressStream(data, uncompressedLength);
1549 }
1550
1551 #endif  // FOLLY_HAVE_LIBZSTD
1552
1553 /**
1554  * Automatic decompression
1555  */
1556 class AutomaticCodec final : public Codec {
1557  public:
1558   static std::unique_ptr<Codec> create(
1559       std::vector<std::unique_ptr<Codec>> customCodecs);
1560   explicit AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs);
1561
1562   std::vector<std::string> validPrefixes() const override;
1563   bool canUncompress(const IOBuf* data, uint64_t uncompressedLength)
1564       const override;
1565
1566  private:
1567   bool doNeedsUncompressedLength() const override;
1568   uint64_t doMaxUncompressedLength() const override;
1569
1570   std::unique_ptr<IOBuf> doCompress(const IOBuf*) override {
1571     throw std::runtime_error("AutomaticCodec error: compress() not supported.");
1572   }
1573   std::unique_ptr<IOBuf> doUncompress(
1574       const IOBuf* data,
1575       uint64_t uncompressedLength) override;
1576
1577   void addCodecIfSupported(CodecType type);
1578
1579   // Throws iff the codecs aren't compatible (very slow)
1580   void checkCompatibleCodecs() const;
1581
1582   std::vector<std::unique_ptr<Codec>> codecs_;
1583   bool needsUncompressedLength_;
1584   uint64_t maxUncompressedLength_;
1585 };
1586
1587 std::vector<std::string> AutomaticCodec::validPrefixes() const {
1588   std::unordered_set<std::string> prefixes;
1589   for (const auto& codec : codecs_) {
1590     const auto codecPrefixes = codec->validPrefixes();
1591     prefixes.insert(codecPrefixes.begin(), codecPrefixes.end());
1592   }
1593   return std::vector<std::string>{prefixes.begin(), prefixes.end()};
1594 }
1595
1596 bool AutomaticCodec::canUncompress(
1597     const IOBuf* data,
1598     uint64_t uncompressedLength) const {
1599   return std::any_of(
1600       codecs_.begin(),
1601       codecs_.end(),
1602       [data, uncompressedLength](const auto& codec) {
1603         return codec->canUncompress(data, uncompressedLength);
1604       });
1605 }
1606
1607 void AutomaticCodec::addCodecIfSupported(CodecType type) {
1608   const bool present =
1609       std::any_of(codecs_.begin(), codecs_.end(), [&type](const auto& codec) {
1610         return codec->type() == type;
1611       });
1612   if (hasCodec(type) && !present) {
1613     codecs_.push_back(getCodec(type));
1614   }
1615 }
1616
1617 /* static */ std::unique_ptr<Codec> AutomaticCodec::create(
1618     std::vector<std::unique_ptr<Codec>> customCodecs) {
1619   return make_unique<AutomaticCodec>(std::move(customCodecs));
1620 }
1621
1622 AutomaticCodec::AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs)
1623     : Codec(CodecType::USER_DEFINED), codecs_(std::move(customCodecs)) {
1624   // Fastest -> slowest
1625   addCodecIfSupported(CodecType::LZ4_FRAME);
1626   addCodecIfSupported(CodecType::ZSTD);
1627   addCodecIfSupported(CodecType::ZLIB);
1628   addCodecIfSupported(CodecType::GZIP);
1629   addCodecIfSupported(CodecType::LZMA2);
1630   if (kIsDebug) {
1631     checkCompatibleCodecs();
1632   }
1633   // Check that none of the codes are are null
1634   DCHECK(std::none_of(codecs_.begin(), codecs_.end(), [](const auto& codec) {
1635     return codec == nullptr;
1636   }));
1637
1638   needsUncompressedLength_ =
1639       std::any_of(codecs_.begin(), codecs_.end(), [](const auto& codec) {
1640         return codec->needsUncompressedLength();
1641       });
1642
1643   const auto it = std::max_element(
1644       codecs_.begin(), codecs_.end(), [](const auto& lhs, const auto& rhs) {
1645         return lhs->maxUncompressedLength() < rhs->maxUncompressedLength();
1646       });
1647   DCHECK(it != codecs_.end());
1648   maxUncompressedLength_ = (*it)->maxUncompressedLength();
1649 }
1650
1651 void AutomaticCodec::checkCompatibleCodecs() const {
1652   // Keep track of all the possible headers.
1653   std::unordered_set<std::string> headers;
1654   // The empty header is not allowed.
1655   headers.insert("");
1656   // Step 1:
1657   // Construct a set of headers and check that none of the headers occur twice.
1658   // Eliminate edge cases.
1659   for (auto&& codec : codecs_) {
1660     const auto codecHeaders = codec->validPrefixes();
1661     // Codecs without any valid headers are not allowed.
1662     if (codecHeaders.empty()) {
1663       throw std::invalid_argument{
1664           "AutomaticCodec: validPrefixes() must not be empty."};
1665     }
1666     // Insert all the headers for the current codec.
1667     const size_t beforeSize = headers.size();
1668     headers.insert(codecHeaders.begin(), codecHeaders.end());
1669     // Codecs are not compatible if any header occurred twice.
1670     if (beforeSize + codecHeaders.size() != headers.size()) {
1671       throw std::invalid_argument{
1672           "AutomaticCodec: Two valid prefixes collide."};
1673     }
1674   }
1675   // Step 2:
1676   // Check if any strict non-empty prefix of any header is a header.
1677   for (const auto& header : headers) {
1678     for (size_t i = 1; i < header.size(); ++i) {
1679       if (headers.count(header.substr(0, i))) {
1680         throw std::invalid_argument{
1681             "AutomaticCodec: One valid prefix is a prefix of another valid "
1682             "prefix."};
1683       }
1684     }
1685   }
1686 }
1687
1688 bool AutomaticCodec::doNeedsUncompressedLength() const {
1689   return needsUncompressedLength_;
1690 }
1691
1692 uint64_t AutomaticCodec::doMaxUncompressedLength() const {
1693   return maxUncompressedLength_;
1694 }
1695
1696 std::unique_ptr<IOBuf> AutomaticCodec::doUncompress(
1697     const IOBuf* data,
1698     uint64_t uncompressedLength) {
1699   for (auto&& codec : codecs_) {
1700     if (codec->canUncompress(data, uncompressedLength)) {
1701       return codec->uncompress(data, uncompressedLength);
1702     }
1703   }
1704   throw std::runtime_error("AutomaticCodec error: Unknown compressed data");
1705 }
1706
1707 }  // namespace
1708
1709 typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
1710 static constexpr CodecFactory
1711     codecFactories[static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
1712         nullptr, // USER_DEFINED
1713         NoCompressionCodec::create,
1714
1715 #if FOLLY_HAVE_LIBLZ4
1716         LZ4Codec::create,
1717 #else
1718         nullptr,
1719 #endif
1720
1721 #if FOLLY_HAVE_LIBSNAPPY
1722         SnappyCodec::create,
1723 #else
1724         nullptr,
1725 #endif
1726
1727 #if FOLLY_HAVE_LIBZ
1728         ZlibCodec::create,
1729 #else
1730         nullptr,
1731 #endif
1732
1733 #if FOLLY_HAVE_LIBLZ4
1734         LZ4Codec::create,
1735 #else
1736         nullptr,
1737 #endif
1738
1739 #if FOLLY_HAVE_LIBLZMA
1740         LZMA2Codec::create,
1741         LZMA2Codec::create,
1742 #else
1743         nullptr,
1744         nullptr,
1745 #endif
1746
1747 #if FOLLY_HAVE_LIBZSTD
1748         ZSTDCodec::create,
1749 #else
1750         nullptr,
1751 #endif
1752
1753 #if FOLLY_HAVE_LIBZ
1754         ZlibCodec::create,
1755 #else
1756         nullptr,
1757 #endif
1758
1759 #if (FOLLY_HAVE_LIBLZ4 && LZ4_VERSION_NUMBER >= 10301)
1760         LZ4FrameCodec::create,
1761 #else
1762         nullptr,
1763 #endif
1764 };
1765
1766 bool hasCodec(CodecType type) {
1767   size_t idx = static_cast<size_t>(type);
1768   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1769     throw std::invalid_argument(
1770         to<std::string>("Compression type ", idx, " invalid"));
1771   }
1772   return codecFactories[idx] != nullptr;
1773 }
1774
1775 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
1776   size_t idx = static_cast<size_t>(type);
1777   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1778     throw std::invalid_argument(
1779         to<std::string>("Compression type ", idx, " invalid"));
1780   }
1781   auto factory = codecFactories[idx];
1782   if (!factory) {
1783     throw std::invalid_argument(to<std::string>(
1784         "Compression type ", idx, " not supported"));
1785   }
1786   auto codec = (*factory)(level, type);
1787   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
1788   return codec;
1789 }
1790
1791 std::unique_ptr<Codec> getAutoUncompressionCodec(
1792     std::vector<std::unique_ptr<Codec>> customCodecs) {
1793   return AutomaticCodec::create(std::move(customCodecs));
1794 }
1795 }}  // namespaces