--- /dev/null
+environment:
+ matrix:
+ - PYTHON: "C:\\Python27"
+ - PYTHON: "C:\\Python33"
+ - PYTHON: "C:\\Python34"
+ - PYTHON: "C:\\Python35"
+ - PYTHON: "C:\\Python27-x64"
+ - PYTHON: "C:\\Python35-x64"
+install:
+ - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%"
+ - "%PYTHON%\\python.exe -m pip install -v -U pip setuptools"
+ - "%PYTHON%\\python.exe -m pip install -v -e ."
+ - "%PYTHON%\\python.exe -m pip install -U unittest2 nose2 cov-core codecov coverage"
+build: off
+test_script:
+ - "%PYTHON%\\python.exe -m nose2 -v --log-capture --with-coverage --coverage src --coverage examples -s src"
+on_success:
+ - "%PYTHON%\\Scripts\\codecov.exe"
--- /dev/null
+# Emacs
+*~
+**/\#*
+
+# Python
+*.py[cod]
+
+# C extensions
+*.so
+
+# Setuptools distribution folder.
+/dist/
+/build/
+
+# Python egg metadata, regenerated from source files by setuptools.
+*.egg-info
+*.egg
+*.eggs
+
+# nose
+.coverage
+coverage.xml
--- /dev/null
+# travis CI config
+language: python
+matrix:
+ include:
+# - python: 2.7
+# env: JYTHON=org.python:jython-installer:2.7.1
+ - python: 2.7
+ - python: 3.3
+ - python: 3.4
+ - python: 3.5
+ - python: pypy
+cache: false
+before_install:
+ - sudo apt-get update -q
+ - sudo apt-get install graphviz -y
+ - python -m pip install -v -U pip setuptools virtualenv wheel
+ - if [ -n "$JYTHON" ]; then bash ./.travis_before_install_jython.sh; source $HOME/myvirtualenv/bin/activate ;fi
+ - if [ -n "$JYTHON" ]; then pip install jip; fi
+ - if [ -n "$JYTHON" ]; then export CLASSPATH=$VIRTUAL_ENV/javalib/*; fi
+install:
+ - python -m pip install -v -e .
+ - python -m pip install -U unittest2 nose2 cov-core codecov coverage
+script: python -m nose2 -v --log-capture --with-coverage --coverage src --coverage examples -s src
+after_success:
+ - codecov
+branches:
+ only:
+ - master
--- /dev/null
+#!/bin/bash
+set -e
+
+pip install jip
+jip install $JYTHON
+NON_GROUP_ID=${JYTHON#*:}
+_JYTHON_BASENAME=${NON_GROUP_ID/:/-}
+OLD_VIRTUAL_ENV=${VIRTUAL_ENV:=.}
+java -jar $OLD_VIRTUAL_ENV/javalib/${_JYTHON_BASENAME}.jar -s -d $HOME/jython
+$HOME/jython/bin/jython -c "import sys; print(sys.version_info)"
+virtualenv --version
+virtualenv -p $HOME/jython/bin/jython $HOME/myvirtualenv
--- /dev/null
+Changelog
+=========
+
+0.3.0 - `master`_
+~~~~~~~~~~~~~~~~~
+
+0.2.0 - 2016-10-18
+~~~~~~~~~~~~~~~~~~
+
+* using chunking to speed up yield_blocks for SFTP
+* speed up clearing entries in S3 interface by chunking delete requests
+* adding helper property to access heap storage on path oram
+* use a mmap to store the top-cached heap buckets
+* replace the show_status_bar keywords by a global config item
+* express status bar units as a memory transfer rate during setup
+* tweaks to Path ORAM to make it easier to generalize to other schemes
+* changing suffix of S3 index file from txt to bin
+* updates to readme
+
+0.1.2 - 2016-05-15
+~~~~~~~~~~~~~~~~~~
+
+* Initial release.
+
+.. _`master`: https://github.com/ghackebeil/PyORAM
--- /dev/null
+The MIT License (MIT)
+
+Copyright (c) 2016 Gabriel Hackebeil
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
--- /dev/null
+include README.rst
+include CHANGELOG.rst
+include LICENSE.md
+
+recursive-include src/_cffi_src *.py
--- /dev/null
+all:
+
+.PHONY: clean
+clean:
+ find examples -name "*.pyc" | xargs rm
+ find src -name "*.pyc" | xargs rm
+ find . -depth 1 -name "*.pyc" | xargs rm
+
+ find examples -name "*.pyo" | xargs rm
+ find src -name "*.pyo" | xargs rm
+ find . -depth 1 -name "*.pyo" | xargs rm
+
+ find examples -name "__pycache__" | xargs rm -r
+ find src -name "__pycache__" | xargs rm -r
+ find . -depth 1 -name "__pycache__" | xargs rm -r
+
+ find examples -name "*~" | xargs rm
+ find src -name "*~" | xargs rm
+ find . -depth 1 -name "*~" | xargs rm
+
+ find src -name "*.so" | xargs rm
--- /dev/null
+PyORAM
+======
+
+.. image:: https://travis-ci.org/ghackebeil/PyORAM.svg?branch=master
+ :target: https://travis-ci.org/ghackebeil/PyORAM
+
+.. image:: https://ci.appveyor.com/api/projects/status/898bxsvqdch1btv6/branch/master?svg=true
+ :target: https://ci.appveyor.com/project/ghackebeil/PyORAM?branch=master
+
+.. image:: https://codecov.io/github/ghackebeil/PyORAM/coverage.svg?branch=master
+ :target: https://codecov.io/github/ghackebeil/PyORAM?branch=master
+
+.. image:: https://img.shields.io/pypi/v/PyORAM.svg
+ :target: https://pypi.python.org/pypi/PyORAM/
+
+Python-based Oblivious RAM (PyORAM) is a collection of
+Oblivious RAM algorithms implemented in Python. This package
+serves to enable rapid prototyping and testing of new ORAM
+algorithms and ORAM-based applications tailored for the
+cloud-storage setting. PyORAM is written to support as many
+Python versions as possible, including Python 2.7+, Python
+3.3+, and PyPy 2.6+.
+
+This software is copyright (c) by Gabriel A. Hackebeil (gabe.hackebeil@gmail.com).
+
+This software is released under the MIT software license.
+This license, including disclaimer, is available in the 'LICENSE' file.
+
+This work was funded by the Privacy Enhancing Technologies
+project under the guidance of Professor `Attila Yavuz
+<https://web.engr.oregonstate.edu/~yavuza>`_ at Oregon State
+University.
+
+Why Python?
+-----------
+
+This project is meant for research. It is provided mainly as
+a tool for other researchers studying the applicability of
+ORAM to the cloud-storage setting. In such a setting, we
+observe that network latency far outweighs any overhead
+introduced from switching to an interpreted language such as
+Python (as opposed to C++ or Java). Thus, our hope is that
+by providing a Python-based library of ORAM tools, we will
+enable researchers to spend more time prototyping new and
+interesting ORAM applications and less time fighting with a
+compiler or chasing down segmentation faults.
+
+Installation
+------------
+
+To install the latest release of PyORAM, simply execute::
+
+ $ pip install PyORAM
+
+To install the trunk version of PyORAM, first clone the repository::
+
+ $ git clone https://github.com/ghackebeil/PyORAM.git
+
+Next, enter the directory where PyORAM has been cloned and run setup::
+
+ $ python setup.py install
+
+If you are a developer, you should instead install using::
+
+ $ pip install -e .
+ $ pip install nose2 unittest2
+
+Installation Tips
+-----------------
+
+* OS X users are recommended to work with the `homebrew
+ <http://brew.sh/>`_ version of Python2 or Python3. If you
+ must use the default system Python, then the best thing to
+ do is create a virtual environment and install PyORAM into
+ that. The process of creating a virtual environment that is
+ stored in the PyORAM directory would look something like::
+
+ $ sudo pip install virtualenv
+ $ cd <PyORAM-directory>
+ $ virtualenv local_python2.7
+
+ If you had already attempted to install PyORAM into the
+ system Python and encountered errors, it may be necessary
+ to delete the directories :code:`build` and :code:`dist`
+ from the current directory using the command::
+
+ $ sudo rm -rf build dist
+
+ Once this virtual environment has been successfully
+ created, you can *activate* it using the command::
+
+ $ . local_python2.7/bin/activate
+
+ Then, proceed with the normal installation steps to
+ install PyORAM into this environment. Note that you must
+ *activate* this environment each time you open a new
+ terminal if PyORAM is installed in this way. Also, note
+ that use of the :code:`sudo` command is no longer
+ necessary (and should be avoided) once a virtual
+ environment is activated in the current shell.
+
+* If you have trouble installing the cryptography package
+ on OS X with PyPy: `stackoverflow <https://stackoverflow.com/questions/36662704/fatal-error-openssl-e-os2-h-file-not-found-in-pypy/36706513#36706513>`_.
+
+* If you encounter the dreaded "unable to find
+ vcvarsall.bat" error when installing packages with C
+ extensions through pip on Windows: `blog post <https://blogs.msdn.microsoft.com/pythonengineering/2016/04/11/unable-to-find-vcvarsall-bat>`_.
+
+Tools Available (So Far)
+------------------------
+
+Encrypted block storage
+~~~~~~~~~~~~~~~~~~~~~~~
+
+* The basic building block for any ORAM implementation.
+
+* Available storage interfaces include:
+
+ - local storage using a file, a memory-mapped file, or RAM
+
+ + Dropbox
+
+ - cloud storage using SFTP (requires SSH access to a server)
+
+ + Amazon EC2
+
+ + Microsoft Azure
+
+ + Google Cloud Platform
+
+ - cloud storage using Amazon Simple Storage Service (S3)
+
+* See Examples:
+
+ - examples/encrypted_storage_ram.py
+
+ - examples/encrypted_storage_mmap.py
+
+ - examples/encrypted_storage_file.py
+
+ - examples/encrypted_storage_sftp.py
+
+ - examples/encrypted_storage_s3.py
+
+Path ORAM
+~~~~~~~~~
+
+* Reference: `Stefanov et al. <http://arxiv.org/abs/1202.5150v3>`_
+
+* Generalized to work over k-kary storage heaps. Default
+ settings use a binary storage heap and bucket size
+ parameter set to 4. Using a k-ary storage heap can reduce
+ the access cost; however, stash size behavior has not been
+ formally analyzed in this setting.
+
+* Tree-Top caching can be used to reduce data transmission
+ per access as well as reduce access latency by exploiting
+ parallelism across independent sub-heaps below the last
+ cached heap level.
+
+* See Examples:
+
+ - examples/path_oram_ram.py
+
+ - examples/path_oram_mmap.py
+
+ - examples/path_oram_file.py
+
+ - examples/path_oram_sftp.py
+
+ - examples/path_oram_s3.py
+
+Performance Tips
+----------------
+
+Setup Storage Locally
+~~~~~~~~~~~~~~~~~~~~~
+
+Storage schemes such as BlockStorageFile ("file"), BlockStorageMMap
+("mmap"), BlockStorageRAM ("ram"), and BlockStorageSFTP ("sftp") all
+employ the same underlying storage format. Thus, an oblivious storage
+scheme can be initialized locally and then transferred to an external
+storage location and accessed via BlockStorageSFTP using SSH login
+credentials. See the following pair of files for an example of this:
+
+* examples/path_oram_sftp_setup.py
+
+* examples/path_oram_sftp_test.py
+
+BlockStorageS3 ("s3") employs a different format whereby the
+underlying blocks are stored in separate "file" objects.
+This design is due to the fact that the Amazon S3 API does
+not allow modifications to a specific byte range within a
+file, but instead requires that the entire modified file
+object be re-uploaded. Thus, any efficient block storage
+scheme must use separate "file" objects for each block.
+
+Tree-Top Caching
+~~~~~~~~~~~~~~~~
+
+For schemes that employ a storage heap (such as Path ORAM),
+tree-top caching provides the ability to parallelize I/O
+operations across the independent sub-heaps below the last
+cached heap level. The default behavior of this
+implementation of Path ORAM, for instance, caches the top
+three levels of the storage heap in RAM, which creates eight
+independent sub-heaps across which write operations can be
+asynchronous.
+
+If the underlying storage is being accessed through SFTP, the
+tree-top cached storage heap will attempt to open an
+independent SFTP session for each sub-heap using the same
+SSH connection. Typically, the maximum number of allowable
+sessions associated with a single SSH connection is limited
+by the SSH server. For instance, the default maximum number
+of sessions allowed by a server using OpenSSH is 10. Thus,
+increasing the number of cached levels beyond 3 when using
+a binary storage heap will attempt to generate 16 or more SFTP
+sessions and result in an error such as::
+
+ paramiko.ssh_exception.ChannelException: (1, 'Administratively prohibited')
+
+There are two options for avoiding this error:
+
+1. If you have administrative privileges on the server, you
+ can increase the maximum number of allowed sessions for a
+ single SSH connection. For example, to set the maximum
+ allowed sessions to 128 on a server using OpenSSH, one
+ would set::
+
+ MaxSessions 128
+
+ in :code:`/etc/ssh/sshd_config`, and then run the
+ command :code:`sudo service ssh restart`.
+
+2. You can limit the number of concurrent devices that will
+ be created by setting the concurrency level to something
+ below the last cached level using the
+ :code:`concurrency_level` keyword. For example, the
+ settings :code:`cached_levels=5` and
+ :code:`concurrency_level=0` would cache the top 5 levels
+ of the storage heap locally, but all external I/O
+ operations would take place through a single storage
+ device (e.g., using 1 SFTP session).
--- /dev/null
+import time
+import base64
+
+from pyoram.crypto.aes import AES
+
+def runtest(label, enc_func, dec_func):
+ print("")
+ print("$"*20)
+ print("{0:^20}".format(label))
+ print("$"*20)
+ for keysize in AES.key_sizes:
+ print("")
+ print("@@@@@@@@@@@@@@@@@@@@")
+ print(" Key Size: %s bytes" % (keysize))
+ print("@@@@@@@@@@@@@@@@@@@@")
+ print("\nTest Bulk")
+ #
+ # generate a key
+ #
+ key = AES.KeyGen(keysize)
+ print("Key: %s" % (base64.b64encode(key)))
+
+ #
+ # generate some plaintext
+ #
+ nblocks = 1000000
+ plaintext_numbytes = AES.block_size * nblocks
+ print("Plaintext Size: %s MB"
+ % (plaintext_numbytes * 1.0e-6))
+ # all zeros
+ plaintext = bytes(bytearray(plaintext_numbytes))
+
+ #
+ # time encryption
+ #
+ start_time = time.time()
+ ciphertext = enc_func(key, plaintext)
+ stop_time = time.time()
+ print("Encryption Time: %.3fs (%.3f MB/s)"
+ % (stop_time-start_time,
+ (plaintext_numbytes * 1.0e-6) / (stop_time-start_time)))
+
+ #
+ # time decryption
+ #
+ start_time = time.time()
+ plaintext_decrypted = dec_func(key, ciphertext)
+ stop_time = time.time()
+ print("Decryption Time: %.3fs (%.3f MB/s)"
+ % (stop_time-start_time,
+ (plaintext_numbytes * 1.0e-6) / (stop_time-start_time)))
+
+ assert plaintext_decrypted == plaintext
+ assert ciphertext != plaintext
+ # IND-CPA
+ assert enc_func(key, plaintext) != ciphertext
+ # make sure the only difference is not in the IV
+ assert enc_func(key, plaintext)[AES.block_size:] \
+ != ciphertext[AES.block_size:]
+ if enc_func is AES.CTREnc:
+ assert len(plaintext) == \
+ len(ciphertext) - AES.block_size
+ else:
+ assert enc_func is AES.GCMEnc
+ assert len(plaintext) == \
+ len(ciphertext) - 2*AES.block_size
+
+ del plaintext
+ del plaintext_decrypted
+ del ciphertext
+
+ print("\nTest Chunks")
+ #
+ # generate a key
+ #
+ key = AES.KeyGen(keysize)
+ print("Key: %s" % (base64.b64encode(key)))
+
+ #
+ # generate some plaintext
+ #
+ nblocks = 1000
+ blocksize = 16000
+ total_bytes = blocksize * nblocks
+ print("Block Size: %s KB" % (blocksize * 1.0e-3))
+ print("Block Count: %s" % (nblocks))
+ print("Total: %s MB" % (total_bytes * 1.0e-6))
+ plaintext_blocks = [bytes(bytearray(blocksize))
+ for i in range(nblocks)]
+
+ #
+ # time encryption
+ #
+ start_time = time.time()
+ ciphertext_blocks = [enc_func(key, b)
+ for b in plaintext_blocks]
+ stop_time = time.time()
+ print("Encryption Time: %.3fs (%.3f MB/s)"
+ % (stop_time-start_time,
+ (total_bytes * 1.0e-6) / (stop_time-start_time)))
+
+ #
+ # time decryption
+ #
+ start_time = time.time()
+ plaintext_decrypted_blocks = [dec_func(key, c)
+ for c in ciphertext_blocks]
+ stop_time = time.time()
+ print("Decryption Time: %.3fs (%.3f MB/s)"
+ % (stop_time-start_time,
+ (total_bytes * 1.0e-6) / (stop_time-start_time)))
+
+def main():
+ runtest("AES - CTR Mode", AES.CTREnc, AES.CTRDec)
+ runtest("AES - GCM Mode", AES.GCMEnc, AES.GCMDec)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of Path ORAM when
+# storage is accessed through an SSH client using the Secure
+# File Transfer Protocol (SFTP).
+#
+# In order to run this example, you must provide a host
+# (server) address along with valid login credentials
+#
+
+import os
+import random
+import time
+import array
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.oblivious_storage.tree.path_oram import \
+ PathORAM
+
+from pyoram.storage.AliTimer import *
+
+import paramiko
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set SSH login credentials here
+# (by default, we pull these from the environment
+# for testing purposes)
+ssh_host = os.environ.get('PYORAM_SSH_TEST_HOST')
+ssh_username = os.environ.get('PYORAM_SSH_TEST_USERNAME')
+ssh_password = os.environ.get('PYORAM_SSH_TEST_PASSWORD')
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 2048
+# one block per bucket in the
+# storage heap of height 8
+# block_count = 2**(8+1)-1
+# block_count = 2**(12+1)-1
+# block_count = 2**(15+1)-1
+block_count = 2**(8+1)-1
+
+def main():
+ timer = Foo.Instance()
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ PathORAM.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='sftp'))))
+ print("")
+
+ # Start an SSH client using paramiko
+ print("Starting SSH Client")
+ with paramiko.SSHClient() as ssh:
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ ssh.load_system_host_keys()
+ ssh.connect(ssh_host,
+ username=ssh_username,
+ password=ssh_password)
+
+ print("Setting Up Path ORAM Storage")
+ start_time = time.time()
+ with PathORAM.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='sftp',
+ sshclient=ssh,
+ cached_levels=2,
+ concurrency_level = 0,
+ ignore_existing=True) as f:
+ print("Total Data Transmission: %s" % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("Done Initializing")
+ pass
+
+ stop_time = time.time()
+ print("Initial Setup Processing Time: " + str(stop_time - start_time))
+ print("Initial Setup Network Time: " + str(timer.getTime()))
+ print("")
+
+ # print("Total Setup Time: %.2f s"
+ # % (time.time()-setup_start))
+ # print("Current Stash Size: %s"
+ # % len(f.stash))
+
+ # print("")
+
+ start_time = time.time()
+ timer.resetTimer()
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+
+ print("Starting test Run...")
+
+
+ with PathORAM(storage_name,
+ f.stash,
+ f.position_map,
+ key=f.key,
+ cached_levels=2,
+ concurrency_level = 0,
+ storage_type='sftp',
+ sshclient=ssh) as f:
+
+ stop_time = time.time()
+ print("Total Data Transmission: %s" % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("Test Setup Processing Time: " + str(stop_time - start_time))
+ print("Test Setup Network Time: " + str(timer.getTime()))
+ print("")
+
+
+
+
+
+ keys = []
+ keys.extend(range(0, block_count))
+ random.shuffle(keys)
+
+
+ print("Starting Ali Test 2")
+ timer.resetTimer()
+ test_count = block_count
+ start_time = time.time()
+
+ for t in tqdm.tqdm(list(range(test_count)), desc="Running I/O Performance Test"):
+ ind = keys[t]
+ # ind = t
+ s = "a" + str(ind)
+ f.write_block(ind, bytearray(s.ljust(block_size, '\0')))
+ print("Total Data Transmission: %s" % (MemorySize(f.bytes_sent + f.bytes_received)))
+
+ # for t in tqdm.tqdm(list(range(test_count)), desc="Running I/O Performance Test"):
+ # ind = keys[t]
+ # # print(f.read_block(ind))
+ # f.read_block(ind)
+
+
+
+ stop_time = time.time()
+ print("Test Processing Time: " + str(stop_time - start_time))
+ print("Test Network Time: " + str(timer.getTime()))
+ print("")
+
+
+
+
+
+ # print("Current Stash Size: %s"
+ # % len(f.stash))
+ # print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ # % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ # (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ # print("Fetch Block Avg. Latency: %.2f ms"
+ # % ((stop_time-start_time)/float(test_count)*1000))
+ # print("")
+ # print("")
+ # print("")
+ # print("")
+ # print("")
+ # print("")
+
+
+
+
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+import os
+import tempfile
+
+from pyoram.util.virtual_heap import \
+ SizedVirtualHeap
+from pyoram.encrypted_storage.encrypted_heap_storage import \
+ EncryptedHeapStorage
+
+def main():
+ #
+ # get a unique filename in the current directory
+ #
+ fid, tmpname = tempfile.mkstemp(dir=os.getcwd())
+ os.close(fid)
+ os.remove(tmpname)
+ print("Storage Name: %s" % (tmpname))
+
+ key_size = 32
+ header_data = b'a message'
+ heap_base = 3
+ heap_height = 2
+ block_size = 8
+ blocks_per_bucket=4
+ initialize = lambda i: \
+ bytes(bytearray([i] * block_size * blocks_per_bucket))
+ vheap = SizedVirtualHeap(
+ heap_base,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket)
+
+ with EncryptedHeapStorage.setup(
+ tmpname,
+ block_size,
+ heap_height,
+ key_size=key_size,
+ header_data=header_data,
+ heap_base=heap_base,
+ blocks_per_bucket=blocks_per_bucket,
+ initialize=initialize) as f:
+ assert tmpname == f.storage_name
+ assert f.header_data == header_data
+ print(f.read_path(vheap.random_bucket()))
+ key = f.key
+ assert os.path.exists(tmpname)
+
+ with EncryptedHeapStorage(tmpname, key=key) as f:
+ assert tmpname == f.storage_name
+ assert f.header_data == header_data
+ print(f.read_path(vheap.random_bucket()))
+
+ #
+ # cleanup
+ #
+ os.remove(tmpname)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of encrypted
+# storage access through a local file.
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ EncryptedBlockStorage
+
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ EncryptedBlockStorage.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='file'))))
+ print("")
+
+ print("Setting Up Encrypted Block Storage")
+ setup_start = time.time()
+ with EncryptedBlockStorage.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='file',
+ ignore_existing=True) as f:
+ print("Total Setup Time: %2.f s"
+ % (time.time()-setup_start))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with EncryptedBlockStorage(storage_name,
+ key=f.key,
+ storage_type='file') as f:
+
+ test_count = 1000
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Access Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+ # cleanup because this is a test example
+ os.remove(storage_name)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of encrypted
+# storage access through a local memory-mapped file.
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ EncryptedBlockStorage
+
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ EncryptedBlockStorage.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='mmap'))))
+ print("")
+
+ print("Setting Up Encrypted Block Storage")
+ setup_start = time.time()
+ with EncryptedBlockStorage.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='mmap',
+ ignore_existing=True) as f:
+ print("Total Setup Time: %2.f s"
+ % (time.time()-setup_start))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with EncryptedBlockStorage(storage_name,
+ key=f.key,
+ storage_type='mmap') as f:
+
+ test_count = 1000
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Access Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+ # cleanup because this is a test example
+ os.remove(storage_name)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of encrypted
+# storage access through RAM.
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ EncryptedBlockStorage
+from pyoram.storage.block_storage_ram import \
+ BlockStorageRAM
+
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ EncryptedBlockStorage.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='ram'))))
+ print("")
+
+ print("Setting Up Encrypted Block Storage")
+ setup_start = time.time()
+ with EncryptedBlockStorage.setup(storage_name, # RAM storage ignores this argument
+ block_size,
+ block_count,
+ storage_type='ram',
+ ignore_existing=True) as f:
+ print("Total Setup Time: %2.f s"
+ % (time.time()-setup_start))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # This must be done after closing the file to ensure the lock flag
+ # is set to False in the saved data. The tofile method only exists
+ # on BlockStorageRAM
+ f.raw_storage.tofile(storage_name)
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with EncryptedBlockStorage(BlockStorageRAM.fromfile(storage_name),
+ key=f.key) as f:
+
+ test_count = 1000
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Access Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+ # cleanup because this is a test example
+ os.remove(storage_name)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of encrypted
+# storage access through Amazon Simple Storage Service
+# (S3).
+#
+# In order to run this example, you must provide a valid
+# S3 bucket name and have the following variables defined
+# in your current environment:
+# - AWS_ACCESS_KEY_ID
+# - AWS_SECRET_ACCESS_KEY
+# - AWS_DEFAULT_REGION
+# These can also be set using keywords.
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ EncryptedBlockStorage
+
+import tqdm
+
+pyoram.SHOW_PROGRESS_BAR = True
+
+# Set S3 bucket name here
+# (by default, we pull this from the environment
+# for testing purposes)
+bucket_name = os.environ.get('PYORAM_AWS_TEST_BUCKET')
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ EncryptedBlockStorage.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='s3'))))
+ print("")
+
+ print("Setting Up Encrypted Block Storage")
+ setup_start = time.time()
+ with EncryptedBlockStorage.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='s3',
+ bucket_name=bucket_name,
+ ignore_existing=True) as f:
+ print("Total Setup Time: %2.f s"
+ % (time.time()-setup_start))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with EncryptedBlockStorage(storage_name,
+ key=f.key,
+ storage_type='s3',
+ bucket_name=bucket_name) as f:
+
+ test_count = 1000
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Access Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of encrypted storage
+# access through an SSH client using the Secure File
+# Transfer Protocol (SFTP).
+#
+# In order to run this example, you must provide a host
+# (server) address along with valid login credentials
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ EncryptedBlockStorage
+
+import paramiko
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set SSH login credentials here
+# (by default, we pull these from the environment
+# for testing purposes)
+ssh_host = os.environ.get('PYORAM_SSH_TEST_HOST')
+ssh_username = os.environ.get('PYORAM_SSH_TEST_USERNAME')
+ssh_password = os.environ.get('PYORAM_SSH_TEST_PASSWORD')
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ EncryptedBlockStorage.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='sftp'))))
+ print("")
+
+ # Start an SSH client using paramiko
+ print("Starting SSH Client")
+ with paramiko.SSHClient() as ssh:
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ ssh.load_system_host_keys()
+ ssh.connect(ssh_host,
+ username=ssh_username,
+ password=ssh_password)
+
+ print("Setting Up Encrypted Block Storage")
+ setup_start = time.time()
+ with EncryptedBlockStorage.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='sftp',
+ sshclient=ssh,
+ ignore_existing=True) as f:
+ print("Total Setup Time: %2.f s"
+ % (time.time()-setup_start))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with EncryptedBlockStorage(storage_name,
+ key=f.key,
+ storage_type='sftp',
+ sshclient=ssh) as f:
+
+ test_count = 1000
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Access Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of Path ORAM
+# when storage is accessed through a local file.
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.oblivious_storage.tree.path_oram import \
+ PathORAM
+
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ PathORAM.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='file'))))
+ print("")
+
+ print("Setting Up Path ORAM Storage")
+ setup_start = time.time()
+ with PathORAM.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='file',
+ ignore_existing=True) as f:
+ print("Total Setup Time: %2.f s"
+ % (time.time()-setup_start))
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with PathORAM(storage_name,
+ f.stash,
+ f.position_map,
+ key=f.key,
+ storage_type='file') as f:
+
+ test_count = 100
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Access Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+ # cleanup because this is a test example
+ os.remove(storage_name)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of Path ORAM
+# when storage is accessed through a local memory-mapped
+# file (mmap).
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.oblivious_storage.tree.path_oram import \
+ PathORAM
+
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ PathORAM.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='mmap'))))
+ print("")
+
+ print("Setting Up Path ORAM Storage")
+ setup_start = time.time()
+ with PathORAM.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='mmap',
+ ignore_existing=True) as f:
+ print("Total Setup Time: %2.f s"
+ % (time.time()-setup_start))
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with PathORAM(storage_name,
+ f.stash,
+ f.position_map,
+ key=f.key,
+ storage_type='mmap') as f:
+
+ test_count = 100
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Access Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+ # cleanup because this is a test example
+ os.remove(storage_name)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of Path ORAM
+# when storage is accessed through RAM.
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.oblivious_storage.tree.path_oram import \
+ PathORAM
+from pyoram.storage.block_storage_ram import \
+ BlockStorageRAM
+
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ PathORAM.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='ram'))))
+ print("")
+
+ print("Setting Up Path ORAM Storage")
+ setup_start = time.time()
+ with PathORAM.setup(storage_name, # RAM storage ignores this argument
+ block_size,
+ block_count,
+ storage_type='ram',
+ ignore_existing=True) as f:
+ print("Total Setup Time: %2.f s"
+ % (time.time()-setup_start))
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # This must be done after closing the file to ensure the lock flag
+ # is set to False in the saved data. The tofile method only exists
+ # on BlockStorageRAM
+ f.raw_storage.tofile(storage_name)
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with PathORAM(BlockStorageRAM.fromfile(storage_name),
+ f.stash,
+ f.position_map,
+ key=f.key) as f:
+
+ test_count = 100
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Access Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+ # cleanup because this is a test example
+ os.remove(storage_name)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of Path ORAM when
+# storage is accessed through Amazon Simple Storage Service
+# (S3).
+#
+# In order to run this example, you must provide a valid
+# S3 bucket name and have the following variables defined
+# in your current environment:
+# - AWS_ACCESS_KEY_ID
+# - AWS_SECRET_ACCESS_KEY
+# - AWS_DEFAULT_REGION
+# These can also be set using keywords.
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.oblivious_storage.tree.path_oram import \
+ PathORAM
+
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set S3 bucket name here
+# (by default, we pull this from the environment
+# for testing purposes)
+bucket_name = os.environ.get('PYORAM_AWS_TEST_BUCKET')
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ PathORAM.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='s3'))))
+ print("")
+
+ print("Setting Up Path ORAM Storage")
+ setup_start = time.time()
+ with PathORAM.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='s3',
+ bucket_name=bucket_name,
+ ignore_existing=True) as f:
+ print("Total Setup Time: %.2f s"
+ % (time.time()-setup_start))
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with PathORAM(storage_name,
+ f.stash,
+ f.position_map,
+ key=f.key,
+ storage_type='s3',
+ bucket_name=bucket_name) as f:
+
+ test_count = 100
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Fetch Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example measures the performance of Path ORAM when
+# storage is accessed through an SSH client using the Secure
+# File Transfer Protocol (SFTP).
+#
+# In order to run this example, you must provide a host
+# (server) address along with valid login credentials
+#
+
+import os
+import random
+import time
+
+import pyoram
+from pyoram.util.misc import MemorySize
+from pyoram.oblivious_storage.tree.path_oram import \
+ PathORAM
+
+import paramiko
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set SSH login credentials here
+# (by default, we pull these from the environment
+# for testing purposes)
+ssh_host = os.environ.get('PYORAM_SSH_TEST_HOST')
+ssh_username = os.environ.get('PYORAM_SSH_TEST_USERNAME')
+ssh_password = os.environ.get('PYORAM_SSH_TEST_PASSWORD')
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ PathORAM.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='sftp'))))
+ print("")
+
+ # Start an SSH client using paramiko
+ print("Starting SSH Client")
+ with paramiko.SSHClient() as ssh:
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ ssh.load_system_host_keys()
+ ssh.connect(ssh_host,
+ username=ssh_username,
+ password=ssh_password)
+
+ print("Setting Up Path ORAM Storage")
+ setup_start = time.time()
+ with PathORAM.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='sftp',
+ sshclient=ssh,
+ ignore_existing=True) as f:
+ print("Total Setup Time: %.2f s"
+ % (time.time()-setup_start))
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ # We close the device and reopen it after
+ # setup to reset the bytes sent and bytes
+ # received stats.
+ with PathORAM(storage_name,
+ f.stash,
+ f.position_map,
+ key=f.key,
+ storage_type='sftp',
+ sshclient=ssh) as f:
+
+ test_count = 100
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Access Block Avg. Data Transmitted: %s (%.3fx)"
+ % (MemorySize((f.bytes_sent + f.bytes_received)/float(test_count)),
+ (f.bytes_sent + f.bytes_received)/float(test_count)/float(block_size)))
+ print("Fetch Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example demonstrates how to setup an instance of Path ORAM
+# locally and then transfer the storage to a server using a paramiko
+# SSHClient. After executing this file, path_oram_sftp_test.py can be
+# executed to run simple I/O performance tests using different caching
+# settings.
+#
+# In order to run this example, you must provide a host
+# (server) address along with valid login credentials
+#
+
+import os
+import random
+import time
+import pickle
+
+import pyoram
+from pyoram.util.misc import MemorySize, save_private_key
+from pyoram.oblivious_storage.tree.path_oram import \
+ PathORAM
+
+import paramiko
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set SSH login credentials here
+# (by default, we pull these from the environment
+# for testing purposes)
+ssh_host = os.environ.get('PYORAM_SSH_TEST_HOST')
+ssh_username = os.environ.get('PYORAM_SSH_TEST_USERNAME')
+ssh_password = os.environ.get('PYORAM_SSH_TEST_PASSWORD')
+
+# Set the storage location and size
+storage_name = "heap.bin"
+# 4KB block size
+block_size = 4000
+# one block per bucket in the
+# storage heap of height 8
+block_count = 2**(8+1)-1
+
+def main():
+
+ print("Storage Name: %s" % (storage_name))
+ print("Block Count: %s" % (block_count))
+ print("Block Size: %s" % (MemorySize(block_size)))
+ print("Total Memory: %s"
+ % (MemorySize(block_size*block_count)))
+ print("Actual Storage Required: %s"
+ % (MemorySize(
+ PathORAM.compute_storage_size(
+ block_size,
+ block_count,
+ storage_type='mmap'))))
+ print("")
+
+ print("Setting Up Path ORAM Storage Locally")
+ setup_start = time.time()
+ with PathORAM.setup(storage_name,
+ block_size,
+ block_count,
+ storage_type='mmap',
+ ignore_existing=True) as f:
+ print("Total Setup Time: %.2f s"
+ % (time.time()-setup_start))
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Total Data Transmission: %s"
+ % (MemorySize(f.bytes_sent + f.bytes_received)))
+ print("")
+
+ print("Saving key to file: %s.key"
+ % (storage_name))
+ save_private_key(storage_name+".key", f.key)
+ print("Saving stash to file: %s.stash"
+ % (storage_name))
+ with open(storage_name+".stash", 'wb') as fstash:
+ pickle.dump(f.stash, fstash)
+ print("Saving position map to file: %s.position"
+ % (storage_name))
+ with open(storage_name+".position", 'wb') as fpos:
+ pickle.dump(f.position_map, fpos)
+
+ # Start an SSH client using paramiko
+ print("Starting SSH Client")
+ with paramiko.SSHClient() as ssh:
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ ssh.load_system_host_keys()
+ ssh.connect(ssh_host,
+ username=ssh_username,
+ password=ssh_password)
+
+ sftp = ssh.open_sftp()
+
+ def my_hook(t):
+ def inner(b, total):
+ t.total = total
+ t.update(b - inner.last_b)
+ inner.last_b = b
+ inner.last_b = 0
+ return inner
+ with tqdm.tqdm(desc="Transferring Storage",
+ unit='B',
+ unit_scale=True,
+ miniters=1) as t:
+ sftp.put(storage_name,
+ storage_name,
+ callback=my_hook(t))
+ sftp.close()
+
+ print("Deleting Local Copy of Storage")
+ os.remove(storage_name)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+#
+# This example demonstrates how to access an existing Path ORAM
+# storage space through an SSH client using the Secure File Transfer
+# Protocol (SFTP). This file should not be executed until the
+# path_oram_sftp_setup.py example has been executed. The user is
+# encouraged to tweak the settings for 'cached_levels',
+# 'concurrency_level', and 'threadpool_size' to observe their effect
+# on access latency.
+#
+# In order to run this example, you must provide a host
+# (server) address along with valid login credentials
+#
+
+import os
+import random
+import time
+import pickle
+import multiprocessing
+
+import pyoram
+from pyoram.util.misc import MemorySize, load_private_key
+from pyoram.oblivious_storage.tree.path_oram import \
+ PathORAM
+
+import paramiko
+import tqdm
+
+pyoram.config.SHOW_PROGRESS_BAR = True
+
+# Set SSH login credentials here
+# (by default, we pull these from the environment
+# for testing purposes)
+ssh_host = os.environ.get('PYORAM_SSH_TEST_HOST')
+ssh_username = os.environ.get('PYORAM_SSH_TEST_USERNAME')
+ssh_password = os.environ.get('PYORAM_SSH_TEST_PASSWORD')
+
+# Set the storage location and size
+storage_name = "heap.bin"
+
+def main():
+
+ print("Loading key from file: %s.key"
+ % (storage_name))
+ key = load_private_key(storage_name+".key")
+ print("Loading stash from file: %s.stash"
+ % (storage_name))
+ with open(storage_name+".stash", 'rb') as fstash:
+ stash = pickle.load(fstash)
+ print("Loading position map from file: %s.position"
+ % (storage_name))
+ with open(storage_name+".position", 'rb') as fpos:
+ position_map = pickle.load(fpos)
+
+ # Start an SSH client using paramiko
+ print("Starting SSH Client")
+ with paramiko.SSHClient() as ssh:
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ ssh.load_system_host_keys()
+ ssh.connect(ssh_host,
+ username=ssh_username,
+ password=ssh_password)
+
+ with PathORAM(storage_name,
+ stash,
+ position_map,
+ key=key,
+ storage_type='sftp',
+ cached_levels=6,
+ concurrency_level=3,
+ threadpool_size=multiprocessing.cpu_count()*2,
+ sshclient=ssh) as f:
+
+ try:
+
+ test_count = 100
+ start_time = time.time()
+ for t in tqdm.tqdm(list(range(test_count)),
+ desc="Running I/O Performance Test"):
+ f.read_block(random.randint(0,f.block_count-1))
+ stop_time = time.time()
+ print("Current Stash Size: %s"
+ % len(f.stash))
+ print("Fetch Block Avg. Latency: %.2f ms"
+ % ((stop_time-start_time)/float(test_count)*1000))
+ print("")
+
+ finally:
+
+ print("Saving stash to file: %s.stash"
+ % (storage_name))
+ with open(storage_name+".stash", 'wb') as fstash:
+ pickle.dump(f.stash, fstash)
+ print("Saving position map to file: %s.position"
+ % (storage_name))
+ with open(storage_name+".position", 'wb') as fpos:
+ pickle.dump(f.position_map, fpos)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+import os
+import struct
+import random
+
+from pyoram.util.virtual_heap import \
+ SizedVirtualHeap
+from pyoram.encrypted_storage.encrypted_heap_storage import \
+ EncryptedHeapStorage
+from pyoram.oblivious_storage.tree.tree_oram_helper import \
+ TreeORAMStorageManagerPointerAddressing
+
+def main():
+ storage_name = "heap.bin"
+ print("Storage Name: %s" % (storage_name))
+
+ key_size = 32
+ heap_base = 2
+ heap_height = 2
+ block_size = struct.calcsize("!?LL")
+ blocks_per_bucket = 2
+ vheap = SizedVirtualHeap(
+ heap_base,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket)
+
+ print("Block Size: %s" % (block_size))
+ print("Blocks Per Bucket: %s" % (blocks_per_bucket))
+
+ position_map = {}
+ def initialize(i):
+ bucket = bytes()
+ for j in range(blocks_per_bucket):
+ if (i*j) % 3:
+ bucket += struct.pack(
+ "!?LL", False, 0, 0)
+ else:
+ x = vheap.Node(i)
+ while not vheap.is_nil_node(x):
+ x = x.child_node(random.randint(0, heap_base-1))
+ x = x.parent_node()
+ bucket += struct.pack(
+ "!?LL", True, initialize.id_, x.bucket)
+ position_map[initialize.id_] = x.bucket
+ initialize.id_ += 1
+ return bucket
+ initialize.id_ = 1
+
+ with EncryptedHeapStorage.setup(
+ storage_name,
+ block_size,
+ heap_height,
+ heap_base=heap_base,
+ key_size=key_size,
+ blocks_per_bucket=blocks_per_bucket,
+ initialize=initialize,
+ ignore_existing=True) as f:
+ assert storage_name == f.storage_name
+ stash = {}
+ oram = TreeORAMStorageManagerPointerAddressing(f, stash)
+
+ b = vheap.random_bucket()
+ oram.load_path(b)
+ print("")
+ print(repr(vheap.Node(oram.path_stop_bucket)))
+ print(oram.path_block_ids)
+ print(oram.path_block_eviction_levels)
+
+ oram.push_down_path()
+ print("")
+ print(repr(vheap.Node(oram.path_stop_bucket)))
+ print(oram.path_block_ids)
+ print(oram.path_block_eviction_levels)
+ print(oram.path_block_reordering)
+
+ oram.evict_path()
+ oram.load_path(b)
+ print("")
+ print(repr(vheap.Node(oram.path_stop_bucket)))
+ print(oram.path_block_ids)
+ print(oram.path_block_eviction_levels)
+
+ oram.push_down_path()
+ print("")
+ print(repr(vheap.Node(oram.path_stop_bucket)))
+ print(oram.path_block_ids)
+ print(oram.path_block_eviction_levels)
+ print(oram.path_block_reordering)
+ assert all(x is None for x in oram.path_block_reordering)
+
+ os.remove(storage_name)
+
+if __name__ == "__main__":
+ main() # pragma: no cover
--- /dev/null
+cryptography
+paramiko
+boto3
+cffi>=1.0.0
+six
+tqdm
--- /dev/null
+[metadata]
+description-file = README.rst
+
+[bdist_wheel]
+# supports python3
+universal=1
--- /dev/null
+import os
+import sys
+import platform
+from setuptools import setup, find_packages
+from codecs import open
+
+here = os.path.abspath(os.path.dirname(__file__))
+
+about = {}
+with open(os.path.join("src", "pyoram", "__about__.py")) as f:
+ exec(f.read(), about)
+
+# Get the long description from the README file
+def _readme():
+ with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f:
+ return f.read()
+
+setup_requirements = []
+requirements = ['cryptography',
+ 'paramiko',
+ 'boto3',
+ 'six',
+ 'tqdm']
+
+if platform.python_implementation() == "PyPy":
+ if sys.pypy_version_info < (2, 6):
+ raise RuntimeError(
+ "PyORAM is not compatible with PyPy < 2.6. Please "
+ "upgrade PyPy to use this library.")
+else:
+ if sys.version_info <= (2, 6):
+ raise RuntimeError(
+ "PyORAM is not compatible with Python < 2.7. Please "
+ "upgrade Python to use this library.")
+ requirements.append("cffi>=1.0.0")
+ setup_requirements.append("cffi>=1.0.0")
+
+setup(
+ name=about['__title__'],
+ version=about['__version__'],
+ description=about['__summary__'],
+ long_description=_readme(),
+ url=about['__uri__'],
+ author=about['__author__'],
+ author_email=about['__email__'],
+ license=about['__license__'],
+ # https://pypi.python.org/pypi?%3Aaction=list_classifiers
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+ 'Intended Audience :: Science/Research',
+ 'Topic :: Security :: Cryptography',
+ "Natural Language :: English",
+ 'License :: OSI Approved :: MIT License',
+ 'Programming Language :: Python :: 2.7',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.3',
+ 'Programming Language :: Python :: 3.4',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: Implementation :: CPython',
+ 'Programming Language :: Python :: Implementation :: PyPy',
+ ],
+ keywords='oram, storage, privacy, cryptography, cloud storage',
+ package_dir={'': 'src'},
+ packages=find_packages(where="src", exclude=["_cffi_src", "_cffi_src.*"]),
+ setup_requires=setup_requirements,
+ install_requires=requirements,
+ cffi_modules=["src/_cffi_src/virtual_heap_helper_build.py:ffi"],
+ # use MANIFEST.in
+ include_package_data=True,
+ test_suite='nose2.collector.collector',
+ tests_require=['unittest2','nose2']
+)
--- /dev/null
+import cffi
+
+#
+# C functions that speed up commonly
+# executed heap calculations in tree-based
+# orams
+#
+
+ffi = cffi.FFI()
+ffi.cdef(
+"""
+int calculate_bucket_level(unsigned int k,
+ unsigned long long b);
+int calculate_last_common_level(unsigned int k,
+ unsigned long long b1,
+ unsigned long long b2);
+""")
+
+ffi.set_source("pyoram.util._virtual_heap_helper",
+"""
+#include <stdio.h>
+#include <stdlib.h>
+
+int calculate_bucket_level(unsigned int k,
+ unsigned long long b)
+{
+ unsigned int h;
+ unsigned long long pow;
+ if (k == 2) {
+ // This is simply log2floor(b+1)
+ h = 0;
+ b += 1;
+ while (b >>= 1) {++h;}
+ return h;
+ }
+ b = (k - 1) * (b + 1) + 1;
+ h = 0;
+ pow = k;
+ while (pow < b) {++h; pow *= k;}
+ return h;
+}
+
+int calculate_last_common_level(unsigned int k,
+ unsigned long long b1,
+ unsigned long long b2)
+{
+ int level1, level2;
+ level1 = calculate_bucket_level(k, b1);
+ level2 = calculate_bucket_level(k, b2);
+ if (level1 != level2) {
+ if (level1 > level2) {
+ while (level1 != level2) {
+ b1 = (b1 - 1)/k;
+ --level1;
+ }
+ }
+ else {
+ while (level2 != level1) {
+ b2 = (b2 - 1)/k;
+ --level2;
+ }
+ }
+ }
+ while (b1 != b2) {
+ b1 = (b1 - 1)/k;
+ b2 = (b2 - 1)/k;
+ --level1;
+ }
+ return level1;
+}
+""")
+
+if __name__ == "__main__":
+ ffi.compile()
--- /dev/null
+__all__ = ('__title__',
+ '__summary__',
+ '__uri__',
+ '__version__',
+ '__author__',
+ '__email__',
+ '__license__',
+ '__copyright__')
+
+__title__ = 'PyORAM'
+__summary__ = 'Python-based Oblivious RAM'
+__uri__ = 'https://github.com/ghackebeil/PyORAM'
+__version__ = '0.3.0'
+__author__ = 'Gabriel A. Hackebeil'
+__email__ = 'gabe.hackebeil@gmail.com'
+__license__ = 'MIT'
+__copyright__ = 'Copyright {0}'.format(__author__)
--- /dev/null
+from pyoram.__about__ import __version__
+
+def _configure_logging():
+ import os
+ import logging
+
+ log = logging.getLogger("pyoram")
+ formatter = logging.Formatter(
+ fmt=("[%(asctime)s.%(msecs)03d,"
+ "%(name)s,%(levelname)s] %(threadName)s %(message)s"),
+ datefmt="%Y-%m-%d %H:%M:%S")
+
+ level = os.environ.get("PYORAM_LOGLEVEL", "WARNING")
+ logfilename = os.environ.get("PYORAM_LOGFILE", None)
+ if len(logging.root.handlers) == 0:
+ # configure the logging with some sensible
+ # defaults.
+ try:
+ import tempfile
+ tempfile = tempfile.TemporaryFile(dir=".")
+ tempfile.close()
+ except OSError:
+ # cannot write in current directory, use the
+ # console logger
+ handler = logging.StreamHandler()
+ else:
+ if logfilename is None:
+ handler = logging.StreamHandler()
+ else:
+ # set up a basic logfile in current directory
+ handler = logging.FileHandler(logfilename)
+ handler.setFormatter(formatter)
+ handler.setLevel(level)
+ log.addHandler(handler)
+ log.setLevel(level)
+ log.info("PyORAM log configured using built-in "
+ "defaults, level=%s", level)
+
+_configure_logging()
+del _configure_logging
+
+def _configure_pyoram():
+ class _Configure(object):
+ __slots__ = ("SHOW_PROGRESS_BAR",)
+ def __init__(self):
+ self.SHOW_PROGRESS_BAR = False
+ return _Configure()
+config = _configure_pyoram()
+del _configure_pyoram
+
+import pyoram.util
+import pyoram.crypto
+import pyoram.storage
+import pyoram.encrypted_storage
+import pyoram.oblivious_storage
--- /dev/null
+import pyoram.crypto.aes
--- /dev/null
+__all__ = ("AES",)
+
+import os
+import cryptography.hazmat.primitives.ciphers
+import cryptography.hazmat.backends
+
+_backend = cryptography.hazmat.backends.default_backend()
+_aes = cryptography.hazmat.primitives.ciphers.algorithms.AES
+_cipher = cryptography.hazmat.primitives.ciphers.Cipher
+_ctrmode = cryptography.hazmat.primitives.ciphers.modes.CTR
+_gcmmode = cryptography.hazmat.primitives.ciphers.modes.GCM
+
+class AES(object):
+
+ key_sizes = [k//8 for k in sorted(_aes.key_sizes)]
+ block_size = _aes.block_size//8
+
+ @staticmethod
+ def KeyGen(size_bytes):
+ assert size_bytes in AES.key_sizes
+ return os.urandom(size_bytes)
+
+ @staticmethod
+ def CTREnc(key, plaintext):
+ iv = os.urandom(AES.block_size)
+ cipher = _cipher(_aes(key), _ctrmode(iv), backend=_backend).encryptor()
+ return iv + cipher.update(plaintext) + cipher.finalize()
+
+ @staticmethod
+ def CTRDec(key, ciphertext):
+ iv = ciphertext[:AES.block_size]
+ cipher = _cipher(_aes(key), _ctrmode(iv), backend=_backend).decryptor()
+ return cipher.update(ciphertext[AES.block_size:]) + \
+ cipher.finalize()
+
+ @staticmethod
+ def GCMEnc(key, plaintext):
+ iv = os.urandom(AES.block_size)
+ cipher = _cipher(_aes(key), _gcmmode(iv), backend=_backend).encryptor()
+ return iv + cipher.update(plaintext) + cipher.finalize() + cipher.tag
+
+ @staticmethod
+ def GCMDec(key, ciphertext):
+ iv = ciphertext[:AES.block_size]
+ tag = ciphertext[-AES.block_size:]
+ cipher = _cipher(_aes(key), _gcmmode(iv, tag), backend=_backend).decryptor()
+ return cipher.update(ciphertext[AES.block_size:-AES.block_size]) + \
+ cipher.finalize()
--- /dev/null
+import pyoram.encrypted_storage.encrypted_block_storage
+import pyoram.encrypted_storage.encrypted_heap_storage
+import pyoram.encrypted_storage.top_cached_encrypted_heap_storage
--- /dev/null
+__all__ = ('EncryptedBlockStorage',)
+
+import struct
+import hmac
+import hashlib
+
+from pyoram.storage.block_storage import (BlockStorageInterface,
+ BlockStorageTypeFactory)
+from pyoram.crypto.aes import AES
+
+import six
+
+class EncryptedBlockStorageInterface(BlockStorageInterface):
+
+ #
+ # Abstract Interface
+ #
+
+ @property
+ def key(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def raw_storage(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+class EncryptedBlockStorage(EncryptedBlockStorageInterface):
+
+ _index_struct_string = "!"+("x"*hashlib.sha384().digest_size)+"?"
+ _index_offset = struct.calcsize(_index_struct_string)
+ _verify_struct_string = "!LLL"
+ _verify_size = struct.calcsize(_verify_struct_string)
+
+ def __init__(self, storage, **kwds):
+ self._key = kwds.pop('key', None)
+ if self._key is None:
+ raise ValueError(
+ "An encryption key is required using "
+ "the 'key' keyword.")
+ if isinstance(storage, BlockStorageInterface):
+ storage_owned = False
+ self._storage = storage
+ if len(kwds):
+ raise ValueError(
+ "Keywords not used when initializing "
+ "with a storage device: %s"
+ % (str(kwds)))
+ else:
+ storage_owned = True
+ storage_type = kwds.pop('storage_type', 'file')
+ self._storage = \
+ BlockStorageTypeFactory(storage_type)(storage, **kwds)
+
+ try:
+ header_data = AES.GCMDec(self._key,
+ self._storage.header_data)
+ (self._ismodegcm,) = struct.unpack(
+ self._index_struct_string,
+ header_data[:self._index_offset])
+ self._verify_digest = header_data[:hashlib.sha384().digest_size]
+
+ verify = hmac.HMAC(
+ key=self.key,
+ msg=struct.pack(self._verify_struct_string,
+ self._storage.block_size,
+ self._storage.block_count,
+ len(self._storage.header_data)),
+ digestmod=hashlib.sha384)
+ if verify.digest() != self._verify_digest:
+ raise ValueError(
+ "HMAC of plaintext index data does not match")
+ if self._ismodegcm:
+ self._encrypt_block_func = AES.GCMEnc
+ self._decrypt_block_func = AES.GCMDec
+ else:
+ self._encrypt_block_func = AES.CTREnc
+ self._decrypt_block_func = AES.CTRDec
+ except:
+ if storage_owned:
+ self._storage.close()
+ raise
+
+ #
+ # Define EncryptedBlockStorageInterface Methods
+ #
+
+ @property
+ def key(self):
+ return self._key
+
+ @property
+ def raw_storage(self):
+ return self._storage
+
+ #
+ # Define BlockStorageInterface Methods
+ #
+
+ def clone_device(self):
+ return EncryptedBlockStorage(self._storage.clone_device(),
+ key=self.key)
+
+ @classmethod
+ def compute_storage_size(cls,
+ block_size,
+ block_count,
+ aes_mode='ctr',
+ storage_type='file',
+ ignore_header=False,
+ **kwds):
+ assert (block_size > 0) and (block_size == int(block_size))
+ assert (block_count > 0) and (block_count == int(block_count))
+ assert aes_mode in ('ctr', 'gcm')
+ if not isinstance(storage_type, BlockStorageInterface):
+ storage_type = BlockStorageTypeFactory(storage_type)
+
+ if aes_mode == 'ctr':
+ extra_block_data = AES.block_size
+ else:
+ assert aes_mode == 'gcm'
+ extra_block_data = 2 * AES.block_size
+ if ignore_header:
+ return (extra_block_data * block_count) + \
+ storage_type.compute_storage_size(
+ block_size,
+ block_count,
+ ignore_header=True,
+ **kwds)
+ else:
+ return cls._index_offset + \
+ 2 * AES.block_size + \
+ (extra_block_data * block_count) + \
+ storage_type.compute_storage_size(
+ block_size,
+ block_count,
+ ignore_header=False,
+ **kwds)
+
+ @classmethod
+ def setup(cls,
+ storage_name,
+ block_size,
+ block_count,
+ aes_mode='ctr',
+ key_size=None,
+ key=None,
+ storage_type='file',
+ initialize=None,
+ **kwds):
+
+ if (key is not None) and (key_size is not None):
+ raise ValueError(
+ "Only one of 'key' or 'keysize' keywords can "
+ "be specified at a time")
+ if key is None:
+ if key_size is None:
+ key_size = AES.key_sizes[-1]
+ if key_size not in AES.key_sizes:
+ raise ValueError(
+ "Invalid key size: %s" % (key_size))
+ key = AES.KeyGen(key_size)
+ else:
+ if len(key) not in AES.key_sizes:
+ raise ValueError(
+ "Invalid key size: %s" % (len(key)))
+
+ if (block_size <= 0) or (block_size != int(block_size)):
+ raise ValueError(
+ "Block size (bytes) must be a positive integer: %s"
+ % (block_size))
+
+ ismodegcm = None
+ encrypt_block_func = None
+ encrypted_block_size = block_size
+ if aes_mode == 'ctr':
+ ismodegcm = False
+ encrypt_block_func = AES.CTREnc
+ encrypted_block_size += AES.block_size
+ elif aes_mode == 'gcm':
+ ismodegcm = True
+ encrypt_block_func = AES.GCMEnc
+ encrypted_block_size += (2 * AES.block_size)
+ else:
+ raise ValueError(
+ "AES encryption mode must be one of 'ctr' or 'gcm'. "
+ "Invalid value: %s" % (aes_mode))
+ assert ismodegcm is not None
+ assert encrypt_block_func is not None
+
+ if not isinstance(storage_type, BlockStorageInterface):
+ storage_type = BlockStorageTypeFactory(storage_type)
+
+ if initialize is None:
+ zeros = bytes(bytearray(block_size))
+ initialize = lambda i: zeros
+ def encrypted_initialize(i):
+ return encrypt_block_func(key, initialize(i))
+ kwds['initialize'] = encrypted_initialize
+
+ user_header_data = kwds.get('header_data', bytes())
+ if type(user_header_data) is not bytes:
+ raise TypeError(
+ "'header_data' must be of type bytes. "
+ "Invalid type: %s" % (type(user_header_data)))
+ # we generate the first time simply to
+ # compute the length
+ tmp = hmac.HMAC(
+ key=key,
+ msg=struct.pack(cls._verify_struct_string,
+ encrypted_block_size,
+ block_count,
+ 0),
+ digestmod=hashlib.sha384).digest()
+ header_data = bytearray(struct.pack(cls._index_struct_string,
+ ismodegcm))
+ header_data[:hashlib.sha384().digest_size] = tmp
+ header_data = header_data + user_header_data
+ header_data = AES.GCMEnc(key, bytes(header_data))
+ # now that we know the length of the header data
+ # being sent to the underlying storage we can
+ # compute the real hmac
+ verify_digest = hmac.HMAC(
+ key=key,
+ msg=struct.pack(cls._verify_struct_string,
+ encrypted_block_size,
+ block_count,
+ len(header_data)),
+ digestmod=hashlib.sha384).digest()
+ header_data = bytearray(struct.pack(cls._index_struct_string,
+ ismodegcm))
+ header_data[:hashlib.sha384().digest_size] = verify_digest
+ header_data = header_data + user_header_data
+ kwds['header_data'] = AES.GCMEnc(key, bytes(header_data))
+
+ return EncryptedBlockStorage(
+ storage_type.setup(storage_name,
+ encrypted_block_size,
+ block_count,
+ **kwds),
+ key=key)
+
+ @property
+ def header_data(self):
+ return AES.GCMDec(self._key,
+ self._storage.header_data)\
+ [self._index_offset:]
+
+ @property
+ def block_count(self):
+ return self._storage.block_count
+
+ @property
+ def block_size(self):
+ if self._ismodegcm:
+ return self._storage.block_size - 2 * AES.block_size
+ else:
+ return self._storage.block_size - AES.block_size
+
+ @property
+ def storage_name(self):
+ return self._storage.storage_name
+
+ def update_header_data(self, new_header_data):
+ self._storage.update_header_data(
+ AES.GCMEnc(
+ self.key,
+ AES.GCMDec(self._key,
+ self._storage.header_data)\
+ [:self._index_offset] + \
+ new_header_data))
+
+ def close(self):
+ self._storage.close()
+
+
+
+
+
+
+
+
+
+
+
+
+ def read_block(self, i):
+ a = self._storage.read_block(i)
+ return self._decrypt_block_func(self._key,a)
+
+ def read_blocks(self, indices, *args, **kwds):
+ a = self._storage.read_blocks(indices, *args, **kwds)
+ return [self._decrypt_block_func(self._key, b) for b in a]
+
+ def yield_blocks(self, indices, *args, **kwds):
+ for b in self._storage.yield_blocks(indices, *args, **kwds):
+ yield self._decrypt_block_func(self._key, b)
+
+
+ def write_block(self, i, block, *args, **kwds):
+ a = self._encrypt_block_func(self._key, block)
+ self._storage.write_block(i, a,*args, **kwds)
+
+ def write_blocks(self, indices, blocks, *args, **kwds):
+ enc_blocks = []
+ for i, b in zip(indices, blocks):
+ enc_blocks.append(self._encrypt_block_func(self._key, b))
+
+
+ self._storage.write_blocks(indices, enc_blocks, *args, **kwds)
+
+
+
+
+
+
+
+
+
+
+
+ @property
+ def bytes_sent(self):
+ return self._storage.bytes_sent
+
+ @property
+ def bytes_received(self):
+ return self._storage.bytes_received
--- /dev/null
+__all__ = ('EncryptedHeapStorage',)
+
+import struct
+
+from pyoram.util.virtual_heap import SizedVirtualHeap
+from pyoram.storage.heap_storage import \
+ (HeapStorageInterface,
+ HeapStorage)
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ (EncryptedBlockStorageInterface,
+ EncryptedBlockStorage)
+
+class EncryptedHeapStorageInterface(HeapStorageInterface):
+
+ #
+ # Abstract Interface
+ #
+
+ @property
+ def key(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def raw_storage(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+class EncryptedHeapStorage(HeapStorage,
+ EncryptedHeapStorageInterface):
+
+ def __init__(self, storage, **kwds):
+
+ if isinstance(storage, EncryptedBlockStorageInterface):
+ if len(kwds):
+ raise ValueError(
+ "Keywords not used when initializing "
+ "with a storage device: %s"
+ % (str(kwds)))
+ else:
+ storage = EncryptedBlockStorage(storage, **kwds)
+
+ super(EncryptedHeapStorage, self).__init__(storage)
+
+ #
+ # Define EncryptedHeapStorageInterface Methods
+ #
+
+ @property
+ def key(self):
+ return self._storage.key
+
+ @property
+ def raw_storage(self):
+ return self._storage.raw_storage
+
+ #
+ # Define HeapStorageInterface Methods
+ # (override what is defined on HeapStorage)
+
+ def clone_device(self):
+ return EncryptedHeapStorage(self._storage.clone_device())
+
+ @classmethod
+ def compute_storage_size(cls,
+ block_size,
+ heap_height,
+ blocks_per_bucket=1,
+ heap_base=2,
+ ignore_header=False,
+ **kwds):
+ assert (block_size > 0) and (block_size == int(block_size))
+ assert heap_height >= 0
+ assert blocks_per_bucket >= 1
+ assert heap_base >= 2
+ assert 'block_count' not in kwds
+ vheap = SizedVirtualHeap(
+ heap_base,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket)
+ if ignore_header:
+ return EncryptedBlockStorage.compute_storage_size(
+ vheap.blocks_per_bucket * block_size,
+ vheap.bucket_count(),
+ ignore_header=True,
+ **kwds)
+ else:
+ return cls._header_offset + \
+ EncryptedBlockStorage.compute_storage_size(
+ vheap.blocks_per_bucket * block_size,
+ vheap.bucket_count(),
+ ignore_header=False,
+ **kwds)
+
+ @classmethod
+ def setup(cls,
+ storage_name,
+ block_size,
+ heap_height,
+ blocks_per_bucket=1,
+ heap_base=2,
+ **kwds):
+ if 'block_count' in kwds:
+ raise ValueError("'block_count' keyword is not accepted")
+ if heap_height < 0:
+ raise ValueError(
+ "heap height must be 0 or greater. Invalid value: %s"
+ % (heap_height))
+ if blocks_per_bucket < 1:
+ raise ValueError(
+ "blocks_per_bucket must be 1 or greater. "
+ "Invalid value: %s" % (blocks_per_bucket))
+ if heap_base < 2:
+ raise ValueError(
+ "heap base must be 2 or greater. Invalid value: %s"
+ % (heap_base))
+
+ vheap = SizedVirtualHeap(
+ heap_base,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket)
+
+ user_header_data = kwds.pop('header_data', bytes())
+ if type(user_header_data) is not bytes:
+ raise TypeError(
+ "'header_data' must be of type bytes. "
+ "Invalid type: %s" % (type(user_header_data)))
+ kwds['header_data'] = \
+ struct.pack(cls._header_struct_string,
+ heap_base,
+ heap_height,
+ blocks_per_bucket) + \
+ user_header_data
+
+ return EncryptedHeapStorage(
+ EncryptedBlockStorage.setup(
+ storage_name,
+ vheap.blocks_per_bucket * block_size,
+ vheap.bucket_count(),
+ **kwds))
+
+ #@property
+ #def header_data(...)
+
+ #@property
+ #def bucket_count(...)
+
+ #@property
+ #def bucket_size(...)
+
+ #@property
+ #def blocks_per_bucket(...)
+
+ #@property
+ #def storage_name(...)
+
+ #@property
+ #def virtual_heap(...)
+
+ #@property
+ #def bucket_storage(...)
+
+ #def update_header_data(...)
+
+ #def close(...)
+
+ #def read_path(...)
+
+ #def write_path(...)
+
+ #@property
+ #def bytes_sent(...)
+
+ #@property
+ #def bytes_received(...)
--- /dev/null
+__all__ = ('TopCachedEncryptedHeapStorage',)
+
+import logging
+import tempfile
+import mmap
+
+import pyoram
+from pyoram.util.virtual_heap import SizedVirtualHeap
+from pyoram.encrypted_storage.encrypted_heap_storage import \
+ (EncryptedHeapStorageInterface,
+ EncryptedHeapStorage)
+
+import tqdm
+import six
+from six.moves import xrange
+
+log = logging.getLogger("pyoram")
+
+class TopCachedEncryptedHeapStorage(EncryptedHeapStorageInterface):
+ """
+ An encrypted block storage device for accessing memory
+ organized as a heap, where the top 1 or more levels can
+ be cached in local memory. This achieves two things:
+
+ (1) Reduces the number of buckets that need to be read
+ from or written to external storage for a given
+ path I/O operation.
+ (2) Allows certain block storage devices to achieve
+ concurrency across path writes by partioning the
+ storage space into independent subheaps starting
+ below the cache line.
+
+ This devices takes as input an existing encrypted heap
+ storage device. This class should not be cloned or used
+ to setup storage, but rather used as a wrapper class for
+ an existing heap storage device to speed up a bulk set
+ of I/O requests. The original heap storage device should
+ not be used after it is wrapped by this class. This
+ class will close the original device when closing
+ itself.
+
+ The number of cached levels (starting from the root
+ bucket at level 0) can be set with the 'cached_levels'
+ keyword (>= 1).
+
+ By default, this will create an independent storage
+ device capable of reading from and writing to the
+ original storage devices memory for each independent
+ subheap (if any) below the last cached level. The
+ 'concurrency_level' keyword can be used to limit the
+ number of concurrent devices to some level below the
+ cache line (>= 0, <= 'cached_levels').
+
+ Values for 'cached_levels' and 'concurrency_level' will
+ be automatically reduced when they are larger than what
+ is allowed by the heap size.
+ """
+
+ def __new__(cls, *args, **kwds):
+ if kwds.get("cached_levels", 1) == 0:
+ assert len(args) == 1
+ storage = args[0]
+ storage.cached_bucket_data = bytes()
+ return storage
+ else:
+ return super(TopCachedEncryptedHeapStorage, cls).\
+ __new__(cls)
+
+ def __init__(self,
+ heap_storage,
+ cached_levels=1,
+ concurrency_level=None):
+ assert isinstance(heap_storage, EncryptedHeapStorage)
+ assert cached_levels != 0
+
+
+ vheap = heap_storage.virtual_heap
+ if cached_levels < 0:
+ cached_levels = vheap.levels
+ if concurrency_level is None:
+ concurrency_level = cached_levels
+ assert concurrency_level >= 0
+ cached_levels = min(vheap.levels, cached_levels)
+ concurrency_level = min(cached_levels, concurrency_level)
+ self._external_level = cached_levels
+ total_buckets = sum(vheap.bucket_count_at_level(l)
+ for l in xrange(cached_levels))
+
+
+ print(" ILA ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI")
+ print(cached_levels)
+ print(concurrency_level)
+ print(" ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI ALI")
+
+
+ self._root_device = heap_storage
+ # clone before we download the cache so that we can
+ # track bytes transferred during read/write requests
+ # (separate from the cached download)
+ self._concurrent_devices = \
+ {vheap.first_bucket_at_level(0): self._root_device.clone_device()}
+
+ self._cached_bucket_count = total_buckets
+ self._cached_buckets_tempfile = tempfile.TemporaryFile()
+ self._cached_buckets_tempfile.seek(0)
+ with tqdm.tqdm(desc=("Downloading %s Cached Heap Buckets"
+ % (self._cached_bucket_count)),
+ total=self._cached_bucket_count*self._root_device.bucket_size,
+ unit="B",
+ unit_scale=True,
+ disable=not pyoram.config.SHOW_PROGRESS_BAR) as progress_bar:
+ for b, bucket in enumerate(
+ self._root_device.bucket_storage.yield_blocks(
+ xrange(vheap.first_bucket_at_level(cached_levels)))):
+ self._cached_buckets_tempfile.write(bucket)
+ progress_bar.update(self._root_device.bucket_size)
+ self._cached_buckets_tempfile.flush()
+ self._cached_buckets_mmap = mmap.mmap(
+ self._cached_buckets_tempfile.fileno(), 0)
+
+ log.info("%s: Cloning %s sub-heap devices"
+ % (self.__class__.__name__, vheap.bucket_count_at_level(concurrency_level)))
+ # Avoid cloning devices when the cache line is at the root
+ # bucket or when the entire heap is cached
+ if (concurrency_level > 0) and \
+ (concurrency_level <= vheap.last_level):
+ for b in xrange(vheap.first_bucket_at_level(concurrency_level),
+ vheap.first_bucket_at_level(concurrency_level+1)):
+ try:
+ self._concurrent_devices[b] = self._root_device.clone_device()
+ except: # pragma: no cover
+ log.error( # pragma: no cover
+ "%s: Exception encountered " # pragma: no cover
+ "while cloning device. " # pragma: no cover
+ "Closing storage." # pragma: no cover
+ % (self.__class__.__name__)) # pragma: no cover
+ self.close() # pragma: no cover
+ raise # pragma: no cover
+
+ self._subheap_storage = {}
+ # Avoid populating this dictionary when the entire
+ # heap is cached
+ if self._external_level <= vheap.last_level:
+ for b in xrange(vheap.first_bucket_at_level(self._external_level),
+ vheap.first_bucket_at_level(self._external_level+1)):
+ node = vheap.Node(b)
+ while node.bucket not in self._concurrent_devices:
+ node = node.parent_node()
+ assert node.bucket >= 0
+ assert node.level == concurrency_level
+ self._subheap_storage[b] = self._concurrent_devices[node.bucket]
+
+ #
+ # Additional Methods
+ #
+
+ @property
+ def cached_bucket_data(self):
+ return self._cached_buckets_mmap
+
+ #
+ # Define EncryptedHeapStorageInterface Methods
+ #
+
+ @property
+ def key(self):
+ return self._root_device.key
+
+ @property
+ def raw_storage(self):
+ return self._root_device.raw_storage
+
+ #
+ # Define HeapStorageInterface Methods
+ #
+
+ def clone_device(self, *args, **kwds):
+ raise NotImplementedError( # pragma: no cover
+ "Class is not designed for cloning") # pragma: no cover
+
+ @classmethod
+ def compute_storage_size(cls, *args, **kwds):
+ return EncryptedHeapStorage.compute_storage_size(*args, **kwds)
+
+ @classmethod
+ def setup(cls, *args, **kwds):
+ raise NotImplementedError( # pragma: no cover
+ "Class is not designed to setup storage") # pragma: no cover
+
+ @property
+ def header_data(self):
+ return self._root_device.header_data
+
+ @property
+ def bucket_count(self):
+ return self._root_device.bucket_count
+
+ @property
+ def bucket_size(self):
+ return self._root_device.bucket_size
+
+ @property
+ def blocks_per_bucket(self):
+ return self._root_device.blocks_per_bucket
+
+ @property
+ def storage_name(self):
+ return self._root_device.storage_name
+
+ @property
+ def virtual_heap(self):
+ return self._root_device.virtual_heap
+
+ @property
+ def bucket_storage(self):
+ return self._root_device.bucket_storage
+
+ def update_header_data(self, new_header_data):
+ self._root_device.update_header_data(new_header_data)
+
+ def close(self):
+ print("Heap Closing 1")
+ log.info("%s: Uploading %s cached bucket data before closing"
+ % (self.__class__.__name__, self._cached_bucket_count))
+ with tqdm.tqdm(desc=("Uploading %s Cached Heap Buckets"
+ % (self._cached_bucket_count)),
+ total=self._cached_bucket_count*self.bucket_size,
+ unit="B",
+ unit_scale=True,
+ disable=not pyoram.config.SHOW_PROGRESS_BAR) as progress_bar:
+ self.bucket_storage.\
+ write_blocks(
+ xrange(self._cached_bucket_count),
+ (self._cached_buckets_mmap[(b*self.bucket_size):
+ ((b+1)*self.bucket_size)]
+ for b in xrange(self._cached_bucket_count)),
+ callback=lambda i: progress_bar.update(self._root_device.bucket_size))
+ for b in self._concurrent_devices:
+ self._concurrent_devices[b].close()
+ self._root_device.close()
+ # forces the bar to become full at close
+ # even if te write_blocks action was faster
+ # the the mininterval time
+ progress_bar.mininterval = 0
+
+ self._cached_buckets_mmap.close()
+ self._cached_buckets_tempfile.close()
+
+ def read_path(self, b, level_start=0):
+ assert 0 <= b < self.virtual_heap.bucket_count()
+ bucket_list = self.virtual_heap.Node(b).bucket_path_from_root()
+ if len(bucket_list) <= self._external_level:
+ return [self._cached_buckets_mmap[(bb*self.bucket_size):
+ ((bb+1)*self.bucket_size)]
+ for bb in bucket_list[level_start:]]
+ elif level_start >= self._external_level:
+ return self._subheap_storage[bucket_list[self._external_level]].\
+ bucket_storage.read_blocks(bucket_list[level_start:])
+ else:
+ local_buckets = bucket_list[:self._external_level]
+ external_buckets = bucket_list[self._external_level:]
+ buckets = []
+ for bb in local_buckets[level_start:]:
+ buckets.append(
+ self._cached_buckets_mmap[(bb*self.bucket_size):
+ ((bb+1)*self.bucket_size)])
+ if len(external_buckets) > 0:
+ buckets.extend(
+ self._subheap_storage[external_buckets[0]].\
+ bucket_storage.read_blocks(external_buckets))
+ assert len(buckets) == len(bucket_list[level_start:])
+ return buckets
+
+ def write_path(self, b, buckets, level_start=0):
+ assert 0 <= b < self.virtual_heap.bucket_count()
+ bucket_list = self.virtual_heap.Node(b).bucket_path_from_root()
+ if len(bucket_list) <= self._external_level:
+ for bb, bucket in zip(bucket_list[level_start:], buckets):
+ self._cached_buckets_mmap[(bb*self.bucket_size):
+ ((bb+1)*self.bucket_size)] = bucket
+ elif level_start >= self._external_level:
+ self._subheap_storage[bucket_list[self._external_level]].\
+ bucket_storage.write_blocks(bucket_list[level_start:], buckets)
+ else:
+ buckets = list(buckets)
+ assert len(buckets) == len(bucket_list[level_start:])
+ local_buckets = bucket_list[:self._external_level]
+ external_buckets = bucket_list[self._external_level:]
+ ndx = -1
+ for ndx, bb in enumerate(local_buckets[level_start:]):
+ self._cached_buckets_mmap[(bb*self.bucket_size):
+ ((bb+1)*self.bucket_size)] = buckets[ndx]
+ if len(external_buckets) > 0:
+ self._subheap_storage[external_buckets[0]].\
+ bucket_storage.write_blocks(external_buckets,
+ buckets[(ndx+1):])
+ @property
+ def bytes_sent(self):
+ return sum(device.bytes_sent for device
+ in self._concurrent_devices.values())
+
+ @property
+ def bytes_received(self):
+ return sum(device.bytes_received for device
+ in self._concurrent_devices.values())
--- /dev/null
+import pyoram.oblivious_storage.tree
--- /dev/null
+import pyoram.oblivious_storage.tree.tree_oram_helper
+import pyoram.oblivious_storage.tree.path_oram
--- /dev/null
+import hashlib
+import hmac
+import struct
+import array
+import logging
+
+import pyoram
+from pyoram.oblivious_storage.tree.tree_oram_helper import \
+ (TreeORAMStorage,
+ TreeORAMStorageManagerExplicitAddressing)
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ EncryptedBlockStorageInterface
+from pyoram.encrypted_storage.encrypted_heap_storage import \
+ (EncryptedHeapStorage,
+ EncryptedHeapStorageInterface)
+from pyoram.encrypted_storage.top_cached_encrypted_heap_storage import \
+ TopCachedEncryptedHeapStorage
+from pyoram.util.virtual_heap import \
+ (SizedVirtualHeap,
+ calculate_necessary_heap_height)
+
+import tqdm
+import six
+from six.moves import xrange
+
+log = logging.getLogger("pyoram")
+
+class PathORAM(EncryptedBlockStorageInterface):
+
+ _header_struct_string = "!"+("x"*2*hashlib.sha384().digest_size)+"L"
+ _header_offset = struct.calcsize(_header_struct_string)
+
+ def __init__(self,
+ storage,
+ stash,
+ position_map,
+ **kwds):
+
+
+ self._oram = None
+ self._block_count = None
+
+ if isinstance(storage, EncryptedHeapStorageInterface):
+ storage_heap = storage
+ close_storage_heap = False
+ if len(kwds):
+ raise ValueError(
+ "Keywords not used when initializing "
+ "with a storage device: %s"
+ % (str(kwds)))
+ else:
+ cached_levels = kwds.pop('cached_levels', 3)
+ # cached_levels = kwds.pop('cached_levels', 1)
+ concurrency_level = kwds.pop('concurrency_level', None)
+ close_storage_heap = True
+ storage_heap = TopCachedEncryptedHeapStorage(EncryptedHeapStorage(storage, **kwds), cached_levels=cached_levels, concurrency_level=concurrency_level)
+ # storage_heap = EncryptedHeapStorage(storage, **kwds)
+
+ (self._block_count,) = struct.unpack(
+ self._header_struct_string,
+ storage_heap.header_data[:self._header_offset])
+ stashdigest = storage_heap.\
+ header_data[:hashlib.sha384().digest_size]
+ positiondigest = storage_heap.\
+ header_data[hashlib.sha384().digest_size:\
+ (2*hashlib.sha384().digest_size)]
+
+ try:
+ if stashdigest != \
+ PathORAM.stash_digest(
+ stash,
+ digestmod=hmac.HMAC(key=storage_heap.key,
+ digestmod=hashlib.sha384)):
+ raise ValueError(
+ "Stash HMAC does not match that saved with "
+ "storage heap %s" % (storage_heap.storage_name))
+ except:
+ if close_storage_heap:
+ storage_heap.close()
+ raise
+
+ try:
+ if positiondigest != \
+ PathORAM.position_map_digest(
+ position_map,
+ digestmod=hmac.HMAC(key=storage_heap.key,
+ digestmod=hashlib.sha384)):
+ raise ValueError(
+ "Position map HMAC does not match that saved with "
+ "storage heap %s" % (storage_heap.storage_name))
+ except:
+ if close_storage_heap:
+ storage_heap.close()
+ raise
+
+ self._oram = TreeORAMStorageManagerExplicitAddressing(
+ storage_heap,
+ stash,
+ position_map)
+ assert self._block_count <= \
+ self._oram.storage_heap.bucket_count
+
+ @classmethod
+ def _init_position_map(cls, vheap, block_count):
+ return array.array("L", [vheap.random_leaf_bucket()
+ for i in xrange(block_count)])
+
+ def _init_oram_block(self, id_, block):
+ oram_block = bytearray(self.block_size)
+ oram_block[self._oram.block_info_storage_size:] = block[:]
+ self._oram.tag_block_with_id(oram_block, id_)
+ return oram_block
+
+ def _extract_virtual_block(self, block):
+ return block[self._oram.block_info_storage_size:]
+
+ #
+ # Add some methods specific to Path ORAM
+ #
+
+ @classmethod
+ def stash_digest(cls, stash, digestmod=None):
+ if digestmod is None:
+ digestmod = hashlib.sha1()
+ id_to_bytes = lambda id_: \
+ struct.pack(TreeORAMStorage.block_id_storage_string, id_)
+ if len(stash) == 0:
+ digestmod.update(b'0')
+ else:
+ for id_ in sorted(stash):
+ if id_ < 0:
+ raise ValueError(
+ "Invalid stash id '%s'. Values must be "
+ "nonnegative integers." % (id_))
+ digestmod.update(id_to_bytes(id_))
+ digestmod.update(bytes(stash[id_]))
+ return digestmod.digest()
+
+ @classmethod
+ def position_map_digest(cls, position_map, digestmod=None):
+ if digestmod is None:
+ digestmod = hashlib.sha1()
+ id_to_bytes = lambda id_: \
+ struct.pack(TreeORAMStorage.block_id_storage_string, id_)
+ assert len(position_map) > 0
+ for addr in position_map:
+ if addr < 0:
+ raise ValueError(
+ "Invalid position map address '%s'. Values must be "
+ "nonnegative integers." % (addr))
+ digestmod.update(id_to_bytes(addr))
+ return digestmod.digest()
+
+ @property
+ def position_map(self):
+ return self._oram.position_map
+
+ @property
+ def stash(self):
+ return self._oram.stash
+
+ def access(self, id_, write_block=None):
+ assert 0 <= id_ <= self.block_count
+ bucket = self.position_map[id_]
+ bucket_level = self._oram.storage_heap.virtual_heap.Node(bucket).level
+ self.position_map[id_] = \
+ self._oram.storage_heap.virtual_heap.\
+ random_bucket_at_level(bucket_level)
+ self._oram.load_path(bucket)
+ block = self._oram.extract_block_from_path(id_)
+
+ if block is None:
+ block = self.stash[id_]
+
+
+ if write_block is not None:
+ block = self._init_oram_block(id_, write_block)
+
+
+ self.stash[id_] = block
+ self._oram.push_down_path()
+ self._oram.fill_path_from_stash()
+ self._oram.evict_path()
+ if write_block is None:
+ return self._extract_virtual_block(block)
+
+ @property
+ def heap_storage(self):
+ return self._oram.storage_heap
+
+ #
+ # Define EncryptedBlockStorageInterface Methods
+ #
+
+ @property
+ def key(self):
+ return self._oram.storage_heap.key
+
+ @property
+ def raw_storage(self):
+ return self._oram.storage_heap.raw_storage
+
+ #
+ # Define BlockStorageInterface Methods
+ #
+
+ @classmethod
+ def compute_storage_size(cls,
+ block_size,
+ block_count,
+ bucket_capacity=4,
+ heap_base=2,
+ ignore_header=False,
+ **kwds):
+ assert (block_size > 0) and (block_size == int(block_size))
+ assert (block_count > 0) and (block_count == int(block_count))
+ assert bucket_capacity >= 1
+ assert heap_base >= 2
+ assert 'heap_height' not in kwds
+ heap_height = calculate_necessary_heap_height(heap_base,
+ block_count)
+ block_size += TreeORAMStorageManagerExplicitAddressing.\
+ block_info_storage_size
+ if ignore_header:
+ return EncryptedHeapStorage.compute_storage_size(
+ block_size,
+ heap_height,
+ blocks_per_bucket=bucket_capacity,
+ heap_base=heap_base,
+ ignore_header=True,
+ **kwds)
+ else:
+ return cls._header_offset + \
+ EncryptedHeapStorage.compute_storage_size(
+ block_size,
+ heap_height,
+ blocks_per_bucket=bucket_capacity,
+ heap_base=heap_base,
+ ignore_header=False,
+ **kwds)
+
+ @classmethod
+ def setup(cls,
+ storage_name,
+ block_size,
+ block_count,
+ bucket_capacity=4,
+ heap_base=2,
+ cached_levels=3,
+ concurrency_level=None,
+ **kwds):
+ if 'heap_height' in kwds:
+ raise ValueError("'heap_height' keyword is not accepted")
+ if (bucket_capacity <= 0) or \
+ (bucket_capacity != int(bucket_capacity)):
+ raise ValueError(
+ "Bucket capacity must be a positive integer: %s"
+ % (bucket_capacity))
+ if (block_size <= 0) or (block_size != int(block_size)):
+ raise ValueError(
+ "Block size (bytes) must be a positive integer: %s"
+ % (block_size))
+ if (block_count <= 0) or (block_count != int(block_count)):
+ raise ValueError(
+ "Block count must be a positive integer: %s"
+ % (block_count))
+
+ if heap_base < 2:
+ raise ValueError(
+ "heap base must be 2 or greater. Invalid value: %s"
+ % (heap_base))
+
+ heap_height = calculate_necessary_heap_height(heap_base,
+ block_count)
+ stash = {}
+ vheap = SizedVirtualHeap(
+ heap_base,
+ heap_height,
+ blocks_per_bucket=bucket_capacity)
+ position_map = cls._init_position_map(vheap, block_count)
+
+ oram_block_size = block_size + \
+ TreeORAMStorageManagerExplicitAddressing.\
+ block_info_storage_size
+
+ user_header_data = kwds.pop('header_data', bytes())
+ if type(user_header_data) is not bytes:
+ raise TypeError(
+ "'header_data' must be of type bytes. "
+ "Invalid type: %s" % (type(user_header_data)))
+
+ initialize = kwds.pop('initialize', None)
+
+ header_data = struct.pack(
+ cls._header_struct_string,
+ block_count)
+ kwds['header_data'] = bytes(header_data) + user_header_data
+ empty_bucket = bytearray(oram_block_size * bucket_capacity)
+ empty_bucket_view = memoryview(empty_bucket)
+ for i in xrange(bucket_capacity):
+ TreeORAMStorageManagerExplicitAddressing.tag_block_as_empty(
+ empty_bucket_view[(i*oram_block_size):\
+ ((i+1)*oram_block_size)])
+ empty_bucket = bytes(empty_bucket)
+
+ kwds['initialize'] = lambda i: empty_bucket
+ f = None
+ try:
+ log.info("%s: setting up encrypted heap storage"
+ % (cls.__name__))
+ f = EncryptedHeapStorage.setup(storage_name,
+ oram_block_size,
+ heap_height,
+ heap_base=heap_base,
+ blocks_per_bucket=bucket_capacity,
+ **kwds)
+ if cached_levels != 0:
+ f = TopCachedEncryptedHeapStorage(
+ f,
+ cached_levels=cached_levels,
+ concurrency_level=concurrency_level)
+ elif concurrency_level is not None:
+ raise ValueError( # pragma: no cover
+ "'concurrency_level' keyword is " # pragma: no cover
+ "not used when no heap levels " # pragma: no cover
+ "are cached") # pragma: no cover
+ oram = TreeORAMStorageManagerExplicitAddressing(
+ f, stash, position_map)
+ if initialize is None:
+ zeros = bytes(bytearray(block_size))
+ initialize = lambda i: zeros
+ initial_oram_block = bytearray(oram_block_size)
+ for i in tqdm.tqdm(xrange(block_count),
+ desc=("Initializing %s Blocks" % (cls.__name__)),
+ total=block_count,
+ disable=not pyoram.config.SHOW_PROGRESS_BAR):
+
+ oram.tag_block_with_id(initial_oram_block, i)
+ initial_oram_block[oram.block_info_storage_size:] = \
+ initialize(i)[:]
+
+ bucket = oram.position_map[i]
+ bucket_level = vheap.Node(bucket).level
+ oram.position_map[i] = \
+ oram.storage_heap.virtual_heap.\
+ random_bucket_at_level(bucket_level)
+
+ oram.load_path(bucket)
+ oram.push_down_path()
+ # place a copy in the stash
+ oram.stash[i] = bytearray(initial_oram_block)
+ oram.fill_path_from_stash()
+ oram.evict_path()
+
+ header_data = bytearray(header_data)
+ stash_digest = cls.stash_digest(
+ oram.stash,
+ digestmod=hmac.HMAC(key=oram.storage_heap.key,
+ digestmod=hashlib.sha384))
+ position_map_digest = cls.position_map_digest(
+ oram.position_map,
+ digestmod=hmac.HMAC(key=oram.storage_heap.key,
+ digestmod=hashlib.sha384))
+ header_data[:len(stash_digest)] = stash_digest[:]
+ header_data[len(stash_digest):\
+ (len(stash_digest)+len(position_map_digest))] = \
+ position_map_digest[:]
+ f.update_header_data(bytes(header_data) + user_header_data)
+ return PathORAM(f, stash, position_map=position_map)
+ except:
+ if f is not None:
+ f.close() # pragma: no cover
+ raise
+
+ @property
+ def header_data(self):
+ return self._oram.storage_heap.\
+ header_data[self._header_offset:]
+
+ @property
+ def block_count(self):
+ return self._block_count
+
+ @property
+ def block_size(self):
+ return self._oram.block_size - self._oram.block_info_storage_size
+
+ @property
+ def storage_name(self):
+ return self._oram.storage_heap.storage_name
+
+ def update_header_data(self, new_header_data):
+ self._oram.storage_heap.update_header_data(
+ self._oram.storage_heap.header_data[:self._header_offset] + \
+ new_header_data)
+
+ def close(self):
+ log.info("%s: Closing" % (self.__class__.__name__))
+ print("Closing")
+
+ if self._oram is not None:
+ try:
+ stashdigest = \
+ PathORAM.stash_digest(
+ self._oram.stash,
+ digestmod=hmac.HMAC(key=self._oram.storage_heap.key,
+ digestmod=hashlib.sha384))
+
+ print("Closing 1")
+ positiondigest = \
+ PathORAM.position_map_digest(
+ self._oram.position_map,
+ digestmod=hmac.HMAC(key=self._oram.storage_heap.key,
+ digestmod=hashlib.sha384))
+
+ print("Closing 2")
+ new_header_data = \
+ bytearray(self._oram.storage_heap.\
+ header_data[:self._header_offset])
+
+ print("Closing 3")
+ new_header_data[:hashlib.sha384().digest_size] = \
+ stashdigest
+ new_header_data[hashlib.sha384().digest_size:\
+ (2*hashlib.sha384().digest_size)] = \
+ positiondigest
+
+ print("Closing 4")
+ self._oram.storage_heap.update_header_data(
+ bytes(new_header_data) + self.header_data)
+ print("Closing 5")
+ except: # pragma: no cover
+ log.error( # pragma: no cover
+ "%s: Failed to update header data with " # pragma: no cover
+ "current stash and position map state" # pragma: no cover
+ % (self.__class__.__name__)) # pragma: no cover
+ print("Closing ")
+ raise
+ finally:
+ print("Closing 6")
+ self._oram.storage_heap.close()
+ print("Closing 7")
+
+ def read_blocks(self, indices):
+ blocks = []
+ for i in indices:
+ blocks.append(self.access(i))
+ return blocks
+
+ def read_block(self, i):
+ return self.access(i)
+
+ def write_blocks(self, indices, blocks):
+ for i, block in zip(indices, blocks):
+ self.access(i, write_block=block)
+
+ def write_block(self, i, block):
+ self.access(i, write_block=block)
+
+ @property
+ def bytes_sent(self):
+ return self._oram.storage_heap.bytes_sent
+
+ @property
+ def bytes_received(self):
+ return self._oram.storage_heap.bytes_received
--- /dev/null
+__all__ = ('TreeORAMStorageManagerExplicitAddressing',
+ 'TreeORAMStorageManagerPointerAddressing')
+
+import struct
+import copy
+
+from pyoram.util.virtual_heap import \
+ SizedVirtualHeap
+
+from six.moves import xrange
+
+class TreeORAMStorage(object):
+
+ empty_block_id = -1
+
+ block_status_storage_string = "!?"
+ block_id_storage_string = "!L"
+ block_info_storage_string = "!?L"
+
+ block_status_storage_size = \
+ struct.calcsize(block_status_storage_string)
+ block_info_storage_size = \
+ struct.calcsize(block_info_storage_string)
+
+ empty_block_bytes_tag = \
+ struct.pack(block_status_storage_string, False)
+ real_block_bytes_tag = \
+ struct.pack(block_status_storage_string, True)
+
+ def __init__(self,
+ storage_heap,
+ stash):
+ self.storage_heap = storage_heap
+ self.stash = stash
+
+ vheap = self.storage_heap.virtual_heap
+ self.bucket_size = self.storage_heap.bucket_size
+ self.block_size = self.bucket_size // vheap.blocks_per_bucket
+ assert self.block_size * vheap.blocks_per_bucket == \
+ self.bucket_size
+
+ self.path_stop_bucket = None
+ self.path_bucket_count = 0
+ self.path_byte_dataview = \
+ bytearray(self.bucket_size * vheap.levels)
+ dataview = memoryview(self.path_byte_dataview)
+ self.path_bucket_dataview = \
+ [dataview[(i*self.bucket_size):((i+1)*self.bucket_size)]
+ for i in xrange(vheap.levels)]
+
+ self.path_block_dataview = []
+ for i in xrange(vheap.levels):
+ bucketview = self.path_bucket_dataview[i]
+ for j in xrange(vheap.blocks_per_bucket):
+ self.path_block_dataview.append(
+ bucketview[(j*self.block_size):((j+1)*self.block_size)])
+
+ max_blocks_on_path = vheap.levels * vheap.blocks_per_bucket
+ assert len(self.path_block_dataview) == max_blocks_on_path
+ self.path_block_ids = [-1] * max_blocks_on_path
+ self.path_block_eviction_levels = [None] * max_blocks_on_path
+ self.path_block_reordering = [None] * max_blocks_on_path
+ self.path_blocks_inserted = []
+
+ def load_path(self, b):
+ vheap = self.storage_heap.virtual_heap
+ Z = vheap.blocks_per_bucket
+ lcl = vheap.clib.calculate_last_common_level
+ k = vheap.k
+
+ read_level_start = 0
+ if self.path_stop_bucket is not None:
+ # don't download the root and any other buckets
+ # that are common between the previous bucket path
+ # and the new one
+ read_level_start = lcl(k, self.path_stop_bucket, b)
+ assert 0 <= b < vheap.bucket_count()
+ self.path_stop_bucket = b
+ new_buckets = self.storage_heap.read_path(
+ self.path_stop_bucket,
+ level_start=read_level_start)
+
+ self.path_bucket_count = read_level_start + len(new_buckets)
+ pos = 0
+ for i in xrange(self.path_bucket_count):
+ if i >= read_level_start:
+ self.path_bucket_dataview[i][:] = \
+ new_buckets[i-read_level_start][:]
+ for j in xrange(Z):
+ block_id, block_addr = \
+ self.get_block_info(self.path_block_dataview[pos])
+ self.path_block_ids[pos] = block_id
+ if block_id != self.empty_block_id:
+ self.path_block_eviction_levels[pos] = \
+ lcl(k, self.path_stop_bucket, block_addr)
+ else:
+ self.path_block_eviction_levels[pos] = None
+ self.path_block_reordering[pos] = None
+ pos += 1
+
+ max_blocks_on_path = vheap.levels * Z
+ while pos != max_blocks_on_path:
+ self.path_block_ids[pos] = None
+ self.path_block_eviction_levels[pos] = None
+ self.path_block_reordering[pos] = None
+ pos += 1
+
+ self.path_blocks_inserted = []
+
+ def push_down_path(self):
+ vheap = self.storage_heap.virtual_heap
+ Z = vheap.blocks_per_bucket
+
+ bucket_count = self.path_bucket_count
+ block_ids = self.path_block_ids
+ block_eviction_levels = self.path_block_eviction_levels
+ block_reordering = self.path_block_reordering
+ def _do_swap(write_pos, read_pos):
+ block_ids[write_pos], block_eviction_levels[write_pos] = \
+ block_ids[read_pos], block_eviction_levels[read_pos]
+ block_ids[read_pos], block_eviction_levels[read_pos] = \
+ self.empty_block_id, None
+ block_reordering[write_pos] = read_pos
+ block_reordering[read_pos] = -1
+
+ def _new_write_pos(current):
+ current -= 1
+ if current < 0:
+ return None, None
+ while (block_eviction_levels[current] is not None):
+ current -= 1
+ if current < 0:
+ return None, None
+ assert block_ids[current] == \
+ self.empty_block_id
+ return current, current // Z
+
+ def _new_read_pos(current):
+ current -= 1
+ if current < 0:
+ return None
+ while (block_eviction_levels[current] is None):
+ current -= 1
+ if current < 0:
+ return None
+ assert block_ids[current] != \
+ self.empty_block_id
+ return current
+
+ write_pos, write_level = _new_write_pos(bucket_count * Z)
+ while write_pos is not None:
+ read_pos = _new_read_pos(write_pos)
+ if read_pos is None:
+ break
+ while ((read_pos // Z) == write_level) or \
+ (write_level > block_eviction_levels[read_pos]):
+ read_pos = _new_read_pos(read_pos)
+ if read_pos is None:
+ break
+ if read_pos is not None:
+ _do_swap(write_pos, read_pos)
+ else:
+ # Jump directly to the start of this
+ # bucket. There is not point in checking
+ # for other empty slots because no blocks
+ # can be evicted to this level.
+ write_pos = Z * (write_pos//Z)
+ write_pos, write_level = _new_write_pos(write_pos)
+
+ def fill_path_from_stash(self):
+ vheap = self.storage_heap.virtual_heap
+ lcl = vheap.clib.calculate_last_common_level
+ k = vheap.k
+ Z = vheap.blocks_per_bucket
+
+ bucket_count = self.path_bucket_count
+ stop_bucket = self.path_stop_bucket
+ block_ids = self.path_block_ids
+ block_eviction_levels = self.path_block_eviction_levels
+ blocks_inserted = self.path_blocks_inserted
+
+ stash_eviction_levels = {}
+ largest_write_position = (bucket_count * Z) - 1
+ for write_pos in xrange(largest_write_position,-1,-1):
+ write_level = write_pos // Z
+ if block_ids[write_pos] == self.empty_block_id:
+ del_id = None
+ for id_ in self.stash:
+ if id_ not in stash_eviction_levels:
+ block_id, block_addr = \
+ self.get_block_info(self.stash[id_])
+ assert id_ != self.empty_block_id
+ eviction_level = stash_eviction_levels[id_] = \
+ lcl(k, stop_bucket, block_addr)
+ else:
+ eviction_level = stash_eviction_levels[id_]
+ if write_level <= eviction_level:
+ block_ids[write_pos] = id_
+ block_eviction_levels[write_pos] = \
+ eviction_level
+ blocks_inserted.append(
+ (write_pos, self.stash[id_]))
+ del_id = id_
+ break
+ if del_id is not None:
+ del self.stash[del_id]
+
+ def evict_path(self):
+ vheap = self.storage_heap.virtual_heap
+ Z = vheap.blocks_per_bucket
+
+ bucket_count = self.path_bucket_count
+ stop_bucket = self.path_stop_bucket
+ bucket_dataview = self.path_bucket_dataview
+ block_dataview = self.path_block_dataview
+ block_reordering = self.path_block_reordering
+ blocks_inserted = self.path_blocks_inserted
+
+ for i, read_pos in enumerate(
+ reversed(block_reordering)):
+ if (read_pos is not None) and \
+ (read_pos != -1):
+ write_pos = len(block_reordering) - 1 - i
+ block_dataview[write_pos][:] = block_dataview[read_pos][:]
+
+ for write_pos, read_pos in enumerate(block_reordering):
+ if read_pos == -1:
+ self.tag_block_as_empty(block_dataview[write_pos])
+
+ for write_pos, block in blocks_inserted:
+ block_dataview[write_pos][:] = block[:]
+
+ self.storage_heap.write_path(
+ stop_bucket,
+ (bucket_dataview[i].tobytes()
+ for i in xrange(bucket_count)))
+
+ def extract_block_from_path(self, id_):
+ block_ids = self.path_block_ids
+ block_dataview = self.path_block_dataview
+ try:
+ pos = block_ids.index(id_)
+ # make a copy
+ block = bytearray(block_dataview[pos])
+ self._set_path_position_to_empty(pos)
+ return block
+ except ValueError:
+ return None
+
+ def _set_path_position_to_empty(self, pos):
+ self.path_block_ids[pos] = self.empty_block_id
+ self.path_block_eviction_levels[pos] = None
+ self.path_block_reordering[pos] = -1
+
+ @staticmethod
+ def tag_block_as_empty(block):
+ block[:TreeORAMStorage.block_status_storage_size] = \
+ TreeORAMStorage.empty_block_bytes_tag[:]
+
+ @staticmethod
+ def tag_block_with_id(block, id_):
+ assert id_ >= 0
+ struct.pack_into(TreeORAMStorage.block_info_storage_string,
+ block,
+ 0,
+ True,
+ id_)
+
+ def get_block_info(self, block):
+ raise NotImplementedError # pragma: no cover
+
+class TreeORAMStorageManagerExplicitAddressing(
+ TreeORAMStorage):
+ """
+ This class should be used to implement tree-based ORAMs
+ that use an explicit position map. Blocks are assumed to
+ begin with bytes representing the block id.
+ """
+
+ block_info_storage_string = \
+ TreeORAMStorage.block_info_storage_string
+ block_info_storage_size = \
+ struct.calcsize(block_info_storage_string)
+
+ def __init__(self,
+ storage_heap,
+ stash,
+ position_map):
+ super(TreeORAMStorageManagerExplicitAddressing, self).\
+ __init__(storage_heap, stash)
+ self.position_map = position_map
+
+ def get_block_info(self, block):
+ real, id_ = struct.unpack_from(
+ self.block_info_storage_string, block)
+ if real:
+ return id_, self.position_map[id_]
+ else:
+ return self.empty_block_id, None
+
+class TreeORAMStorageManagerPointerAddressing(
+ TreeORAMStorage):
+ """
+ This class should be used to implement tree-based ORAMs
+ that use a pointer-based position map stored with the
+ blocks. Blocks are assumed to begin with bytes
+ representing the block id followed by bytes representing
+ the blocks current heap bucket address.
+ """
+
+ block_info_storage_string = \
+ TreeORAMStorage.block_info_storage_string + "L"
+ block_info_storage_size = \
+ struct.calcsize(block_info_storage_string)
+
+ def __init__(self,
+ storage_heap,
+ stash):
+ super(TreeORAMStorageManagerPointerAddressing, self).\
+ __init__(storage_heap, stash)
+ self.position_map = None
+
+ def get_block_info(self, block):
+ real, id_, addr = struct.unpack_from(
+ self.block_info_storage_string, block)
+ if not real:
+ return self.empty_block_id, 0
+ else:
+ return id_, addr
--- /dev/null
+import time
+
+
+
+
+class Singleton:
+ """
+ A non-thread-safe helper class to ease implementing singletons.
+ This should be used as a decorator -- not a metaclass -- to the
+ class that should be a singleton.
+
+ The decorated class can define one `__init__` function that
+ takes only the `self` argument. Also, the decorated class cannot be
+ inherited from. Other than that, there are no restrictions that apply
+ to the decorated class.
+
+ To get the singleton instance, use the `Instance` method. Trying
+ to use `__call__` will result in a `TypeError` being raised.
+
+ """
+
+ def __init__(self, decorated):
+ self._decorated = decorated
+
+
+ def Instance(self):
+ """
+ Returns the singleton instance. Upon its first call, it creates a
+ new instance of the decorated class and calls its `__init__` method.
+ On all subsequent calls, the already created instance is returned.
+
+ """
+ try:
+ return self._instance
+ except AttributeError:
+ self._instance = self._decorated()
+ return self._instance
+
+
+
+ def __call__(self):
+ raise TypeError('Singletons must be accessed through `Instance()`.')
+
+ def __instancecheck__(self, inst):
+ return isinstance(inst, self._decorated)
+
+
+@Singleton
+class Foo:
+ def __init__(self):
+ print 'Foo created'
+ self._totalTime = 0;
+
+ def getTime(self):
+ return self._totalTime
+
+ def resetTimer(self):
+ self._totalTime = 0
+
+ def startTimer(self):
+ self._startTime = time.time()
+
+ def endTimer(self):
+ self._totalTime += time.time() - self._startTime
--- /dev/null
+import pyoram.storage.block_storage
+import pyoram.storage.block_storage_file
+import pyoram.storage.block_storage_mmap
+import pyoram.storage.block_storage_ram
+import pyoram.storage.block_storage_sftp
+import pyoram.storage.block_storage_s3
+import pyoram.storage.heap_storage
--- /dev/null
+__all__ = ('BlockStorageTypeFactory',)
+
+import logging
+
+log = logging.getLogger("pyoram")
+
+def BlockStorageTypeFactory(storage_type_name):
+ if storage_type_name in BlockStorageTypeFactory._registered_devices:
+ return BlockStorageTypeFactory.\
+ _registered_devices[storage_type_name]
+ else:
+ raise ValueError(
+ "BlockStorageTypeFactory: Unsupported storage "
+ "type: %s" % (storage_type_name))
+BlockStorageTypeFactory._registered_devices = {}
+
+def _register_device(name, type_):
+ if name in BlockStorageTypeFactory._registered_devices:
+ raise ValueError("Can not register block storage device type "
+ "with name '%s'. A device type is already "
+ "registered with that name." % (name))
+ if not issubclass(type_, BlockStorageInterface):
+ raise TypeError("Can not register block storage device type "
+ "'%s'. The device must be a subclass of "
+ "BlockStorageInterface" % (type_))
+ BlockStorageTypeFactory._registered_devices[name] = type_
+BlockStorageTypeFactory.register_device = _register_device
+
+class BlockStorageInterface(object):
+
+ def __enter__(self):
+ return self
+ def __exit__(self, *args):
+ self.close()
+
+ #
+ # Abstract Interface
+ #
+
+ def clone_device(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+ @classmethod
+ def compute_storage_size(cls, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @classmethod
+ def setup(cls, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+ @property
+ def header_data(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def block_count(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def block_size(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def storage_name(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+ def update_header_data(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ def close(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ def read_blocks(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ def yield_blocks(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ def read_block(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ def write_blocks(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ def write_block(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+ @property
+ def bytes_sent(self):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def bytes_received(self):
+ raise NotImplementedError # pragma: no cover
--- /dev/null
+__all__ = ('BlockStorageFile',)
+
+import os
+import struct
+import logging
+import errno
+from multiprocessing.pool import ThreadPool
+
+import pyoram
+from pyoram.storage.block_storage import \
+ (BlockStorageInterface,
+ BlockStorageTypeFactory)
+
+import tqdm
+import six
+from six.moves import xrange
+
+import time
+from AliTimer import *
+
+log = logging.getLogger("pyoram")
+
+class default_filesystem(object):
+ open = open
+ remove = os.remove
+ stat = os.stat
+
+class BlockStorageFile(BlockStorageInterface):
+ """
+ A class implementing the block storage interface
+ using a local file.
+ """
+
+ _index_struct_string = "!LLL?"
+ _index_offset = struct.calcsize(_index_struct_string)
+
+ def __init__(self,
+ storage_name,
+ threadpool_size=None,
+ ignore_lock=False,
+ _filesystem=default_filesystem):
+ self._timer = Foo.Instance();
+
+ self._bytes_sent = 0
+ self._bytes_received = 0
+ self._filesystem = _filesystem
+ self._ignore_lock = ignore_lock
+ self._f = None
+ self._pool = None
+ self._close_pool = True
+ self._async_write = None
+ self._storage_name = storage_name
+ self._f = self._filesystem.open(self.storage_name, "r+b")
+ self._f.seek(0)
+ self._block_size, self._block_count, user_header_size, locked = \
+ struct.unpack(
+ BlockStorageFile._index_struct_string,
+ self._f.read(BlockStorageFile._index_offset))
+
+ if locked and (not self._ignore_lock):
+ self._f.close()
+ self._f = None
+ raise IOError(
+ "Can not open block storage device because it is "
+ "locked by another process. To ignore this check, "
+ "initialize this class with the keyword 'ignore_lock' "
+ "set to True.")
+ self._user_header_data = bytes()
+ if user_header_size > 0:
+ self._user_header_data = \
+ self._f.read(user_header_size)
+ self._header_offset = BlockStorageFile._index_offset + \
+ len(self._user_header_data)
+
+ # TODO: Figure out why this is required for Python3
+ # in order to prevent issues with the
+ # TopCachedEncryptedHeapStorage class. The
+ # problem has something to do with bufferedio,
+ # but it makes no sense why this fixes it (all
+ # we've done is read above these lines). As
+ # part of this, investigate whethor or not we
+ # need the call to flush after write_block(s),
+ # or if its simply connected to some Python3
+ # bug in bufferedio.
+ self._f.flush()
+
+ if not self._ignore_lock:
+ # turn on the locked flag
+ self._f.seek(0)
+ self._f.write(
+ struct.pack(BlockStorageFile._index_struct_string,
+ self.block_size,
+ self.block_count,
+ len(self._user_header_data),
+ True))
+ self._f.flush()
+
+ if threadpool_size != 0:
+ self._pool = ThreadPool(threadpool_size)
+
+ def _check_async(self):
+ if self._async_write is not None:
+ self._async_write.get()
+ self._async_write = None
+ # TODO: Figure out why tests fail on Python3 without this
+ if six.PY3:
+ if self._f is None:
+ return
+ self._f.flush()
+
+ def _schedule_async_write(self, args, callback=None):
+ assert self._async_write is None
+ if self._pool is not None:
+ self._async_write = \
+ self._pool.apply_async(self._writev, (args, callback))
+ else:
+ self._writev(args, callback)
+
+ # This method is usually executed in another thread, so
+ # do not attempt to handle exceptions because it will
+ # not work.
+ def _writev(self, chunks, callback):
+ for i, block in chunks:
+
+ # startTime = time.time();
+ self._timer.startTimer();
+ self._f.seek(self._header_offset + i * self.block_size)
+ self._f.write(block)
+ # self._f.flush()
+ self._timer.endTimer();
+
+ # print("Write....... " + str(time.time() - startTime))
+
+ if callback is not None:
+ callback(i)
+
+ def _prep_for_close(self):
+ print("prep file close 1")
+ self._check_async()
+ print("prep file close 2")
+
+ if self._close_pool and (self._pool is not None):
+ print("prep file close 3")
+ self._pool.close()
+ print("prep file close 4")
+ self._pool.join()
+ print("prep file close 5")
+ self._pool = None
+ print("prep file close 6")
+
+
+ if self._f is not None:
+ print("prep file close 7")
+ if not self._ignore_lock:
+ print("prep file close 8")
+ # turn off the locked flag
+ self._f.seek(0)
+ print("prep file close 9")
+
+ a = struct.pack(BlockStorageFile._index_struct_string,
+ self.block_size,
+ self.block_count,
+ len(self._user_header_data),
+ False)
+
+ print("prep file close 9.1")
+
+ self._f.write(a)
+ print("prep file close 10")
+ self._f.flush()
+ print("prep file close 11")
+
+ #
+ # Define BlockStorageInterface Methods
+ #
+
+ def clone_device(self):
+ f = BlockStorageFile(self.storage_name,
+ threadpool_size=0,
+ ignore_lock=True)
+ f._pool = self._pool
+ f._close_pool = False
+ return f
+
+ @classmethod
+ def compute_storage_size(cls,
+ block_size,
+ block_count,
+ header_data=None,
+ ignore_header=False):
+ assert (block_size > 0) and (block_size == int(block_size))
+ assert (block_count > 0) and (block_count == int(block_count))
+ if header_data is None:
+ header_data = bytes()
+ if ignore_header:
+ return block_size * block_count
+ else:
+ return BlockStorageFile._index_offset + \
+ len(header_data) + \
+ block_size * block_count
+
+ @classmethod
+ def setup(cls,
+ storage_name,
+ block_size,
+ block_count,
+ initialize=None,
+ header_data=None,
+ ignore_existing=False,
+ threadpool_size=None,
+ _filesystem=default_filesystem):
+
+ if (not ignore_existing):
+ _exists = True
+ try:
+ _filesystem.stat(storage_name)
+ except OSError as e:
+ if e.errno == errno.ENOENT:
+ _exists = False
+ if _exists:
+ raise IOError(
+ "Storage location already exists: %s"
+ % (storage_name))
+ if (block_size <= 0) or (block_size != int(block_size)):
+ raise ValueError(
+ "Block size (bytes) must be a positive integer: %s"
+ % (block_size))
+ if (block_count <= 0) or (block_count != int(block_count)):
+ raise ValueError(
+ "Block count must be a positive integer: %s"
+ % (block_count))
+ if (header_data is not None) and \
+ (type(header_data) is not bytes):
+ raise TypeError(
+ "'header_data' must be of type bytes. "
+ "Invalid type: %s" % (type(header_data)))
+
+ if initialize is None:
+ zeros = bytes(bytearray(block_size))
+ initialize = lambda i: zeros
+ try:
+ with _filesystem.open(storage_name, "wb") as f:
+ # create_index
+ if header_data is None:
+ f.write(struct.pack(BlockStorageFile._index_struct_string,
+ block_size,
+ block_count,
+ 0,
+ False))
+ else:
+ f.write(struct.pack(BlockStorageFile._index_struct_string,
+ block_size,
+ block_count,
+ len(header_data),
+ False))
+ f.write(header_data)
+ with tqdm.tqdm(total=block_count*block_size,
+ desc="Initializing File Block Storage Space",
+ unit="B",
+ unit_scale=True,
+ disable=not pyoram.config.SHOW_PROGRESS_BAR) as progress_bar:
+ for i in xrange(block_count):
+ block = initialize(i)
+ assert len(block) == block_size, \
+ ("%s != %s" % (len(block), block_size))
+ f.write(block)
+ progress_bar.update(n=block_size)
+ except: # pragma: no cover
+ _filesystem.remove(storage_name) # pragma: no cover
+ raise # pragma: no cover
+
+ return BlockStorageFile(storage_name,
+ threadpool_size=threadpool_size,
+ _filesystem=_filesystem)
+
+ @property
+ def header_data(self):
+ return self._user_header_data
+
+ @property
+ def block_count(self):
+ return self._block_count
+
+ @property
+ def block_size(self):
+ return self._block_size
+
+ @property
+ def storage_name(self):
+ return self._storage_name
+
+ def update_header_data(self, new_header_data):
+ self._check_async()
+ if len(new_header_data) != len(self.header_data):
+ raise ValueError(
+ "The size of header data can not change.\n"
+ "Original bytes: %s\n"
+ "New bytes: %s" % (len(self.header_data),
+ len(new_header_data)))
+ self._user_header_data = bytes(new_header_data)
+
+ self._timer.startTimer();
+ self._f.seek(BlockStorageFile._index_offset)
+ self._f.write(self._user_header_data)
+ self._timer.endTimer();
+
+ def close(self):
+ print("file close 1")
+ self._prep_for_close()
+ print("file close 2")
+ if self._f is not None:
+ try:
+ self._f.close()
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ self._f = None
+
+ def read_blocks(self, indices):
+ # print("Reading Blocks ......");
+ self._check_async()
+ blocks = []
+ for i in indices:
+ assert 0 <= i < self.block_count
+ self._bytes_received += self.block_size
+
+ self._timer.startTimer();
+ self._f.seek(self._header_offset + i * self.block_size)
+ a = self._f.read(self.block_size)
+ # self._f.flush()
+ self._timer.endTimer();
+
+ blocks.append(a)
+ return blocks
+
+ def yield_blocks(self, indices):
+ # print("Yielding Blocks ......");
+ self._check_async()
+ for i in indices:
+ assert 0 <= i < self.block_count
+ self._bytes_received += self.block_size
+
+ self._timer.startTimer();
+ self._f.seek(self._header_offset + i * self.block_size)
+ a = self._f.read(self.block_size)
+ # self._f.flush()
+ self._timer.endTimer();
+
+ yield a
+
+ def read_block(self, i):
+ # print("Reading Block ......");
+ self._check_async()
+ assert 0 <= i < self.block_count
+ self._bytes_received += self.block_size
+
+ self._timer.startTimer();
+ self._f.seek(self._header_offset + i * self.block_size)
+ a = self._f.read(self.block_size)
+ # self._f.flush()
+ self._timer.endTimer();
+
+ return a
+
+ def write_blocks(self, indices, blocks, callback=None):
+ self._check_async()
+ chunks = []
+ for i, block in zip(indices, blocks):
+ assert 0 <= i < self.block_count
+ assert len(block) == self.block_size, \
+ ("%s != %s" % (len(block), self.block_size))
+ self._bytes_sent += self.block_size
+ chunks.append((i, block))
+
+ self._schedule_async_write(chunks, callback=callback)
+
+
+
+ def write_block(self, i, block):
+ self._check_async()
+ assert 0 <= i < self.block_count
+ assert len(block) == self.block_size
+ self._bytes_sent += self.block_size
+
+ self._schedule_async_write(((i, block),))
+
+ @property
+ def bytes_sent(self):
+ return self._bytes_sent
+
+ @property
+ def bytes_received(self):
+ return self._bytes_received
+
+BlockStorageTypeFactory.register_device("file", BlockStorageFile)
--- /dev/null
+__all__ = ('BlockStorageMMap',)
+
+import logging
+import mmap
+
+from pyoram.storage.block_storage import \
+ BlockStorageTypeFactory
+from pyoram.storage.block_storage_file import \
+ BlockStorageFile
+
+log = logging.getLogger("pyoram")
+
+class _BlockStorageMemoryImpl(object):
+ """
+ This class implementents the BlockStorageInterface read/write
+ methods for classes with a private attribute _f that can be
+ accessed using __getslice__/__setslice__ notation.
+ """
+
+ def read_blocks(self, indices):
+ blocks = []
+ for i in indices:
+ assert 0 <= i < self.block_count
+ self._bytes_received += self.block_size
+ pos_start = self._header_offset + i * self.block_size
+ pos_stop = pos_start + self.block_size
+ blocks.append(self._f[pos_start:pos_stop])
+ return blocks
+
+ def yield_blocks(self, indices):
+ for i in indices:
+ assert 0 <= i < self.block_count
+ self._bytes_received += self.block_size
+ pos_start = self._header_offset + i * self.block_size
+ pos_stop = pos_start + self.block_size
+ yield self._f[pos_start:pos_stop]
+
+ def read_block(self, i):
+ assert 0 <= i < self.block_count
+ self._bytes_received += self.block_size
+ pos_start = self._header_offset + i * self.block_size
+ pos_stop = pos_start + self.block_size
+ return self._f[pos_start:pos_stop]
+
+ def write_blocks(self, indices, blocks, callback=None):
+ for i, block in zip(indices, blocks):
+ assert 0 <= i < self.block_count
+ self._bytes_sent += self.block_size
+ pos_start = self._header_offset + i * self.block_size
+ pos_stop = pos_start + self.block_size
+ self._f[pos_start:pos_stop] = block
+ if callback is not None:
+ callback(i)
+
+ def write_block(self, i, block):
+ assert 0 <= i < self.block_count
+ self._bytes_sent += self.block_size
+ pos_start = self._header_offset + i * self.block_size
+ pos_stop = pos_start + self.block_size
+ self._f[pos_start:pos_stop] = block
+
+class BlockStorageMMap(_BlockStorageMemoryImpl,
+ BlockStorageFile):
+ """
+ A class implementing the block storage interface by creating a
+ memory map over a local file. This class uses the same storage
+ format as BlockStorageFile. Thus, a block storage space can be
+ created using this class and then, after saving the raw storage
+ data to disk, reopened with any other class compatible with
+ BlockStorageFile (and visa versa).
+ """
+
+ def __init__(self, *args, **kwds):
+ mm = kwds.pop('mm', None)
+ self._mmap_owned = True
+ super(BlockStorageMMap, self).__init__(*args, **kwds)
+ if mm is None:
+ self._f.flush()
+ mm = mmap.mmap(self._f.fileno(), 0)
+ else:
+ self._mmap_owned = False
+ self._f.close()
+ self._f = mm
+
+ #
+ # Define BlockStorageInterface Methods
+ # (override what is defined on BlockStorageFile)
+ #
+
+ #@classmethod
+ #def compute_storage_size(...)
+
+ def clone_device(self):
+ f = BlockStorageMMap(self.storage_name,
+ threadpool_size=0,
+ mm=self._f,
+ ignore_lock=True)
+ f._pool = self._pool
+ f._close_pool = False
+ return f
+
+ @classmethod
+ def setup(cls,
+ storage_name,
+ block_size,
+ block_count,
+ **kwds):
+ f = BlockStorageFile.setup(storage_name,
+ block_size,
+ block_count,
+ **kwds)
+ f.close()
+ return BlockStorageMMap(storage_name)
+
+ #def update_header_data(...)
+
+ def close(self):
+ self._prep_for_close()
+ if self._f is not None:
+ if self._mmap_owned:
+ try:
+ self._f.close()
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ self._f = None
+
+ #def read_blocks(...)
+
+ #def yield_blocks(...)
+
+ #def read_block(...)
+
+ #def write_blocks(...)
+
+ #def write_block(...)
+
+ #@property
+ #def bytes_sent(...)
+
+ #@property
+ #def bytes_received(...)
+
+BlockStorageTypeFactory.register_device("mmap", BlockStorageMMap)
--- /dev/null
+__all__ = ('BlockStorageRAM',)
+
+import os
+import struct
+import logging
+import errno
+from multiprocessing.pool import ThreadPool
+
+import pyoram
+from pyoram.storage.block_storage import \
+ (BlockStorageInterface,
+ BlockStorageTypeFactory)
+from pyoram.storage.block_storage_mmap import \
+ (BlockStorageMMap,
+ _BlockStorageMemoryImpl)
+
+import tqdm
+import six
+from six.moves import xrange
+
+log = logging.getLogger("pyoram")
+
+class BlockStorageRAM(_BlockStorageMemoryImpl,
+ BlockStorageInterface):
+ """
+ A class implementing the block storage interface where all data is
+ kept in RAM. This class uses the same storage format as
+ BlockStorageFile. Thus, a block storage space can be created using
+ this class and then, after saving the raw storage data to disk,
+ reopened with any other class compatible with BlockStorageFile
+ (and visa versa).
+ """
+
+ _index_struct_string = BlockStorageMMap._index_struct_string
+ _index_offset = struct.calcsize(_index_struct_string)
+
+ def __init__(self,
+ storage_data,
+ threadpool_size=None,
+ ignore_lock=False):
+
+ self._bytes_sent = 0
+ self._bytes_received = 0
+ self._ignore_lock = ignore_lock
+ self._f = None
+ self._pool = None
+ self._close_pool = True
+ if type(storage_data) is not bytearray:
+ raise TypeError(
+ "BlockStorageRAM requires input argument of type "
+ "'bytearray'. Invalid input type: %s"
+ % (type(storage_data)))
+ self._f = storage_data
+ self._block_size, self._block_count, user_header_size, locked = \
+ struct.unpack(
+ BlockStorageRAM._index_struct_string,
+ self._f[:BlockStorageRAM._index_offset])
+
+ if locked and (not self._ignore_lock):
+ raise IOError(
+ "Can not open block storage device because it is "
+ "locked by another process. To ignore this check, "
+ "initialize this class with the keyword 'ignore_lock' "
+ "set to True.")
+ self._user_header_data = bytes()
+ if user_header_size > 0:
+ self._user_header_data = \
+ bytes(self._f[BlockStorageRAM._index_offset:\
+ (BlockStorageRAM._index_offset+user_header_size)])
+ assert len(self._user_header_data) == user_header_size
+ self._header_offset = BlockStorageRAM._index_offset + \
+ len(self._user_header_data)
+
+ if not self._ignore_lock:
+ # turn on the locked flag
+ self._f[:BlockStorageRAM._index_offset] = \
+ struct.pack(BlockStorageRAM._index_struct_string,
+ self.block_size,
+ self.block_count,
+ len(self._user_header_data),
+ True)
+
+ # Although we do not use the threadpool we still
+ # create just in case we are the first
+ if threadpool_size != 0:
+ self._pool = ThreadPool(threadpool_size)
+
+ #
+ # Add some methods specific to BlockStorageRAM
+ #
+
+ @staticmethod
+ def fromfile(file_,
+ threadpool_size=None,
+ ignore_lock=False):
+ """
+ Instantiate BlockStorageRAM device from a file saved in block
+ storage format. The file_ argument can be a file object or a
+ string that represents a filename. If called with a file
+ object, it should be opened in binary mode, and the caller is
+ responsible for closing the file.
+
+ This method returns a BlockStorageRAM instance.
+ """
+ close_file = False
+ if not hasattr(file_, 'read'):
+ file_ = open(file_, 'rb')
+ close_file = True
+ try:
+ header_data = file_.read(BlockStorageRAM._index_offset)
+ block_size, block_count, user_header_size, locked = \
+ struct.unpack(
+ BlockStorageRAM._index_struct_string,
+ header_data)
+ if locked and (not ignore_lock):
+ raise IOError(
+ "Can not open block storage device because it is "
+ "locked by another process. To ignore this check, "
+ "call this method with the keyword 'ignore_lock' "
+ "set to True.")
+ header_offset = len(header_data) + \
+ user_header_size
+ f = bytearray(header_offset + \
+ (block_size * block_count))
+ f[:header_offset] = header_data + file_.read(user_header_size)
+ f[header_offset:] = file_.read(block_size * block_count)
+ finally:
+ if close_file:
+ file_.close()
+
+ return BlockStorageRAM(f,
+ threadpool_size=threadpool_size,
+ ignore_lock=ignore_lock)
+
+ def tofile(self, file_):
+ """
+ Dump all storage data to a file. The file_ argument can be a
+ file object or a string that represents a filename. If called
+ with a file object, it should be opened in binary mode, and
+ the caller is responsible for closing the file.
+
+ The method should only be called after the storage device has
+ been closed to ensure that the locked flag has been set to
+ False.
+ """
+ close_file = False
+ if not hasattr(file_, 'write'):
+ file_ = open(file_, 'wb')
+ close_file = True
+ file_.write(self._f)
+ if close_file:
+ file_.close()
+
+ @property
+ def data(self):
+ """Access the raw bytearray"""
+ return self._f
+
+ #
+ # Define BlockStorageInterface Methods
+ #
+
+ def clone_device(self):
+ f = BlockStorageRAM(self._f,
+ threadpool_size=0,
+ ignore_lock=True)
+ f._pool = self._pool
+ f._close_pool = False
+ return f
+
+ @classmethod
+ def compute_storage_size(cls, *args, **kwds):
+ return BlockStorageMMap.compute_storage_size(*args, **kwds)
+
+ @classmethod
+ def setup(cls,
+ storage_name,
+ block_size,
+ block_count,
+ initialize=None,
+ header_data=None,
+ ignore_existing=False,
+ threadpool_size=None):
+
+ # We ignore the 'storage_name' argument
+ # We ignore the 'ignore_existing' flag
+ if (block_size <= 0) or (block_size != int(block_size)):
+ raise ValueError(
+ "Block size (bytes) must be a positive integer: %s"
+ % (block_size))
+ if (block_count <= 0) or (block_count != int(block_count)):
+ raise ValueError(
+ "Block count must be a positive integer: %s"
+ % (block_count))
+ if (header_data is not None) and \
+ (type(header_data) is not bytes):
+ raise TypeError(
+ "'header_data' must be of type bytes. "
+ "Invalid type: %s" % (type(header_data)))
+
+ if initialize is None:
+ zeros = bytes(bytearray(block_size))
+ initialize = lambda i: zeros
+
+ # create_index
+ index_data = None
+ if header_data is None:
+ index_data = struct.pack(BlockStorageRAM._index_struct_string,
+ block_size,
+ block_count,
+ 0,
+ False)
+ header_data = bytes()
+ else:
+ index_data = struct.pack(BlockStorageRAM._index_struct_string,
+ block_size,
+ block_count,
+ len(header_data),
+ False)
+ header_offset = len(index_data) + len(header_data)
+ f = bytearray(header_offset + \
+ (block_size * block_count))
+ f[:header_offset] = index_data + header_data
+ progress_bar = tqdm.tqdm(total=block_count*block_size,
+ desc="Initializing File Block Storage Space",
+ unit="B",
+ unit_scale=True,
+ disable=not pyoram.config.SHOW_PROGRESS_BAR)
+ for i in xrange(block_count):
+ block = initialize(i)
+ assert len(block) == block_size, \
+ ("%s != %s" % (len(block), block_size))
+ pos_start = header_offset + i * block_size
+ pos_start = header_offset + i * block_size
+ pos_stop = pos_start + block_size
+ f[pos_start:pos_stop] = block[:]
+ progress_bar.update(n=block_size)
+ progress_bar.close()
+
+ return BlockStorageRAM(f, threadpool_size=threadpool_size)
+
+ @property
+ def header_data(self):
+ return self._user_header_data
+
+ @property
+ def block_count(self):
+ return self._block_count
+
+ @property
+ def block_size(self):
+ return self._block_size
+
+ @property
+ def storage_name(self):
+ return None
+
+ def update_header_data(self, new_header_data):
+ if len(new_header_data) != len(self.header_data):
+ raise ValueError(
+ "The size of header data can not change.\n"
+ "Original bytes: %s\n"
+ "New bytes: %s" % (len(self.header_data),
+ len(new_header_data)))
+ self._user_header_data = bytes(new_header_data)
+ self._f[BlockStorageRAM._index_offset:\
+ (BlockStorageRAM._index_offset+len(new_header_data))] = \
+ self._user_header_data
+
+ def close(self):
+ if self._close_pool and (self._pool is not None):
+ self._pool.close()
+ self._pool.join()
+ self._pool = None
+ if not self._ignore_lock:
+ # turn off the locked flag
+ self._f[:BlockStorageRAM._index_offset] = \
+ struct.pack(BlockStorageRAM._index_struct_string,
+ self.block_size,
+ self.block_count,
+ len(self._user_header_data),
+ False)
+ self._ignore_lock = True
+
+ #
+ # We must cast from bytearray to bytes
+ # when reading from a bytearray so that this
+ # class works with the encryption layer.
+ #
+
+ def read_blocks(self, indices):
+ return [bytes(block) for block
+ in super(BlockStorageRAM, self).read_blocks(indices)]
+
+ def yield_blocks(self, indices):
+ for block in super(BlockStorageRAM, self).yield_blocks(indices):
+ yield bytes(block)
+
+ def read_block(self, i):
+ return bytes(super(BlockStorageRAM, self).read_block(i))
+
+ #def write_blocks(...)
+
+ #def write_block(...)
+
+ @property
+ def bytes_sent(self):
+ return self._bytes_sent
+
+ @property
+ def bytes_received(self):
+ return self._bytes_received
+
+BlockStorageTypeFactory.register_device("ram", BlockStorageRAM)
--- /dev/null
+__all__ = ('BlockStorageS3',)
+
+import struct
+import logging
+from multiprocessing.pool import ThreadPool
+
+import pyoram
+from pyoram.storage.block_storage import \
+ (BlockStorageInterface,
+ BlockStorageTypeFactory)
+from pyoram.storage.boto3_s3_wrapper import Boto3S3Wrapper
+
+import tqdm
+import six
+from six.moves import xrange, map
+
+log = logging.getLogger("pyoram")
+
+class BlockStorageS3(BlockStorageInterface):
+ """
+ A block storage device for Amazon Simple
+ Storage Service (S3).
+ """
+
+ _index_name = "PyORAMBlockStorageS3_index.bin"
+ _index_struct_string = "!LLL?"
+ _index_offset = struct.calcsize(_index_struct_string)
+
+ def __init__(self,
+ storage_name,
+ bucket_name=None,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
+ region_name=None,
+ ignore_lock=False,
+ threadpool_size=None,
+ s3_wrapper=Boto3S3Wrapper):
+
+ self._bytes_sent = 0
+ self._bytes_received = 0
+ self._storage_name = storage_name
+ self._bucket_name = bucket_name
+ self._aws_access_key_id = aws_access_key_id
+ self._aws_secret_access_key = aws_secret_access_key
+ self._region_name = region_name
+ self._pool = None
+ self._close_pool = True
+ self._s3 = None
+ self._ignore_lock = ignore_lock
+ self._async_write = None
+ self._async_write_callback = None
+
+ if bucket_name is None:
+ raise ValueError("'bucket_name' keyword is required")
+
+ if threadpool_size != 0:
+ self._pool = ThreadPool(threadpool_size)
+
+ self._s3 = s3_wrapper(bucket_name,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=region_name)
+ self._basename = self.storage_name+"/b%d"
+
+ index_data = self._s3.download(
+ self._storage_name+"/"+BlockStorageS3._index_name)
+ self._block_size, self._block_count, user_header_size, locked = \
+ struct.unpack(
+ BlockStorageS3._index_struct_string,
+ index_data[:BlockStorageS3._index_offset])
+ if locked and (not self._ignore_lock):
+ raise IOError(
+ "Can not open block storage device because it is "
+ "locked by another process. To ignore this check, "
+ "initialize this class with the keyword 'ignore_lock' "
+ "set to True.")
+ self._user_header_data = bytes()
+ if user_header_size > 0:
+ self._user_header_data = \
+ index_data[BlockStorageS3._index_offset:
+ (BlockStorageS3._index_offset+user_header_size)]
+
+ if not self._ignore_lock:
+ # turn on the locked flag
+ self._s3.upload((self._storage_name+"/"+BlockStorageS3._index_name,
+ struct.pack(BlockStorageS3._index_struct_string,
+ self.block_size,
+ self.block_count,
+ len(self.header_data),
+ True) + \
+ self.header_data))
+
+ def _check_async(self):
+ if self._async_write is not None:
+ for i in self._async_write:
+ if self._async_write_callback is not None:
+ self._async_write_callback(i)
+ self._async_write = None
+ self._async_write_callback = None
+
+ def _schedule_async_write(self, arglist, callback=None):
+ assert self._async_write is None
+ if self._pool is not None:
+ self._async_write = \
+ self._pool.imap_unordered(self._s3.upload, arglist)
+ self._async_write_callback = callback
+ else:
+ # Note: we are using six.map which always
+ # behaves like imap
+ for i in map(self._s3.upload, arglist):
+ if callback is not None:
+ callback(i)
+
+ def _download(self, i):
+ return self._s3.download(self._basename % i)
+
+ #
+ # Define BlockStorageInterface Methods
+ #
+
+ def clone_device(self):
+ f = BlockStorageS3(self.storage_name,
+ bucket_name=self._bucket_name,
+ aws_access_key_id=self._aws_access_key_id,
+ aws_secret_access_key=self._aws_secret_access_key,
+ region_name=self._region_name,
+ threadpool_size=0,
+ s3_wrapper=type(self._s3),
+ ignore_lock=True)
+ f._pool = self._pool
+ f._close_pool = False
+ return f
+
+ @classmethod
+ def compute_storage_size(cls,
+ block_size,
+ block_count,
+ header_data=None,
+ ignore_header=False):
+ assert (block_size > 0) and (block_size == int(block_size))
+ assert (block_count > 0) and (block_count == int(block_count))
+ if header_data is None:
+ header_data = bytes()
+ if ignore_header:
+ return block_size * block_count
+ else:
+ return BlockStorageS3._index_offset + \
+ len(header_data) + \
+ block_size * block_count
+
+ @classmethod
+ def setup(cls,
+ storage_name,
+ block_size,
+ block_count,
+ bucket_name=None,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
+ region_name=None,
+ header_data=None,
+ initialize=None,
+ threadpool_size=None,
+ ignore_existing=False,
+ s3_wrapper=Boto3S3Wrapper):
+
+ if bucket_name is None:
+ raise ValueError("'bucket_name' is required")
+ if (block_size <= 0) or (block_size != int(block_size)):
+ raise ValueError(
+ "Block size (bytes) must be a positive integer: %s"
+ % (block_size))
+ if (block_count <= 0) or (block_count != int(block_count)):
+ raise ValueError(
+ "Block count must be a positive integer: %s"
+ % (block_count))
+ if (header_data is not None) and \
+ (type(header_data) is not bytes):
+ raise TypeError(
+ "'header_data' must be of type bytes. "
+ "Invalid type: %s" % (type(header_data)))
+
+ pool = None
+ if threadpool_size != 0:
+ pool = ThreadPool(threadpool_size)
+
+ s3 = s3_wrapper(bucket_name,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=region_name)
+ exists = s3.exists(storage_name)
+ if (not ignore_existing) and exists:
+ raise IOError(
+ "Storage location already exists in bucket %s: %s"
+ % (bucket_name, storage_name))
+ if exists:
+ log.info("Deleting objects in existing S3 entry: %s/%s"
+ % (bucket_name, storage_name))
+ print("Clearing Existing S3 Objects With Prefix %s/%s/"
+ % (bucket_name, storage_name))
+ s3.clear(storage_name, threadpool=pool)
+
+ if header_data is None:
+ s3.upload((storage_name+"/"+BlockStorageS3._index_name,
+ struct.pack(BlockStorageS3._index_struct_string,
+ block_size,
+ block_count,
+ 0,
+ False)))
+ else:
+ s3.upload((storage_name+"/"+BlockStorageS3._index_name,
+ struct.pack(BlockStorageS3._index_struct_string,
+ block_size,
+ block_count,
+ len(header_data),
+ False) + \
+ header_data))
+
+ if initialize is None:
+ zeros = bytes(bytearray(block_size))
+ initialize = lambda i: zeros
+ basename = storage_name+"/b%d"
+ # NOTE: We will not be informed when a thread
+ # encounters an exception (e.g., when
+ # calling initialize(i). We must ensure
+ # that all iterations were processed
+ # by counting the results.
+ def init_blocks():
+ for i in xrange(block_count):
+ yield (basename % i, initialize(i))
+ def _do_upload(arg):
+ try:
+ s3.upload(arg)
+ except Exception as e: # pragma: no cover
+ log.error( # pragma: no cover
+ "An exception occured during S3 " # pragma: no cover
+ "setup when calling the block " # pragma: no cover
+ "initialization function: %s" # pragma: no cover
+ % (str(e))) # pragma: no cover
+ raise # pragma: no cover
+ total = None
+ progress_bar = tqdm.tqdm(total=block_count*block_size,
+ desc="Initializing S3 Block Storage Space",
+ unit="B",
+ unit_scale=True,
+ disable=not pyoram.config.SHOW_PROGRESS_BAR)
+ if pool is not None:
+ try:
+ for i,_ in enumerate(
+ pool.imap_unordered(_do_upload, init_blocks())):
+ total = i
+ progress_bar.update(n=block_size)
+ except Exception as e: # pragma: no cover
+ s3.clear(storage_name) # pragma: no cover
+ raise # pragma: no cover
+ finally:
+ progress_bar.close()
+ pool.close()
+ pool.join()
+ else:
+ try:
+ for i,_ in enumerate(
+ map(s3.upload, init_blocks())):
+ total = i
+ progress_bar.update(n=block_size)
+ except Exception as e: # pragma: no cover
+ s3.clear(storage_name) # pragma: no cover
+ raise # pragma: no cover
+ finally:
+ progress_bar.close()
+
+ if total != block_count - 1:
+ s3.clear(storage_name) # pragma: no cover
+ if pool is not None: # pragma: no cover
+ pool.close() # pragma: no cover
+ pool.join() # pragma: no cover
+ raise ValueError( # pragma: no cover
+ "Something went wrong during S3 block" # pragma: no cover
+ " initialization. Check the logger " # pragma: no cover
+ "output for more information.") # pragma: no cover
+
+ return BlockStorageS3(storage_name,
+ bucket_name=bucket_name,
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=region_name,
+ threadpool_size=threadpool_size,
+ s3_wrapper=s3_wrapper)
+
+ @property
+ def header_data(self):
+ return self._user_header_data
+
+ @property
+ def block_count(self):
+ return self._block_count
+
+ @property
+ def block_size(self):
+ return self._block_size
+
+ @property
+ def storage_name(self):
+ return self._storage_name
+
+ def update_header_data(self, new_header_data):
+ self._check_async()
+ if len(new_header_data) != len(self.header_data):
+ raise ValueError(
+ "The size of header data can not change.\n"
+ "Original bytes: %s\n"
+ "New bytes: %s" % (len(self.header_data),
+ len(new_header_data)))
+ self._user_header_data = new_header_data
+
+ index_data = bytearray(self._s3.download(
+ self._storage_name+"/"+BlockStorageS3._index_name))
+ lenbefore = len(index_data)
+ index_data[BlockStorageS3._index_offset:] = new_header_data
+ assert lenbefore == len(index_data)
+ self._s3.upload((self._storage_name+"/"+BlockStorageS3._index_name,
+ bytes(index_data)))
+
+ def close(self):
+ self._check_async()
+ if self._s3 is not None:
+ if not self._ignore_lock:
+ # turn off the locked flag
+ self._s3.upload(
+ (self._storage_name+"/"+BlockStorageS3._index_name,
+ struct.pack(BlockStorageS3._index_struct_string,
+ self.block_size,
+ self.block_count,
+ len(self.header_data),
+ False) + \
+ self.header_data))
+ if self._close_pool and (self._pool is not None):
+ self._pool.close()
+ self._pool.join()
+ self._pool = None
+
+ def read_blocks(self, indices):
+ self._check_async()
+ # be sure not to exhaust this if it is an iterator
+ # or generator
+ indices = list(indices)
+ assert all(0 <= i <= self.block_count for i in indices)
+ self._bytes_received += self.block_size * len(indices)
+ if self._pool is not None:
+ return self._pool.map(self._download, indices)
+ else:
+ return list(map(self._download, indices))
+
+ def yield_blocks(self, indices):
+ self._check_async()
+ # be sure not to exhaust this if it is an iterator
+ # or generator
+ indices = list(indices)
+ assert all(0 <= i <= self.block_count for i in indices)
+ self._bytes_received += self.block_size * len(indices)
+ if self._pool is not None:
+ return self._pool.imap(self._download, indices)
+ else:
+ return map(self._download, indices)
+
+ def read_block(self, i):
+ self._check_async()
+ assert 0 <= i < self.block_count
+ self._bytes_received += self.block_size
+ return self._download(i)
+
+ def write_blocks(self, indices, blocks, callback=None):
+ self._check_async()
+ # be sure not to exhaust this if it is an iterator
+ # or generator
+ indices = list(indices)
+ assert all(0 <= i <= self.block_count for i in indices)
+ self._bytes_sent += self.block_size * len(indices)
+ indices = (self._basename % i for i in indices)
+ self._schedule_async_write(zip(indices, blocks),
+ callback=callback)
+
+ def write_block(self, i, block):
+ self._check_async()
+ assert 0 <= i < self.block_count
+ self._bytes_sent += self.block_size
+ self._schedule_async_write((((self._basename % i), block),))
+
+ @property
+ def bytes_sent(self):
+ return self._bytes_sent
+
+ @property
+ def bytes_received(self):
+ return self._bytes_received
+
+BlockStorageTypeFactory.register_device("s3", BlockStorageS3)
--- /dev/null
+__all__ = ('BlockStorageSFTP',)
+
+import logging
+from AliTimer import *
+from pyoram.util.misc import chunkiter
+from pyoram.storage.block_storage import \
+ BlockStorageTypeFactory
+from pyoram.storage.block_storage_file import \
+ BlockStorageFile
+
+import time
+
+log = logging.getLogger("pyoram")
+
+class BlockStorageSFTP(BlockStorageFile):
+ """
+ A block storage device for accessing file data through
+ an SSH portal using Secure File Transfer Protocol (SFTP).
+ """
+
+ def __init__(self,
+ storage_name,
+ sshclient=None,
+ **kwds):
+ if sshclient is None:
+ raise ValueError(
+ "Can not open sftp block storage device "
+ "without an ssh client.")
+ super(BlockStorageSFTP, self).__init__(
+ storage_name,
+ _filesystem=sshclient.open_sftp(),
+ **kwds)
+ self._sshclient = sshclient
+ self._f.set_pipelined()
+ self._timer = Foo.Instance();
+
+
+
+ #
+ # Define BlockStorageInterface Methods
+ #
+
+ def clone_device(self):
+ f = BlockStorageSFTP(self.storage_name,
+ sshclient=self._sshclient,
+ threadpool_size=0,
+ ignore_lock=True)
+ f._pool = self._pool
+ f._close_pool = False
+ return f
+
+ #@classmethod
+ #def compute_storage_size(...)
+
+ @classmethod
+ def setup(cls,
+ storage_name,
+ block_size,
+ block_count,
+ sshclient=None,
+ threadpool_size=None,
+ **kwds):
+ if sshclient is None:
+ raise ValueError(
+ "Can not setup sftp block storage device "
+ "without an ssh client.")
+
+ with BlockStorageFile.setup(storage_name,
+ block_size,
+ block_count,
+ _filesystem=sshclient.open_sftp(),
+ threadpool_size=threadpool_size,
+ **kwds) as f:
+ pass
+ f._filesystem.close()
+
+ return BlockStorageSFTP(storage_name,
+ sshclient=sshclient,
+ threadpool_size=threadpool_size)
+
+ #@property
+ #def header_data(...)
+
+ #@property
+ #def block_count(...)
+
+ #@property
+ #def block_size(...)
+
+ #@property
+ #def storage_name(...)
+
+ #def update_header_data(...)
+
+ def close(self):
+ print("sftp close 1")
+ super(BlockStorageSFTP, self).close()
+ print("sftp close 2")
+ self._filesystem.close()
+ print("sftp close 3")
+
+ def read_blocks(self, indices):
+ self._check_async()
+ args = []
+ for i in indices:
+ assert 0 <= i < self.block_count
+ self._bytes_received += self.block_size
+ args.append((self._header_offset + i * self.block_size,
+ self.block_size))
+
+
+ sTime = time.time();
+ self._timer.startTimer();
+ a = self._f.readv(args)
+ self._timer.endTimer();
+
+
+ # print("Reading Blocks SFTP......" + str(time.time() - sTime));
+ return a
+
+ def yield_blocks(self, indices, chunksize=100):
+ for chunk in chunkiter(indices, n=chunksize):
+ assert all(0 <= i <= self.block_count for i in chunk)
+ self._bytes_received += self.block_size * len(chunk)
+ args = [(self._header_offset + i * self.block_size,
+ self.block_size)
+ for i in chunk]
+
+
+ self._timer.startTimer();
+ a = self._f.readv(args)
+ self._timer.endTimer();
+
+ # print("Yield SFTP......");
+
+ for block in a:
+ yield block
+
+ #def read_block(...)
+
+ #def write_blocks(...)
+
+ #def write_block(...)
+
+ #@property
+ #def bytes_sent(...)
+
+ #@property
+ #def bytes_received(...)
+
+BlockStorageTypeFactory.register_device("sftp", BlockStorageSFTP)
--- /dev/null
+__all__ = ("Boto3S3Wrapper",
+ "MockBoto3S3Wrapper")
+import os
+import shutil
+
+import pyoram
+
+import tqdm
+try:
+ import boto3
+ import botocore
+ boto3_available = True
+except: # pragma: no cover
+ boto3_available = False # pragma: no cover
+
+import six
+from six.moves import xrange, map
+
+class Boto3S3Wrapper(object):
+ """
+ A wrapper class for the boto3 S3 service.
+ """
+
+ def __init__(self,
+ bucket_name,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
+ region_name=None):
+ if not boto3_available:
+ raise ImportError( # pragma: no cover
+ "boto3 module is required to " # pragma: no cover
+ "use BlockStorageS3 device") # pragma: no cover
+
+ self._s3 = boto3.session.Session(
+ aws_access_key_id=aws_access_key_id,
+ aws_secret_access_key=aws_secret_access_key,
+ region_name=region_name).resource('s3')
+ self._bucket = self._s3.Bucket(bucket_name)
+
+ def exists(self, key):
+ try:
+ self._bucket.Object(key).load()
+ except botocore.exceptions.ClientError as e:
+ if e.response['Error']['Code'] == "404":
+ pass
+ else:
+ raise e
+ else:
+ return True
+ # It's not a file. Check if it's a "directory".
+ for obj in self._bucket.objects.filter(
+ Prefix=key+"/"):
+ return True
+ return False
+
+ def download(self, key):
+ try:
+ return self._s3.meta.client.get_object(
+ Bucket=self._bucket.name,
+ Key=key)['Body'].read()
+ except botocore.exceptions.ClientError:
+ raise IOError("Can not download key: %s"
+ % (key))
+
+ def upload(self, key_block):
+ key, block = key_block
+ self._bucket.put_object(Key=key, Body=block)
+
+ # Chunk a streamed iterator of which we do not know
+ # the size
+ def _chunks(self, objs, n=100):
+ assert 1 <= n <= 1000 # required by boto3
+ objs = iter(objs)
+ try:
+ while (1):
+ chunk = []
+ while len(chunk) < n:
+ chunk.append({'Key': six.next(objs).key})
+ yield {'Objects': chunk}
+ except StopIteration:
+ pass
+ if len(chunk):
+ yield {'Objects': chunk}
+
+ def _del(self, chunk):
+ self._bucket.delete_objects(Delete=chunk)
+ return len(chunk['Objects'])
+
+ def clear(self, key, threadpool=None):
+ objs = self._bucket.objects.filter(Prefix=key+"/")
+ if threadpool is not None:
+ deliter = threadpool.imap(self._del, self._chunks(objs))
+ else:
+ deliter = map(self._del, self._chunks(objs))
+ with tqdm.tqdm(total=None, #len(objs),
+ desc="Clearing S3 Blocks",
+ unit=" objects",
+ disable=not pyoram.config.SHOW_PROGRESS_BAR) as progress_bar:
+ progress_bar.update(n=0)
+ for chunksize in deliter:
+ progress_bar.update(n=chunksize)
+
+class MockBoto3S3Wrapper(object):
+ """
+ A mock class for Boto3S3Wrapper that uses the local filesystem and
+ treats the bucket name as a directory.
+
+ This class is mainly used for testing, but could potentially be
+ used to setup storage locally that is then uploaded to S3 through
+ the AWS web portal.
+ """
+
+ def __init__(self,
+ bucket_name,
+ aws_access_key_id=None,
+ aws_secret_access_key=None,
+ region_name=None):
+
+ self._bucket_name = os.path.abspath(
+ os.path.normpath(bucket_name))
+
+ # called within upload to create directory
+ # heirarchy on the fly
+ def _makedirs_if_needed(self, key):
+ if not os.path.exists(
+ os.path.dirname(
+ os.path.join(self._bucket_name, key))):
+ os.makedirs(
+ os.path.dirname(
+ os.path.join(self._bucket_name, key)))
+ assert not os.path.isdir(
+ os.path.join(self._bucket_name, key))
+
+ def exists(self, key):
+ return os.path.exists(
+ os.path.join(self._bucket_name, key))
+
+ def download(self, key):
+ with open(os.path.join(self._bucket_name, key), 'rb') as f:
+ return f.read()
+
+ def upload(self, key_block):
+ key, block = key_block
+ self._makedirs_if_needed(key)
+ with open(os.path.join(self._bucket_name, key), 'wb') as f:
+ f.write(block)
+
+ def clear(self, key, threadpool=None):
+ if os.path.exists(
+ os.path.join(self._bucket_name, key)):
+ if os.path.isdir(
+ os.path.join(self._bucket_name, key)):
+ shutil.rmtree(
+ os.path.join(self._bucket_name, key),
+ ignore_errors=True)
+ else:
+ os.remove(
+ os.path.join(self._bucket_name, key))
--- /dev/null
+__all__ = ('HeapStorage',)
+
+import struct
+
+from pyoram.util.virtual_heap import SizedVirtualHeap
+from pyoram.storage.block_storage import (BlockStorageInterface,
+ BlockStorageTypeFactory)
+
+class HeapStorageInterface(object):
+
+ def __enter__(self):
+ return self
+ def __exit__(self, *args):
+ self.close()
+
+ #
+ # Abstract Interface
+ #
+
+ def clone_device(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+ @classmethod
+ def compute_storage_size(cls, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @classmethod
+ def setup(cls, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+ @property
+ def header_data(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def bucket_count(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def bucket_size(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def blocks_per_bucket(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def storage_name(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def virtual_heap(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def bucket_storage(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+ def update_header_data(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ def close(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ def read_path(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+ def write_path(self, *args, **kwds):
+ raise NotImplementedError # pragma: no cover
+
+ @property
+ def bytes_sent(self):
+ raise NotImplementedError # pragma: no cover
+ @property
+ def bytes_received(self):
+ raise NotImplementedError # pragma: no cover
+
+class HeapStorage(HeapStorageInterface):
+
+ _header_struct_string = "!LLL"
+ _header_offset = struct.calcsize(_header_struct_string)
+
+ def _new_storage(self, storage, **kwds):
+ storage_type = kwds.pop('storage_type', 'file')
+
+
+ def __init__(self, storage, **kwds):
+ if isinstance(storage, BlockStorageInterface):
+ self._storage = storage
+ if len(kwds):
+ raise ValueError(
+ "Keywords not used when initializing "
+ "with a storage device: %s"
+ % (str(kwds)))
+ else:
+ storage_type = kwds.pop('storage_type', 'file')
+ self._storage = BlockStorageTypeFactory(storage_type)\
+ (storage, **kwds)
+
+ heap_base, heap_height, blocks_per_bucket = \
+ struct.unpack(
+ self._header_struct_string,
+ self._storage.header_data[:self._header_offset])
+ self._vheap = SizedVirtualHeap(
+ heap_base,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket)
+
+ #
+ # Define HeapStorageInterface Methods
+ #
+
+ def clone_device(self):
+ return HeapStorage(self._storage.clone_device())
+
+ @classmethod
+ def compute_storage_size(cls,
+ block_size,
+ heap_height,
+ blocks_per_bucket=1,
+ heap_base=2,
+ ignore_header=False,
+ storage_type='file',
+ **kwds):
+ assert (block_size > 0) and (block_size == int(block_size))
+ assert heap_height >= 0
+ assert blocks_per_bucket >= 1
+ assert heap_base >= 2
+ assert 'block_count' not in kwds
+ vheap = SizedVirtualHeap(
+ heap_base,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket)
+ if ignore_header:
+ return BlockStorageTypeFactory(storage_type).\
+ compute_storage_size(
+ vheap.blocks_per_bucket * block_size,
+ vheap.bucket_count(),
+ ignore_header=True,
+ **kwds)
+ else:
+ return cls._header_offset + \
+ BlockStorageTypeFactory(storage_type).\
+ compute_storage_size(
+ vheap.blocks_per_bucket * block_size,
+ vheap.bucket_count(),
+ ignore_header=False,
+ **kwds)
+
+ @classmethod
+ def setup(cls,
+ storage_name,
+ block_size,
+ heap_height,
+ blocks_per_bucket=1,
+ heap_base=2,
+ storage_type='file',
+ **kwds):
+ if 'block_count' in kwds:
+ raise ValueError("'block_count' keyword is not accepted")
+ if heap_height < 0:
+ raise ValueError(
+ "heap height must be 0 or greater. Invalid value: %s"
+ % (heap_height))
+ if blocks_per_bucket < 1:
+ raise ValueError(
+ "blocks_per_bucket must be 1 or greater. "
+ "Invalid value: %s" % (blocks_per_bucket))
+ if heap_base < 2:
+ raise ValueError(
+ "heap base must be 2 or greater. Invalid value: %s"
+ % (heap_base))
+
+ vheap = SizedVirtualHeap(
+ heap_base,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket)
+
+ user_header_data = kwds.pop('header_data', bytes())
+ if type(user_header_data) is not bytes:
+ raise TypeError(
+ "'header_data' must be of type bytes. "
+ "Invalid type: %s" % (type(user_header_data)))
+ kwds['header_data'] = \
+ struct.pack(cls._header_struct_string,
+ heap_base,
+ heap_height,
+ blocks_per_bucket) + \
+ user_header_data
+
+ return HeapStorage(
+ BlockStorageTypeFactory(storage_type).setup(
+ storage_name,
+ vheap.blocks_per_bucket * block_size,
+ vheap.bucket_count(),
+ **kwds))
+
+ @property
+ def header_data(self):
+ return self._storage.header_data[self._header_offset:]
+
+ @property
+ def bucket_count(self):
+ return self._storage.block_count
+
+ @property
+ def bucket_size(self):
+ return self._storage.block_size
+
+ @property
+ def blocks_per_bucket(self):
+ return self._vheap.blocks_per_bucket
+
+ @property
+ def storage_name(self):
+ return self._storage.storage_name
+
+ @property
+ def virtual_heap(self):
+ return self._vheap
+
+ @property
+ def bucket_storage(self):
+ return self._storage
+
+ def update_header_data(self, new_header_data):
+ self._storage.update_header_data(
+ self._storage.header_data[:self._header_offset] + \
+ new_header_data)
+
+ def close(self):
+ print("Non enc heap close 1")
+ self._storage.close()
+ print("Non enc heap close 2")
+
+ def read_path(self, b, level_start=0):
+ assert 0 <= b < self._vheap.bucket_count()
+ bucket_list = self._vheap.Node(b).bucket_path_from_root()
+ assert 0 <= level_start < len(bucket_list)
+ return self._storage.read_blocks(bucket_list[level_start:])
+
+ def write_path(self, b, buckets, level_start=0):
+ assert 0 <= b < self._vheap.bucket_count()
+ bucket_list = self._vheap.Node(b).bucket_path_from_root()
+ assert 0 <= level_start < len(bucket_list)
+ self._storage.write_blocks(bucket_list[level_start:],
+ buckets)
+
+ @property
+ def bytes_sent(self):
+ return self._storage.bytes_sent
+
+ @property
+ def bytes_received(self):
+ return self._storage.bytes_received
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="(0, 0)}"];
+}
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="{{0}}"];
+}
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="''}"];
+ 1 [penwidth=1,label="'0'}"];
+ 3 [penwidth=1,label="'00'}"];
+ 7 [penwidth=1,label="'000'}"];
+ 3 -> 7 ;
+ 8 [penwidth=1,label="'001'}"];
+ 3 -> 8 ;
+ 1 -> 3 ;
+ 4 [penwidth=1,label="'01'}"];
+ 9 [penwidth=1,label="'010'}"];
+ 4 -> 9 ;
+ 10 [penwidth=1,label="'011'}"];
+ 4 -> 10 ;
+ 1 -> 4 ;
+ 0 -> 1 ;
+ 2 [penwidth=1,label="'1'}"];
+ 5 [penwidth=1,label="'10'}"];
+ 11 [penwidth=1,label="'100'}"];
+ 5 -> 11 ;
+ 12 [penwidth=1,label="'101'}"];
+ 5 -> 12 ;
+ 2 -> 5 ;
+ 6 [penwidth=1,label="'11'}"];
+ 13 [penwidth=1,label="'110'}"];
+ 6 -> 13 ;
+ 14 [penwidth=1,label="'111'}"];
+ 6 -> 14 ;
+ 2 -> 6 ;
+ 0 -> 2 ;
+}
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="{{0}}"];
+ 1 [penwidth=1,label="{{1}}"];
+ 3 [penwidth=1,label="{{3}}"];
+ 7 [penwidth=1,label="{{7}}"];
+ 3 -> 7 ;
+ 8 [penwidth=1,label="{{8}}"];
+ 3 -> 8 ;
+ 1 -> 3 ;
+ 4 [penwidth=1,label="{{4}}"];
+ 9 [penwidth=1,label="{{9}}"];
+ 4 -> 9 ;
+ 10 [penwidth=1,label="{{10}}"];
+ 4 -> 10 ;
+ 1 -> 4 ;
+ 0 -> 1 ;
+ 2 [penwidth=1,label="{{2}}"];
+ 5 [penwidth=1,label="{{5}}"];
+ 11 [penwidth=1,label="{{11}}"];
+ 5 -> 11 ;
+ 12 [penwidth=1,label="{{12}}"];
+ 5 -> 12 ;
+ 2 -> 5 ;
+ 6 [penwidth=1,label="{{6}}"];
+ 13 [penwidth=1,label="{{13}}"];
+ 6 -> 13 ;
+ 14 [penwidth=1,label="{{14}}"];
+ 6 -> 14 ;
+ 2 -> 6 ;
+ 0 -> 2 ;
+}
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="''}"];
+ 1 [penwidth=1,label="'0'}"];
+ 3 [penwidth=1,label="'00'}"];
+ 7 [penwidth=1,label="'000'}"];
+ 3 -> 7 ;
+ 8 [penwidth=1,label="'001'}"];
+ 3 -> 8 ;
+ 1 -> 3 ;
+ 4 [penwidth=1,label="'01'}"];
+ 9 [penwidth=1,label="'010'}"];
+ 4 -> 9 ;
+ 10 [penwidth=1,label="'011'}"];
+ 4 -> 10 ;
+ 1 -> 4 ;
+ 0 -> 1 ;
+ 2 [penwidth=1,label="'1'}"];
+ 5 [penwidth=1,label="'10'}"];
+ 11 [penwidth=1,label="'100'}"];
+ 5 -> 11 ;
+ 12 [penwidth=1,label="'101'}"];
+ 5 -> 12 ;
+ 2 -> 5 ;
+ 6 [penwidth=1,label="'11'}"];
+ 13 [penwidth=1,label="'110'}"];
+ 6 -> 13 ;
+ 14 [penwidth=1,label="'111'}"];
+ 6 -> 14 ;
+ 2 -> 6 ;
+ 0 -> 2 ;
+}
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="{{0}|{1}}"];
+ 1 [penwidth=1,label="{{2}|{3}}"];
+ 3 [penwidth=1,label="{{6}|{7}}"];
+ 7 [penwidth=1,label="{{14}|{15}}"];
+ 3 -> 7 ;
+ 8 [penwidth=1,label="{{16}|{17}}"];
+ 3 -> 8 ;
+ 1 -> 3 ;
+ 4 [penwidth=1,label="{{8}|{9}}"];
+ 9 [penwidth=1,label="{{18}|{19}}"];
+ 4 -> 9 ;
+ 10 [penwidth=1,label="{{20}|{21}}"];
+ 4 -> 10 ;
+ 1 -> 4 ;
+ 0 -> 1 ;
+ 2 [penwidth=1,label="{{4}|{5}}"];
+ 5 [penwidth=1,label="{{10}|{11}}"];
+ 11 [penwidth=1,label="{{22}|{23}}"];
+ 5 -> 11 ;
+ 12 [penwidth=1,label="{{24}|{25}}"];
+ 5 -> 12 ;
+ 2 -> 5 ;
+ 6 [penwidth=1,label="{{12}|{13}}"];
+ 13 [penwidth=1,label="{{26}|{27}}"];
+ 6 -> 13 ;
+ 14 [penwidth=1,label="{{28}|{29}}"];
+ 6 -> 14 ;
+ 2 -> 6 ;
+ 0 -> 2 ;
+}
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="''}"];
+ 1 [penwidth=1,label="'0'}"];
+ 4 [penwidth=1,label="'00'}"];
+ 13 [penwidth=1,label="'000'}"];
+ 4 -> 13 ;
+ 14 [penwidth=1,label="'001'}"];
+ 4 -> 14 ;
+ 15 [penwidth=1,label="'002'}"];
+ 4 -> 15 ;
+ 1 -> 4 ;
+ 5 [penwidth=1,label="'01'}"];
+ 16 [penwidth=1,label="'010'}"];
+ 5 -> 16 ;
+ 17 [penwidth=1,label="'011'}"];
+ 5 -> 17 ;
+ 18 [penwidth=1,label="'012'}"];
+ 5 -> 18 ;
+ 1 -> 5 ;
+ 6 [penwidth=1,label="'02'}"];
+ 19 [penwidth=1,label="'020'}"];
+ 6 -> 19 ;
+ 20 [penwidth=1,label="'021'}"];
+ 6 -> 20 ;
+ 21 [penwidth=1,label="'022'}"];
+ 6 -> 21 ;
+ 1 -> 6 ;
+ 0 -> 1 ;
+ 2 [penwidth=1,label="'1'}"];
+ 7 [penwidth=1,label="'10'}"];
+ 22 [penwidth=1,label="'100'}"];
+ 7 -> 22 ;
+ 23 [penwidth=1,label="'101'}"];
+ 7 -> 23 ;
+ 24 [penwidth=1,label="'102'}"];
+ 7 -> 24 ;
+ 2 -> 7 ;
+ 8 [penwidth=1,label="'11'}"];
+ 25 [penwidth=1,label="'110'}"];
+ 8 -> 25 ;
+ 26 [penwidth=1,label="'111'}"];
+ 8 -> 26 ;
+ 27 [penwidth=1,label="'112'}"];
+ 8 -> 27 ;
+ 2 -> 8 ;
+ 9 [penwidth=1,label="'12'}"];
+ 28 [penwidth=1,label="'120'}"];
+ 9 -> 28 ;
+ 29 [penwidth=1,label="'121'}"];
+ 9 -> 29 ;
+ 30 [penwidth=1,label="'122'}"];
+ 9 -> 30 ;
+ 2 -> 9 ;
+ 0 -> 2 ;
+ 3 [penwidth=1,label="'2'}"];
+ 10 [penwidth=1,label="'20'}"];
+ 31 [penwidth=1,label="'200'}"];
+ 10 -> 31 ;
+ 32 [penwidth=1,label="'201'}"];
+ 10 -> 32 ;
+ 33 [penwidth=1,label="'202'}"];
+ 10 -> 33 ;
+ 3 -> 10 ;
+ 11 [penwidth=1,label="'21'}"];
+ 34 [penwidth=1,label="'210'}"];
+ 11 -> 34 ;
+ 35 [penwidth=1,label="'211'}"];
+ 11 -> 35 ;
+ 36 [penwidth=1,label="'212'}"];
+ 11 -> 36 ;
+ 3 -> 11 ;
+ 12 [penwidth=1,label="'22'}"];
+ 37 [penwidth=1,label="'220'}"];
+ 12 -> 37 ;
+ 38 [penwidth=1,label="'221'}"];
+ 12 -> 38 ;
+ 39 [penwidth=1,label="'222'}"];
+ 12 -> 39 ;
+ 3 -> 12 ;
+ 0 -> 3 ;
+}
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="{{0}}"];
+ 1 [penwidth=1,label="{{1}}"];
+ 4 [penwidth=1,label="{{4}}"];
+ 13 [penwidth=1,label="{{13}}"];
+ 4 -> 13 ;
+ 14 [penwidth=1,label="{{14}}"];
+ 4 -> 14 ;
+ 15 [penwidth=1,label="{{15}}"];
+ 4 -> 15 ;
+ 1 -> 4 ;
+ 5 [penwidth=1,label="{{5}}"];
+ 16 [penwidth=1,label="{{16}}"];
+ 5 -> 16 ;
+ 17 [penwidth=1,label="{{17}}"];
+ 5 -> 17 ;
+ 18 [penwidth=1,label="{{18}}"];
+ 5 -> 18 ;
+ 1 -> 5 ;
+ 6 [penwidth=1,label="{{6}}"];
+ 19 [penwidth=1,label="{{19}}"];
+ 6 -> 19 ;
+ 20 [penwidth=1,label="{{20}}"];
+ 6 -> 20 ;
+ 21 [penwidth=1,label="{{21}}"];
+ 6 -> 21 ;
+ 1 -> 6 ;
+ 0 -> 1 ;
+ 2 [penwidth=1,label="{{2}}"];
+ 7 [penwidth=1,label="{{7}}"];
+ 22 [penwidth=1,label="{{22}}"];
+ 7 -> 22 ;
+ 23 [penwidth=1,label="{{23}}"];
+ 7 -> 23 ;
+ 24 [penwidth=1,label="{{24}}"];
+ 7 -> 24 ;
+ 2 -> 7 ;
+ 8 [penwidth=1,label="{{8}}"];
+ 25 [penwidth=1,label="{{25}}"];
+ 8 -> 25 ;
+ 26 [penwidth=1,label="{{26}}"];
+ 8 -> 26 ;
+ 27 [penwidth=1,label="{{27}}"];
+ 8 -> 27 ;
+ 2 -> 8 ;
+ 9 [penwidth=1,label="{{9}}"];
+ 28 [penwidth=1,label="{{28}}"];
+ 9 -> 28 ;
+ 29 [penwidth=1,label="{{29}}"];
+ 9 -> 29 ;
+ 30 [penwidth=1,label="{{30}}"];
+ 9 -> 30 ;
+ 2 -> 9 ;
+ 0 -> 2 ;
+ 3 [penwidth=1,label="{{3}}"];
+ 10 [penwidth=1,label="{{10}}"];
+ 31 [penwidth=1,label="{{31}}"];
+ 10 -> 31 ;
+ 32 [penwidth=1,label="{{32}}"];
+ 10 -> 32 ;
+ 33 [penwidth=1,label="{{33}}"];
+ 10 -> 33 ;
+ 3 -> 10 ;
+ 11 [penwidth=1,label="{{11}}"];
+ 34 [penwidth=1,label="{{34}}"];
+ 11 -> 34 ;
+ 35 [penwidth=1,label="{{35}}"];
+ 11 -> 35 ;
+ 36 [penwidth=1,label="{{36}}"];
+ 11 -> 36 ;
+ 3 -> 11 ;
+ 12 [penwidth=1,label="{{12}}"];
+ 37 [penwidth=1,label="{{37}}"];
+ 12 -> 37 ;
+ 38 [penwidth=1,label="{{38}}"];
+ 12 -> 38 ;
+ 39 [penwidth=1,label="{{39}}"];
+ 12 -> 39 ;
+ 3 -> 12 ;
+ 0 -> 3 ;
+}
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="''}"];
+ 1 [penwidth=1,label="'0'}"];
+ 4 [penwidth=1,label="'00'}"];
+ 13 [penwidth=1,label="'000'}"];
+ 4 -> 13 ;
+ 14 [penwidth=1,label="'001'}"];
+ 4 -> 14 ;
+ 15 [penwidth=1,label="'002'}"];
+ 4 -> 15 ;
+ 1 -> 4 ;
+ 5 [penwidth=1,label="'01'}"];
+ 16 [penwidth=1,label="'010'}"];
+ 5 -> 16 ;
+ 17 [penwidth=1,label="'011'}"];
+ 5 -> 17 ;
+ 18 [penwidth=1,label="'012'}"];
+ 5 -> 18 ;
+ 1 -> 5 ;
+ 6 [penwidth=1,label="'02'}"];
+ 19 [penwidth=1,label="'020'}"];
+ 6 -> 19 ;
+ 20 [penwidth=1,label="'021'}"];
+ 6 -> 20 ;
+ 21 [penwidth=1,label="'022'}"];
+ 6 -> 21 ;
+ 1 -> 6 ;
+ 0 -> 1 ;
+ 2 [penwidth=1,label="'1'}"];
+ 7 [penwidth=1,label="'10'}"];
+ 22 [penwidth=1,label="'100'}"];
+ 7 -> 22 ;
+ 23 [penwidth=1,label="'101'}"];
+ 7 -> 23 ;
+ 24 [penwidth=1,label="'102'}"];
+ 7 -> 24 ;
+ 2 -> 7 ;
+ 8 [penwidth=1,label="'11'}"];
+ 25 [penwidth=1,label="'110'}"];
+ 8 -> 25 ;
+ 26 [penwidth=1,label="'111'}"];
+ 8 -> 26 ;
+ 27 [penwidth=1,label="'112'}"];
+ 8 -> 27 ;
+ 2 -> 8 ;
+ 9 [penwidth=1,label="'12'}"];
+ 28 [penwidth=1,label="'120'}"];
+ 9 -> 28 ;
+ 29 [penwidth=1,label="'121'}"];
+ 9 -> 29 ;
+ 30 [penwidth=1,label="'122'}"];
+ 9 -> 30 ;
+ 2 -> 9 ;
+ 0 -> 2 ;
+ 3 [penwidth=1,label="'2'}"];
+ 10 [penwidth=1,label="'20'}"];
+ 31 [penwidth=1,label="'200'}"];
+ 10 -> 31 ;
+ 32 [penwidth=1,label="'201'}"];
+ 10 -> 32 ;
+ 33 [penwidth=1,label="'202'}"];
+ 10 -> 33 ;
+ 3 -> 10 ;
+ 11 [penwidth=1,label="'21'}"];
+ 34 [penwidth=1,label="'210'}"];
+ 11 -> 34 ;
+ 35 [penwidth=1,label="'211'}"];
+ 11 -> 35 ;
+ 36 [penwidth=1,label="'212'}"];
+ 11 -> 36 ;
+ 3 -> 11 ;
+ 12 [penwidth=1,label="'22'}"];
+ 37 [penwidth=1,label="'220'}"];
+ 12 -> 37 ;
+ 38 [penwidth=1,label="'221'}"];
+ 12 -> 38 ;
+ 39 [penwidth=1,label="'222'}"];
+ 12 -> 39 ;
+ 3 -> 12 ;
+ 0 -> 3 ;
+}
--- /dev/null
+// Created by SizedVirtualHeap.write_as_dot(...)
+digraph heaptree {
+node [shape=record]
+ 0 [penwidth=1,label="{{0}|{1}}"];
+ 1 [penwidth=1,label="{{2}|{3}}"];
+ 4 [penwidth=1,label="{{8}|{9}}"];
+ 13 [penwidth=1,label="{{26}|{27}}"];
+ 4 -> 13 ;
+ 14 [penwidth=1,label="{{28}|{29}}"];
+ 4 -> 14 ;
+ 15 [penwidth=1,label="{{30}|{31}}"];
+ 4 -> 15 ;
+ 1 -> 4 ;
+ 5 [penwidth=1,label="{{10}|{11}}"];
+ 16 [penwidth=1,label="{{32}|{33}}"];
+ 5 -> 16 ;
+ 17 [penwidth=1,label="{{34}|{35}}"];
+ 5 -> 17 ;
+ 18 [penwidth=1,label="{{36}|{37}}"];
+ 5 -> 18 ;
+ 1 -> 5 ;
+ 6 [penwidth=1,label="{{12}|{13}}"];
+ 19 [penwidth=1,label="{{38}|{39}}"];
+ 6 -> 19 ;
+ 20 [penwidth=1,label="{{40}|{41}}"];
+ 6 -> 20 ;
+ 21 [penwidth=1,label="{{42}|{43}}"];
+ 6 -> 21 ;
+ 1 -> 6 ;
+ 0 -> 1 ;
+ 2 [penwidth=1,label="{{4}|{5}}"];
+ 7 [penwidth=1,label="{{14}|{15}}"];
+ 22 [penwidth=1,label="{{44}|{45}}"];
+ 7 -> 22 ;
+ 23 [penwidth=1,label="{{46}|{47}}"];
+ 7 -> 23 ;
+ 24 [penwidth=1,label="{{48}|{49}}"];
+ 7 -> 24 ;
+ 2 -> 7 ;
+ 8 [penwidth=1,label="{{16}|{17}}"];
+ 25 [penwidth=1,label="{{50}|{51}}"];
+ 8 -> 25 ;
+ 26 [penwidth=1,label="{{52}|{53}}"];
+ 8 -> 26 ;
+ 27 [penwidth=1,label="{{54}|{55}}"];
+ 8 -> 27 ;
+ 2 -> 8 ;
+ 9 [penwidth=1,label="{{18}|{19}}"];
+ 28 [penwidth=1,label="{{56}|{57}}"];
+ 9 -> 28 ;
+ 29 [penwidth=1,label="{{58}|{59}}"];
+ 9 -> 29 ;
+ 30 [penwidth=1,label="{{60}|{61}}"];
+ 9 -> 30 ;
+ 2 -> 9 ;
+ 0 -> 2 ;
+ 3 [penwidth=1,label="{{6}|{7}}"];
+ 10 [penwidth=1,label="{{20}|{21}}"];
+ 31 [penwidth=1,label="{{62}|{63}}"];
+ 10 -> 31 ;
+ 32 [penwidth=1,label="{{64}|{65}}"];
+ 10 -> 32 ;
+ 33 [penwidth=1,label="{{66}|{67}}"];
+ 10 -> 33 ;
+ 3 -> 10 ;
+ 11 [penwidth=1,label="{{22}|{23}}"];
+ 34 [penwidth=1,label="{{68}|{69}}"];
+ 11 -> 34 ;
+ 35 [penwidth=1,label="{{70}|{71}}"];
+ 11 -> 35 ;
+ 36 [penwidth=1,label="{{72}|{73}}"];
+ 11 -> 36 ;
+ 3 -> 11 ;
+ 12 [penwidth=1,label="{{24}|{25}}"];
+ 37 [penwidth=1,label="{{74}|{75}}"];
+ 12 -> 37 ;
+ 38 [penwidth=1,label="{{76}|{77}}"];
+ 12 -> 38 ;
+ 39 [penwidth=1,label="{{78}|{79}}"];
+ 12 -> 39 ;
+ 3 -> 12 ;
+ 0 -> 3 ;
+}
--- /dev/null
+import unittest2
+
+from pyoram.crypto.aes import AES
+
+class TestAES(unittest2.TestCase):
+
+ def test_KeyGen(self):
+ self.assertEqual(len(AES.key_sizes), 3)
+ self.assertEqual(len(set(AES.key_sizes)), 3)
+ for keysize in AES.key_sizes:
+ key_list = []
+ key_set = set()
+ for i in range(10):
+ k = AES.KeyGen(keysize)
+ self.assertEqual(len(k), keysize)
+ key_list.append(k)
+ key_set.add(k)
+ self.assertEqual(len(key_list), 10)
+ # make sure every key is unique
+ self.assertEqual(len(key_list), len(key_set))
+
+ def test_CTR(self):
+ self._test_Enc_Dec(
+ AES.CTREnc,
+ AES.CTRDec,
+ lambda i, size: bytes(bytearray([i]) * size))
+
+ def test_GCM(self):
+ self._test_Enc_Dec(
+ AES.GCMEnc,
+ AES.GCMDec,
+ lambda i, size: bytes(bytearray([i]) * size))
+
+ def _test_Enc_Dec(self,
+ enc_func,
+ dec_func,
+ get_plaintext):
+ blocksize_factor = [0.5, 1, 1.5, 2, 2.5]
+ plaintext_blocks = []
+ for i, f in enumerate(blocksize_factor):
+ size = AES.block_size * f
+ size = int(round(size))
+ if int(f) != f:
+ assert (size % AES.block_size) != 0
+ plaintext_blocks.append(get_plaintext(i, size))
+
+ assert len(AES.key_sizes) > 0
+ ciphertext_blocks = {}
+ keys = {}
+ for keysize in AES.key_sizes:
+ key = AES.KeyGen(keysize)
+ keys[keysize] = key
+ ciphertext_blocks[keysize] = []
+ for block in plaintext_blocks:
+ ciphertext_blocks[keysize].append(
+ enc_func(key, block))
+
+ self.assertEqual(len(ciphertext_blocks),
+ len(AES.key_sizes))
+ self.assertEqual(len(keys),
+ len(AES.key_sizes))
+
+ plaintext_decrypted_blocks = {}
+ for keysize in keys:
+ key = keys[keysize]
+ plaintext_decrypted_blocks[keysize] = []
+ for block in ciphertext_blocks[keysize]:
+ plaintext_decrypted_blocks[keysize].append(
+ dec_func(key, block))
+
+ self.assertEqual(len(plaintext_decrypted_blocks),
+ len(AES.key_sizes))
+
+ for i in range(len(blocksize_factor)):
+ for keysize in AES.key_sizes:
+ self.assertEqual(
+ plaintext_blocks[i],
+ plaintext_decrypted_blocks[keysize][i])
+ self.assertNotEqual(
+ plaintext_blocks[i],
+ ciphertext_blocks[keysize][i])
+ if enc_func is AES.CTREnc:
+ self.assertEqual(
+ len(ciphertext_blocks[keysize][i]),
+ len(plaintext_blocks[i]) + AES.block_size)
+ else:
+ assert enc_func is AES.GCMEnc
+ self.assertEqual(
+ len(ciphertext_blocks[keysize][i]),
+ len(plaintext_blocks[i]) + 2*AES.block_size)
+ # check IND-CPA
+ key = keys[keysize]
+ alt_ciphertext = enc_func(key, plaintext_blocks[i])
+ self.assertNotEqual(
+ ciphertext_blocks[keysize][i],
+ alt_ciphertext)
+ self.assertEqual(
+ len(ciphertext_blocks[keysize][i]),
+ len(alt_ciphertext))
+ self.assertNotEqual(
+ ciphertext_blocks[keysize][i][:AES.block_size],
+ alt_ciphertext[:AES.block_size])
+ self.assertNotEqual(
+ ciphertext_blocks[keysize][i][AES.block_size:],
+ alt_ciphertext[AES.block_size:])
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import os
+import shutil
+import unittest2
+import tempfile
+import struct
+
+from pyoram.storage.block_storage import \
+ BlockStorageTypeFactory
+from pyoram.storage.block_storage_file import \
+ BlockStorageFile
+from pyoram.storage.block_storage_mmap import \
+ BlockStorageMMap
+from pyoram.storage.block_storage_ram import \
+ BlockStorageRAM
+from pyoram.storage.block_storage_sftp import \
+ BlockStorageSFTP
+from pyoram.storage.block_storage_s3 import \
+ BlockStorageS3
+from pyoram.storage.boto3_s3_wrapper import \
+ (Boto3S3Wrapper,
+ MockBoto3S3Wrapper)
+
+import six
+from six.moves import xrange
+from six import BytesIO
+
+thisdir = os.path.dirname(os.path.abspath(__file__))
+
+try:
+ import boto3
+ has_boto3 = True
+except: # pragma: no cover
+ has_boto3 = False # pragma: no cover
+
+class TestBlockStorageTypeFactory(unittest2.TestCase):
+
+ def test_file(self):
+ self.assertIs(BlockStorageTypeFactory('file'),
+ BlockStorageFile)
+
+ def test_mmap(self):
+ self.assertIs(BlockStorageTypeFactory('mmap'),
+ BlockStorageMMap)
+
+ def test_ram(self):
+ self.assertIs(BlockStorageTypeFactory('ram'),
+ BlockStorageRAM)
+
+ def test_sftp(self):
+ self.assertIs(BlockStorageTypeFactory('sftp'),
+ BlockStorageSFTP)
+
+ def test_s3(self):
+ self.assertIs(BlockStorageTypeFactory('s3'),
+ BlockStorageS3)
+
+ def test_invalid(self):
+ with self.assertRaises(ValueError):
+ BlockStorageTypeFactory(None)
+
+ def test_register_invalid_name(self):
+ with self.assertRaises(ValueError):
+ BlockStorageTypeFactory.register_device(
+ 's3', BlockStorageFile)
+
+ def test_register_invalid_type(self):
+ with self.assertRaises(TypeError):
+ BlockStorageTypeFactory.register_device(
+ 'new_str_type', str)
+
+class _TestBlockStorage(object):
+
+ _type = None
+ _type_kwds = None
+
+ @classmethod
+ def _read_storage(cls, storage):
+ with open(storage.storage_name, 'rb') as f:
+ return f.read()
+
+ @classmethod
+ def _remove_storage(cls, name):
+ if os.path.exists(name):
+ if os.path.isdir(name):
+ shutil.rmtree(name, ignore_errors=True)
+ else:
+ os.remove(name)
+
+ @classmethod
+ def _check_exists(cls, name):
+ return os.path.exists(name)
+
+ @classmethod
+ def _get_empty_existing(cls):
+ return os.path.join(thisdir,
+ "baselines",
+ "exists.empty")
+
+ @classmethod
+ def _get_dummy_noexist(cls):
+ fd, name = tempfile.mkstemp(dir=os.getcwd())
+ os.close(fd)
+ return name
+
+ def _open_teststorage(self, **kwds):
+ kwds.update(self._type_kwds)
+ return self._type(self._testfname, **kwds)
+
+ def _reopen_storage(self, storage):
+ return self._type(storage.storage_name, **self._type_kwds)
+
+ @classmethod
+ def setUpClass(cls):
+ assert cls._type is not None
+ assert cls._type_kwds is not None
+ cls._dummy_name = cls._get_dummy_noexist()
+ if cls._check_exists(cls._dummy_name):
+ cls._remove_storage(cls._dummy_name)
+ if os.path.exists(cls._dummy_name):
+ _TestBlockStorage.\
+ _remove_storage(cls._dummy_name) # pragma: no cover
+ cls._block_size = 25
+ cls._block_count = 5
+ cls._testfname = cls.__name__ + "_testfile.bin"
+ cls._blocks = []
+ f = cls._type.setup(
+ cls._testfname,
+ block_size=cls._block_size,
+ block_count=cls._block_count,
+ initialize=lambda i: bytes(bytearray([i])*cls._block_size),
+ ignore_existing=True,
+ **cls._type_kwds)
+ f.close()
+ cls._original_f = f
+ for i in range(cls._block_count):
+ data = bytearray([i])*cls._block_size
+ cls._blocks.append(data)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls._remove_storage(cls._testfname)
+ cls._remove_storage(cls._dummy_name)
+
+ def test_setup_fails(self):
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ self._type.setup(
+ self._get_empty_existing(),
+ block_size=10,
+ block_count=10,
+ **self._type_kwds)
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ self._type.setup(
+ self._get_empty_existing(),
+ block_size=10,
+ block_count=10,
+ ignore_existing=False,
+ **self._type_kwds)
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ self._type.setup(self._dummy_name,
+ block_size=0,
+ block_count=1,
+ **self._type_kwds)
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ self._type.setup(self._dummy_name,
+ block_size=1,
+ block_count=0,
+ **self._type_kwds)
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ with self.assertRaises(TypeError):
+ self._type.setup(self._dummy_name,
+ block_size=1,
+ block_count=1,
+ header_data=2,
+ **self._type_kwds)
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ # TODO: The multiprocessing module is bad
+ # about handling exceptions raised on the
+ # thread's stack.
+ #with self.assertRaises(ValueError):
+ # def _init(i):
+ # raise ValueError
+ # self._type.setup(self._dummy_name,
+ # block_size=1,
+ # block_count=1,
+ # initialize=_init,
+ # **self._type_kwds)
+ #self.assertEqual(self._check_exists(self._dummy_name), False)
+
+ def test_setup(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ self._remove_storage(fname)
+ bsize = 10
+ bcount = 11
+ fsetup = self._type.setup(fname, bsize, bcount, **self._type_kwds)
+ fsetup.close()
+ flen = len(self._read_storage(fsetup))
+ self.assertEqual(
+ flen,
+ self._type.compute_storage_size(bsize,
+ bcount))
+ self.assertEqual(
+ flen >
+ self._type.compute_storage_size(bsize,
+ bcount,
+ ignore_header=True),
+ True)
+ with self._reopen_storage(fsetup) as f:
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(fsetup.header_data, bytes())
+ self.assertEqual(f.block_size, bsize)
+ self.assertEqual(fsetup.block_size, bsize)
+ self.assertEqual(f.block_count, bcount)
+ self.assertEqual(fsetup.block_count, bcount)
+ self.assertEqual(f.storage_name, fsetup.storage_name)
+ self.assertEqual(fsetup.storage_name, fsetup.storage_name)
+ if self._type is not BlockStorageRAM:
+ self.assertEqual(fsetup.storage_name, fname)
+ else:
+ self.assertEqual(fsetup.storage_name, None)
+ self._remove_storage(fname)
+
+ def test_setup_withdata(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ self._remove_storage(fname)
+ bsize = 10
+ bcount = 11
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = self._type.setup(fname,
+ bsize,
+ bcount,
+ header_data=header_data,
+ **self._type_kwds)
+ fsetup.close()
+
+ flen = len(self._read_storage(fsetup))
+ self.assertEqual(
+ flen,
+ self._type.compute_storage_size(bsize,
+ bcount,
+ header_data=header_data))
+ self.assertTrue(len(header_data) > 0)
+ self.assertEqual(
+ self._type.compute_storage_size(bsize,
+ bcount) <
+ self._type.compute_storage_size(bsize,
+ bcount,
+ header_data=header_data),
+ True)
+ self.assertEqual(
+ flen >
+ self._type.compute_storage_size(bsize,
+ bcount,
+ header_data=header_data,
+ ignore_header=True),
+ True)
+ with self._reopen_storage(fsetup) as f:
+ self.assertEqual(f.header_data, header_data)
+ self.assertEqual(fsetup.header_data, header_data)
+ self.assertEqual(f.block_size, bsize)
+ self.assertEqual(fsetup.block_size, bsize)
+ self.assertEqual(f.block_count, bcount)
+ self.assertEqual(fsetup.block_count, bcount)
+ self.assertEqual(f.storage_name, fsetup.storage_name)
+ self.assertEqual(fsetup.storage_name, fsetup.storage_name)
+ if self._type is not BlockStorageRAM:
+ self.assertEqual(fsetup.storage_name, fname)
+ else:
+ self.assertEqual(fsetup.storage_name, None)
+
+ self._remove_storage(fname)
+
+ def test_init_noexists(self):
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ with self._type(self._dummy_name, **self._type_kwds) as f:
+ pass # pragma: no cover
+
+ def test_init_exists(self):
+ self.assertEqual(self._check_exists(self._testfname), True)
+ databefore = self._read_storage(self._original_f)
+ with self._open_teststorage() as f:
+ self.assertEqual(f.block_size, self._block_size)
+ self.assertEqual(f.block_count, self._block_count)
+ self.assertEqual(f.storage_name, self._testfname)
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(self._check_exists(self._testfname), True)
+ dataafter = self._read_storage(self._original_f)
+ self.assertEqual(databefore, dataafter)
+
+ def test_read_block(self):
+ with self._open_teststorage() as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ self._block_count*self._block_size*4)
+
+ with self._open_teststorage() as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(list(bytearray(f.read_block(0))),
+ list(self._blocks[0]))
+ self.assertEqual(list(bytearray(f.read_block(self._block_count-1))),
+ list(self._blocks[-1]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ self._block_size*2)
+
+ def test_write_block(self):
+ data = bytearray([self._block_count])*self._block_size
+ self.assertEqual(len(data) > 0, True)
+ with self._open_teststorage() as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ for i in xrange(self._block_count):
+ self.assertNotEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i in xrange(self._block_count):
+ f.write_block(i, bytes(data))
+ for i in xrange(self._block_count):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i, block in enumerate(self._blocks):
+ f.write_block(i, bytes(block))
+
+ self.assertEqual(f.bytes_sent,
+ self._block_count*self._block_size*2)
+ self.assertEqual(f.bytes_received,
+ self._block_count*self._block_size*2)
+
+ def test_read_blocks(self):
+ with self._open_teststorage() as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ data = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ data = f.read_blocks([0])
+ self.assertEqual(len(data), 1)
+ self.assertEqual(list(bytearray(data[0])),
+ list(self._blocks[0]))
+ self.assertEqual(len(self._blocks) > 1, True)
+ data = f.read_blocks(list(xrange(1, self._block_count)) + [0])
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data[:-1], 1):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ self.assertEqual(list(bytearray(data[-1])),
+ list(self._blocks[0]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ (2*self._block_count+1)*self._block_size)
+
+ def test_yield_blocks(self):
+ with self._open_teststorage() as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ data = list(f.yield_blocks(list(xrange(self._block_count))))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ data = list(f.yield_blocks([0]))
+ self.assertEqual(len(data), 1)
+ self.assertEqual(list(bytearray(data[0])),
+ list(self._blocks[0]))
+ self.assertEqual(len(self._blocks) > 1, True)
+ data = list(f.yield_blocks(list(xrange(1, self._block_count)) + [0]))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data[:-1], 1):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ self.assertEqual(list(bytearray(data[-1])),
+ list(self._blocks[0]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ (2*self._block_count+1)*self._block_size)
+
+ def test_write_blocks(self):
+ data = [bytearray([self._block_count])*self._block_size
+ for i in xrange(self._block_count)]
+ with self._open_teststorage() as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in data])
+ new = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(new), self._block_count)
+ for i, block in enumerate(new):
+ self.assertEqual(list(bytearray(block)),
+ list(data[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in self._blocks])
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+
+ self.assertEqual(f.bytes_sent,
+ self._block_count*self._block_size*2)
+ self.assertEqual(f.bytes_received,
+ self._block_count*self._block_size*3)
+
+ def test_update_header_data(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ self._remove_storage(fname)
+ bsize = 10
+ bcount = 11
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = self._type.setup(fname,
+ block_size=bsize,
+ block_count=bcount,
+ header_data=header_data,
+ **self._type_kwds)
+ fsetup.close()
+ new_header_data = bytes(bytearray([1,1,1]))
+ with self._reopen_storage(fsetup) as f:
+ self.assertEqual(f.header_data, header_data)
+ f.update_header_data(new_header_data)
+ self.assertEqual(f.header_data, new_header_data)
+ with self._reopen_storage(fsetup) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ with self.assertRaises(ValueError):
+ with self._reopen_storage(fsetup) as f:
+ f.update_header_data(bytes(bytearray([1,1])))
+ with self.assertRaises(ValueError):
+ with self._reopen_storage(fsetup) as f:
+ f.update_header_data(bytes(bytearray([1,1,1,1])))
+ with self._reopen_storage(fsetup) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ self._remove_storage(fname)
+
+ def test_locked_flag(self):
+ with self._open_teststorage() as f:
+ with self.assertRaises(IOError):
+ with self._open_teststorage() as f1:
+ pass # pragma: no cover
+ with self.assertRaises(IOError):
+ with self._open_teststorage() as f1:
+ pass # pragma: no cover
+ with self._open_teststorage(ignore_lock=True) as f1:
+ pass
+ with self.assertRaises(IOError):
+ with self._open_teststorage() as f1:
+ pass # pragma: no cover
+ with self._open_teststorage(ignore_lock=True) as f1:
+ pass
+ with self._open_teststorage(ignore_lock=True) as f1:
+ pass
+ with self._open_teststorage(ignore_lock=True) as f:
+ pass
+
+ def test_read_block_cloned(self):
+ with self._open_teststorage() as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ self._block_count*self._block_size*4)
+
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(list(bytearray(f.read_block(0))),
+ list(self._blocks[0]))
+ self.assertEqual(list(bytearray(f.read_block(self._block_count-1))),
+ list(self._blocks[-1]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ self._block_size*2)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_write_block_cloned(self):
+ data = bytearray([self._block_count])*self._block_size
+ self.assertEqual(len(data) > 0, True)
+ with self._open_teststorage() as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ for i in xrange(self._block_count):
+ self.assertNotEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i in xrange(self._block_count):
+ f.write_block(i, bytes(data))
+ for i in xrange(self._block_count):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i, block in enumerate(self._blocks):
+ f.write_block(i, bytes(block))
+
+ self.assertEqual(f.bytes_sent,
+ self._block_count*self._block_size*2)
+ self.assertEqual(f.bytes_received,
+ self._block_count*self._block_size*2)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_read_blocks_cloned(self):
+ with self._open_teststorage() as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ data = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ data = f.read_blocks([0])
+ self.assertEqual(len(data), 1)
+ self.assertEqual(list(bytearray(data[0])),
+ list(self._blocks[0]))
+ self.assertEqual(len(self._blocks) > 1, True)
+ data = f.read_blocks(list(xrange(1, self._block_count)) + [0])
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data[:-1], 1):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ self.assertEqual(list(bytearray(data[-1])),
+ list(self._blocks[0]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ (2*self._block_count + 1)*self._block_size)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_yield_blocks_cloned(self):
+ with self._open_teststorage() as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ data = list(f.yield_blocks(list(xrange(self._block_count))))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ data = list(f.yield_blocks([0]))
+ self.assertEqual(len(data), 1)
+ self.assertEqual(list(bytearray(data[0])),
+ list(self._blocks[0]))
+ self.assertEqual(len(self._blocks) > 1, True)
+ data = list(f.yield_blocks(list(xrange(1, self._block_count)) + [0]))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data[:-1], 1):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ self.assertEqual(list(bytearray(data[-1])),
+ list(self._blocks[0]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ (2*self._block_count + 1)*self._block_size)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_write_blocks_cloned(self):
+ data = [bytearray([self._block_count])*self._block_size
+ for i in xrange(self._block_count)]
+ with self._open_teststorage() as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in data])
+ new = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(new), self._block_count)
+ for i, block in enumerate(new):
+ self.assertEqual(list(bytearray(block)),
+ list(data[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in self._blocks])
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+
+ self.assertEqual(f.bytes_sent,
+ self._block_count*self._block_size*2)
+ self.assertEqual(f.bytes_received,
+ self._block_count*self._block_size*3)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+class TestBlockStorageFile(_TestBlockStorage,
+ unittest2.TestCase):
+ _type = BlockStorageFile
+ _type_kwds = {}
+
+class TestBlockStorageFileNoThreadPool(_TestBlockStorage,
+ unittest2.TestCase):
+ _type = BlockStorageFile
+ _type_kwds = {'threadpool_size': 0}
+
+class TestBlockStorageFileThreadPool(_TestBlockStorage,
+ unittest2.TestCase):
+ _type = BlockStorageFile
+ _type_kwds = {'threadpool_size': 1}
+
+class TestBlockStorageMMap(_TestBlockStorage,
+ unittest2.TestCase):
+ _type = BlockStorageMMap
+ _type_kwds = {}
+
+class _TestBlockStorageRAM(_TestBlockStorage):
+
+ @classmethod
+ def _read_storage(cls, storage):
+ return storage.data
+
+ @classmethod
+ def _remove_storage(cls, name):
+ pass
+
+ @classmethod
+ def _check_exists(cls, name):
+ return True
+
+ def _open_teststorage(self, **kwds):
+ kwds.update(self._type_kwds)
+ return self._type(self._original_f.data, **kwds)
+
+ def _reopen_storage(self, storage):
+ return self._type(storage.data, **self._type_kwds)
+
+ #
+ # Override some of the test methods
+ #
+
+ def test_setup_fails(self):
+ with self.assertRaises(ValueError):
+ self._type.setup(self._dummy_name,
+ block_size=0,
+ block_count=1,
+ **self._type_kwds)
+ with self.assertRaises(ValueError):
+ self._type.setup(self._dummy_name,
+ block_size=1,
+ block_count=0,
+ **self._type_kwds)
+ with self.assertRaises(TypeError):
+ self._type.setup(self._dummy_name,
+ block_size=1,
+ block_count=1,
+ header_data=2,
+ **self._type_kwds)
+
+ def test_init_noexists(self):
+ with self.assertRaises(TypeError):
+ with self._type(2, **self._type_kwds) as f:
+ pass # pragma: no cover
+ with self.assertRaises(TypeError):
+ with self._type(None, **self._type_kwds) as f:
+ pass # pragma: no cover
+ with self.assertRaises(struct.error):
+ with self._type(bytearray(), **self._type_kwds) as f:
+ pass # pragma: no cover
+
+ def test_init_exists(self):
+ databefore = self._read_storage(self._original_f)
+ with self._open_teststorage() as f:
+ self.assertEqual(f.block_size, self._block_size)
+ self.assertEqual(f.block_count, self._block_count)
+ self.assertEqual(f.storage_name, self._original_f.storage_name)
+ self.assertEqual(f.storage_name, None)
+ self.assertEqual(f.header_data, bytes())
+ dataafter = self._read_storage(self._original_f)
+ self.assertEqual(databefore, dataafter)
+
+ def test_tofile_fromfile_fileobj(self):
+ out1 = BytesIO()
+ self._original_f.tofile(out1)
+ out1.seek(0)
+ self.assertEqual(len(self._original_f.data) > 0, True)
+ self.assertEqual(self._original_f.data, out1.read())
+ out1.seek(0)
+ in1 = self._type.fromfile(out1)
+ self.assertNotEqual(self._original_f.data, in1.data)
+ out2 = BytesIO()
+ in1.tofile(out2)
+ self.assertNotEqual(self._original_f.data, in1.data)
+ in1.close()
+ self.assertEqual(self._original_f.data, in1.data)
+ out2.seek(0)
+ with self.assertRaises(IOError):
+ with self._type.fromfile(out2) as in2:
+ pass # pragma: no cover
+ out2.seek(0)
+ with self._type.fromfile(out2, ignore_lock=True) as in2:
+ self.assertEqual(self._original_f.data, in1.data)
+ self.assertNotEqual(self._original_f.data, in2.data)
+ self.assertEqual(self._original_f.data, in1.data)
+ self.assertNotEqual(self._original_f.data, in2.data)
+
+ def test_tofile_fromfile_filename(self):
+
+ def _create():
+ fd, out = tempfile.mkstemp()
+ os.close(fd)
+ return out
+ def _read(name):
+ with open(name, 'rb') as f:
+ return f.read()
+
+ out1 = _create()
+ self._original_f.tofile(out1)
+ self.assertEqual(len(self._original_f.data) > 0, True)
+ self.assertEqual(self._original_f.data, _read(out1))
+ in1 = self._type.fromfile(out1)
+ self.assertNotEqual(self._original_f.data, in1.data)
+ out2 = _create()
+ in1.tofile(out2)
+ self.assertNotEqual(self._original_f.data, in1.data)
+ in1.close()
+ self.assertEqual(self._original_f.data, in1.data)
+ with self.assertRaises(IOError):
+ with self._type.fromfile(out2) as in2:
+ pass # pragma: no cover
+ with self._type.fromfile(out2, ignore_lock=True) as in2:
+ self.assertEqual(self._original_f.data, in1.data)
+ self.assertNotEqual(self._original_f.data, in2.data)
+ self.assertEqual(self._original_f.data, in1.data)
+ self.assertNotEqual(self._original_f.data, in2.data)
+
+class TestBlockStorageRAM(_TestBlockStorageRAM,
+ unittest2.TestCase):
+ _type = BlockStorageRAM
+ _type_kwds = {}
+
+class _dummy_sftp_file(object):
+ def __init__(self, *args, **kwds):
+ self._f = open(*args, **kwds)
+ def __enter__(self):
+ return self
+ def __exit__(self, *args):
+ self._f.close()
+ def readv(self, chunks):
+ data = []
+ for offset, size in chunks:
+ self._f.seek(offset)
+ data.append(self._f.read(size))
+ return data
+ def __getattr__(self, key):
+ return getattr(self._f, key)
+ def set_pipelined(self):
+ pass
+
+class dummy_sftp(object):
+ remove = os.remove
+ stat = os.stat
+ @staticmethod
+ def open(*args, **kwds):
+ return _dummy_sftp_file(*args, **kwds)
+ @staticmethod
+ def close():
+ pass
+
+class dummy_sshclient(object):
+ @staticmethod
+ def open_sftp():
+ return dummy_sftp
+
+class TestBlockStorageSFTP(_TestBlockStorage,
+ unittest2.TestCase):
+ _type = BlockStorageSFTP
+ _type_kwds = {'sshclient': dummy_sshclient}
+
+ def test_setup_fails_no_sshclient(self):
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ kwds = dict(self._type_kwds)
+ del kwds['sshclient']
+ with self.assertRaises(ValueError):
+ self._type.setup(self._dummy_name,
+ block_size=1,
+ block_count=1,
+ **kwds)
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+
+ def test_init_exists_no_sshclient(self):
+ self.assertEqual(self._check_exists(self._testfname), True)
+ kwds = dict(self._type_kwds)
+ del kwds['sshclient']
+ with self.assertRaises(ValueError):
+ with self._type(self._testfname, **kwds) as f:
+ pass # pragma: no cover
+
+ databefore = self._read_storage(self._original_f)
+ with self._open_teststorage() as f:
+ self.assertEqual(f.block_size, self._block_size)
+ self.assertEqual(f.block_count, self._block_count)
+ self.assertEqual(f.storage_name, self._testfname)
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(self._check_exists(self._testfname), True)
+ dataafter = self._read_storage(self._original_f)
+ self.assertEqual(databefore, dataafter)
+
+
+class _TestBlockStorageS3Mock(_TestBlockStorage):
+ _type = BlockStorageS3
+ _type_kwds = {}
+
+ @classmethod
+ def _read_storage(cls, storage):
+ import glob
+ data = bytearray()
+ name = storage.storage_name
+ prefix_len = len(os.path.join(name,"b"))
+ nblocks = max(int(bfile[prefix_len:]) for bfile in glob.glob(name+"/b*")) + 1
+ with open(os.path.join(name, BlockStorageS3._index_name), 'rb') as f:
+ data.extend(f.read())
+ for i in range(nblocks):
+ with open(os.path.join(name, "b"+str(i)), 'rb') as f:
+ data.extend(f.read())
+ return data
+
+ def test_init_exists_no_bucket(self):
+ self.assertEqual(self._check_exists(self._testfname), True)
+ databefore = self._read_storage(self._original_f)
+ with self._open_teststorage() as f:
+ self.assertEqual(f.block_size, self._block_size)
+ self.assertEqual(f.block_count, self._block_count)
+ self.assertEqual(f.storage_name, self._testfname)
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(self._check_exists(self._testfname), True)
+ dataafter = self._read_storage(self._original_f)
+ self.assertEqual(databefore, dataafter)
+ kwds = dict(self._type_kwds)
+ del kwds['bucket_name']
+ with self.assertRaises(ValueError):
+ with self._type(self._testfname, **kwds) as f:
+ pass # pragma: no cover
+ dataafter = self._read_storage(self._original_f)
+ self.assertEqual(databefore, dataafter)
+
+ def test_setup_fails_no_bucket(self):
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ kwds = dict(self._type_kwds)
+ del kwds['bucket_name']
+ with self.assertRaises(ValueError):
+ self._type.setup(self._dummy_name,
+ block_size=1,
+ block_count=1,
+ **kwds)
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+
+ def test_setup_ignore_existing(self):
+ self.assertEqual(self._check_exists(self._dummy_name), False)
+ with self._type.setup(self._dummy_name,
+ block_size=1,
+ block_count=1,
+ **self._type_kwds) as f:
+ pass
+ self.assertEqual(self._check_exists(self._dummy_name), True)
+ with self.assertRaises(IOError):
+ with self._type.setup(self._dummy_name,
+ block_size=1,
+ block_count=1,
+ **self._type_kwds) as f:
+ pass # pragma: no cover
+ self.assertEqual(self._check_exists(self._dummy_name), True)
+ with self._type.setup(self._dummy_name,
+ block_size=1,
+ block_count=1,
+ ignore_existing=True,
+ **self._type_kwds) as f:
+ pass
+ self.assertEqual(self._check_exists(self._dummy_name), True)
+ self._remove_storage(self._dummy_name)
+
+class TestBlockStorageS3Mock(_TestBlockStorageS3Mock,
+ unittest2.TestCase):
+ _type_kwds = {'s3_wrapper': MockBoto3S3Wrapper,
+ 'bucket_name': '.'}
+
+class TestBlockStorageS3MockNoThreadPool(_TestBlockStorageS3Mock,
+ unittest2.TestCase):
+ _type_kwds = {'s3_wrapper': MockBoto3S3Wrapper,
+ 'bucket_name': '.',
+ 'threadpool_size': 0}
+
+class TestBlockStorageS3MockThreadPool(_TestBlockStorageS3Mock,
+ unittest2.TestCase):
+ _type_kwds = {'s3_wrapper': MockBoto3S3Wrapper,
+ 'bucket_name': '.',
+ 'threadpool_size': 4}
+
+@unittest2.skipIf((os.environ.get('PYORAM_AWS_TEST_BUCKET') is None) or \
+ (not has_boto3),
+ "No PYORAM_AWS_TEST_BUCKET defined in environment or "
+ "boto3 is not available")
+class TestBlockStorageS3(_TestBlockStorage,
+ unittest2.TestCase):
+ _type = BlockStorageS3
+ _type_kwds = {'bucket_name': os.environ.get('PYORAM_AWS_TEST_BUCKET')}
+
+ @classmethod
+ def _read_storage(cls, storage):
+ data = bytearray()
+ name = storage.storage_name
+ s3 = Boto3S3Wrapper(cls._type_kwds['bucket_name'])
+ prefix_len = len(name+"/b")
+ nblocks = 1 + max(int(obj.key[prefix_len:]) for obj
+ in s3._bucket.objects.filter(Prefix=name+"/b"))
+ data.extend(s3.download(name+"/"+BlockStorageS3._index_name))
+ for i in range(nblocks):
+ data.extend(s3.download(name+"/b"+str(i)))
+ return data
+
+ @classmethod
+ def _remove_storage(cls, name):
+ Boto3S3Wrapper(cls._type_kwds['bucket_name']).clear(name)
+
+ @classmethod
+ def _check_exists(cls, name):
+ return Boto3S3Wrapper(cls._type_kwds['bucket_name']).exists(name)
+
+ @classmethod
+ def _get_empty_existing(cls):
+ return "exists.empty"
+
+ @classmethod
+ def _get_dummy_noexist(cls):
+ s3 = Boto3S3Wrapper(cls._type_kwds['bucket_name'])
+ fd, name = tempfile.mkstemp(dir=os.getcwd())
+ os.close(fd)
+ os.remove(name)
+ while s3.exists(name):
+ fd, name = tempfile.mkstemp(dir=os.getcwd())
+ os.close(fd)
+ os.remove(name)
+ return name
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import os
+import unittest2
+import tempfile
+import struct
+
+from pyoram.storage.block_storage import \
+ BlockStorageTypeFactory
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ EncryptedBlockStorage
+from pyoram.crypto.aes import AES
+
+from six.moves import xrange
+
+thisdir = os.path.dirname(os.path.abspath(__file__))
+
+class _TestEncryptedBlockStorage(object):
+
+ _type_name = None
+ _aes_mode = None
+ _test_key = None
+ _test_key_size = None
+
+ @classmethod
+ def setUpClass(cls):
+ assert cls._type_name is not None
+ assert cls._aes_mode is not None
+ assert not ((cls._test_key is not None) and \
+ (cls._test_key_size is not None))
+ fd, cls._dummy_name = tempfile.mkstemp()
+ os.close(fd)
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ cls._block_size = 25
+ cls._block_count = 5
+ cls._testfname = cls.__name__ + "_testfile.bin"
+ cls._blocks = []
+ f = EncryptedBlockStorage.setup(
+ cls._testfname,
+ cls._block_size,
+ cls._block_count,
+ key_size=cls._test_key_size,
+ key=cls._test_key,
+ storage_type=cls._type_name,
+ aes_mode=cls._aes_mode,
+ initialize=lambda i: bytes(bytearray([i])*cls._block_size),
+ ignore_existing=True)
+ f.close()
+ cls._key = f.key
+ for i in range(cls._block_count):
+ data = bytearray([i])*cls._block_size
+ cls._blocks.append(data)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ os.remove(cls._testfname)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+
+ def test_setup_fails(self):
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ EncryptedBlockStorage.setup(
+ os.path.join(thisdir,
+ "baselines",
+ "exists.empty"),
+ block_size=10,
+ block_count=10,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ EncryptedBlockStorage.setup(
+ os.path.join(thisdir,
+ "baselines",
+ "exists.empty"),
+ block_size=10,
+ block_count=10,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ storage_type=self._type_name,
+ aes_mode=self._aes_mode,
+ ignore_existing=False)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ EncryptedBlockStorage.setup(
+ self._dummy_name,
+ block_size=0,
+ block_count=1,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ EncryptedBlockStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=0,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(TypeError):
+ EncryptedBlockStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=2)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ EncryptedBlockStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=None,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ EncryptedBlockStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ key_size=-1,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(TypeError):
+ EncryptedBlockStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ key=-1,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ EncryptedBlockStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ key=AES.KeyGen(AES.key_sizes[-1]),
+ key_size=AES.key_sizes[-1],
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ EncryptedBlockStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ key=os.urandom(AES.key_sizes[-1]+100),
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name)
+
+ def test_setup(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ bcount = 11
+ fsetup = EncryptedBlockStorage.setup(
+ fname,
+ bsize,
+ bcount,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name)
+ fsetup.close()
+ self.assertEqual(type(fsetup.raw_storage),
+ BlockStorageTypeFactory(self._type_name))
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ EncryptedBlockStorage.compute_storage_size(
+ bsize,
+ bcount,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name))
+ self.assertEqual(
+ flen >
+ EncryptedBlockStorage.compute_storage_size(
+ bsize,
+ bcount,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ ignore_header=True),
+ True)
+ with EncryptedBlockStorage(fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(fsetup.header_data, bytes())
+ self.assertEqual(f.key, fsetup.key)
+ self.assertEqual(f.block_size, bsize)
+ self.assertEqual(fsetup.block_size, bsize)
+ self.assertEqual(f.block_count, bcount)
+ self.assertEqual(fsetup.block_count, bcount)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ # tamper with the plaintext index
+ with open(fname, 'r+b') as f:
+ f.seek(0)
+ f.write(struct.pack("!L",0))
+ with self.assertRaises(ValueError):
+ with EncryptedBlockStorage(fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ pass # pragma: no cover
+ os.remove(fname)
+
+ def test_setup_withdata(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ bcount = 11
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = EncryptedBlockStorage.setup(
+ fname,
+ block_size=bsize,
+ block_count=bcount,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=header_data)
+ fsetup.close()
+ self.assertEqual(type(fsetup.raw_storage),
+ BlockStorageTypeFactory(self._type_name))
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ EncryptedBlockStorage.compute_storage_size(
+ bsize,
+ bcount,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=header_data))
+ self.assertTrue(len(header_data) > 0)
+ self.assertEqual(
+ EncryptedBlockStorage.compute_storage_size(
+ bsize,
+ bcount,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name) <
+ EncryptedBlockStorage.compute_storage_size(
+ bsize,
+ bcount,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=header_data),
+ True)
+ self.assertEqual(
+ flen >
+ EncryptedBlockStorage.compute_storage_size(
+ bsize,
+ bcount,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=header_data,
+ ignore_header=True),
+ True)
+ with EncryptedBlockStorage(fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, header_data)
+ self.assertEqual(fsetup.header_data, header_data)
+ self.assertEqual(f.key, fsetup.key)
+ self.assertEqual(f.block_size, bsize)
+ self.assertEqual(fsetup.block_size, bsize)
+ self.assertEqual(f.block_count, bcount)
+ self.assertEqual(fsetup.block_count, bcount)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ os.remove(fname)
+
+ def test_init_noexists(self):
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ with EncryptedBlockStorage(
+ self._dummy_name,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ pass # pragma: no cover
+
+ def test_init_exists(self):
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with open(self._testfname, 'rb') as f:
+ databefore = f.read()
+ with self.assertRaises(ValueError):
+ with EncryptedBlockStorage(self._testfname,
+ storage_type=self._type_name) as f:
+ pass # pragma: no cover
+ with self.assertRaises(ValueError):
+ with BlockStorageTypeFactory(self._type_name)(self._testfname) as fb:
+ with EncryptedBlockStorage(fb,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ pass # pragma: no cover
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.key, self._key)
+ self.assertEqual(f.block_size, self._block_size)
+ self.assertEqual(f.block_count, self._block_count)
+ self.assertEqual(f.storage_name, self._testfname)
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with open(self._testfname, 'rb') as f:
+ dataafter = f.read()
+ self.assertEqual(databefore, dataafter)
+
+ def test_read_block(self):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ self._block_count*f._storage.block_size*4)
+
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(list(bytearray(f.read_block(0))),
+ list(self._blocks[0]))
+ self.assertEqual(list(bytearray(f.read_block(self._block_count-1))),
+ list(self._blocks[-1]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ f._storage.block_size*2)
+
+ def test_write_block(self):
+ data = bytearray([self._block_count])*self._block_size
+ self.assertEqual(len(data) > 0, True)
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ for i in xrange(self._block_count):
+ self.assertNotEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i in xrange(self._block_count):
+ f.write_block(i, bytes(data))
+ for i in xrange(self._block_count):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i, block in enumerate(self._blocks):
+ f.write_block(i, bytes(block))
+
+ self.assertEqual(f.bytes_sent,
+ self._block_count*f._storage.block_size*2)
+ self.assertEqual(f.bytes_received,
+ self._block_count*f._storage.block_size*2)
+
+ def test_read_blocks(self):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ data = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ data = f.read_blocks([0])
+ self.assertEqual(len(data), 1)
+ self.assertEqual(list(bytearray(data[0])),
+ list(self._blocks[0]))
+ self.assertEqual(len(self._blocks) > 1, True)
+ data = f.read_blocks(list(xrange(1, self._block_count)) + [0])
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data[:-1], 1):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ self.assertEqual(list(bytearray(data[-1])),
+ list(self._blocks[0]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ (2*self._block_count+1)*f._storage.block_size)
+
+ def test_yield_blocks(self):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ data = list(f.yield_blocks(list(xrange(self._block_count))))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ data = list(f.yield_blocks([0]))
+ self.assertEqual(len(data), 1)
+ self.assertEqual(list(bytearray(data[0])),
+ list(self._blocks[0]))
+ self.assertEqual(len(self._blocks) > 1, True)
+ data = list(f.yield_blocks(list(xrange(1, self._block_count)) + [0]))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data[:-1], 1):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ self.assertEqual(list(bytearray(data[-1])),
+ list(self._blocks[0]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ (2*self._block_count+1)*f._storage.block_size)
+
+ def test_write_blocks(self):
+ data = [bytearray([self._block_count])*self._block_size
+ for i in xrange(self._block_count)]
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in data])
+ new = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(new), self._block_count)
+ for i, block in enumerate(new):
+ self.assertEqual(list(bytearray(block)),
+ list(data[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in self._blocks])
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+
+ self.assertEqual(f.bytes_sent,
+ self._block_count*f._storage.block_size*2)
+ self.assertEqual(f.bytes_received,
+ self._block_count*f._storage.block_size*3)
+
+ def test_update_header_data(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ bcount = 11
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = EncryptedBlockStorage.setup(
+ fname,
+ block_size=bsize,
+ block_count=bcount,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ header_data=header_data)
+ fsetup.close()
+ new_header_data = bytes(bytearray([1,1,1]))
+ with EncryptedBlockStorage(fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, header_data)
+ f.update_header_data(new_header_data)
+ self.assertEqual(f.header_data, new_header_data)
+ with EncryptedBlockStorage(fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ with self.assertRaises(ValueError):
+ with EncryptedBlockStorage(fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ f.update_header_data(bytes(bytearray([1,1])))
+ with self.assertRaises(ValueError):
+ with EncryptedBlockStorage(fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ f.update_header_data(bytes(bytearray([1,1,1,1])))
+ with EncryptedBlockStorage(fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ os.remove(fname)
+
+ def test_locked_flag(self):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ with self.assertRaises(IOError):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f1:
+ pass # pragma: no cover
+ with self.assertRaises(IOError):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f1:
+ pass # pragma: no cover
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name,
+ ignore_lock=True) as f1:
+ pass
+ with self.assertRaises(IOError):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f1:
+ pass # pragma: no cover
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name,
+ ignore_lock=True) as f1:
+ pass
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name,
+ ignore_lock=True) as f1:
+ pass
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ pass
+
+ def test_read_block_cloned(self):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ self._block_count*f._storage.block_size*4)
+
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(list(bytearray(f.read_block(0))),
+ list(self._blocks[0]))
+ self.assertEqual(list(bytearray(f.read_block(self._block_count-1))),
+ list(self._blocks[-1]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ f._storage.block_size*2)
+
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_write_block_cloned(self):
+ data = bytearray([self._block_count])*self._block_size
+ self.assertEqual(len(data) > 0, True)
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ for i in xrange(self._block_count):
+ self.assertNotEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i in xrange(self._block_count):
+ f.write_block(i, bytes(data))
+ for i in xrange(self._block_count):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i, block in enumerate(self._blocks):
+ f.write_block(i, bytes(block))
+
+ self.assertEqual(f.bytes_sent,
+ self._block_count*f._storage.block_size*2)
+ self.assertEqual(f.bytes_received,
+ self._block_count*f._storage.block_size*2)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_read_blocks_cloned(self):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ data = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ data = f.read_blocks([0])
+ self.assertEqual(len(data), 1)
+ self.assertEqual(list(bytearray(data[0])),
+ list(self._blocks[0]))
+ self.assertEqual(len(self._blocks) > 1, True)
+ data = f.read_blocks(list(xrange(1, self._block_count)) + [0])
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data[:-1], 1):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ self.assertEqual(list(bytearray(data[-1])),
+ list(self._blocks[0]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ (2*self._block_count + 1)*f._storage.block_size)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_yield_blocks_cloned(self):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ data = list(f.yield_blocks(list(xrange(self._block_count))))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ data = list(f.yield_blocks([0]))
+ self.assertEqual(len(data), 1)
+ self.assertEqual(list(bytearray(data[0])),
+ list(self._blocks[0]))
+ self.assertEqual(len(self._blocks) > 1, True)
+ data = list(f.yield_blocks(list(xrange(1, self._block_count)) + [0]))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data[:-1], 1):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ self.assertEqual(list(bytearray(data[-1])),
+ list(self._blocks[0]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ (2*self._block_count + 1)*f._storage.block_size)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_write_blocks_cloned(self):
+ data = [bytearray([self._block_count])*self._block_size
+ for i in xrange(self._block_count)]
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in data])
+ new = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(new), self._block_count)
+ for i, block in enumerate(new):
+ self.assertEqual(list(bytearray(block)),
+ list(data[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in self._blocks])
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+
+ self.assertEqual(f.bytes_sent,
+ self._block_count*f._storage.block_size*2)
+ self.assertEqual(f.bytes_received,
+ self._block_count*f._storage.block_size*3)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+class TestEncryptedBlockStorageFileCTRKey(_TestEncryptedBlockStorage,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'ctr'
+ _test_key = AES.KeyGen(16)
+
+class TestEncryptedBlockStorageFileCTR32(_TestEncryptedBlockStorage,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'ctr'
+ _test_key_size = 16
+
+class TestEncryptedBlockStorageFileGCMKey(_TestEncryptedBlockStorage,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'gcm'
+ _test_key = AES.KeyGen(24)
+
+class TestEncryptedBlockStorageFileGCM32(_TestEncryptedBlockStorage,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'gcm'
+ _test_key_size = 24
+
+class TestEncryptedBlockStorageMMapFileCTRKey(_TestEncryptedBlockStorage,
+ unittest2.TestCase):
+ _type_name = 'mmap'
+ _aes_mode = 'ctr'
+ _test_key = AES.KeyGen(32)
+
+class TestEncryptedBlockStorageMMapFileCTR32(_TestEncryptedBlockStorage,
+ unittest2.TestCase):
+ _type_name = 'mmap'
+ _aes_mode = 'ctr'
+ _test_key_size = 32
+
+class TestEncryptedBlockStorageMMapFileGCMKey(_TestEncryptedBlockStorage,
+ unittest2.TestCase):
+ _type_name = 'mmap'
+ _aes_mode = 'gcm'
+
+class TestEncryptedBlockStorageMMapFileGCM32(_TestEncryptedBlockStorage,
+ unittest2.TestCase):
+ _type_name = 'mmap'
+ _aes_mode = 'gcm'
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import os
+import unittest2
+import tempfile
+
+from pyoram.util.virtual_heap import \
+ SizedVirtualHeap
+from pyoram.storage.block_storage import \
+ BlockStorageTypeFactory
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ EncryptedBlockStorage
+from pyoram.encrypted_storage.encrypted_heap_storage import \
+ EncryptedHeapStorage
+from pyoram.crypto.aes import AES
+
+from six.moves import xrange
+
+thisdir = os.path.dirname(os.path.abspath(__file__))
+
+class TestEncryptedHeapStorage(unittest2.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ fd, cls._dummy_name = tempfile.mkstemp()
+ os.close(fd)
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ cls._block_size = 25
+ cls._blocks_per_bucket = 3
+ cls._heap_base = 4
+ cls._heap_height = 2
+ cls._bucket_count = \
+ ((cls._heap_base**(cls._heap_height+1)) - 1)//(cls._heap_base-1)
+ cls._block_count = cls._bucket_count * \
+ cls._blocks_per_bucket
+ cls._testfname = cls.__name__ + "_testfile.bin"
+ cls._buckets = []
+ cls._type_name = "file"
+ f = EncryptedHeapStorage.setup(
+ cls._testfname,
+ block_size=cls._block_size,
+ heap_height=cls._heap_height,
+ key_size=AES.key_sizes[-1],
+ heap_base=cls._heap_base,
+ blocks_per_bucket=cls._blocks_per_bucket,
+ storage_type=cls._type_name,
+ initialize=lambda i: bytes(bytearray([i]) * \
+ cls._block_size * \
+ cls._blocks_per_bucket),
+ ignore_existing=True)
+ f.close()
+ cls._key = f.key
+ for i in range(cls._bucket_count):
+ data = bytearray([i]) * \
+ cls._block_size * \
+ cls._blocks_per_bucket
+ cls._buckets.append(data)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ os.remove(cls._testfname)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+
+ def test_setup_fails(self):
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ EncryptedHeapStorage.setup(
+ os.path.join(thisdir,
+ "baselines",
+ "exists.empty"),
+ block_size=10,
+ heap_height=1,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ EncryptedHeapStorage.setup(
+ os.path.join(thisdir,
+ "baselines",
+ "exists.empty"),
+ block_size=10,
+ heap_height=1,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=1,
+ storage_type=self._type_name,
+ ignore_existing=False)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad block_size
+ with self.assertRaises(ValueError):
+ EncryptedHeapStorage.setup(
+ self._dummy_name,
+ block_size=0,
+ heap_height=1,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad heap_height
+ with self.assertRaises(ValueError):
+ EncryptedHeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=-1,
+ blocks_per_bucket=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad blocks_per_bucket
+ with self.assertRaises(ValueError):
+ EncryptedHeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=1,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=0,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad heap_base
+ with self.assertRaises(ValueError):
+ EncryptedHeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=1,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=1,
+ heap_base=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad header_data
+ with self.assertRaises(TypeError):
+ EncryptedHeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=1,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=1,
+ storage_type=self._type_name,
+ header_data=2)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # uses block_count
+ with self.assertRaises(ValueError):
+ EncryptedHeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=1,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=1,
+ block_count=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+
+ def test_setup(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ heap_height = 2
+ blocks_per_bucket = 3
+ fsetup = EncryptedHeapStorage.setup(
+ fname,
+ bsize,
+ heap_height,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=blocks_per_bucket)
+ fsetup.close()
+ self.assertEqual(type(fsetup.raw_storage),
+ BlockStorageTypeFactory(self._type_name))
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ EncryptedHeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket))
+ self.assertEqual(
+ flen >
+ EncryptedHeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket,
+ ignore_header=True),
+ True)
+ with EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(fsetup.header_data, bytes())
+ self.assertEqual(f.key, fsetup.key)
+ self.assertEqual(f.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(fsetup.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(f.bucket_count,
+ 2**(heap_height+1) - 1)
+ self.assertEqual(fsetup.bucket_count,
+ 2**(heap_height+1) - 1)
+ self.assertEqual(f.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(fsetup.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ os.remove(fname)
+
+ def test_setup_withdata(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ heap_height = 2
+ blocks_per_bucket = 1
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = EncryptedHeapStorage.setup(
+ fname,
+ bsize,
+ heap_height,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=blocks_per_bucket,
+ header_data=header_data)
+ fsetup.close()
+ self.assertEqual(type(fsetup.raw_storage),
+ BlockStorageTypeFactory(self._type_name))
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ EncryptedHeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ header_data=header_data))
+ self.assertTrue(len(header_data) > 0)
+ self.assertEqual(
+ EncryptedHeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ storage_type=self._type_name) <
+ EncryptedHeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ storage_type=self._type_name,
+ header_data=header_data),
+ True)
+ self.assertEqual(
+ flen >
+ EncryptedHeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ storage_type=self._type_name,
+ header_data=header_data,
+ ignore_header=True),
+ True)
+ with EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, header_data)
+ self.assertEqual(fsetup.header_data, header_data)
+ self.assertEqual(f.key, fsetup.key)
+ self.assertEqual(f.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(fsetup.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(f.bucket_count,
+ 2**(heap_height+1) - 1)
+ self.assertEqual(fsetup.bucket_count,
+ 2**(heap_height+1) - 1)
+ self.assertEqual(f.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(fsetup.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ os.remove(fname)
+
+ def test_init_noexists(self):
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ with EncryptedHeapStorage(
+ self._dummy_name,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ pass # pragma: no cover
+
+ def test_init_exists(self):
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with open(self._testfname, 'rb') as f:
+ databefore = f.read()
+ with self.assertRaises(ValueError):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as fb:
+ with EncryptedHeapStorage(fb, key=self._key) as f:
+ pass # pragma: no cover
+ with EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.key, self._key)
+ self.assertEqual(f.bucket_size,
+ self._block_size * \
+ self._blocks_per_bucket)
+ self.assertEqual(f.bucket_count,
+ self._bucket_count)
+ self.assertEqual(f.storage_name, self._testfname)
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with open(self._testfname, 'rb') as f:
+ dataafter = f.read()
+ self.assertEqual(databefore, dataafter)
+
+ def test_read_path(self):
+
+ with EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ total_buckets = 0
+ for b in range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1):
+ data = f.read_path(b)
+ bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ total_buckets += len(bucket_path)
+ self.assertEqual(f.virtual_heap.Node(b).level+1,
+ len(bucket_path))
+ for i, bucket in zip(bucket_path, data):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ total_buckets*f.bucket_storage._storage.block_size)
+
+ def test_write_path(self):
+ data = [bytearray([self._bucket_count]) * \
+ self._block_size * \
+ self._blocks_per_bucket
+ for i in xrange(self._block_count)]
+ with EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ total_buckets = 0
+ for b in range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1):
+ orig = f.read_path(b)
+ bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ total_buckets += len(bucket_path)
+ self.assertNotEqual(len(bucket_path), 0)
+ self.assertEqual(f.virtual_heap.Node(b).level+1,
+ len(bucket_path))
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+ f.write_path(b, [bytes(data[i])
+ for i in bucket_path])
+
+ new = f.read_path(b)
+ self.assertEqual(len(new), len(bucket_path))
+ for i, bucket in zip(bucket_path, new):
+ self.assertEqual(list(bytearray(bucket)),
+ list(data[i]))
+
+ f.write_path(b, [bytes(self._buckets[i])
+ for i in bucket_path])
+
+ orig = f.read_path(b)
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent,
+ total_buckets*f.bucket_storage._storage.block_size*2)
+ self.assertEqual(f.bytes_received,
+ total_buckets*f.bucket_storage._storage.block_size*3)
+
+ def test_update_header_data(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ heap_height = 2
+ blocks_per_bucket = 1
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = EncryptedHeapStorage.setup(
+ fname,
+ block_size=bsize,
+ heap_height=heap_height,
+ key_size=AES.key_sizes[-1],
+ blocks_per_bucket=blocks_per_bucket,
+ header_data=header_data)
+ fsetup.close()
+ new_header_data = bytes(bytearray([1,1,1]))
+ with EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, header_data)
+ f.update_header_data(new_header_data)
+ self.assertEqual(f.header_data, new_header_data)
+ with EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ with self.assertRaises(ValueError):
+ with EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ f.update_header_data(bytes(bytearray([1,1])))
+ with self.assertRaises(ValueError):
+ with EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ f.update_header_data(bytes(bytearray([1,1,1,1])))
+ with EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ os.remove(fname)
+
+ def test_locked_flag(self):
+ with EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ with self.assertRaises(IOError):
+ with EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f1:
+ pass # pragma: no cover
+ with self.assertRaises(IOError):
+ with EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f1:
+ pass # pragma: no cover
+ with EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name,
+ ignore_lock=True) as f1:
+ pass
+ with self.assertRaises(IOError):
+ with EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f1:
+ pass # pragma: no cover
+ with EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name,
+ ignore_lock=True) as f1:
+ pass
+ with EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name,
+ ignore_lock=True) as f1:
+ pass
+ with EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as f:
+ pass
+
+ def test_read_path_cloned(self):
+
+ with EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ total_buckets = 0
+ for b in range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1):
+ data = f.read_path(b)
+ bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ total_buckets += len(bucket_path)
+ self.assertEqual(f.virtual_heap.Node(b).level+1,
+ len(bucket_path))
+ for i, bucket in zip(bucket_path, data):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ total_buckets*f.bucket_storage._storage.block_size)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_write_path_cloned(self):
+ data = [bytearray([self._bucket_count]) * \
+ self._block_size * \
+ self._blocks_per_bucket
+ for i in xrange(self._block_count)]
+ with EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ total_buckets = 0
+ for b in range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1):
+ orig = f.read_path(b)
+ bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ total_buckets += len(bucket_path)
+ self.assertNotEqual(len(bucket_path), 0)
+ self.assertEqual(f.virtual_heap.Node(b).level+1,
+ len(bucket_path))
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+ f.write_path(b, [bytes(data[i])
+ for i in bucket_path])
+
+ new = f.read_path(b)
+ self.assertEqual(len(new), len(bucket_path))
+ for i, bucket in zip(bucket_path, new):
+ self.assertEqual(list(bytearray(bucket)),
+ list(data[i]))
+
+ f.write_path(b, [bytes(self._buckets[i])
+ for i in bucket_path])
+
+ orig = f.read_path(b)
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent,
+ total_buckets*f.bucket_storage._storage.block_size*2)
+ self.assertEqual(f.bytes_received,
+ total_buckets*f.bucket_storage._storage.block_size*3)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import os
+import glob
+import sys
+import unittest2
+
+thisfile = os.path.abspath(__file__)
+thisdir = os.path.dirname(thisfile)
+topdir = os.path.dirname(
+ os.path.dirname(
+ os.path.dirname(thisdir)))
+exdir = os.path.join(topdir, 'examples')
+examples = glob.glob(os.path.join(exdir,"*.py"))
+
+assert os.path.exists(exdir)
+assert thisfile not in examples
+
+tdict = {}
+for fname in examples:
+ basename = os.path.basename(fname)
+ assert basename.endswith('.py')
+ assert len(basename) >= 3
+ basename = basename[:-3]
+ tname = 'test_'+basename
+ tdict[tname] = fname, basename
+
+assert len(tdict) == len(examples)
+
+assert 'test_encrypted_storage_s3' in tdict
+assert 'test_path_oram_s3' in tdict
+if 'PYORAM_AWS_TEST_BUCKET' not in os.environ:
+ del tdict['test_encrypted_storage_s3']
+ del tdict['test_path_oram_s3']
+assert 'test_encrypted_storage_sftp' in tdict
+assert 'test_path_oram_sftp' in tdict
+assert 'test_path_oram_sftp_setup' in tdict
+assert 'test_path_oram_sftp_test' in tdict
+if 'PYORAM_SSH_TEST_HOST' not in os.environ:
+ del tdict['test_encrypted_storage_sftp']
+ del tdict['test_path_oram_sftp']
+ del tdict['test_path_oram_sftp_setup']
+ del tdict['test_path_oram_sftp_test']
+
+def _execute_example(example_name):
+ filename, basename = tdict[example_name]
+ assert os.path.exists(filename)
+ try:
+ sys.path.insert(0, exdir)
+ m = __import__(basename)
+ m.main()
+ finally:
+ sys.path.remove(exdir)
+
+# this is recognized by nosetests as
+# a dynamic test generator
+def test_generator():
+ for example_name in sorted(tdict):
+ yield _execute_example, example_name
+
+if __name__ == "__main__":
+ for tfunc, tname in test_generator(): # pragma: no cover
+ tfunc(tname) # pragma: no cover
--- /dev/null
+import os
+import unittest2
+import tempfile
+
+from pyoram.util.virtual_heap import \
+ SizedVirtualHeap
+from pyoram.storage.block_storage import \
+ BlockStorageTypeFactory
+from pyoram.storage.block_storage_file import \
+ BlockStorageFile
+from pyoram.storage.heap_storage import \
+ HeapStorage
+
+from six.moves import xrange
+
+thisdir = os.path.dirname(os.path.abspath(__file__))
+
+class TestHeapStorage(unittest2.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ fd, cls._dummy_name = tempfile.mkstemp()
+ os.close(fd)
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ cls._block_size = 25
+ cls._blocks_per_bucket = 3
+ cls._heap_base = 4
+ cls._heap_height = 2
+ cls._bucket_count = \
+ ((cls._heap_base**(cls._heap_height+1)) - 1)//(cls._heap_base-1)
+ cls._block_count = cls._bucket_count * \
+ cls._blocks_per_bucket
+ cls._testfname = cls.__name__ + "_testfile.bin"
+ cls._buckets = []
+ cls._type_name = "file"
+ f = HeapStorage.setup(
+ cls._testfname,
+ block_size=cls._block_size,
+ heap_height=cls._heap_height,
+ heap_base=cls._heap_base,
+ blocks_per_bucket=cls._blocks_per_bucket,
+ storage_type=cls._type_name,
+ initialize=lambda i: bytes(bytearray([i]) * \
+ cls._block_size * \
+ cls._blocks_per_bucket),
+ ignore_existing=True)
+ f.close()
+ for i in range(cls._bucket_count):
+ data = bytearray([i]) * \
+ cls._block_size * \
+ cls._blocks_per_bucket
+ cls._buckets.append(data)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ os.remove(cls._testfname)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+
+ def test_setup_fails(self):
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ HeapStorage.setup(
+ os.path.join(thisdir,
+ "baselines",
+ "exists.empty"),
+ block_size=10,
+ heap_height=1,
+ blocks_per_bucket=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ HeapStorage.setup(
+ os.path.join(thisdir,
+ "baselines",
+ "exists.empty"),
+ block_size=10,
+ heap_height=1,
+ blocks_per_bucket=1,
+ storage_type=self._type_name,
+ ignore_existing=False)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad block_size
+ with self.assertRaises(ValueError):
+ HeapStorage.setup(
+ self._dummy_name,
+ block_size=0,
+ heap_height=1,
+ blocks_per_bucket=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad heap_height
+ with self.assertRaises(ValueError):
+ HeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=-1,
+ blocks_per_bucket=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad blocks_per_bucket
+ with self.assertRaises(ValueError):
+ HeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=1,
+ blocks_per_bucket=0,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad heap_base
+ with self.assertRaises(ValueError):
+ HeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=1,
+ blocks_per_bucket=1,
+ heap_base=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # bad header_data
+ with self.assertRaises(TypeError):
+ HeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=1,
+ blocks_per_bucket=1,
+ storage_type=self._type_name,
+ header_data=2)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ # uses block_count
+ with self.assertRaises(ValueError):
+ HeapStorage.setup(
+ self._dummy_name,
+ block_size=1,
+ heap_height=1,
+ blocks_per_bucket=1,
+ block_count=1,
+ storage_type=self._type_name)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+
+ def test_setup(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ heap_height = 2
+ blocks_per_bucket = 3
+ fsetup = HeapStorage.setup(
+ fname,
+ bsize,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket)
+ fsetup.close()
+ self.assertEqual(type(fsetup.bucket_storage),
+ BlockStorageTypeFactory(self._type_name))
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ HeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket))
+ self.assertEqual(
+ flen >
+ HeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket,
+ ignore_header=True),
+ True)
+ with HeapStorage(
+ fname,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(fsetup.header_data, bytes())
+ self.assertEqual(f.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(fsetup.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(f.bucket_count,
+ 2**(heap_height+1) - 1)
+ self.assertEqual(fsetup.bucket_count,
+ 2**(heap_height+1) - 1)
+ self.assertEqual(f.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(fsetup.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ os.remove(fname)
+
+ def test_setup_withdata(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ heap_height = 2
+ blocks_per_bucket = 1
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = HeapStorage.setup(
+ fname,
+ bsize,
+ heap_height,
+ blocks_per_bucket=blocks_per_bucket,
+ header_data=header_data)
+ fsetup.close()
+ self.assertEqual(type(fsetup.bucket_storage),
+ BlockStorageTypeFactory(self._type_name))
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ HeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ header_data=header_data))
+ self.assertTrue(len(header_data) > 0)
+ self.assertEqual(
+ HeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ storage_type=self._type_name) <
+ HeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ storage_type=self._type_name,
+ header_data=header_data),
+ True)
+ self.assertEqual(
+ flen >
+ HeapStorage.compute_storage_size(
+ bsize,
+ heap_height,
+ storage_type=self._type_name,
+ header_data=header_data,
+ ignore_header=True),
+ True)
+ with HeapStorage(
+ fname,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, header_data)
+ self.assertEqual(fsetup.header_data, header_data)
+ self.assertEqual(f.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(fsetup.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(f.bucket_count,
+ 2**(heap_height+1) - 1)
+ self.assertEqual(fsetup.bucket_count,
+ 2**(heap_height+1) - 1)
+ self.assertEqual(f.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(fsetup.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ os.remove(fname)
+
+ def test_init_noexists(self):
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ with HeapStorage(
+ self._dummy_name,
+ storage_type=self._type_name) as f:
+ pass # pragma: no cover
+
+ def test_init_exists(self):
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with open(self._testfname, 'rb') as f:
+ databefore = f.read()
+ with self.assertRaises(ValueError):
+ with BlockStorageFile(self._testfname) as fb:
+ with HeapStorage(fb, storage_type='file') as f:
+ pass # pragma: no cover
+ with HeapStorage(
+ self._testfname,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bucket_size,
+ self._block_size * \
+ self._blocks_per_bucket)
+ self.assertEqual(f.bucket_count,
+ self._bucket_count)
+ self.assertEqual(f.storage_name, self._testfname)
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with open(self._testfname, 'rb') as f:
+ dataafter = f.read()
+ self.assertEqual(databefore, dataafter)
+
+ def test_read_path(self):
+
+ with HeapStorage(
+ self._testfname,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ total_buckets = 0
+ for b in range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1):
+ data = f.read_path(b)
+ bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ total_buckets += len(bucket_path)
+ self.assertEqual(f.virtual_heap.Node(b).level+1,
+ len(bucket_path))
+ for i, bucket in zip(bucket_path, data):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ total_buckets*f.bucket_storage.block_size)
+
+ def test_write_path(self):
+ data = [bytearray([self._bucket_count]) * \
+ self._block_size * \
+ self._blocks_per_bucket
+ for i in xrange(self._block_count)]
+ with HeapStorage(
+ self._testfname,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ total_buckets = 0
+ for b in range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1):
+ orig = f.read_path(b)
+ bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ total_buckets += len(bucket_path)
+ self.assertNotEqual(len(bucket_path), 0)
+ self.assertEqual(f.virtual_heap.Node(b).level+1,
+ len(bucket_path))
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+ f.write_path(b, [bytes(data[i])
+ for i in bucket_path])
+
+ new = f.read_path(b)
+ self.assertEqual(len(new), len(bucket_path))
+ for i, bucket in zip(bucket_path, new):
+ self.assertEqual(list(bytearray(bucket)),
+ list(data[i]))
+
+ f.write_path(b, [bytes(self._buckets[i])
+ for i in bucket_path])
+
+ orig = f.read_path(b)
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent,
+ total_buckets*f.bucket_storage.block_size*2)
+ self.assertEqual(f.bytes_received,
+ total_buckets*f.bucket_storage.block_size*3)
+
+ def test_update_header_data(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ heap_height = 2
+ blocks_per_bucket = 1
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = HeapStorage.setup(
+ fname,
+ block_size=bsize,
+ heap_height=heap_height,
+ blocks_per_bucket=blocks_per_bucket,
+ header_data=header_data)
+ fsetup.close()
+ new_header_data = bytes(bytearray([1,1,1]))
+ with HeapStorage(
+ fname,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, header_data)
+ f.update_header_data(new_header_data)
+ self.assertEqual(f.header_data, new_header_data)
+ with HeapStorage(
+ fname,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ with self.assertRaises(ValueError):
+ with HeapStorage(
+ fname,
+ storage_type=self._type_name) as f:
+ f.update_header_data(bytes(bytearray([1,1])))
+ with self.assertRaises(ValueError):
+ with HeapStorage(
+ fname,
+ storage_type=self._type_name) as f:
+ f.update_header_data(bytes(bytearray([1,1,1,1])))
+ with HeapStorage(
+ fname,
+ storage_type=self._type_name) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ os.remove(fname)
+
+ def test_locked_flag(self):
+ with HeapStorage(self._testfname,
+ storage_type=self._type_name) as f:
+ with self.assertRaises(IOError):
+ with HeapStorage(self._testfname,
+ storage_type=self._type_name) as f1:
+ pass # pragma: no cover
+ with self.assertRaises(IOError):
+ with HeapStorage(self._testfname,
+ storage_type=self._type_name) as f1:
+ pass # pragma: no cover
+ with HeapStorage(self._testfname,
+ storage_type=self._type_name,
+ ignore_lock=True) as f1:
+ pass
+ with self.assertRaises(IOError):
+ with HeapStorage(self._testfname,
+ storage_type=self._type_name) as f1:
+ pass # pragma: no cover
+ with HeapStorage(self._testfname,
+ storage_type=self._type_name,
+ ignore_lock=True) as f1:
+ pass
+ with HeapStorage(self._testfname,
+ storage_type=self._type_name,
+ ignore_lock=True) as f1:
+ pass
+ with HeapStorage(self._testfname,
+ storage_type=self._type_name) as f:
+ pass
+
+ def test_read_path_cloned(self):
+
+ with HeapStorage(
+ self._testfname,
+ storage_type=self._type_name) as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ total_buckets = 0
+ for b in range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1):
+ data = f.read_path(b)
+ bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ total_buckets += len(bucket_path)
+ self.assertEqual(f.virtual_heap.Node(b).level+1,
+ len(bucket_path))
+ for i, bucket in zip(bucket_path, data):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ total_buckets*f.bucket_storage.block_size)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+ def test_write_path_cloned(self):
+ data = [bytearray([self._bucket_count]) * \
+ self._block_size * \
+ self._blocks_per_bucket
+ for i in xrange(self._block_count)]
+ with HeapStorage(
+ self._testfname,
+ storage_type=self._type_name) as forig:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ with forig.clone_device() as f:
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ total_buckets = 0
+ for b in range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1):
+ orig = f.read_path(b)
+ bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ total_buckets += len(bucket_path)
+ self.assertNotEqual(len(bucket_path), 0)
+ self.assertEqual(f.virtual_heap.Node(b).level+1,
+ len(bucket_path))
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+ f.write_path(b, [bytes(data[i])
+ for i in bucket_path])
+
+ new = f.read_path(b)
+ self.assertEqual(len(new), len(bucket_path))
+ for i, bucket in zip(bucket_path, new):
+ self.assertEqual(list(bytearray(bucket)),
+ list(data[i]))
+
+ f.write_path(b, [bytes(self._buckets[i])
+ for i in bucket_path])
+
+ orig = f.read_path(b)
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent,
+ total_buckets*f.bucket_storage.block_size*2)
+ self.assertEqual(f.bytes_received,
+ total_buckets*f.bucket_storage.block_size*3)
+ self.assertEqual(forig.bytes_sent, 0)
+ self.assertEqual(forig.bytes_received, 0)
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import os
+import unittest2
+import tempfile
+
+import pyoram.util.misc
+
+class Test(unittest2.TestCase):
+
+ def test_log2floor(self):
+ self.assertEqual(pyoram.util.misc.log2floor(1), 0)
+ self.assertEqual(pyoram.util.misc.log2floor(2), 1)
+ self.assertEqual(pyoram.util.misc.log2floor(3), 1)
+ self.assertEqual(pyoram.util.misc.log2floor(4), 2)
+ self.assertEqual(pyoram.util.misc.log2floor(5), 2)
+ self.assertEqual(pyoram.util.misc.log2floor(6), 2)
+ self.assertEqual(pyoram.util.misc.log2floor(7), 2)
+ self.assertEqual(pyoram.util.misc.log2floor(8), 3)
+ self.assertEqual(pyoram.util.misc.log2floor(9), 3)
+
+ def test_log2ceil(self):
+ self.assertEqual(pyoram.util.misc.log2ceil(1), 0)
+ self.assertEqual(pyoram.util.misc.log2ceil(2), 1)
+ self.assertEqual(pyoram.util.misc.log2ceil(3), 2)
+ self.assertEqual(pyoram.util.misc.log2ceil(4), 2)
+ self.assertEqual(pyoram.util.misc.log2ceil(5), 3)
+ self.assertEqual(pyoram.util.misc.log2ceil(6), 3)
+ self.assertEqual(pyoram.util.misc.log2ceil(7), 3)
+ self.assertEqual(pyoram.util.misc.log2ceil(8), 3)
+ self.assertEqual(pyoram.util.misc.log2ceil(9), 4)
+
+ def test_intdivceil(self):
+
+ with self.assertRaises(ZeroDivisionError):
+ pyoram.util.misc.intdivceil(0, 0)
+ with self.assertRaises(ZeroDivisionError):
+ pyoram.util.misc.intdivceil(1, 0)
+
+ self.assertEqual(pyoram.util.misc.intdivceil(1, 1), 1)
+ self.assertEqual(pyoram.util.misc.intdivceil(2, 3), 1)
+ self.assertEqual(2 // 3, 0)
+ self.assertEqual(pyoram.util.misc.intdivceil(
+ 123123123123123123123123123123123123123123123123,
+ 123123123123123123123123123123123123123123123123), 1)
+ self.assertEqual(pyoram.util.misc.intdivceil(
+ 2 * 123123123123123123123123123123123123123123123123,
+ 123123123123123123123123123123123123123123123123), 2)
+ self.assertEqual(pyoram.util.misc.intdivceil(
+ 2 * 123123123123123123123123123123123123123123123123 + 1,
+ 123123123123123123123123123123123123123123123123), 3)
+ self.assertEqual(pyoram.util.misc.intdivceil(
+ 2 * 123123123123123123123123123123123123123123123123 - 1,
+ 123123123123123123123123123123123123123123123123), 2)
+ self.assertEqual(
+ (2 * 123123123123123123123123123123123123123123123123 - 1) // \
+ 123123123123123123123123123123123123123123123123,
+ 1)
+
+ def test_MemorySize(self):
+ self.assertTrue("b" in str(pyoram.util.misc.MemorySize(0.1)))
+ self.assertTrue("B" in str(pyoram.util.misc.MemorySize(1)))
+ self.assertTrue("B" in str(pyoram.util.misc.MemorySize(999)))
+ self.assertTrue("KB" in str(pyoram.util.misc.MemorySize(1000)))
+ self.assertTrue("KB" in str(pyoram.util.misc.MemorySize(999999)))
+ self.assertTrue("MB" in str(pyoram.util.misc.MemorySize(1000000)))
+ self.assertTrue("MB" in str(pyoram.util.misc.MemorySize(999999999)))
+ self.assertTrue("GB" in str(pyoram.util.misc.MemorySize(1000000000)))
+ self.assertTrue("GB" in str(pyoram.util.misc.MemorySize(9999999999)))
+ self.assertTrue("TB" in str(pyoram.util.misc.MemorySize(1000000000000)))
+ self.assertTrue("b" in str(pyoram.util.misc.MemorySize(1, unit="b")))
+ self.assertTrue("b" in str(pyoram.util.misc.MemorySize(2, unit="b")))
+ self.assertTrue("b" in str(pyoram.util.misc.MemorySize(7.9, unit="b")))
+
+ self.assertTrue("B" in str(pyoram.util.misc.MemorySize(8, unit="b")))
+ self.assertTrue("B" in str(pyoram.util.misc.MemorySize(1, unit="B")))
+ self.assertTrue("B" in str(pyoram.util.misc.MemorySize(999, unit="B")))
+
+ self.assertTrue("KB" in str(pyoram.util.misc.MemorySize(1000, unit="B")))
+ self.assertTrue("KB" in str(pyoram.util.misc.MemorySize(1, unit="KB")))
+ self.assertTrue("KB" in str(pyoram.util.misc.MemorySize(999, unit="KB")))
+ self.assertTrue("MB" in str(pyoram.util.misc.MemorySize(1000, unit="KB")))
+ self.assertTrue("MB" in str(pyoram.util.misc.MemorySize(1, unit="MB")))
+ self.assertTrue("MB" in str(pyoram.util.misc.MemorySize(999, unit="MB")))
+ self.assertTrue("GB" in str(pyoram.util.misc.MemorySize(1000, unit="MB")))
+ self.assertTrue("GB" in str(pyoram.util.misc.MemorySize(1, unit="GB")))
+ self.assertTrue("GB" in str(pyoram.util.misc.MemorySize(999, unit="GB")))
+ self.assertTrue("TB" in str(pyoram.util.misc.MemorySize(1000, unit="GB")))
+ self.assertTrue("TB" in str(pyoram.util.misc.MemorySize(1, unit="TB")))
+
+ self.assertEqual(pyoram.util.misc.MemorySize(1024).KiB, 1)
+ self.assertEqual(pyoram.util.misc.MemorySize(1024**2).MiB, 1)
+ self.assertEqual(pyoram.util.misc.MemorySize(1024**3).GiB, 1)
+ self.assertEqual(pyoram.util.misc.MemorySize(1024**4).TiB, 1)
+
+ def test_saveload_private_key(self):
+ with tempfile.NamedTemporaryFile(delete=False) as f:
+ filename = f.name
+ try:
+ key = os.urandom(32)
+ pyoram.util.misc.save_private_key(filename, key)
+ loaded_key = pyoram.util.misc.load_private_key(filename)
+ self.assertEqual(key, loaded_key)
+ finally:
+ os.remove(filename)
+
+ def test_chunkiter(self):
+ self.assertEqual(list(pyoram.util.misc.chunkiter([1,2,3,4,5], 1)),
+ [[1],[2],[3],[4],[5]])
+ self.assertEqual(list(pyoram.util.misc.chunkiter([1,2,3,4,5], 2)),
+ [[1,2],[3,4],[5]])
+ self.assertEqual(list(pyoram.util.misc.chunkiter([1,2,3,4,5], 3)),
+ [[1,2,3],[4,5]])
+ self.assertEqual(list(pyoram.util.misc.chunkiter([1,2,3,4,5], 4)),
+ [[1,2,3,4],[5]])
+ self.assertEqual(list(pyoram.util.misc.chunkiter([1,2,3,4,5], 5)),
+ [[1,2,3,4,5]])
+ self.assertEqual(list(pyoram.util.misc.chunkiter([1,2,3,4,5], 6)),
+ [[1,2,3,4,5]])
+ self.assertEqual(list(pyoram.util.misc.chunkiter([], 1)),
+ [])
+ self.assertEqual(list(pyoram.util.misc.chunkiter([], 2)),
+ [])
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import sys
+import unittest2
+
+import pyoram
+
+is_pypy = False
+try:
+ import __pypy__
+ is_pypy = True
+except ImportError:
+ is_pypy = False
+
+class Test(unittest2.TestCase):
+
+ # See what Python versions the combined
+ # coverage report includes
+ def test_show_coverage(self):
+ if not is_pypy:
+ if sys.version_info.major == 2:
+ if sys.version_info.minor == 7:
+ print(sys.version_info)
+ elif sys.version_info.major == 3:
+ if sys.version_info.minor == 3:
+ print(sys.version_info)
+ elif sys.version_info.minor == 4:
+ print(sys.version_info)
+ elif sys.version_info.minor == 5:
+ print(sys.version_info)
+ if is_pypy:
+ if sys.version_info.major == 2:
+ if sys.version_info.minor == 7:
+ print(sys.version_info)
+
+ def test_version(self):
+ pyoram.__version__
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import os
+import unittest2
+import tempfile
+
+from pyoram.oblivious_storage.tree.path_oram import \
+ PathORAM
+from pyoram.storage.block_storage import \
+ BlockStorageTypeFactory
+from pyoram.encrypted_storage.encrypted_heap_storage import \
+ EncryptedHeapStorage
+from pyoram.crypto.aes import AES
+
+from six.moves import xrange
+
+thisdir = os.path.dirname(os.path.abspath(__file__))
+
+class _TestPathORAMBase(object):
+
+ _type_name = None
+ _aes_mode = None
+ _test_key = None
+ _test_key_size = None
+ _bucket_capacity = None
+ _heap_base = None
+ _kwds = None
+
+ @classmethod
+ def setUpClass(cls):
+ assert cls._type_name is not None
+ assert cls._aes_mode is not None
+ assert not ((cls._test_key is not None) and \
+ (cls._test_key_size is not None))
+ assert cls._bucket_capacity is not None
+ assert cls._heap_base is not None
+ assert cls._kwds is not None
+ fd, cls._dummy_name = tempfile.mkstemp()
+ os.close(fd)
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ cls._block_size = 25
+ cls._block_count = 47
+ cls._testfname = cls.__name__ + "_testfile.bin"
+ cls._blocks = []
+ f = PathORAM.setup(
+ cls._testfname,
+ cls._block_size,
+ cls._block_count,
+ bucket_capacity=cls._bucket_capacity,
+ heap_base=cls._heap_base,
+ key_size=cls._test_key_size,
+ key=cls._test_key,
+ storage_type=cls._type_name,
+ aes_mode=cls._aes_mode,
+ initialize=lambda i: bytes(bytearray([i])*cls._block_size),
+ ignore_existing=True,
+ **cls._kwds)
+ f.close()
+ cls._key = f.key
+ cls._stash = f.stash
+ cls._position_map = f.position_map
+ for i in range(cls._block_count):
+ data = bytearray([i])*cls._block_size
+ cls._blocks.append(data)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ os.remove(cls._testfname)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+
+ def test_setup_fails(self):
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ PathORAM.setup(
+ os.path.join(thisdir,
+ "baselines",
+ "exists.empty"),
+ block_size=10,
+ block_count=10,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ PathORAM.setup(
+ os.path.join(thisdir,
+ "baselines",
+ "exists.empty"),
+ block_size=10,
+ block_count=10,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ storage_type=self._type_name,
+ aes_mode=self._aes_mode,
+ ignore_existing=False,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=0,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=0,
+ block_count=1,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(TypeError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=2,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=None,
+ storage_type=self._type_name,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ bucket_capacity=0,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=1,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key_size=-1,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(TypeError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=-1,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=AES.KeyGen(AES.key_sizes[-1]),
+ key_size=AES.key_sizes[-1],
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(ValueError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=os.urandom(AES.key_sizes[-1]+100),
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ with self.assertRaises(ValueError):
+ PathORAM.setup(
+ self._dummy_name,
+ block_size=1,
+ block_count=1,
+ heap_height=1,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._key,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+
+ def test_setup(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ bcount = 11
+ fsetup = PathORAM.setup(
+ fname,
+ bsize,
+ bcount,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ **self._kwds)
+ fsetup.close()
+ self.assertEqual(type(fsetup.raw_storage),
+ BlockStorageTypeFactory(self._type_name))
+ # test that these can be called with default keyword values
+ fsetup.stash_digest(fsetup.stash)
+ fsetup.position_map_digest(fsetup.position_map)
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ PathORAM.compute_storage_size(
+ bsize,
+ bcount,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name))
+ self.assertEqual(
+ flen >
+ PathORAM.compute_storage_size(
+ bsize,
+ bcount,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ ignore_header=True),
+ True)
+ with PathORAM(fname,
+ fsetup.stash,
+ fsetup.position_map,
+ key=fsetup.key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(fsetup.header_data, bytes())
+ self.assertEqual(f.key, fsetup.key)
+ self.assertEqual(f.block_size, bsize)
+ self.assertEqual(fsetup.block_size, bsize)
+ self.assertEqual(f.block_count, bcount)
+ self.assertEqual(fsetup.block_count, bcount)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ os.remove(fname)
+
+ def test_setup_withdata(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ bcount = 11
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = PathORAM.setup(
+ fname,
+ block_size=bsize,
+ block_count=bcount,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=header_data,
+ **self._kwds)
+ fsetup.close()
+ self.assertEqual(type(fsetup.raw_storage),
+ BlockStorageTypeFactory(self._type_name))
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ PathORAM.compute_storage_size(
+ bsize,
+ bcount,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=header_data))
+ self.assertTrue(len(header_data) > 0)
+ self.assertEqual(
+ PathORAM.compute_storage_size(
+ bsize,
+ bcount,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name) <
+ PathORAM.compute_storage_size(
+ bsize,
+ bcount,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=header_data),
+ True)
+ self.assertEqual(
+ flen >
+ PathORAM.compute_storage_size(
+ bsize,
+ bcount,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ aes_mode=self._aes_mode,
+ storage_type=self._type_name,
+ header_data=header_data,
+ ignore_header=True),
+ True)
+ with PathORAM(fname,
+ fsetup.stash,
+ fsetup.position_map,
+ key=fsetup.key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ self.assertEqual(f.header_data, header_data)
+ self.assertEqual(fsetup.header_data, header_data)
+ self.assertEqual(f.key, fsetup.key)
+ self.assertEqual(f.block_size, bsize)
+ self.assertEqual(fsetup.block_size, bsize)
+ self.assertEqual(f.block_count, bcount)
+ self.assertEqual(fsetup.block_count, bcount)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ os.remove(fname)
+
+ def test_init_noexists(self):
+ self.assertEqual(os.path.exists(self._dummy_name), False)
+ with self.assertRaises(IOError):
+ with PathORAM(
+ self._dummy_name,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ pass # pragma: no cover
+
+ def test_init_exists(self):
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with open(self._testfname, 'rb') as f:
+ databefore = f.read()
+ # no key
+ with self.assertRaises(ValueError):
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ pass # pragma: no cover
+ # stash does not match digest
+ with self.assertRaises(ValueError):
+ with PathORAM(self._testfname,
+ {1: bytes()},
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ pass # pragma: no cover
+ # stash hash invalid key (negative)
+ with self.assertRaises(ValueError):
+ with PathORAM(self._testfname,
+ {-1: bytes()},
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ pass # pragma: no cover
+ # position map has invalid item (negative)
+ with self.assertRaises(ValueError):
+ with PathORAM(self._testfname,
+ self._stash,
+ [-1],
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ pass # pragma: no cover
+ # position map does not match digest
+ with self.assertRaises(ValueError):
+ with PathORAM(self._testfname,
+ self._stash,
+ [1],
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ pass # pragma: no cover
+ with self.assertRaises(ValueError):
+ with EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._type_name) as fb:
+ with PathORAM(fb,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ pass # pragma: no cover
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ self.assertEqual(f.key, self._key)
+ self.assertEqual(f.block_size, self._block_size)
+ self.assertEqual(f.block_count, self._block_count)
+ self.assertEqual(f.storage_name, self._testfname)
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with open(self._testfname, 'rb') as f:
+ dataafter = f.read()
+ self.assertEqual(databefore[-(self._block_count*self._block_size):],
+ dataafter[-(self._block_count*self._block_size):])
+
+ def test_read_block(self):
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in enumerate(self._blocks):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ for i, data in reversed(list(enumerate(self._blocks))):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(self._blocks[i]))
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ self.assertEqual(list(bytearray(f.read_block(0))),
+ list(self._blocks[0]))
+ self.assertEqual(list(bytearray(f.read_block(self._block_count-1))),
+ list(self._blocks[-1]))
+
+ # test eviction behavior of the tree oram helper
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ oram = f._oram
+ vheap = oram.storage_heap.virtual_heap
+ Z = vheap.blocks_per_bucket
+ def _has_vacancies(level):
+ return any(oram.path_block_ids[i] == oram.empty_block_id
+ for i in range(level*Z, (level+1)*Z))
+
+ for i in range(len(f.position_map)):
+ b = f.position_map[i]
+ f.position_map[i] = vheap.random_leaf_bucket()
+ oram.load_path(b)
+ block = oram.extract_block_from_path(i)
+ if block is not None:
+ oram.stash[i] = block
+
+ # track where everyone should be able to move
+ # to, unless the bucket fills up
+ eviction_levels = {}
+ for id_, level in zip(oram.path_block_ids,
+ oram.path_block_eviction_levels):
+ eviction_levels[id_] = level
+ for id_ in oram.stash:
+ block_id, block_addr = \
+ oram.get_block_info(oram.stash[id_])
+ assert block_id == id_
+ eviction_levels[id_] = \
+ vheap.clib.calculate_last_common_level(
+ vheap.k, b, block_addr)
+
+ oram.push_down_path()
+ oram.fill_path_from_stash()
+ oram.evict_path()
+
+ # check that everyone was pushed down greedily
+ oram.load_path(b)
+ for pos, id_ in enumerate(oram.path_block_ids):
+ current_level = pos // Z
+ if (id_ != oram.empty_block_id):
+ eviction_level = eviction_levels[id_]
+ self.assertEqual(current_level <= eviction_level, True)
+ if current_level < eviction_level:
+ self.assertEqual(_has_vacancies(eviction_level), False)
+ for id_ in oram.stash:
+ self.assertEqual(
+ _has_vacancies(eviction_levels[id_]), False)
+
+ def test_write_block(self):
+ data = bytearray([self._block_count])*self._block_size
+ self.assertEqual(len(data) > 0, True)
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ for i in xrange(self._block_count):
+ self.assertNotEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i in xrange(self._block_count):
+ f.write_block(i, bytes(data))
+ for i in xrange(self._block_count):
+ self.assertEqual(list(bytearray(f.read_block(i))),
+ list(data))
+ for i, block in enumerate(self._blocks):
+ f.write_block(i, bytes(block))
+
+ def test_read_blocks(self):
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ data = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ data = f.read_blocks([0])
+ self.assertEqual(len(data), 1)
+ self.assertEqual(list(bytearray(data[0])),
+ list(self._blocks[0]))
+ self.assertEqual(len(self._blocks) > 1, True)
+ data = f.read_blocks(list(xrange(1, self._block_count)) + [0])
+ self.assertEqual(len(data), self._block_count)
+ for i, block in enumerate(data[:-1], 1):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ self.assertEqual(list(bytearray(data[-1])),
+ list(self._blocks[0]))
+
+ def test_write_blocks(self):
+ data = [bytearray([self._block_count])*self._block_size
+ for i in xrange(self._block_count)]
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in data])
+ new = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(new), self._block_count)
+ for i, block in enumerate(new):
+ self.assertEqual(list(bytearray(block)),
+ list(data[i]))
+ f.write_blocks(list(xrange(self._block_count)),
+ [bytes(b) for b in self._blocks])
+ orig = f.read_blocks(list(xrange(self._block_count)))
+ self.assertEqual(len(orig), self._block_count)
+ for i, block in enumerate(orig):
+ self.assertEqual(list(bytearray(block)),
+ list(self._blocks[i]))
+
+ def test_update_header_data(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ bcount = 11
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = PathORAM.setup(
+ fname,
+ block_size=bsize,
+ block_count=bcount,
+ bucket_capacity=self._bucket_capacity,
+ heap_base=self._heap_base,
+ key=self._test_key,
+ key_size=self._test_key_size,
+ header_data=header_data,
+ **self._kwds)
+ fsetup.close()
+ new_header_data = bytes(bytearray([1,1,1]))
+ with PathORAM(fname,
+ fsetup.stash,
+ fsetup.position_map,
+ key=fsetup.key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ self.assertEqual(f.header_data, header_data)
+ f.update_header_data(new_header_data)
+ self.assertEqual(f.header_data, new_header_data)
+ with PathORAM(fname,
+ fsetup.stash,
+ fsetup.position_map,
+ key=fsetup.key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ with self.assertRaises(ValueError):
+ with PathORAM(fname,
+ fsetup.stash,
+ fsetup.position_map,
+ key=fsetup.key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ f.update_header_data(bytes(bytearray([1,1])))
+ with self.assertRaises(ValueError):
+ with PathORAM(fname,
+ fsetup.stash,
+ fsetup.position_map,
+ key=fsetup.key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ f.update_header_data(bytes(bytearray([1,1,1,1])))
+ with PathORAM(fname,
+ fsetup.stash,
+ fsetup.position_map,
+ key=fsetup.key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ os.remove(fname)
+
+ def test_locked_flag(self):
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ with self.assertRaises(IOError):
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f1:
+ pass # pragma: no cover
+ with self.assertRaises(IOError):
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f1:
+ pass # pragma: no cover
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ ignore_lock=True,
+ **self._kwds) as f1:
+ pass
+ with self.assertRaises(IOError):
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f1:
+ pass # pragma: no cover
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ ignore_lock=True,
+ **self._kwds) as f1:
+ pass
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ ignore_lock=True,
+ **self._kwds) as f1:
+ pass
+ with PathORAM(self._testfname,
+ self._stash,
+ self._position_map,
+ key=self._key,
+ storage_type=self._type_name,
+ **self._kwds) as f:
+ pass
+
+class TestPathORAMB2Z1(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'ctr'
+ _bucket_capacity = 1
+ _heap_base = 2
+ _kwds = {'cached_levels': 0}
+
+class TestPathORAMB2Z2(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'mmap'
+ _aes_mode = 'gcm'
+ _bucket_capacity = 2
+ _heap_base = 2
+ _kwds = {'cached_levels': 0}
+
+class TestPathORAMB2Z3(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'mmap'
+ _aes_mode = 'ctr'
+ _bucket_capacity = 3
+ _heap_base = 2
+ _kwds = {'cached_levels': 1}
+
+class TestPathORAMB2Z4(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'gcm'
+ _bucket_capacity = 4
+ _heap_base = 2
+ _kwds = {'cached_levels': 1}
+
+class TestPathORAMB2Z5(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'ctr'
+ _bucket_capacity = 5
+ _heap_base = 2
+ _kwds = {'cached_levels': 2,
+ 'concurrency_level': 0}
+
+class TestPathORAMB3Z1(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'ctr'
+ _bucket_capacity = 1
+ _heap_base = 3
+ _kwds = {'cached_levels': 2,
+ 'concurrency_level': 1}
+
+class TestPathORAMB3Z2(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'mmap'
+ _aes_mode = 'gcm'
+ _bucket_capacity = 2
+ _heap_base = 3
+ _kwds = {}
+
+class TestPathORAMB3Z3(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'mmap'
+ _aes_mode = 'ctr'
+ _bucket_capacity = 3
+ _heap_base = 3
+ _kwds = {}
+
+class TestPathORAMB3Z4(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'gcm'
+ _bucket_capacity = 4
+ _heap_base = 3
+ _kwds = {}
+
+class TestPathORAMB3Z5(_TestPathORAMBase,
+ unittest2.TestCase):
+ _type_name = 'file'
+ _aes_mode = 'ctr'
+ _bucket_capacity = 5
+ _heap_base = 3
+ _kwds = {}
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import os
+import unittest2
+import tempfile
+import random
+
+from pyoram.storage.block_storage import \
+ BlockStorageTypeFactory
+from pyoram.encrypted_storage.top_cached_encrypted_heap_storage import \
+ TopCachedEncryptedHeapStorage
+from pyoram.encrypted_storage.encrypted_block_storage import \
+ EncryptedBlockStorage
+from pyoram.encrypted_storage.encrypted_heap_storage import \
+ EncryptedHeapStorage
+from pyoram.crypto.aes import AES
+
+from six.moves import xrange
+
+thisdir = os.path.dirname(os.path.abspath(__file__))
+
+class _TestTopCachedEncryptedHeapStorage(object):
+
+ _init_kwds = None
+ _storage_type = None
+ _heap_base = None
+ _heap_height = None
+
+ @classmethod
+ def setUpClass(cls):
+ assert cls._init_kwds is not None
+ assert cls._storage_type is not None
+ assert cls._heap_base is not None
+ assert cls._heap_height is not None
+ fd, cls._dummy_name = tempfile.mkstemp()
+ os.close(fd)
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ cls._block_size = 50
+ cls._blocks_per_bucket = 3
+ cls._bucket_count = \
+ ((cls._heap_base**(cls._heap_height+1)) - 1)//(cls._heap_base-1)
+ cls._block_count = cls._bucket_count * \
+ cls._blocks_per_bucket
+ cls._testfname = cls.__name__ + "_testfile.bin"
+ cls._buckets = []
+ f = EncryptedHeapStorage.setup(
+ cls._testfname,
+ cls._block_size,
+ cls._heap_height,
+ heap_base=cls._heap_base,
+ blocks_per_bucket=cls._blocks_per_bucket,
+ storage_type=cls._storage_type,
+ initialize=lambda i: bytes(bytearray([i]) * \
+ cls._block_size * \
+ cls._blocks_per_bucket),
+ ignore_existing=True)
+ f.close()
+ cls._key = f.key
+ for i in range(cls._bucket_count):
+ data = bytearray([i]) * \
+ cls._block_size * \
+ cls._blocks_per_bucket
+ cls._buckets.append(data)
+
+ @classmethod
+ def tearDownClass(cls):
+ try:
+ os.remove(cls._testfname)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ try:
+ os.remove(cls._dummy_name)
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+
+ def test_factory(self):
+ kwds = dict(self._init_kwds)
+ kwds['cached_levels'] = 0
+ with EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._storage_type) as f1:
+ with TopCachedEncryptedHeapStorage(f1, **kwds) as f2:
+ self.assertTrue(f1 is f2)
+
+ def test_setup(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ blocks_per_bucket = 3
+ fsetup = EncryptedHeapStorage.setup(
+ fname,
+ bsize,
+ self._heap_height,
+ heap_base=self._heap_base,
+ storage_type=self._storage_type,
+ blocks_per_bucket=blocks_per_bucket)
+ fsetup.close()
+ self.assertEqual(type(fsetup.raw_storage),
+ BlockStorageTypeFactory(self._storage_type))
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ TopCachedEncryptedHeapStorage.compute_storage_size(
+ bsize,
+ self._heap_height,
+ heap_base=self._heap_base,
+ blocks_per_bucket=blocks_per_bucket))
+ self.assertEqual(
+ flen >
+ TopCachedEncryptedHeapStorage.compute_storage_size(
+ bsize,
+ self._heap_height,
+ heap_base=self._heap_base,
+ blocks_per_bucket=blocks_per_bucket,
+ ignore_header=True),
+ True)
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(fsetup.header_data, bytes())
+ self.assertEqual(f.key, fsetup.key)
+ self.assertEqual(f.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(fsetup.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(f.bucket_count,
+ (self._heap_base**(self._heap_height+1) - 1)//(self._heap_base-1))
+ self.assertEqual(fsetup.bucket_count,
+ (self._heap_base**(self._heap_height+1) - 1)//(self._heap_base-1))
+ self.assertEqual(f.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(fsetup.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ os.remove(fname)
+
+ def test_setup_withdata(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ blocks_per_bucket = 1
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = EncryptedHeapStorage.setup(
+ fname,
+ bsize,
+ self._heap_height,
+ heap_base=self._heap_base,
+ storage_type=self._storage_type,
+ blocks_per_bucket=blocks_per_bucket,
+ header_data=header_data)
+ fsetup.close()
+ self.assertEqual(type(fsetup.raw_storage),
+ BlockStorageTypeFactory(self._storage_type))
+ with open(fname, 'rb') as f:
+ flen = len(f.read())
+ self.assertEqual(
+ flen,
+ TopCachedEncryptedHeapStorage.compute_storage_size(
+ bsize,
+ self._heap_height,
+ heap_base=self._heap_base,
+ header_data=header_data))
+ self.assertTrue(len(header_data) > 0)
+ self.assertEqual(
+ TopCachedEncryptedHeapStorage.compute_storage_size(
+ bsize,
+ self._heap_height,
+ heap_base=self._heap_base,
+ storage_type=self._storage_type) <
+ TopCachedEncryptedHeapStorage.compute_storage_size(
+ bsize,
+ self._heap_height,
+ heap_base=self._heap_base,
+ storage_type=self._storage_type,
+ header_data=header_data),
+ True)
+ self.assertEqual(
+ flen >
+ TopCachedEncryptedHeapStorage.compute_storage_size(
+ bsize,
+ self._heap_height,
+ heap_base=self._heap_base,
+ storage_type=self._storage_type,
+ header_data=header_data,
+ ignore_header=True),
+ True)
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ self.assertEqual(f.header_data, header_data)
+ self.assertEqual(fsetup.header_data, header_data)
+ self.assertEqual(f.key, fsetup.key)
+ self.assertEqual(f.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(fsetup.blocks_per_bucket,
+ blocks_per_bucket)
+ self.assertEqual(f.bucket_count,
+ (self._heap_base**(self._heap_height+1) - 1)//(self._heap_base-1))
+ self.assertEqual(fsetup.bucket_count,
+ (self._heap_base**(self._heap_height+1) - 1)//(self._heap_base-1))
+ self.assertEqual(f.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(fsetup.bucket_size,
+ bsize * blocks_per_bucket)
+ self.assertEqual(f.storage_name, fname)
+ self.assertEqual(fsetup.storage_name, fname)
+ os.remove(fname)
+
+ def test_init_exists(self):
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type) as f:
+ databefore = f.read_blocks(list(range(f.block_count)))
+ with self.assertRaises(ValueError):
+ with EncryptedBlockStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type) as fb:
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(fb, key=self._key),
+ **self._init_kwds) as f:
+ pass # pragma: no cover
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ self.assertEqual(f.key, self._key)
+ self.assertEqual(f.bucket_size,
+ self._block_size * \
+ self._blocks_per_bucket)
+ self.assertEqual(f.bucket_count,
+ self._bucket_count)
+ self.assertEqual(f.storage_name, self._testfname)
+ self.assertEqual(f.header_data, bytes())
+ self.assertEqual(os.path.exists(self._testfname), True)
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ dataafter = f.bucket_storage.read_blocks(
+ list(range(f.bucket_storage.block_count)))
+ self.assertEqual(databefore, dataafter)
+
+ def test_read_path(self):
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ total_buckets = 0
+ for b in range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1):
+ full_bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ all_level_starts = list(range(len(full_bucket_path)+1))
+ for level_start in all_level_starts:
+ data = f.read_path(b, level_start=level_start)
+ bucket_path = full_bucket_path[level_start:]
+
+ if len(full_bucket_path) <= f._external_level:
+ pass
+ elif level_start >= f._external_level:
+ total_buckets += len(bucket_path)
+ else:
+ total_buckets += len(full_bucket_path[f._external_level:])
+
+ self.assertEqual(f.virtual_heap.Node(b).level+1-level_start,
+ len(bucket_path))
+ for i, bucket in zip(bucket_path, data):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received,
+ total_buckets*f.bucket_storage._storage.block_size)
+
+ def test_write_path(self):
+ data = [bytearray([self._bucket_count]) * \
+ self._block_size * \
+ self._blocks_per_bucket
+ for i in xrange(self._block_count)]
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+
+ self.assertEqual(
+ f.virtual_heap.first_bucket_at_level(0), 0)
+ self.assertNotEqual(
+ f.virtual_heap.last_leaf_bucket(), 0)
+ all_buckets = list(range(f.virtual_heap.first_bucket_at_level(0),
+ f.virtual_heap.last_leaf_bucket()+1))
+ random.shuffle(all_buckets)
+ total_read_buckets = 0
+ total_write_buckets = 0
+ for b in all_buckets:
+ full_bucket_path = f.virtual_heap.Node(b).\
+ bucket_path_from_root()
+ all_level_starts = list(range(len(full_bucket_path)+1))
+ random.shuffle(all_level_starts)
+ for level_start in all_level_starts:
+ orig = f.read_path(b, level_start=level_start)
+ bucket_path = full_bucket_path[level_start:]
+
+ if len(full_bucket_path) <= f._external_level:
+ pass
+ elif level_start >= f._external_level:
+ total_read_buckets += len(bucket_path)
+ else:
+ total_read_buckets += len(full_bucket_path[f._external_level:])
+
+ if level_start != len(full_bucket_path):
+ self.assertNotEqual(len(bucket_path), 0)
+ self.assertEqual(f.virtual_heap.Node(b).level+1-level_start,
+ len(bucket_path))
+ self.assertEqual(len(orig), len(bucket_path))
+
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ f.write_path(b, [bytes(data[i])
+ for i in bucket_path],
+ level_start=level_start)
+ if len(full_bucket_path) <= f._external_level:
+ pass
+ elif level_start >= f._external_level:
+ total_write_buckets += len(bucket_path)
+ else:
+ total_write_buckets += len(full_bucket_path[f._external_level:])
+
+ new = f.read_path(b, level_start=level_start)
+ if len(full_bucket_path) <= f._external_level:
+ pass
+ elif level_start >= f._external_level:
+ total_read_buckets += len(bucket_path)
+ else:
+ total_read_buckets += len(full_bucket_path[f._external_level:])
+
+ self.assertEqual(len(new), len(bucket_path))
+ for i, bucket in zip(bucket_path, new):
+ self.assertEqual(list(bytearray(bucket)),
+ list(data[i]))
+
+ f.write_path(b, [bytes(self._buckets[i])
+ for i in bucket_path],
+ level_start=level_start)
+ if len(full_bucket_path) <= f._external_level:
+ pass
+ elif level_start >= f._external_level:
+ total_write_buckets += len(bucket_path)
+ else:
+ total_write_buckets += len(full_bucket_path[f._external_level:])
+
+
+ orig = f.read_path(b, level_start=level_start)
+ if len(full_bucket_path) <= f._external_level:
+ pass
+ elif level_start >= f._external_level:
+ total_read_buckets += len(bucket_path)
+ else:
+ total_read_buckets += len(full_bucket_path[f._external_level:])
+
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ full_orig = f.read_path(b)
+ if len(full_bucket_path) <= f._external_level:
+ pass
+ else:
+ total_read_buckets += len(full_bucket_path[f._external_level:])
+
+ for i, bucket in zip(full_bucket_path, full_orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+ for c in xrange(self._heap_base):
+ cn = f.virtual_heap.Node(b).child_node(c)
+ if not f.virtual_heap.is_nil_node(cn):
+ cb = cn.bucket
+ bucket_path = f.virtual_heap.Node(cb).\
+ bucket_path_from_root()
+ orig = f.read_path(cb)
+ if len(bucket_path) <= f._external_level:
+ pass
+ else:
+ total_read_buckets += len(bucket_path[f._external_level:])
+ self.assertEqual(len(orig), len(bucket_path))
+ for i, bucket in zip(bucket_path, orig):
+ self.assertEqual(list(bytearray(bucket)),
+ list(self._buckets[i]))
+
+ self.assertEqual(f.bytes_sent,
+ total_write_buckets*f.bucket_storage._storage.block_size)
+ self.assertEqual(f.bytes_received,
+ total_read_buckets*f.bucket_storage._storage.block_size)
+
+ def test_update_header_data(self):
+ fname = ".".join(self.id().split(".")[1:])
+ fname += ".bin"
+ fname = os.path.join(thisdir, fname)
+ if os.path.exists(fname):
+ os.remove(fname) # pragma: no cover
+ bsize = 10
+ blocks_per_bucket = 1
+ header_data = bytes(bytearray([0,1,2]))
+ fsetup = EncryptedHeapStorage.setup(
+ fname,
+ bsize,
+ self._heap_height,
+ heap_base=self._heap_base,
+ blocks_per_bucket=blocks_per_bucket,
+ header_data=header_data)
+ fsetup.close()
+ new_header_data = bytes(bytearray([1,1,1]))
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ self.assertEqual(f.header_data, header_data)
+ f.update_header_data(new_header_data)
+ self.assertEqual(f.header_data, new_header_data)
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ with self.assertRaises(ValueError):
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ f.update_header_data(bytes(bytearray([1,1])))
+ with self.assertRaises(ValueError):
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ f.update_header_data(bytes(bytearray([1,1,1,1])))
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(
+ fname,
+ key=fsetup.key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ self.assertEqual(f.header_data, new_header_data)
+ os.remove(fname)
+
+ def test_locked_flag(self):
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ with self.assertRaises(IOError):
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f1:
+ pass # pragma: no cover
+ with self.assertRaises(IOError):
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f1:
+ pass # pragma: no cover
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type,
+ ignore_lock=True),
+ **self._init_kwds) as f1:
+ pass
+ with self.assertRaises(IOError):
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f1:
+ pass # pragma: no cover
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type,
+ ignore_lock=True),
+ **self._init_kwds) as f1:
+ pass
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type,
+ ignore_lock=True),
+ **self._init_kwds) as f1:
+ pass
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ pass
+
+ def test_cache_size(self):
+ with TopCachedEncryptedHeapStorage(
+ EncryptedHeapStorage(self._testfname,
+ key=self._key,
+ storage_type=self._storage_type),
+ **self._init_kwds) as f:
+ num_cached_levels = self._init_kwds.get('cached_levels', 1)
+ if num_cached_levels < 0:
+ num_cached_levels = f.virtual_heap.levels
+ cache_bucket_count = 0
+ for l in xrange(num_cached_levels):
+ if l <= f.virtual_heap.last_level:
+ cache_bucket_count += f.virtual_heap.bucket_count_at_level(l)
+ self.assertEqual(cache_bucket_count > 0, True)
+ self.assertEqual(len(f.cached_bucket_data),
+ cache_bucket_count * f.bucket_size)
+
+ self.assertEqual(f.bytes_sent, 0)
+ self.assertEqual(f.bytes_received, 0)
+ self.assertEqual(f._root_device.bytes_sent, 0)
+ self.assertEqual(
+ f._root_device.bytes_received,
+ cache_bucket_count*f._root_device.bucket_storage._storage.block_size)
+
+class TestTopCachedEncryptedHeapStorageCacheMMapDefault(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {}
+ _storage_type = 'mmap'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageMMapCache1(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 1}
+ _storage_type = 'mmap'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageMMapCache2(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 2}
+ _storage_type = 'mmap'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageMMapCache3(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 3}
+ _storage_type = 'mmap'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageMMapCache4(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 4}
+ _storage_type = 'mmap'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageMMapCache5(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 5}
+ _storage_type = 'mmap'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageCacheFileDefault(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageFileCache1(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 1}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageFileCache2(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 2}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageFileCache3(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 3}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageFileCache4(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 4}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageFileCache5(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 5}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageFileCacheBigConcurrency0(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 20,
+ 'concurrency_level': 0}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageFileCache6Concurrency1(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 6,
+ 'concurrency_level': 1}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageFileCache3ConcurrencyBig(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 3,
+ 'concurrency_level': 20}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 7
+
+class TestTopCachedEncryptedHeapStorageFileCache3Concurrency1Base3(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': 3,
+ 'concurrency_level': 3}
+ _storage_type = 'file'
+ _heap_base = 3
+ _heap_height = 4
+
+class TestTopCachedEncryptedHeapStorageFileCacheAll(
+ _TestTopCachedEncryptedHeapStorage,
+ unittest2.TestCase):
+ _init_kwds = {'cached_levels': -1}
+ _storage_type = 'file'
+ _heap_base = 2
+ _heap_height = 3
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import os
+import subprocess
+import unittest2
+
+import pyoram
+from pyoram.util.virtual_heap import \
+ (VirtualHeap,
+ SizedVirtualHeap,
+ max_k_labeled,
+ calculate_bucket_count_in_heap_with_height,
+ calculate_bucket_count_in_heap_at_level,
+ calculate_bucket_level,
+ calculate_last_common_level,
+ calculate_necessary_heap_height,
+ basek_string_to_base10_integer,
+ numerals,
+ _clib)
+
+from six.moves import xrange
+
+thisdir = os.path.dirname(os.path.abspath(__file__))
+baselinedir = os.path.join(thisdir, "baselines")
+
+try:
+ has_dot = not subprocess.call(['dot','-?'],
+ stdout=subprocess.PIPE)
+except:
+ has_dot = False
+
+_test_bases = list(xrange(2, 15)) + [max_k_labeled+1]
+_test_labeled_bases = list(xrange(2, 15)) + [max_k_labeled]
+
+def _do_preorder(x):
+ if x.level > 2:
+ return
+ yield x.bucket
+ for c in xrange(x.k):
+ for b in _do_preorder(x.child_node(c)):
+ yield b
+
+def _do_postorder(x):
+ if x.level > 2:
+ return
+ for c in xrange(x.k):
+ for b in _do_postorder(x.child_node(c)):
+ yield b
+ yield x.bucket
+
+def _do_inorder(x):
+ assert x.k == 2
+ if x.level > 2:
+ return
+ for b in _do_inorder(x.child_node(0)):
+ yield b
+ yield x.bucket
+ for b in _do_inorder(x.child_node(1)):
+ yield b
+
+class TestVirtualHeapNode(unittest2.TestCase):
+
+ def test_init(self):
+ for k in _test_bases:
+ vh = VirtualHeap(k)
+ node = vh.Node(0)
+ self.assertEqual(node.k, k)
+ self.assertEqual(node.bucket, 0)
+ self.assertEqual(node.level, 0)
+ for b in xrange(1, k+1):
+ node = vh.Node(b)
+ self.assertEqual(node.k, k)
+ self.assertEqual(node.bucket, b)
+ self.assertEqual(node.level, 1)
+
+ def test_level(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(Node(0).level, 0)
+ self.assertEqual(Node(1).level, 1)
+ self.assertEqual(Node(2).level, 1)
+ self.assertEqual(Node(3).level, 2)
+ self.assertEqual(Node(4).level, 2)
+ self.assertEqual(Node(5).level, 2)
+ self.assertEqual(Node(6).level, 2)
+ self.assertEqual(Node(7).level, 3)
+
+ Node = VirtualHeap(3).Node
+ self.assertEqual(Node(0).level, 0)
+ self.assertEqual(Node(1).level, 1)
+ self.assertEqual(Node(2).level, 1)
+ self.assertEqual(Node(3).level, 1)
+ self.assertEqual(Node(4).level, 2)
+ self.assertEqual(Node(5).level, 2)
+ self.assertEqual(Node(6).level, 2)
+ self.assertEqual(Node(7).level, 2)
+ self.assertEqual(Node(8).level, 2)
+ self.assertEqual(Node(9).level, 2)
+ self.assertEqual(Node(10).level, 2)
+ self.assertEqual(Node(11).level, 2)
+ self.assertEqual(Node(12).level, 2)
+ self.assertEqual(Node(13).level, 3)
+
+ def test_hash(self):
+ x1 = VirtualHeap(3).Node(5)
+ x2 = VirtualHeap(2).Node(5)
+ self.assertNotEqual(id(x1), id(x2))
+ self.assertEqual(x1, x2)
+ self.assertEqual(x1, x1)
+ self.assertEqual(x2, x2)
+
+ all_node_set = set()
+ all_node_list = list()
+ for k in _test_bases:
+ node_set = set()
+ node_list = list()
+ Node = VirtualHeap(k).Node
+ for height in xrange(k+2):
+ node = Node(height)
+ node_set.add(node)
+ all_node_set.add(node)
+ node_list.append(node)
+ all_node_list.append(node)
+ self.assertEqual(sorted(node_set),
+ sorted(node_list))
+ self.assertNotEqual(sorted(all_node_set),
+ sorted(all_node_list))
+ def test_int(self):
+ Node2 = VirtualHeap(2).Node
+ Node3 = VirtualHeap(3).Node
+ for b in range(100):
+ self.assertEqual(int(Node2(b)), b)
+ self.assertEqual(int(Node3(b)), b)
+
+ def test_lt(self):
+ Node = VirtualHeap(3).Node
+ self.assertEqual(Node(5) < 4, False)
+ self.assertEqual(Node(5) < 5, False)
+ self.assertEqual(Node(5) < 6, True)
+
+ def test_le(self):
+ Node = VirtualHeap(3).Node
+ self.assertEqual(Node(5) <= 4, False)
+ self.assertEqual(Node(5) <= 5, True)
+ self.assertEqual(Node(5) <= 6, True)
+
+ def test_eq(self):
+ Node = VirtualHeap(3).Node
+ self.assertEqual(Node(5) == 4, False)
+ self.assertEqual(Node(5) == 5, True)
+ self.assertEqual(Node(5) == 6, False)
+
+ def test_ne(self):
+ Node = VirtualHeap(3).Node
+ self.assertEqual(Node(5) != 4, True)
+ self.assertEqual(Node(5) != 5, False)
+ self.assertEqual(Node(5) != 6, True)
+
+ def test_gt(self):
+ Node = VirtualHeap(3).Node
+ self.assertEqual(Node(5) > 4, True)
+ self.assertEqual(Node(5) > 5, False)
+ self.assertEqual(Node(5) > 6, False)
+
+ def test_ge(self):
+ Node = VirtualHeap(3).Node
+ self.assertEqual(Node(5) >= 4, True)
+ self.assertEqual(Node(5) >= 5, True)
+ self.assertEqual(Node(5) >= 6, False)
+
+ def test_last_common_level_k2(self):
+ Node = VirtualHeap(2).Node
+ n0 = Node(0)
+ n1 = Node(1)
+ n2 = Node(2)
+ n3 = Node(3)
+ n4 = Node(4)
+ n5 = Node(5)
+ n6 = Node(6)
+ n7 = Node(7)
+ self.assertEqual(n0.last_common_level(n0), 0)
+ self.assertEqual(n0.last_common_level(n1), 0)
+ self.assertEqual(n0.last_common_level(n2), 0)
+ self.assertEqual(n0.last_common_level(n3), 0)
+ self.assertEqual(n0.last_common_level(n4), 0)
+ self.assertEqual(n0.last_common_level(n5), 0)
+ self.assertEqual(n0.last_common_level(n6), 0)
+ self.assertEqual(n0.last_common_level(n7), 0)
+
+ self.assertEqual(n1.last_common_level(n0), 0)
+ self.assertEqual(n1.last_common_level(n1), 1)
+ self.assertEqual(n1.last_common_level(n2), 0)
+ self.assertEqual(n1.last_common_level(n3), 1)
+ self.assertEqual(n1.last_common_level(n4), 1)
+ self.assertEqual(n1.last_common_level(n5), 0)
+ self.assertEqual(n1.last_common_level(n6), 0)
+ self.assertEqual(n1.last_common_level(n7), 1)
+
+ self.assertEqual(n2.last_common_level(n0), 0)
+ self.assertEqual(n2.last_common_level(n1), 0)
+ self.assertEqual(n2.last_common_level(n2), 1)
+ self.assertEqual(n2.last_common_level(n3), 0)
+ self.assertEqual(n2.last_common_level(n4), 0)
+ self.assertEqual(n2.last_common_level(n5), 1)
+ self.assertEqual(n2.last_common_level(n6), 1)
+ self.assertEqual(n2.last_common_level(n7), 0)
+
+ self.assertEqual(n3.last_common_level(n0), 0)
+ self.assertEqual(n3.last_common_level(n1), 1)
+ self.assertEqual(n3.last_common_level(n2), 0)
+ self.assertEqual(n3.last_common_level(n3), 2)
+ self.assertEqual(n3.last_common_level(n4), 1)
+ self.assertEqual(n3.last_common_level(n5), 0)
+ self.assertEqual(n3.last_common_level(n6), 0)
+ self.assertEqual(n3.last_common_level(n7), 2)
+
+ self.assertEqual(n4.last_common_level(n0), 0)
+ self.assertEqual(n4.last_common_level(n1), 1)
+ self.assertEqual(n4.last_common_level(n2), 0)
+ self.assertEqual(n4.last_common_level(n3), 1)
+ self.assertEqual(n4.last_common_level(n4), 2)
+ self.assertEqual(n4.last_common_level(n5), 0)
+ self.assertEqual(n4.last_common_level(n6), 0)
+ self.assertEqual(n4.last_common_level(n7), 1)
+
+ self.assertEqual(n5.last_common_level(n0), 0)
+ self.assertEqual(n5.last_common_level(n1), 0)
+ self.assertEqual(n5.last_common_level(n2), 1)
+ self.assertEqual(n5.last_common_level(n3), 0)
+ self.assertEqual(n5.last_common_level(n4), 0)
+ self.assertEqual(n5.last_common_level(n5), 2)
+ self.assertEqual(n5.last_common_level(n6), 1)
+ self.assertEqual(n5.last_common_level(n7), 0)
+
+ self.assertEqual(n6.last_common_level(n0), 0)
+ self.assertEqual(n6.last_common_level(n1), 0)
+ self.assertEqual(n6.last_common_level(n2), 1)
+ self.assertEqual(n6.last_common_level(n3), 0)
+ self.assertEqual(n6.last_common_level(n4), 0)
+ self.assertEqual(n6.last_common_level(n5), 1)
+ self.assertEqual(n6.last_common_level(n6), 2)
+ self.assertEqual(n6.last_common_level(n7), 0)
+
+ self.assertEqual(n7.last_common_level(n0), 0)
+ self.assertEqual(n7.last_common_level(n1), 1)
+ self.assertEqual(n7.last_common_level(n2), 0)
+ self.assertEqual(n7.last_common_level(n3), 2)
+ self.assertEqual(n7.last_common_level(n4), 1)
+ self.assertEqual(n7.last_common_level(n5), 0)
+ self.assertEqual(n7.last_common_level(n6), 0)
+ self.assertEqual(n7.last_common_level(n7), 3)
+
+ def test_last_common_level_k3(self):
+ Node = VirtualHeap(3).Node
+ n0 = Node(0)
+ n1 = Node(1)
+ n2 = Node(2)
+ n3 = Node(3)
+ n4 = Node(4)
+ n5 = Node(5)
+ n6 = Node(6)
+ n7 = Node(7)
+ self.assertEqual(n0.last_common_level(n0), 0)
+ self.assertEqual(n0.last_common_level(n1), 0)
+ self.assertEqual(n0.last_common_level(n2), 0)
+ self.assertEqual(n0.last_common_level(n3), 0)
+ self.assertEqual(n0.last_common_level(n4), 0)
+ self.assertEqual(n0.last_common_level(n5), 0)
+ self.assertEqual(n0.last_common_level(n6), 0)
+ self.assertEqual(n0.last_common_level(n7), 0)
+
+ self.assertEqual(n1.last_common_level(n0), 0)
+ self.assertEqual(n1.last_common_level(n1), 1)
+ self.assertEqual(n1.last_common_level(n2), 0)
+ self.assertEqual(n1.last_common_level(n3), 0)
+ self.assertEqual(n1.last_common_level(n4), 1)
+ self.assertEqual(n1.last_common_level(n5), 1)
+ self.assertEqual(n1.last_common_level(n6), 1)
+ self.assertEqual(n1.last_common_level(n7), 0)
+
+ self.assertEqual(n2.last_common_level(n0), 0)
+ self.assertEqual(n2.last_common_level(n1), 0)
+ self.assertEqual(n2.last_common_level(n2), 1)
+ self.assertEqual(n2.last_common_level(n3), 0)
+ self.assertEqual(n2.last_common_level(n4), 0)
+ self.assertEqual(n2.last_common_level(n5), 0)
+ self.assertEqual(n2.last_common_level(n6), 0)
+ self.assertEqual(n2.last_common_level(n7), 1)
+
+ self.assertEqual(n3.last_common_level(n0), 0)
+ self.assertEqual(n3.last_common_level(n1), 0)
+ self.assertEqual(n3.last_common_level(n2), 0)
+ self.assertEqual(n3.last_common_level(n3), 1)
+ self.assertEqual(n3.last_common_level(n4), 0)
+ self.assertEqual(n3.last_common_level(n5), 0)
+ self.assertEqual(n3.last_common_level(n6), 0)
+ self.assertEqual(n3.last_common_level(n7), 0)
+
+ self.assertEqual(n4.last_common_level(n0), 0)
+ self.assertEqual(n4.last_common_level(n1), 1)
+ self.assertEqual(n4.last_common_level(n2), 0)
+ self.assertEqual(n4.last_common_level(n3), 0)
+ self.assertEqual(n4.last_common_level(n4), 2)
+ self.assertEqual(n4.last_common_level(n5), 1)
+ self.assertEqual(n4.last_common_level(n6), 1)
+ self.assertEqual(n4.last_common_level(n7), 0)
+
+ self.assertEqual(n5.last_common_level(n0), 0)
+ self.assertEqual(n5.last_common_level(n1), 1)
+ self.assertEqual(n5.last_common_level(n2), 0)
+ self.assertEqual(n5.last_common_level(n3), 0)
+ self.assertEqual(n5.last_common_level(n4), 1)
+ self.assertEqual(n5.last_common_level(n5), 2)
+ self.assertEqual(n5.last_common_level(n6), 1)
+ self.assertEqual(n5.last_common_level(n7), 0)
+
+ self.assertEqual(n6.last_common_level(n0), 0)
+ self.assertEqual(n6.last_common_level(n1), 1)
+ self.assertEqual(n6.last_common_level(n2), 0)
+ self.assertEqual(n6.last_common_level(n3), 0)
+ self.assertEqual(n6.last_common_level(n4), 1)
+ self.assertEqual(n6.last_common_level(n5), 1)
+ self.assertEqual(n6.last_common_level(n6), 2)
+ self.assertEqual(n6.last_common_level(n7), 0)
+
+ self.assertEqual(n7.last_common_level(n0), 0)
+ self.assertEqual(n7.last_common_level(n1), 0)
+ self.assertEqual(n7.last_common_level(n2), 1)
+ self.assertEqual(n7.last_common_level(n3), 0)
+ self.assertEqual(n7.last_common_level(n4), 0)
+ self.assertEqual(n7.last_common_level(n5), 0)
+ self.assertEqual(n7.last_common_level(n6), 0)
+ self.assertEqual(n7.last_common_level(n7), 2)
+
+ def test_child_node(self):
+ root = VirtualHeap(2).Node(0)
+ self.assertEqual(list(_do_preorder(root)),
+ [0, 1, 3, 4, 2, 5, 6])
+ self.assertEqual(list(_do_postorder(root)),
+ [3, 4, 1, 5, 6, 2, 0])
+ self.assertEqual(list(_do_inorder(root)),
+ [3, 1, 4, 0, 5, 2, 6])
+
+ root = VirtualHeap(3).Node(0)
+ self.assertEqual(
+ list(_do_preorder(root)),
+ [0, 1, 4, 5, 6, 2, 7, 8, 9, 3, 10, 11, 12])
+ self.assertEqual(
+ list(_do_postorder(root)),
+ [4, 5, 6, 1, 7, 8, 9, 2, 10, 11, 12, 3, 0])
+
+ def test_parent_node(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(Node(0).parent_node(),
+ None)
+ self.assertEqual(Node(1).parent_node(),
+ Node(0))
+ self.assertEqual(Node(2).parent_node(),
+ Node(0))
+ self.assertEqual(Node(3).parent_node(),
+ Node(1))
+ self.assertEqual(Node(4).parent_node(),
+ Node(1))
+ self.assertEqual(Node(5).parent_node(),
+ Node(2))
+ self.assertEqual(Node(6).parent_node(),
+ Node(2))
+ self.assertEqual(Node(7).parent_node(),
+ Node(3))
+
+ Node = VirtualHeap(3).Node
+ self.assertEqual(Node(0).parent_node(),
+ None)
+ self.assertEqual(Node(1).parent_node(),
+ Node(0))
+ self.assertEqual(Node(2).parent_node(),
+ Node(0))
+ self.assertEqual(Node(3).parent_node(),
+ Node(0))
+ self.assertEqual(Node(4).parent_node(),
+ Node(1))
+ self.assertEqual(Node(5).parent_node(),
+ Node(1))
+ self.assertEqual(Node(6).parent_node(),
+ Node(1))
+ self.assertEqual(Node(7).parent_node(),
+ Node(2))
+ self.assertEqual(Node(8).parent_node(),
+ Node(2))
+ self.assertEqual(Node(9).parent_node(),
+ Node(2))
+ self.assertEqual(Node(10).parent_node(),
+ Node(3))
+ self.assertEqual(Node(11).parent_node(),
+ Node(3))
+ self.assertEqual(Node(12).parent_node(),
+ Node(3))
+ self.assertEqual(Node(13).parent_node(),
+ Node(4))
+
+ def test_ancestor_node_at_level(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(Node(0).ancestor_node_at_level(0),
+ Node(0))
+ self.assertEqual(Node(0).ancestor_node_at_level(1),
+ None)
+ self.assertEqual(Node(1).ancestor_node_at_level(0),
+ Node(0))
+ self.assertEqual(Node(1).ancestor_node_at_level(1),
+ Node(1))
+ self.assertEqual(Node(1).ancestor_node_at_level(2),
+ None)
+ self.assertEqual(Node(3).ancestor_node_at_level(0),
+ Node(0))
+ self.assertEqual(Node(3).ancestor_node_at_level(1),
+ Node(1))
+ self.assertEqual(Node(3).ancestor_node_at_level(2),
+ Node(3))
+ self.assertEqual(Node(3).ancestor_node_at_level(3),
+ None)
+
+ Node = VirtualHeap(3).Node
+ self.assertEqual(Node(0).ancestor_node_at_level(0),
+ Node(0))
+ self.assertEqual(Node(0).ancestor_node_at_level(1),
+ None)
+ self.assertEqual(Node(1).ancestor_node_at_level(0),
+ Node(0))
+ self.assertEqual(Node(1).ancestor_node_at_level(1),
+ Node(1))
+ self.assertEqual(Node(1).ancestor_node_at_level(2),
+ None)
+ self.assertEqual(Node(4).ancestor_node_at_level(0),
+ Node(0))
+ self.assertEqual(Node(4).ancestor_node_at_level(1),
+ Node(1))
+ self.assertEqual(Node(4).ancestor_node_at_level(2),
+ Node(4))
+ self.assertEqual(Node(4).ancestor_node_at_level(3),
+ None)
+
+ def test_path_to_root(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(list(int(n) for n in Node(0).bucket_path_to_root()),
+ list(reversed([0])))
+ self.assertEqual(list(int(n) for n in Node(7).bucket_path_to_root()),
+ list(reversed([0, 1, 3, 7])))
+ self.assertEqual(list(int(n) for n in Node(8).bucket_path_to_root()),
+ list(reversed([0, 1, 3, 8])))
+ self.assertEqual(list(int(n) for n in Node(9).bucket_path_to_root()),
+ list(reversed([0, 1, 4, 9])))
+ self.assertEqual(list(int(n) for n in Node(10).bucket_path_to_root()),
+ list(reversed([0, 1, 4, 10])))
+ self.assertEqual(list(int(n) for n in Node(11).bucket_path_to_root()),
+ list(reversed([0, 2, 5, 11])))
+ self.assertEqual(list(int(n) for n in Node(12).bucket_path_to_root()),
+ list(reversed([0, 2, 5, 12])))
+ self.assertEqual(list(int(n) for n in Node(13).bucket_path_to_root()),
+ list(reversed([0, 2, 6, 13])))
+ self.assertEqual(list(int(n) for n in Node(14).bucket_path_to_root()),
+ list(reversed([0, 2, 6, 14])))
+
+ def test_path_from_root(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(list(int(n) for n in Node(0).bucket_path_from_root()),
+ [0])
+ self.assertEqual(list(int(n) for n in Node(7).bucket_path_from_root()),
+ [0, 1, 3, 7])
+ self.assertEqual(list(int(n) for n in Node(8).bucket_path_from_root()),
+ [0, 1, 3, 8])
+ self.assertEqual(list(int(n) for n in Node(9).bucket_path_from_root()),
+ [0, 1, 4, 9])
+ self.assertEqual(list(int(n) for n in Node(10).bucket_path_from_root()),
+ [0, 1, 4, 10])
+ self.assertEqual(list(int(n) for n in Node(11).bucket_path_from_root()),
+ [0, 2, 5, 11])
+ self.assertEqual(list(int(n) for n in Node(12).bucket_path_from_root()),
+ [0, 2, 5, 12])
+ self.assertEqual(list(int(n) for n in Node(13).bucket_path_from_root()),
+ [0, 2, 6, 13])
+ self.assertEqual(list(int(n) for n in Node(14).bucket_path_from_root()),
+ [0, 2, 6, 14])
+
+ def test_bucket_path_to_root(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(list(Node(0).bucket_path_to_root()),
+ list(reversed([0])))
+ self.assertEqual(list(Node(7).bucket_path_to_root()),
+ list(reversed([0, 1, 3, 7])))
+ self.assertEqual(list(Node(8).bucket_path_to_root()),
+ list(reversed([0, 1, 3, 8])))
+ self.assertEqual(list(Node(9).bucket_path_to_root()),
+ list(reversed([0, 1, 4, 9])))
+ self.assertEqual(list(Node(10).bucket_path_to_root()),
+ list(reversed([0, 1, 4, 10])))
+ self.assertEqual(list(Node(11).bucket_path_to_root()),
+ list(reversed([0, 2, 5, 11])))
+ self.assertEqual(list(Node(12).bucket_path_to_root()),
+ list(reversed([0, 2, 5, 12])))
+ self.assertEqual(list(Node(13).bucket_path_to_root()),
+ list(reversed([0, 2, 6, 13])))
+ self.assertEqual(list(Node(14).bucket_path_to_root()),
+ list(reversed([0, 2, 6, 14])))
+
+ def test_bucket_path_from_root(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(Node(0).bucket_path_from_root(),
+ [0])
+ self.assertEqual(Node(7).bucket_path_from_root(),
+ [0, 1, 3, 7])
+ self.assertEqual(Node(8).bucket_path_from_root(),
+ [0, 1, 3, 8])
+ self.assertEqual(Node(9).bucket_path_from_root(),
+ [0, 1, 4, 9])
+ self.assertEqual(Node(10).bucket_path_from_root(),
+ [0, 1, 4, 10])
+ self.assertEqual(Node(11).bucket_path_from_root(),
+ [0, 2, 5, 11])
+ self.assertEqual(Node(12).bucket_path_from_root(),
+ [0, 2, 5, 12])
+ self.assertEqual(Node(13).bucket_path_from_root(),
+ [0, 2, 6, 13])
+ self.assertEqual(Node(14).bucket_path_from_root(),
+ [0, 2, 6, 14])
+
+ def test_repr(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(
+ repr(Node(0)),
+ "VirtualHeapNode(k=2, bucket=0, level=0, label='')")
+ self.assertEqual(
+ repr(Node(7)),
+ "VirtualHeapNode(k=2, bucket=7, level=3, label='000')")
+
+ Node = VirtualHeap(3).Node
+ self.assertEqual(
+ repr(Node(0)),
+ "VirtualHeapNode(k=3, bucket=0, level=0, label='')")
+ self.assertEqual(
+ repr(Node(7)),
+ "VirtualHeapNode(k=3, bucket=7, level=2, label='10')")
+
+ Node = VirtualHeap(5).Node
+ self.assertEqual(
+ repr(Node(25)),
+ "VirtualHeapNode(k=5, bucket=25, level=2, label='34')")
+
+ Node = VirtualHeap(max_k_labeled).Node
+ self.assertEqual(
+ repr(Node(0)),
+ ("VirtualHeapNode(k=%d, bucket=0, level=0, label='')"
+ % (max_k_labeled)))
+ self.assertEqual(max_k_labeled >= 2, True)
+ self.assertEqual(
+ repr(Node(1)),
+ ("VirtualHeapNode(k=%d, bucket=1, level=1, label='0')"
+ % (max_k_labeled)))
+
+ Node = VirtualHeap(max_k_labeled+1).Node
+ self.assertEqual(
+ repr(Node(0)),
+ ("VirtualHeapNode(k=%d, bucket=0, level=0, label='')"
+ % (max_k_labeled+1)))
+ self.assertEqual(
+ repr(Node(1)),
+ ("VirtualHeapNode(k=%d, bucket=1, level=1, label='<unknown>')"
+ % (max_k_labeled+1)))
+ self.assertEqual(
+ repr(Node(max_k_labeled+1)),
+ ("VirtualHeapNode(k=%d, bucket=%d, level=1, label='<unknown>')"
+ % (max_k_labeled+1,
+ max_k_labeled+1)))
+ self.assertEqual(
+ repr(Node(max_k_labeled+2)),
+ ("VirtualHeapNode(k=%d, bucket=%d, level=2, label='<unknown>')"
+ % (max_k_labeled+1,
+ max_k_labeled+2)))
+
+ def test_str(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(
+ str(Node(0)),
+ "(0, 0)")
+ self.assertEqual(
+ str(Node(7)),
+ "(3, 0)")
+
+ Node = VirtualHeap(3).Node
+ self.assertEqual(
+ str(Node(0)),
+ "(0, 0)")
+ self.assertEqual(
+ str(Node(7)),
+ "(2, 3)")
+
+ Node = VirtualHeap(5).Node
+ self.assertEqual(
+ str(Node(25)),
+ "(2, 19)")
+
+ def test_label(self):
+
+ Node = VirtualHeap(2).Node
+ self.assertEqual(Node(0).label(), "")
+ self.assertEqual(Node(1).label(), "0")
+ self.assertEqual(Node(2).label(), "1")
+ self.assertEqual(Node(3).label(), "00")
+ self.assertEqual(Node(4).label(), "01")
+ self.assertEqual(Node(5).label(), "10")
+ self.assertEqual(Node(6).label(), "11")
+ self.assertEqual(Node(7).label(), "000")
+ self.assertEqual(Node(8).label(), "001")
+ self.assertEqual(Node(9).label(), "010")
+ self.assertEqual(Node(10).label(), "011")
+ self.assertEqual(Node(11).label(), "100")
+ self.assertEqual(Node(12).label(), "101")
+ self.assertEqual(Node(13).label(), "110")
+ self.assertEqual(Node(14).label(), "111")
+ self.assertEqual(Node(15).label(), "0000")
+ self.assertEqual(Node(30).label(), "1111")
+
+ for k in _test_labeled_bases:
+ Node = VirtualHeap(k).Node
+ for b in xrange(calculate_bucket_count_in_heap_with_height(k, 2)+1):
+ label = Node(b).label()
+ level = Node(b).level
+ if label == "":
+ self.assertEqual(b, 0)
+ else:
+ self.assertEqual(
+ b,
+ basek_string_to_base10_integer(k, label) + \
+ calculate_bucket_count_in_heap_with_height(k, level-1))
+
+ def test_is_node_on_path(self):
+ Node = VirtualHeap(2).Node
+ self.assertEqual(
+ Node(0).is_node_on_path(
+ Node(0)),
+ True)
+ self.assertEqual(
+ Node(0).is_node_on_path(
+ Node(1)),
+ False)
+ self.assertEqual(
+ Node(0).is_node_on_path(
+ Node(2)),
+ False)
+ self.assertEqual(
+ Node(0).is_node_on_path(
+ Node(3)),
+ False)
+
+ Node = VirtualHeap(5).Node
+ self.assertEqual(
+ Node(20).is_node_on_path(
+ Node(21)),
+ False)
+ self.assertEqual(
+ Node(21).is_node_on_path(
+ Node(4)),
+ True)
+
+ Node = VirtualHeap(3).Node
+ self.assertEqual(
+ Node(7).is_node_on_path(
+ Node(0)),
+ True)
+ self.assertEqual(
+ Node(7).is_node_on_path(
+ Node(2)),
+ True)
+ self.assertEqual(
+ Node(7).is_node_on_path(
+ Node(7)),
+ True)
+ self.assertEqual(
+ Node(7).is_node_on_path(
+ Node(8)),
+ False)
+
+class TestVirtualHeap(unittest2.TestCase):
+
+ def test_init(self):
+ vh = VirtualHeap(2, blocks_per_bucket=4)
+ self.assertEqual(vh.k, 2)
+ self.assertEqual(vh.Node.k, 2)
+ self.assertEqual(vh.blocks_per_bucket, 4)
+ vh = VirtualHeap(5, blocks_per_bucket=7)
+ self.assertEqual(vh.k, 5)
+ self.assertEqual(vh.Node.k, 5)
+ self.assertEqual(vh.blocks_per_bucket, 7)
+
+ def test_node_label_to_bucket(self):
+ vh = VirtualHeap(2)
+ self.assertEqual(vh.node_label_to_bucket(""), 0)
+ self.assertEqual(vh.node_label_to_bucket("0"), 1)
+ self.assertEqual(vh.node_label_to_bucket("1"), 2)
+ self.assertEqual(vh.node_label_to_bucket("00"), 3)
+ self.assertEqual(vh.node_label_to_bucket("01"), 4)
+ self.assertEqual(vh.node_label_to_bucket("10"), 5)
+ self.assertEqual(vh.node_label_to_bucket("11"), 6)
+ self.assertEqual(vh.node_label_to_bucket("000"), 7)
+ self.assertEqual(vh.node_label_to_bucket("001"), 8)
+ self.assertEqual(vh.node_label_to_bucket("010"), 9)
+ self.assertEqual(vh.node_label_to_bucket("011"), 10)
+ self.assertEqual(vh.node_label_to_bucket("100"), 11)
+ self.assertEqual(vh.node_label_to_bucket("101"), 12)
+ self.assertEqual(vh.node_label_to_bucket("110"), 13)
+ self.assertEqual(vh.node_label_to_bucket("111"), 14)
+ self.assertEqual(vh.node_label_to_bucket("0000"), 15)
+ self.assertEqual(vh.node_label_to_bucket("1111"),
+ calculate_bucket_count_in_heap_with_height(2, 4)-1)
+
+ vh = VirtualHeap(3)
+ self.assertEqual(vh.node_label_to_bucket(""), 0)
+ self.assertEqual(vh.node_label_to_bucket("0"), 1)
+ self.assertEqual(vh.node_label_to_bucket("1"), 2)
+ self.assertEqual(vh.node_label_to_bucket("2"), 3)
+ self.assertEqual(vh.node_label_to_bucket("00"), 4)
+ self.assertEqual(vh.node_label_to_bucket("01"), 5)
+ self.assertEqual(vh.node_label_to_bucket("02"), 6)
+ self.assertEqual(vh.node_label_to_bucket("10"), 7)
+ self.assertEqual(vh.node_label_to_bucket("11"), 8)
+ self.assertEqual(vh.node_label_to_bucket("12"), 9)
+ self.assertEqual(vh.node_label_to_bucket("20"), 10)
+ self.assertEqual(vh.node_label_to_bucket("21"), 11)
+ self.assertEqual(vh.node_label_to_bucket("22"), 12)
+ self.assertEqual(vh.node_label_to_bucket("000"), 13)
+ self.assertEqual(vh.node_label_to_bucket("222"),
+ calculate_bucket_count_in_heap_with_height(3, 3)-1)
+
+ for k in xrange(2, max_k_labeled+1):
+ for h in xrange(5):
+ vh = VirtualHeap(k)
+ largest_symbol = numerals[k-1]
+ self.assertEqual(vh.k, k)
+ self.assertEqual(vh.node_label_to_bucket(""), 0)
+ self.assertEqual(vh.node_label_to_bucket(largest_symbol * h),
+ calculate_bucket_count_in_heap_with_height(k, h)-1)
+
+ def test_ObjectCountAtLevel(self):
+ for k in _test_bases:
+ for height in xrange(k+2):
+ for blocks_per_bucket in xrange(1, 5):
+ vh = VirtualHeap(k, blocks_per_bucket=blocks_per_bucket)
+ for l in xrange(height+1):
+ cnt = k**l
+ self.assertEqual(vh.bucket_count_at_level(l), cnt)
+ self.assertEqual(vh.node_count_at_level(l), cnt)
+ self.assertEqual(vh.block_count_at_level(l),
+ cnt * blocks_per_bucket)
+
+ def test_bucket_to_block(self):
+ for k in xrange(2, 6):
+ for blocks_per_bucket in xrange(1, 5):
+ heap = VirtualHeap(k, blocks_per_bucket=blocks_per_bucket)
+ for b in xrange(20):
+ self.assertEqual(heap.bucket_to_block(b),
+ blocks_per_bucket * b)
+
+ def test_node_count_at_level(self):
+ self.assertEqual(VirtualHeap(2).node_count_at_level(0), 1)
+ self.assertEqual(VirtualHeap(2).node_count_at_level(1), 2)
+ self.assertEqual(VirtualHeap(2).node_count_at_level(2), 4)
+ self.assertEqual(VirtualHeap(2).node_count_at_level(3), 8)
+ self.assertEqual(VirtualHeap(2).node_count_at_level(4), 16)
+
+ self.assertEqual(VirtualHeap(3).node_count_at_level(0), 1)
+ self.assertEqual(VirtualHeap(3).node_count_at_level(1), 3)
+ self.assertEqual(VirtualHeap(3).node_count_at_level(2), 9)
+ self.assertEqual(VirtualHeap(3).node_count_at_level(3), 27)
+ self.assertEqual(VirtualHeap(3).node_count_at_level(4), 81)
+
+ self.assertEqual(VirtualHeap(4).node_count_at_level(0), 1)
+ self.assertEqual(VirtualHeap(4).node_count_at_level(1), 4)
+ self.assertEqual(VirtualHeap(4).node_count_at_level(2), 16)
+ self.assertEqual(VirtualHeap(4).node_count_at_level(3), 64)
+ self.assertEqual(VirtualHeap(4).node_count_at_level(4), 256)
+
+ def test_first_node_at_level(self):
+ self.assertEqual(VirtualHeap(2).first_node_at_level(0), 0)
+ self.assertEqual(VirtualHeap(2).first_node_at_level(1), 1)
+ self.assertEqual(VirtualHeap(2).first_node_at_level(2), 3)
+ self.assertEqual(VirtualHeap(2).first_node_at_level(3), 7)
+ self.assertEqual(VirtualHeap(2).first_node_at_level(4), 15)
+
+ self.assertEqual(VirtualHeap(3).first_node_at_level(0), 0)
+ self.assertEqual(VirtualHeap(3).first_node_at_level(1), 1)
+ self.assertEqual(VirtualHeap(3).first_node_at_level(2), 4)
+ self.assertEqual(VirtualHeap(3).first_node_at_level(3), 13)
+ self.assertEqual(VirtualHeap(3).first_node_at_level(4), 40)
+
+ self.assertEqual(VirtualHeap(4).first_node_at_level(0), 0)
+ self.assertEqual(VirtualHeap(4).first_node_at_level(1), 1)
+ self.assertEqual(VirtualHeap(4).first_node_at_level(2), 5)
+ self.assertEqual(VirtualHeap(4).first_node_at_level(3), 21)
+ self.assertEqual(VirtualHeap(4).first_node_at_level(4), 85)
+
+ def test_last_node_at_level(self):
+ self.assertEqual(VirtualHeap(2).last_node_at_level(0), 0)
+ self.assertEqual(VirtualHeap(2).last_node_at_level(1), 2)
+ self.assertEqual(VirtualHeap(2).last_node_at_level(2), 6)
+ self.assertEqual(VirtualHeap(2).last_node_at_level(3), 14)
+ self.assertEqual(VirtualHeap(2).last_node_at_level(4), 30)
+
+ self.assertEqual(VirtualHeap(3).last_node_at_level(0), 0)
+ self.assertEqual(VirtualHeap(3).last_node_at_level(1), 3)
+ self.assertEqual(VirtualHeap(3).last_node_at_level(2), 12)
+ self.assertEqual(VirtualHeap(3).last_node_at_level(3), 39)
+ self.assertEqual(VirtualHeap(3).last_node_at_level(4), 120)
+
+ self.assertEqual(VirtualHeap(4).last_node_at_level(0), 0)
+ self.assertEqual(VirtualHeap(4).last_node_at_level(1), 4)
+ self.assertEqual(VirtualHeap(4).last_node_at_level(2), 20)
+ self.assertEqual(VirtualHeap(4).last_node_at_level(3), 84)
+ self.assertEqual(VirtualHeap(4).last_node_at_level(4), 340)
+
+ def test_random_node_up_to_level(self):
+ for k in xrange(2,6):
+ heap = VirtualHeap(k)
+ for l in xrange(4):
+ for t in xrange(2 * calculate_bucket_count_in_heap_with_height(k, l)):
+ node = heap.random_node_up_to_level(l)
+ self.assertEqual(node.level <= l, True)
+
+ def test_random_node_at_level(self):
+ for k in xrange(2,6):
+ heap = VirtualHeap(k)
+ for l in xrange(4):
+ for t in xrange(2 * calculate_bucket_count_in_heap_at_level(k, l)):
+ node = heap.random_node_at_level(l)
+ self.assertEqual(node.level == l, True)
+
+ def test_first_block_at_level(self):
+ for blocks_per_bucket in xrange(1, 5):
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(0), 0 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(1), 1 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(2), 3 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(3), 7 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(4), 15 * blocks_per_bucket)
+
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(0), 0 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(1), 1 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(2), 4 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(3), 13 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(4), 40 * blocks_per_bucket)
+
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(0), 0 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(1), 1 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(2), 5 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(3), 21 * blocks_per_bucket)
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ first_block_at_level(4), 85 * blocks_per_bucket)
+
+ def test_last_block_at_level(self):
+ for blocks_per_bucket in xrange(1, 5):
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(0), 0 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(1), 2 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(2), 6 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(3), 14 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(4), 30 * blocks_per_bucket + (blocks_per_bucket-1))
+
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(0), 0 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(1), 3 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(2), 12 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(3), 39 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(3, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(4), 120 * blocks_per_bucket + (blocks_per_bucket-1))
+
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(0), 0 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(1), 4 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(2), 20 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(3), 84 * blocks_per_bucket + (blocks_per_bucket-1))
+ self.assertEqual(VirtualHeap(4, blocks_per_bucket=blocks_per_bucket).\
+ last_block_at_level(4), 340 * blocks_per_bucket + (blocks_per_bucket-1))
+
+ def test_block_to_bucket(self):
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=1).block_to_bucket(0), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=1).block_to_bucket(1), 1)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=1).block_to_bucket(2), 2)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=1).block_to_bucket(3), 3)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=1).block_to_bucket(4), 4)
+
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=2).block_to_bucket(0), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=2).block_to_bucket(1), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=2).block_to_bucket(2), 1)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=2).block_to_bucket(3), 1)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=2).block_to_bucket(4), 2)
+
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=3).block_to_bucket(0), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=3).block_to_bucket(1), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=3).block_to_bucket(2), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=3).block_to_bucket(3), 1)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=3).block_to_bucket(4), 1)
+
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=4).block_to_bucket(0), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=4).block_to_bucket(1), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=4).block_to_bucket(2), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=4).block_to_bucket(3), 0)
+ self.assertEqual(VirtualHeap(2, blocks_per_bucket=4).block_to_bucket(4), 1)
+
+ def test_root_node(self):
+ for k in range(2, 6):
+ for blocks_per_bucket in range(1, 5):
+ heap = VirtualHeap(k, blocks_per_bucket=blocks_per_bucket)
+ root = heap.root_node()
+ self.assertEqual(root, 0)
+ self.assertEqual(root.bucket, 0)
+ self.assertEqual(root.level, 0)
+ self.assertEqual(root.parent_node(), None)
+
+class TestSizedVirtualHeap(unittest2.TestCase):
+
+ def test_init(self):
+ vh = SizedVirtualHeap(2, 8, blocks_per_bucket=4)
+ self.assertEqual(vh.k, 2)
+ self.assertEqual(vh.Node.k, 2)
+ self.assertEqual(vh.blocks_per_bucket, 4)
+ vh = SizedVirtualHeap(5, 9, blocks_per_bucket=7)
+ self.assertEqual(vh.k, 5)
+ self.assertEqual(vh.Node.k, 5)
+ self.assertEqual(vh.blocks_per_bucket, 7)
+
+ def test_height(self):
+ vh = SizedVirtualHeap(2, 3, blocks_per_bucket=4)
+ self.assertEqual(vh.height, 3)
+ vh = SizedVirtualHeap(5, 6, blocks_per_bucket=7)
+ self.assertEqual(vh.height, 6)
+
+ def test_levels(self):
+ vh = SizedVirtualHeap(2, 3, blocks_per_bucket=4)
+ self.assertEqual(vh.levels, 4)
+ vh = SizedVirtualHeap(5, 6, blocks_per_bucket=7)
+ self.assertEqual(vh.levels, 7)
+
+ def test_first_level(self):
+ vh = SizedVirtualHeap(2, 3, blocks_per_bucket=4)
+ self.assertEqual(vh.first_level, 0)
+ vh = SizedVirtualHeap(5, 6, blocks_per_bucket=7)
+ self.assertEqual(vh.first_level, 0)
+
+ def test_last_level(self):
+ vh = SizedVirtualHeap(2, 3, blocks_per_bucket=4)
+ self.assertEqual(vh.last_level, 3)
+ self.assertEqual(vh.last_level, vh.levels-1)
+ self.assertEqual(vh.last_level, vh.height)
+ vh = SizedVirtualHeap(5, 6, blocks_per_bucket=7)
+ self.assertEqual(vh.last_level, 6)
+ self.assertEqual(vh.last_level, vh.levels-1)
+ self.assertEqual(vh.last_level, vh.height)
+
+ def test_ObjectCount(self):
+ for k in _test_bases:
+ for height in xrange(k+2):
+ for blocks_per_bucket in xrange(1, 5):
+ vh = SizedVirtualHeap(k,
+ height,
+ blocks_per_bucket=blocks_per_bucket)
+ cnt = (((k**(height+1))-1)//(k-1))
+ self.assertEqual(vh.bucket_count(), cnt)
+ self.assertEqual(vh.node_count(), cnt)
+ self.assertEqual(vh.block_count(), cnt * blocks_per_bucket)
+
+ def test_LeafObjectCount(self):
+ for k in _test_bases:
+ for height in xrange(k+2):
+ for blocks_per_bucket in xrange(1, 5):
+ vh = SizedVirtualHeap(k,
+ height,
+ blocks_per_bucket=blocks_per_bucket)
+ self.assertEqual(vh.leaf_bucket_count(),
+ vh.bucket_count_at_level(vh.height))
+ self.assertEqual(vh.leaf_node_count(),
+ vh.node_count_at_level(vh.height))
+ self.assertEqual(vh.leaf_block_count(),
+ vh.block_count_at_level(vh.height))
+
+
+ def test_FirstLeafObject(self):
+ vh = SizedVirtualHeap(2, 3, blocks_per_bucket=3)
+ self.assertEqual(vh.first_leaf_node(), 7)
+ self.assertEqual(vh.first_leaf_block(), 7*3)
+
+ def test_LastLeafObject(self):
+ vh = SizedVirtualHeap(2, 3, blocks_per_bucket=3)
+ self.assertEqual(vh.last_leaf_node(), 14)
+ self.assertEqual(vh.last_leaf_block(), 14*3 + 2)
+
+ def test_random_node(self):
+ for k in xrange(2,6):
+ height = 3
+ heap = SizedVirtualHeap(k, height)
+ for t in xrange(2 * heap.bucket_count()):
+ node = heap.random_node()
+ self.assertEqual(0 <= node.level <= height, True)
+
+ def test_random_leaf_node(self):
+ for k in xrange(2,6):
+ height = 3
+ heap = SizedVirtualHeap(k, height)
+ for t in xrange(2 * heap.bucket_count()):
+ node = heap.random_leaf_node()
+ self.assertEqual(node.level, height)
+
+ def _assert_file_equals_baselines(self, fname, bname):
+ with open(fname)as f:
+ flines = f.readlines()
+ with open(bname) as f:
+ blines = f.readlines()
+ self.assertListEqual(flines, blines)
+ os.remove(fname)
+
+ def test_write_as_dot(self):
+
+ for k, h, b, maxl in [(2, 3, 1, None),
+ (2, 3, 2, None),
+ (3, 3, 1, None),
+ (3, 3, 2, None),
+ (3, 10, 2, 4),
+ (200, 0, 1, None)]:
+ if maxl is None:
+ label = "k%d_h%d_b%d" % (k, h, b)
+ else:
+ label = "k%d_h%d_b%d" % (k, maxl-1, b)
+ heap = SizedVirtualHeap(k, h, blocks_per_bucket=b)
+
+ fname = label+".dot"
+ with open(os.path.join(thisdir, fname), "w") as f:
+ heap.write_as_dot(f, max_levels=maxl)
+ self._assert_file_equals_baselines(
+ os.path.join(thisdir, fname),
+ os.path.join(baselinedir, fname))
+
+ data = list(range(heap.block_count()))
+ fname = label+"_data.dot"
+ with open(os.path.join(thisdir, fname), "w") as f:
+ heap.write_as_dot(f, data=data, max_levels=maxl)
+ self._assert_file_equals_baselines(
+ os.path.join(thisdir, fname),
+ os.path.join(baselinedir, fname))
+
+ def test_save_image_as_pdf(self):
+
+ for k, h, b, maxl in [(2, 3, 1, None),
+ (2, 3, 2, None),
+ (3, 3, 1, None),
+ (3, 3, 2, None),
+ (3, 10, 2, 4)]:
+ if maxl is None:
+ label = "k%d_h%d_b%d" % (k, h, b)
+ else:
+ label = "k%d_h%d_b%d" % (k, maxl-1, b)
+ heap = SizedVirtualHeap(k, h, blocks_per_bucket=b)
+
+ fname = label+".pdf"
+ try:
+ os.remove(os.path.join(thisdir, fname))
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+
+ rc = heap.save_image_as_pdf(os.path.join(thisdir, label),
+ max_levels=maxl)
+
+ if not has_dot:
+ self.assertEqual(rc, False)
+ else:
+ self.assertEqual(rc, True)
+ self.assertEqual(
+ os.path.exists(os.path.join(thisdir, fname)), True)
+ try:
+ os.remove(os.path.join(thisdir, fname))
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+
+ data = list(range(heap.block_count()))
+ fname = label+"_data.pdf"
+ try:
+ os.remove(os.path.join(thisdir, fname))
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+ rc = heap.save_image_as_pdf(os.path.join(thisdir, fname),
+ data=data,
+ max_levels=maxl)
+ if not has_dot:
+ self.assertEqual(rc, False)
+ else:
+ self.assertEqual(rc, True)
+ self.assertEqual(
+ os.path.exists(os.path.join(thisdir, fname)), True)
+ try:
+ os.remove(os.path.join(thisdir, fname))
+ except OSError: # pragma: no cover
+ pass # pragma: no cover
+
+class TestMisc(unittest2.TestCase):
+
+ def test_calculate_bucket_level(self):
+ self.assertEqual(calculate_bucket_level(2, 0), 0)
+ self.assertEqual(calculate_bucket_level(2, 1), 1)
+ self.assertEqual(calculate_bucket_level(2, 2), 1)
+ self.assertEqual(calculate_bucket_level(2, 3), 2)
+ self.assertEqual(calculate_bucket_level(2, 4), 2)
+ self.assertEqual(calculate_bucket_level(2, 5), 2)
+ self.assertEqual(calculate_bucket_level(2, 6), 2)
+ self.assertEqual(calculate_bucket_level(2, 7), 3)
+
+ self.assertEqual(calculate_bucket_level(3, 0), 0)
+ self.assertEqual(calculate_bucket_level(3, 1), 1)
+ self.assertEqual(calculate_bucket_level(3, 2), 1)
+ self.assertEqual(calculate_bucket_level(3, 3), 1)
+ self.assertEqual(calculate_bucket_level(3, 4), 2)
+ self.assertEqual(calculate_bucket_level(3, 5), 2)
+ self.assertEqual(calculate_bucket_level(3, 6), 2)
+ self.assertEqual(calculate_bucket_level(3, 7), 2)
+ self.assertEqual(calculate_bucket_level(3, 8), 2)
+ self.assertEqual(calculate_bucket_level(3, 9), 2)
+ self.assertEqual(calculate_bucket_level(3, 10), 2)
+ self.assertEqual(calculate_bucket_level(3, 11), 2)
+ self.assertEqual(calculate_bucket_level(3, 12), 2)
+ self.assertEqual(calculate_bucket_level(3, 13), 3)
+
+ self.assertEqual(calculate_bucket_level(4, 0), 0)
+ self.assertEqual(calculate_bucket_level(4, 1), 1)
+ self.assertEqual(calculate_bucket_level(4, 2), 1)
+ self.assertEqual(calculate_bucket_level(4, 3), 1)
+ self.assertEqual(calculate_bucket_level(4, 4), 1)
+
+ self.assertEqual(calculate_bucket_level(4, 5), 2)
+ self.assertEqual(calculate_bucket_level(4, 6), 2)
+ self.assertEqual(calculate_bucket_level(4, 7), 2)
+ self.assertEqual(calculate_bucket_level(4, 8), 2)
+
+ self.assertEqual(calculate_bucket_level(4, 9), 2)
+ self.assertEqual(calculate_bucket_level(4, 10), 2)
+ self.assertEqual(calculate_bucket_level(4, 11), 2)
+ self.assertEqual(calculate_bucket_level(4, 12), 2)
+
+ self.assertEqual(calculate_bucket_level(4, 13), 2)
+ self.assertEqual(calculate_bucket_level(4, 14), 2)
+ self.assertEqual(calculate_bucket_level(4, 15), 2)
+ self.assertEqual(calculate_bucket_level(4, 16), 2)
+
+ self.assertEqual(calculate_bucket_level(4, 17), 2)
+ self.assertEqual(calculate_bucket_level(4, 18), 2)
+ self.assertEqual(calculate_bucket_level(4, 19), 2)
+ self.assertEqual(calculate_bucket_level(4, 20), 2)
+
+ self.assertEqual(calculate_bucket_level(4, 21), 3)
+
+ def test_clib_calculate_bucket_level(self):
+ for k in _test_bases:
+ for b in xrange(calculate_bucket_count_in_heap_with_height(k, 2)+2):
+ self.assertEqual(_clib.calculate_bucket_level(k, b),
+ calculate_bucket_level(k, b))
+ for k, b in [(89, 14648774),
+ (89, 14648775),
+ (90, 14648774),
+ (90, 14648775)]:
+ self.assertEqual(_clib.calculate_bucket_level(k, b),
+ calculate_bucket_level(k, b))
+
+ def test_clib_calculate_last_common_level(self):
+ for k in range(2, 8):
+ for b1 in xrange(calculate_bucket_count_in_heap_with_height(k, 2)+2):
+ for b2 in xrange(calculate_bucket_count_in_heap_with_height(k, 2)+2):
+ self.assertEqual(_clib.calculate_last_common_level(k, b1, b2),
+ calculate_last_common_level(k, b1, b2))
+ for k in [89,90]:
+ for b1 in [0, 100, 10000, 14648774, 14648775]:
+ for b2 in [0, 100, 10000, 14648774, 14648775]:
+ self.assertEqual(_clib.calculate_last_common_level(k, b1, b2),
+ calculate_last_common_level(k, b1, b2))
+
+ def test_calculate_necessary_heap_height(self):
+ self.assertEqual(calculate_necessary_heap_height(2, 1), 0)
+ self.assertEqual(calculate_necessary_heap_height(2, 2), 1)
+ self.assertEqual(calculate_necessary_heap_height(2, 3), 1)
+ self.assertEqual(calculate_necessary_heap_height(2, 4), 2)
+ self.assertEqual(calculate_necessary_heap_height(2, 5), 2)
+ self.assertEqual(calculate_necessary_heap_height(2, 6), 2)
+ self.assertEqual(calculate_necessary_heap_height(2, 7), 2)
+ self.assertEqual(calculate_necessary_heap_height(2, 8), 3)
+
+ self.assertEqual(calculate_necessary_heap_height(3, 1), 0)
+ self.assertEqual(calculate_necessary_heap_height(3, 2), 1)
+ self.assertEqual(calculate_necessary_heap_height(3, 3), 1)
+ self.assertEqual(calculate_necessary_heap_height(3, 4), 1)
+ self.assertEqual(calculate_necessary_heap_height(3, 5), 2)
+ self.assertEqual(calculate_necessary_heap_height(3, 6), 2)
+ self.assertEqual(calculate_necessary_heap_height(3, 7), 2)
+ self.assertEqual(calculate_necessary_heap_height(3, 8), 2)
+ self.assertEqual(calculate_necessary_heap_height(3, 9), 2)
+ self.assertEqual(calculate_necessary_heap_height(3, 10), 2)
+ self.assertEqual(calculate_necessary_heap_height(3, 11), 2)
+ self.assertEqual(calculate_necessary_heap_height(3, 12), 2)
+ self.assertEqual(calculate_necessary_heap_height(3, 13), 2)
+ self.assertEqual(calculate_necessary_heap_height(3, 14), 3)
+ self.assertEqual(calculate_necessary_heap_height(3, 15), 3)
+ self.assertEqual(calculate_necessary_heap_height(3, 16), 3)
+
+if __name__ == "__main__":
+ unittest2.main() # pragma: no cover
--- /dev/null
+import pyoram.util.misc
+import pyoram.util.virtual_heap
--- /dev/null
+import base64
+
+import six
+
+def log2floor(n):
+ """
+ Returns the exact value of floor(log2(n)).
+ No floating point calculations are used.
+ Requires positive integer type.
+ """
+ assert n > 0
+ return n.bit_length() - 1
+
+def log2ceil(n):
+ """
+ Returns the exact value of ceil(log2(n)).
+ No floating point calculations are used.
+ Requires positive integer type.
+ """
+ if n == 1:
+ return 0
+ return log2floor(n-1) + 1
+
+def intdivceil(x, y):
+ """
+ Returns the exact value of ceil(x // y).
+ No floating point calculations are used.
+ Requires positive integer types. The result
+ is undefined if at least one of the inputs
+ is floating point.
+ """
+ result = x // y
+ if (x % y):
+ result += 1
+ return result
+
+def save_private_key(filename, key):
+ with open(filename, "wb") as f:
+ f.write(base64.b64encode(key))
+
+def load_private_key(filename):
+ with open(filename, "rb") as f:
+ return base64.b64decode(f.read())
+
+from fractions import Fraction
+
+class MemorySize(object):
+
+ to_bytes = {}
+ to_bytes['b'] = lambda x: Fraction(x,8)
+ to_bytes['B'] = lambda x: Fraction(x,1)
+ to_bytes['KB'] = lambda x: Fraction(1000*x,1)
+ to_bytes['MB'] = lambda x: Fraction((1000**2)*x,1)
+ to_bytes['GB'] = lambda x: Fraction((1000**3)*x,1)
+ to_bytes['TB'] = lambda x: Fraction((1000**4)*x,1)
+ to_bytes['KiB'] = lambda x: Fraction(1024*x,1)
+ to_bytes['MiB'] = lambda x: Fraction((1024**2)*x,1)
+ to_bytes['GiB'] = lambda x: Fraction((1024**3)*x,1)
+ to_bytes['TiB'] = lambda x: Fraction((1024**4)*x,1)
+
+ def __init__(self, size, unit='B'):
+ assert size >= 0
+ self.numbytes = MemorySize.to_bytes[unit](Fraction.from_float(size))
+
+ def __str__(self):
+ if self.B < 1:
+ return "%.3f b" % (self.b)
+ if self.KB < 1:
+ return "%.3f B" % (self.B)
+ if self.MB < 1:
+ return "%.3f KB" % (self.KB)
+ if self.GB < 1:
+ return "%.3f MB" % (self.MB)
+ if self.TB < 1:
+ return "%.3f GB" % (self.GB)
+ return "%.3f TB" % (self.TB)
+
+ @property
+ def b(self): return self.numbytes*8
+ @property
+ def B(self): return self.numbytes
+
+ @property
+ def KB(self): return self.B/1000
+ @property
+ def MB(self): return self.KB/1000
+ @property
+ def GB(self): return self.MB/1000
+ @property
+ def TB(self): return self.GB/1000
+
+ @property
+ def KiB(self): return self.B/1024
+ @property
+ def MiB(self): return self.KiB/1024
+ @property
+ def GiB(self): return self.MiB/1024
+ @property
+ def TiB(self): return self.GiB/1024
+
+def chunkiter(objs, n=100):
+ """
+ Chunk an iterator of unknown size. The optional
+ keyword 'n' sets the chunk size (default 100).
+ """
+
+ objs = iter(objs)
+ try:
+ while (1):
+ chunk = []
+ while len(chunk) < n:
+ chunk.append(six.next(objs))
+ yield chunk
+ except StopIteration:
+ pass
+ if len(chunk):
+ yield chunk
--- /dev/null
+__all__ = ("VirtualHeap",
+ "SizedVirtualHeap")
+
+import os
+import sys
+import subprocess
+import random
+import string
+import tempfile
+
+from six.moves import xrange
+
+from pyoram.util._virtual_heap_helper import lib as _clib
+from pyoram.util.misc import log2floor
+
+numerals = ''.join([c for c in string.printable \
+ if ((c not in string.whitespace) and \
+ (c != '+') and (c != '-') and \
+ (c != '"') and (c != "'") and \
+ (c != '\\') and (c != '/'))])
+numeral_index = dict((c,i) for i,c in enumerate(numerals))
+
+# The maximum heap base for which base k labels
+# can be produced.
+max_k_labeled = len(numerals)
+
+def base10_integer_to_basek_string(k, x):
+ """Convert an integer into a base k string."""
+ if not (2 <= k <= max_k_labeled):
+ raise ValueError("k must be in range [2, %d]: %s"
+ % (max_k_labeled, k))
+ return ((x == 0) and numerals[0]) or \
+ (base10_integer_to_basek_string(k, x // k).\
+ lstrip(numerals[0]) + numerals[x % k])
+
+def basek_string_to_base10_integer(k, x):
+ """Convert a base k string into an integer."""
+ assert 1 < k <= max_k_labeled
+ return sum(numeral_index[c]*(k**i)
+ for i, c in enumerate(reversed(x)))
+
+# _clib defines a faster version of this function
+def calculate_bucket_level(k, b):
+ """
+ Calculate the level in which a 0-based bucket
+ lives inside of a k-ary heap.
+ """
+ assert k >= 2
+ if k == 2:
+ return log2floor(b+1)
+ v = (k - 1) * (b + 1) + 1
+ h = 0
+ while k**(h+1) < v:
+ h += 1
+ return h
+
+# _clib defines a faster version of this function
+def calculate_last_common_level(k, b1, b2):
+ """
+ Calculate the highest level after which the
+ paths from the root to these buckets diverge.
+ """
+ l1 = calculate_bucket_level(k, b1)
+ l2 = calculate_bucket_level(k, b2)
+ while l1 > l2:
+ b1 = (b1-1)//k
+ l1 -= 1
+ while l2 > l1:
+ b2 = (b2-1)//k
+ l2 -= 1
+ while b1 != b2:
+ b1 = (b1-1)//k
+ b2 = (b2-1)//k
+ l1 -= 1
+ return l1
+
+def calculate_necessary_heap_height(k, n):
+ """
+ Calculate the necessary k-ary heap height
+ to store n buckets.
+ """
+ assert n >= 1
+ return calculate_bucket_level(k, n-1)
+
+def calculate_bucket_count_in_heap_with_height(k, h):
+ """
+ Calculate the number of buckets in a
+ k-ary heap of height h.
+ """
+ assert h >= 0
+ return ((k**(h+1)) - 1) // (k - 1)
+
+def calculate_bucket_count_in_heap_at_level(k, l):
+ """
+ Calculate the number of buckets in a
+ k-ary heap at level l.
+ """
+ assert l >= 0
+ return k**l
+
+def calculate_leaf_bucket_count_in_heap_with_height(k, h):
+ """
+ Calculate the number of buckets in the
+ leaf-level of a k-ary heap of height h.
+ """
+ return calculate_bucket_count_in_heap_at_level(k, h)
+
+def create_node_type(k):
+
+ class VirtualHeapNode(object):
+ __slots__ = ("bucket", "level")
+ def __init__(self, bucket):
+ assert bucket >= 0
+ self.bucket = bucket
+ self.level = _clib.calculate_bucket_level(self.k, self.bucket)
+
+ def __hash__(self):
+ return self.bucket.__hash__()
+ def __int__(self):
+ return self.bucket
+ def __lt__(self, other):
+ return self.bucket < other
+ def __le__(self, other):
+ return self.bucket <= other
+ def __eq__(self, other):
+ return self.bucket == other
+ def __ne__(self, other):
+ return self.bucket != other
+ def __gt__(self, other):
+ return self.bucket > other
+ def __ge__(self, other):
+ return self.bucket >= other
+ def last_common_level(self, n):
+ return _clib.calculate_last_common_level(self.k,
+ self.bucket,
+ n.bucket)
+ def child_node(self, c):
+ assert type(c) is int
+ assert 0 <= c < self.k
+ return VirtualHeapNode(self.k * self.bucket + 1 + c)
+ def parent_node(self):
+ if self.bucket != 0:
+ return VirtualHeapNode((self.bucket - 1)//self.k)
+ return None
+ def ancestor_node_at_level(self, level):
+ if level > self.level:
+ return None
+ current = self
+ while current.level != level:
+ current = current.parent_node()
+ return current
+ def path_to_root(self):
+ bucket = self.bucket
+ yield self
+ while bucket != 0:
+ bucket = (bucket - 1)//self.k
+ yield type(self)(bucket)
+ def path_from_root(self):
+ return list(reversed(list(self.path_to_root())))
+ def bucket_path_to_root(self):
+ bucket = self.bucket
+ yield bucket
+ while bucket != 0:
+ bucket = (bucket - 1)//self.k
+ yield bucket
+ def bucket_path_from_root(self):
+ return list(reversed(list(self.bucket_path_to_root())))
+
+ #
+ # Expensive Functions
+ #
+ def __repr__(self):
+ try:
+ label = self.label()
+ except ValueError:
+ # presumably, k is too large
+ label = "<unknown>"
+ return ("VirtualHeapNode(k=%s, bucket=%s, level=%s, label=%r)"
+ % (self.k, self.bucket, self.level, label))
+ def __str__(self):
+ """Returns a tuple (<level>, <bucket offset within level>)."""
+ if self.bucket != 0:
+ return ("(%s, %s)"
+ % (self.level,
+ self.bucket -
+ calculate_bucket_count_in_heap_with_height(self.k,
+ self.level-1)))
+ assert self.level == 0
+ return "(0, 0)"
+
+ def label(self):
+ assert 0 <= self.bucket
+ if self.level == 0:
+ return ''
+ b_offset = self.bucket - \
+ calculate_bucket_count_in_heap_with_height(self.k,
+ self.level-1)
+ basek = base10_integer_to_basek_string(self.k, b_offset)
+ return basek.zfill(self.level)
+
+ def is_node_on_path(self, n):
+ if n.level <= self.level:
+ n_label = n.label()
+ if n_label == "":
+ return True
+ return self.label().startswith(n_label)
+ return False
+
+ VirtualHeapNode.k = k
+
+ return VirtualHeapNode
+
+class VirtualHeap(object):
+
+ clib = _clib
+ random = random.SystemRandom()
+
+ def __init__(self, k, blocks_per_bucket=1):
+ assert 1 < k
+ assert blocks_per_bucket >= 1
+ self._k = k
+ self._blocks_per_bucket = blocks_per_bucket
+ self.Node = create_node_type(k)
+
+ @property
+ def k(self):
+ return self._k
+
+ def node_label_to_bucket(self, label):
+ if len(label) > 0:
+ return \
+ (calculate_bucket_count_in_heap_with_height(self.k,
+ len(label)-1) +
+ basek_string_to_base10_integer(self.k, label))
+ return 0
+
+ #
+ # Buckets (0-based integer, equivalent to block for heap
+ # with blocks_per_bucket=1)
+ #
+
+ @property
+ def blocks_per_bucket(self):
+ return self._blocks_per_bucket
+
+ def bucket_count_at_level(self, l):
+ return calculate_bucket_count_in_heap_at_level(self.k, l)
+ def first_bucket_at_level(self, l):
+ if l > 0:
+ return calculate_bucket_count_in_heap_with_height(self.k, l-1)
+ return 0
+ def last_bucket_at_level(self, l):
+ return calculate_bucket_count_in_heap_with_height(self.k, l) - 1
+ def random_bucket_up_to_level(self, l):
+ return self.random.randint(self.first_bucket_at_level(0),
+ self.last_bucket_at_level(l))
+ def random_bucket_at_level(self, l):
+ return self.random.randint(self.first_bucket_at_level(l),
+ self.first_bucket_at_level(l+1)-1)
+
+ #
+ # Nodes (a class that helps with heap path calculations)
+ #
+
+ def root_node(self):
+ return self.first_node_at_level(0)
+ def node_count_at_level(self, l):
+ return self.bucket_count_at_level(l)
+ def first_node_at_level(self, l):
+ return self.Node(self.first_bucket_at_level(l))
+ def last_node_at_level(self, l):
+ return self.Node(self.last_bucket_at_level(l))
+ def random_node_up_to_level(self, l):
+ return self.Node(self.random_bucket_up_to_level(l))
+ def random_node_at_level(self, l):
+ return self.Node(self.random_bucket_at_level(l))
+
+ #
+ # Block (0-based integer)
+ #
+
+ def bucket_to_block(self, b):
+ assert b >= 0
+ return b * self.blocks_per_bucket
+ def block_to_bucket(self, s):
+ assert s >= 0
+ return s//self.blocks_per_bucket
+ def first_block_in_bucket(self, b):
+ return self.bucket_to_block(b)
+ def last_block_in_bucket(self, b):
+ return self.bucket_to_block(b) + self.blocks_per_bucket - 1
+ def block_count_at_level(self, l):
+ return self.bucket_count_at_level(l) * self.blocks_per_bucket
+ def first_block_at_level(self, l):
+ return self.bucket_to_block(self.first_bucket_at_level(l))
+ def last_block_at_level(self, l):
+ return self.bucket_to_block(self.first_bucket_at_level(l+1)) - 1
+
+class SizedVirtualHeap(VirtualHeap):
+
+ def __init__(self, k, height, blocks_per_bucket=1):
+ super(SizedVirtualHeap, self).\
+ __init__(k, blocks_per_bucket=blocks_per_bucket)
+ self._height = height
+
+ #
+ # Size properties
+ #
+ @property
+ def height(self):
+ return self._height
+ @property
+ def levels(self):
+ return self.height + 1
+ @property
+ def first_level(self):
+ return 0
+ @property
+ def last_level(self):
+ return self.height
+
+ #
+ # Buckets (0-based integer, equivalent to block for heap
+ # with blocks_per_bucket=1)
+ #
+
+ def bucket_count(self):
+ return calculate_bucket_count_in_heap_with_height(self.k,
+ self.height)
+ def leaf_bucket_count(self):
+ return calculate_leaf_bucket_count_in_heap_with_height(self.k,
+ self.height)
+ def first_leaf_bucket(self):
+ return self.first_bucket_at_level(self.height)
+ def last_leaf_bucket(self):
+ return self.last_bucket_at_level(self.height)
+ def random_bucket(self):
+ return self.random.randint(self.first_bucket_at_level(0),
+ self.last_leaf_bucket())
+ def random_leaf_bucket(self):
+ return self.random_bucket_at_level(self.height)
+
+ #
+ # Nodes (a class that helps with heap path calculations)
+ #
+
+ def is_nil_node(self, n):
+ return n.bucket >= self.bucket_count()
+ def node_count(self):
+ return self.bucket_count()
+ def leaf_node_count(self):
+ return self.leaf_bucket_count()
+ def first_leaf_node(self):
+ return self.Node(self.first_leaf_bucket())
+ def last_leaf_node(self):
+ return self.Node(self.last_leaf_bucket())
+ def random_leaf_node(self):
+ return self.Node(self.random_leaf_bucket())
+ def random_node(self):
+ return self.Node(self.random_bucket())
+
+ #
+ # Block (0-based integer)
+ #
+
+ def block_count(self):
+ return self.bucket_count() * self.blocks_per_bucket
+ def leaf_block_count(self):
+ return self.leaf_bucket_count() * self.blocks_per_bucket
+ def first_leaf_block(self):
+ return self.first_block_in_bucket(self.first_leaf_bucket())
+ def last_leaf_block(self):
+ return self.last_block_in_bucket(self.last_leaf_bucket())
+
+ #
+ # Visualization
+ #
+
+ def write_as_dot(self, f, data=None, max_levels=None):
+ "Write the tree in the dot language format to f."
+ assert (max_levels is None) or (max_levels >= 0)
+ def visit_node(n, levels):
+ lbl = "{"
+ if data is None:
+ if self.k <= max_k_labeled:
+ lbl = repr(n.label()).\
+ replace("{","\{").\
+ replace("}","\}").\
+ replace("|","\|").\
+ replace("<","\<").\
+ replace(">","\>")
+ else:
+ lbl = str(n)
+ else:
+ s = self.bucket_to_block(n.bucket)
+ for i in xrange(self.blocks_per_bucket):
+ lbl += "{%s}" % (data[s+i])
+ if i + 1 != self.blocks_per_bucket:
+ lbl += "|"
+ lbl += "}"
+ f.write(" %s [penwidth=%s,label=\"%s\"];\n"
+ % (n.bucket, 1, lbl))
+ levels += 1
+ if (max_levels is None) or (levels <= max_levels):
+ for i in xrange(self.k):
+ cn = n.child_node(i)
+ if not self.is_nil_node(cn):
+ visit_node(cn, levels)
+ f.write(" %s -> %s ;\n" % (n.bucket, cn.bucket))
+
+ f.write("// Created by SizedVirtualHeap.write_as_dot(...)\n")
+ f.write("digraph heaptree {\n")
+ f.write("node [shape=record]\n")
+
+ if (max_levels is None) or (max_levels > 0):
+ visit_node(self.root_node(), 1)
+ f.write("}\n")
+
+ def save_image_as_pdf(self, filename, data=None, max_levels=None):
+ "Write the heap as PDF file."
+ assert (max_levels is None) or (max_levels >= 0)
+ import os
+ if not filename.endswith('.pdf'):
+ filename = filename+'.pdf'
+ tmpfd, tmpname = tempfile.mkstemp(suffix='dot')
+ with open(tmpname, 'w') as f:
+ self.write_as_dot(f, data=data, max_levels=max_levels)
+ os.close(tmpfd)
+ try:
+ subprocess.call(['dot',
+ tmpname,
+ '-Tpdf',
+ '-o',
+ ('%s'%filename)])
+ except OSError:
+ sys.stderr.write(
+ "DOT -> PDF conversion failed. See DOT file: %s\n"
+ % (tmpname))
+ return False
+ os.remove(tmpname)
+ return True
Table t1 = new Table("http://dc-6.calit2.uci.edu/test.iotcloud/", "reallysecret", 361, -1);
+
+ try {
+ Thread.sleep(5000);
+ } catch (Exception e) {
+
+ }
+
+ long start = System.currentTimeMillis();
t1.rebuild();
+
+ long stop1 = System.currentTimeMillis();
+ t1.update();
+ long stop2 = System.currentTimeMillis();
+ System.out.println("Done......");
+ System.out.println(stop1 - start);
+ System.out.println(stop2 - stop1);
- String pingTimerKey = "sensorController";
- IoTString ipingTimerKey = new IoTString(pingTimerKey);
+ // t1.startTransaction();
+ // t1.addKV(ipingTimerKey, ipingTimer);
+ // t1.addKV(ia1, senDat);
+ // t1.commitTransaction();
- String a1 = "sensor";
- IoTString ia1 = new IoTString(a1);
- System.out.println("Starting System");
+ // String pingTimerKey = "sensorController";
+ // IoTString ipingTimerKey = new IoTString(pingTimerKey);
+ // String a1 = "sensor";
+ // IoTString ia1 = new IoTString(a1);
- while (true) {
- try {
+ // System.out.println("Starting System");
- Runtime runTime = Runtime.getRuntime();
- // Process proc = runTime.exec("/opt/vc/bin/vcgencmd measure_temp | tr -d 'temp=' | tr -d \"'C\"");
- Process proc = runTime.exec("/opt/vc/bin/vcgencmd measure_temp");
- BufferedReader reader = new BufferedReader(new InputStreamReader(proc.getInputStream()));
- String line = null;
- String dat = "";
- while ((line = reader.readLine()) != null) {
- System.out.println(line);
- dat = line;
- }
- reader.close();
+ // while (true) {
+ // try {
- String pingTimer = Long.toString(System.currentTimeMillis());
- IoTString ipingTimer = new IoTString(pingTimer);
- IoTString senDat = new IoTString(dat);
+ // // Runtime runTime = Runtime.getRuntime();
+ // // // Process proc = runTime.exec("/opt/vc/bin/vcgencmd measure_temp | tr -d 'temp=' | tr -d \"'C\"");
+ // // Process proc = runTime.exec("/opt/vc/bin/vcgencmd measure_temp");
+ // // BufferedReader reader = new BufferedReader(new InputStreamReader(proc.getInputStream()));
+ // // String line = null;
+ // // String dat = "";
+ // // while ((line = reader.readLine()) != null) {
+ // // System.out.println(line);
+ // // dat = line;
+ // // }
+ // // reader.close();
- t1.update();
- t1.startTransaction();
- t1.addKV(ipingTimerKey, ipingTimer);
- t1.addKV(ia1, senDat);
- t1.commitTransaction();
+ // // String pingTimer = Long.toString(System.currentTimeMillis());
+ // // IoTString ipingTimer = new IoTString(pingTimer);
+ // IoTString senDat = new IoTString(dat);
+ // t1.update();
+ // t1.startTransaction();
+ // t1.addKV(ipingTimerKey, ipingTimer);
+ // t1.addKV(ia1, senDat);
+ // t1.commitTransaction();
- Thread.sleep(5000);
+ // Thread.sleep(5000);
- } catch (Error e) {
- e.printStackTrace();
+ // } catch (Error e) {
+ // e.printStackTrace();
- Runtime runTime = Runtime.getRuntime();
- runTime.exec("gpio mode 4 out");
+ // Runtime runTime = Runtime.getRuntime();
+ // runTime.exec("gpio mode 4 out");
- while (true) {
- runTime.exec("gpio write 4 1");
- Thread.sleep(500);
- runTime.exec("gpio write 4 0");
- Thread.sleep(500);
- }
- }
- }
+ // while (true) {
+ // runTime.exec("gpio write 4 1");
+ // Thread.sleep(500);
+ // runTime.exec("gpio write 4 0");
+ // Thread.sleep(500);
+ // }
+ // }
+ // }
}
}
\ No newline at end of file