Allow getAutoUncompressionCodec() to have 1 terminal decoder
authorStella Lau <laus@fb.com>
Fri, 15 Sep 2017 17:13:08 +0000 (10:13 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 15 Sep 2017 17:21:54 +0000 (10:21 -0700)
Summary: getAutoUncompressionCodec() currently only allows unambiguous headers. Allow a single "terminal codec" to be called if all other codecs can't uncompress or throw.

Reviewed By: terrelln

Differential Revision: D5804833

fbshipit-source-id: 057cb6e13a48fea20508d5c028234afddf7435f6

folly/io/Compression.cpp
folly/io/Compression.h
folly/io/test/CompressionTest.cpp

index d02a6b3e6f1bc741f972ce5c55af8f70c28c269b..0d38652543f7588bccb850da6cb72bb18ba89b96 100644 (file)
@@ -1818,8 +1818,11 @@ std::unique_ptr<StreamCodec> getZlibStreamCodec(int level, CodecType type) {
 class AutomaticCodec final : public Codec {
  public:
   static std::unique_ptr<Codec> create(
-      std::vector<std::unique_ptr<Codec>> customCodecs);
-  explicit AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs);
+      std::vector<std::unique_ptr<Codec>> customCodecs,
+      std::unique_ptr<Codec> terminalCodec);
+  explicit AutomaticCodec(
+      std::vector<std::unique_ptr<Codec>> customCodecs,
+      std::unique_ptr<Codec> terminalCodec);
 
   std::vector<std::string> validPrefixes() const override;
   bool canUncompress(const IOBuf* data, Optional<uint64_t> uncompressedLength)
@@ -1846,6 +1849,7 @@ class AutomaticCodec final : public Codec {
   void checkCompatibleCodecs() const;
 
   std::vector<std::unique_ptr<Codec>> codecs_;
+  std::unique_ptr<Codec> terminalCodec_;
   bool needsUncompressedLength_;
   uint64_t maxUncompressedLength_;
 };
@@ -1877,38 +1881,70 @@ void AutomaticCodec::addCodecIfSupported(CodecType type) {
       [&type](std::unique_ptr<Codec> const& codec) {
         return codec->type() == type;
       });
-  if (hasCodec(type) && !present) {
+  bool const isTerminalType = terminalCodec_ && terminalCodec_->type() == type;
+  if (hasCodec(type) && !present && !isTerminalType) {
     codecs_.push_back(getCodec(type));
   }
 }
 
 /* static */ std::unique_ptr<Codec> AutomaticCodec::create(
-    std::vector<std::unique_ptr<Codec>> customCodecs) {
-  return std::make_unique<AutomaticCodec>(std::move(customCodecs));
-}
-
-AutomaticCodec::AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs)
-    : Codec(CodecType::USER_DEFINED), codecs_(std::move(customCodecs)) {
+    std::vector<std::unique_ptr<Codec>> customCodecs,
+    std::unique_ptr<Codec> terminalCodec) {
+  return std::make_unique<AutomaticCodec>(
+      std::move(customCodecs), std::move(terminalCodec));
+}
+
+AutomaticCodec::AutomaticCodec(
+    std::vector<std::unique_ptr<Codec>> customCodecs,
+    std::unique_ptr<Codec> terminalCodec)
+    : Codec(CodecType::USER_DEFINED),
+      codecs_(std::move(customCodecs)),
+      terminalCodec_(std::move(terminalCodec)) {
   // Fastest -> slowest
-  addCodecIfSupported(CodecType::LZ4_FRAME);
-  addCodecIfSupported(CodecType::ZSTD);
-  addCodecIfSupported(CodecType::ZLIB);
-  addCodecIfSupported(CodecType::GZIP);
-  addCodecIfSupported(CodecType::LZMA2);
-  addCodecIfSupported(CodecType::BZIP2);
+  std::array<CodecType, 6> defaultTypes{{
+      CodecType::LZ4_FRAME,
+      CodecType::ZSTD,
+      CodecType::ZLIB,
+      CodecType::GZIP,
+      CodecType::LZMA2,
+      CodecType::BZIP2,
+  }};
+
+  for (auto type : defaultTypes) {
+    addCodecIfSupported(type);
+  }
+
   if (kIsDebug) {
     checkCompatibleCodecs();
   }
-  // Check that none of the codes are are null
+
+  // Check that none of the codecs are null
   DCHECK(std::none_of(
       codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
         return codec == nullptr;
       }));
 
