Allow getAutoUncompressionCodec() to have 1 terminal decoder
[folly.git] / folly / io / Compression.cpp
index d02a6b3e6f1bc741f972ce5c55af8f70c28c269b..0d38652543f7588bccb850da6cb72bb18ba89b96 100644 (file)
@@ -1818,8 +1818,11 @@ std::unique_ptr<StreamCodec> getZlibStreamCodec(int level, CodecType type) {
 class AutomaticCodec final : public Codec {
  public:
   static std::unique_ptr<Codec> create(
-      std::vector<std::unique_ptr<Codec>> customCodecs);
-  explicit AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs);
+      std::vector<std::unique_ptr<Codec>> customCodecs,
+      std::unique_ptr<Codec> terminalCodec);
+  explicit AutomaticCodec(
+      std::vector<std::unique_ptr<Codec>> customCodecs,
+      std::unique_ptr<Codec> terminalCodec);
 
   std::vector<std::string> validPrefixes() const override;
   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
@@ -1846,6 +1849,7 @@ class AutomaticCodec final : public Codec {
   void checkCompatibleCodecs() const;
 
   std::vector<std::unique_ptr<Codec>> codecs_;
+  std::unique_ptr<Codec> terminalCodec_;
   bool needsUncompressedLength_;
   uint64_t maxUncompressedLength_;
 };
@@ -1877,38 +1881,70 @@ void AutomaticCodec::addCodecIfSupported(CodecType type) {
       [&type](std::unique_ptr<Codec> const& codec) {
         return codec->type() == type;
       });
-  if (hasCodec(type) && !present) {
+  bool const isTerminalType = terminalCodec_ && terminalCodec_->type() == type;
+  if (hasCodec(type) && !present && !isTerminalType) {
     codecs_.push_back(getCodec(type));
   }
 }
 
 /* static */ std::unique_ptr<Codec> AutomaticCodec::create(
-    std::vector<std::unique_ptr<Codec>> customCodecs) {
-  return std::make_unique<AutomaticCodec>(std::move(customCodecs));
-}
-
-AutomaticCodec::AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs)
-    : Codec(CodecType::USER_DEFINED), codecs_(std::move(customCodecs)) {
+    std::vector<std::unique_ptr<Codec>> customCodecs,
+    std::unique_ptr<Codec> terminalCodec) {
+  return std::make_unique<AutomaticCodec>(
+      std::move(customCodecs), std::move(terminalCodec));
+}
+
+AutomaticCodec::AutomaticCodec(
+    std::vector<std::unique_ptr<Codec>> customCodecs,
+    std::unique_ptr<Codec> terminalCodec)
+    : Codec(CodecType::USER_DEFINED),
+      codecs_(std::move(customCodecs)),
+      terminalCodec_(std::move(terminalCodec)) {
   // Fastest -> slowest
-  addCodecIfSupported(CodecType::LZ4_FRAME);
-  addCodecIfSupported(CodecType::ZSTD);
-  addCodecIfSupported(CodecType::ZLIB);
-  addCodecIfSupported(CodecType::GZIP);
-  addCodecIfSupported(CodecType::LZMA2);
-  addCodecIfSupported(CodecType::BZIP2);
+  std::array<CodecType, 6> defaultTypes{{
+      CodecType::LZ4_FRAME,
+      CodecType::ZSTD,
+      CodecType::ZLIB,
+      CodecType::GZIP,
+      CodecType::LZMA2,
+      CodecType::BZIP2,
+  }};
+
+  for (auto type : defaultTypes) {
+    addCodecIfSupported(type);
+  }
+
   if (kIsDebug) {
     checkCompatibleCodecs();
   }
-  // Check that none of the codes are are null
+
+  // Check that none of the codecs are null
   DCHECK(std::none_of(
       codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
         return codec == nullptr;
       }));
 
+  // Check that the terminal codec's type is not duplicated (with the exception
+  // of USER_DEFINED).
+  if (terminalCodec_) {
+    DCHECK(std::none_of(
+        codecs_.begin(),
+        codecs_.end(),
+        [&](std::unique_ptr<Codec> const& codec) {
+          return codec->type() != CodecType::USER_DEFINED &&
+              codec->type() == terminalCodec_->type();
+        }));
+  }
+
+  bool const terminalNeedsUncompressedLength =
+      terminalCodec_ && terminalCodec_->needsUncompressedLength();
   needsUncompressedLength_ = std::any_of(
-      codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
-        return codec->needsUncompressedLength();
-      });
+                                 codecs_.begin(),
+                                 codecs_.end(),
+                                 [](std::unique_ptr<Codec> const& codec) {
+                                   return codec->needsUncompressedLength();
+                                 }) ||
+      terminalNeedsUncompressedLength;
 
   const auto it = std::max_element(
       codecs_.begin(),
@@ -1917,7 +1953,10 @@ AutomaticCodec::AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs)
         return lhs->maxUncompressedLength() < rhs->maxUncompressedLength();
       });
   DCHECK(it != codecs_.end());
-  maxUncompressedLength_ = (*it)->maxUncompressedLength();
+  auto const terminalMaxUncompressedLength =
+      terminalCodec_ ? terminalCodec_->maxUncompressedLength() : 0;
+  maxUncompressedLength_ =
+      std::max((*it)->maxUncompressedLength(), terminalMaxUncompressedLength);
 }
 
 void AutomaticCodec::checkCompatibleCodecs() const {
@@ -1968,11 +2007,23 @@ uint64_t AutomaticCodec::doMaxUncompressedLength() const {
 std::unique_ptr<IOBuf> AutomaticCodec::doUncompress(
     const IOBuf* data,
     Optional<uint64_t> uncompressedLength) {
-  for (auto&& codec : codecs_) {
-    if (codec->canUncompress(data, uncompressedLength)) {
-      return codec->uncompress(data, uncompressedLength);
+  try {
+    for (auto&& codec : codecs_) {
+      if (codec->canUncompress(data, uncompressedLength)) {
+        return codec->uncompress(data, uncompressedLength);
+      }
     }
+  } catch (std::exception const& e) {
+    if (!terminalCodec_) {
+      throw e;
+    }
+  }
+
+  // Try terminal codec
+  if (terminalCodec_) {
+    return terminalCodec_->uncompress(data, uncompressedLength);
   }
+
   throw std::runtime_error("AutomaticCodec error: Unknown compressed data");
 }
 
@@ -2086,8 +2137,10 @@ std::unique_ptr<StreamCodec> getStreamCodec(CodecType type, int level) {
 }
 
 std::unique_ptr<Codec> getAutoUncompressionCodec(
-    std::vector<std::unique_ptr<Codec>> customCodecs) {
-  return AutomaticCodec::create(std::move(customCodecs));
+    std::vector<std::unique_ptr<Codec>> customCodecs,
+    std::unique_ptr<Codec> terminalCodec) {
+  return AutomaticCodec::create(
+      std::move(customCodecs), std::move(terminalCodec));
 }
 } // namespace io
 } // namespace folly