Refactors folly sync test cases
[folly.git] / folly / IPAddress.cpp
index 2ef69d64bda0e2c8341b4e3e012aad6ec3bdaccc..5f698f28f25ba66ba9af9fef448b425b9e586be4 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright 2017 Facebook, Inc.
+ * Copyright 2014-present Facebook, Inc.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -13,7 +13,6 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 #include <folly/IPAddress.h>
 
 #include <limits>
@@ -68,43 +67,108 @@ IPAddressV6 IPAddress::createIPv6(const IPAddress& addr) {
   }
 }
 
+namespace {
+vector<string> splitIpSlashCidr(StringPiece ipSlashCidr) {
+  vector<string> vec;
+  split("/", ipSlashCidr, vec);
+  return vec;
+}
+} // namespace
+
 // public static
 CIDRNetwork IPAddress::createNetwork(
     StringPiece ipSlashCidr,
     int defaultCidr, /* = -1 */
     bool applyMask /* = true */) {
-  if (defaultCidr > std::numeric_limits<uint8_t>::max()) {
+  auto const ret =
+      IPAddress::tryCreateNetwork(ipSlashCidr, defaultCidr, applyMask);
+
+  if (ret.hasValue()) {
+    return ret.value();
+  }
+
+  if (ret.error() == CIDRNetworkError::INVALID_DEFAULT_CIDR) {
     throw std::range_error("defaultCidr must be <= UINT8_MAX");
   }
-  vector<string> vec;
-  split("/", ipSlashCidr, vec);
-  vector<string>::size_type elemCount = vec.size();
 
-  if (elemCount == 0 || // weird invalid string
-      elemCount > 2) { // invalid string (IP/CIDR/extras)
+  if (ret.error() == CIDRNetworkError::INVALID_IP_SLASH_CIDR) {
     throw IPAddressFormatException(sformat(
         "Invalid ipSlashCidr specified. Expected IP/CIDR format, got '{}'",
         ipSlashCidr));
   }
-  IPAddress subnet(vec.at(0));
-  auto cidr =
-      uint8_t((defaultCidr > -1) ? defaultCidr : (subnet.isV4() ? 32 : 128));
 
-  if (elemCount == 2) {
-    try {
-      cidr = to<uint8_t>(vec.at(1));
-    } catch (...) {
+  // Handler the remaining error cases. We re-parse the ip/mask pair
+  // to make error messages more meaningful
+  auto const vec = splitIpSlashCidr(ipSlashCidr);
+
+  switch (ret.error()) {
+    case CIDRNetworkError::INVALID_IP:
+      CHECK_GE(vec.size(), 1);
+      throw IPAddressFormatException(
+          sformat("Invalid IP address {}", vec.at(0)));
+    case CIDRNetworkError::INVALID_CIDR:
+      CHECK_GE(vec.size(), 2);
       throw IPAddressFormatException(
           sformat("Mask value '{}' not a valid mask", vec.at(1)));
+    case CIDRNetworkError::CIDR_MISMATCH: {
+      auto const subnet = IPAddress::tryFromString(vec.at(0)).value();
+      auto cidr = static_cast<uint8_t>(
+          (defaultCidr > -1) ? defaultCidr : (subnet.isV4() ? 32 : 128));
+
+      throw IPAddressFormatException(sformat(
+          "CIDR value '{}' is > network bit count '{}'",
+          vec.size() == 2 ? vec.at(1) : to<string>(cidr),
+          subnet.bitCount()));
     }
+    default:
+      // unreachable
+      break;
   }
-  if (cidr > subnet.bitCount()) {
-    throw IPAddressFormatException(sformat(
-        "CIDR value '{}' is > network bit count '{}'",
-        cidr,
-        subnet.bitCount()));
+
+  CHECK(0);
+
+  return CIDRNetwork{};
+}
+
+// public static
+Expected<CIDRNetwork, CIDRNetworkError> IPAddress::tryCreateNetwork(
+    StringPiece ipSlashCidr,
+    int defaultCidr,
+    bool applyMask) {
+  if (defaultCidr > std::numeric_limits<uint8_t>::max()) {
+    return makeUnexpected(CIDRNetworkError::INVALID_DEFAULT_CIDR);
+  }
+
+  auto const vec = splitIpSlashCidr(ipSlashCidr);
+  auto const elemCount = vec.size();
+
+  if (elemCount == 0 || // weird invalid string
+      elemCount > 2) { // invalid string (IP/CIDR/extras)
+    return makeUnexpected(CIDRNetworkError::INVALID_IP_SLASH_CIDR);
   }
-  return std::make_pair(applyMask ? subnet.mask(cidr) : subnet, cidr);
+
+  auto const subnet = IPAddress::tryFromString(vec.at(0));
+  if (subnet.hasError()) {
+    return makeUnexpected(CIDRNetworkError::INVALID_IP);
+  }
+
+  auto cidr = static_cast<uint8_t>(
+      (defaultCidr > -1) ? defaultCidr : (subnet.value().isV4() ? 32 : 128));
+
+  if (elemCount == 2) {
+    auto const maybeCidr = tryTo<uint8_t>(vec.at(1));
+    if (maybeCidr.hasError()) {
+      return makeUnexpected(CIDRNetworkError::INVALID_CIDR);
+    }
+    cidr = maybeCidr.value();
+  }
+
+  if (cidr > subnet.value().bitCount()) {
+    return makeUnexpected(CIDRNetworkError::CIDR_MISMATCH);
+  }
+
+  return std::make_pair(
+      applyMask ? subnet.value().mask(cidr) : subnet.value(), cidr);
 }
 
 // public static