+ int level_;
+};
+
+std::unique_ptr<Codec> ZSTDCodec::create(int level, CodecType type) {
+ return make_unique<ZSTDCodec>(level, type);
+}
+
+ZSTDCodec::ZSTDCodec(int level, CodecType type) : Codec(type) {
+ DCHECK(type == CodecType::ZSTD);
+ switch (level) {
+ case COMPRESSION_LEVEL_FASTEST:
+ level = 1;
+ break;
+ case COMPRESSION_LEVEL_DEFAULT:
+ level = 1;
+ break;
+ case COMPRESSION_LEVEL_BEST:
+ level = 19;
+ break;
+ }
+ if (level < 1 || level > ZSTD_maxCLevel()) {
+ throw std::invalid_argument(
+ to<std::string>("ZSTD: invalid level: ", level));
+ }
+ level_ = level;
+}
+
+bool ZSTDCodec::doNeedsUncompressedLength() const {
+ return false;
+}
+
+void zstdThrowIfError(size_t rc) {
+ if (!ZSTD_isError(rc)) {
+ return;
+ }
+ throw std::runtime_error(
+ to<std::string>("ZSTD returned an error: ", ZSTD_getErrorName(rc)));
+}
+
+std::unique_ptr<IOBuf> ZSTDCodec::doCompress(const IOBuf* data) {
+ // Support earlier versions of the codec (working with a single IOBuf,
+ // and using ZSTD_decompress which requires ZSTD frame to contain size,
+ // which isn't populated by streaming API).
+ if (!data->isChained()) {
+ auto out = IOBuf::createCombined(ZSTD_compressBound(data->length()));
+ const auto rc = ZSTD_compress(
+ out->writableData(),
+ out->capacity(),
+ data->data(),
+ data->length(),
+ level_);
+ zstdThrowIfError(rc);
+ out->append(rc);
+ return out;
+ }
+
+ auto zcs = ZSTD_createCStream();
+ SCOPE_EXIT {
+ ZSTD_freeCStream(zcs);
+ };
+
+ auto rc = ZSTD_initCStream(zcs, level_);
+ zstdThrowIfError(rc);
+
+ Cursor cursor(data);
+ auto result = IOBuf::createCombined(ZSTD_compressBound(cursor.totalLength()));
+
+ ZSTD_outBuffer out;
+ out.dst = result->writableTail();
+ out.size = result->capacity();
+ out.pos = 0;
+
+ for (auto buffer = cursor.peekBytes(); !buffer.empty();) {
+ ZSTD_inBuffer in;
+ in.src = buffer.data();
+ in.size = buffer.size();
+ for (in.pos = 0; in.pos != in.size;) {
+ rc = ZSTD_compressStream(zcs, &out, &in);
+ zstdThrowIfError(rc);
+ }
+ cursor.skip(in.size);
+ buffer = cursor.peekBytes();
+ }
+
+ rc = ZSTD_endStream(zcs, &out);
+ zstdThrowIfError(rc);
+ CHECK_EQ(rc, 0);
+
+ result->append(out.pos);
+ return result;
+}
+
+std::unique_ptr<IOBuf> ZSTDCodec::doUncompress(
+ const IOBuf* data,
+ uint64_t uncompressedLength) {
+ auto zds = ZSTD_createDStream();
+ SCOPE_EXIT {
+ ZSTD_freeDStream(zds);
+ };
+
+ auto rc = ZSTD_initDStream(zds);
+ zstdThrowIfError(rc);
+
+ ZSTD_outBuffer out{};
+ ZSTD_inBuffer in{};
+
+ auto outputSize = ZSTD_DStreamOutSize();
+ if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH) {
+ outputSize = uncompressedLength;
+ } else {
+ auto decompressedSize =
+ ZSTD_getDecompressedSize(data->data(), data->length());
+ if (decompressedSize != 0 && decompressedSize < outputSize) {
+ outputSize = decompressedSize;
+ }
+ }
+
+ IOBufQueue queue(IOBufQueue::cacheChainLength());
+
+ Cursor cursor(data);
+ for (rc = 0;;) {
+ if (in.pos == in.size) {
+ auto buffer = cursor.peekBytes();
+ in.src = buffer.data();
+ in.size = buffer.size();
+ in.pos = 0;
+ cursor.skip(in.size);
+ if (rc > 1 && in.size == 0) {
+ throw std::runtime_error(to<std::string>("ZSTD: incomplete input"));
+ }
+ }
+ if (out.pos == out.size) {
+ if (out.pos != 0) {
+ queue.postallocate(out.pos);
+ }
+ auto buffer = queue.preallocate(outputSize, outputSize);
+ out.dst = buffer.first;
+ out.size = buffer.second;
+ out.pos = 0;
+ outputSize = ZSTD_DStreamOutSize();
+ }
+ rc = ZSTD_decompressStream(zds, &out, &in);
+ zstdThrowIfError(rc);
+ if (rc == 0) {
+ break;
+ }
+ }
+ if (out.pos != 0) {
+ queue.postallocate(out.pos);
+ }
+ if (in.pos != in.size || !cursor.isAtEnd()) {
+ throw std::runtime_error("ZSTD: junk after end of data");
+ }
+ if (uncompressedLength != UNKNOWN_UNCOMPRESSED_LENGTH &&
+ queue.chainLength() != uncompressedLength) {
+ throw std::runtime_error("ZSTD: invalid uncompressed length");
+ }
+
+ return queue.move();
+}
+
+#endif // FOLLY_HAVE_LIBZSTD
+
+} // namespace
+
+typedef std::unique_ptr<Codec> (*CodecFactory)(int, CodecType);
+static CodecFactory
+ codecFactories[static_cast<size_t>(CodecType::NUM_CODEC_TYPES)] = {
+ nullptr, // USER_DEFINED
+ NoCompressionCodec::create,