1115f550d064f8fd56b268621805f666fbad1842
[folly.git] / folly / GroupVarint.h
1 /*
2  * Copyright 2014 Facebook, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef FOLLY_GROUPVARINT_H_
18 #define FOLLY_GROUPVARINT_H_
19
20 #ifndef __GNUC__
21 #error GroupVarint.h requires GCC
22 #endif
23
24 #include <folly/Portability.h>
25
26 #if FOLLY_X64 || defined(__i386__)
27 #define HAVE_GROUP_VARINT 1
28
29 #include <cstdint>
30 #include <limits>
31 #include <folly/detail/GroupVarintDetail.h>
32 #include <folly/Bits.h>
33 #include <folly/Range.h>
34 #include <glog/logging.h>
35
36 #ifdef __SSSE3__
37 #include <x86intrin.h>
38 namespace folly {
39 namespace detail {
40 extern const __m128i groupVarintSSEMasks[];
41 }  // namespace detail
42 }  // namespace folly
43 #endif
44
45 namespace folly {
46 namespace detail {
47 extern const uint8_t groupVarintLengths[];
48 }  // namespace detail
49 }  // namespace folly
50
51 namespace folly {
52
53 template <typename T>
54 class GroupVarint;
55
56 /**
57  * GroupVarint encoding for 32-bit values.
58  *
59  * Encodes 4 32-bit integers at once, each using 1-4 bytes depending on size.
60  * There is one byte of overhead.  (The first byte contains the lengths of
61  * the four integers encoded as two bits each; 00=1 byte .. 11=4 bytes)
62  *
63  * This implementation assumes little-endian and does unaligned 32-bit
64  * accesses, so it's basically not portable outside of the x86[_64] world.
65  */
66 template <>
67 class GroupVarint<uint32_t> : public detail::GroupVarintBase<uint32_t> {
68  public:
69
70   /**
71    * Return the number of bytes used to encode these four values.
72    */
73   static size_t size(uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
74     return kHeaderSize + kGroupSize + key(a) + key(b) + key(c) + key(d);
75   }
76
77   /**
78    * Return the number of bytes used to encode four uint32_t values stored
79    * at consecutive positions in an array.
80    */
81   static size_t size(const uint32_t* p) {
82     return size(p[0], p[1], p[2], p[3]);
83   }
84
85   /**
86    * Return the number of bytes used to encode count (<= 4) values.
87    * If you clip a buffer after these many bytes, you can still decode
88    * the first "count" values correctly (if the remaining size() -
89    * partialSize() bytes are filled with garbage).
90    */
91   static size_t partialSize(const type* p, size_t count) {
92     DCHECK_LE(count, kGroupSize);
93     size_t s = kHeaderSize + count;
94     for (; count; --count, ++p) {
95       s += key(*p);
96     }
97     return s;
98   }
99
100   /**
101    * Return the number of values from *p that are valid from an encoded
102    * buffer of size bytes.
103    */
104   static size_t partialCount(const char* p, size_t size) {
105     char v = *p;
106     size_t s = kHeaderSize;
107     s += 1 + b0key(v);
108     if (s > size) return 0;
109     s += 1 + b1key(v);
110     if (s > size) return 1;
111     s += 1 + b2key(v);
112     if (s > size) return 2;
113     s += 1 + b3key(v);
114     if (s > size) return 3;
115     return 4;
116   }
117
118   /**
119    * Given a pointer to the beginning of an GroupVarint32-encoded block,
120    * return the number of bytes used by the encoding.
121    */
122   static size_t encodedSize(const char* p) {
123     return (kHeaderSize + kGroupSize +
124             b0key(*p) + b1key(*p) + b2key(*p) + b3key(*p));
125   }
126
127   /**
128    * Encode four uint32_t values into the buffer pointed-to by p, and return
129    * the next position in the buffer (that is, one character past the last
130    * encoded byte).  p needs to have at least size()+4 bytes available.
131    */
132   static char* encode(char* p, uint32_t a, uint32_t b, uint32_t c, uint32_t d) {
133     uint8_t b0key = key(a);
134     uint8_t b1key = key(b);
135     uint8_t b2key = key(c);
136     uint8_t b3key = key(d);
137     *p++ = (b3key << 6) | (b2key << 4) | (b1key << 2) | b0key;
138     storeUnaligned(p, a);
139     p += b0key+1;
140     storeUnaligned(p, b);
141     p += b1key+1;
142     storeUnaligned(p, c);
143     p += b2key+1;
144     storeUnaligned(p, d);
145     p += b3key+1;
146     return p;
147   }
148
149   /**
150    * Encode four uint32_t values from the array pointed-to by src into the
151    * buffer pointed-to by p, similar to encode(p,a,b,c,d) above.
152    */
153   static char* encode(char* p, const uint32_t* src) {
154     return encode(p, src[0], src[1], src[2], src[3]);
155   }
156
157   /**
158    * Decode four uint32_t values from a buffer, and return the next position
159    * in the buffer (that is, one character past the last encoded byte).
160    * The buffer needs to have at least 3 extra bytes available (they
161    * may be read but ignored).
162    */
163   static const char* decode_simple(const char* p, uint32_t* a, uint32_t* b,
164                                    uint32_t* c, uint32_t* d) {
165     size_t k = loadUnaligned<uint8_t>(p);
166     const char* end = p + detail::groupVarintLengths[k];
167     ++p;
168     size_t k0 = b0key(k);
169     *a = loadUnaligned<uint32_t>(p) & kMask[k0];
170     p += k0+1;
171     size_t k1 = b1key(k);
172     *b = loadUnaligned<uint32_t>(p) & kMask[k1];
173     p += k1+1;
174     size_t k2 = b2key(k);
175     *c = loadUnaligned<uint32_t>(p) & kMask[k2];
176     p += k2+1;
177     size_t k3 = b3key(k);
178     *d = loadUnaligned<uint32_t>(p) & kMask[k3];
179     p += k3+1;
180     return end;
181   }
182
183   /**
184    * Decode four uint32_t values from a buffer and store them in the array
185    * pointed-to by dest, similar to decode(p,a,b,c,d) above.
186    */
187   static const char* decode_simple(const char* p, uint32_t* dest) {
188     return decode_simple(p, dest, dest+1, dest+2, dest+3);
189   }
190
191 #ifdef __SSSE3__
192   static const char* decode(const char* p, uint32_t* dest) {
193     uint8_t key = p[0];
194     __m128i val = _mm_loadu_si128((const __m128i*)(p+1));
195     __m128i mask = detail::groupVarintSSEMasks[key];
196     __m128i r = _mm_shuffle_epi8(val, mask);
197     _mm_storeu_si128((__m128i*)dest, r);
198     return p + detail::groupVarintLengths[key];
199   }
200
201   static const char* decode(const char* p, uint32_t* a, uint32_t* b,
202                             uint32_t* c, uint32_t* d) {
203     uint8_t key = p[0];
204     __m128i val = _mm_loadu_si128((const __m128i*)(p+1));
205     __m128i mask = detail::groupVarintSSEMasks[key];
206     __m128i r = _mm_shuffle_epi8(val, mask);
207
208     // Extracting 32 bits at a time out of an XMM register is a SSE4 feature
209 #ifdef __SSE4__
210     *a = _mm_extract_epi32(r, 0);
211     *b = _mm_extract_epi32(r, 1);
212     *c = _mm_extract_epi32(r, 2);
213     *d = _mm_extract_epi32(r, 3);
214 #else  /* !__SSE4__ */
215     *a = _mm_extract_epi16(r, 0) + (_mm_extract_epi16(r, 1) << 16);
216     *b = _mm_extract_epi16(r, 2) + (_mm_extract_epi16(r, 3) << 16);
217     *c = _mm_extract_epi16(r, 4) + (_mm_extract_epi16(r, 5) << 16);
218     *d = _mm_extract_epi16(r, 6) + (_mm_extract_epi16(r, 7) << 16);
219 #endif  /* __SSE4__ */
220
221     return p + detail::groupVarintLengths[key];
222   }
223
224 #else  /* !__SSSE3__ */
225   static const char* decode(const char* p, uint32_t* a, uint32_t* b,
226                             uint32_t* c, uint32_t* d) {
227     return decode_simple(p, a, b, c, d);
228   }
229
230   static const char* decode(const char* p, uint32_t* dest) {
231     return decode_simple(p, dest);
232   }
233 #endif  /* __SSSE3__ */
234
235  private:
236   static uint8_t key(uint32_t x) {
237     // __builtin_clz is undefined for the x==0 case
238     return 3 - (__builtin_clz(x|1) / 8);
239   }
240   static size_t b0key(size_t x) { return x & 3; }
241   static size_t b1key(size_t x) { return (x >> 2) & 3; }
242   static size_t b2key(size_t x) { return (x >> 4) & 3; }
243   static size_t b3key(size_t x) { return (x >> 6) & 3; }
244
245   static const uint32_t kMask[];
246 };
247
248
249 /**
250  * GroupVarint encoding for 64-bit values.
251  *
252  * Encodes 5 64-bit integers at once, each using 1-8 bytes depending on size.
253  * There are two bytes of overhead.  (The first two bytes contain the lengths
254  * of the five integers encoded as three bits each; 000=1 byte .. 111 = 8 bytes)
255  *
256  * This implementation assumes little-endian and does unaligned 64-bit
257  * accesses, so it's basically not portable outside of the x86[_64] world.
258  */
259 template <>
260 class GroupVarint<uint64_t> : public detail::GroupVarintBase<uint64_t> {
261  public:
262   /**
263    * Return the number of bytes used to encode these five values.
264    */
265   static size_t size(uint64_t a, uint64_t b, uint64_t c, uint64_t d,
266                      uint64_t e) {
267     return (kHeaderSize + kGroupSize +
268             key(a) + key(b) + key(c) + key(d) + key(e));
269   }
270
271   /**
272    * Return the number of bytes used to encode five uint64_t values stored
273    * at consecutive positions in an array.
274    */
275   static size_t size(const uint64_t* p) {
276     return size(p[0], p[1], p[2], p[3], p[4]);
277   }
278
279   /**
280    * Return the number of bytes used to encode count (<= 4) values.
281    * If you clip a buffer after these many bytes, you can still decode
282    * the first "count" values correctly (if the remaining size() -
283    * partialSize() bytes are filled with garbage).
284    */
285   static size_t partialSize(const type* p, size_t count) {
286     DCHECK_LE(count, kGroupSize);
287     size_t s = kHeaderSize + count;
288     for (; count; --count, ++p) {
289       s += key(*p);
290     }
291     return s;
292   }
293
294   /**
295    * Return the number of values from *p that are valid from an encoded
296    * buffer of size bytes.
297    */
298   static size_t partialCount(const char* p, size_t size) {
299     uint16_t v = loadUnaligned<uint16_t>(p);
300     size_t s = kHeaderSize;
301     s += 1 + b0key(v);
302     if (s > size) return 0;
303     s += 1 + b1key(v);
304     if (s > size) return 1;
305     s += 1 + b2key(v);
306     if (s > size) return 2;
307     s += 1 + b3key(v);
308     if (s > size) return 3;
309     s += 1 + b4key(v);
310     if (s > size) return 4;
311     return 5;
312   }
313
314   /**
315    * Given a pointer to the beginning of an GroupVarint64-encoded block,
316    * return the number of bytes used by the encoding.
317    */
318   static size_t encodedSize(const char* p) {
319     uint16_t n = loadUnaligned<uint16_t>(p);
320     return (kHeaderSize + kGroupSize +
321             b0key(n) + b1key(n) + b2key(n) + b3key(n) + b4key(n));
322   }
323
324   /**
325    * Encode five uint64_t values into the buffer pointed-to by p, and return
326    * the next position in the buffer (that is, one character past the last
327    * encoded byte).  p needs to have at least size()+8 bytes available.
328    */
329   static char* encode(char* p, uint64_t a, uint64_t b, uint64_t c,
330                       uint64_t d, uint64_t e) {
331     uint8_t b0key = key(a);
332     uint8_t b1key = key(b);
333     uint8_t b2key = key(c);
334     uint8_t b3key = key(d);
335     uint8_t b4key = key(e);
336     storeUnaligned<uint16_t>(
337         p,
338         (b4key << 12) | (b3key << 9) | (b2key << 6) | (b1key << 3) | b0key);
339     p += 2;
340     storeUnaligned(p, a);
341     p += b0key+1;
342     storeUnaligned(p, b);
343     p += b1key+1;
344     storeUnaligned(p, c);
345     p += b2key+1;
346     storeUnaligned(p, d);
347     p += b3key+1;
348     storeUnaligned(p, e);
349     p += b4key+1;
350     return p;
351   }
352
353   /**
354    * Encode five uint64_t values from the array pointed-to by src into the
355    * buffer pointed-to by p, similar to encode(p,a,b,c,d,e) above.
356    */
357   static char* encode(char* p, const uint64_t* src) {
358     return encode(p, src[0], src[1], src[2], src[3], src[4]);
359   }
360
361   /**
362    * Decode five uint64_t values from a buffer, and return the next position
363    * in the buffer (that is, one character past the last encoded byte).
364    * The buffer needs to have at least 7 bytes available (they may be read
365    * but ignored).
366    */
367   static const char* decode(const char* p, uint64_t* a, uint64_t* b,
368                             uint64_t* c, uint64_t* d, uint64_t* e) {
369     uint16_t k = loadUnaligned<uint16_t>(p);
370     p += 2;
371     uint8_t k0 = b0key(k);
372     *a = loadUnaligned<uint64_t>(p) & kMask[k0];
373     p += k0+1;
374     uint8_t k1 = b1key(k);
375     *b = loadUnaligned<uint64_t>(p) & kMask[k1];
376     p += k1+1;
377     uint8_t k2 = b2key(k);
378     *c = loadUnaligned<uint64_t>(p) & kMask[k2];
379     p += k2+1;
380     uint8_t k3 = b3key(k);
381     *d = loadUnaligned<uint64_t>(p) & kMask[k3];
382     p += k3+1;
383     uint8_t k4 = b4key(k);
384     *e = loadUnaligned<uint64_t>(p) & kMask[k4];
385     p += k4+1;
386     return p;
387   }
388
389   /**
390    * Decode five uint64_t values from a buffer and store them in the array
391    * pointed-to by dest, similar to decode(p,a,b,c,d,e) above.
392    */
393   static const char* decode(const char* p, uint64_t* dest) {
394     return decode(p, dest, dest+1, dest+2, dest+3, dest+4);
395   }
396
397  private:
398   enum { kHeaderBytes = 2 };
399
400   static uint8_t key(uint64_t x) {
401     // __builtin_clzll is undefined for the x==0 case
402     return 7 - (__builtin_clzll(x|1) / 8);
403   }
404
405   static uint8_t b0key(uint16_t x) { return x & 7; }
406   static uint8_t b1key(uint16_t x) { return (x >> 3) & 7; }
407   static uint8_t b2key(uint16_t x) { return (x >> 6) & 7; }
408   static uint8_t b3key(uint16_t x) { return (x >> 9) & 7; }
409   static uint8_t b4key(uint16_t x) { return (x >> 12) & 7; }
410
411   static const uint64_t kMask[];
412 };
413
414 typedef GroupVarint<uint32_t> GroupVarint32;
415 typedef GroupVarint<uint64_t> GroupVarint64;
416
417 /**
418  * Simplify use of GroupVarint* for the case where data is available one
419  * entry at a time (instead of one group at a time).  Handles buffering
420  * and an incomplete last chunk.
421  *
422  * Output is a function object that accepts character ranges:
423  * out(StringPiece) appends the given character range to the output.
424  */
425 template <class T, class Output>
426 class GroupVarintEncoder {
427  public:
428   typedef GroupVarint<T> Base;
429   typedef T type;
430
431   explicit GroupVarintEncoder(Output out)
432     : out_(out),
433       count_(0) {
434   }
435
436   ~GroupVarintEncoder() {
437     finish();
438   }
439
440   /**
441    * Add a value to the encoder.
442    */
443   void add(type val) {
444     buf_[count_++] = val;
445     if (count_ == Base::kGroupSize) {
446       char* p = Base::encode(tmp_, buf_);
447       out_(StringPiece(tmp_, p));
448       count_ = 0;
449     }
450   }
451
452   /**
453    * Finish encoding, flushing any buffered values if necessary.
454    * After finish(), the encoder is immediately ready to encode more data
455    * to the same output.
456    */
457   void finish() {
458     if (count_) {
459       // This is not strictly necessary, but it makes testing easy;
460       // uninitialized bytes are guaranteed to be recorded as taking one byte
461       // (not more).
462       for (size_t i = count_; i < Base::kGroupSize; i++) {
463         buf_[i] = 0;
464       }
465       Base::encode(tmp_, buf_);
466       out_(StringPiece(tmp_, Base::partialSize(buf_, count_)));
467       count_ = 0;
468     }
469   }
470
471   /**
472    * Return the appender that was used.
473    */
474   Output& output() {
475     return out_;
476   }
477   const Output& output() const {
478     return out_;
479   }
480
481   /**
482    * Reset the encoder, disregarding any state (except what was already
483    * flushed to the output, of course).
484    */
485   void clear() {
486     count_ = 0;
487   }
488
489  private:
490   Output out_;
491   char tmp_[Base::kMaxSize];
492   type buf_[Base::kGroupSize];
493   size_t count_;
494 };
495
496 /**
497  * Simplify use of GroupVarint* for the case where the last group in the
498  * input may be incomplete (but the exact size of the input is known).
499  * Allows for extracting values one at a time.
500  */
501 template <typename T>
502 class GroupVarintDecoder {
503  public:
504   typedef GroupVarint<T> Base;
505   typedef T type;
506
507   GroupVarintDecoder() { }
508
509   explicit GroupVarintDecoder(StringPiece data,
510                               size_t maxCount = (size_t)-1)
511     : rrest_(data.end()),
512       p_(data.data()),
513       end_(data.end()),
514       limit_(end_),
515       pos_(0),
516       count_(0),
517       remaining_(maxCount) {
518   }
519
520   void reset(StringPiece data, size_t maxCount = (size_t)-1) {
521     rrest_ = data.end();
522     p_ = data.data();
523     end_ = data.end();
524     limit_ = end_;
525     pos_ = 0;
526     count_ = 0;
527     remaining_ = maxCount;
528   }
529
530   /**
531    * Read and return the next value.
532    */
533   bool next(type* val) {
534     if (pos_ == count_) {
535       // refill
536       size_t rem = end_ - p_;
537       if (rem == 0 || remaining_ == 0) {
538         return false;
539       }
540       // next() attempts to read one full group at a time, and so we must have
541       // at least enough bytes readable after its end to handle the case if the
542       // last group is full.
543       //
544       // The best way to ensure this is to ensure that data has at least
545       // Base::kMaxSize - 1 bytes readable *after* the end, otherwise we'll copy
546       // into a temporary buffer.
547       if (limit_ - p_ < Base::kMaxSize) {
548         memcpy(tmp_, p_, rem);
549         p_ = tmp_;
550         end_ = p_ + rem;
551         limit_ = tmp_ + sizeof(tmp_);
552       }
553       pos_ = 0;
554       const char* n = Base::decode(p_, buf_);
555       if (n <= end_) {
556         // Full group could be decoded
557         if (remaining_ >= Base::kGroupSize) {
558           remaining_ -= Base::kGroupSize;
559           count_ = Base::kGroupSize;
560           p_ = n;
561         } else {
562           count_ = remaining_;
563           remaining_ = 0;
564           p_ += Base::partialSize(buf_, count_);
565         }
566       } else {
567         // Can't decode a full group
568         count_ = Base::partialCount(p_, end_ - p_);
569         if (remaining_ >= count_) {
570           remaining_ -= count_;
571           p_ = end_;
572         } else {
573           count_ = remaining_;
574           remaining_ = 0;
575           p_ += Base::partialSize(buf_, count_);
576         }
577         if (count_ == 0) {
578           return false;
579         }
580       }
581     }
582     *val = buf_[pos_++];
583     return true;
584   }
585
586   StringPiece rest() const {
587     // This is only valid after next() returned false
588     CHECK(pos_ == count_ && (p_ == end_ || remaining_ == 0));
589     // p_ may point to the internal buffer (tmp_), but we want
590     // to return subpiece of the original data
591     size_t size = end_ - p_;
592     return StringPiece(rrest_ - size, rrest_);
593   }
594
595  private:
596   const char* rrest_;
597   const char* p_;
598   const char* end_;
599   const char* limit_;
600   char tmp_[2 * Base::kMaxSize];
601   type buf_[Base::kGroupSize];
602   size_t pos_;
603   size_t count_;
604   size_t remaining_;
605 };
606
607 typedef GroupVarintDecoder<uint32_t> GroupVarint32Decoder;
608 typedef GroupVarintDecoder<uint64_t> GroupVarint64Decoder;
609
610 }  // namespace folly
611
612 #endif /* FOLLY_X64 || defined(__i386__) */
613 #endif /* FOLLY_GROUPVARINT_H_ */
614