Make most implicit integer truncations and sign conversions explicit
[folly.git] / folly / io / RecordIO.cpp
1 /*
2  * Copyright 2017 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <folly/io/RecordIO.h>
18
19 #include <sys/types.h>
20
21 #include <folly/Exception.h>
22 #include <folly/FileUtil.h>
23 #include <folly/Memory.h>
24 #include <folly/Portability.h>
25 #include <folly/ScopeGuard.h>
26 #include <folly/String.h>
27 #include <folly/portability/Unistd.h>
28
29 namespace folly {
30
31 using namespace recordio_helpers;
32
33 RecordIOWriter::RecordIOWriter(File file, uint32_t fileId)
34   : file_(std::move(file)),
35     fileId_(fileId),
36     writeLock_(file_, std::defer_lock),
37     filePos_(0) {
38   if (!writeLock_.try_lock()) {
39     throw std::runtime_error("RecordIOWriter: file locked by another process");
40   }
41
42   struct stat st;
43   checkUnixError(fstat(file_.fd(), &st), "fstat() failed");
44
45   filePos_ = st.st_size;
46 }
47
48 void RecordIOWriter::write(std::unique_ptr<IOBuf> buf) {
49   size_t totalLength = prependHeader(buf, fileId_);
50   if (totalLength == 0) {
51     return;  // nothing to do
52   }
53
54   DCHECK_EQ(buf->computeChainDataLength(), totalLength);
55
56   // We're going to write.  Reserve space for ourselves.
57   off_t pos = filePos_.fetch_add(off_t(totalLength));
58
59 #if FOLLY_HAVE_PWRITEV
60   auto iov = buf->getIov();
61   ssize_t bytes = pwritevFull(file_.fd(), iov.data(), iov.size(), pos);
62 #else
63   buf->unshare();
64   buf->coalesce();
65   ssize_t bytes = pwriteFull(file_.fd(), buf->data(), buf->length(), pos);
66 #endif
67
68   checkUnixError(bytes, "pwrite() failed");
69   DCHECK_EQ(size_t(bytes), totalLength);
70 }
71
72 RecordIOReader::RecordIOReader(File file, uint32_t fileId)
73   : map_(std::move(file)),
74     fileId_(fileId) {
75 }
76
77 RecordIOReader::Iterator::Iterator(ByteRange range, uint32_t fileId, off_t pos)
78   : range_(range),
79     fileId_(fileId),
80     recordAndPos_(ByteRange(), 0) {
81   if (size_t(pos) >= range_.size()) {
82     // Note that this branch can execute if pos is negative as well.
83     recordAndPos_.second = off_t(-1);
84     range_.clear();
85   } else {
86     recordAndPos_.second = pos;
87     range_.advance(size_t(pos));
88     advanceToValid();
89   }
90 }
91
92 void RecordIOReader::Iterator::advanceToValid() {
93   ByteRange record = findRecord(range_, fileId_).record;
94   if (record.empty()) {
95     recordAndPos_ = std::make_pair(ByteRange(), off_t(-1));
96     range_.clear();  // at end
97   } else {
98     size_t skipped = size_t(record.begin() - range_.begin());
99     DCHECK_GE(skipped, headerSize());
100     skipped -= headerSize();
101     range_.advance(skipped);
102     recordAndPos_.first = record;
103     recordAndPos_.second += off_t(skipped);
104   }
105 }
106
107 namespace recordio_helpers {
108
109 using namespace detail;
110
111 namespace {
112
113 constexpr uint32_t kHashSeed = 0xdeadbeef;  // for mcurtiss
114
115 uint32_t headerHash(const Header& header) {
116   return hash::SpookyHashV2::Hash32(&header, offsetof(Header, headerHash),
117                                     kHashSeed);
118 }
119
120 std::pair<size_t, uint64_t> dataLengthAndHash(const IOBuf* buf) {
121   size_t len = 0;
122   hash::SpookyHashV2 hasher;
123   hasher.Init(kHashSeed, kHashSeed);
124   for (auto br : *buf) {
125     len += br.size();
126     hasher.Update(br.data(), br.size());
127   }
128   uint64_t hash1;
129   uint64_t hash2;
130   hasher.Final(&hash1, &hash2);
131   if (len + headerSize() >= std::numeric_limits<uint32_t>::max()) {
132     throw std::invalid_argument("Record length must fit in 32 bits");
133   }
134   return std::make_pair(len, hash1);
135 }
136
137 uint64_t dataHash(ByteRange range) {
138   return hash::SpookyHashV2::Hash64(range.data(), range.size(), kHashSeed);
139 }
140
141 }  // namespace
142
143 size_t prependHeader(std::unique_ptr<IOBuf>& buf, uint32_t fileId) {
144   if (fileId == 0) {
145     throw std::invalid_argument("invalid file id");
146   }
147   auto lengthAndHash = dataLengthAndHash(buf.get());
148   if (lengthAndHash.first == 0) {
149     return 0;  // empty, nothing to do, no zero-length records
150   }
151
152   // Prepend to the first buffer in the chain if we have room, otherwise
153   // prepend a new buffer.
154   if (buf->headroom() >= headerSize()) {
155     buf->unshareOne();
156     buf->prepend(headerSize());
157   } else {
158     auto b = IOBuf::create(headerSize());
159     b->append(headerSize());
160     b->appendChain(std::move(buf));
161     buf = std::move(b);
162   }
163   detail::Header* header =
164     reinterpret_cast<detail::Header*>(buf->writableData());
165   memset(header, 0, sizeof(Header));
166   header->magic = detail::Header::kMagic;
167   header->fileId = fileId;
168   header->dataLength = uint32_t(lengthAndHash.first);
169   header->dataHash = lengthAndHash.second;
170   header->headerHash = headerHash(*header);
171
172   return lengthAndHash.first + headerSize();
173 }
174
175 RecordInfo validateRecord(ByteRange range, uint32_t fileId) {
176   if (range.size() <= headerSize()) {  // records may not be empty
177     return {0, {}};
178   }
179   const Header* header = reinterpret_cast<const Header*>(range.begin());
180   range.advance(sizeof(Header));
181   if (header->magic != Header::kMagic ||
182       header->version != 0 ||
183       header->hashFunction != 0 ||
184       header->flags != 0 ||
185       (fileId != 0 && header->fileId != fileId) ||
186       header->dataLength > range.size()) {
187     return {0, {}};
188   }
189   if (headerHash(*header) != header->headerHash) {
190     return {0, {}};
191   }
192   range.reset(range.begin(), header->dataLength);
193   if (dataHash(range) != header->dataHash) {
194     return {0, {}};
195   }
196   return {header->fileId, range};
197 }
198
199 RecordInfo findRecord(ByteRange searchRange,
200                       ByteRange wholeRange,
201                       uint32_t fileId) {
202   static const uint32_t magic = Header::kMagic;
203   static const ByteRange magicRange(reinterpret_cast<const uint8_t*>(&magic),
204                                     sizeof(magic));
205
206   DCHECK_GE(searchRange.begin(), wholeRange.begin());
207   DCHECK_LE(searchRange.end(), wholeRange.end());
208
209   const uint8_t* start = searchRange.begin();
210   const uint8_t* end = std::min(searchRange.end(),
211                                 wholeRange.end() - sizeof(Header));
212   // end-1: the last place where a Header could start
213   while (start < end) {
214     auto p = ByteRange(start, end + sizeof(magic)).find(magicRange);
215     if (p == ByteRange::npos) {
216       break;
217     }
218
219     start += p;
220     auto r = validateRecord(ByteRange(start, wholeRange.end()), fileId);
221     if (!r.record.empty()) {
222       return r;
223     }
224
225     // No repeated prefix in magic, so we can do better than start++
226     start += sizeof(magic);
227   }
228
229   return {0, {}};
230 }
231
232 }  // namespace
233
234 }  // namespaces