Optimize ZSTDCodec::doUncompress()
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4frame.h>
22 #include <lz4hc.h>
23 #endif
24
25 #include <glog/logging.h>
26
27 #if FOLLY_HAVE_LIBSNAPPY
28 #include <snappy.h>
29 #include <snappy-sinksource.h>
30 #endif
31
32 #if FOLLY_HAVE_LIBZ
33 #include <zlib.h>
34 #endif
35
36 #if FOLLY_HAVE_LIBLZMA
37 #include <lzma.h>
38 #endif
39
40 #if FOLLY_HAVE_LIBZSTD
41 #include <zstd.h>
42 #endif
43
44 #include <folly/Conv.h>
45 #include <folly/Memory.h>
46 #include <folly/Portability.h>
47 #include <folly/ScopeGuard.h>
48 #include <folly/Varint.h>
49 #include <folly/io/Cursor.h>
50
51 namespace folly { namespace io {
52
53 Codec::Codec(CodecType type) : type_(type) { }
54
55 // Ensure consistent behavior in the nullptr case
56 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
57   uint64_t len = data->computeChainDataLength();
58   if (len == 0) {
59     return IOBuf::create(0);
60   }
61   if (len > maxUncompressedLength()) {
62     throw std::runtime_error("Codec: uncompressed length too large");
63   }
64
65   return doCompress(data);
66 }
67
68 std::string Codec::compress(const StringPiece data) {
69   const uint64_t len = data.size();
70   if (len == 0) {
71     return "";
72   }
73   if (len > maxUncompressedLength()) {
74     throw std::runtime_error("Codec: uncompressed length too large");
75   }
76
77   return doCompressString(data);
78 }
79
80 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
81                                          uint64_t uncompressedLength) {
82   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
83     if (needsUncompressedLength()) {
84       throw std::invalid_argument("Codec: uncompressed length required");
85     }
86   } else if (uncompressedLength > maxUncompressedLength()) {
87     throw std::runtime_error("Codec: uncompressed length too large");
88   }
89
90   if (data->empty()) {
91     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
92         uncompressedLength != 0) {
93       throw std::runtime_error("Codec: invalid uncompressed length");
94     }
95     return IOBuf::create(0);
96   }
97
98   return doUncompress(data, uncompressedLength);
99 }
100
101 std::string Codec::uncompress(
102     const StringPiece data,
103     uint64_t uncompressedLength) {
104   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
105     if (needsUncompressedLength()) {
106       throw std::invalid_argument("Codec: uncompressed length required");
107     }
108   } else if (uncompressedLength > maxUncompressedLength()) {
109     throw std::runtime_error("Codec: uncompressed length too large");
110   }
111
112   if (data.empty()) {
113     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
114         uncompressedLength != 0) {
115       throw std::runtime_error("Codec: invalid uncompressed length");
116     }
117     return "";
118   }
119
120   return doUncompressString(data, uncompressedLength);
121 }
122
123 bool Codec::needsUncompressedLength() const {
124   return doNeedsUncompressedLength();
125 }
126
127 uint64_t Codec::maxUncompressedLength() const {
128   return doMaxUncompressedLength();
129 }
130
131 bool Codec::doNeedsUncompressedLength() const {
132   return false;
133 }
134
135 uint64_t Codec::doMaxUncompressedLength() const {
136   return UNLIMITED_UNCOMPRESSED_LENGTH;
137 }
138
139 std::string Codec::doCompressString(const StringPiece data) {
140   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
141   auto outputBuffer = doCompress(&inputBuffer);
142   std::string output;
143   output.reserve(outputBuffer->computeChainDataLength());
144   for (auto range : *outputBuffer) {
145     output.append(reinterpret_cast<const char*>(range.data()), range.size());
146   }
147   return output;
148 }
149
150 std::string Codec::doUncompressString(
151     const StringPiece data,
152     uint64_t uncompressedLength) {
153   const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data};
154   auto outputBuffer = doUncompress(&inputBuffer, uncompressedLength);
155   std::string output;
156   output.reserve(outputBuffer->computeChainDataLength());
157   for (auto range : *outputBuffer) {
158     output.append(reinterpret_cast<const char*>(range.data()), range.size());
159   }
160   return output;
161 }
162
163 namespace {
164
165 /**
166  * No compression
167  */
168 class NoCompressionCodec final : public Codec {
169  public:
170   static std::unique_ptr<Codec> create(int level, CodecType type);
171   explicit NoCompressionCodec(int level, CodecType type);
172
173  private:
174   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
175   std::unique_ptr<IOBuf> doUncompress(
176       const IOBuf* data,
177       uint64_t uncompressedLength) override;
178 };
179
180 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
181   return make_unique<NoCompressionCodec>(level, type);
182 }
183
184 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
185   : Codec(type) {
186   DCHECK(type == CodecType::NO_COMPRESSION);
187   switch (level) {
188   case COMPRESSION_LEVEL_DEFAULT:
189   case COMPRESSION_LEVEL_FASTEST:
190   case COMPRESSION_LEVEL_BEST:
191     level = 0;
192   }
193   if (level != 0) {
194     throw std::invalid_argument(to<std::string>(
195         "NoCompressionCodec: invalid level ", level));
196   }
197 }
198
199 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
200     const IOBuf* data) {
201   return data->clone();
202 }
203
204 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
205     const IOBuf* data,
206     uint64_t uncompressedLength) {
207   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
208       data->computeChainDataLength() != uncompressedLength) {
209     throw std::runtime_error(to<std::string>(
210         "NoCompressionCodec: invalid uncompressed length"));
211   }
212   return data->clone();
213 }
214
215 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
216
217 namespace {
218
219 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
220   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
221   out->append(encodeVarint(val, out->writableTail()));
222 }
223
224 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
225   uint64_t val = 0;
226   int8_t b = 0;
227   for (int shift = 0; shift <= 63; shift += 7) {
228     b = cursor.read<int8_t>();
229     val |= static_cast<uint64_t>(b & 0x7f) << shift;
230     if (b >= 0) {
231       break;
232     }
233   }
234   if (b < 0) {
235     throw std::invalid_argument("Invalid varint value. Too big.");
236   }
237   return val;
238 }
239
240 }  // namespace
241
242 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
243
244 #if FOLLY_HAVE_LIBLZ4
245
246 /**
247  * LZ4 compression
248  */
249 class LZ4Codec final : public Codec {
250  public:
251   static std::unique_ptr<Codec> create(int level, CodecType type);
252   explicit LZ4Codec(int level, CodecType type);
253
254  private:
255   bool doNeedsUncompressedLength() const override;
256   uint64_t doMaxUncompressedLength() const override;
257
258   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
259
260   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
261   std::unique_ptr<IOBuf> doUncompress(
262       const IOBuf* data,
263       uint64_t uncompressedLength) override;
264
265   bool highCompression_;
266 };
267
268 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
269   return make_unique<LZ4Codec>(level, type);
270 }
271
272 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
273   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
274
275   switch (level) {
276   case COMPRESSION_LEVEL_FASTEST:
277   case COMPRESSION_LEVEL_DEFAULT:
278     level = 1;
279     break;
280   case COMPRESSION_LEVEL_BEST:
281     level = 2;
282     break;
283   }
284   if (level < 1 || level > 2) {
285     throw std::invalid_argument(to<std::string>(
286         "LZ4Codec: invalid level: ", level));
287   }
288   highCompression_ = (level > 1);
289 }
290
291 bool LZ4Codec::doNeedsUncompressedLength() const {
292   return !encodeSize();
293 }
294
295 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
296 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
297 // here.
298 #ifndef LZ4_MAX_INPUT_SIZE
299 # define LZ4_MAX_INPUT_SIZE 0x7E000000
300 #endif
301
302 uint64_t LZ4Codec::doMaxUncompressedLength() const {
303   return LZ4_MAX_INPUT_SIZE;
304 }
305
306 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
307   IOBuf clone;
308   if (data->isChained()) {
309     // LZ4 doesn't support streaming, so we have to coalesce
310     clone = data->cloneCoalescedAsValue();
311     data = &clone;
312   }
313
314   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
315   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
316   if (encodeSize()) {
317     encodeVarintToIOBuf(data->length(), out.get());
318   }
319
320   int n;
321   auto input = reinterpret_cast<const char*>(data->data());
322   auto output = reinterpret_cast<char*>(out->writableTail());
323   const auto inputLength = data->length();
324 #if LZ4_VERSION_NUMBER >= 10700
325   if (highCompression_) {
326     n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0);
327   } else {
328     n = LZ4_compress_default(input, output, inputLength, out->tailroom());
329   }
330 #else
331   if (highCompression_) {
332     n = LZ4_compressHC(input, output, inputLength);
333   } else {
334     n = LZ4_compress(input, output, inputLength);
335   }
336 #endif
337
338   CHECK_GE(n, 0);
339   CHECK_LE(n, out->capacity());
340
341   out->append(n);
342   return out;
343 }
344
345 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
346     const IOBuf* data,
347     uint64_t uncompressedLength) {
348   IOBuf clone;
349   if (data->isChained()) {
350     // LZ4 doesn't support streaming, so we have to coalesce
351     clone = data->cloneCoalescedAsValue();
352     data = &clone;
353   }
354
355   folly::io::Cursor cursor(data);
356   uint64_t actualUncompressedLength;
357   if (encodeSize()) {
358     actualUncompressedLength = decodeVarintFromCursor(cursor);
359     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
360         uncompressedLength != actualUncompressedLength) {
361       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
362     }
363   } else {
364     actualUncompressedLength = uncompressedLength;
365     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
366         actualUncompressedLength > maxUncompressedLength()) {
367       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
368     }
369   }
370
371   auto sp = StringPiece{cursor.peekBytes()};
372   auto out = IOBuf::create(actualUncompressedLength);
373   int n = LZ4_decompress_safe(
374       sp.data(),
375       reinterpret_cast<char*>(out->writableTail()),
376       sp.size(),
377       actualUncompressedLength);
378
379   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
380     throw std::runtime_error(to<std::string>(
381         "LZ4 decompression returned invalid value ", n));
382   }
383   out->append(actualUncompressedLength);
384   return out;
385 }
386
387 class LZ4FrameCodec final : public Codec {
388  public:
389   static std::unique_ptr<Codec> create(int level, CodecType type);
390   explicit LZ4FrameCodec(int level, CodecType type);
391   ~LZ4FrameCodec();
392
393  private:
394   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
395   std::unique_ptr<IOBuf> doUncompress(
396       const IOBuf* data,
397       uint64_t uncompressedLength) override;
398
399   // Reset the dctx_ if it is dirty or null.
400   void resetDCtx();
401
402   int level_;
403   LZ4F_dctx* dctx_{nullptr};
404   bool dirty_{false};
405 };
406
407 /* static */ std::unique_ptr<Codec> LZ4FrameCodec::create(
408     int level,
409     CodecType type) {
410   return make_unique<LZ4FrameCodec>(level, type);
411 }
412
413 static size_t lz4FrameThrowOnError(size_t code) {
414   if (LZ4F_isError(code)) {
415     throw std::runtime_error(
416         to<std::string>("LZ4Frame error: ", LZ4F_getErrorName(code)));
417   }
418   return code;
419 }
420
421 void LZ4FrameCodec::resetDCtx() {
422   if (dctx_ && !dirty_) {
423     return;
424   }
425   if (dctx_) {
426     LZ4F_freeDecompressionContext(dctx_);
427   }
428   lz4FrameThrowOnError(LZ4F_createDecompressionContext(&dctx_, 100));
429   dirty_ = false;
430 }
431
432 LZ4FrameCodec::LZ4FrameCodec(int level, CodecType type) : Codec(type) {
433   DCHECK(type == CodecType::LZ4_FRAME);
434   switch (level) {
435     case COMPRESSION_LEVEL_FASTEST:
436     case COMPRESSION_LEVEL_DEFAULT:
437       level_ = 0;
438       break;
439     case COMPRESSION_LEVEL_BEST:
440       level_ = 16;
441       break;
442     default:
443       level_ = level;
444       break;
445   }
446 }
447
448 LZ4FrameCodec::~LZ4FrameCodec() {
449   if (dctx_) {
450     LZ4F_freeDecompressionContext(dctx_);
451   }
452 }
453
454 std::unique_ptr<IOBuf> LZ4FrameCodec::doCompress(const IOBuf* data) {
455   // LZ4 Frame compression doesn't support streaming so we have to coalesce
456   IOBuf clone;
457   if (data->isChained()) {
458     clone = data->cloneCoalescedAsValue();
459     data = &clone;
460   }
461   // Set preferences
462   const auto uncompressedLength = data->length();
463   LZ4F_preferences_t prefs{};
464   prefs.compressionLevel = level_;
465   prefs.frameInfo.contentSize = uncompressedLength;
466   // Compress
467   auto buf = IOBuf::create(LZ4F_compressFrameBound(uncompressedLength, &prefs));
468   const size_t written = lz4FrameThrowOnError(LZ4F_compressFrame(
469       buf->writableTail(),
470       buf->tailroom(),
471       data->data(),
472       data->length(),
473       &prefs));
474   buf->append(written);
475   return buf;
476 }
477
478 std::unique_ptr<IOBuf> LZ4FrameCodec::doUncompress(
479     const IOBuf* data,
480     uint64_t uncompressedLength) {
481   // Reset the dctx if any errors have occurred
482   resetDCtx();
483   // Coalesce the data
484   ByteRange in = *data->begin();
485   IOBuf clone;
486   if (data->isChained()) {
487     clone = data->cloneCoalescedAsValue();
488     in = clone.coalesce();
489   }
490   data = nullptr;
491   // Select decompression options
492   LZ4F_decompressOptions_t options;
493   options.stableDst = 1;
494   // Select blockSize and growthSize for the IOBufQueue
495   IOBufQueue queue(IOBufQueue::cacheChainLength());
496   auto blockSize = uint64_t{64} << 10;
497   auto growthSize = uint64_t{4} << 20;
498   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
499     // Allocate uncompressedLength in one chunk (up to 64 MB)
500     const auto allocateSize = std::min(uncompressedLength, uint64_t{64} << 20);
501     queue.preallocate(allocateSize, allocateSize);
502     blockSize = std::min(uncompressedLength, blockSize);
503     growthSize = std::min(uncompressedLength, growthSize);
504   } else {
505     // Reduce growthSize for small data
506     const auto guessUncompressedLen = 4 * std::max(blockSize, in.size());
507     growthSize = std::min(guessUncompressedLen, growthSize);
508   }
509   // Once LZ4_decompress() is called, the dctx_ cannot be reused until it
510   // returns 0
511   dirty_ = true;
512   // Decompress until the frame is over
513   size_t code = 0;
514   do {
515     // Allocate enough space to decompress at least a block
516     void* out;
517     size_t outSize;
518     std::tie(out, outSize) = queue.preallocate(blockSize, growthSize);
519     // Decompress
520     size_t inSize = in.size();
521     code = lz4FrameThrowOnError(
522         LZ4F_decompress(dctx_, out, &outSize, in.data(), &inSize, &options));
523     if (in.empty() && outSize == 0 && code != 0) {
524       // We passed no input, no output was produced, and the frame isn't over
525       // No more forward progress is possible
526       throw std::runtime_error("LZ4Frame error: Incomplete frame");
527     }
528     in.uncheckedAdvance(inSize);
529     queue.postallocate(outSize);
530   } while (code != 0);
531   // At this point the decompression context can be reused
532   dirty_ = false;
533   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
534       queue.chainLength() != uncompressedLength) {
535     throw std::runtime_error("LZ4Frame error: Invalid uncompressedLength");
536   }
537   return queue.move();
538 }
539
540 #endif // FOLLY_HAVE_LIBLZ4
541
542 #if FOLLY_HAVE_LIBSNAPPY
543
544 /**
545  * Snappy compression
546  */
547
548 /**
549  * Implementation of snappy::Source that reads from a IOBuf chain.
550  */
551 class IOBufSnappySource final : public snappy::Source {
552  public:
553   explicit IOBufSnappySource(const IOBuf* data);
554   size_t Available() const override;
555   const char* Peek(size_t* len) override;
556   void Skip(size_t n) override;
557  private:
558   size_t available_;
559   io::Cursor cursor_;
560 };
561
562 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
563   : available_(data->computeChainDataLength()),
564     cursor_(data) {
565 }
566
567 size_t IOBufSnappySource::Available() const {
568   return available_;
569 }
570
571 const char* IOBufSnappySource::Peek(size_t* len) {
572   auto sp = StringPiece{cursor_.peekBytes()};
573   *len = sp.size();
574   return sp.data();
575 }
576
577 void IOBufSnappySource::Skip(size_t n) {
578   CHECK_LE(n, available_);
579   cursor_.skip(n);
580   available_ -= n;
581 }
582
583 class SnappyCodec final : public Codec {
584  public:
585   static std::unique_ptr<Codec> create(int level, CodecType type);
586   explicit SnappyCodec(int level, CodecType type);
587
588  private:
589   uint64_t doMaxUncompressedLength() const override;
590   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
591   std::unique_ptr<IOBuf> doUncompress(
592       const IOBuf* data,
593       uint64_t uncompressedLength) override;
594 };
595
596 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
597   return make_unique<SnappyCodec>(level, type);
598 }
599
600 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
601   DCHECK(type == CodecType::SNAPPY);
602   switch (level) {
603   case COMPRESSION_LEVEL_FASTEST:
604   case COMPRESSION_LEVEL_DEFAULT:
605   case COMPRESSION_LEVEL_BEST:
606     level = 1;
607   }
608   if (level != 1) {
609     throw std::invalid_argument(to<std::string>(
610         "SnappyCodec: invalid level: ", level));
611   }
612 }
613
614 uint64_t SnappyCodec::doMaxUncompressedLength() const {
615   // snappy.h uses uint32_t for lengths, so there's that.
616   return std::numeric_limits<uint32_t>::max();
617 }
618
619 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
620   IOBufSnappySource source(data);
621   auto out =
622     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
623
624   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
625       out->writableTail()));
626
627   size_t n = snappy::Compress(&source, &sink);
628
629   CHECK_LE(n, out->capacity());
630   out->append(n);
631   return out;
632 }
633
634 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
635                                                  uint64_t uncompressedLength) {
636   uint32_t actualUncompressedLength = 0;
637
638   {
639     IOBufSnappySource source(data);
640     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
641       throw std::runtime_error("snappy::GetUncompressedLength failed");
642     }
643     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
644         uncompressedLength != actualUncompressedLength) {
645       throw std::runtime_error("snappy: invalid uncompressed length");
646     }
647   }
648
649   auto out = IOBuf::create(actualUncompressedLength);
650
651   {
652     IOBufSnappySource source(data);
653     if (!snappy::RawUncompress(&source,
654                                reinterpret_cast<char*>(out->writableTail()))) {
655       throw std::runtime_error("snappy::RawUncompress failed");
656     }
657   }
658
659   out->append(actualUncompressedLength);
660   return out;
661 }
662
663 #endif  // FOLLY_HAVE_LIBSNAPPY
664
665 #if FOLLY_HAVE_LIBZ
666 /**
667  * Zlib codec
668  */
669 class ZlibCodec final : public Codec {
670  public:
671   static std::unique_ptr<Codec> create(int level, CodecType type);
672   explicit ZlibCodec(int level, CodecType type);
673
674  private:
675   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
676   std::unique_ptr<IOBuf> doUncompress(
677       const IOBuf* data,
678       uint64_t uncompressedLength) override;
679
680   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
681   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
682
683   int level_;
684 };
685
686 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
687   return make_unique<ZlibCodec>(level, type);
688 }
689
690 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
691   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
692   switch (level) {
693   case COMPRESSION_LEVEL_FASTEST:
694     level = 1;
695     break;
696   case COMPRESSION_LEVEL_DEFAULT:
697     level = Z_DEFAULT_COMPRESSION;
698     break;
699   case COMPRESSION_LEVEL_BEST:
700     level = 9;
701     break;
702   }
703   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
704     throw std::invalid_argument(to<std::string>(
705         "ZlibCodec: invalid level: ", level));
706   }
707   level_ = level;
708 }
709
710 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
711                                                   uint32_t length) {
712   CHECK_EQ(stream->avail_out, 0);
713
714   auto buf = IOBuf::create(length);
715   buf->append(length);
716
717   stream->next_out = buf->writableData();
718   stream->avail_out = buf->length();
719
720   return buf;
721 }
722
723 bool ZlibCodec::doInflate(z_stream* stream,
724                           IOBuf* head,
725                           uint32_t bufferLength) {
726   if (stream->avail_out == 0) {
727     head->prependChain(addOutputBuffer(stream, bufferLength));
728   }
729
730   int rc = inflate(stream, Z_NO_FLUSH);
731
732   switch (rc) {
733   case Z_OK:
734     break;
735   case Z_STREAM_END:
736     return true;
737   case Z_BUF_ERROR:
738   case Z_NEED_DICT:
739   case Z_DATA_ERROR:
740   case Z_MEM_ERROR:
741     throw std::runtime_error(to<std::string>(
742         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
743   default:
744     CHECK(false) << rc << ": " << stream->msg;
745   }
746
747   return false;
748 }
749
750 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
751   z_stream stream;
752   stream.zalloc = nullptr;
753   stream.zfree = nullptr;
754   stream.opaque = nullptr;
755
756   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
757   // base two logarithm of the maximum window size (...) The default value is
758   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
759   // around the compressed data instead of a zlib wrapper. The gzip header
760   // will have no file name, no extra data, no comment, no modification time
761   // (set to zero), no header crc, and the operating system will be set to 255
762   // (unknown)."
763   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
764   // All other parameters (method, memLevel, strategy) get default values from
765   // the zlib manual.
766   int rc = deflateInit2(&stream,
767                         level_,
768                         Z_DEFLATED,
769                         windowBits,
770                         /* memLevel */ 8,
771                         Z_DEFAULT_STRATEGY);
772   if (rc != Z_OK) {
773     throw std::runtime_error(to<std::string>(
774         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
775   }
776
777   stream.next_in = stream.next_out = nullptr;
778   stream.avail_in = stream.avail_out = 0;
779   stream.total_in = stream.total_out = 0;
780
781   bool success = false;
782
783   SCOPE_EXIT {
784     rc = deflateEnd(&stream);
785     // If we're here because of an exception, it's okay if some data
786     // got dropped.
787     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
788       << rc << ": " << stream.msg;
789   };
790
791   uint64_t uncompressedLength = data->computeChainDataLength();
792   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
793
794   // Max 64MiB in one go
795   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
796   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
797
798   auto out = addOutputBuffer(
799       &stream,
800       (maxCompressedLength <= maxSingleStepLength ?
801        maxCompressedLength :
802        defaultBufferLength));
803
804   for (auto& range : *data) {
805     uint64_t remaining = range.size();
806     uint64_t written = 0;
807     while (remaining) {
808       uint32_t step = (remaining > maxSingleStepLength ?
809                        maxSingleStepLength : remaining);
810       stream.next_in = const_cast<uint8_t*>(range.data() + written);
811       stream.avail_in = step;
812       remaining -= step;
813       written += step;
814
815       while (stream.avail_in != 0) {
816         if (stream.avail_out == 0) {
817           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
818         }
819
820         rc = deflate(&stream, Z_NO_FLUSH);
821
822         CHECK_EQ(rc, Z_OK) << stream.msg;
823       }
824     }
825   }
826
827   do {
828     if (stream.avail_out == 0) {
829       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
830     }
831
832     rc = deflate(&stream, Z_FINISH);
833   } while (rc == Z_OK);
834
835   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
836
837   out->prev()->trimEnd(stream.avail_out);
838
839   success = true;  // we survived
840
841   return out;
842 }
843
844 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
845                                                uint64_t uncompressedLength) {
846   z_stream stream;
847   stream.zalloc = nullptr;
848   stream.zfree = nullptr;
849   stream.opaque = nullptr;
850
851   // "The windowBits parameter is the base two logarithm of the maximum window
852   // size (...) The default value is 15 (...) add 16 to decode only the gzip
853   // format (the zlib format will return a Z_DATA_ERROR)."
854   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
855   int rc = inflateInit2(&stream, windowBits);
856   if (rc != Z_OK) {
857     throw std::runtime_error(to<std::string>(
858         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
859   }
860
861   stream.next_in = stream.next_out = nullptr;
862   stream.avail_in = stream.avail_out = 0;
863   stream.total_in = stream.total_out = 0;
864
865   bool success = false;
866
867   SCOPE_EXIT {
868     rc = inflateEnd(&stream);
869     // If we're here because of an exception, it's okay if some data
870     // got dropped.
871     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
872       << rc << ": " << stream.msg;
873   };
874
875   // Max 64MiB in one go
876   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
877   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
878
879   auto out = addOutputBuffer(
880       &stream,
881       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
882         uncompressedLength <= maxSingleStepLength) ?
883        uncompressedLength :
884        defaultBufferLength));
885
886   bool streamEnd = false;
887   for (auto& range : *data) {
888     if (range.empty()) {
889       continue;
890     }
891
892     stream.next_in = const_cast<uint8_t*>(range.data());
893     stream.avail_in = range.size();
894
895     while (stream.avail_in != 0) {
896       if (streamEnd) {
897         throw std::runtime_error(to<std::string>(
898             "ZlibCodec: junk after end of data"));
899       }
900
901       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
902     }
903   }
904
905   while (!streamEnd) {
906     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
907   }
908
909   out->prev()->trimEnd(stream.avail_out);
910
911   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
912       uncompressedLength != stream.total_out) {
913     throw std::runtime_error(to<std::string>(
914         "ZlibCodec: invalid uncompressed length"));
915   }
916
917   success = true;  // we survived
918
919   return out;
920 }
921
922 #endif  // FOLLY_HAVE_LIBZ
923
924 #if FOLLY_HAVE_LIBLZMA
925
926 /**
927  * LZMA2 compression
928  */
929 class LZMA2Codec final : public Codec {
930  public:
931   static std::unique_ptr<Codec> create(int level, CodecType type);
932   explicit LZMA2Codec(int level, CodecType type);
933
934  private:
935   bool doNeedsUncompressedLength() const override;
936   uint64_t doMaxUncompressedLength() const override;
937
938   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
939
940   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
941   std::unique_ptr<IOBuf> doUncompress(
942       const IOBuf* data,
943       uint64_t uncompressedLength) override;
944
945   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
946   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
947
948   int level_;
949 };
950
951 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
952   return make_unique<LZMA2Codec>(level, type);
953 }
954
955 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
956   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
957   switch (level) {
958   case COMPRESSION_LEVEL_FASTEST:
959     level = 0;
960     break;
961   case COMPRESSION_LEVEL_DEFAULT:
962     level = LZMA_PRESET_DEFAULT;
963     break;
964   case COMPRESSION_LEVEL_BEST:
965     level = 9;
966     break;
967   }
968   if (level < 0 || level > 9) {
969     throw std::invalid_argument(to<std::string>(
970         "LZMA2Codec: invalid level: ", level));
971   }
972   level_ = level;
973 }
974
975 bool LZMA2Codec::doNeedsUncompressedLength() const {
976   return !encodeSize();
977 }
978
979 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
980   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
981   return uint64_t(1) << 63;
982 }
983
984 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
985     lzma_stream* stream,
986     size_t length) {
987
988   CHECK_EQ(stream->avail_out, 0);
989
990   auto buf = IOBuf::create(length);
991   buf->append(length);
992
993   stream->next_out = buf->writableData();
994   stream->avail_out = buf->length();
995
996   return buf;
997 }
998
999 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
1000   lzma_ret rc;
1001   lzma_stream stream = LZMA_STREAM_INIT;
1002
1003   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
1004   if (rc != LZMA_OK) {
1005     throw std::runtime_error(folly::to<std::string>(
1006       "LZMA2Codec: lzma_easy_encoder error: ", rc));
1007   }
1008
1009   SCOPE_EXIT { lzma_end(&stream); };
1010
1011   uint64_t uncompressedLength = data->computeChainDataLength();
1012   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
1013
1014   // Max 64MiB in one go
1015   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
1016   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
1017
1018   auto out = addOutputBuffer(
1019     &stream,
1020     (maxCompressedLength <= maxSingleStepLength ?
1021      maxCompressedLength :
1022      defaultBufferLength));
1023
1024   if (encodeSize()) {
1025     auto size = IOBuf::createCombined(kMaxVarintLength64);
1026     encodeVarintToIOBuf(uncompressedLength, size.get());
1027     size->appendChain(std::move(out));
1028     out = std::move(size);
1029   }
1030
1031   for (auto& range : *data) {
1032     if (range.empty()) {
1033       continue;
1034     }
1035
1036     stream.next_in = const_cast<uint8_t*>(range.data());
1037     stream.avail_in = range.size();
1038
1039     while (stream.avail_in != 0) {
1040       if (stream.avail_out == 0) {
1041         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1042       }
1043
1044       rc = lzma_code(&stream, LZMA_RUN);
1045
1046       if (rc != LZMA_OK) {
1047         throw std::runtime_error(folly::to<std::string>(
1048           "LZMA2Codec: lzma_code error: ", rc));
1049       }
1050     }
1051   }
1052
1053   do {
1054     if (stream.avail_out == 0) {
1055       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
1056     }
1057
1058     rc = lzma_code(&stream, LZMA_FINISH);
1059   } while (rc == LZMA_OK);
1060
1061   if (rc != LZMA_STREAM_END) {
1062     throw std::runtime_error(folly::to<std::string>(
1063       "LZMA2Codec: lzma_code ended with error: ", rc));
1064   }
1065
1066   out->prev()->trimEnd(stream.avail_out);
1067
1068   return out;
1069 }
1070
1071 bool LZMA2Codec::doInflate(lzma_stream* stream,
1072                           IOBuf* head,
1073                           size_t bufferLength) {
1074   if (stream->avail_out == 0) {
1075     head->prependChain(addOutputBuffer(stream, bufferLength));
1076   }
1077
1078   lzma_ret rc = lzma_code(stream, LZMA_RUN);
1079
1080   switch (rc) {
1081   case LZMA_OK:
1082     break;
1083   case LZMA_STREAM_END:
1084     return true;
1085   default:
1086     throw std::runtime_error(to<std::string>(
1087         "LZMA2Codec: lzma_code error: ", rc));
1088   }
1089
1090   return false;
1091 }
1092
1093 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
1094                                                uint64_t uncompressedLength) {
1095   lzma_ret rc;
1096   lzma_stream stream = LZMA_STREAM_INIT;
1097
1098   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
1099   if (rc != LZMA_OK) {
1100     throw std::runtime_error(folly::to<std::string>(
1101       "LZMA2Codec: lzma_auto_decoder error: ", rc));
1102   }
1103
1104   SCOPE_EXIT { lzma_end(&stream); };
1105
1106   // Max 64MiB in one go
1107   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
1108   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
1109
1110   folly::io::Cursor cursor(data);
1111   uint64_t actualUncompressedLength;
1112   if (encodeSize()) {
1113     actualUncompressedLength = decodeVarintFromCursor(cursor);
1114     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1115         uncompressedLength != actualUncompressedLength) {
1116       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
1117     }
1118   } else {
1119     actualUncompressedLength = uncompressedLength;
1120     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
1121   }
1122
1123   auto out = addOutputBuffer(
1124       &stream,
1125       (actualUncompressedLength <= maxSingleStepLength ?
1126        actualUncompressedLength :
1127        defaultBufferLength));
1128
1129   bool streamEnd = false;
1130   auto buf = cursor.peekBytes();
1131   while (!buf.empty()) {
1132     stream.next_in = const_cast<uint8_t*>(buf.data());
1133     stream.avail_in = buf.size();
1134
1135     while (stream.avail_in != 0) {
1136       if (streamEnd) {
1137         throw std::runtime_error(to<std::string>(
1138             "LZMA2Codec: junk after end of data"));
1139       }
1140
1141       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1142     }
1143
1144     cursor.skip(buf.size());
1145     buf = cursor.peekBytes();
1146   }
1147
1148   while (!streamEnd) {
1149     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
1150   }
1151
1152   out->prev()->trimEnd(stream.avail_out);
1153
1154   if (actualUncompressedLength != stream.total_out) {
1155     throw std::runtime_error(to<std::string>(
1156         "LZMA2Codec: invalid uncompressed length"));
1157   }
1158
1159   return out;
1160 }
1161
1162 #endif  // FOLLY_HAVE_LIBLZMA
1163
1164 #ifdef FOLLY_HAVE_LIBZSTD
1165
1166 /**
1167  * ZSTD compression
1168  */
1169 class ZSTDCodec final : public Codec {
1170  public:
1171   static std::unique_ptr<Codec> create(int level, CodecType);
1172   explicit ZSTDCodec(int level, CodecType type);
1173
1174  private:
1175   bool doNeedsUncompressedLength() const override;
1176   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
1177   std::unique_ptr<IOBuf> doUncompress(
1178       const IOBuf* data,
1179       uint64_t uncompressedLength) override;
1180
1181   int level_;
1182 };
1183
1184 std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
1185   return make_unique<ZSTDCodec>(level, type);
1186 }
1187
1188 ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
1189   DCHECK(type == CodecType::ZSTD);
1190   switch (level) {
1191     case COMPRESSION_LEVEL_FASTEST:
1192       level = 1;
1193       break;
1194     case COMPRESSION_LEVEL_DEFAULT:
1195       level = 1;
1196       break;
1197     case COMPRESSION_LEVEL_BEST:
1198       level = 19;
1199       break;
1200   }
1201   if (level < 1 || level > ZSTD_maxCLevel()) {
1202     throw std::invalid_argument(
1203         to<std::string>("ZSTD: invalid level: ", level));
1204   }
1205   level_ = level;
1206 }
1207
1208 bool ZSTDCodec::doNeedsUncompressedLength() const {
1209   return false;
1210 }
1211
1212 void zstdThrowIfError(size_t rc) {
1213   if (!ZSTD_isError(rc)) {
1214     return;
1215   }
1216   throw std::runtime_error(
1217       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1218 }
1219
1220 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
1221   // Support earlier versions of the codec (working with a single IOBuf,
1222   // and using ZSTD_decompress which requires ZSTD frame to contain size,
1223   // which isn't populated by streaming API).
1224   if (!data->isChained()) {
1225     auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
1226     const auto rc = ZSTD_compress(
1227         out->writableData(),
1228         out->capacity(),
1229         data->data(),
1230         data->length(),
1231         level_);
1232     zstdThrowIfError(rc);
1233     out->append(rc);
1234     return out;
1235   }
1236
1237   auto zcs = ZSTD_createCStream();
1238   SCOPE_EXIT {
1239     ZSTD_freeCStream(zcs);
1240   };
1241
1242   auto rc = ZSTD_initCStream(zcs, level_);
1243   zstdThrowIfError(rc);
1244
1245   Cursor cursor(data);
1246   auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
1247
1248   ZSTD_outBuffer out;
1249   out.dst = result->writableTail();
1250   out.size = result->capacity();
1251   out.pos = 0;
1252
1253   for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
1254     ZSTD_inBuffer in;
1255     in.src = buffer.data();
1256     in.size = buffer.size();
1257     for (in.pos = 0; in.pos != in.size;) {
1258       rc = ZSTD_compressStream(zcs, &out, &in);
1259       zstdThrowIfError(rc);
1260     }
1261     cursor.skip(in.size);
1262     buffer = cursor.peekBytes();
1263   }
1264
1265   rc = ZSTD_endStream(zcs, &out);
1266   zstdThrowIfError(rc);
1267   CHECK_EQ(rc, 0);
1268
1269   result->append(out.pos);
1270   return result;
1271 }
1272
1273 static std::unique_ptr<IOBuf> zstdUncompressBuffer(
1274     const IOBuf* data,
1275     uint64_t uncompressedLength) {
1276   // Check preconditions
1277   DCHECK(!data->isChained());
1278   DCHECK(uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH);
1279
1280   auto uncompressed = IOBuf::create(uncompressedLength);
1281   const auto decompressedSize = ZSTD_decompress(
1282       uncompressed->writableTail(),
1283       uncompressed->tailroom(),
1284       data->data(),
1285       data->length());
1286   zstdThrowIfError(decompressedSize);
1287   if (decompressedSize != uncompressedLength) {
1288     throw std::runtime_error("ZSTD: invalid uncompressed length");
1289   }
1290   uncompressed->append(decompressedSize);
1291   return uncompressed;
1292 }
1293
1294 static std::unique_ptr<IOBuf> zstdUncompressStream(
1295     const IOBuf* data,
1296     uint64_t uncompressedLength) {
1297   auto zds = ZSTD_createDStream();
1298   SCOPE_EXIT {
1299     ZSTD_freeDStream(zds);
1300   };
1301
1302   auto rc = ZSTD_initDStream(zds);
1303   zstdThrowIfError(rc);
1304
1305   ZSTD_outBuffer out{};
1306   ZSTD_inBuffer in{};
1307
1308   auto outputSize = ZSTD_DStreamOutSize();
1309   if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH) {
1310     outputSize = uncompressedLength;
1311   }
1312
1313   IOBufQueue queue(IOBufQueue::cacheChainLength());
1314
1315   Cursor cursor(data);
1316   for (rc = 0;;) {
1317     if (in.pos == in.size) {
1318       auto buffer = cursor.peekBytes();
1319       in.src = buffer.data();
1320       in.size = buffer.size();
1321       in.pos = 0;
1322       cursor.skip(in.size);
1323       if (rc > 1 && in.size == 0) {
1324         throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
1325       }
1326     }
1327     if (out.pos == out.size) {
1328       if (out.pos != 0) {
1329         queue.postallocate(out.pos);
1330       }
1331       auto buffer = queue.preallocate(outputSize, outputSize);
1332       out.dst = buffer.first;
1333       out.size = buffer.second;
1334       out.pos = 0;
1335       outputSize = ZSTD_DStreamOutSize();
1336     }
1337     rc = ZSTD_decompressStream(zds, &out, &in);
1338     zstdThrowIfError(rc);
1339     if (rc == 0) {
1340       break;
1341     }
1342   }
1343   if (out.pos != 0) {
1344     queue.postallocate(out.pos);
1345   }
1346   if (in.pos != in.size || !cursor.isAtEnd()) {
1347     throw std::runtime_error("ZSTD: junk after end of data");
1348   }
1349   if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH &&
1350       queue.chainLength() != uncompressedLength) {
1351     throw std::runtime_error("ZSTD: invalid uncompressed length");
1352   }
1353
1354   return queue.move();
1355 }
1356
1357 std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
1358     const IOBuf* data,
1359     uint64_t uncompressedLength) {
1360   {
1361     // Read decompressed size from frame if available in first IOBuf.
1362     const auto decompressedSize =
1363         ZSTD_getDecompressedSize(data->data(), data->length());
1364     if (decompressedSize != 0) {
1365       if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH &&
1366           uncompressedLength != decompressedSize) {
1367         throw std::runtime_error("ZSTD: invalid uncompressed length");
1368       }
1369       uncompressedLength = decompressedSize;
1370     }
1371   }
1372   // Faster to decompress using ZSTD_decompress() if we can.
1373   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH && !data->isChained()) {
1374     return zstdUncompressBuffer(data, uncompressedLength);
1375   }
1376   // Fall back to slower streaming decompression.
1377   return zstdUncompressStream(data, uncompressedLength);
1378 }
1379
1380 #endif  // FOLLY_HAVE_LIBZSTD
1381
1382 }  // namespace
1383
1384 typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
1385 static constexpr CodecFactory
1386     codecFactories[static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
1387         nullptr, // USER_DEFINED
1388         NoCompressionCodec::create,
1389
1390 #if FOLLY_HAVE_LIBLZ4
1391         LZ4Codec::create,
1392 #else
1393         nullptr,
1394 #endif
1395
1396 #if FOLLY_HAVE_LIBSNAPPY
1397         SnappyCodec::create,
1398 #else
1399         nullptr,
1400 #endif
1401
1402 #if FOLLY_HAVE_LIBZ
1403         ZlibCodec::create,
1404 #else
1405         nullptr,
1406 #endif
1407
1408 #if FOLLY_HAVE_LIBLZ4
1409         LZ4Codec::create,
1410 #else
1411         nullptr,
1412 #endif
1413
1414 #if FOLLY_HAVE_LIBLZMA
1415         LZMA2Codec::create,
1416         LZMA2Codec::create,
1417 #else
1418         nullptr,
1419         nullptr,
1420 #endif
1421
1422 #if FOLLY_HAVE_LIBZSTD
1423         ZSTDCodec::create,
1424 #else
1425         nullptr,
1426 #endif
1427
1428 #if FOLLY_HAVE_LIBZ
1429         ZlibCodec::create,
1430 #else
1431         nullptr,
1432 #endif
1433
1434 #if FOLLY_HAVE_LIBLZ4
1435         LZ4FrameCodec::create,
1436 #else
1437         nullptr,
1438 #endif
1439 };
1440
1441 bool hasCodec(CodecType type) {
1442   size_t idx = static_cast<size_t>(type);
1443   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1444     throw std::invalid_argument(
1445         to<std::string>("Compression type ", idx, " invalid"));
1446   }
1447   return codecFactories[idx] != nullptr;
1448 }
1449
1450 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
1451   size_t idx = static_cast<size_t>(type);
1452   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1453     throw std::invalid_argument(
1454         to<std::string>("Compression type ", idx, " invalid"));
1455   }
1456   auto factory = codecFactories[idx];
1457   if (!factory) {
1458     throw std::invalid_argument(to<std::string>(
1459         "Compression type ", idx, " not supported"));
1460   }
1461   auto codec = (*factory)(level, type);
1462   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
1463   return codec;
1464 }
1465
1466 }}  // namespaces