Fix SIGSEGV in StringPiece::find_first_of
[folly.git] / folly / Range.cpp
index 94dc4fe9c9ea89619733381214ec7f2d83647204..aca2daa73e235f870ae58f1b4fe905a236c48900 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "folly/Range.h"
 
+#include <emmintrin.h>  // __v16qi
 #include "folly/Likely.h"
 
 namespace folly {
@@ -56,6 +57,16 @@ size_t qfind_first_byte_of_memchr(const StringPiece& haystack,
 
 namespace {
 
+// It's okay if pages are bigger than this (as powers of two), but they should
+// not be smaller.
+constexpr size_t kMinPageSize = 4096;
+#define PAGE_FOR(addr) \
+  (reinterpret_cast<intptr_t>(addr) / kMinPageSize)
+
+// Rounds up to the next multiple of 16
+#define ROUND_UP_16(val) \
+  ((val + 15) & ~0xF)
+
 // build sse4.2-optimized version even if -msse4.2 is not passed to GCC
 size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
                                      const StringPiece& needles)
@@ -64,15 +75,30 @@ size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
 // helper method for case where needles.size() <= 16
 size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
                                      const StringPiece& needles) {
+  DCHECK(!haystack.empty());
+  DCHECK(!needles.empty());
   DCHECK_LE(needles.size(), 16);
-  if (needles.size() <= 2 && haystack.size() >= 256) {
+  if ((needles.size() <= 2 && haystack.size() >= 256) ||
+      // we can't load needles into SSE register if it could cross page boundary
+      (PAGE_FOR(needles.end() - 1) != PAGE_FOR(needles.data() + 15))) {
     // benchmarking shows that memchr beats out SSE for small needle-sets
     // with large haystacks.
     // TODO(mcurtiss): could this be because of unaligned SSE loads?
     return detail::qfind_first_byte_of_memchr(haystack, needles);
   }
-  auto arr2 = __builtin_ia32_loaddqu(needles.data());
-  for (size_t i = 0; i < haystack.size(); i+= 16) {
+
+  __v16qi arr2 = __builtin_ia32_loaddqu(needles.data());
+
+  // If true, the last byte we want to load into the SSE register is on the
+  // same page as the last byte of the actual Range.  No risk of segfault.
+  bool canSseLoadLastBlock =
+    (PAGE_FOR(haystack.end() - 1) ==
+     PAGE_FOR(haystack.data() + ROUND_UP_16(haystack.size()) - 1));
+  int64_t lastSafeBlockIdx = canSseLoadLastBlock ?
+    haystack.size() : static_cast<int64_t>(haystack.size()) - 16;
+
+  int64_t i = 0;
+  for (; i < lastSafeBlockIdx; i+= 16) {
     auto arr1 = __builtin_ia32_loaddqu(haystack.data() + i);
     auto index = __builtin_ia32_pcmpestri128(arr2, needles.size(),
                                              arr1, haystack.size() - i, 0);
@@ -80,6 +106,15 @@ size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
       return i + index;
     }
   }
+
+  if (!canSseLoadLastBlock) {
+    StringPiece tmp(haystack);
+    tmp.advance(i);
+    auto ret = detail::qfind_first_byte_of_memchr(tmp, needles);
+    if (ret != StringPiece::npos) {
+      return ret + i;
+    }
+  }
   return StringPiece::npos;
 }
 
@@ -127,6 +162,46 @@ size_t qfind_first_byte_of_byteset(const StringPiece& haystack,
   return StringPiece::npos;
 }
 
+inline size_t scanHaystackBlock(const StringPiece& haystack,
+                                const StringPiece& needles,
+                                int64_t idx)
+// inlining is okay because it's only called from other sse4.2 functions
+  __attribute__ ((__target__("sse4.2")));
+
+// Scans a 16-byte block of haystack (starting at blockStartIdx) to find first
+// needle. If blockStartIdx is near the end of haystack, it may read a few bytes
+// past the end; it is the caller's responsibility to ensure this is safe.
+inline size_t scanHaystackBlock(const StringPiece& haystack,
+                                const StringPiece& needles,
+                                int64_t blockStartIdx) {
+  // small needle sets should be handled by qfind_first_byte_of_needles16()
+  DCHECK_GT(needles.size(), 16);
+  DCHECK(blockStartIdx + 16 <= haystack.size() ||
+         (PAGE_FOR(haystack.data() + blockStartIdx) ==
+          PAGE_FOR(haystack.data() + blockStartIdx + 15)));
+  size_t b = 16;
+  auto arr1 = __builtin_ia32_loaddqu(haystack.data() + blockStartIdx);
+  int64_t j = 0;
+  for (; j < static_cast<int64_t>(needles.size()) - 16; j += 16) {
+    auto arr2 = __builtin_ia32_loaddqu(needles.data() + j);
+    auto index = __builtin_ia32_pcmpestri128(
+      arr2, 16, arr1, haystack.size() - blockStartIdx, 0);
+    b = std::min<size_t>(index, b);
+  }
+
+  // Avoid reading any bytes past the end needles by just reading the last
+  // 16 bytes of needles. We know this is safe because needles.size() > 16.
+  auto arr2 = __builtin_ia32_loaddqu(needles.end() - 16);
+  auto index = __builtin_ia32_pcmpestri128(
+    arr2, 16, arr1, haystack.size() - blockStartIdx, 0);
+  b = std::min<size_t>(index, b);
+
+  if (b < 16) {
+    return blockStartIdx + b;
+  }
+  return StringPiece::npos;
+}
+
 size_t qfind_first_byte_of_sse42(const StringPiece& haystack,
                                  const StringPiece& needles)
   __attribute__ ((__target__("sse4.2"), noinline));
@@ -141,20 +216,26 @@ size_t qfind_first_byte_of_sse42(const StringPiece& haystack,
     return qfind_first_byte_of_needles16(haystack, needles);
   }
 
-  size_t index = haystack.size();
-  for (size_t i = 0; i < haystack.size(); i += 16) {
-    size_t b = 16;
-    auto arr1 = __builtin_ia32_loaddqu(haystack.data() + i);
-    for (size_t j = 0; j < needles.size(); j += 16) {
-      auto arr2 = __builtin_ia32_loaddqu(needles.data() + j);
-      auto index = __builtin_ia32_pcmpestri128(arr2, needles.size() - j,
-                                               arr1, haystack.size() - i, 0);
-      b = std::min<size_t>(index, b);
-    }
-    if (b < 16) {
-      return i + b;
+  int64_t i = 0;
+  for (; i < static_cast<int64_t>(haystack.size()) - 16; i += 16) {
+    auto ret = scanHaystackBlock(haystack, needles, i);
+    if (ret != StringPiece::npos) {
+      return ret;
     }
   };
+
+  if (i == haystack.size() - 16 ||
+      PAGE_FOR(haystack.end() - 1) == PAGE_FOR(haystack.data() + i + 15)) {
+    return scanHaystackBlock(haystack, needles, i);
+  } else {
+    auto ret = qfind_first_byte_of_nosse(StringPiece(haystack.data() + i,
+                                                     haystack.end()),
+                                         needles);
+    if (ret != StringPiece::npos) {
+      return i + ret;
+    }
+  }
+
   return StringPiece::npos;
 }