Add Varint-length-prefixed flavor of LZ4
[folly.git] / folly / io / Compression.cpp
1 /*
2  * Copyright 2013 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "folly/io/Compression.h"
18
19 #include <lz4.h>
20 #include <lz4hc.h>
21 #include <glog/logging.h>
22 #include <snappy.h>
23 #include <snappy-sinksource.h>
24 #include <zlib.h>
25
26 #include "folly/Conv.h"
27 #include "folly/Memory.h"
28 #include "folly/Portability.h"
29 #include "folly/ScopeGuard.h"
30 #include "folly/Varint.h"
31 #include "folly/io/Cursor.h"
32
33 namespace folly { namespace io {
34
35 Codec::Codec(CodecType type) : type_(type) { }
36
37 // Ensure consistent behavior in the nullptr case
38 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
39   return !data->empty() ? doCompress(data) : IOBuf::create(0);
40 }
41
42 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
43                                          uint64_t uncompressedLength) {
44   if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) {
45     if (needsUncompressedLength()) {
46       throw std::invalid_argument("Codec: uncompressed length required");
47     }
48   } else if (uncompressedLength > maxUncompressedLength()) {
49     throw std::runtime_error("Codec: uncompressed length too large");
50   }
51
52   if (data->empty()) {
53     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
54         uncompressedLength != 0) {
55       throw std::runtime_error("Codec: invalid uncompressed length");
56     }
57     return IOBuf::create(0);
58   }
59
60   return doUncompress(data, uncompressedLength);
61 }
62
63 bool Codec::needsUncompressedLength() const {
64   return doNeedsUncompressedLength();
65 }
66
67 uint64_t Codec::maxUncompressedLength() const {
68   return doMaxUncompressedLength();
69 }
70
71 bool Codec::doNeedsUncompressedLength() const {
72   return false;
73 }
74
75 uint64_t Codec::doMaxUncompressedLength() const {
76   return std::numeric_limits<uint64_t>::max() - 1;
77 }
78
79 namespace {
80
81 /**
82  * No compression
83  */
84 class NoCompressionCodec FOLLY_FINAL : public Codec {
85  public:
86   static std::unique_ptr<Codec> create(int level, CodecType type);
87   explicit NoCompressionCodec(int level, CodecType type);
88
89  private:
90   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
91   std::unique_ptr<IOBuf> doUncompress(
92       const IOBuf* data,
93       uint64_t uncompressedLength) FOLLY_OVERRIDE;
94 };
95
96 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
97   return make_unique<NoCompressionCodec>(level, type);
98 }
99
100 NoCompressionCodec::NoCompressionCodec(int level, CodecType type)
101   : Codec(type) {
102   DCHECK(type == CodecType::NO_COMPRESSION);
103   switch (level) {
104   case COMPRESSION_LEVEL_DEFAULT:
105   case COMPRESSION_LEVEL_FASTEST:
106   case COMPRESSION_LEVEL_BEST:
107     level = 0;
108   }
109   if (level != 0) {
110     throw std::invalid_argument(to<std::string>(
111         "NoCompressionCodec: invalid level ", level));
112   }
113 }
114
115 std::unique_ptr<IOBuf> NoCompressionCodec::doCompress(
116     const IOBuf* data) {
117   return data->clone();
118 }
119
120 std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
121     const IOBuf* data,
122     uint64_t uncompressedLength) {
123   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
124       data->computeChainDataLength() != uncompressedLength) {
125     throw std::runtime_error(to<std::string>(
126         "NoCompressionCodec: invalid uncompressed length"));
127   }
128   return data->clone();
129 }
130
131 /**
132  * LZ4 compression
133  */
134 class LZ4Codec FOLLY_FINAL : public Codec {
135  public:
136   static std::unique_ptr<Codec> create(int level, CodecType type);
137   explicit LZ4Codec(int level, CodecType type);
138
139  private:
140   bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
141   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
142
143   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
144
145   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
146   std::unique_ptr<IOBuf> doUncompress(
147       const IOBuf* data,
148       uint64_t uncompressedLength) FOLLY_OVERRIDE;
149
150   bool highCompression_;
151 };
152
153 std::unique_ptr<Codec> LZ4Codec::create(int level, CodecType type) {
154   return make_unique<LZ4Codec>(level, type);
155 }
156
157 LZ4Codec::LZ4Codec(int level, CodecType type) : Codec(type) {
158   DCHECK(type == CodecType::LZ4 || type == CodecType::LZ4_VARINT_SIZE);
159
160   switch (level) {
161   case COMPRESSION_LEVEL_FASTEST:
162   case COMPRESSION_LEVEL_DEFAULT:
163     level = 1;
164     break;
165   case COMPRESSION_LEVEL_BEST:
166     level = 2;
167     break;
168   }
169   if (level < 1 || level > 2) {
170     throw std::invalid_argument(to<std::string>(
171         "LZ4Codec: invalid level: ", level));
172   }
173   highCompression_ = (level > 1);
174 }
175
176 bool LZ4Codec::doNeedsUncompressedLength() const {
177   return !encodeSize();
178 }
179
180 uint64_t LZ4Codec::doMaxUncompressedLength() const {
181   // From lz4.h: "Max supported value is ~1.9GB"; I wish we had something
182   // more accurate.
183   return 1.8 * (uint64_t(1) << 30);
184 }
185
186 namespace {
187
188 void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
189   DCHECK_GE(out->tailroom(), kMaxVarintLength64);
190   out->append(encodeVarint(val, out->writableTail()));
191 }
192
193 uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
194   // Must have enough room in *this* buffer.
195   auto p = cursor.peek();
196   folly::ByteRange range(p.first, p.second);
197   uint64_t val = decodeVarint(range);
198   cursor.skip(range.data() - p.first);
199   return val;
200 }
201
202 }  // namespace
203
204 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
205   std::unique_ptr<IOBuf> clone;
206   if (data->isChained()) {
207     // LZ4 doesn't support streaming, so we have to coalesce
208     clone = data->clone();
209     clone->coalesce();
210     data = clone.get();
211   }
212
213   uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0;
214   auto out = IOBuf::create(extraSize + LZ4_compressBound(data->length()));
215   if (encodeSize()) {
216     encodeVarintToIOBuf(data->length(), out.get());
217   }
218
219   int n;
220   if (highCompression_) {
221     n = LZ4_compressHC(reinterpret_cast<const char*>(data->data()),
222                        reinterpret_cast<char*>(out->writableTail()),
223                        data->length());
224   } else {
225     n = LZ4_compress(reinterpret_cast<const char*>(data->data()),
226                      reinterpret_cast<char*>(out->writableTail()),
227                      data->length());
228   }
229
230   CHECK_GE(n, 0);
231   CHECK_LE(n, out->capacity());
232
233   out->append(n);
234   return out;
235 }
236
237 std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
238     const IOBuf* data,
239     uint64_t uncompressedLength) {
240   std::unique_ptr<IOBuf> clone;
241   if (data->isChained()) {
242     // LZ4 doesn't support streaming, so we have to coalesce
243     clone = data->clone();
244     clone->coalesce();
245     data = clone.get();
246   }
247
248   folly::io::Cursor cursor(data);
249   uint64_t actualUncompressedLength;
250   if (encodeSize()) {
251     actualUncompressedLength = decodeVarintFromCursor(cursor);
252     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
253         uncompressedLength != actualUncompressedLength) {
254       throw std::runtime_error("LZ4Codec: invalid uncompressed length");
255     }
256   } else {
257     actualUncompressedLength = uncompressedLength;
258     DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
259   }
260
261   auto out = IOBuf::create(actualUncompressedLength);
262   auto p = cursor.peek();
263   int n = LZ4_uncompress(reinterpret_cast<const char*>(p.first),
264                          reinterpret_cast<char*>(out->writableTail()),
265                          actualUncompressedLength);
266   if (n != p.second) {
267     throw std::runtime_error(to<std::string>(
268         "LZ4 decompression returned invalid value ", n));
269   }
270   out->append(actualUncompressedLength);
271   return out;
272 }
273
274 /**
275  * Snappy compression
276  */
277
278 /**
279  * Implementation of snappy::Source that reads from a IOBuf chain.
280  */
281 class IOBufSnappySource FOLLY_FINAL : public snappy::Source {
282  public:
283   explicit IOBufSnappySource(const IOBuf* data);
284   size_t Available() const FOLLY_OVERRIDE;
285   const char* Peek(size_t* len) FOLLY_OVERRIDE;
286   void Skip(size_t n) FOLLY_OVERRIDE;
287  private:
288   size_t available_;
289   io::Cursor cursor_;
290 };
291
292 IOBufSnappySource::IOBufSnappySource(const IOBuf* data)
293   : available_(data->computeChainDataLength()),
294     cursor_(data) {
295 }
296
297 size_t IOBufSnappySource::Available() const {
298   return available_;
299 }
300
301 const char* IOBufSnappySource::Peek(size_t* len) {
302   auto p = cursor_.peek();
303   *len = p.second;
304   return reinterpret_cast<const char*>(p.first);
305 }
306
307 void IOBufSnappySource::Skip(size_t n) {
308   CHECK_LE(n, available_);
309   cursor_.skip(n);
310   available_ -= n;
311 }
312
313 class SnappyCodec FOLLY_FINAL : public Codec {
314  public:
315   static std::unique_ptr<Codec> create(int level, CodecType type);
316   explicit SnappyCodec(int level, CodecType type);
317
318  private:
319   uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
320   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
321   std::unique_ptr<IOBuf> doUncompress(
322       const IOBuf* data,
323       uint64_t uncompressedLength) FOLLY_OVERRIDE;
324 };
325
326 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
327   return make_unique<SnappyCodec>(level, type);
328 }
329
330 SnappyCodec::SnappyCodec(int level, CodecType type) : Codec(type) {
331   DCHECK(type == CodecType::SNAPPY);
332   switch (level) {
333   case COMPRESSION_LEVEL_FASTEST:
334   case COMPRESSION_LEVEL_DEFAULT:
335   case COMPRESSION_LEVEL_BEST:
336     level = 1;
337   }
338   if (level != 1) {
339     throw std::invalid_argument(to<std::string>(
340         "SnappyCodec: invalid level: ", level));
341   }
342 }
343
344 uint64_t SnappyCodec::doMaxUncompressedLength() const {
345   // snappy.h uses uint32_t for lengths, so there's that.
346   return std::numeric_limits<uint32_t>::max();
347 }
348
349 std::unique_ptr<IOBuf> SnappyCodec::doCompress(const IOBuf* data) {
350   IOBufSnappySource source(data);
351   auto out =
352     IOBuf::create(snappy::MaxCompressedLength(source.Available()));
353
354   snappy::UncheckedByteArraySink sink(reinterpret_cast<char*>(
355       out->writableTail()));
356
357   size_t n = snappy::Compress(&source, &sink);
358
359   CHECK_LE(n, out->capacity());
360   out->append(n);
361   return out;
362 }
363
364 std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
365                                                  uint64_t uncompressedLength) {
366   uint32_t actualUncompressedLength = 0;
367
368   {
369     IOBufSnappySource source(data);
370     if (!snappy::GetUncompressedLength(&source, &actualUncompressedLength)) {
371       throw std::runtime_error("snappy::GetUncompressedLength failed");
372     }
373     if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
374         uncompressedLength != actualUncompressedLength) {
375       throw std::runtime_error("snappy: invalid uncompressed length");
376     }
377   }
378
379   auto out = IOBuf::create(actualUncompressedLength);
380
381   {
382     IOBufSnappySource source(data);
383     if (!snappy::RawUncompress(&source,
384                                reinterpret_cast<char*>(out->writableTail()))) {
385       throw std::runtime_error("snappy::RawUncompress failed");
386     }
387   }
388
389   out->append(actualUncompressedLength);
390   return out;
391 }
392
393 /**
394  * Zlib codec
395  */
396 class ZlibCodec FOLLY_FINAL : public Codec {
397  public:
398   static std::unique_ptr<Codec> create(int level, CodecType type);
399   explicit ZlibCodec(int level, CodecType type);
400
401  private:
402   std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
403   std::unique_ptr<IOBuf> doUncompress(
404       const IOBuf* data,
405       uint64_t uncompressedLength) FOLLY_OVERRIDE;
406
407   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
408   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
409
410   int level_;
411 };
412
413 std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
414   return make_unique<ZlibCodec>(level, type);
415 }
416
417 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
418   DCHECK(type == CodecType::ZLIB);
419   switch (level) {
420   case COMPRESSION_LEVEL_FASTEST:
421     level = 1;
422     break;
423   case COMPRESSION_LEVEL_DEFAULT:
424     level = Z_DEFAULT_COMPRESSION;
425     break;
426   case COMPRESSION_LEVEL_BEST:
427     level = 9;
428     break;
429   }
430   if (level != Z_DEFAULT_COMPRESSION && (level < 0 || level > 9)) {
431     throw std::invalid_argument(to<std::string>(
432         "ZlibCodec: invalid level: ", level));
433   }
434   level_ = level;
435 }
436
437 std::unique_ptr<IOBuf> ZlibCodec::addOutputBuffer(z_stream* stream,
438                                                   uint32_t length) {
439   CHECK_EQ(stream->avail_out, 0);
440
441   auto buf = IOBuf::create(length);
442   buf->append(length);
443
444   stream->next_out = buf->writableData();
445   stream->avail_out = buf->length();
446
447   return buf;
448 }
449
450 bool ZlibCodec::doInflate(z_stream* stream,
451                           IOBuf* head,
452                           uint32_t bufferLength) {
453   if (stream->avail_out == 0) {
454     head->prependChain(addOutputBuffer(stream, bufferLength));
455   }
456
457   int rc = inflate(stream, Z_NO_FLUSH);
458
459   switch (rc) {
460   case Z_OK:
461     break;
462   case Z_STREAM_END:
463     return true;
464   case Z_BUF_ERROR:
465   case Z_NEED_DICT:
466   case Z_DATA_ERROR:
467   case Z_MEM_ERROR:
468     throw std::runtime_error(to<std::string>(
469         "ZlibCodec: inflate error: ", rc, ": ", stream->msg));
470   default:
471     CHECK(false) << rc << ": " << stream->msg;
472   }
473
474   return false;
475 }
476
477
478 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
479   z_stream stream;
480   stream.zalloc = nullptr;
481   stream.zfree = nullptr;
482   stream.opaque = nullptr;
483
484   int rc = deflateInit(&stream, level_);
485   if (rc != Z_OK) {
486     throw std::runtime_error(to<std::string>(
487         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
488   }
489
490   stream.next_in = stream.next_out = nullptr;
491   stream.avail_in = stream.avail_out = 0;
492   stream.total_in = stream.total_out = 0;
493
494   bool success = false;
495
496   SCOPE_EXIT {
497     int rc = deflateEnd(&stream);
498     // If we're here because of an exception, it's okay if some data
499     // got dropped.
500     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
501       << rc << ": " << stream.msg;
502   };
503
504   uint64_t uncompressedLength = data->computeChainDataLength();
505   uint64_t maxCompressedLength = deflateBound(&stream, uncompressedLength);
506
507   // Max 64MiB in one go
508   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
509   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
510
511   auto out = addOutputBuffer(
512       &stream,
513       (maxCompressedLength <= maxSingleStepLength ?
514        maxCompressedLength :
515        defaultBufferLength));
516
517   for (auto& range : *data) {
518     if (range.empty()) {
519       continue;
520     }
521
522     stream.next_in = const_cast<uint8_t*>(range.data());
523     stream.avail_in = range.size();
524
525     while (stream.avail_in != 0) {
526       if (stream.avail_out == 0) {
527         out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
528       }
529
530       rc = deflate(&stream, Z_NO_FLUSH);
531
532       CHECK_EQ(rc, Z_OK) << stream.msg;
533     }
534   }
535
536   do {
537     if (stream.avail_out == 0) {
538       out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
539     }
540
541     rc = deflate(&stream, Z_FINISH);
542   } while (rc == Z_OK);
543
544   CHECK_EQ(rc, Z_STREAM_END) << stream.msg;
545
546   out->prev()->trimEnd(stream.avail_out);
547
548   success = true;  // we survived
549
550   return out;
551 }
552
553 std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
554                                                uint64_t uncompressedLength) {
555   z_stream stream;
556   stream.zalloc = nullptr;
557   stream.zfree = nullptr;
558   stream.opaque = nullptr;
559
560   int rc = inflateInit(&stream);
561   if (rc != Z_OK) {
562     throw std::runtime_error(to<std::string>(
563         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
564   }
565
566   stream.next_in = stream.next_out = nullptr;
567   stream.avail_in = stream.avail_out = 0;
568   stream.total_in = stream.total_out = 0;
569
570   bool success = false;
571
572   SCOPE_EXIT {
573     int rc = inflateEnd(&stream);
574     // If we're here because of an exception, it's okay if some data
575     // got dropped.
576     CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR))
577       << rc << ": " << stream.msg;
578   };
579
580   // Max 64MiB in one go
581   constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
582   constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
583
584   auto out = addOutputBuffer(
585       &stream,
586       ((uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
587         uncompressedLength <= maxSingleStepLength) ?
588        uncompressedLength :
589        defaultBufferLength));
590
591   bool streamEnd = false;
592   for (auto& range : *data) {
593     if (range.empty()) {
594       continue;
595     }
596
597     stream.next_in = const_cast<uint8_t*>(range.data());
598     stream.avail_in = range.size();
599
600     while (stream.avail_in != 0) {
601       if (streamEnd) {
602         throw std::runtime_error(to<std::string>(
603             "ZlibCodec: junk after end of data"));
604       }
605
606       streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
607     }
608   }
609
610   while (!streamEnd) {
611     streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
612   }
613
614   out->prev()->trimEnd(stream.avail_out);
615
616   if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
617       uncompressedLength != stream.total_out) {
618     throw std::runtime_error(to<std::string>(
619         "ZlibCodec: invalid uncompressed length"));
620   }
621
622   success = true;  // we survived
623
624   return out;
625 }
626
627 typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
628
629 CodecFactory gCodecFactories[
630     static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
631   nullptr,  // USER_DEFINED
632   NoCompressionCodec::create,
633   LZ4Codec::create,
634   SnappyCodec::create,
635   ZlibCodec::create,
636   LZ4Codec::create
637 };
638
639 }  // namespace
640
641 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
642   size_t idx = static_cast<size_t>(type);
643   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
644     throw std::invalid_argument(to<std::string>(
645         "Compression type ", idx, " not supported"));
646   }
647   auto factory = gCodecFactories[idx];
648   if (!factory) {
649     throw std::invalid_argument(to<std::string>(
650         "Compression type ", idx, " not supported"));
651   }
652   auto codec = (*factory)(level, type);
653   DCHECK_EQ(static_cast<size_t>(codec->type()), idx);
654   return codec;
655 }
656
657 }}  // namespaces
658