PyORAm
[iotcloud.git] / PyORAM / src / pyoram / oblivious_storage / tree / tree_oram_helper.py
1 __all__ = ('TreeORAMStorageManagerExplicitAddressing',
2            'TreeORAMStorageManagerPointerAddressing')
3
4 import struct
5 import copy
6
7 from pyoram.util.virtual_heap import \
8     SizedVirtualHeap
9
10 from six.moves import xrange
11
12 class TreeORAMStorage(object):
13
14     empty_block_id = -1
15
16     block_status_storage_string = "!?"
17     block_id_storage_string = "!L"
18     block_info_storage_string = "!?L"
19
20     block_status_storage_size = \
21         struct.calcsize(block_status_storage_string)
22     block_info_storage_size = \
23         struct.calcsize(block_info_storage_string)
24
25     empty_block_bytes_tag = \
26         struct.pack(block_status_storage_string, False)
27     real_block_bytes_tag = \
28         struct.pack(block_status_storage_string, True)
29
30     def __init__(self,
31                  storage_heap,
32                  stash):
33         self.storage_heap = storage_heap
34         self.stash = stash
35
36         vheap = self.storage_heap.virtual_heap
37         self.bucket_size = self.storage_heap.bucket_size
38         self.block_size = self.bucket_size // vheap.blocks_per_bucket
39         assert self.block_size * vheap.blocks_per_bucket == \
40             self.bucket_size
41
42         self.path_stop_bucket = None
43         self.path_bucket_count = 0
44         self.path_byte_dataview = \
45             bytearray(self.bucket_size * vheap.levels)
46         dataview = memoryview(self.path_byte_dataview)
47         self.path_bucket_dataview = \
48             [dataview[(i*self.bucket_size):((i+1)*self.bucket_size)]
49              for i in xrange(vheap.levels)]
50
51         self.path_block_dataview = []
52         for i in xrange(vheap.levels):
53             bucketview = self.path_bucket_dataview[i]
54             for j in xrange(vheap.blocks_per_bucket):
55                 self.path_block_dataview.append(
56                     bucketview[(j*self.block_size):((j+1)*self.block_size)])
57
58         max_blocks_on_path = vheap.levels * vheap.blocks_per_bucket
59         assert len(self.path_block_dataview) == max_blocks_on_path
60         self.path_block_ids = [-1] * max_blocks_on_path
61         self.path_block_eviction_levels = [None] * max_blocks_on_path
62         self.path_block_reordering = [None] * max_blocks_on_path
63         self.path_blocks_inserted = []
64
65     def load_path(self, b):
66         vheap = self.storage_heap.virtual_heap
67         Z = vheap.blocks_per_bucket
68         lcl = vheap.clib.calculate_last_common_level
69         k = vheap.k
70
71         read_level_start = 0
72         if self.path_stop_bucket is not None:
73             # don't download the root and any other buckets
74             # that are common between the previous bucket path
75             # and the new one
76             read_level_start = lcl(k, self.path_stop_bucket, b)
77         assert 0 <= b < vheap.bucket_count()
78         self.path_stop_bucket = b
79         new_buckets = self.storage_heap.read_path(
80             self.path_stop_bucket,
81             level_start=read_level_start)
82
83         self.path_bucket_count = read_level_start + len(new_buckets)
84         pos = 0
85         for i in xrange(self.path_bucket_count):
86             if i >= read_level_start:
87                 self.path_bucket_dataview[i][:] = \
88                     new_buckets[i-read_level_start][:]
89             for j in xrange(Z):
90                 block_id, block_addr = \
91                     self.get_block_info(self.path_block_dataview[pos])
92                 self.path_block_ids[pos] = block_id
93                 if block_id != self.empty_block_id:
94                     self.path_block_eviction_levels[pos] = \
95                         lcl(k, self.path_stop_bucket, block_addr)
96                 else:
97                     self.path_block_eviction_levels[pos] = None
98                 self.path_block_reordering[pos] = None
99                 pos += 1
100
101         max_blocks_on_path = vheap.levels * Z
102         while pos != max_blocks_on_path:
103             self.path_block_ids[pos] = None
104             self.path_block_eviction_levels[pos] = None
105             self.path_block_reordering[pos] = None
106             pos += 1
107
108         self.path_blocks_inserted = []
109
110     def push_down_path(self):
111         vheap = self.storage_heap.virtual_heap
112         Z = vheap.blocks_per_bucket
113
114         bucket_count = self.path_bucket_count
115         block_ids = self.path_block_ids
116         block_eviction_levels = self.path_block_eviction_levels
117         block_reordering = self.path_block_reordering
118         def _do_swap(write_pos, read_pos):
119             block_ids[write_pos], block_eviction_levels[write_pos] = \
120                 block_ids[read_pos], block_eviction_levels[read_pos]
121             block_ids[read_pos], block_eviction_levels[read_pos] = \
122                 self.empty_block_id, None
123             block_reordering[write_pos] = read_pos
124             block_reordering[read_pos] = -1
125
126         def _new_write_pos(current):
127             current -= 1
128             if current < 0:
129                 return None, None
130             while (block_eviction_levels[current] is not None):
131                 current -= 1
132                 if current < 0:
133                     return None, None
134             assert block_ids[current] == \
135                 self.empty_block_id
136             return current, current // Z
137
138         def _new_read_pos(current):
139             current -= 1
140             if current < 0:
141                 return None
142             while (block_eviction_levels[current] is None):
143                 current -= 1
144                 if current < 0:
145                     return None
146             assert block_ids[current] != \
147                 self.empty_block_id
148             return current
149
150         write_pos, write_level = _new_write_pos(bucket_count * Z)
151         while write_pos is not None:
152             read_pos = _new_read_pos(write_pos)
153             if read_pos is None:
154                 break
155             while ((read_pos // Z) == write_level) or \
156                   (write_level > block_eviction_levels[read_pos]):
157                 read_pos = _new_read_pos(read_pos)
158                 if read_pos is None:
159                     break
160             if read_pos is not None:
161                 _do_swap(write_pos, read_pos)
162             else:
163                 # Jump directly to the start of this
164                 # bucket. There is not point in checking
165                 # for other empty slots because no blocks
166                 # can be evicted to this level.
167                 write_pos = Z * (write_pos//Z)
168             write_pos, write_level = _new_write_pos(write_pos)
169
170     def fill_path_from_stash(self):
171         vheap = self.storage_heap.virtual_heap
172         lcl = vheap.clib.calculate_last_common_level
173         k = vheap.k
174         Z = vheap.blocks_per_bucket
175
176         bucket_count = self.path_bucket_count
177         stop_bucket = self.path_stop_bucket
178         block_ids = self.path_block_ids
179         block_eviction_levels = self.path_block_eviction_levels
180         blocks_inserted = self.path_blocks_inserted
181
182         stash_eviction_levels = {}
183         largest_write_position = (bucket_count * Z) - 1
184         for write_pos in xrange(largest_write_position,-1,-1):
185             write_level = write_pos // Z
186             if block_ids[write_pos] == self.empty_block_id:
187                 del_id = None
188                 for id_ in self.stash:
189                     if id_ not in stash_eviction_levels:
190                         block_id, block_addr = \
191                             self.get_block_info(self.stash[id_])
192                         assert id_ != self.empty_block_id
193                         eviction_level = stash_eviction_levels[id_] = \
194                             lcl(k, stop_bucket, block_addr)
195                     else:
196                         eviction_level = stash_eviction_levels[id_]
197                     if write_level <= eviction_level:
198                         block_ids[write_pos] = id_
199                         block_eviction_levels[write_pos] = \
200                             eviction_level
201                         blocks_inserted.append(
202                             (write_pos, self.stash[id_]))
203                         del_id = id_
204                         break
205                 if del_id is not None:
206                     del self.stash[del_id]
207
208     def evict_path(self):
209         vheap = self.storage_heap.virtual_heap
210         Z = vheap.blocks_per_bucket
211
212         bucket_count = self.path_bucket_count
213         stop_bucket = self.path_stop_bucket
214         bucket_dataview = self.path_bucket_dataview
215         block_dataview = self.path_block_dataview
216         block_reordering = self.path_block_reordering
217         blocks_inserted = self.path_blocks_inserted
218
219         for i, read_pos in enumerate(
220                 reversed(block_reordering)):
221             if (read_pos is not None) and \
222                (read_pos != -1):
223                 write_pos = len(block_reordering) - 1 - i
224                 block_dataview[write_pos][:] = block_dataview[read_pos][:]
225
226         for write_pos, read_pos in enumerate(block_reordering):
227             if read_pos == -1:
228                 self.tag_block_as_empty(block_dataview[write_pos])
229
230         for write_pos, block in blocks_inserted:
231             block_dataview[write_pos][:] = block[:]
232
233         self.storage_heap.write_path(
234             stop_bucket,
235             (bucket_dataview[i].tobytes()
236              for i in xrange(bucket_count)))
237
238     def extract_block_from_path(self, id_):
239         block_ids = self.path_block_ids
240         block_dataview = self.path_block_dataview
241         try:
242             pos = block_ids.index(id_)
243             # make a copy
244             block = bytearray(block_dataview[pos])
245             self._set_path_position_to_empty(pos)
246             return block
247         except ValueError:
248             return None
249
250     def _set_path_position_to_empty(self, pos):
251         self.path_block_ids[pos] = self.empty_block_id
252         self.path_block_eviction_levels[pos] = None
253         self.path_block_reordering[pos] = -1
254
255     @staticmethod
256     def tag_block_as_empty(block):
257         block[:TreeORAMStorage.block_status_storage_size] = \
258             TreeORAMStorage.empty_block_bytes_tag[:]
259
260     @staticmethod
261     def tag_block_with_id(block, id_):
262         assert id_ >= 0
263         struct.pack_into(TreeORAMStorage.block_info_storage_string,
264                          block,
265                          0,
266                          True,
267                          id_)
268
269     def get_block_info(self, block):
270         raise NotImplementedError                      # pragma: no cover
271
272 class TreeORAMStorageManagerExplicitAddressing(
273         TreeORAMStorage):
274     """
275     This class should be used to implement tree-based ORAMs
276     that use an explicit position map. Blocks are assumed to
277     begin with bytes representing the block id.
278     """
279
280     block_info_storage_string = \
281         TreeORAMStorage.block_info_storage_string
282     block_info_storage_size = \
283         struct.calcsize(block_info_storage_string)
284
285     def __init__(self,
286                  storage_heap,
287                  stash,
288                  position_map):
289         super(TreeORAMStorageManagerExplicitAddressing, self).\
290             __init__(storage_heap, stash)
291         self.position_map = position_map
292
293     def get_block_info(self, block):
294         real, id_ = struct.unpack_from(
295             self.block_info_storage_string, block)
296         if real:
297             return id_, self.position_map[id_]
298         else:
299             return self.empty_block_id, None
300
301 class TreeORAMStorageManagerPointerAddressing(
302         TreeORAMStorage):
303     """
304     This class should be used to implement tree-based ORAMs
305     that use a pointer-based position map stored with the
306     blocks. Blocks are assumed to begin with bytes
307     representing the block id followed by bytes representing
308     the blocks current heap bucket address.
309     """
310
311     block_info_storage_string = \
312         TreeORAMStorage.block_info_storage_string + "L"
313     block_info_storage_size = \
314         struct.calcsize(block_info_storage_string)
315
316     def __init__(self,
317                  storage_heap,
318                  stash):
319         super(TreeORAMStorageManagerPointerAddressing, self).\
320             __init__(storage_heap, stash)
321         self.position_map = None
322
323     def get_block_info(self, block):
324         real, id_, addr = struct.unpack_from(
325             self.block_info_storage_string, block)
326         if not real:
327             return self.empty_block_id, 0
328         else:
329             return id_, addr