X-Git-Url: http://plrg.eecs.uci.edu/git/?a=blobdiff_plain;f=folly%2Fio%2FCompression.cpp;h=c626be74b3949e21adfe39fe9ce7eb0aafd112c9;hb=69bd8deb890fb6d2c648815e8e986f1cd1b93fac;hp=97c62f0705a84b49cf1673fa1719181c65e8fe09;hpb=8fe79ff17abec2d7d254cdfbe2a4c13f954493ad;p=folly.git diff --git a/folly/io/Compression.cpp b/folly/io/Compression.cpp index 97c62f07..c626be74 100644 --- a/folly/io/Compression.cpp +++ b/folly/io/Compression.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2015 Facebook, Inc. + * Copyright 2017 Facebook, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -36,6 +36,10 @@ #include #endif +#if FOLLY_HAVE_LIBZSTD +#include +#endif + #include #include #include @@ -52,13 +56,26 @@ std::unique_ptr Codec::compress(const IOBuf* data) { uint64_t len = data->computeChainDataLength(); if (len == 0) { return IOBuf::create(0); - } else if (len > maxUncompressedLength()) { + } + if (len > maxUncompressedLength()) { throw std::runtime_error("Codec: uncompressed length too large"); } return doCompress(data); } +std::string Codec::compress(const StringPiece data) { + const uint64_t len = data.size(); + if (len == 0) { + return ""; + } + if (len > maxUncompressedLength()) { + throw std::runtime_error("Codec: uncompressed length too large"); + } + + return doCompressString(data); +} + std::unique_ptr Codec::uncompress(const IOBuf* data, uint64_t uncompressedLength) { if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) { @@ -80,6 +97,28 @@ std::unique_ptr Codec::uncompress(const IOBuf* data, return doUncompress(data, uncompressedLength); } +std::string Codec::uncompress( + const StringPiece data, + uint64_t uncompressedLength) { + if (uncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH) { + if (needsUncompressedLength()) { + throw std::invalid_argument("Codec: uncompressed length required"); + } + } else if (uncompressedLength > maxUncompressedLength()) { + throw std::runtime_error("Codec: uncompressed length too large"); + } + + if (data.empty()) { + if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH && + uncompressedLength != 0) { + throw std::runtime_error("Codec: invalid uncompressed length"); + } + return ""; + } + + return doUncompressString(data, uncompressedLength); +} + bool Codec::needsUncompressedLength() const { return doNeedsUncompressedLength(); } @@ -96,6 +135,30 @@ uint64_t Codec::doMaxUncompressedLength() const { return UNLIMITED_UNCOMPRESSED_LENGTH; } +std::string Codec::doCompressString(const StringPiece data) { + const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data}; + auto outputBuffer = doCompress(&inputBuffer); + std::string output; + output.reserve(outputBuffer->computeChainDataLength()); + for (auto range : *outputBuffer) { + output.append(reinterpret_cast(range.data()), range.size()); + } + return output; +} + +std::string Codec::doUncompressString( + const StringPiece data, + uint64_t uncompressedLength) { + const IOBuf inputBuffer{IOBuf::WRAP_BUFFER, data}; + auto outputBuffer = doUncompress(&inputBuffer, uncompressedLength); + std::string output; + output.reserve(outputBuffer->computeChainDataLength()); + for (auto range : *outputBuffer) { + output.append(reinterpret_cast(range.data()), range.size()); + } + return output; +} + namespace { /** @@ -240,12 +303,11 @@ uint64_t LZ4Codec::doMaxUncompressedLength() const { } std::unique_ptr LZ4Codec::doCompress(const IOBuf* data) { - std::unique_ptr clone; + IOBuf clone; if (data->isChained()) { // LZ4 doesn't support streaming, so we have to coalesce - clone = data->clone(); - clone->coalesce(); - data = clone.get(); + clone = data->cloneCoalescedAsValue(); + data = &clone; } uint32_t extraSize = encodeSize() ? kMaxVarintLength64 : 0; @@ -255,15 +317,22 @@ std::unique_ptr LZ4Codec::doCompress(const IOBuf* data) { } int n; + auto input = reinterpret_cast(data->data()); + auto output = reinterpret_cast(out->writableTail()); + const auto inputLength = data->length(); +#if LZ4_VERSION_NUMBER >= 10700 + if (highCompression_) { + n = LZ4_compress_HC(input, output, inputLength, out->tailroom(), 0); + } else { + n = LZ4_compress_default(input, output, inputLength, out->tailroom()); + } +#else if (highCompression_) { - n = LZ4_compressHC(reinterpret_cast(data->data()), - reinterpret_cast(out->writableTail()), - data->length()); + n = LZ4_compressHC(input, output, inputLength); } else { - n = LZ4_compress(reinterpret_cast(data->data()), - reinterpret_cast(out->writableTail()), - data->length()); + n = LZ4_compress(input, output, inputLength); } +#endif CHECK_GE(n, 0); CHECK_LE(n, out->capacity()); @@ -275,12 +344,11 @@ std::unique_ptr LZ4Codec::doCompress(const IOBuf* data) { std::unique_ptr LZ4Codec::doUncompress( const IOBuf* data, uint64_t uncompressedLength) { - std::unique_ptr clone; + IOBuf clone; if (data->isChained()) { // LZ4 doesn't support streaming, so we have to coalesce - clone = data->clone(); - clone->coalesce(); - data = clone.get(); + clone = data->cloneCoalescedAsValue(); + data = &clone; } folly::io::Cursor cursor(data); @@ -299,12 +367,13 @@ std::unique_ptr LZ4Codec::doUncompress( } } - auto p = cursor.peek(); + auto sp = StringPiece{cursor.peekBytes()}; auto out = IOBuf::create(actualUncompressedLength); - int n = LZ4_decompress_safe(reinterpret_cast(p.first), - reinterpret_cast(out->writableTail()), - p.second, - actualUncompressedLength); + int n = LZ4_decompress_safe( + sp.data(), + reinterpret_cast(out->writableTail()), + sp.size(), + actualUncompressedLength); if (n < 0 || uint64_t(n) != actualUncompressedLength) { throw std::runtime_error(to( @@ -346,9 +415,9 @@ size_t IOBufSnappySource::Available() const { } const char* IOBufSnappySource::Peek(size_t* len) { - auto p = cursor_.peek(); - *len = p.second; - return reinterpret_cast(p.first); + auto sp = StringPiece{cursor_.peekBytes()}; + *len = sp.size(); + return sp.data(); } void IOBufSnappySource::Skip(size_t n) { @@ -465,7 +534,7 @@ std::unique_ptr ZlibCodec::create(int level, CodecType type) { } ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) { - DCHECK(type == CodecType::ZLIB); + DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP); switch (level) { case COMPRESSION_LEVEL_FASTEST: level = 1; @@ -530,7 +599,22 @@ std::unique_ptr ZlibCodec::doCompress(const IOBuf* data) { stream.zfree = nullptr; stream.opaque = nullptr; - int rc = deflateInit(&stream, level_); + // Using deflateInit2() to support gzip. "The windowBits parameter is the + // base two logarithm of the maximum window size (...) The default value is + // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer + // around the compressed data instead of a zlib wrapper. The gzip header + // will have no file name, no extra data, no comment, no modification time + // (set to zero), no header crc, and the operating system will be set to 255 + // (unknown)." + int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0); + // All other parameters (method, memLevel, strategy) get default values from + // the zlib manual. + int rc = deflateInit2(&stream, + level_, + Z_DEFLATED, + windowBits, + /* memLevel */ 8, + Z_DEFAULT_STRATEGY); if (rc != Z_OK) { throw std::runtime_error(to( "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg)); @@ -543,7 +627,7 @@ std::unique_ptr ZlibCodec::doCompress(const IOBuf* data) { bool success = false; SCOPE_EXIT { - int rc = deflateEnd(&stream); + rc = deflateEnd(&stream); // If we're here because of an exception, it's okay if some data // got dropped. CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR)) @@ -610,7 +694,11 @@ std::unique_ptr ZlibCodec::doUncompress(const IOBuf* data, stream.zfree = nullptr; stream.opaque = nullptr; - int rc = inflateInit(&stream); + // "The windowBits parameter is the base two logarithm of the maximum window + // size (...) The default value is 15 (...) add 16 to decode only the gzip + // format (the zlib format will return a Z_DATA_ERROR)." + int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0); + int rc = inflateInit2(&stream, windowBits); if (rc != Z_OK) { throw std::runtime_error(to( "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg)); @@ -623,7 +711,7 @@ std::unique_ptr ZlibCodec::doUncompress(const IOBuf* data, bool success = false; SCOPE_EXIT { - int rc = inflateEnd(&stream); + rc = inflateEnd(&stream); // If we're here because of an exception, it's okay if some data // got dropped. CHECK(rc == Z_OK || (!success && rc == Z_DATA_ERROR)) @@ -885,10 +973,10 @@ std::unique_ptr LZMA2Codec::doUncompress(const IOBuf* data, defaultBufferLength)); bool streamEnd = false; - auto buf = cursor.peek(); - while (buf.second != 0) { - stream.next_in = const_cast(buf.first); - stream.avail_in = buf.second; + auto buf = cursor.peekBytes(); + while (!buf.empty()) { + stream.next_in = const_cast(buf.data()); + stream.avail_in = buf.size(); while (stream.avail_in != 0) { if (streamEnd) { @@ -899,8 +987,8 @@ std::unique_ptr LZMA2Codec::doUncompress(const IOBuf* data, streamEnd = doInflate(&stream, out.get(), defaultBufferLength); } - cursor.skip(buf.second); - buf = cursor.peek(); + cursor.skip(buf.size()); + buf = cursor.peekBytes(); } while (!streamEnd) { @@ -919,53 +1007,253 @@ std::unique_ptr LZMA2Codec::doUncompress(const IOBuf* data, #endif // FOLLY_HAVE_LIBLZMA -} // namespace +#ifdef FOLLY_HAVE_LIBZSTD -std::unique_ptr getCodec(CodecType type, int level) { - typedef std::unique_ptr (*CodecFactory)(int, CodecType); +/** + * ZSTD compression + */ +class ZSTDCodec final : public Codec { + public: + static std::unique_ptr create(int level, CodecType); + explicit ZSTDCodec(int level, CodecType type); + + private: + bool doNeedsUncompressedLength() const override; + std::unique_ptr doCompress(const IOBuf* data) override; + std::unique_ptr doUncompress( + const IOBuf* data, + uint64_t uncompressedLength) override; + + int level_; +}; + +std::unique_ptr ZSTDCodec::create(int level, CodecType type) { + return make_unique(level, type); +} + +ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) { + DCHECK(type == CodecType::ZSTD); + switch (level) { + case COMPRESSION_LEVEL_FASTEST: + level = 1; + break; + case COMPRESSION_LEVEL_DEFAULT: + level = 1; + break; + case COMPRESSION_LEVEL_BEST: + level = 19; + break; + } + if (level < 1 || level > ZSTD_maxCLevel()) { + throw std::invalid_argument( + to("ZSTD: invalid level: ", level)); + } + level_ = level; +} - static CodecFactory codecFactories[ - static_cast(CodecType::NUM_CODEC_TYPES)] = { - nullptr, // USER_DEFINED - NoCompressionCodec::create, +bool ZSTDCodec::doNeedsUncompressedLength() const { + return false; +} + +void zstdThrowIfError(size_t rc) { + if (!ZSTD_isError(rc)) { + return; + } + throw std::runtime_error( + to("ZSTD returned an error: ", ZSTD_getErrorName(rc))); +} + +std::unique_ptr ZSTDCodec::doCompress(const IOBuf* data) { + // Support earlier versions of the codec (working with a single IOBuf, + // and using ZSTD_decompress which requires ZSTD frame to contain size, + // which isn't populated by streaming API). + if (!data->isChained()) { + auto out = IOBuf::createCombined(ZSTD_compressBound(data->length())); + const auto rc = ZSTD_compress( + out->writableData(), + out->capacity(), + data->data(), + data->length(), + level_); + zstdThrowIfError(rc); + out->append(rc); + return out; + } + + auto zcs = ZSTD_createCStream(); + SCOPE_EXIT { + ZSTD_freeCStream(zcs); + }; + + auto rc = ZSTD_initCStream(zcs, level_); + zstdThrowIfError(rc); + + Cursor cursor(data); + auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength())); + + ZSTD_outBuffer out; + out.dst = result->writableTail(); + out.size = result->capacity(); + out.pos = 0; + + for (auto buffer = cursor.peekBytes(); !buffer.empty();) { + ZSTD_inBuffer in; + in.src = buffer.data(); + in.size = buffer.size(); + for (in.pos = 0; in.pos != in.size;) { + rc = ZSTD_compressStream(zcs, &out, &in); + zstdThrowIfError(rc); + } + cursor.skip(in.size); + buffer = cursor.peekBytes(); + } + + rc = ZSTD_endStream(zcs, &out); + zstdThrowIfError(rc); + CHECK_EQ(rc, 0); + + result->append(out.pos); + return result; +} + +std::unique_ptr ZSTDCodec::doUncompress( + const IOBuf* data, + uint64_t uncompressedLength) { + auto zds = ZSTD_createDStream(); + SCOPE_EXIT { + ZSTD_freeDStream(zds); + }; + + auto rc = ZSTD_initDStream(zds); + zstdThrowIfError(rc); + + ZSTD_outBuffer out{}; + ZSTD_inBuffer in{}; + + auto outputSize = ZSTD_DStreamOutSize(); + if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) { + outputSize = uncompressedLength; + } else { + auto decompressedSize = + ZSTD_getDecompressedSize(data->data(), data->length()); + if (decompressedSize != 0 && decompressedSize < outputSize) { + outputSize = decompressedSize; + } + } + + IOBufQueue queue(IOBufQueue::cacheChainLength()); + + Cursor cursor(data); + for (rc = 0;;) { + if (in.pos == in.size) { + auto buffer = cursor.peekBytes(); + in.src = buffer.data(); + in.size = buffer.size(); + in.pos = 0; + cursor.skip(in.size); + if (rc > 1 && in.size == 0) { + throw std::runtime_error(to("ZSTD: incomplete input")); + } + } + if (out.pos == out.size) { + if (out.pos != 0) { + queue.postallocate(out.pos); + } + auto buffer = queue.preallocate(outputSize, outputSize); + out.dst = buffer.first; + out.size = buffer.second; + out.pos = 0; + outputSize = ZSTD_DStreamOutSize(); + } + rc = ZSTD_decompressStream(zds, &out, &in); + zstdThrowIfError(rc); + if (rc == 0) { + break; + } + } + if (out.pos != 0) { + queue.postallocate(out.pos); + } + if (in.pos != in.size || !cursor.isAtEnd()) { + throw std::runtime_error("ZSTD: junk after end of data"); + } + if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH && + queue.chainLength() != uncompressedLength) { + throw std::runtime_error("ZSTD: invalid uncompressed length"); + } + + return queue.move(); +} + +#endif // FOLLY_HAVE_LIBZSTD + +} // namespace + +typedef std::unique_ptr (*CodecFactory)(int, CodecType); +static CodecFactory + codecFactories[static_cast(CodecType::NUM_CODEC_TYPES)] = { + nullptr, // USER_DEFINED + NoCompressionCodec::create, #if FOLLY_HAVE_LIBLZ4 - LZ4Codec::create, + LZ4Codec::create, #else - nullptr, + nullptr, #endif #if FOLLY_HAVE_LIBSNAPPY - SnappyCodec::create, + SnappyCodec::create, #else - nullptr, + nullptr, #endif #if FOLLY_HAVE_LIBZ - ZlibCodec::create, + ZlibCodec::create, #else - nullptr, + nullptr, #endif #if FOLLY_HAVE_LIBLZ4 - LZ4Codec::create, + LZ4Codec::create, #else - nullptr, + nullptr, #endif #if FOLLY_HAVE_LIBLZMA - LZMA2Codec::create, - LZMA2Codec::create, + LZMA2Codec::create, + LZMA2Codec::create, #else - nullptr, - nullptr, + nullptr, + nullptr, #endif - }; +#if FOLLY_HAVE_LIBZSTD + ZSTDCodec::create, +#else + nullptr, +#endif + +#if FOLLY_HAVE_LIBZ + ZlibCodec::create, +#else + nullptr, +#endif +}; + +bool hasCodec(CodecType type) { size_t idx = static_cast(type); if (idx >= static_cast(CodecType::NUM_CODEC_TYPES)) { - throw std::invalid_argument(to( - "Compression type ", idx, " not supported")); + throw std::invalid_argument( + to("Compression type ", idx, " invalid")); + } + return codecFactories[idx] != nullptr; +} + +std::unique_ptr getCodec(CodecType type, int level) { + size_t idx = static_cast(type); + if (idx >= static_cast(CodecType::NUM_CODEC_TYPES)) { + throw std::invalid_argument( + to("Compression type ", idx, " invalid")); } auto factory = codecFactories[idx]; if (!factory) {