An intro to the upgrade mutex in the Synchronized docs
[folly.git] / folly / io / Compression.cpp
index 951619a17debccd79c0b2edfc0a5d1909e084dc2..356699cf49f2487635d6e21fc97049469d92c934 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2013 Facebook, Inc.
+ * Copyright 2016 Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * limitations under the License.
  */
 
-#include "folly/io/Compression.h"
+#include <folly/io/Compression.h>
 
+#if FOLLY_HAVE_LIBLZ4
 #include <lz4.h>
 #include <lz4hc.h>
+#endif
+
 #include <glog/logging.h>
+
+#if FOLLY_HAVE_LIBSNAPPY
 #include <snappy.h>
 #include <snappy-sinksource.h>
+#endif
+
+#if FOLLY_HAVE_LIBZ
 #include <zlib.h>
+#endif
+
+#if FOLLY_HAVE_LIBLZMA
+#include <lzma.h>
+#endif
+
+#if FOLLY_HAVE_LIBZSTD
+#include <zstd.h>
+#endif
 
-#include "folly/Conv.h"
-#include "folly/Memory.h"
-#include "folly/Portability.h"
-#include "folly/ScopeGuard.h"
-#include "folly/Varint.h"
-#include "folly/io/Cursor.h"
+#include <folly/Conv.h>
+#include <folly/Memory.h>
+#include <folly/Portability.h>
+#include <folly/ScopeGuard.h>
+#include <folly/Varint.h>
+#include <folly/io/Cursor.h>
 
 namespace folly { namespace io {
 
@@ -36,7 +53,14 @@ Codec::Codec(CodecType type) : type_(type) { }
 
 // Ensure consistent behavior in the nullptr case
 std::unique_ptr<IOBuf> Codec::compress(const IOBuf* data) {
-  return !data->empty() ? doCompress(data) : IOBuf::create(0);
+  uint64_t len = data->computeChainDataLength();
+  if (len == 0) {
+    return IOBuf::create(0);
+  } else if (len > maxUncompressedLength()) {
+    throw std::runtime_error("Codec: uncompressed length too large");
+  }
+
+  return doCompress(data);
 }
 
 std::unique_ptr<IOBuf> Codec::uncompress(const IOBuf* data,
@@ -73,7 +97,7 @@ bool Codec::doNeedsUncompressedLength() const {
 }
 
 uint64_t Codec::doMaxUncompressedLength() const {
-  return std::numeric_limits<uint64_t>::max() - 1;
+  return UNLIMITED_UNCOMPRESSED_LENGTH;
 }
 
 namespace {
@@ -81,16 +105,16 @@ namespace {
 /**
  * No compression
  */
-class NoCompressionCodec FOLLY_FINAL : public Codec {
+class NoCompressionCodec final : public Codec {
  public:
   static std::unique_ptr<Codec> create(int level, CodecType type);
   explicit NoCompressionCodec(int level, CodecType type);
 
  private:
-  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
+  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
   std::unique_ptr<IOBuf> doUncompress(
       const IOBuf* data,
-      uint64_t uncompressedLength) FOLLY_OVERRIDE;
+      uint64_t uncompressedLength) override;
 };
 
 std::unique_ptr<Codec> NoCompressionCodec::create(int level, CodecType type) {
@@ -128,24 +152,55 @@ std::unique_ptr<IOBuf> NoCompressionCodec::doUncompress(
   return data->clone();
 }
 
+#if (FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA)
+
+namespace {
+
+void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
+  DCHECK_GE(out->tailroom(), kMaxVarintLength64);
+  out->append(encodeVarint(val, out->writableTail()));
+}
+
+inline uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
+  uint64_t val = 0;
+  int8_t b = 0;
+  for (int shift = 0; shift <= 63; shift += 7) {
+    b = cursor.read<int8_t>();
+    val |= static_cast<uint64_t>(b & 0x7f) << shift;
+    if (b >= 0) {
+      break;
+    }
+  }
+  if (b < 0) {
+    throw std::invalid_argument("Invalid varint value. Too big.");
+  }
+  return val;
+}
+
+}  // namespace
+
+#endif  // FOLLY_HAVE_LIBLZ4 || FOLLY_HAVE_LIBLZMA
+
+#if FOLLY_HAVE_LIBLZ4
+
 /**
  * LZ4 compression
  */
-class LZ4Codec FOLLY_FINAL : public Codec {
+class LZ4Codec final : public Codec {
  public:
   static std::unique_ptr<Codec> create(int level, CodecType type);
   explicit LZ4Codec(int level, CodecType type);
 
  private:
-  bool doNeedsUncompressedLength() const FOLLY_OVERRIDE;
-  uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
+  bool doNeedsUncompressedLength() const override;
+  uint64_t doMaxUncompressedLength() const override;
 
   bool encodeSize() const { return type() == CodecType::LZ4_VARINT_SIZE; }
 
-  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
+  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
   std::unique_ptr<IOBuf> doUncompress(
       const IOBuf* data,
-      uint64_t uncompressedLength) FOLLY_OVERRIDE;
+      uint64_t uncompressedLength) override;
 
   bool highCompression_;
 };
@@ -177,30 +232,17 @@ bool LZ4Codec::doNeedsUncompressedLength() const {
   return !encodeSize();
 }
 
-uint64_t LZ4Codec::doMaxUncompressedLength() const {
-  // From lz4.h: "Max supported value is ~1.9GB"; I wish we had something
-  // more accurate.
-  return 1.8 * (uint64_t(1) << 30);
-}
-
-namespace {
+// The value comes from lz4.h in lz4-r117, but older versions of lz4 don't
+// define LZ4_MAX_INPUT_SIZE (even though the max size is the same), so do it
+// here.
+#ifndef LZ4_MAX_INPUT_SIZE
+# define LZ4_MAX_INPUT_SIZE 0x7E000000
+#endif
 
-void encodeVarintToIOBuf(uint64_t val, folly::IOBuf* out) {
-  DCHECK_GE(out->tailroom(), kMaxVarintLength64);
-  out->append(encodeVarint(val, out->writableTail()));
-}
-
-uint64_t decodeVarintFromCursor(folly::io::Cursor& cursor) {
-  // Must have enough room in *this* buffer.
-  auto p = cursor.peek();
-  folly::ByteRange range(p.first, p.second);
-  uint64_t val = decodeVarint(range);
-  cursor.skip(range.data() - p.first);
-  return val;
+uint64_t LZ4Codec::doMaxUncompressedLength() const {
+  return LZ4_MAX_INPUT_SIZE;
 }
 
-}  // namespace
-
 std::unique_ptr<IOBuf> LZ4Codec::doCompress(const IOBuf* data) {
   std::unique_ptr<IOBuf> clone;
   if (data->isChained()) {
@@ -255,15 +297,21 @@ std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
     }
   } else {
     actualUncompressedLength = uncompressedLength;
-    DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
+    if (actualUncompressedLength == UNKNOWN_UNCOMPRESSED_LENGTH ||
+        actualUncompressedLength > maxUncompressedLength()) {
+      throw std::runtime_error("LZ4Codec: invalid uncompressed length");
+    }
   }
 
+  auto sp = StringPiece{cursor.peekBytes()};
   auto out = IOBuf::create(actualUncompressedLength);
-  auto p = cursor.peek();
-  int n = LZ4_uncompress(reinterpret_cast<const char*>(p.first),
-                         reinterpret_cast<char*>(out->writableTail()),
-                         actualUncompressedLength);
-  if (n != p.second) {
+  int n = LZ4_decompress_safe(
+      sp.data(),
+      reinterpret_cast<char*>(out->writableTail()),
+      sp.size(),
+      actualUncompressedLength);
+
+  if (n < 0 || uint64_t(n) != actualUncompressedLength) {
     throw std::runtime_error(to<std::string>(
         "LZ4 decompression returned invalid value ", n));
   }
@@ -271,6 +319,10 @@ std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
   return out;
 }
 
+#endif  // FOLLY_HAVE_LIBLZ4
+
+#if FOLLY_HAVE_LIBSNAPPY
+
 /**
  * Snappy compression
  */
@@ -278,12 +330,12 @@ std::unique_ptr<IOBuf> LZ4Codec::doUncompress(
 /**
  * Implementation of snappy::Source that reads from a IOBuf chain.
  */
-class IOBufSnappySource FOLLY_FINAL : public snappy::Source {
+class IOBufSnappySource final : public snappy::Source {
  public:
   explicit IOBufSnappySource(const IOBuf* data);
-  size_t Available() const FOLLY_OVERRIDE;
-  const char* Peek(size_t* len) FOLLY_OVERRIDE;
-  void Skip(size_t n) FOLLY_OVERRIDE;
+  size_t Available() const override;
+  const char* Peek(size_t* len) override;
+  void Skip(size_t n) override;
  private:
   size_t available_;
   io::Cursor cursor_;
@@ -299,9 +351,9 @@ size_t IOBufSnappySource::Available() const {
 }
 
 const char* IOBufSnappySource::Peek(size_t* len) {
-  auto p = cursor_.peek();
-  *len = p.second;
-  return reinterpret_cast<const char*>(p.first);
+  auto sp = StringPiece{cursor_.peekBytes()};
+  *len = sp.size();
+  return sp.data();
 }
 
 void IOBufSnappySource::Skip(size_t n) {
@@ -310,17 +362,17 @@ void IOBufSnappySource::Skip(size_t n) {
   available_ -= n;
 }
 
-class SnappyCodec FOLLY_FINAL : public Codec {
+class SnappyCodec final : public Codec {
  public:
   static std::unique_ptr<Codec> create(int level, CodecType type);
   explicit SnappyCodec(int level, CodecType type);
 
  private:
-  uint64_t doMaxUncompressedLength() const FOLLY_OVERRIDE;
-  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
+  uint64_t doMaxUncompressedLength() const override;
+  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
   std::unique_ptr<IOBuf> doUncompress(
       const IOBuf* data,
-      uint64_t uncompressedLength) FOLLY_OVERRIDE;
+      uint64_t uncompressedLength) override;
 };
 
 std::unique_ptr<Codec> SnappyCodec::create(int level, CodecType type) {
@@ -390,19 +442,22 @@ std::unique_ptr<IOBuf> SnappyCodec::doUncompress(const IOBuf* data,
   return out;
 }
 
+#endif  // FOLLY_HAVE_LIBSNAPPY
+
+#if FOLLY_HAVE_LIBZ
 /**
  * Zlib codec
  */
-class ZlibCodec FOLLY_FINAL : public Codec {
+class ZlibCodec final : public Codec {
  public:
   static std::unique_ptr<Codec> create(int level, CodecType type);
   explicit ZlibCodec(int level, CodecType type);
 
  private:
-  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) FOLLY_OVERRIDE;
+  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
   std::unique_ptr<IOBuf> doUncompress(
       const IOBuf* data,
-      uint64_t uncompressedLength) FOLLY_OVERRIDE;
+      uint64_t uncompressedLength) override;
 
   std::unique_ptr<IOBuf> addOutputBuffer(z_stream* stream, uint32_t length);
   bool doInflate(z_stream* stream, IOBuf* head, uint32_t bufferLength);
@@ -415,7 +470,7 @@ std::unique_ptr<Codec> ZlibCodec::create(int level, CodecType type) {
 }
 
 ZlibCodec::ZlibCodec(int level, CodecType type) : Codec(type) {
-  DCHECK(type == CodecType::ZLIB);
+  DCHECK(type == CodecType::ZLIB || type == CodecType::GZIP);
   switch (level) {
   case COMPRESSION_LEVEL_FASTEST:
     level = 1;
@@ -474,14 +529,28 @@ bool ZlibCodec::doInflate(z_stream* stream,
   return false;
 }
 
-
 std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
   z_stream stream;
   stream.zalloc = nullptr;
   stream.zfree = nullptr;
   stream.opaque = nullptr;
 
-  int rc = deflateInit(&stream, level_);
+  // Using deflateInit2() to support gzip.  "The windowBits parameter is the
+  // base two logarithm of the maximum window size (...) The default value is
+  // 15 (...) Add 16 to windowBits to write a simple gzip header and trailer
+  // around the compressed data instead of a zlib wrapper. The gzip header
+  // will have no file name, no extra data, no comment, no modification time
+  // (set to zero), no header crc, and the operating system will be set to 255
+  // (unknown)."
+  int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
+  // All other parameters (method, memLevel, strategy) get default values from
+  // the zlib manual.
+  int rc = deflateInit2(&stream,
+                        level_,
+                        Z_DEFLATED,
+                        windowBits,
+                        /* memLevel */ 8,
+                        Z_DEFAULT_STRATEGY);
   if (rc != Z_OK) {
     throw std::runtime_error(to<std::string>(
         "ZlibCodec: deflateInit error: ", rc, ": ", stream.msg));
@@ -515,21 +584,25 @@ std::unique_ptr<IOBuf> ZlibCodec::doCompress(const IOBuf* data) {
        defaultBufferLength));
 
   for (auto& range : *data) {
-    if (range.empty()) {
-      continue;
-    }
-
-    stream.next_in = const_cast<uint8_t*>(range.data());
-    stream.avail_in = range.size();
-
-    while (stream.avail_in != 0) {
-      if (stream.avail_out == 0) {
-        out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
+    uint64_t remaining = range.size();
+    uint64_t written = 0;
+    while (remaining) {
+      uint32_t step = (remaining > maxSingleStepLength ?
+                       maxSingleStepLength : remaining);
+      stream.next_in = const_cast<uint8_t*>(range.data() + written);
+      stream.avail_in = step;
+      remaining -= step;
+      written += step;
+
+      while (stream.avail_in != 0) {
+        if (stream.avail_out == 0) {
+          out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
+        }
+
+        rc = deflate(&stream, Z_NO_FLUSH);
+
+        CHECK_EQ(rc, Z_OK) << stream.msg;
       }
-
-      rc = deflate(&stream, Z_NO_FLUSH);
-
-      CHECK_EQ(rc, Z_OK) << stream.msg;
     }
   }
 
@@ -557,7 +630,11 @@ std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
   stream.zfree = nullptr;
   stream.opaque = nullptr;
 
-  int rc = inflateInit(&stream);
+  // "The windowBits parameter is the base two logarithm of the maximum window
+  // size (...) The default value is 15 (...) add 16 to decode only the gzip
+  // format (the zlib format will return a Z_DATA_ERROR)."
+  int windowBits = 15 + (type() == CodecType::GZIP ? 16 : 0);
+  int rc = inflateInit2(&stream, windowBits);
   if (rc != Z_OK) {
     throw std::runtime_error(to<std::string>(
         "ZlibCodec: inflateInit error: ", rc, ": ", stream.msg));
@@ -624,27 +701,402 @@ std::unique_ptr<IOBuf> ZlibCodec::doUncompress(const IOBuf* data,
   return out;
 }
 
-typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
+#endif  // FOLLY_HAVE_LIBZ
 
-CodecFactory gCodecFactories[
-    static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
-  nullptr,  // USER_DEFINED
-  NoCompressionCodec::create,
-  LZ4Codec::create,
-  SnappyCodec::create,
-  ZlibCodec::create,
-  LZ4Codec::create
+#if FOLLY_HAVE_LIBLZMA
+
+/**
+ * LZMA2 compression
+ */
+class LZMA2Codec final : public Codec {
+ public:
+  static std::unique_ptr<Codec> create(int level, CodecType type);
+  explicit LZMA2Codec(int level, CodecType type);
+
+ private:
+  bool doNeedsUncompressedLength() const override;
+  uint64_t doMaxUncompressedLength() const override;
+
+  bool encodeSize() const { return type() == CodecType::LZMA2_VARINT_SIZE; }
+
+  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
+  std::unique_ptr<IOBuf> doUncompress(
+      const IOBuf* data,
+      uint64_t uncompressedLength) override;
+
+  std::unique_ptr<IOBuf> addOutputBuffer(lzma_stream* stream, size_t length);
+  bool doInflate(lzma_stream* stream, IOBuf* head, size_t bufferLength);
+
+  int level_;
 };
 
+std::unique_ptr<Codec> LZMA2Codec::create(int level, CodecType type) {
+  return make_unique<LZMA2Codec>(level, type);
+}
+
+LZMA2Codec::LZMA2Codec(int level, CodecType type) : Codec(type) {
+  DCHECK(type == CodecType::LZMA2 || type == CodecType::LZMA2_VARINT_SIZE);
+  switch (level) {
+  case COMPRESSION_LEVEL_FASTEST:
+    level = 0;
+    break;
+  case COMPRESSION_LEVEL_DEFAULT:
+    level = LZMA_PRESET_DEFAULT;
+    break;
+  case COMPRESSION_LEVEL_BEST:
+    level = 9;
+    break;
+  }
+  if (level < 0 || level > 9) {
+    throw std::invalid_argument(to<std::string>(
+        "LZMA2Codec: invalid level: ", level));
+  }
+  level_ = level;
+}
+
+bool LZMA2Codec::doNeedsUncompressedLength() const {
+  return !encodeSize();
+}
+
+uint64_t LZMA2Codec::doMaxUncompressedLength() const {
+  // From lzma/base.h: "Stream is roughly 8 EiB (2^63 bytes)"
+  return uint64_t(1) << 63;
+}
+
+std::unique_ptr<IOBuf> LZMA2Codec::addOutputBuffer(
+    lzma_stream* stream,
+    size_t length) {
+
+  CHECK_EQ(stream->avail_out, 0);
+
+  auto buf = IOBuf::create(length);
+  buf->append(length);
+
+  stream->next_out = buf->writableData();
+  stream->avail_out = buf->length();
+
+  return buf;
+}
+
+std::unique_ptr<IOBuf> LZMA2Codec::doCompress(const IOBuf* data) {
+  lzma_ret rc;
+  lzma_stream stream = LZMA_STREAM_INIT;
+
+  rc = lzma_easy_encoder(&stream, level_, LZMA_CHECK_NONE);
+  if (rc != LZMA_OK) {
+    throw std::runtime_error(folly::to<std::string>(
+      "LZMA2Codec: lzma_easy_encoder error: ", rc));
+  }
+
+  SCOPE_EXIT { lzma_end(&stream); };
+
+  uint64_t uncompressedLength = data->computeChainDataLength();
+  uint64_t maxCompressedLength = lzma_stream_buffer_bound(uncompressedLength);
+
+  // Max 64MiB in one go
+  constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
+  constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
+
+  auto out = addOutputBuffer(
+    &stream,
+    (maxCompressedLength <= maxSingleStepLength ?
+     maxCompressedLength :
+     defaultBufferLength));
+
+  if (encodeSize()) {
+    auto size = IOBuf::createCombined(kMaxVarintLength64);
+    encodeVarintToIOBuf(uncompressedLength, size.get());
+    size->appendChain(std::move(out));
+    out = std::move(size);
+  }
+
+  for (auto& range : *data) {
+    if (range.empty()) {
+      continue;
+    }
+
+    stream.next_in = const_cast<uint8_t*>(range.data());
+    stream.avail_in = range.size();
+
+    while (stream.avail_in != 0) {
+      if (stream.avail_out == 0) {
+        out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
+      }
+
+      rc = lzma_code(&stream, LZMA_RUN);
+
+      if (rc != LZMA_OK) {
+        throw std::runtime_error(folly::to<std::string>(
+          "LZMA2Codec: lzma_code error: ", rc));
+      }
+    }
+  }
+
+  do {
+    if (stream.avail_out == 0) {
+      out->prependChain(addOutputBuffer(&stream, defaultBufferLength));
+    }
+
+    rc = lzma_code(&stream, LZMA_FINISH);
+  } while (rc == LZMA_OK);
+
+  if (rc != LZMA_STREAM_END) {
+    throw std::runtime_error(folly::to<std::string>(
+      "LZMA2Codec: lzma_code ended with error: ", rc));
+  }
+
+  out->prev()->trimEnd(stream.avail_out);
+
+  return out;
+}
+
+bool LZMA2Codec::doInflate(lzma_stream* stream,
+                          IOBuf* head,
+                          size_t bufferLength) {
+  if (stream->avail_out == 0) {
+    head->prependChain(addOutputBuffer(stream, bufferLength));
+  }
+
+  lzma_ret rc = lzma_code(stream, LZMA_RUN);
+
+  switch (rc) {
+  case LZMA_OK:
+    break;
+  case LZMA_STREAM_END:
+    return true;
+  default:
+    throw std::runtime_error(to<std::string>(
+        "LZMA2Codec: lzma_code error: ", rc));
+  }
+
+  return false;
+}
+
+std::unique_ptr<IOBuf> LZMA2Codec::doUncompress(const IOBuf* data,
+                                               uint64_t uncompressedLength) {
+  lzma_ret rc;
+  lzma_stream stream = LZMA_STREAM_INIT;
+
+  rc = lzma_auto_decoder(&stream, std::numeric_limits<uint64_t>::max(), 0);
+  if (rc != LZMA_OK) {
+    throw std::runtime_error(folly::to<std::string>(
+      "LZMA2Codec: lzma_auto_decoder error: ", rc));
+  }
+
+  SCOPE_EXIT { lzma_end(&stream); };
+
+  // Max 64MiB in one go
+  constexpr uint32_t maxSingleStepLength = uint32_t(64) << 20;    // 64MiB
+  constexpr uint32_t defaultBufferLength = uint32_t(4) << 20;     // 4MiB
+
+  folly::io::Cursor cursor(data);
+  uint64_t actualUncompressedLength;
+  if (encodeSize()) {
+    actualUncompressedLength = decodeVarintFromCursor(cursor);
+    if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
+        uncompressedLength != actualUncompressedLength) {
+      throw std::runtime_error("LZMA2Codec: invalid uncompressed length");
+    }
+  } else {
+    actualUncompressedLength = uncompressedLength;
+    DCHECK_NE(actualUncompressedLength, UNKNOWN_UNCOMPRESSED_LENGTH);
+  }
+
+  auto out = addOutputBuffer(
+      &stream,
+      (actualUncompressedLength <= maxSingleStepLength ?
+       actualUncompressedLength :
+       defaultBufferLength));
+
+  bool streamEnd = false;
+  auto buf = cursor.peekBytes();
+  while (!buf.empty()) {
+    stream.next_in = const_cast<uint8_t*>(buf.data());
+    stream.avail_in = buf.size();
+
+    while (stream.avail_in != 0) {
+      if (streamEnd) {
+        throw std::runtime_error(to<std::string>(
+            "LZMA2Codec: junk after end of data"));
+      }
+
+      streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
+    }
+
+    cursor.skip(buf.size());
+    buf = cursor.peekBytes();
+  }
+
+  while (!streamEnd) {
+    streamEnd = doInflate(&stream, out.get(), defaultBufferLength);
+  }
+
+  out->prev()->trimEnd(stream.avail_out);
+
+  if (actualUncompressedLength != stream.total_out) {
+    throw std::runtime_error(to<std::string>(
+        "LZMA2Codec: invalid uncompressed length"));
+  }
+
+  return out;
+}
+
+#endif  // FOLLY_HAVE_LIBLZMA
+
+#ifdef FOLLY_HAVE_LIBZSTD
+
+/**
+ * ZSTD_BETA compression
+ */
+class ZSTDCodec final : public Codec {
+ public:
+  static std::unique_ptr<Codec> create(int level, CodecType);
+  explicit ZSTDCodec(int level, CodecType type);
+
+ private:
+  bool doNeedsUncompressedLength() const override;
+  std::unique_ptr<IOBuf> doCompress(const IOBuf* data) override;
+  std::unique_ptr<IOBuf> doUncompress(
+      const IOBuf* data,
+      uint64_t uncompressedLength) override;
+
+  int level_{1};
+};
+
+std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
+  return make_unique<ZSTDCodec>(level, type);
+}
+
+ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
+  DCHECK(type == CodecType::ZSTD_BETA);
+  switch (level) {
+    case COMPRESSION_LEVEL_FASTEST:
+      level_ = 1;
+      break;
+    case COMPRESSION_LEVEL_DEFAULT:
+      level_ = 1;
+      break;
+    case COMPRESSION_LEVEL_BEST:
+      level_ = 19;
+      break;
+  }
+}
+
+bool ZSTDCodec::doNeedsUncompressedLength() const {
+  return true;
+}
+
+std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
+  size_t rc;
+  size_t maxCompressedLength = ZSTD_compressBound(data->length());
+  auto out = IOBuf::createCombined(maxCompressedLength);
+
+  CHECK_EQ(out->length(), 0);
+
+  rc = ZSTD_compress(out->writableTail(),
+                     out->capacity(),
+                     data->data(),
+                     data->length(),
+                     level_);
+
+  if (ZSTD_isError(rc)) {
+    throw std::runtime_error(to<std::string>(
+          "ZSTD compression returned an error: ",
+          ZSTD_getErrorName(rc)));
+  }
+
+  out->append(rc);
+  CHECK_EQ(out->length(), rc);
+
+  return out;
+}
+
+std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(const IOBuf* data,
+                                               uint64_t uncompressedLength) {
+  size_t rc;
+  auto out = IOBuf::createCombined(uncompressedLength);
+
+  CHECK_GE(out->capacity(), uncompressedLength);
+  CHECK_EQ(out->length(), 0);
+
+  rc = ZSTD_decompress(
+      out->writableTail(), out->capacity(), data->data(), data->length());
+
+  if (ZSTD_isError(rc)) {
+    throw std::runtime_error(to<std::string>(
+          "ZSTD decompression returned an error: ",
+          ZSTD_getErrorName(rc)));
+  }
+
+  out->append(rc);
+  CHECK_EQ(out->length(), rc);
+
+  return out;
+}
+
+#endif  // FOLLY_HAVE_LIBZSTD
+
 }  // namespace
 
 std::unique_ptr<Codec> getCodec(CodecType type, int level) {
+  typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
+
+  static CodecFactory codecFactories[
+    static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
+    nullptr,  // USER_DEFINED
+    NoCompressionCodec::create,
+
+#if FOLLY_HAVE_LIBLZ4
+    LZ4Codec::create,
+#else
+    nullptr,
+#endif
+
+#if FOLLY_HAVE_LIBSNAPPY
+    SnappyCodec::create,
+#else
+    nullptr,
+#endif
+
+#if FOLLY_HAVE_LIBZ
+    ZlibCodec::create,
+#else
+    nullptr,
+#endif
+
+#if FOLLY_HAVE_LIBLZ4
+    LZ4Codec::create,
+#else
+    nullptr,
+#endif
+
+#if FOLLY_HAVE_LIBLZMA
+    LZMA2Codec::create,
+    LZMA2Codec::create,
+#else
+    nullptr,
+    nullptr,
+#endif
+
+#if FOLLY_HAVE_LIBZSTD
+    ZSTDCodec::create,
+#else
+    nullptr,
+#endif
+
+#if FOLLY_HAVE_LIBZ
+    ZlibCodec::create,
+#else
+    nullptr,
+#endif
+  };
+
   size_t idx = static_cast<size_t>(type);
   if (idx >= static_cast<size_t>(CodecType::NUM_CODEC_TYPES)) {
     throw std::invalid_argument(to<std::string>(
         "Compression type ", idx, " not supported"));
   }
-  auto factory = gCodecFactories[idx];
+  auto factory = codecFactories[idx];
   if (!factory) {
     throw std::invalid_argument(to<std::string>(
         "Compression type ", idx, " not supported"));
@@ -655,4 +1107,3 @@ std::unique_ptr<Codec> getCodec(CodecType type, int level) {
 }
 
 }}  // namespaces
-