Fix SIGSEGV in StringPiece::find_first_of
authorMike Curtiss <mcurtiss@fb.com>
Tue, 12 Feb 2013 22:39:13 +0000 (14:39 -0800)
committerJordan DeLong <jdelong@fb.com>
Tue, 19 Mar 2013 00:07:36 +0000 (17:07 -0700)
Summary:
Our SSE version of find_first_of was reading past the end of
the StringPiece in some cases, which (very rarely) caused a seg-fault
when we were reading outside of our allotted virtual address space.

Modify the code to never read past the end of the underlying buffers
except when we think it's "safe" because we're still within the same
page. (ASSUMPTION: if a process is allowed to read a byte within a
page, then it is allowed to read _all_ bytes within that page.)

Test Plan:
Added tests that verify we won't go across page boundaries.

Sadly, this code hurts our benchmarks -- sometimes by up to 50% for
smaller strings.

Reviewed By: philipp@fb.com

FB internal diff: D707923

Blame Revision: D638500

folly/Range.cpp
folly/test/RangeTest.cpp

index 94dc4fe9c9ea89619733381214ec7f2d83647204..aca2daa73e235f870ae58f1b4fe905a236c48900 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "folly/Range.h"
 
+#include <emmintrin.h>  // __v16qi
 #include "folly/Likely.h"
 
 namespace folly {
@@ -56,6 +57,16 @@ size_t qfind_first_byte_of_memchr(const StringPiece& haystack,
 
 namespace {
 
+// It's okay if pages are bigger than this (as powers of two), but they should
+// not be smaller.
+constexpr size_t kMinPageSize = 4096;
+#define PAGE_FOR(addr) \
+  (reinterpret_cast<intptr_t>(addr) / kMinPageSize)
+
+// Rounds up to the next multiple of 16
+#define ROUND_UP_16(val) \
+  ((val + 15) & ~0xF)
+
 // build sse4.2-optimized version even if -msse4.2 is not passed to GCC
 size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
                                      const StringPiece& needles)
@@ -64,15 +75,30 @@ size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
 // helper method for case where needles.size() <= 16
 size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
                                      const StringPiece& needles) {
+  DCHECK(!haystack.empty());
+  DCHECK(!needles.empty());
   DCHECK_LE(needles.size(), 16);
-  if (needles.size() <= 2 && haystack.size() >= 256) {
+  if ((needles.size() <= 2 && haystack.size() >= 256) ||
+      // we can't load needles into SSE register if it could cross page boundary
+      (PAGE_FOR(needles.end() - 1) != PAGE_FOR(needles.data() + 15))) {
     // benchmarking shows that memchr beats out SSE for small needle-sets
     // with large haystacks.
     // TODO(mcurtiss): could this be because of unaligned SSE loads?
     return detail::qfind_first_byte_of_memchr(haystack, needles);
   }
-  auto arr2 = __builtin_ia32_loaddqu(needles.data());
-  for (size_t i = 0; i < haystack.size(); i+= 16) {
+
+  __v16qi arr2 = __builtin_ia32_loaddqu(needles.data());
+
+  // If true, the last byte we want to load into the SSE register is on the
+  // same page as the last byte of the actual Range.  No risk of segfault.
+  bool canSseLoadLastBlock =
+    (PAGE_FOR(haystack.end() - 1) ==
+     PAGE_FOR(haystack.data() + ROUND_UP_16(haystack.size()) - 1));
+  int64_t lastSafeBlockIdx = canSseLoadLastBlock ?
+    haystack.size() : static_cast<int64_t>(haystack.size()) - 16;
+
+  int64_t i = 0;
+  for (; i < lastSafeBlockIdx; i+= 16) {
     auto arr1 = __builtin_ia32_loaddqu(haystack.data() + i);
     auto index = __builtin_ia32_pcmpestri128(arr2, needles.size(),
                                              arr1, haystack.size() - i, 0);
@@ -80,6 +106,15 @@ size_t qfind_first_byte_of_needles16(const StringPiece& haystack,
       return i + index;
     }
   }
+
+  if (!canSseLoadLastBlock) {
+    StringPiece tmp(haystack);
+    tmp.advance(i);
+    auto ret = detail::qfind_first_byte_of_memchr(tmp, needles);
+    if (ret != StringPiece::npos) {
+      return ret + i;
+    }
+  }
   return StringPiece::npos;
 }
 
@@ -127,6 +162,46 @@ size_t qfind_first_byte_of_byteset(const StringPiece& haystack,
   return StringPiece::npos;
 }
 
+inline size_t scanHaystackBlock(const StringPiece& haystack,
+                                const StringPiece& needles,
+                                int64_t idx)
+// inlining is okay because it's only called from other sse4.2 functions
+  __attribute__ ((__target__("sse4.2")));
+
+// Scans a 16-byte block of haystack (starting at blockStartIdx) to find first
+// needle. If blockStartIdx is near the end of haystack, it may read a few bytes
+// past the end; it is the caller's responsibility to ensure this is safe.
+inline size_t scanHaystackBlock(const StringPiece& haystack,
+                                const StringPiece& needles,
+                                int64_t blockStartIdx) {
+  // small needle sets should be handled by qfind_first_byte_of_needles16()
+  DCHECK_GT(needles.size(), 16);
+  DCHECK(blockStartIdx + 16 <= haystack.size() ||
+         (PAGE_FOR(haystack.data() + blockStartIdx) ==
+          PAGE_FOR(haystack.data() + blockStartIdx + 15)));
+  size_t b = 16;
+  auto arr1 = __builtin_ia32_loaddqu(haystack.data() + blockStartIdx);
+  int64_t j = 0;
+  for (; j < static_cast<int64_t>(needles.size()) - 16; j += 16) {
+    auto arr2 = __builtin_ia32_loaddqu(needles.data() + j);
+    auto index = __builtin_ia32_pcmpestri128(
+      arr2, 16, arr1, haystack.size() - blockStartIdx, 0);
+    b = std::min<size_t>(index, b);
+  }
+
+  // Avoid reading any bytes past the end needles by just reading the last
+  // 16 bytes of needles. We know this is safe because needles.size() > 16.
+  auto arr2 = __builtin_ia32_loaddqu(needles.end() - 16);
+  auto index = __builtin_ia32_pcmpestri128(
+    arr2, 16, arr1, haystack.size() - blockStartIdx, 0);
+  b = std::min<size_t>(index, b);
+
+  if (b < 16) {
+    return blockStartIdx + b;
+  }
+  return StringPiece::npos;
+}
+
 size_t qfind_first_byte_of_sse42(const StringPiece& haystack,
                                  const StringPiece& needles)
   __attribute__ ((__target__("sse4.2"), noinline));
@@ -141,20 +216,26 @@ size_t qfind_first_byte_of_sse42(const StringPiece& haystack,
     return qfind_first_byte_of_needles16(haystack, needles);
   }
 
-  size_t index = haystack.size();
-  for (size_t i = 0; i < haystack.size(); i += 16) {
-    size_t b = 16;
-    auto arr1 = __builtin_ia32_loaddqu(haystack.data() + i);
-    for (size_t j = 0; j < needles.size(); j += 16) {
-      auto arr2 = __builtin_ia32_loaddqu(needles.data() + j);
-      auto index = __builtin_ia32_pcmpestri128(arr2, needles.size() - j,
-                                               arr1, haystack.size() - i, 0);
-      b = std::min<size_t>(index, b);
-    }
-    if (b < 16) {
-      return i + b;
+  int64_t i = 0;
+  for (; i < static_cast<int64_t>(haystack.size()) - 16; i += 16) {
+    auto ret = scanHaystackBlock(haystack, needles, i);
+    if (ret != StringPiece::npos) {
+      return ret;
     }
   };
+
+  if (i == haystack.size() - 16 ||
+      PAGE_FOR(haystack.end() - 1) == PAGE_FOR(haystack.data() + i + 15)) {
+    return scanHaystackBlock(haystack, needles, i);
+  } else {
+    auto ret = qfind_first_byte_of_nosse(StringPiece(haystack.data() + i,
+                                                     haystack.end()),
+                                         needles);
+    if (ret != StringPiece::npos) {
+      return i + ret;
+    }
+  }
+
   return StringPiece::npos;
 }
 
index 6b29725224a4f4b1330e666b0e6218d43409885b..095cb3a5a63511fb550a6b7907d61370fbb7819e 100644 (file)
 // @author Kristina Holst (kholst@fb.com)
 // @author Andrei Alexandrescu (andrei.alexandrescu@fb.com)
 
+#include "folly/Range.h"
+
 #include <limits>
+#include <stdlib.h>
 #include <string>
+#include <sys/mman.h>
 #include <boost/range/concepts.hpp>
 #include <gtest/gtest.h>
-#include "folly/Range.h"
 
 namespace folly { namespace detail {
 
@@ -336,3 +339,60 @@ TYPED_TEST(NeedleFinderTest, Base) {
     }
   }
 }
+
+const size_t kPageSize = 4096;
+// Updates contents so that any read accesses past the last byte will
+// cause a SIGSEGV.  It accomplishes this by changing access to the page that
+// begins immediately after the end of the contents (as allocators and mmap()
+// all operate on page boundaries, this is a reasonable assumption).
+// This function will also initialize buf, which caller must free().
+void createProtectedBuf(StringPiece& contents, char** buf) {
+  ASSERT_LE(contents.size(), kPageSize);
+  const size_t kSuccess = 0;
+  char* tmp;
+  if (kSuccess != posix_memalign((void**)buf, kPageSize, 2 * kPageSize)) {
+    ASSERT_FALSE(true);
+  }
+  mprotect(*buf + kPageSize, kPageSize, PROT_NONE);
+  size_t newBegin = kPageSize - contents.size();
+  memcpy(*buf + newBegin, contents.data(), contents.size());
+  contents.reset(*buf + newBegin, contents.size());
+}
+
+TYPED_TEST(NeedleFinderTest, NoSegFault) {
+  const string base = string(32, 'a') + string("b");
+  const string delims = string(32, 'c') + string("b");
+  for (int i = 0; i <= 32; i++) {
+    for (int j = 0; j <= 33; j++) {
+      for (int shouldFind = 0; shouldFind <= 1; ++shouldFind) {
+        StringPiece s1(base);
+        s1.advance(i);
+        ASSERT_TRUE(!s1.empty());
+        if (!shouldFind) {
+          s1.pop_back();
+        }
+        StringPiece s2(delims);
+        s2.advance(j);
+        char* buf1;
+        char* buf2;
+        createProtectedBuf(s1, &buf1);
+        createProtectedBuf(s2, &buf2);
+        // printf("s1: '%s' (%ld) \ts2: '%s' (%ld)\n",
+        //        string(s1.data(), s1.size()).c_str(), s1.size(),
+        //        string(s2.data(), s2.size()).c_str(), s2.size());
+        auto r1 = this->find_first_byte_of(s1, s2);
+        auto f1 = std::find_first_of(s1.begin(), s1.end(),
+                                     s2.begin(), s2.end());
+        auto e1 = (f1 == s1.end()) ? StringPiece::npos : f1 - s1.begin();
+        EXPECT_EQ(r1, e1);
+        auto r2 = this->find_first_byte_of(s2, s1);
+        auto f2 = std::find_first_of(s2.begin(), s2.end(),
+                                     s1.begin(), s1.end());
+        auto e2 = (f2 == s2.end()) ? StringPiece::npos : f2 - s2.begin();
+        EXPECT_EQ(r2, e2);
+        free(buf1);
+        free(buf2);
+      }
+    }
+  }
+}