+  // Check that the terminal codec's type is not duplicated (with the exception
+  // of USER_DEFINED).
+  if (terminalCodec_) {
+    DCHECK(std::none_of(
+        codecs_.begin(),
+        codecs_.end(),
+        [&](std::unique_ptr<Codec> const& codec) {
+          return codec->type() != CodecType::USER_DEFINED &&
+              codec->type() == terminalCodec_->type();
+        }));
+  }
+
+  bool const terminalNeedsUncompressedLength =
+      terminalCodec_ && terminalCodec_->needsUncompressedLength();
   needsUncompressedLength_ = std::any_of(
-      codecs_.begin(), codecs_.end(), [](std::unique_ptr<Codec> const& codec) {
-        return codec->needsUncompressedLength();
-      });
+                                 codecs_.begin(),
+                                 codecs_.end(),
+                                 [](std::unique_ptr<Codec> const& codec) {
+                                   return codec->needsUncompressedLength();
+                                 }) ||
+      terminalNeedsUncompressedLength;
 
   const auto it = std::max_element(
       codecs_.begin(),
@@ -1917,7 +1953,10 @@ AutomaticCodec::AutomaticCodec(std::vector<std::unique_ptr<Codec>> customCodecs)
         return lhs->maxUncompressedLength() < rhs->maxUncompressedLength();
       });
   DCHECK(it != codecs_.end());
