Timestamping callback interface in folly::AsyncSocket
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/Compression.h>
18
19 #if FOLLY_HAVE_LIBLZ4
20 #include <lz4.h>
21 #include <lz4hc.h>
22 #endif
23
24 #include <glog/logging.h>
25
26 #if FOLLY_HAVE_LIBSNAPPY
27 #include <snappy.h>
28 #include <snappy-sinksource.h>
29 #endif
30
31 #if FOLLY_HAVE_LIBZ
32 #include <zlib.h>
33 #endif
34
35 #if FOLLY_HAVE_LIBLZMA
36 #include <lzma.h>
37 #endif
38
39 #if FOLLY_HAVE_LIBZSTD
40 #include <zstd.h>
41 #endif
42
43 #include <folly/Conv.h>
44 #include <folly/Memory.h>
45 #include <folly/Portability.h>
46 #include <folly/ScopeGuard.h>
47 #include <folly/Varint.h>
48 #include <folly/io/Cursor.h>
49
50 namespace folly { namespace io {
51
52 Codec::Codec(CodecType type) : type_(type) { }
53
54 // Ensure consistent behavior in the nullptr case
55 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
56   uint64_t len = data->computeChainDataLength();
57   if (len == 0) {
58     return IOBuf::create(0);
59   } else if (len > maxUncompressedLength()) {
60     throw std::runtime_error("Codec: uncompressed length too large");
61   }
62
63   return doCompress(data);
64 }
65
66 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
67                                          uint64_t uncompressedLength) {
68   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
69     if (needsUncompressedLength()) {
70       throw std::invalid_argument("Codec: uncompressed length required");
71     }
72   } else if (uncompressedLength > maxUncompressedLength()) {
73     throw std::runtime_error("Codec: uncompressed length too large");
74   }
75
76   if (data->empty()) {
77     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
78         uncompressedLength != 0) {
79       throw std::runtime_error("Codec: invalid uncompressed length");
80     }
81     return IOBuf::create(0);
82   }
83
84   return doUncompress(data, uncompressedLength);
85 }
86
87 bool Codec::needsUncompressedLength() const {
88   return doNeedsUncompressedLength();
89 }
90
91 uint64_t Codec::maxUncompressedLength() const {
92   return doMaxUncompressedLength();
93 }
94
95 bool Codec::doNeedsUncompressedLength() const {
96   return false;
97 }
98
99 uint64_t Codec::doMaxUncompressedLength() const {
100   return UNLIMITED_UNCOMPRESSED_LENGTH;
101 }
102
103 namespace {
104
105 /**
106  * No compression
107  */
108 class NoCompressionCodec final : public Codec {
109  public:
110   static std::unique_ptr<Codec> create(int level, CodecType type);
111   explicit NoCompressionCodec(int level, CodecType type);
112
113  private:
114   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
115   std::unique_ptr<IOBuf> doUncompress(
116       const IOBuf* data,
117       uint64_t uncompressedLength) override;
118 };
119
120 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
121   return make_unique<NoCompressionCodec>(level, type);
122 }
123
124 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
125   : Codec(type) {
126   DCHECK(type == CodecType::NO_COMPRESSION);
127   switch (level) {
128   case COMPRESSION_LEVEL_DEFAULT:
129   case COMPRESSION_LEVEL_FASTEST:
130   case COMPRESSION_LEVEL_BEST:
131     level = 0;
132   }
133   if (level != 0) {
134     throw std::invalid_argument(to<std::string>(
135         "NoCompressionCodec: invalid level ", level));
136   }
137 }
138
139 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
140     const IOBuf* data) {
141   return data->clone();
142 }
143
144 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
145     const IOBuf* data,
146     uint64_t uncompressedLength) {
147   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
148       data->computeChainDataLength() != uncompressedLength) {
149     throw std::runtime_error(to<std::string>(
150         "NoCompressionCodec: invalid uncompressed length"));
151   }
152   return data->clone();
153 }
154
155 #if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
156
157 namespace {
158
159 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
160   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
161   out->append(encodeVarint(val, out->writableTail()));
162 }
163
164 inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
165   uint64_t val = 0;
166   int8_t b = 0;
167   for (int shift = 0; shift <= 63; shift += 7) {
168     b = cursor.read<int8_t>();
169     val |= static_cast<uint64_t>(b & 0x7f) << shift;
170     if (b >= 0) {
171       break;
172     }
173   }
174   if (b < 0) {
175     throw std::invalid_argument("Invalid varint value. Too big.");
176   }
177   return val;
178 }
179
180 }  // namespace
181
182 #endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
183
184 #if FOLLY_HAVE_LIBLZ4
185
186 /**
187  * LZ4 compression
188  */
189 class LZ4Codec final : public Codec {
190  public:
191   static std::unique_ptr<Codec> create(int level, CodecType type);
192   explicit LZ4Codec(int level, CodecType type);
193
194  private:
195   bool doNeedsUncompressedLength() const override;
196   uint64_t doMaxUncompressedLength() const override;
197
198   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
199
200   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
201   std::unique_ptr<IOBuf> doUncompress(
202       const IOBuf* data,
203       uint64_t uncompressedLength) override;
204
205   bool highCompression_;
206 };
207
208 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
209   return make_unique<LZ4Codec>(level, type);
210 }
211
212 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
213   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
214
215   switch (level) {
216   case COMPRESSION_LEVEL_FASTEST:
217   case COMPRESSION_LEVEL_DEFAULT:
218     level = 1;
219     break;
220   case COMPRESSION_LEVEL_BEST:
221     level = 2;
222     break;
223   }
224   if (level < 1 || level > 2) {
225     throw std::invalid_argument(to<std::string>(
226         "LZ4Codec: invalid level: ", level));
227   }
228   highCompression_ = (level > 1);
229 }
230
231 bool LZ4Codec::doNeedsUncompressedLength() const {
232   return !encodeSize();
233 }
234
235 // The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
236 // define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
237 // here.
238 #ifndef LZ4_MAX_INPUT_SIZE
239 # define LZ4_MAX_INPUT_SIZE 0x7E000000
240 #endif
241
242 uint64_t LZ4Codec::doMaxUncompressedLength() const {
243   return LZ4_MAX_INPUT_SIZE;
244 }
245
246 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
247   std::unique_ptr<IOBuf> clone;
248   if (data->isChained()) {
249     // LZ4 doesn't support streaming, so we have to coalesce
250     clone = data->clone();
251     clone->coalesce();
252     data = clone.get();
253   }
254
255   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
256   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
257   if (encodeSize()) {
258     encodeVarintToIOBuf(data->length(), out.get());
259   }
260
261   int n;
262   auto input = reinterpret_cast<const char*>(data->data());
263   auto output = reinterpret_cast<char*>(out->writableTail());
264   const auto inputLength = data->length();
265 #if LZ4_VERSION_NUMBER >= 10700
266   if (highCompression_) {
267     n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0);
268   } else {
269     n = LZ4_compress_default(input, output, inputLength, out->tailroom());
270   }
271 #else
272   if (highCompression_) {
273     n = LZ4_compressHC(input, output, inputLength);
274   } else {
275     n = LZ4_compress(input, output, inputLength);
276   }
277 #endif
278
279   CHECK_GE(n, 0);
280   CHECK_LE(n, out->capacity());
281
282   out->append(n);
283   return out;
284 }
285
286 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
287     const IOBuf* data,
288     uint64_t uncompressedLength) {
289   std::unique_ptr<IOBuf> clone;
290   if (data->isChained()) {
291     // LZ4 doesn't support streaming, so we have to coalesce
292     clone = data->clone();
293     clone->coalesce();
294     data = clone.get();
295   }
296
297   folly::io::Cursor cursor(data);
298   uint64_t actualUncompressedLength;
299   if (encodeSize()) {
300     actualUncompressedLength = decodeVarintFromCursor(cursor);
301     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
302         uncompressedLength != actualUncompressedLength) {
303       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
304     }
305   } else {
306     actualUncompressedLength = uncompressedLength;
307     if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
308         actualUncompressedLength > maxUncompressedLength()) {
309       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
310     }
311   }
312
313   auto sp = StringPiece{cursor.peekBytes()};
314   auto out = IOBuf::create(actualUncompressedLength);
315   int n = LZ4_decompress_safe(
316       sp.data(),
317       reinterpret_cast<char*>(out->writableTail()),
318       sp.size(),
319       actualUncompressedLength);
320
321   if (n < 0 || uint64_t(n) != actualUncompressedLength) {
322     throw std::runtime_error(to<std::string>(
323         "LZ4 decompression returned invalid value ", n));
324   }
325   out->append(actualUncompressedLength);
326   return out;
327 }
328
329 #endif  // FOLLY_HAVE_LIBLZ4
330
331 #if FOLLY_HAVE_LIBSNAPPY
332
333 /**
334  * Snappy compression
335  */
336
337 /**
338  * Implementation of snappy::Source that reads from a IOBuf chain.
339  */
340 class IOBufSnappySource final : public snappy::Source {
341  public:
342   explicit IOBufSnappySource(const IOBuf* data);
343   size_t Available() const override;
344   const char* Peek(size_t* len) override;
345   void Skip(size_t n) override;
346  private:
347   size_t available_;
348   io::Cursor cursor_;
349 };
350
351 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
352   : available_(data->computeChainDataLength()),
353     cursor_(data) {
354 }
355
356 size_t IOBufSnappySource::Available() const {
357   return available_;
358 }
359
360 const char* IOBufSnappySource::Peek(size_t* len) {
361   auto sp = StringPiece{cursor_.peekBytes()};
362   *len = sp.size();
363   return sp.data();
364 }
365
366 void IOBufSnappySource::Skip(size_t n) {
367   CHECK_LE(n, available_);
368   cursor_.skip(n);
369   available_ -= n;
370 }
371
372 class SnappyCodec final : public Codec {
373  public:
374   static std::unique_ptr<Codec> create(int level, CodecType type);
375   explicit SnappyCodec(int level, CodecType type);
376
377  private:
378   uint64_t doMaxUncompressedLength() const override;
379   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
380   std::unique_ptr<IOBuf> doUncompress(
381       const IOBuf* data,
382       uint64_t uncompressedLength) override;
383 };
384
385 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
386   return make_unique<SnappyCodec>(level, type);
387 }
388
389 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
390   DCHECK(type == CodecType::SNAPPY);
391   switch (level) {
392   case COMPRESSION_LEVEL_FASTEST:
393   case COMPRESSION_LEVEL_DEFAULT:
394   case COMPRESSION_LEVEL_BEST:
395     level = 1;
396   }
397   if (level != 1) {
398     throw std::invalid_argument(to<std::string>(
399         "SnappyCodec: invalid level: ", level));
400   }
401 }
402
403 uint64_t SnappyCodec::doMaxUncompressedLength() const {
404   // snappy.h uses uint32_t for lengths, so there's that.
405   return std::numeric_limits<uint32_t>::max();
406 }
407
408 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
409   IOBufSnappySource source(data);
410   auto out =
411     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
412
413   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
414       out->writableTail()));
415
416   size_t n = snappy::Compress(&source, &sink);
417
418   CHECK_LE(n, out->capacity());
419   out->append(n);
420   return out;
421 }
422
423 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
424                                                  uint64_t uncompressedLength) {
425   uint32_t actualUncompressedLength = 0;
426
427   {
428     IOBufSnappySource source(data);
429     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
430       throw std::runtime_error("snappy::GetUncompressedLength failed");
431     }
432     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
433         uncompressedLength != actualUncompressedLength) {
434       throw std::runtime_error("snappy: invalid uncompressed length");
435     }
436   }
437
438   auto out = IOBuf::create(actualUncompressedLength);
439
440   {
441     IOBufSnappySource source(data);
442     if (!snappy::RawUncompress(&source,
443                                reinterpret_cast<char*>(out->writableTail()))) {
444       throw std::runtime_error("snappy::RawUncompress failed");
445     }
446   }
447
448   out->append(actualUncompressedLength);
449   return out;
450 }
451
452 #endif  // FOLLY_HAVE_LIBSNAPPY
453
454 #if FOLLY_HAVE_LIBZ
455 /**
456  * Zlib codec
457  */
458 class ZlibCodec final : public Codec {
459  public:
460   static std::unique_ptr<Codec> create(int level, CodecType type);
461   explicit ZlibCodec(int level, CodecType type);
462
463  private:
464   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
465   std::unique_ptr<IOBuf> doUncompress(
466       const IOBuf* data,
467       uint64_t uncompressedLength) override;
468
469   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
470   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
471
472   int level_;
473 };
474
475 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
476   return make_unique<ZlibCodec>(level, type);
477 }
478
479 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
480   DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
481   switch (level) {
482   case COMPRESSION_LEVEL_FASTEST:
483     level = 1;
484     break;
485   case COMPRESSION_LEVEL_DEFAULT:
486     level = Z_DEFAULT_COMPRESSION;
487     break;
488   case COMPRESSION_LEVEL_BEST:
489     level = 9;
490     break;
491   }
492   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
493     throw std::invalid_argument(to<std::string>(
494         "ZlibCodec: invalid level: ", level));
495   }
496   level_ = level;
497 }
498
499 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
500                                                   uint32_t length) {
501   CHECK_EQ(stream->avail_out, 0);
502
503   auto buf = IOBuf::create(length);
504   buf->append(length);
505
506   stream->next_out = buf->writableData();
507   stream->avail_out = buf->length();
508
509   return buf;
510 }
511
512 bool ZlibCodec::doInflate(z_stream* stream,
513                           IOBuf* head,
514                           uint32_t bufferLength) {
515   if (stream->avail_out == 0) {
516     head->prependChain(addOutputBuffer(stream, bufferLength));
517   }
518
519   int rc = inflate(stream, Z_NO_FLUSH);
520
521   switch (rc) {
522   case Z_OK:
523     break;
524   case Z_STREAM_END:
525     return true;
526   case Z_BUF_ERROR:
527   case Z_NEED_DICT:
528   case Z_DATA_ERROR:
529   case Z_MEM_ERROR:
530     throw std::runtime_error(to<std::string>(
531         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
532   default:
533     CHECK(false) << rc << ": " << stream->msg;
534   }
535
536   return false;
537 }
538
539 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
540   z_stream stream;
541   stream.zalloc = nullptr;
542   stream.zfree = nullptr;
543   stream.opaque = nullptr;
544
545   // Using deflateInit2() to support gzip.  "The windowBits parameter is the
546   // base two logarithm of the maximum window size (...) The default value is
547   // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
548   // around the compressed data instead of a zlib wrapper. The gzip header
549   // will have no file name, no extra data, no comment, no modification time
550   // (set to zero), no header crc, and the operating system will be set to 255
551   // (unknown)."
552   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
553   // All other parameters (method, memLevel, strategy) get default values from
554   // the zlib manual.
555   int rc = deflateInit2(&stream,
556                         level_,
557                         Z_DEFLATED,
558                         windowBits,
559                         /* memLevel */ 8,
560                         Z_DEFAULT_STRATEGY);
561   if (rc != Z_OK) {
562     throw std::runtime_error(to<std::string>(
563         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
564   }
565
566   stream.next_in = stream.next_out = nullptr;
567   stream.avail_in = stream.avail_out = 0;
568   stream.total_in = stream.total_out = 0;
569
570   bool success = false;
571
572   SCOPE_EXIT {
573     rc = deflateEnd(&stream);
574     // If we're here because of an exception, it's okay if some data
575     // got dropped.
576     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
577       << rc << ": " << stream.msg;
578   };
579
580   uint64_t uncompressedLength = data->computeChainDataLength();
581   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
582
583   // Max 64MiB in one go
584   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
585   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
586
587   auto out = addOutputBuffer(
588       &stream,
589       (maxCompressedLength <= maxSingleStepLength ?
590        maxCompressedLength :
591        defaultBufferLength));
592
593   for (auto& range : *data) {
594     uint64_t remaining = range.size();
595     uint64_t written = 0;
596     while (remaining) {
597       uint32_t step = (remaining > maxSingleStepLength ?
598                        maxSingleStepLength : remaining);
599       stream.next_in = const_cast<uint8_t*>(range.data() + written);
600       stream.avail_in = step;
601       remaining -= step;
602       written += step;
603
604       while (stream.avail_in != 0) {
605         if (stream.avail_out == 0) {
606           out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
607         }
608
609         rc = deflate(&stream, Z_NO_FLUSH);
610
611         CHECK_EQ(rc, Z_OK) << stream.msg;
612       }
613     }
614   }
615
616   do {
617     if (stream.avail_out == 0) {
618       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
619     }
620
621     rc = deflate(&stream, Z_FINISH);
622   } while (rc == Z_OK);
623
624   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
625
626   out->prev()->trimEnd(stream.avail_out);
627
628   success = true;  // we survived
629
630   return out;
631 }
632
633 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
634                                                uint64_t uncompressedLength) {
635   z_stream stream;
636   stream.zalloc = nullptr;
637   stream.zfree = nullptr;
638   stream.opaque = nullptr;
639
640   // "The windowBits parameter is the base two logarithm of the maximum window
641   // size (...) The default value is 15 (...) add 16 to decode only the gzip
642   // format (the zlib format will return a Z_DATA_ERROR)."
643   int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
644   int rc = inflateInit2(&stream, windowBits);
645   if (rc != Z_OK) {
646     throw std::runtime_error(to<std::string>(
647         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
648   }
649
650   stream.next_in = stream.next_out = nullptr;
651   stream.avail_in = stream.avail_out = 0;
652   stream.total_in = stream.total_out = 0;
653
654   bool success = false;
655
656   SCOPE_EXIT {
657     rc = inflateEnd(&stream);
658     // If we're here because of an exception, it's okay if some data
659     // got dropped.
660     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
661       << rc << ": " << stream.msg;
662   };
663
664   // Max 64MiB in one go
665   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
666   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
667
668   auto out = addOutputBuffer(
669       &stream,
670       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
671         uncompressedLength <= maxSingleStepLength) ?
672        uncompressedLength :
673        defaultBufferLength));
674
675   bool streamEnd = false;
676   for (auto& range : *data) {
677     if (range.empty()) {
678       continue;
679     }
680
681     stream.next_in = const_cast<uint8_t*>(range.data());
682     stream.avail_in = range.size();
683
684     while (stream.avail_in != 0) {
685       if (streamEnd) {
686         throw std::runtime_error(to<std::string>(
687             "ZlibCodec: junk after end of data"));
688       }
689
690       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
691     }
692   }
693
694   while (!streamEnd) {
695     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
696   }
697
698   out->prev()->trimEnd(stream.avail_out);
699
700   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
701       uncompressedLength != stream.total_out) {
702     throw std::runtime_error(to<std::string>(
703         "ZlibCodec: invalid uncompressed length"));
704   }
705
706   success = true;  // we survived
707
708   return out;
709 }
710
711 #endif  // FOLLY_HAVE_LIBZ
712
713 #if FOLLY_HAVE_LIBLZMA
714
715 /**
716  * LZMA2 compression
717  */
718 class LZMA2Codec final : public Codec {
719  public:
720   static std::unique_ptr<Codec> create(int level, CodecType type);
721   explicit LZMA2Codec(int level, CodecType type);
722
723  private:
724   bool doNeedsUncompressedLength() const override;
725   uint64_t doMaxUncompressedLength() const override;
726
727   bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
728
729   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
730   std::unique_ptr<IOBuf> doUncompress(
731       const IOBuf* data,
732       uint64_t uncompressedLength) override;
733
734   std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
735   bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
736
737   int level_;
738 };
739
740 std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
741   return make_unique<LZMA2Codec>(level, type);
742 }
743
744 LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
745   DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
746   switch (level) {
747   case COMPRESSION_LEVEL_FASTEST:
748     level = 0;
749     break;
750   case COMPRESSION_LEVEL_DEFAULT:
751     level = LZMA_PRESET_DEFAULT;
752     break;
753   case COMPRESSION_LEVEL_BEST:
754     level = 9;
755     break;
756   }
757   if (level < 0 || level > 9) {
758     throw std::invalid_argument(to<std::string>(
759         "LZMA2Codec: invalid level: ", level));
760   }
761   level_ = level;
762 }
763
764 bool LZMA2Codec::doNeedsUncompressedLength() const {
765   return !encodeSize();
766 }
767
768 uint64_t LZMA2Codec::doMaxUncompressedLength() const {
769   // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
770   return uint64_t(1) << 63;
771 }
772
773 std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
774     lzma_stream* stream,
775     size_t length) {
776
777   CHECK_EQ(stream->avail_out, 0);
778
779   auto buf = IOBuf::create(length);
780   buf->append(length);
781
782   stream->next_out = buf->writableData();
783   stream->avail_out = buf->length();
784
785   return buf;
786 }
787
788 std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
789   lzma_ret rc;
790   lzma_stream stream = LZMA_STREAM_INIT;
791
792   rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
793   if (rc != LZMA_OK) {
794     throw std::runtime_error(folly::to<std::string>(
795       "LZMA2Codec: lzma_easy_encoder error: ", rc));
796   }
797
798   SCOPE_EXIT { lzma_end(&stream); };
799
800   uint64_t uncompressedLength = data->computeChainDataLength();
801   uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
802
803   // Max 64MiB in one go
804   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
805   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
806
807   auto out = addOutputBuffer(
808     &stream,
809     (maxCompressedLength <= maxSingleStepLength ?
810      maxCompressedLength :
811      defaultBufferLength));
812
813   if (encodeSize()) {
814     auto size = IOBuf::createCombined(kMaxVarintLength64);
815     encodeVarintToIOBuf(uncompressedLength, size.get());
816     size->appendChain(std::move(out));
817     out = std::move(size);
818   }
819
820   for (auto& range : *data) {
821     if (range.empty()) {
822       continue;
823     }
824
825     stream.next_in = const_cast<uint8_t*>(range.data());
826     stream.avail_in = range.size();
827
828     while (stream.avail_in != 0) {
829       if (stream.avail_out == 0) {
830         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
831       }
832
833       rc = lzma_code(&stream, LZMA_RUN);
834
835       if (rc != LZMA_OK) {
836         throw std::runtime_error(folly::to<std::string>(
837           "LZMA2Codec: lzma_code error: ", rc));
838       }
839     }
840   }
841
842   do {
843     if (stream.avail_out == 0) {
844       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
845     }
846
847     rc = lzma_code(&stream, LZMA_FINISH);
848   } while (rc == LZMA_OK);
849
850   if (rc != LZMA_STREAM_END) {
851     throw std::runtime_error(folly::to<std::string>(
852       "LZMA2Codec: lzma_code ended with error: ", rc));
853   }
854
855   out->prev()->trimEnd(stream.avail_out);
856
857   return out;
858 }
859
860 bool LZMA2Codec::doInflate(lzma_stream* stream,
861                           IOBuf* head,
862                           size_t bufferLength) {
863   if (stream->avail_out == 0) {
864     head->prependChain(addOutputBuffer(stream, bufferLength));
865   }
866
867   lzma_ret rc = lzma_code(stream, LZMA_RUN);
868
869   switch (rc) {
870   case LZMA_OK:
871     break;
872   case LZMA_STREAM_END:
873     return true;
874   default:
875     throw std::runtime_error(to<std::string>(
876         "LZMA2Codec: lzma_code error: ", rc));
877   }
878
879   return false;
880 }
881
882 std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
883                                                uint64_t uncompressedLength) {
884   lzma_ret rc;
885   lzma_stream stream = LZMA_STREAM_INIT;
886
887   rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
888   if (rc != LZMA_OK) {
889     throw std::runtime_error(folly::to<std::string>(
890       "LZMA2Codec: lzma_auto_decoder error: ", rc));
891   }
892
893   SCOPE_EXIT { lzma_end(&stream); };
894
895   // Max 64MiB in one go
896   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
897   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
898
899   folly::io::Cursor cursor(data);
900   uint64_t actualUncompressedLength;
901   if (encodeSize()) {
902     actualUncompressedLength = decodeVarintFromCursor(cursor);
903     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
904         uncompressedLength != actualUncompressedLength) {
905       throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
906     }
907   } else {
908     actualUncompressedLength = uncompressedLength;
909     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
910   }
911
912   auto out = addOutputBuffer(
913       &stream,
914       (actualUncompressedLength <= maxSingleStepLength ?
915        actualUncompressedLength :
916        defaultBufferLength));
917
918   bool streamEnd = false;
919   auto buf = cursor.peekBytes();
920   while (!buf.empty()) {
921     stream.next_in = const_cast<uint8_t*>(buf.data());
922     stream.avail_in = buf.size();
923
924     while (stream.avail_in != 0) {
925       if (streamEnd) {
926         throw std::runtime_error(to<std::string>(
927             "LZMA2Codec: junk after end of data"));
928       }
929
930       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
931     }
932
933     cursor.skip(buf.size());
934     buf = cursor.peekBytes();
935   }
936
937   while (!streamEnd) {
938     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
939   }
940
941   out->prev()->trimEnd(stream.avail_out);
942
943   if (actualUncompressedLength != stream.total_out) {
944     throw std::runtime_error(to<std::string>(
945         "LZMA2Codec: invalid uncompressed length"));
946   }
947
948   return out;
949 }
950
951 #endif  // FOLLY_HAVE_LIBLZMA
952
953 #ifdef FOLLY_HAVE_LIBZSTD
954
955 /**
956  * ZSTD compression
957  */
958 class ZSTDCodec final : public Codec {
959  public:
960   static std::unique_ptr<Codec> create(int level, CodecType);
961   explicit ZSTDCodec(int level, CodecType type);
962
963  private:
964   bool doNeedsUncompressedLength() const override;
965   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
966   std::unique_ptr<IOBuf> doUncompress(
967       const IOBuf* data,
968       uint64_t uncompressedLength) override;
969
970   int level_;
971 };
972
973 std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
974   return make_unique<ZSTDCodec>(level, type);
975 }
976
977 ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
978   DCHECK(type == CodecType::ZSTD);
979   switch (level) {
980     case COMPRESSION_LEVEL_FASTEST:
981       level = 1;
982       break;
983     case COMPRESSION_LEVEL_DEFAULT:
984       level = 1;
985       break;
986     case COMPRESSION_LEVEL_BEST:
987       level = 19;
988       break;
989   }
990   if (level < 1 || level > ZSTD_maxCLevel()) {
991     throw std::invalid_argument(
992         to<std::string>("ZSTD: invalid level: ", level));
993   }
994   level_ = level;
995 }
996
997 bool ZSTDCodec::doNeedsUncompressedLength() const {
998   return false;
999 }
1000
1001 void zstdThrowIfError(size_t rc) {
1002   if (!ZSTD_isError(rc)) {
1003     return;
1004   }
1005   throw std::runtime_error(
1006       to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
1007 }
1008
1009 std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
1010   // Support earlier versions of the codec (working with a single IOBuf,
1011   // and using ZSTD_decompress which requires ZSTD frame to contain size,
1012   // which isn't populated by streaming API).
1013   if (!data->isChained()) {
1014     auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
1015     const auto rc = ZSTD_compress(
1016         out->writableData(),
1017         out->capacity(),
1018         data->data(),
1019         data->length(),
1020         level_);
1021     zstdThrowIfError(rc);
1022     out->append(rc);
1023     return out;
1024   }
1025
1026   auto zcs = ZSTD_createCStream();
1027   SCOPE_EXIT {
1028     ZSTD_freeCStream(zcs);
1029   };
1030
1031   auto rc = ZSTD_initCStream(zcs, level_);
1032   zstdThrowIfError(rc);
1033
1034   Cursor cursor(data);
1035   auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
1036
1037   ZSTD_outBuffer out;
1038   out.dst = result->writableTail();
1039   out.size = result->capacity();
1040   out.pos = 0;
1041
1042   for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
1043     ZSTD_inBuffer in;
1044     in.src = buffer.data();
1045     in.size = buffer.size();
1046     for (in.pos = 0; in.pos != in.size;) {
1047       rc = ZSTD_compressStream(zcs, &out, &in);
1048       zstdThrowIfError(rc);
1049     }
1050     cursor.skip(in.size);
1051     buffer = cursor.peekBytes();
1052   }
1053
1054   rc = ZSTD_endStream(zcs, &out);
1055   zstdThrowIfError(rc);
1056   CHECK_EQ(rc, 0);
1057
1058   result->append(out.pos);
1059   return result;
1060 }
1061
1062 std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
1063     const IOBuf* data,
1064     uint64_t uncompressedLength) {
1065   auto zds = ZSTD_createDStream();
1066   SCOPE_EXIT {
1067     ZSTD_freeDStream(zds);
1068   };
1069
1070   auto rc = ZSTD_initDStream(zds);
1071   zstdThrowIfError(rc);
1072
1073   ZSTD_outBuffer out{};
1074   ZSTD_inBuffer in{};
1075
1076   auto outputSize = ZSTD_DStreamOutSize();
1077   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
1078     outputSize = uncompressedLength;
1079   } else {
1080     auto decompressedSize =
1081         ZSTD_getDecompressedSize(data->data(), data->length());
1082     if (decompressedSize != 0 && decompressedSize < outputSize) {
1083       outputSize = decompressedSize;
1084     }
1085   }
1086
1087   IOBufQueue queue(IOBufQueue::cacheChainLength());
1088
1089   Cursor cursor(data);
1090   for (rc = 0;;) {
1091     if (in.pos == in.size) {
1092       auto buffer = cursor.peekBytes();
1093       in.src = buffer.data();
1094       in.size = buffer.size();
1095       in.pos = 0;
1096       cursor.skip(in.size);
1097       if (rc > 1 && in.size == 0) {
1098         throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
1099       }
1100     }
1101     if (out.pos == out.size) {
1102       if (out.pos != 0) {
1103         queue.postallocate(out.pos);
1104       }
1105       auto buffer = queue.preallocate(outputSize, outputSize);
1106       out.dst = buffer.first;
1107       out.size = buffer.second;
1108       out.pos = 0;
1109       outputSize = ZSTD_DStreamOutSize();
1110     }
1111     rc = ZSTD_decompressStream(zds, &out, &in);
1112     zstdThrowIfError(rc);
1113     if (rc == 0) {
1114       break;
1115     }
1116   }
1117   if (out.pos != 0) {
1118     queue.postallocate(out.pos);
1119   }
1120   if (in.pos != in.size || !cursor.isAtEnd()) {
1121     throw std::runtime_error("ZSTD: junk after end of data");
1122   }
1123   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
1124       queue.chainLength() != uncompressedLength) {
1125     throw std::runtime_error("ZSTD: invalid uncompressed length");
1126   }
1127
1128   return queue.move();
1129 }
1130
1131 #endif  // FOLLY_HAVE_LIBZSTD
1132
1133 }  // namespace
1134
1135 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
1136   typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
1137
1138   static CodecFactory codecFactories[
1139     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
1140     nullptr,  // USER_DEFINED
1141     NoCompressionCodec::create,
1142
1143 #if FOLLY_HAVE_LIBLZ4
1144     LZ4Codec::create,
1145 #else
1146     nullptr,
1147 #endif
1148
1149 #if FOLLY_HAVE_LIBSNAPPY
1150     SnappyCodec::create,
1151 #else
1152     nullptr,
1153 #endif
1154
1155 #if FOLLY_HAVE_LIBZ
1156     ZlibCodec::create,
1157 #else
1158     nullptr,
1159 #endif
1160
1161 #if FOLLY_HAVE_LIBLZ4
1162     LZ4Codec::create,
1163 #else
1164     nullptr,
1165 #endif
1166
1167 #if FOLLY_HAVE_LIBLZMA
1168     LZMA2Codec::create,
1169     LZMA2Codec::create,
1170 #else
1171     nullptr,
1172     nullptr,
1173 #endif
1174
1175 #if FOLLY_HAVE_LIBZSTD
1176     ZSTDCodec::create,
1177 #else
1178     nullptr,
1179 #endif
1180
1181 #if FOLLY_HAVE_LIBZ
1182     ZlibCodec::create,
1183 #else
1184     nullptr,
1185 #endif
1186   };
1187
1188   size_t idx = static_cast<size_t>(type);
1189   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
1190     throw std::invalid_argument(to<std::string>(
1191         "Compression type ", idx, " not supported"));
1192   }
1193   auto factory = codecFactories[idx];
1194   if (!factory) {
1195     throw std::invalid_argument(to<std::string>(
1196         "Compression type ", idx, " not supported"));
1197   }
1198   auto codec = (*factory)(level, type);
1199   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
1200   return codec;
1201 }
1202
1203 }}  // namespaces