PyORAm
[iotcloud.git] / PyORAM / src / pyoram / encrypted_storage / top_cached_encrypted_heap_storage.py
1 __all__ = ('TopCachedEncryptedHeapStorage',)
2
3 import logging
4 import tempfile
5 import mmap
6
7 import pyoram
8 from pyoram.util.virtual_heap import SizedVirtualHeap
9 from pyoram.encrypted_storage.encrypted_heap_storage import \
10     (EncryptedHeapStorageInterface,
11      EncryptedHeapStorage)
12
13 import tqdm
14 import six
15 from six.moves import xrange
16
17 log = logging.getLogger("pyoram")
18
19 class TopCachedEncryptedHeapStorage(EncryptedHeapStorageInterface):
20     """
21     An encrypted block storage device for accessing memory
22     organized as a heap, where the top 1 or more levels can
23     be cached in local memory. This achieves two things:
24
25       (1) Reduces the number of buckets that need to be read
26           from or written to external storage for a given
27           path I/O operation.
28       (2) Allows certain block storage devices to achieve
29           concurrency across path writes by partioning the
30           storage space into independent subheaps starting
31           below the cache line.
32
33     This devices takes as input an existing encrypted heap
34     storage device. This class should not be cloned or used
35     to setup storage, but rather used as a wrapper class for
36     an existing heap storage device to speed up a bulk set
37     of I/O requests. The original heap storage device should
38     not be used after it is wrapped by this class. This
39     class will close the original device when closing
40     itself.
41
42     The number of cached levels (starting from the root
43     bucket at level 0) can be set with the 'cached_levels'
44     keyword (>= 1).
45
46     By default, this will create an independent storage
47     device capable of reading from and writing to the
48     original storage devices memory for each independent
49     subheap (if any) below the last cached level. The
50     'concurrency_level' keyword can be used to limit the
51     number of concurrent devices to some level below the
52     cache line (>= 0, <= 'cached_levels').
53
54     Values for 'cached_levels' and 'concurrency_level' will
55     be automatically reduced when they are larger than what
56     is allowed by the heap size.
57     """
58
59     def __new__(cls, *args, **kwds):
60         if kwds.get("cached_levels", 1) == 0:
61             assert len(args) == 1
62             storage = args[0]
63             storage.cached_bucket_data = bytes()
64             return storage
65         else:
66             return super(TopCachedEncryptedHeapStorage, cls).\
67                 __new__(cls)
68
69     def __init__(self,
70                  heap_storage,
71                  cached_levels=1,
72                  concurrency_level=None):
73         assert isinstance(heap_storage, EncryptedHeapStorage)
74         assert cached_levels != 0
75
76
77         vheap = heap_storage.virtual_heap
78         if cached_levels < 0:
79             cached_levels = vheap.levels
80         if concurrency_level is None:
81             concurrency_level = cached_levels
82         assert concurrency_level >= 0
83         cached_levels = min(vheap.levels, cached_levels)
84         concurrency_level = min(cached_levels, concurrency_level)
85         self._external_level = cached_levels
86         total_buckets = sum(vheap.bucket_count_at_level(l)
87                             for l in xrange(cached_levels))
88
89
90         print(" ILA ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI")
91         print(cached_levels)
92         print(concurrency_level)
93         print(" ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI")
94
95
96         self._root_device = heap_storage
97         # clone before we download the cache so that we can
98         # track bytes transferred during read/write requests
99         # (separate from the cached download)
100         self._concurrent_devices = \
101             {vheap.first_bucket_at_level(0): self._root_device.clone_device()}
102
103         self._cached_bucket_count = total_buckets
104         self._cached_buckets_tempfile = tempfile.TemporaryFile()
105         self._cached_buckets_tempfile.seek(0)
106         with tqdm.tqdm(desc=("Downloading %s Cached Heap Buckets"
107                              % (self._cached_bucket_count)),
108                        total=self._cached_bucket_count*self._root_device.bucket_size,
109                        unit="B",
110                        unit_scale=True,
111                        disable=not pyoram.config.SHOW_PROGRESS_BAR) as progress_bar:
112             for b, bucket in enumerate(
113                     self._root_device.bucket_storage.yield_blocks(
114                         xrange(vheap.first_bucket_at_level(cached_levels)))):
115                 self._cached_buckets_tempfile.write(bucket)
116                 progress_bar.update(self._root_device.bucket_size)
117         self._cached_buckets_tempfile.flush()
118         self._cached_buckets_mmap = mmap.mmap(
119             self._cached_buckets_tempfile.fileno(), 0)
120
121         log.info("%s: Cloning %s sub-heap devices"
122                  % (self.__class__.__name__, vheap.bucket_count_at_level(concurrency_level)))
123         # Avoid cloning devices when the cache line is at the root
124         # bucket or when the entire heap is cached
125         if (concurrency_level > 0) and \
126            (concurrency_level <= vheap.last_level):
127             for b in xrange(vheap.first_bucket_at_level(concurrency_level),
128                             vheap.first_bucket_at_level(concurrency_level+1)):
129                 try:
130                     self._concurrent_devices[b] = self._root_device.clone_device()
131                 except:                                # pragma: no cover
132                     log.error(                         # pragma: no cover
133                         "%s: Exception encountered "   # pragma: no cover
134                         "while cloning device. "       # pragma: no cover
135                         "Closing storage."             # pragma: no cover
136                         % (self.__class__.__name__))   # pragma: no cover
137                     self.close()                       # pragma: no cover
138                     raise                              # pragma: no cover
139
140         self._subheap_storage = {}
141         # Avoid populating this dictionary when the entire
142         # heap is cached
143         if self._external_level <= vheap.last_level:
144             for b in xrange(vheap.first_bucket_at_level(self._external_level),
145                             vheap.first_bucket_at_level(self._external_level+1)):
146                 node = vheap.Node(b)
147                 while node.bucket not in self._concurrent_devices:
148                     node = node.parent_node()
149                 assert node.bucket >= 0
150                 assert node.level == concurrency_level
151                 self._subheap_storage[b] = self._concurrent_devices[node.bucket]
152
153     #
154     # Additional Methods
155     #
156
157     @property
158     def cached_bucket_data(self):
159         return self._cached_buckets_mmap
160
161     #
162     # Define EncryptedHeapStorageInterface Methods
163     #
164
165     @property
166     def key(self):
167         return self._root_device.key
168
169     @property
170     def raw_storage(self):
171         return self._root_device.raw_storage
172
173     #
174     # Define HeapStorageInterface Methods
175     #
176
177     def clone_device(self, *args, **kwds):
178         raise NotImplementedError(                     # pragma: no cover
179             "Class is not designed for cloning")       # pragma: no cover
180
181     @classmethod
182     def compute_storage_size(cls, *args, **kwds):
183         return EncryptedHeapStorage.compute_storage_size(*args, **kwds)
184
185     @classmethod
186     def setup(cls, *args, **kwds):
187         raise NotImplementedError(                     # pragma: no cover
188             "Class is not designed to setup storage")  # pragma: no cover
189
190     @property
191     def header_data(self):
192         return self._root_device.header_data
193
194     @property
195     def bucket_count(self):
196         return self._root_device.bucket_count
197
198     @property
199     def bucket_size(self):
200         return self._root_device.bucket_size
201
202     @property
203     def blocks_per_bucket(self):
204         return self._root_device.blocks_per_bucket
205
206     @property
207     def storage_name(self):
208         return self._root_device.storage_name
209
210     @property
211     def virtual_heap(self):
212         return self._root_device.virtual_heap
213
214     @property
215     def bucket_storage(self):
216         return self._root_device.bucket_storage
217
218     def update_header_data(self, new_header_data):
219         self._root_device.update_header_data(new_header_data)
220
221     def close(self):
222         print("Heap Closing 1")
223         log.info("%s: Uploading %s cached bucket data before closing"
224                  % (self.__class__.__name__, self._cached_bucket_count))
225         with tqdm.tqdm(desc=("Uploading %s Cached Heap Buckets"
226                              % (self._cached_bucket_count)),
227                        total=self._cached_bucket_count*self.bucket_size,
228                        unit="B",
229                        unit_scale=True,
230                        disable=not pyoram.config.SHOW_PROGRESS_BAR) as progress_bar:
231             self.bucket_storage.\
232                 write_blocks(
233                     xrange(self._cached_bucket_count),
234                     (self._cached_buckets_mmap[(b*self.bucket_size):
235                                                ((b+1)*self.bucket_size)]
236                      for b in xrange(self._cached_bucket_count)),
237                     callback=lambda i: progress_bar.update(self._root_device.bucket_size))
238             for b in self._concurrent_devices:
239                 self._concurrent_devices[b].close()
240             self._root_device.close()
241             # forces the bar to become full at close
242             # even if te write_blocks action was faster
243             # the the mininterval time
244             progress_bar.mininterval = 0
245
246         self._cached_buckets_mmap.close()
247         self._cached_buckets_tempfile.close()
248
249     def read_path(self, b, level_start=0):
250         assert 0 <= b < self.virtual_heap.bucket_count()
251         bucket_list = self.virtual_heap.Node(b).bucket_path_from_root()
252         if len(bucket_list) <= self._external_level:
253             return [self._cached_buckets_mmap[(bb*self.bucket_size):
254                                               ((bb+1)*self.bucket_size)]
255                     for bb in bucket_list[level_start:]]
256         elif level_start >= self._external_level:
257             return self._subheap_storage[bucket_list[self._external_level]].\
258                 bucket_storage.read_blocks(bucket_list[level_start:])
259         else:
260             local_buckets = bucket_list[:self._external_level]
261             external_buckets = bucket_list[self._external_level:]
262             buckets = []
263             for bb in local_buckets[level_start:]:
264                 buckets.append(
265                     self._cached_buckets_mmap[(bb*self.bucket_size):
266                                               ((bb+1)*self.bucket_size)])
267             if len(external_buckets) > 0:
268                 buckets.extend(
269                     self._subheap_storage[external_buckets[0]].\
270                     bucket_storage.read_blocks(external_buckets))
271             assert len(buckets) == len(bucket_list[level_start:])
272             return buckets
273
274     def write_path(self, b, buckets, level_start=0):
275         assert 0 <= b < self.virtual_heap.bucket_count()
276         bucket_list = self.virtual_heap.Node(b).bucket_path_from_root()
277         if len(bucket_list) <= self._external_level:
278             for bb, bucket in zip(bucket_list[level_start:], buckets):
279                 self._cached_buckets_mmap[(bb*self.bucket_size):
280                                           ((bb+1)*self.bucket_size)] = bucket
281         elif level_start >= self._external_level:
282             self._subheap_storage[bucket_list[self._external_level]].\
283                 bucket_storage.write_blocks(bucket_list[level_start:], buckets)
284         else:
285             buckets = list(buckets)
286             assert len(buckets) == len(bucket_list[level_start:])
287             local_buckets = bucket_list[:self._external_level]
288             external_buckets = bucket_list[self._external_level:]
289             ndx = -1
290             for ndx, bb in enumerate(local_buckets[level_start:]):
291                 self._cached_buckets_mmap[(bb*self.bucket_size):
292                                           ((bb+1)*self.bucket_size)] = buckets[ndx]
293             if len(external_buckets) > 0:
294                 self._subheap_storage[external_buckets[0]].\
295                     bucket_storage.write_blocks(external_buckets,
296                                                 buckets[(ndx+1):])
297     @property
298     def bytes_sent(self):
299         return sum(device.bytes_sent for device
300                    in self._concurrent_devices.values())
301
302     @property
303     def bytes_received(self):
304         return sum(device.bytes_received for device
305                    in self._concurrent_devices.values())