PyORAm
[iotcloud.git] / PyORAM / src / pyoram / oblivious_storage / tree / path_oram.py
1 import hashlib
2 import hmac
3 import struct
4 import array
5 import logging
6
7 import pyoram
8 from pyoram.oblivious_storage.tree.tree_oram_helper import \
9     (TreeORAMStorage,
10      TreeORAMStorageManagerExplicitAddressing)
11 from pyoram.encrypted_storage.encrypted_block_storage import \
12     EncryptedBlockStorageInterface
13 from pyoram.encrypted_storage.encrypted_heap_storage import \
14     (EncryptedHeapStorage,
15      EncryptedHeapStorageInterface)
16 from pyoram.encrypted_storage.top_cached_encrypted_heap_storage import \
17     TopCachedEncryptedHeapStorage
18 from pyoram.util.virtual_heap import \
19     (SizedVirtualHeap,
20      calculate_necessary_heap_height)
21
22 import tqdm
23 import six
24 from six.moves import xrange
25
26 log = logging.getLogger("pyoram")
27
28 class PathORAM(EncryptedBlockStorageInterface):
29
30     _header_struct_string = "!"+("x"*2*hashlib.sha384().digest_size)+"L"
31     _header_offset = struct.calcsize(_header_struct_string)
32
33     def __init__(self,
34                  storage,
35                  stash,
36                  position_map,
37                  **kwds):
38
39
40         self._oram = None
41         self._block_count = None
42
43         if isinstance(storage, EncryptedHeapStorageInterface):
44             storage_heap = storage
45             close_storage_heap = False
46             if len(kwds):
47                 raise ValueError(
48                     "Keywords not used when initializing "
49                     "with a storage device: %s"
50                     % (str(kwds)))
51         else:
52             cached_levels = kwds.pop('cached_levels', 3)
53             # cached_levels = kwds.pop('cached_levels', 1)
54             concurrency_level = kwds.pop('concurrency_level', None)
55             close_storage_heap = True
56             storage_heap = TopCachedEncryptedHeapStorage(EncryptedHeapStorage(storage, **kwds), cached_levels=cached_levels, concurrency_level=concurrency_level)
57             # storage_heap = EncryptedHeapStorage(storage, **kwds) 
58
59         (self._block_count,) = struct.unpack(
60             self._header_struct_string,
61             storage_heap.header_data[:self._header_offset])
62         stashdigest = storage_heap.\
63                       header_data[:hashlib.sha384().digest_size]
64         positiondigest = storage_heap.\
65             header_data[hashlib.sha384().digest_size:\
66                         (2*hashlib.sha384().digest_size)]
67
68         try:
69             if stashdigest != \
70                PathORAM.stash_digest(
71                    stash,
72                    digestmod=hmac.HMAC(key=storage_heap.key,
73                                        digestmod=hashlib.sha384)):
74                 raise ValueError(
75                     "Stash HMAC does not match that saved with "
76                     "storage heap %s" % (storage_heap.storage_name))
77         except:
78             if close_storage_heap:
79                 storage_heap.close()
80             raise
81
82         try:
83             if positiondigest != \
84                PathORAM.position_map_digest(
85                    position_map,
86                    digestmod=hmac.HMAC(key=storage_heap.key,
87                                        digestmod=hashlib.sha384)):
88                 raise ValueError(
89                     "Position map HMAC does not match that saved with "
90                     "storage heap %s" % (storage_heap.storage_name))
91         except:
92             if close_storage_heap:
93                 storage_heap.close()
94             raise
95
96         self._oram = TreeORAMStorageManagerExplicitAddressing(
97             storage_heap,
98             stash,
99             position_map)
100         assert self._block_count <= \
101             self._oram.storage_heap.bucket_count
102
103     @classmethod
104     def _init_position_map(cls, vheap, block_count):
105         return array.array("L", [vheap.random_leaf_bucket()
106                                  for i in xrange(block_count)])
107
108     def _init_oram_block(self, id_, block):
109         oram_block = bytearray(self.block_size)
110         oram_block[self._oram.block_info_storage_size:] = block[:]
111         self._oram.tag_block_with_id(oram_block, id_)
112         return oram_block
113
114     def _extract_virtual_block(self, block):
115         return block[self._oram.block_info_storage_size:]
116
117     #
118     # Add some methods specific to Path ORAM
119     #
120
121     @classmethod
122     def stash_digest(cls, stash, digestmod=None):
123         if digestmod is None:
124             digestmod = hashlib.sha1()
125         id_to_bytes = lambda id_: \
126             struct.pack(TreeORAMStorage.block_id_storage_string, id_)
127         if len(stash) == 0:
128             digestmod.update(b'0')
129         else:
130             for id_ in sorted(stash):
131                 if id_ < 0:
132                     raise ValueError(
133                         "Invalid stash id '%s'. Values must be "
134                         "nonnegative integers." % (id_))
135                 digestmod.update(id_to_bytes(id_))
136                 digestmod.update(bytes(stash[id_]))
137         return digestmod.digest()
138
139     @classmethod
140     def position_map_digest(cls, position_map, digestmod=None):
141         if digestmod is None:
142             digestmod = hashlib.sha1()
143         id_to_bytes = lambda id_: \
144             struct.pack(TreeORAMStorage.block_id_storage_string, id_)
145         assert len(position_map) > 0
146         for addr in position_map:
147             if addr < 0:
148                 raise ValueError(
149                     "Invalid position map address '%s'. Values must be "
150                     "nonnegative integers." % (addr))
151             digestmod.update(id_to_bytes(addr))
152         return digestmod.digest()
153
154     @property
155     def position_map(self):
156         return self._oram.position_map
157
158     @property
159     def stash(self):
160         return self._oram.stash
161
162     def access(self, id_, write_block=None):
163         assert 0 <= id_ <= self.block_count
164         bucket = self.position_map[id_]
165         bucket_level = self._oram.storage_heap.virtual_heap.Node(bucket).level
166         self.position_map[id_] = \
167             self._oram.storage_heap.virtual_heap.\
168             random_bucket_at_level(bucket_level)
169         self._oram.load_path(bucket)
170         block = self._oram.extract_block_from_path(id_)
171         
172         if block is None:
173             block = self.stash[id_]
174         
175
176         if write_block is not None:
177             block = self._init_oram_block(id_, write_block)
178
179
180         self.stash[id_] = block
181         self._oram.push_down_path()
182         self._oram.fill_path_from_stash()
183         self._oram.evict_path()
184         if write_block is None:
185             return self._extract_virtual_block(block)
186
187     @property
188     def heap_storage(self):
189         return self._oram.storage_heap
190
191     #
192     # Define EncryptedBlockStorageInterface Methods
193     #
194
195     @property
196     def key(self):
197         return self._oram.storage_heap.key
198
199     @property
200     def raw_storage(self):
201         return self._oram.storage_heap.raw_storage
202
203     #
204     # Define BlockStorageInterface Methods
205     #
206
207     @classmethod
208     def compute_storage_size(cls,
209                              block_size,
210                              block_count,
211                              bucket_capacity=4,
212                              heap_base=2,
213                              ignore_header=False,
214                              **kwds):
215         assert (block_size > 0) and (block_size == int(block_size))
216         assert (block_count > 0) and (block_count == int(block_count))
217         assert bucket_capacity >= 1
218         assert heap_base >= 2
219         assert 'heap_height' not in kwds
220         heap_height = calculate_necessary_heap_height(heap_base,
221                                                       block_count)
222         block_size += TreeORAMStorageManagerExplicitAddressing.\
223                       block_info_storage_size
224         if ignore_header:
225             return EncryptedHeapStorage.compute_storage_size(
226                 block_size,
227                 heap_height,
228                 blocks_per_bucket=bucket_capacity,
229                 heap_base=heap_base,
230                 ignore_header=True,
231                 **kwds)
232         else:
233             return cls._header_offset + \
234                    EncryptedHeapStorage.compute_storage_size(
235                        block_size,
236                        heap_height,
237                        blocks_per_bucket=bucket_capacity,
238                        heap_base=heap_base,
239                        ignore_header=False,
240                        **kwds)
241
242     @classmethod
243     def setup(cls,
244               storage_name,
245               block_size,
246               block_count,
247               bucket_capacity=4,
248               heap_base=2,
249               cached_levels=3,
250               concurrency_level=None,
251               **kwds):
252         if 'heap_height' in kwds:
253             raise ValueError("'heap_height' keyword is not accepted")
254         if (bucket_capacity <= 0) or \
255            (bucket_capacity != int(bucket_capacity)):
256             raise ValueError(
257                 "Bucket capacity must be a positive integer: %s"
258                 % (bucket_capacity))
259         if (block_size <= 0) or (block_size != int(block_size)):
260             raise ValueError(
261                 "Block size (bytes) must be a positive integer: %s"
262                 % (block_size))
263         if (block_count <= 0) or (block_count != int(block_count)):
264             raise ValueError(
265                 "Block count must be a positive integer: %s"
266                 % (block_count))
267
268         if heap_base < 2:
269             raise ValueError(
270                 "heap base must be 2 or greater. Invalid value: %s"
271                 % (heap_base))
272
273         heap_height = calculate_necessary_heap_height(heap_base,
274                                                       block_count)
275         stash = {}
276         vheap = SizedVirtualHeap(
277             heap_base,
278             heap_height,
279             blocks_per_bucket=bucket_capacity)
280         position_map = cls._init_position_map(vheap, block_count)
281
282         oram_block_size = block_size + \
283                           TreeORAMStorageManagerExplicitAddressing.\
284                           block_info_storage_size
285
286         user_header_data = kwds.pop('header_data', bytes())
287         if type(user_header_data) is not bytes:
288             raise TypeError(
289                 "'header_data' must be of type bytes. "
290                 "Invalid type: %s" % (type(user_header_data)))
291
292         initialize = kwds.pop('initialize', None)
293
294         header_data = struct.pack(
295             cls._header_struct_string,
296             block_count)
297         kwds['header_data'] = bytes(header_data) + user_header_data
298         empty_bucket = bytearray(oram_block_size * bucket_capacity)
299         empty_bucket_view = memoryview(empty_bucket)
300         for i in xrange(bucket_capacity):
301             TreeORAMStorageManagerExplicitAddressing.tag_block_as_empty(
302                 empty_bucket_view[(i*oram_block_size):\
303                                   ((i+1)*oram_block_size)])
304         empty_bucket = bytes(empty_bucket)
305
306         kwds['initialize'] = lambda i: empty_bucket
307         f = None
308         try:
309             log.info("%s: setting up encrypted heap storage"
310                      % (cls.__name__))
311             f = EncryptedHeapStorage.setup(storage_name,
312                                            oram_block_size,
313                                            heap_height,
314                                            heap_base=heap_base,
315                                            blocks_per_bucket=bucket_capacity,
316                                            **kwds)
317             if cached_levels != 0:
318                 f = TopCachedEncryptedHeapStorage(
319                     f,
320                     cached_levels=cached_levels,
321                     concurrency_level=concurrency_level)
322             elif concurrency_level is not None:
323                 raise ValueError(                      # pragma: no cover
324                     "'concurrency_level' keyword is "  # pragma: no cover
325                     "not used when no heap levels "    # pragma: no cover
326                     "are cached")                      # pragma: no cover
327             oram = TreeORAMStorageManagerExplicitAddressing(
328                 f, stash, position_map)
329             if initialize is None:
330                 zeros = bytes(bytearray(block_size))
331                 initialize = lambda i: zeros
332             initial_oram_block = bytearray(oram_block_size)
333             for i in tqdm.tqdm(xrange(block_count),
334                                desc=("Initializing %s Blocks" % (cls.__name__)),
335                                total=block_count,
336                                disable=not pyoram.config.SHOW_PROGRESS_BAR):
337
338                 oram.tag_block_with_id(initial_oram_block, i)
339                 initial_oram_block[oram.block_info_storage_size:] = \
340                     initialize(i)[:]
341
342                 bucket = oram.position_map[i]
343                 bucket_level = vheap.Node(bucket).level
344                 oram.position_map[i] = \
345                     oram.storage_heap.virtual_heap.\
346                     random_bucket_at_level(bucket_level)
347
348                 oram.load_path(bucket)
349                 oram.push_down_path()
350                 # place a copy in the stash
351                 oram.stash[i] = bytearray(initial_oram_block)
352                 oram.fill_path_from_stash()
353                 oram.evict_path()
354
355             header_data = bytearray(header_data)
356             stash_digest = cls.stash_digest(
357                 oram.stash,
358                 digestmod=hmac.HMAC(key=oram.storage_heap.key,
359                                     digestmod=hashlib.sha384))
360             position_map_digest = cls.position_map_digest(
361                 oram.position_map,
362                 digestmod=hmac.HMAC(key=oram.storage_heap.key,
363                                     digestmod=hashlib.sha384))
364             header_data[:len(stash_digest)] = stash_digest[:]
365             header_data[len(stash_digest):\
366                         (len(stash_digest)+len(position_map_digest))] = \
367                 position_map_digest[:]
368             f.update_header_data(bytes(header_data) + user_header_data)
369             return PathORAM(f, stash, position_map=position_map)
370         except:
371             if f is not None:
372                 f.close()                              # pragma: no cover
373             raise
374
375     @property
376     def header_data(self):
377         return self._oram.storage_heap.\
378             header_data[self._header_offset:]
379
380     @property
381     def block_count(self):
382         return self._block_count
383
384     @property
385     def block_size(self):
386         return self._oram.block_size - self._oram.block_info_storage_size
387
388     @property
389     def storage_name(self):
390         return self._oram.storage_heap.storage_name
391
392     def update_header_data(self, new_header_data):
393         self._oram.storage_heap.update_header_data(
394             self._oram.storage_heap.header_data[:self._header_offset] + \
395             new_header_data)
396
397     def close(self):
398         log.info("%s: Closing" % (self.__class__.__name__))
399         print("Closing")
400
401         if self._oram is not None:
402             try:
403                 stashdigest = \
404                     PathORAM.stash_digest(
405                         self._oram.stash,
406                         digestmod=hmac.HMAC(key=self._oram.storage_heap.key,
407                                             digestmod=hashlib.sha384))
408
409                 print("Closing 1")
410                 positiondigest = \
411                     PathORAM.position_map_digest(
412                         self._oram.position_map,
413                         digestmod=hmac.HMAC(key=self._oram.storage_heap.key,
414                                             digestmod=hashlib.sha384))
415
416                 print("Closing 2")
417                 new_header_data = \
418                     bytearray(self._oram.storage_heap.\
419                               header_data[:self._header_offset])
420
421                 print("Closing 3")
422                 new_header_data[:hashlib.sha384().digest_size] = \
423                     stashdigest
424                 new_header_data[hashlib.sha384().digest_size:\
425                                 (2*hashlib.sha384().digest_size)] = \
426                     positiondigest
427
428                 print("Closing 4")
429                 self._oram.storage_heap.update_header_data(
430                     bytes(new_header_data) + self.header_data)
431                 print("Closing 5")
432             except:                                                # pragma: no cover
433                 log.error(                                         # pragma: no cover
434                     "%s: Failed to update header data with "       # pragma: no cover
435                     "current stash and position map state"         # pragma: no cover
436                     % (self.__class__.__name__))                   # pragma: no cover
437                 print("Closing ")
438                 raise
439             finally:
440                 print("Closing 6")
441                 self._oram.storage_heap.close()
442                 print("Closing 7")
443
444     def read_blocks(self, indices):
445         blocks = []
446         for i in indices:
447             blocks.append(self.access(i))
448         return blocks
449
450     def read_block(self, i):
451         return self.access(i)
452
453     def write_blocks(self, indices, blocks):
454         for i, block in zip(indices, blocks):
455             self.access(i, write_block=block)
456
457     def write_block(self, i, block):
458         self.access(i, write_block=block)
459
460     @property
461     def bytes_sent(self):
462         return self._oram.storage_heap.bytes_sent
463
464     @property
465     def bytes_received(self):
466         return self._oram.storage_heap.bytes_received