-  maxUncompressedLength_ = (*it)->maxUncompressedLength();
+  auto const terminalMaxUncompressedLength =
+      terminalCodec_ ? terminalCodec_->maxUncompressedLength() : 0;
+  maxUncompressedLength_ =
+      std::max((*it)->maxUncompressedLength(), terminalMaxUncompressedLength);
 }
 
 void AutomaticCodec::checkCompatibleCodecs() const {
@@ -1968,11 +2007,23 @@ uint64_t AutomaticCodec::doMaxUncompressedLength() const {
 std::unique_ptr<IOBuf> AutomaticCodec::doUncompress(
     const IOBuf* data,
     Optional<uint64_t> uncompressedLength) {
-  for (auto&& codec : codecs_) {
-    if (codec->canUncompress(data, uncompressedLength)) {
-      return codec->uncompress(data, uncompressedLength);
+  try {
+    for (auto&& codec : codecs_) {
+      if (codec->canUncompress(data, uncompressedLength)) {
+        return codec->uncompress(data, uncompressedLength);
+      }
     }
+  } catch (std::exception const& e) {
+    if (!terminalCodec_) {
+      throw e;
+    }
+  }
+
+  // Try terminal codec
+  if (terminalCodec_) {
+    return terminalCodec_->uncompress(data, uncompressedLength);
   }
+
   throw std::runtime_error("AutomaticCodec error: Unknown compressed data");
 }
 
@@ -2086,8 +2137,10 @@ std::unique_ptr<StreamCodec> getStreamCodec(CodecType type, int level) {
 }
 
 std::unique_ptr<Codec> getAutoUncompressionCodec(
-    std::vector<std::unique_ptr<Codec>> customCodecs) {
-  return AutomaticCodec::create(std::move(customCodecs));
+    std::vector<std::unique_ptr<Codec>> customCodecs,
+    std::unique_ptr<Codec> terminalCodec) {
+  return AutomaticCodec::create(
+      std::move(customCodecs), std::move(terminalCodec));
 }
 } // namespace io
 } // namespace folly
index 4013e0a24dcc53ec3f3228e7ed46d932c2569030..345eda82e97fc719e03f5ef70724858851f9eb1b 100644 (file)
@@ -443,11 +443,28 @@ std::unique_ptr<StreamCodec> getStreamCodec(
  * Returns a codec that can uncompress any of the given codec types as well as
  * {LZ4_FRAME, ZSTD, ZLIB, GZIP, LZMA2, BZIP2}. Appends each default codec to
  * customCodecs in order, so long as a codec with the same type() isn't already
- * present. When uncompress() is called, each codec's canUncompress() is called
- * in the order that they are given. Appended default codecs are checked last.
- * uncompress() is called on the first codec whose canUncompress() returns true.
- * An exception is thrown if no codec canUncompress() the data.
- * An exception is thrown if the chosen codec's uncompress() throws on the data.
+ * present in customCodecs or as the terminalCodec. When uncompress() is called,
+ * each codec's canUncompress() is called in the order that they are given.
+ * Appended default codecs are checked last.  uncompress() is called on the
+ * first codec whose canUncompress() returns true.
+ *
+ * In addition, an optional `terminalCodec` can be provided. This codec's
+ * uncompress() will be called either when no other codec canUncompress() the
+ * data or the chosen codec throws an exception on the data. The terminalCodec
+ * is intended for ambiguous headers, when canUncompress() is false for some
+ * data it can actually uncompress. The terminalCodec does not need to override
+ * validPrefixes() or canUncompress() and overriding these functions will have
+ * no effect on the returned codec's validPrefixes() or canUncompress()
+ * functions. The terminalCodec's needsUncompressedLength() and
+ * maxUncompressedLength() will affect the returned codec's respective
+ * functions. The terminalCodec must not be duplicated in customCodecs.
+ *
+ * An exception is thrown if no codec canUncompress() the data and either no
+ * terminal codec was provided or a terminal codec was provided and it throws on
+ * the data.
+ * An exception is thrown if the chosen codec's uncompress() throws on the data
+ * and either no terminal codec was provided or a terminal codec was provided
+ * and it also throws on the data.
  * An exception is thrown if compress() is called on the returned codec.
  *
  * Requirements are checked in debug mode and are as follows:
@@ -457,9 +474,12 @@ std::unique_ptr<StreamCodec> getStreamCodec(
  *  3. No header in headers may be empty.
  *  4. headers must not contain any duplicate elements.
  *  5. No strict non-empty prefix of any header in headers may be in headers.
+ *  6. The terminalCodec's type must not be the same as any other codec's type
+ *     (with USER_DEFINED being the exception).
  */
 std::unique_ptr<Codec> getAutoUncompressionCodec(
-    std::vector<std::unique_ptr<Codec>> customCodecs = {});
+    std::vector<std::unique_ptr<Codec>> customCodecs = {},
+    std::unique_ptr<Codec> terminalCodec = {});
 
 /**
  * Check if a specified codec is supported.
index b9db1da0c2d63f481e870ba2f71b931272f0cd23..69a037d8597beb3570e26e3a29d1dd6518cb9cf9 100644 (file)
@@ -958,17 +958,46 @@ INSTANTIATE_TEST_CASE_P(
         testing::Values(12, 17, 20),
         testing::ValuesIn(availableStreamCodecs())));
 
+namespace {
+
+// Codec types included in the codec returned by getAutoUncompressionCodec() by
+// default.
+std::vector<CodecType> autoUncompressionCodecTypes = {{
+    CodecType::LZ4_FRAME,
+    CodecType::ZSTD,
+    CodecType::ZLIB,
+    CodecType::GZIP,
+    CodecType::LZMA2,
+    CodecType::BZIP2,
+}};
+
+} // namespace
+
 class AutomaticCodecTest : public testing::TestWithParam<CodecType> {
  protected:
   void SetUp() override {
-    codec_ = getCodec(GetParam());
-    auto_ = getAutoUncompressionCodec();
+    codecType_ = GetParam();
+    codec_ = getCodec(codecType_);
+    autoType_ = std::any_of(
+        autoUncompressionCodecTypes.begin(),
+        autoUncompressionCodecTypes.end(),
+        [&](CodecType o) { return codecType_ == o; });
+    // Add the codec with type codecType_ as the terminal codec if it is not in
+    // autoUncompressionCodecTypes.
+    auto_ = getAutoUncompressionCodec({}, getTerminalCodec());
   }
 
   void runSimpleTest(const DataHolder& dh);
 
+  std::unique_ptr<Codec> getTerminalCodec() {
+    return (autoType_ ? nullptr : getCodec(codecType_));
+  }
+
   std::unique_ptr<Codec> codec_;
   std::unique_ptr<Codec> auto_;
+  CodecType codecType_;
+  // true if codecType_ is in autoUncompressionCodecTypes
+  bool autoType_;
 };
 
 void AutomaticCodecTest::runSimpleTest(const DataHolder& dh) {
@@ -1034,10 +1063,17 @@ TEST_P(AutomaticCodecTest, DefaultCodec) {
   const uint64_t length = 42;
   std::vector<std::unique_ptr<Codec>> codecs;
   codecs.push_back(getCodec(CodecType::ZSTD));
-  auto automatic = getAutoUncompressionCodec(std::move(codecs));
+  auto automatic =
+      getAutoUncompressionCodec(std::move(codecs), getTerminalCodec());
   auto original = IOBuf::wrapBuffer(constantDataHolder.data(length));
   auto compressed = codec_->compress(original.get());
-  auto decompressed = automatic->uncompress(compressed.get());
+  std::unique_ptr<IOBuf> decompressed;
+
+  if (automatic->needsUncompressedLength()) {
+    decompressed = automatic->uncompress(compressed.get(), length);
+  } else {
+    decompressed = automatic->uncompress(compressed.get());
+  }
 
   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(decompressed.get()));
 }
@@ -1096,17 +1132,28 @@ TEST_P(AutomaticCodecTest, CustomCodec) {
   auto ab = CustomCodec::create("ab", CodecType::ZSTD);
   std::vector<std::unique_ptr<Codec>> codecs;
   codecs.push_back(CustomCodec::create("ab", CodecType::ZSTD));
-  auto automatic = getAutoUncompressionCodec(std::move(codecs));
+  auto automatic =
+      getAutoUncompressionCodec(std::move(codecs), getTerminalCodec());
   auto original = IOBuf::wrapBuffer(constantDataHolder.data(length));
 
   auto abCompressed = ab->compress(original.get());
-  auto abDecompressed = automatic->uncompress(abCompressed.get());
+  std::unique_ptr<IOBuf> abDecompressed;
+  if (automatic->needsUncompressedLength()) {
+    abDecompressed = automatic->uncompress(abCompressed.get(), length);
+  } else {
+    abDecompressed = automatic->uncompress(abCompressed.get());
+  }
   EXPECT_TRUE(automatic->canUncompress(abCompressed.get()));
   EXPECT_FALSE(auto_->canUncompress(abCompressed.get()));
   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(abDecompressed.get()));
 
   auto compressed = codec_->compress(original.get());
-  auto decompressed = automatic->uncompress(compressed.get());
+  std::unique_ptr<IOBuf> decompressed;
+  if (automatic->needsUncompressedLength()) {
+    decompressed = automatic->uncompress(compressed.get(), length);
+  } else {
+    decompressed = automatic->uncompress(compressed.get());
+  }
   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(decompressed.get()));
 }
 
@@ -1116,17 +1163,28 @@ TEST_P(AutomaticCodecTest, CustomDefaultCodec) {
   std::vector<std::unique_ptr<Codec>> codecs;
   codecs.push_back(CustomCodec::create("none", CodecType::NO_COMPRESSION));
   codecs.push_back(getCodec(CodecType::LZ4_FRAME));
-  auto automatic = getAutoUncompressionCodec(std::move(codecs));
+  auto automatic =
+      getAutoUncompressionCodec(std::move(codecs), getTerminalCodec());
   auto original = IOBuf::wrapBuffer(constantDataHolder.data(length));
 
   auto noneCompressed = none->compress(original.get());
-  auto noneDecompressed = automatic->uncompress(noneCompressed.get());
+  std::unique_ptr<IOBuf> noneDecompressed;
+  if (automatic->needsUncompressedLength()) {
+    noneDecompressed = automatic->uncompress(noneCompressed.get(), length);
+  } else {
+    noneDecompressed = automatic->uncompress(noneCompressed.get());
+  }
   EXPECT_TRUE(automatic->canUncompress(noneCompressed.get()));
   EXPECT_FALSE(auto_->canUncompress(noneCompressed.get()));
   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(noneDecompressed.get()));
 
   auto compressed = codec_->compress(original.get());
-  auto decompressed = automatic->uncompress(compressed.get());
+  std::unique_ptr<IOBuf> decompressed;
+  if (automatic->needsUncompressedLength()) {
+    decompressed = automatic->uncompress(compressed.get(), length);
+  } else {
+    decompressed = automatic->uncompress(compressed.get());
+  }
   EXPECT_EQ(constantDataHolder.hash(length), hashIOBuf(decompressed.get()));
 }
 
@@ -1143,13 +1201,92 @@ TEST_P(AutomaticCodecTest, canUncompressOneBytes) {
 INSTANTIATE_TEST_CASE_P(
     AutomaticCodecTest,
     AutomaticCodecTest,
-    testing::Values(
-        CodecType::LZ4_FRAME,
-        CodecType::ZSTD,
-        CodecType::ZLIB,
-        CodecType::GZIP,
-        CodecType::LZMA2,
-        CodecType::BZIP2));
+    testing::ValuesIn(availableCodecs()));
+
+namespace {
+
+// Codec that always "uncompresses" to the same string.
+class ConstantCodec : public Codec {
+ public:
+  static std::unique_ptr<Codec> create(
+      std::string uncompressed,
+      CodecType type) {
+    return std::make_unique<ConstantCodec>(std::move(uncompressed), type);
+  }
+  explicit ConstantCodec(std::string uncompressed, CodecType type)
+      : Codec(type), uncompressed_(std::move(uncompressed)) {}
+
+ private:
+  uint64_t doMaxCompressedLength(uint64_t uncompressedLength) const override {
+    return uncompressedLength;
+  }
+
+  std::unique_ptr<IOBuf> doCompress(const IOBuf*) override {
+    throw std::runtime_error("ConstantCodec error: compress() not supported.");
+  }
+
+  std::unique_ptr<IOBuf> doUncompress(const IOBuf*, Optional<uint64_t>)
+      override {
+    return IOBuf::copyBuffer(uncompressed_);
+  }
+
+  std::string uncompressed_;
+  std::unique_ptr<Codec> codec_;
+};
+
+} // namespace
+
+class TerminalCodecTest : public testing::TestWithParam<CodecType> {
+ protected:
+  void SetUp() override {
+    codecType_ = GetParam();
+    codec_ = getCodec(codecType_);
+    auto_ = getAutoUncompressionCodec();
+  }
+
+  CodecType codecType_;
+  std::unique_ptr<Codec> codec_;
+  std::unique_ptr<Codec> auto_;
+};
+
+// Test that the terminal codec's uncompress() function is called when the
+// default chosen automatic codec throws.
+TEST_P(TerminalCodecTest, uncompressIfDefaultThrows) {
+  std::string const original = "abc";
+  auto const compressed = codec_->compress(original);
+
+  // Sanity check: the automatic codec can uncompress the original string.
+  auto const uncompressed = auto_->uncompress(compressed);
+  EXPECT_EQ(uncompressed, original);
+
+  // Truncate the compressed string.
+  auto const truncated = compressed.substr(0, compressed.size() - 1);
+  auto const truncatedBuf =
+      IOBuf::wrapBuffer(truncated.data(), truncated.size());
+  EXPECT_TRUE(auto_->canUncompress(truncatedBuf.get()));
+  EXPECT_ANY_THROW(auto_->uncompress(truncated));
+
+  // Expect the terminal codec to successfully uncompress the string.
+  std::unique_ptr<Codec> terminal = getAutoUncompressionCodec(
+      {}, ConstantCodec::create("dummyString", CodecType::USER_DEFINED));
+  EXPECT_TRUE(terminal->canUncompress(truncatedBuf.get()));
+  EXPECT_EQ(terminal->uncompress(truncated), "dummyString");
+}
+
+// If the terminal codec has one of the "default types" automatically added in
+// the AutomaticCodec, check that the default codec is no longer added.
+TEST_P(TerminalCodecTest, terminalOverridesDefaults) {
+  std::unique_ptr<Codec> terminal = getAutoUncompressionCodec(
+      {}, ConstantCodec::create("dummyString", codecType_));
+  std::string const original = "abc";
+  auto const compressed = codec_->compress(original);
+  EXPECT_EQ(terminal->uncompress(compressed), "dummyString");
+}
+
+INSTANTIATE_TEST_CASE_P(
+    TerminalCodecTest,
+    TerminalCodecTest,
+    testing::ValuesIn(autoUncompressionCodecTypes));
 
 TEST(ValidPrefixesTest, CustomCodec) {
   std::vector<std::unique_ptr<Codec>> codecs;