Optimize ZSTDCodec::doUncompress()
authorNick Terrell <terrelln@fb.com>
Fri, 24 Mar 2017 19:18:57 +0000 (12:18 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 24 Mar 2017 19:21:52 +0000 (12:21 -0700)
Summary:
It is ~10% faster to call `ZSTD_decompress()` than use the
streaming API. The streaming API does some extra `memcpy`s that we
can avoid. We are working on improving the speed of the streaming
API in the case where all the data can be processed in one shot,
but that won't be available in the stable ZSTD api for a few versions.

Reviewed By: yfeldblum

Differential Revision: D4731058

fbshipit-source-id: 39026c499c0f5002466097b5afe7e30f850e0ae8

folly/io/Compression.cpp
folly/io/test/CompressionTest.cpp

index e0beac2..45a4d97 100644 (file)
@@ -1270,7 +1270,28 @@ std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
   return result;
 }
 
-std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
+static std::unique_ptr<IOBuf> zstdUncompressBuffer(
+    const IOBuf* data,
+    uint64_t uncompressedLength) {
+  // Check preconditions
+  DCHECK(!data->isChained());
+  DCHECK(uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH);
+
+  auto uncompressed = IOBuf::create(uncompressedLength);
+  const auto decompressedSize = ZSTD_decompress(
+      uncompressed->writableTail(),
+      uncompressed->tailroom(),
+      data->data(),
+      data->length());
+  zstdThrowIfError(decompressedSize);
+  if (decompressedSize != uncompressedLength) {
+    throw std::runtime_error("ZSTD: invalid uncompressed length");
+  }
+  uncompressed->append(decompressedSize);
+  return uncompressed;
+}
+
+static std::unique_ptr<IOBuf> zstdUncompressStream(
     const IOBuf* data,
     uint64_t uncompressedLength) {
   auto zds = ZSTD_createDStream();
@@ -1285,14 +1306,8 @@ std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
   ZSTD_inBuffer in{};
 
   auto outputSize = ZSTD_DStreamOutSize();
-  if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
+  if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH) {
     outputSize = uncompressedLength;
-  } else {
-    auto decompressedSize =
-        ZSTD_getDecompressedSize(data->data(), data->length());
-    if (decompressedSize != 0 && decompressedSize < outputSize) {
-      outputSize = decompressedSize;
-    }
   }
 
   IOBufQueue queue(IOBufQueue::cacheChainLength());
@@ -1331,7 +1346,7 @@ std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
   if (in.pos != in.size || !cursor.isAtEnd()) {
     throw std::runtime_error("ZSTD: junk after end of data");
   }
-  if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
+  if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH &&
       queue.chainLength() != uncompressedLength) {
     throw std::runtime_error("ZSTD: invalid uncompressed length");
   }
@@ -1339,6 +1354,29 @@ std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
   return queue.move();
 }
 
+std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
+    const IOBuf* data,
+    uint64_t uncompressedLength) {
+  {
+    // Read decompressed size from frame if available in first IOBuf.
+    const auto decompressedSize =
+        ZSTD_getDecompressedSize(data->data(), data->length());
+    if (decompressedSize != 0) {
+      if (uncompressedLength != Codec::UNKNOWN_UNCOMPRESSED_LENGTH &&
+          uncompressedLength != decompressedSize) {
+        throw std::runtime_error("ZSTD: invalid uncompressed length");
+      }
+      uncompressedLength = decompressedSize;
+    }
+  }
+  // Faster to decompress using ZSTD_decompress() if we can.
+  if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH && !data->isChained()) {
+    return zstdUncompressBuffer(data, uncompressedLength);
+  }
+  // Fall back to slower streaming decompression.
+  return zstdUncompressStream(data, uncompressedLength);
+}
+
 #endif  // FOLLY_HAVE_LIBZSTD
 
 }  // namespace
index b1599f8..a97b5bc 100644 (file)
@@ -392,6 +392,7 @@ INSTANTIATE_TEST_CASE_P(
         supportedCodecs({
             CodecType::SNAPPY,
             CodecType::ZLIB,
+            CodecType::ZSTD,
             CodecType::LZ4_FRAME,
         })));
 }}}  // namespaces