From: bdemsky Date: Thu, 29 Mar 2018 11:36:55 +0000 (-0700) Subject: rename files X-Git-Url: http://plrg.eecs.uci.edu/git/?p=iotcloud.git;a=commitdiff_plain;h=786e40250f31eff04eec25bbcaae3cd916fedb14 rename files --- diff --git a/version2/src/C/Abort.cc b/version2/src/C/Abort.cc deleted file mode 100644 index f653e97..0000000 --- a/version2/src/C/Abort.cc +++ /dev/null @@ -1,44 +0,0 @@ -#include "Abort.h" -#include "ByteBuffer.h" - -Abort::Abort(Slot *slot, int64_t _transactionClientLocalSequenceNumber, int64_t _transactionSequenceNumber, int64_t _transactionMachineId, int64_t _transactionArbitrator, int64_t _arbitratorLocalSequenceNumber) : - Entry(slot), - transactionClientLocalSequenceNumber(_transactionClientLocalSequenceNumber), - transactionSequenceNumber(_transactionSequenceNumber), - transactionMachineId(_transactionMachineId), - transactionArbitrator(_transactionArbitrator), - arbitratorLocalSequenceNumber(_arbitratorLocalSequenceNumber), - abortId(Pair(transactionMachineId, transactionClientLocalSequenceNumber)) { -} - -Abort::Abort(Slot *slot, int64_t _transactionClientLocalSequenceNumber, int64_t _transactionSequenceNumber, int64_t _sequenceNumber, int64_t _transactionMachineId, int64_t _transactionArbitrator, int64_t _arbitratorLocalSequenceNumber) : - Entry(slot), - transactionClientLocalSequenceNumber(_transactionClientLocalSequenceNumber), - transactionSequenceNumber(_transactionSequenceNumber), - sequenceNumber(_sequenceNumber), - transactionMachineId(_transactionMachineId), - transactionArbitrator(_transactionArbitrator), - arbitratorLocalSequenceNumber(_arbitratorLocalSequenceNumber), - abortId(Pair(transactionMachineId, transactionClientLocalSequenceNumber)) { -} - -Entry *Abort_decode(Slot *slot, ByteBuffer *bb) { - int64_t transactionClientLocalSequenceNumber = bb->getLong(); - int64_t transactionSequenceNumber = bb->getLong(); - int64_t sequenceNumber = bb->getLong(); - int64_t transactionMachineId = bb->getLong(); - int64_t transactionArbitrator = bb->getLong(); - int64_t arbitratorLocalSequenceNumber = bb->getLong(); - - return new Abort(slot, transactionClientLocalSequenceNumber, transactionSequenceNumber, sequenceNumber, transactionMachineId, transactionArbitrator, arbitratorLocalSequenceNumber); -} - -void Abort::encode(ByteBuffer *bb) { - bb->put(TypeAbort); - bb->putLong(transactionClientLocalSequenceNumber); - bb->putLong(transactionSequenceNumber); - bb->putLong(sequenceNumber); - bb->putLong(transactionMachineId); - bb->putLong(transactionArbitrator); - bb->putLong(arbitratorLocalSequenceNumber); -} diff --git a/version2/src/C/Abort.cpp b/version2/src/C/Abort.cpp new file mode 100644 index 0000000..f653e97 --- /dev/null +++ b/version2/src/C/Abort.cpp @@ -0,0 +1,44 @@ +#include "Abort.h" +#include "ByteBuffer.h" + +Abort::Abort(Slot *slot, int64_t _transactionClientLocalSequenceNumber, int64_t _transactionSequenceNumber, int64_t _transactionMachineId, int64_t _transactionArbitrator, int64_t _arbitratorLocalSequenceNumber) : + Entry(slot), + transactionClientLocalSequenceNumber(_transactionClientLocalSequenceNumber), + transactionSequenceNumber(_transactionSequenceNumber), + transactionMachineId(_transactionMachineId), + transactionArbitrator(_transactionArbitrator), + arbitratorLocalSequenceNumber(_arbitratorLocalSequenceNumber), + abortId(Pair(transactionMachineId, transactionClientLocalSequenceNumber)) { +} + +Abort::Abort(Slot *slot, int64_t _transactionClientLocalSequenceNumber, int64_t _transactionSequenceNumber, int64_t _sequenceNumber, int64_t _transactionMachineId, int64_t _transactionArbitrator, int64_t _arbitratorLocalSequenceNumber) : + Entry(slot), + transactionClientLocalSequenceNumber(_transactionClientLocalSequenceNumber), + transactionSequenceNumber(_transactionSequenceNumber), + sequenceNumber(_sequenceNumber), + transactionMachineId(_transactionMachineId), + transactionArbitrator(_transactionArbitrator), + arbitratorLocalSequenceNumber(_arbitratorLocalSequenceNumber), + abortId(Pair(transactionMachineId, transactionClientLocalSequenceNumber)) { +} + +Entry *Abort_decode(Slot *slot, ByteBuffer *bb) { + int64_t transactionClientLocalSequenceNumber = bb->getLong(); + int64_t transactionSequenceNumber = bb->getLong(); + int64_t sequenceNumber = bb->getLong(); + int64_t transactionMachineId = bb->getLong(); + int64_t transactionArbitrator = bb->getLong(); + int64_t arbitratorLocalSequenceNumber = bb->getLong(); + + return new Abort(slot, transactionClientLocalSequenceNumber, transactionSequenceNumber, sequenceNumber, transactionMachineId, transactionArbitrator, arbitratorLocalSequenceNumber); +} + +void Abort::encode(ByteBuffer *bb) { + bb->put(TypeAbort); + bb->putLong(transactionClientLocalSequenceNumber); + bb->putLong(transactionSequenceNumber); + bb->putLong(sequenceNumber); + bb->putLong(transactionMachineId); + bb->putLong(transactionArbitrator); + bb->putLong(arbitratorLocalSequenceNumber); +} diff --git a/version2/src/C/ArbitrationRound.cc b/version2/src/C/ArbitrationRound.cc deleted file mode 100644 index cc316f0..0000000 --- a/version2/src/C/ArbitrationRound.cc +++ /dev/null @@ -1,123 +0,0 @@ -#include "ArbitrationRound.h" -#include "Commit.h" -#include "CommitPart.h" - -ArbitrationRound::ArbitrationRound(Commit *_commit, Hashset *_abortsBefore) : - abortsBefore(_abortsBefore), - parts(new Vector()), - commit(_commit), - currentSize(0), - didSendPart(false), - didGenerateParts(false) { - - if (commit != NULL) { - commit->createCommitParts(); - currentSize += commit->getNumberOfParts(); - } - - currentSize += abortsBefore->size(); -} - -ArbitrationRound::~ArbitrationRound() { - delete abortsBefore; - uint partsSize = parts->size(); - for (uint i = 0; i < partsSize; i++) { - Entry * part = parts->get(i); - part->releaseRef(); - } - delete parts; - if (commit != NULL) - delete commit; -} - -void ArbitrationRound::generateParts() { - if (didGenerateParts) { - return; - } - uint partsSize = parts->size(); - for (uint i = 0; i < partsSize; i++) { - Entry * part = parts->get(i); - part->releaseRef(); - } - parts->clear(); - SetIterator *abit = abortsBefore->iterator(); - while (abit->hasNext()) - parts->add((Entry *)abit->next()); - delete abit; - if (commit != NULL) { - Vector *cParts = commit->getParts(); - uint cPartsSize = cParts->size(); - for (uint i = 0; i < cPartsSize; i++) { - CommitPart * part = cParts->get(i); - part->acquireRef(); - parts->add((Entry *)part); - } - } -} - -Vector *ArbitrationRound::getParts() { - return parts; -} - -void ArbitrationRound::removeParts(Vector *removeParts) { - uint size = removeParts->size(); - for(uint i=0; i < size; i++) { - Entry * e = removeParts->get(i); - if (parts->remove(e)) - e->releaseRef(); - } - didSendPart = true; -} - - -bool ArbitrationRound::isDoneSending() { - if ((commit == NULL) && abortsBefore->isEmpty()) { - return true; - } - return parts->isEmpty(); -} - -Commit *ArbitrationRound::getCommit() { - return commit; -} - -void ArbitrationRound::setCommit(Commit *_commit) { - if (commit != NULL) { - currentSize -= commit->getNumberOfParts(); - } - commit = _commit; - - if (commit != NULL) { - currentSize += commit->getNumberOfParts(); - } -} - -void ArbitrationRound::addAbort(Abort *abort) { - abortsBefore->add(abort); - currentSize++; -} - -void ArbitrationRound::addAborts(Hashset *aborts) { - abortsBefore->addAll(aborts); - currentSize += aborts->size(); -} - -Hashset *ArbitrationRound::getAborts() { - return abortsBefore; -} - -int ArbitrationRound::getAbortsCount() { - return abortsBefore->size(); -} - -int ArbitrationRound::getCurrentSize() { - return currentSize; -} - -bool ArbitrationRound::isFull() { - return currentSize >= ArbitrationRound_MAX_PARTS; -} - -bool ArbitrationRound::getDidSendPart() { - return didSendPart; -} diff --git a/version2/src/C/ArbitrationRound.cpp b/version2/src/C/ArbitrationRound.cpp new file mode 100644 index 0000000..cc316f0 --- /dev/null +++ b/version2/src/C/ArbitrationRound.cpp @@ -0,0 +1,123 @@ +#include "ArbitrationRound.h" +#include "Commit.h" +#include "CommitPart.h" + +ArbitrationRound::ArbitrationRound(Commit *_commit, Hashset *_abortsBefore) : + abortsBefore(_abortsBefore), + parts(new Vector()), + commit(_commit), + currentSize(0), + didSendPart(false), + didGenerateParts(false) { + + if (commit != NULL) { + commit->createCommitParts(); + currentSize += commit->getNumberOfParts(); + } + + currentSize += abortsBefore->size(); +} + +ArbitrationRound::~ArbitrationRound() { + delete abortsBefore; + uint partsSize = parts->size(); + for (uint i = 0; i < partsSize; i++) { + Entry * part = parts->get(i); + part->releaseRef(); + } + delete parts; + if (commit != NULL) + delete commit; +} + +void ArbitrationRound::generateParts() { + if (didGenerateParts) { + return; + } + uint partsSize = parts->size(); + for (uint i = 0; i < partsSize; i++) { + Entry * part = parts->get(i); + part->releaseRef(); + } + parts->clear(); + SetIterator *abit = abortsBefore->iterator(); + while (abit->hasNext()) + parts->add((Entry *)abit->next()); + delete abit; + if (commit != NULL) { + Vector *cParts = commit->getParts(); + uint cPartsSize = cParts->size(); + for (uint i = 0; i < cPartsSize; i++) { + CommitPart * part = cParts->get(i); + part->acquireRef(); + parts->add((Entry *)part); + } + } +} + +Vector *ArbitrationRound::getParts() { + return parts; +} + +void ArbitrationRound::removeParts(Vector *removeParts) { + uint size = removeParts->size(); + for(uint i=0; i < size; i++) { + Entry * e = removeParts->get(i); + if (parts->remove(e)) + e->releaseRef(); + } + didSendPart = true; +} + + +bool ArbitrationRound::isDoneSending() { + if ((commit == NULL) && abortsBefore->isEmpty()) { + return true; + } + return parts->isEmpty(); +} + +Commit *ArbitrationRound::getCommit() { + return commit; +} + +void ArbitrationRound::setCommit(Commit *_commit) { + if (commit != NULL) { + currentSize -= commit->getNumberOfParts(); + } + commit = _commit; + + if (commit != NULL) { + currentSize += commit->getNumberOfParts(); + } +} + +void ArbitrationRound::addAbort(Abort *abort) { + abortsBefore->add(abort); + currentSize++; +} + +void ArbitrationRound::addAborts(Hashset *aborts) { + abortsBefore->addAll(aborts); + currentSize += aborts->size(); +} + +Hashset *ArbitrationRound::getAborts() { + return abortsBefore; +} + +int ArbitrationRound::getAbortsCount() { + return abortsBefore->size(); +} + +int ArbitrationRound::getCurrentSize() { + return currentSize; +} + +bool ArbitrationRound::isFull() { + return currentSize >= ArbitrationRound_MAX_PARTS; +} + +bool ArbitrationRound::getDidSendPart() { + return didSendPart; +} diff --git a/version2/src/C/ByteBuffer.cc b/version2/src/C/ByteBuffer.cc deleted file mode 100644 index 22e28cd..0000000 --- a/version2/src/C/ByteBuffer.cc +++ /dev/null @@ -1,81 +0,0 @@ -#include "ByteBuffer.h" -#include - -ByteBuffer::ByteBuffer(Array *array) : - buffer(array), - offset(0) { -} - -void ByteBuffer::put(char c) { - buffer->set(offset++, c); -} - -void ByteBuffer::putInt(int32_t l) { - buffer->set(offset++, (char)(l >> 24)); - buffer->set(offset++, (char)((l >> 16) & 0xff)); - buffer->set(offset++, (char)((l >> 8) & 0xff)); - buffer->set(offset++, (char)(l & 0xff)); -} - -void ByteBuffer::putLong(int64_t l) { - buffer->set(offset++, (char)(l >> 56)); - buffer->set(offset++, (char)((l >> 48) & 0xff)); - buffer->set(offset++, (char)((l >> 40) & 0xff)); - buffer->set(offset++, (char)((l >> 32) & 0xff)); - buffer->set(offset++, (char)((l >> 24) & 0xff)); - buffer->set(offset++, (char)((l >> 16) & 0xff)); - buffer->set(offset++, (char)((l >> 8) & 0xff)); - buffer->set(offset++, (char)(l & 0xff)); -} - -void ByteBuffer::put(Array *array) { - memcpy(&buffer->internalArray()[offset], array->internalArray(), array->length()); - offset += array->length(); -} - -int64_t ByteBuffer::getLong() { - char *array = &buffer->internalArray()[offset]; - offset += 8; - return (((int64_t)(unsigned char)array[0]) << 56) | - (((int64_t)(unsigned char)array[1]) << 48) | - (((int64_t)(unsigned char)array[2]) << 40) | - (((int64_t)(unsigned char)array[3]) << 32) | - (((int64_t)(unsigned char)array[4]) << 24) | - (((int64_t)(unsigned char)array[5]) << 16) | - (((int64_t)(unsigned char)array[6]) << 8) | - (((int64_t)(unsigned char)array[7])); -} - -int32_t ByteBuffer::getInt() { - char *array = &buffer->internalArray()[offset]; - offset += 4; - return (((int32_t)(unsigned char)array[0]) << 24) | - (((int32_t)(unsigned char)array[1]) << 16) | - (((int32_t)(unsigned char)array[2]) << 8) | - (((int32_t)(unsigned char)array[3])); -} - -char ByteBuffer::get() { - return buffer->get(offset++); -} - -void ByteBuffer::get(Array *array) { - memcpy(array->internalArray(), &buffer->internalArray()[offset], array->length()); - offset += array->length(); -} - -void ByteBuffer::position(int32_t newPosition) { - offset = newPosition; -} - -Array *ByteBuffer::array() { - return buffer; -} - -ByteBuffer *ByteBuffer_wrap(Array *array) { - return new ByteBuffer(array); -} - -ByteBuffer *ByteBuffer_allocate(uint size) { - return new ByteBuffer(new Array(size)); -} diff --git a/version2/src/C/ByteBuffer.cpp b/version2/src/C/ByteBuffer.cpp new file mode 100644 index 0000000..22e28cd --- /dev/null +++ b/version2/src/C/ByteBuffer.cpp @@ -0,0 +1,81 @@ +#include "ByteBuffer.h" +#include + +ByteBuffer::ByteBuffer(Array *array) : + buffer(array), + offset(0) { +} + +void ByteBuffer::put(char c) { + buffer->set(offset++, c); +} + +void ByteBuffer::putInt(int32_t l) { + buffer->set(offset++, (char)(l >> 24)); + buffer->set(offset++, (char)((l >> 16) & 0xff)); + buffer->set(offset++, (char)((l >> 8) & 0xff)); + buffer->set(offset++, (char)(l & 0xff)); +} + +void ByteBuffer::putLong(int64_t l) { + buffer->set(offset++, (char)(l >> 56)); + buffer->set(offset++, (char)((l >> 48) & 0xff)); + buffer->set(offset++, (char)((l >> 40) & 0xff)); + buffer->set(offset++, (char)((l >> 32) & 0xff)); + buffer->set(offset++, (char)((l >> 24) & 0xff)); + buffer->set(offset++, (char)((l >> 16) & 0xff)); + buffer->set(offset++, (char)((l >> 8) & 0xff)); + buffer->set(offset++, (char)(l & 0xff)); +} + +void ByteBuffer::put(Array *array) { + memcpy(&buffer->internalArray()[offset], array->internalArray(), array->length()); + offset += array->length(); +} + +int64_t ByteBuffer::getLong() { + char *array = &buffer->internalArray()[offset]; + offset += 8; + return (((int64_t)(unsigned char)array[0]) << 56) | + (((int64_t)(unsigned char)array[1]) << 48) | + (((int64_t)(unsigned char)array[2]) << 40) | + (((int64_t)(unsigned char)array[3]) << 32) | + (((int64_t)(unsigned char)array[4]) << 24) | + (((int64_t)(unsigned char)array[5]) << 16) | + (((int64_t)(unsigned char)array[6]) << 8) | + (((int64_t)(unsigned char)array[7])); +} + +int32_t ByteBuffer::getInt() { + char *array = &buffer->internalArray()[offset]; + offset += 4; + return (((int32_t)(unsigned char)array[0]) << 24) | + (((int32_t)(unsigned char)array[1]) << 16) | + (((int32_t)(unsigned char)array[2]) << 8) | + (((int32_t)(unsigned char)array[3])); +} + +char ByteBuffer::get() { + return buffer->get(offset++); +} + +void ByteBuffer::get(Array *array) { + memcpy(array->internalArray(), &buffer->internalArray()[offset], array->length()); + offset += array->length(); +} + +void ByteBuffer::position(int32_t newPosition) { + offset = newPosition; +} + +Array *ByteBuffer::array() { + return buffer; +} + +ByteBuffer *ByteBuffer_wrap(Array *array) { + return new ByteBuffer(array); +} + +ByteBuffer *ByteBuffer_allocate(uint size) { + return new ByteBuffer(new Array(size)); +} diff --git a/version2/src/C/CloudComm.cc b/version2/src/C/CloudComm.cc deleted file mode 100644 index 8fe0cee..0000000 --- a/version2/src/C/CloudComm.cc +++ /dev/null @@ -1,819 +0,0 @@ -#include "CloudComm.h" -#include "TimingSingleton.h" -#include "SecureRandom.h" -#include "IoTString.h" -#include "Error.h" -#include "URL.h" -#include "Mac.h" -#include "Table.h" -#include "Slot.h" -#include "Crypto.h" -#include "ByteBuffer.h" -#include "aes.h" -#include -#include -#include -#include -#include -#include - -/** - * Empty Constructor needed for child class. - */ -CloudComm::CloudComm() : - baseurl(NULL), - key(NULL), - mac(NULL), - password(NULL), - random(NULL), - salt(NULL), - table(NULL), - listeningPort(-1), - doEnd(false), - timer(TimingSingleton_getInstance()), - getslot(new Array("getslot", 7)), - putslot(new Array("putslot", 7)) -{ -} - -void *threadWrapper(void *cloud) { - CloudComm *c = (CloudComm *) cloud; - c->localServerWorkerFunction(); - return NULL; -} - -/** - * Constructor for actual use. Takes in the url and password. - */ -CloudComm::CloudComm(Table *_table, IoTString *_baseurl, IoTString *_password, int _listeningPort) : - baseurl(new IoTString(_baseurl)), - key(NULL), - mac(NULL), - password(new IoTString(_password)), - random(new SecureRandom()), - salt(NULL), - table(_table), - listeningPort(_listeningPort), - doEnd(false), - timer(TimingSingleton_getInstance()), - getslot(new Array("getslot", 7)), - putslot(new Array("putslot", 7)) { - if (listeningPort > 0) { - pthread_create(&localServerThread, NULL, threadWrapper, this); - } -} - -CloudComm::~CloudComm() { - delete getslot; - delete putslot; - if (salt) - delete salt; - if (password) - delete password; - if (random) - delete random; - if (baseurl) - delete baseurl; - if (mac) - delete mac; - if (key) - delete key; -} - -/** - * Generates Key from password. - */ -AESKey *CloudComm::initKey() { - try { - AESKey *key = new AESKey(password->internalBytes(), - salt, - 65536, - 128); - return key; - } catch (Exception *e) { - throw new Error("Failed generating key."); - } -} - -/** - * Inits all the security stuff - */ - -void CloudComm::initSecurity() { - // try to get the salt and if one does not exist set one - if (!getSalt()) { - //Set the salt - setSalt(); - } - - initCrypt(); -} - -/** - * Inits the HMAC generator. - */ -void CloudComm::initCrypt() { - if (password == NULL) { - return; - } - try { - key = initKey(); - delete password; - password = NULL;// drop password - mac = new Mac(); - mac->init(key); - } catch (Exception *e) { - throw new Error("Failed To Initialize Ciphers"); - } -} - -/* - * Builds the URL for the given request. - */ -IoTString *CloudComm::buildRequest(bool isput, int64_t sequencenumber, int64_t maxentries) { - const char *reqstring = isput ? "req=putslot" : "req=getslot"; - char *buffer = (char *) malloc(baseurl->length() + 200); - memcpy(buffer, baseurl->internalBytes()->internalArray(), baseurl->length()); - int offset = baseurl->length(); - offset += sprintf(&buffer[offset], "?%s&seq=%" PRId64, reqstring, sequencenumber); - if (maxentries != 0) - sprintf(&buffer[offset], "&max=%" PRId64, maxentries); - IoTString *urlstr = new IoTString(buffer); - free(buffer); - return urlstr; -} - -void loopWrite(int fd, char *array, int bytestowrite) { - int byteswritten = 0; - while (bytestowrite) { - int bytes = write(fd, &array[byteswritten], bytestowrite); - if (bytes >= 0) { - byteswritten += bytes; - bytestowrite -= bytes; - } else { - printf("Error in write\n"); - exit(-1); - } - } -} - -void loopRead(int fd, char *array, int bytestoread) { - int bytesread = 0; - while (bytestoread) { - int bytes = read(fd, &array[bytesread], bytestoread); - if (bytes >= 0) { - bytesread += bytes; - bytestoread -= bytes; - } else { - printf("Error in read\n"); - exit(-1); - } - } -} - -WebConnection openURL(IoTString *url) { - if (url->length() < 7 || memcmp(url->internalBytes()->internalArray(), "http://", 7)) { - printf("BOGUS URL\n"); - exit(-1); - } - int i = 7; - for (; i < url->length(); i++) - if (url->get(i) == '/') - break; - - if ( i == url->length()) { - printf("ERROR in openURL\n"); - exit(-1); - } - - char *host = (char *) malloc(i - 6); - memcpy(host, &url->internalBytes()->internalArray()[7], i - 7); - host[i - 7] = 0; - printf("%s\n", host); - - char *message = (char *)malloc(sizeof("POST HTTP/1.1\r\n") + sizeof("Host: \r\n") + 2 * url->length()); - - /* fill in the parameters */ - int post = sprintf(message,"POST "); - /* copy data */ - memcpy(&message[post], &url->internalBytes()->internalArray()[i], url->length() - i); - int endpost = sprintf(&message[post + url->length() - i], " HTTP/1.1\r\n"); - - int hostlen = sprintf(&message[endpost + post + url->length() - i], "Host: "); - memcpy(&message[endpost + post + url->length() + hostlen - i], host, i - 7); - sprintf(&message[endpost + post + url->length() + hostlen - 7], "\r\n"); - - /* create the socket */ - int sockfd = socket(AF_INET, SOCK_STREAM, 0); - if (sockfd < 0) {printf("ERROR opening socket\n"); exit(-1);} - - /* lookup the ip address */ - struct hostent *server = gethostbyname(host); - free(host); - - if (server == NULL) {printf("ERROR, no such host"); exit(-1);} - - /* fill in the structure */ - struct sockaddr_in serv_addr; - - memset(&serv_addr,0,sizeof(serv_addr)); - serv_addr.sin_family = AF_INET; - serv_addr.sin_port = htons(80); - memcpy(&serv_addr.sin_addr.s_addr,server->h_addr,server->h_length); - - /* connect the socket */ - if (connect(sockfd,(struct sockaddr *)&serv_addr,sizeof(serv_addr)) < 0) { - printf("ERROR connecting"); - exit(-1); - } - - /* send the request */ - int total = strlen(message); - loopWrite(sockfd, message, total); - free(message); - return (WebConnection) {sockfd, -1}; -} - -int createSocket(IoTString *name, int port) { - char *host = (char *) malloc(name->length() + 1); - memcpy(host, name->internalBytes()->internalArray(), name->length()); - host[name->length()] = 0; - printf("%s\n", host); - /* How big is the message? */ - - /* create the socket */ - int sockfd = socket(AF_INET, SOCK_STREAM, 0); - if (sockfd < 0) {printf("ERROR opening socket\n"); exit(-1);} - - /* lookup the ip address */ - struct hostent *server = gethostbyname(host); - free(host); - - if (server == NULL) {printf("ERROR, no such host"); exit(-1);} - - /* fill in the structure */ - struct sockaddr_in serv_addr; - - memset(&serv_addr,0,sizeof(serv_addr)); - serv_addr.sin_family = AF_INET; - serv_addr.sin_port = htons(port); - memcpy(&serv_addr.sin_addr.s_addr,server->h_addr,server->h_length); - - /* connect the socket */ - if (connect(sockfd,(struct sockaddr *)&serv_addr,sizeof(serv_addr)) < 0) { - printf("ERROR connecting"); - exit(-1); - } - - return sockfd; -} - -int createSocket(int port) { - int fd; - struct sockaddr_in sin; - - bzero(&sin, sizeof(sin)); - sin.sin_family = AF_INET; - sin.sin_port = htons(port); - sin.sin_addr.s_addr = htonl(INADDR_ANY); - fd = socket(AF_INET, SOCK_STREAM, 0); - int n = 1; - if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (char *)&n, sizeof (n)) < 0) { - close(fd); - printf("Create Socket Error\n"); - exit(-1); - } - if (bind(fd, (struct sockaddr *) &sin, sizeof(sin)) < 0) { - close(fd); - exit(-1); - } - if (listen(fd, 5) < 0) { - close(fd); - exit(-1); - } - return fd; -} - -int acceptSocket(int socket) { - struct sockaddr_in sin; - unsigned int sinlen = sizeof(sin); - int newfd = accept(socket, (struct sockaddr *)&sin, &sinlen); - int flag = 1; - setsockopt(newfd, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(flag)); - if (newfd < 0) { - printf("Accept Error\n"); - exit(-1); - } - return newfd; -} - -void writeSocketData(int fd, Array *data) { - loopWrite(fd, data->internalArray(), data->length()); -} - -void writeSocketInt(int fd, int32_t value) { - char array[4]; - array[0] = value >> 24; - array[1] = (value >> 16) & 0xff; - array[2] = (value >> 8) & 0xff; - array[3] = value & 0xff; - loopWrite(fd, array, 4); -} - -int readSocketInt(int fd) { - char array[4]; - loopRead(fd, array, 4); - return (((int32_t)(unsigned char) array[0]) << 24) | - (((int32_t)(unsigned char) array[1]) << 16) | - (((int32_t)(unsigned char) array[2]) << 8) | - ((int32_t)(unsigned char) array[3]); -} - -void readSocketData(int fd, Array *data) { - loopRead(fd, data->internalArray(), data->length()); -} - -void writeURLDataAndClose(WebConnection *wc, Array *data) { - dprintf(wc->fd, "Content-Length: %d\r\n\r\n", data->length()); - loopWrite(wc->fd, data->internalArray(), data->length()); -} - -void closeURLReq(WebConnection *wc) { - dprintf(wc->fd, "\r\n"); -} - -void readURLData(WebConnection *wc, Array *output) { - loopRead(wc->fd, output->internalArray(), output->length()); -} - -int readURLInt(WebConnection *wc) { - char array[4]; - loopRead(wc->fd, array, 4); - return (((int32_t)(unsigned char) array[0]) << 24) | - (((int32_t)(unsigned char) array[1]) << 16) | - (((int32_t)(unsigned char) array[2]) << 8) | - ((int32_t)(unsigned char) array[3]); -} - -void readLine(WebConnection *wc, char *response, int numBytes) { - int offset = 0; - char newchar; - while (true) { - int bytes = read(wc->fd, &newchar, 1); - if (bytes <= 0) - break; - if (offset == (numBytes - 1)) { - printf("Response too long"); - exit(-1); - } - response[offset++] = newchar; - if (newchar == '\n') - break; - } - response[offset] = 0; -} - -int getResponseCode(WebConnection *wc) { - char response[600]; - readLine(wc, response, sizeof(response)); - int ver1 = 0, ver2 = 0, respcode = 0; - sscanf(response, "HTTP/%d.%d %d", &ver1, &ver2, &respcode); - printf("Response code %d\n", respcode); - return respcode; -} - -void readHeaders(WebConnection *wc) { - char response[600]; - int numBytes; - - while (true) { - readLine(wc, response, sizeof(response)); - if (response[0] == '\r') - return; - else if (memcmp(response, "Content-Length:", sizeof("Content-Length:") - 1) == 0) { - sscanf(response, "Content-Length: %d", &numBytes); - wc->numBytes = numBytes; - } - } -} - -void CloudComm::setSalt() { - if (salt != NULL) { - // Salt already sent to server so don't set it again - return; - } - - WebConnection wc = {-1, -1}; - try { - Array *saltTmp = new Array(CloudComm_SALT_SIZE); - random->nextBytes(saltTmp); - - char *buffer = (char *) malloc(baseurl->length() + 100); - memcpy(buffer, baseurl->internalBytes()->internalArray(), baseurl->length()); - int offset = baseurl->length(); - offset += sprintf(&buffer[offset], "?req=setsalt"); - IoTString *urlstr = new IoTString(buffer); - free(buffer); - - timer->startTime(); - wc = openURL(urlstr); - delete urlstr; - writeURLDataAndClose(&wc, saltTmp); - - int responsecode = getResponseCode(&wc); - if (responsecode != HttpURLConnection_HTTP_OK) { - throw new Error("Invalid response"); - } - close(wc.fd); - - timer->endTime(); - salt = saltTmp; - } catch (Exception *e) { - timer->endTime(); - throw new ServerException("Failed setting salt", ServerException_TypeConnectTimeout); - } -} - -bool CloudComm::getSalt() { - WebConnection wc = {-1, -1}; - IoTString *urlstr = NULL; - - try { - char *buffer = (char *) malloc(baseurl->length() + 100); - memcpy(buffer, baseurl->internalBytes()->internalArray(), baseurl->length()); - int offset = baseurl->length(); - offset += sprintf(&buffer[offset], "?req=getsalt"); - urlstr = new IoTString(buffer); - free(buffer); - } catch (Exception *e) { - throw new Error("getSlot failed"); - } - try { - timer->startTime(); - wc = openURL(urlstr); - delete urlstr; - urlstr = NULL; - closeURLReq(&wc); - timer->endTime(); - } catch (SocketTimeoutException *e) { - if (urlstr) - delete urlstr; - timer->endTime(); - throw new ServerException("getSalt failed", ServerException_TypeConnectTimeout); - } catch (Exception *e) { - if (urlstr) - delete urlstr; - throw new Error("getSlot failed"); - } - - try { - timer->startTime(); - int responsecode = getResponseCode(&wc); - readHeaders(&wc); - if (responsecode != HttpURLConnection_HTTP_OK) { - throw new Error("Invalid response"); - } - if (wc.numBytes == 0) { - timer->endTime(); - close(wc.fd); - return false; - } - - - int salt_length = readURLInt(&wc); - Array *tmp = new Array(salt_length); - readURLData(&wc, tmp); - close(wc.fd); - - salt = tmp; - timer->endTime(); - return true; - } catch (SocketTimeoutException *e) { - timer->endTime(); - throw new ServerException("getSalt failed", ServerException_TypeInputTimeout); - } catch (Exception *e) { - throw new Error("getSlot failed"); - } -} - -Array *CloudComm::createIV(int64_t machineId, int64_t localSequenceNumber) { - ByteBuffer *buffer = ByteBuffer_allocate(CloudComm_IV_SIZE); - buffer->putLong(machineId); - int64_t localSequenceNumberShifted = localSequenceNumber << 16; - buffer->putLong(localSequenceNumberShifted); - return buffer->array(); -} - -Array *AESEncrypt(Array *ivBytes, AESKey *key, Array *data) { - Array *output = new Array(data->length()); - aes_encrypt_ctr((BYTE *)data->internalArray(), data->length(), (BYTE *) output->internalArray(), (WORD *)key->getKeySchedule(), key->getKey()->length() * 8, (BYTE *)ivBytes->internalArray()); - return output; -} - -Array *AESDecrypt(Array *ivBytes, AESKey *key, Array *data) { - Array *output = new Array(data->length()); - aes_decrypt_ctr((BYTE *)data->internalArray(), data->length(), (BYTE *)output->internalArray(), (WORD *)key->getKeySchedule(), key->getKey()->length() * 8, (BYTE *)ivBytes->internalArray()); - return output; -} - -Array *CloudComm::encryptSlotAndPrependIV(Array *rawData, Array *ivBytes) { - try { - Array *encryptedBytes = AESEncrypt(ivBytes, key, rawData); - Array *chars = new Array(encryptedBytes->length() + CloudComm_IV_SIZE); - System_arraycopy(ivBytes, 0, chars, 0, ivBytes->length()); - System_arraycopy(encryptedBytes, 0, chars, CloudComm_IV_SIZE, encryptedBytes->length()); - delete encryptedBytes; - return chars; - } catch (Exception *e) { - throw new Error("Failed To Encrypt"); - } -} - -Array *CloudComm::stripIVAndDecryptSlot(Array *rawData) { - try { - Array *ivBytes = new Array(CloudComm_IV_SIZE); - Array *encryptedBytes = new Array(rawData->length() - CloudComm_IV_SIZE); - System_arraycopy(rawData, 0, ivBytes, 0, CloudComm_IV_SIZE); - System_arraycopy(rawData, CloudComm_IV_SIZE, encryptedBytes, 0, encryptedBytes->length()); - Array * data = AESDecrypt(ivBytes, key, encryptedBytes); - delete encryptedBytes; - delete ivBytes; - return data; - } catch (Exception *e) { - throw new Error("Failed To Decrypt"); - } -} - -/* - * API for putting a slot into the queue. Returns NULL on success. - * On failure, the server will send slots with newer sequence - * numbers. - */ -Array *CloudComm::putSlot(Slot *slot, int max) { - WebConnection wc = {-1, -1}; - try { - if (salt == NULL) { - if (!getSalt()) { - throw new ServerException("putSlot failed", ServerException_TypeSalt); - } - initCrypt(); - } - - int64_t sequencenumber = slot->getSequenceNumber(); - Array *slotBytes = slot->encode(mac); - Array * ivBytes = slot->getSlotCryptIV(); - Array *chars = encryptSlotAndPrependIV(slotBytes, ivBytes); - delete ivBytes; - delete slotBytes; - IoTString *url = buildRequest(true, sequencenumber, max); - timer->startTime(); - wc = openURL(url); - delete url; - writeURLDataAndClose(&wc, chars); - delete chars; - timer->endTime(); - } catch (ServerException *e) { - timer->endTime(); - throw e; - } catch (SocketTimeoutException *e) { - timer->endTime(); - throw new ServerException("putSlot failed", ServerException_TypeConnectTimeout); - } catch (Exception *e) { - throw new Error("putSlot failed"); - } - - Array *resptype = NULL; - try { - int respcode = getResponseCode(&wc); - readHeaders(&wc); - timer->startTime(); - resptype = new Array(7); - readURLData(&wc, resptype); - timer->endTime(); - - if (resptype->equals(getslot)) { - delete resptype; - Array *tmp = processSlots(&wc); - close(wc.fd); - return tmp; - } else if (resptype->equals(putslot)) { - delete resptype; - close(wc.fd); - return NULL; - } else { - delete resptype; - close(wc.fd); - throw new Error("Bad response to putslot"); - } - } catch (SocketTimeoutException *e) { - if (resptype != NULL) - delete resptype; - timer->endTime(); - close(wc.fd); - throw new ServerException("putSlot failed", ServerException_TypeInputTimeout); - } catch (Exception *e) { - if (resptype != NULL) - delete resptype; - throw new Error("putSlot failed"); - } -} - -/** - * Request the server to send all slots with the given - * sequencenumber or newer-> - */ -Array *CloudComm::getSlots(int64_t sequencenumber) { - WebConnection wc = {-1, -1}; - try { - if (salt == NULL) { - if (!getSalt()) { - throw new ServerException("getSlots failed", ServerException_TypeSalt); - } - initCrypt(); - } - - IoTString *url = buildRequest(false, sequencenumber, 0); - timer->startTime(); - wc = openURL(url); - delete url; - closeURLReq(&wc); - timer->endTime(); - } catch (SocketTimeoutException *e) { - timer->endTime(); - throw new ServerException("getSlots failed", ServerException_TypeConnectTimeout); - } catch (ServerException *e) { - timer->endTime(); - - throw e; - } catch (Exception *e) { - throw new Error("getSlots failed"); - } - - try { - timer->startTime(); - int responsecode = getResponseCode(&wc); - readHeaders(&wc); - Array *resptype = new Array(7); - readURLData(&wc, resptype); - timer->endTime(); - if (!resptype->equals(getslot)) - throw new Error("Bad Response: "); - - delete resptype; - Array *tmp = processSlots(&wc); - close(wc.fd); - return tmp; - } catch (SocketTimeoutException *e) { - timer->endTime(); - close(wc.fd); - throw new ServerException("getSlots failed", ServerException_TypeInputTimeout); - } catch (Exception *e) { - throw new Error("getSlots failed"); - } -} - -/** - * Method that actually handles building Slot objects from the - * server response. Shared by both putSlot and getSlots. - */ -Array *CloudComm::processSlots(WebConnection *wc) { - int numberofslots = readURLInt(wc); - Array *sizesofslots = new Array(numberofslots); - Array *slots = new Array(numberofslots); - - for (int i = 0; i < numberofslots; i++) - sizesofslots->set(i, readURLInt(wc)); - for (int i = 0; i < numberofslots; i++) { - Array *rawData = new Array(sizesofslots->get(i)); - readURLData(wc, rawData); - Array *data = stripIVAndDecryptSlot(rawData); - delete rawData; - slots->set(i, Slot_decode(table, data, mac)); - delete data; - } - delete sizesofslots; - return slots; -} - -Array *CloudComm::sendLocalData(Array *sendData, int64_t localSequenceNumber, IoTString *host, int port) { - if (salt == NULL) - return NULL; - try { - printf("Passing Locally\n"); - mac->update(sendData, 0, sendData->length()); - Array *genmac = mac->doFinal(); - Array *totalData = new Array(sendData->length() + genmac->length()); - System_arraycopy(sendData, 0, totalData, 0, sendData->length()); - System_arraycopy(genmac, 0, totalData, sendData->length(), genmac->length()); - - // Encrypt the data for sending - Array *iv = createIV(table->getMachineId(), table->getLocalSequenceNumber()); - Array *encryptedData = encryptSlotAndPrependIV(totalData, iv); - - // Open a TCP socket connection to a local device - int socket = createSocket(host, port); - - timer->startTime(); - // Send data to output (length of data, the data) - writeSocketInt(socket, encryptedData->length()); - writeSocketData(socket, encryptedData); - - int lengthOfReturnData = readSocketInt(socket); - Array *returnData = new Array(lengthOfReturnData); - readSocketData(socket, returnData); - timer->endTime(); - returnData = stripIVAndDecryptSlot(returnData); - - // We are done with this socket - close(socket); - mac->update(returnData, 0, returnData->length() - CloudComm_HMAC_SIZE); - Array *realmac = mac->doFinal(); - Array *recmac = new Array(CloudComm_HMAC_SIZE); - System_arraycopy(returnData, returnData->length() - realmac->length(), recmac, 0, realmac->length()); - - if (!recmac->equals(realmac)) - throw new Error("Local Error: Invalid HMAC! Potential Attack!"); - - Array *returnData2 = new Array(lengthOfReturnData - recmac->length()); - System_arraycopy(returnData, 0, returnData2, 0, returnData2->length()); - - return returnData2; - } catch (Exception *e) { - printf("Exception\n"); - } - - return NULL; -} - -void CloudComm::localServerWorkerFunction() { - int inputSocket = -1; - - try { - // Local server socket - inputSocket = createSocket(listeningPort); - } catch (Exception *e) { - throw new Error("Local server setup failure..."); - } - - while (!doEnd) { - try { - // Accept incoming socket - int socket = acceptSocket(inputSocket); - - // Get the encrypted data from the server - int dataSize = readSocketInt(socket); - Array *readData = new Array(dataSize); - readSocketData(socket, readData); - timer->endTime(); - - // Decrypt the data - readData = stripIVAndDecryptSlot(readData); - mac->update(readData, 0, readData->length() - CloudComm_HMAC_SIZE); - Array *genmac = mac->doFinal(); - Array *recmac = new Array(CloudComm_HMAC_SIZE); - System_arraycopy(readData, readData->length() - recmac->length(), recmac, 0, recmac->length()); - - if (!recmac->equals(genmac)) - throw new Error("Local Error: Invalid HMAC! Potential Attack!"); - - Array *returnData = new Array(readData->length() - recmac->length()); - System_arraycopy(readData, 0, returnData, 0, returnData->length()); - - // Process the data - Array *sendData = table->acceptDataFromLocal(returnData); - mac->update(sendData, 0, sendData->length()); - Array *realmac = mac->doFinal(); - Array *totalData = new Array(sendData->length() + realmac->length()); - System_arraycopy(sendData, 0, totalData, 0, sendData->length()); - System_arraycopy(realmac, 0, totalData, sendData->length(), realmac->length()); - - // Encrypt the data for sending - Array *iv = createIV(table->getMachineId(), table->getLocalSequenceNumber()); - Array *encryptedData = encryptSlotAndPrependIV(totalData, iv); - - timer->startTime(); - // Send data to output (length of data, the data) - writeSocketInt(socket, encryptedData->length()); - writeSocketData(socket, encryptedData); - close(socket); - } catch (Exception *e) { - } - } - - if (inputSocket != -1) { - try { - close(inputSocket); - } catch (Exception *e) { - throw new Error("Local server close failure..."); - } - } -} - -void CloudComm::closeCloud() { - doEnd = true; - - if (listeningPort > 0) { - if (pthread_join(localServerThread, NULL) != 0) - throw new Error("Local Server thread join issue..."); - } -} diff --git a/version2/src/C/CloudComm.cpp b/version2/src/C/CloudComm.cpp new file mode 100644 index 0000000..8fe0cee --- /dev/null +++ b/version2/src/C/CloudComm.cpp @@ -0,0 +1,819 @@ +#include "CloudComm.h" +#include "TimingSingleton.h" +#include "SecureRandom.h" +#include "IoTString.h" +#include "Error.h" +#include "URL.h" +#include "Mac.h" +#include "Table.h" +#include "Slot.h" +#include "Crypto.h" +#include "ByteBuffer.h" +#include "aes.h" +#include +#include +#include +#include +#include +#include + +/** + * Empty Constructor needed for child class. + */ +CloudComm::CloudComm() : + baseurl(NULL), + key(NULL), + mac(NULL), + password(NULL), + random(NULL), + salt(NULL), + table(NULL), + listeningPort(-1), + doEnd(false), + timer(TimingSingleton_getInstance()), + getslot(new Array("getslot", 7)), + putslot(new Array("putslot", 7)) +{ +} + +void *threadWrapper(void *cloud) { + CloudComm *c = (CloudComm *) cloud; + c->localServerWorkerFunction(); + return NULL; +} + +/** + * Constructor for actual use. Takes in the url and password. + */ +CloudComm::CloudComm(Table *_table, IoTString *_baseurl, IoTString *_password, int _listeningPort) : + baseurl(new IoTString(_baseurl)), + key(NULL), + mac(NULL), + password(new IoTString(_password)), + random(new SecureRandom()), + salt(NULL), + table(_table), + listeningPort(_listeningPort), + doEnd(false), + timer(TimingSingleton_getInstance()), + getslot(new Array("getslot", 7)), + putslot(new Array("putslot", 7)) { + if (listeningPort > 0) { + pthread_create(&localServerThread, NULL, threadWrapper, this); + } +} + +CloudComm::~CloudComm() { + delete getslot; + delete putslot; + if (salt) + delete salt; + if (password) + delete password; + if (random) + delete random; + if (baseurl) + delete baseurl; + if (mac) + delete mac; + if (key) + delete key; +} + +/** + * Generates Key from password. + */ +AESKey *CloudComm::initKey() { + try { + AESKey *key = new AESKey(password->internalBytes(), + salt, + 65536, + 128); + return key; + } catch (Exception *e) { + throw new Error("Failed generating key."); + } +} + +/** + * Inits all the security stuff + */ + +void CloudComm::initSecurity() { + // try to get the salt and if one does not exist set one + if (!getSalt()) { + //Set the salt + setSalt(); + } + + initCrypt(); +} + +/** + * Inits the HMAC generator. + */ +void CloudComm::initCrypt() { + if (password == NULL) { + return; + } + try { + key = initKey(); + delete password; + password = NULL;// drop password + mac = new Mac(); + mac->init(key); + } catch (Exception *e) { + throw new Error("Failed To Initialize Ciphers"); + } +} + +/* + * Builds the URL for the given request. + */ +IoTString *CloudComm::buildRequest(bool isput, int64_t sequencenumber, int64_t maxentries) { + const char *reqstring = isput ? "req=putslot" : "req=getslot"; + char *buffer = (char *) malloc(baseurl->length() + 200); + memcpy(buffer, baseurl->internalBytes()->internalArray(), baseurl->length()); + int offset = baseurl->length(); + offset += sprintf(&buffer[offset], "?%s&seq=%" PRId64, reqstring, sequencenumber); + if (maxentries != 0) + sprintf(&buffer[offset], "&max=%" PRId64, maxentries); + IoTString *urlstr = new IoTString(buffer); + free(buffer); + return urlstr; +} + +void loopWrite(int fd, char *array, int bytestowrite) { + int byteswritten = 0; + while (bytestowrite) { + int bytes = write(fd, &array[byteswritten], bytestowrite); + if (bytes >= 0) { + byteswritten += bytes; + bytestowrite -= bytes; + } else { + printf("Error in write\n"); + exit(-1); + } + } +} + +void loopRead(int fd, char *array, int bytestoread) { + int bytesread = 0; + while (bytestoread) { + int bytes = read(fd, &array[bytesread], bytestoread); + if (bytes >= 0) { + bytesread += bytes; + bytestoread -= bytes; + } else { + printf("Error in read\n"); + exit(-1); + } + } +} + +WebConnection openURL(IoTString *url) { + if (url->length() < 7 || memcmp(url->internalBytes()->internalArray(), "http://", 7)) { + printf("BOGUS URL\n"); + exit(-1); + } + int i = 7; + for (; i < url->length(); i++) + if (url->get(i) == '/') + break; + + if ( i == url->length()) { + printf("ERROR in openURL\n"); + exit(-1); + } + + char *host = (char *) malloc(i - 6); + memcpy(host, &url->internalBytes()->internalArray()[7], i - 7); + host[i - 7] = 0; + printf("%s\n", host); + + char *message = (char *)malloc(sizeof("POST HTTP/1.1\r\n") + sizeof("Host: \r\n") + 2 * url->length()); + + /* fill in the parameters */ + int post = sprintf(message,"POST "); + /* copy data */ + memcpy(&message[post], &url->internalBytes()->internalArray()[i], url->length() - i); + int endpost = sprintf(&message[post + url->length() - i], " HTTP/1.1\r\n"); + + int hostlen = sprintf(&message[endpost + post + url->length() - i], "Host: "); + memcpy(&message[endpost + post + url->length() + hostlen - i], host, i - 7); + sprintf(&message[endpost + post + url->length() + hostlen - 7], "\r\n"); + + /* create the socket */ + int sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (sockfd < 0) {printf("ERROR opening socket\n"); exit(-1);} + + /* lookup the ip address */ + struct hostent *server = gethostbyname(host); + free(host); + + if (server == NULL) {printf("ERROR, no such host"); exit(-1);} + + /* fill in the structure */ + struct sockaddr_in serv_addr; + + memset(&serv_addr,0,sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + serv_addr.sin_port = htons(80); + memcpy(&serv_addr.sin_addr.s_addr,server->h_addr,server->h_length); + + /* connect the socket */ + if (connect(sockfd,(struct sockaddr *)&serv_addr,sizeof(serv_addr)) < 0) { + printf("ERROR connecting"); + exit(-1); + } + + /* send the request */ + int total = strlen(message); + loopWrite(sockfd, message, total); + free(message); + return (WebConnection) {sockfd, -1}; +} + +int createSocket(IoTString *name, int port) { + char *host = (char *) malloc(name->length() + 1); + memcpy(host, name->internalBytes()->internalArray(), name->length()); + host[name->length()] = 0; + printf("%s\n", host); + /* How big is the message? */ + + /* create the socket */ + int sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (sockfd < 0) {printf("ERROR opening socket\n"); exit(-1);} + + /* lookup the ip address */ + struct hostent *server = gethostbyname(host); + free(host); + + if (server == NULL) {printf("ERROR, no such host"); exit(-1);} + + /* fill in the structure */ + struct sockaddr_in serv_addr; + + memset(&serv_addr,0,sizeof(serv_addr)); + serv_addr.sin_family = AF_INET; + serv_addr.sin_port = htons(port); + memcpy(&serv_addr.sin_addr.s_addr,server->h_addr,server->h_length); + + /* connect the socket */ + if (connect(sockfd,(struct sockaddr *)&serv_addr,sizeof(serv_addr)) < 0) { + printf("ERROR connecting"); + exit(-1); + } + + return sockfd; +} + +int createSocket(int port) { + int fd; + struct sockaddr_in sin; + + bzero(&sin, sizeof(sin)); + sin.sin_family = AF_INET; + sin.sin_port = htons(port); + sin.sin_addr.s_addr = htonl(INADDR_ANY); + fd = socket(AF_INET, SOCK_STREAM, 0); + int n = 1; + if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (char *)&n, sizeof (n)) < 0) { + close(fd); + printf("Create Socket Error\n"); + exit(-1); + } + if (bind(fd, (struct sockaddr *) &sin, sizeof(sin)) < 0) { + close(fd); + exit(-1); + } + if (listen(fd, 5) < 0) { + close(fd); + exit(-1); + } + return fd; +} + +int acceptSocket(int socket) { + struct sockaddr_in sin; + unsigned int sinlen = sizeof(sin); + int newfd = accept(socket, (struct sockaddr *)&sin, &sinlen); + int flag = 1; + setsockopt(newfd, IPPROTO_TCP, TCP_NODELAY, (char *) &flag, sizeof(flag)); + if (newfd < 0) { + printf("Accept Error\n"); + exit(-1); + } + return newfd; +} + +void writeSocketData(int fd, Array *data) { + loopWrite(fd, data->internalArray(), data->length()); +} + +void writeSocketInt(int fd, int32_t value) { + char array[4]; + array[0] = value >> 24; + array[1] = (value >> 16) & 0xff; + array[2] = (value >> 8) & 0xff; + array[3] = value & 0xff; + loopWrite(fd, array, 4); +} + +int readSocketInt(int fd) { + char array[4]; + loopRead(fd, array, 4); + return (((int32_t)(unsigned char) array[0]) << 24) | + (((int32_t)(unsigned char) array[1]) << 16) | + (((int32_t)(unsigned char) array[2]) << 8) | + ((int32_t)(unsigned char) array[3]); +} + +void readSocketData(int fd, Array *data) { + loopRead(fd, data->internalArray(), data->length()); +} + +void writeURLDataAndClose(WebConnection *wc, Array *data) { + dprintf(wc->fd, "Content-Length: %d\r\n\r\n", data->length()); + loopWrite(wc->fd, data->internalArray(), data->length()); +} + +void closeURLReq(WebConnection *wc) { + dprintf(wc->fd, "\r\n"); +} + +void readURLData(WebConnection *wc, Array *output) { + loopRead(wc->fd, output->internalArray(), output->length()); +} + +int readURLInt(WebConnection *wc) { + char array[4]; + loopRead(wc->fd, array, 4); + return (((int32_t)(unsigned char) array[0]) << 24) | + (((int32_t)(unsigned char) array[1]) << 16) | + (((int32_t)(unsigned char) array[2]) << 8) | + ((int32_t)(unsigned char) array[3]); +} + +void readLine(WebConnection *wc, char *response, int numBytes) { + int offset = 0; + char newchar; + while (true) { + int bytes = read(wc->fd, &newchar, 1); + if (bytes <= 0) + break; + if (offset == (numBytes - 1)) { + printf("Response too long"); + exit(-1); + } + response[offset++] = newchar; + if (newchar == '\n') + break; + } + response[offset] = 0; +} + +int getResponseCode(WebConnection *wc) { + char response[600]; + readLine(wc, response, sizeof(response)); + int ver1 = 0, ver2 = 0, respcode = 0; + sscanf(response, "HTTP/%d.%d %d", &ver1, &ver2, &respcode); + printf("Response code %d\n", respcode); + return respcode; +} + +void readHeaders(WebConnection *wc) { + char response[600]; + int numBytes; + + while (true) { + readLine(wc, response, sizeof(response)); + if (response[0] == '\r') + return; + else if (memcmp(response, "Content-Length:", sizeof("Content-Length:") - 1) == 0) { + sscanf(response, "Content-Length: %d", &numBytes); + wc->numBytes = numBytes; + } + } +} + +void CloudComm::setSalt() { + if (salt != NULL) { + // Salt already sent to server so don't set it again + return; + } + + WebConnection wc = {-1, -1}; + try { + Array *saltTmp = new Array(CloudComm_SALT_SIZE); + random->nextBytes(saltTmp); + + char *buffer = (char *) malloc(baseurl->length() + 100); + memcpy(buffer, baseurl->internalBytes()->internalArray(), baseurl->length()); + int offset = baseurl->length(); + offset += sprintf(&buffer[offset], "?req=setsalt"); + IoTString *urlstr = new IoTString(buffer); + free(buffer); + + timer->startTime(); + wc = openURL(urlstr); + delete urlstr; + writeURLDataAndClose(&wc, saltTmp); + + int responsecode = getResponseCode(&wc); + if (responsecode != HttpURLConnection_HTTP_OK) { + throw new Error("Invalid response"); + } + close(wc.fd); + + timer->endTime(); + salt = saltTmp; + } catch (Exception *e) { + timer->endTime(); + throw new ServerException("Failed setting salt", ServerException_TypeConnectTimeout); + } +} + +bool CloudComm::getSalt() { + WebConnection wc = {-1, -1}; + IoTString *urlstr = NULL; + + try { + char *buffer = (char *) malloc(baseurl->length() + 100); + memcpy(buffer, baseurl->internalBytes()->internalArray(), baseurl->length()); + int offset = baseurl->length(); + offset += sprintf(&buffer[offset], "?req=getsalt"); + urlstr = new IoTString(buffer); + free(buffer); + } catch (Exception *e) { + throw new Error("getSlot failed"); + } + try { + timer->startTime(); + wc = openURL(urlstr); + delete urlstr; + urlstr = NULL; + closeURLReq(&wc); + timer->endTime(); + } catch (SocketTimeoutException *e) { + if (urlstr) + delete urlstr; + timer->endTime(); + throw new ServerException("getSalt failed", ServerException_TypeConnectTimeout); + } catch (Exception *e) { + if (urlstr) + delete urlstr; + throw new Error("getSlot failed"); + } + + try { + timer->startTime(); + int responsecode = getResponseCode(&wc); + readHeaders(&wc); + if (responsecode != HttpURLConnection_HTTP_OK) { + throw new Error("Invalid response"); + } + if (wc.numBytes == 0) { + timer->endTime(); + close(wc.fd); + return false; + } + + + int salt_length = readURLInt(&wc); + Array *tmp = new Array(salt_length); + readURLData(&wc, tmp); + close(wc.fd); + + salt = tmp; + timer->endTime(); + return true; + } catch (SocketTimeoutException *e) { + timer->endTime(); + throw new ServerException("getSalt failed", ServerException_TypeInputTimeout); + } catch (Exception *e) { + throw new Error("getSlot failed"); + } +} + +Array *CloudComm::createIV(int64_t machineId, int64_t localSequenceNumber) { + ByteBuffer *buffer = ByteBuffer_allocate(CloudComm_IV_SIZE); + buffer->putLong(machineId); + int64_t localSequenceNumberShifted = localSequenceNumber << 16; + buffer->putLong(localSequenceNumberShifted); + return buffer->array(); +} + +Array *AESEncrypt(Array *ivBytes, AESKey *key, Array *data) { + Array *output = new Array(data->length()); + aes_encrypt_ctr((BYTE *)data->internalArray(), data->length(), (BYTE *) output->internalArray(), (WORD *)key->getKeySchedule(), key->getKey()->length() * 8, (BYTE *)ivBytes->internalArray()); + return output; +} + +Array *AESDecrypt(Array *ivBytes, AESKey *key, Array *data) { + Array *output = new Array(data->length()); + aes_decrypt_ctr((BYTE *)data->internalArray(), data->length(), (BYTE *)output->internalArray(), (WORD *)key->getKeySchedule(), key->getKey()->length() * 8, (BYTE *)ivBytes->internalArray()); + return output; +} + +Array *CloudComm::encryptSlotAndPrependIV(Array *rawData, Array *ivBytes) { + try { + Array *encryptedBytes = AESEncrypt(ivBytes, key, rawData); + Array *chars = new Array(encryptedBytes->length() + CloudComm_IV_SIZE); + System_arraycopy(ivBytes, 0, chars, 0, ivBytes->length()); + System_arraycopy(encryptedBytes, 0, chars, CloudComm_IV_SIZE, encryptedBytes->length()); + delete encryptedBytes; + return chars; + } catch (Exception *e) { + throw new Error("Failed To Encrypt"); + } +} + +Array *CloudComm::stripIVAndDecryptSlot(Array *rawData) { + try { + Array *ivBytes = new Array(CloudComm_IV_SIZE); + Array *encryptedBytes = new Array(rawData->length() - CloudComm_IV_SIZE); + System_arraycopy(rawData, 0, ivBytes, 0, CloudComm_IV_SIZE); + System_arraycopy(rawData, CloudComm_IV_SIZE, encryptedBytes, 0, encryptedBytes->length()); + Array * data = AESDecrypt(ivBytes, key, encryptedBytes); + delete encryptedBytes; + delete ivBytes; + return data; + } catch (Exception *e) { + throw new Error("Failed To Decrypt"); + } +} + +/* + * API for putting a slot into the queue. Returns NULL on success. + * On failure, the server will send slots with newer sequence + * numbers. + */ +Array *CloudComm::putSlot(Slot *slot, int max) { + WebConnection wc = {-1, -1}; + try { + if (salt == NULL) { + if (!getSalt()) { + throw new ServerException("putSlot failed", ServerException_TypeSalt); + } + initCrypt(); + } + + int64_t sequencenumber = slot->getSequenceNumber(); + Array *slotBytes = slot->encode(mac); + Array * ivBytes = slot->getSlotCryptIV(); + Array *chars = encryptSlotAndPrependIV(slotBytes, ivBytes); + delete ivBytes; + delete slotBytes; + IoTString *url = buildRequest(true, sequencenumber, max); + timer->startTime(); + wc = openURL(url); + delete url; + writeURLDataAndClose(&wc, chars); + delete chars; + timer->endTime(); + } catch (ServerException *e) { + timer->endTime(); + throw e; + } catch (SocketTimeoutException *e) { + timer->endTime(); + throw new ServerException("putSlot failed", ServerException_TypeConnectTimeout); + } catch (Exception *e) { + throw new Error("putSlot failed"); + } + + Array *resptype = NULL; + try { + int respcode = getResponseCode(&wc); + readHeaders(&wc); + timer->startTime(); + resptype = new Array(7); + readURLData(&wc, resptype); + timer->endTime(); + + if (resptype->equals(getslot)) { + delete resptype; + Array *tmp = processSlots(&wc); + close(wc.fd); + return tmp; + } else if (resptype->equals(putslot)) { + delete resptype; + close(wc.fd); + return NULL; + } else { + delete resptype; + close(wc.fd); + throw new Error("Bad response to putslot"); + } + } catch (SocketTimeoutException *e) { + if (resptype != NULL) + delete resptype; + timer->endTime(); + close(wc.fd); + throw new ServerException("putSlot failed", ServerException_TypeInputTimeout); + } catch (Exception *e) { + if (resptype != NULL) + delete resptype; + throw new Error("putSlot failed"); + } +} + +/** + * Request the server to send all slots with the given + * sequencenumber or newer-> + */ +Array *CloudComm::getSlots(int64_t sequencenumber) { + WebConnection wc = {-1, -1}; + try { + if (salt == NULL) { + if (!getSalt()) { + throw new ServerException("getSlots failed", ServerException_TypeSalt); + } + initCrypt(); + } + + IoTString *url = buildRequest(false, sequencenumber, 0); + timer->startTime(); + wc = openURL(url); + delete url; + closeURLReq(&wc); + timer->endTime(); + } catch (SocketTimeoutException *e) { + timer->endTime(); + throw new ServerException("getSlots failed", ServerException_TypeConnectTimeout); + } catch (ServerException *e) { + timer->endTime(); + + throw e; + } catch (Exception *e) { + throw new Error("getSlots failed"); + } + + try { + timer->startTime(); + int responsecode = getResponseCode(&wc); + readHeaders(&wc); + Array *resptype = new Array(7); + readURLData(&wc, resptype); + timer->endTime(); + if (!resptype->equals(getslot)) + throw new Error("Bad Response: "); + + delete resptype; + Array *tmp = processSlots(&wc); + close(wc.fd); + return tmp; + } catch (SocketTimeoutException *e) { + timer->endTime(); + close(wc.fd); + throw new ServerException("getSlots failed", ServerException_TypeInputTimeout); + } catch (Exception *e) { + throw new Error("getSlots failed"); + } +} + +/** + * Method that actually handles building Slot objects from the + * server response. Shared by both putSlot and getSlots. + */ +Array *CloudComm::processSlots(WebConnection *wc) { + int numberofslots = readURLInt(wc); + Array *sizesofslots = new Array(numberofslots); + Array *slots = new Array(numberofslots); + + for (int i = 0; i < numberofslots; i++) + sizesofslots->set(i, readURLInt(wc)); + for (int i = 0; i < numberofslots; i++) { + Array *rawData = new Array(sizesofslots->get(i)); + readURLData(wc, rawData); + Array *data = stripIVAndDecryptSlot(rawData); + delete rawData; + slots->set(i, Slot_decode(table, data, mac)); + delete data; + } + delete sizesofslots; + return slots; +} + +Array *CloudComm::sendLocalData(Array *sendData, int64_t localSequenceNumber, IoTString *host, int port) { + if (salt == NULL) + return NULL; + try { + printf("Passing Locally\n"); + mac->update(sendData, 0, sendData->length()); + Array *genmac = mac->doFinal(); + Array *totalData = new Array(sendData->length() + genmac->length()); + System_arraycopy(sendData, 0, totalData, 0, sendData->length()); + System_arraycopy(genmac, 0, totalData, sendData->length(), genmac->length()); + + // Encrypt the data for sending + Array *iv = createIV(table->getMachineId(), table->getLocalSequenceNumber()); + Array *encryptedData = encryptSlotAndPrependIV(totalData, iv); + + // Open a TCP socket connection to a local device + int socket = createSocket(host, port); + + timer->startTime(); + // Send data to output (length of data, the data) + writeSocketInt(socket, encryptedData->length()); + writeSocketData(socket, encryptedData); + + int lengthOfReturnData = readSocketInt(socket); + Array *returnData = new Array(lengthOfReturnData); + readSocketData(socket, returnData); + timer->endTime(); + returnData = stripIVAndDecryptSlot(returnData); + + // We are done with this socket + close(socket); + mac->update(returnData, 0, returnData->length() - CloudComm_HMAC_SIZE); + Array *realmac = mac->doFinal(); + Array *recmac = new Array(CloudComm_HMAC_SIZE); + System_arraycopy(returnData, returnData->length() - realmac->length(), recmac, 0, realmac->length()); + + if (!recmac->equals(realmac)) + throw new Error("Local Error: Invalid HMAC! Potential Attack!"); + + Array *returnData2 = new Array(lengthOfReturnData - recmac->length()); + System_arraycopy(returnData, 0, returnData2, 0, returnData2->length()); + + return returnData2; + } catch (Exception *e) { + printf("Exception\n"); + } + + return NULL; +} + +void CloudComm::localServerWorkerFunction() { + int inputSocket = -1; + + try { + // Local server socket + inputSocket = createSocket(listeningPort); + } catch (Exception *e) { + throw new Error("Local server setup failure..."); + } + + while (!doEnd) { + try { + // Accept incoming socket + int socket = acceptSocket(inputSocket); + + // Get the encrypted data from the server + int dataSize = readSocketInt(socket); + Array *readData = new Array(dataSize); + readSocketData(socket, readData); + timer->endTime(); + + // Decrypt the data + readData = stripIVAndDecryptSlot(readData); + mac->update(readData, 0, readData->length() - CloudComm_HMAC_SIZE); + Array *genmac = mac->doFinal(); + Array *recmac = new Array(CloudComm_HMAC_SIZE); + System_arraycopy(readData, readData->length() - recmac->length(), recmac, 0, recmac->length()); + + if (!recmac->equals(genmac)) + throw new Error("Local Error: Invalid HMAC! Potential Attack!"); + + Array *returnData = new Array(readData->length() - recmac->length()); + System_arraycopy(readData, 0, returnData, 0, returnData->length()); + + // Process the data + Array *sendData = table->acceptDataFromLocal(returnData); + mac->update(sendData, 0, sendData->length()); + Array *realmac = mac->doFinal(); + Array *totalData = new Array(sendData->length() + realmac->length()); + System_arraycopy(sendData, 0, totalData, 0, sendData->length()); + System_arraycopy(realmac, 0, totalData, sendData->length(), realmac->length()); + + // Encrypt the data for sending + Array *iv = createIV(table->getMachineId(), table->getLocalSequenceNumber()); + Array *encryptedData = encryptSlotAndPrependIV(totalData, iv); + + timer->startTime(); + // Send data to output (length of data, the data) + writeSocketInt(socket, encryptedData->length()); + writeSocketData(socket, encryptedData); + close(socket); + } catch (Exception *e) { + } + } + + if (inputSocket != -1) { + try { + close(inputSocket); + } catch (Exception *e) { + throw new Error("Local server close failure..."); + } + } +} + +void CloudComm::closeCloud() { + doEnd = true; + + if (listeningPort > 0) { + if (pthread_join(localServerThread, NULL) != 0) + throw new Error("Local Server thread join issue..."); + } +} diff --git a/version2/src/C/Commit.cc b/version2/src/C/Commit.cc deleted file mode 100644 index 7665841..0000000 --- a/version2/src/C/Commit.cc +++ /dev/null @@ -1,300 +0,0 @@ -#include "Commit.h" -#include "CommitPart.h" -#include "ByteBuffer.h" -#include "IoTString.h" - -Commit::Commit() : - parts(new Vector()), - partCount(0), - missingParts(NULL), - fldisComplete(false), - hasLastPart(false), - keyValueUpdateSet(new Hashset()), - isDead(false), - sequenceNumber(-1), - machineId(-1), - transactionSequenceNumber(-1), - dataBytes(NULL), - liveKeys(new Hashset()) { -} - -Commit::Commit(int64_t _sequenceNumber, int64_t _machineId, int64_t _transactionSequenceNumber) : - parts(new Vector()), - partCount(0), - missingParts(NULL), - fldisComplete(true), - hasLastPart(false), - keyValueUpdateSet(new Hashset()), - isDead(false), - sequenceNumber(_sequenceNumber), - machineId(_machineId), - transactionSequenceNumber(_transactionSequenceNumber), - dataBytes(NULL), - liveKeys(new Hashset()) { -} - -Commit::~Commit() { - { - uint Size = parts->size(); - for(uint i=0;iget(i)->releaseRef(); - delete parts; - } - { - SetIterator * keyit = keyValueUpdateSet->iterator(); - while(keyit->hasNext()) { - delete keyit->next(); - } - delete keyit; - delete keyValueUpdateSet; - } - delete liveKeys; - if (missingParts != NULL) - delete missingParts; - if (dataBytes != NULL) - delete dataBytes; -} - -void Commit::addPartDecode(CommitPart *newPart) { - if (isDead) { - // If dead then just kill this part and move on - newPart->setDead(); - return; - } - - newPart->acquireRef(); - CommitPart *previouslySeenPart = parts->setExpand(newPart->getPartNumber(), newPart); - if (previouslySeenPart == NULL) - partCount++; - - if (previouslySeenPart != NULL) { - // Set dead the old one since the new one is a rescued version of this part - previouslySeenPart->setDead(); - previouslySeenPart->releaseRef(); - } else if (newPart->isLastPart()) { - missingParts = new Hashset(); - hasLastPart = true; - - for (int i = 0; i < newPart->getPartNumber(); i++) { - if (parts->get(i) == NULL) { - missingParts->add(i); - } - } - } - - if (!fldisComplete && hasLastPart) { - - // We have seen this part so remove it from the set of missing parts - missingParts->remove(newPart->getPartNumber()); - - // Check if all the parts have been seen - if (missingParts->size() == 0) { - - // We have all the parts - fldisComplete = true; - - // Decode all the parts and create the key value guard and update sets - decodeCommitData(); - - // Get the sequence number and arbitrator of this transaction - sequenceNumber = parts->get(0)->getSequenceNumber(); - machineId = parts->get(0)->getMachineId(); - transactionSequenceNumber = parts->get(0)->getTransactionSequenceNumber(); - } - } -} - -int64_t Commit::getSequenceNumber() { - return sequenceNumber; -} - -int64_t Commit::getTransactionSequenceNumber() { - return transactionSequenceNumber; -} - -Vector *Commit::getParts() { - return parts; -} - -void Commit::addKV(KeyValue *kv) { - KeyValue * kvcopy = kv->getCopy(); - keyValueUpdateSet->add(kvcopy); - liveKeys->add(kvcopy->getKey()); -} - -void Commit::invalidateKey(IoTString *key) { - liveKeys->remove(key); - - if (liveKeys->size() == 0) { - setDead(); - } -} - -Hashset *Commit::getKeyValueUpdateSet() { - return keyValueUpdateSet; -} - -int32_t Commit::getNumberOfParts() { - return partCount; -} - -void Commit::setDead() { - if (!isDead) { - isDead = true; - // Make all the parts of this transaction dead - for (uint32_t partNumber = 0; partNumber < parts->size(); partNumber++) { - CommitPart *part = parts->get(partNumber); - part->setDead(); - } - } -} - -void Commit::createCommitParts() { - uint Size = parts->size(); - for(uint i=0;i < Size; i++) { - Entry * e=parts->get(i); - e->releaseRef(); - } - parts->clear(); - partCount = 0; - // Convert to chars - Array *charData = convertDataToBytes(); - - int commitPartCount = 0; - int currentPosition = 0; - int remaining = charData->length(); - - while (remaining > 0) { - bool isLastPart = false; - // determine how much to copy - int copySize = CommitPart_MAX_NON_HEADER_SIZE; - if (remaining <= CommitPart_MAX_NON_HEADER_SIZE) { - copySize = remaining; - isLastPart = true;// last bit of data so last part - } - - // Copy to a smaller version - Array *partData = new Array(copySize); - System_arraycopy(charData, currentPosition, partData, 0, copySize); - - CommitPart *part = new CommitPart(NULL, machineId, sequenceNumber, transactionSequenceNumber, commitPartCount, partData, isLastPart); - parts->setExpand(part->getPartNumber(), part); - - // Update position, count and remaining - currentPosition += copySize; - commitPartCount++; - remaining -= copySize; - } - delete charData; -} - -void Commit::decodeCommitData() { - // Calculate the size of the data section - int dataSize = 0; - for (uint i = 0; i < parts->size(); i++) { - CommitPart *tp = parts->get(i); - if (tp != NULL) - dataSize += tp->getDataSize(); - } - - Array *combinedData = new Array(dataSize); - int currentPosition = 0; - - // Stitch all the data sections together - for (uint i = 0; i < parts->size(); i++) { - CommitPart *tp = parts->get(i); - if (tp != NULL) { - System_arraycopy(tp->getData(), 0, combinedData, currentPosition, tp->getDataSize()); - currentPosition += tp->getDataSize(); - } - } - - // Decoder Object - ByteBuffer *bbDecode = ByteBuffer_wrap(combinedData); - - // Decode how many key value pairs need to be decoded - int numberOfKVUpdates = bbDecode->getInt(); - - // Decode all the updates key values - for (int i = 0; i < numberOfKVUpdates; i++) { - KeyValue *kv = (KeyValue *)KeyValue_decode(bbDecode); - keyValueUpdateSet->add(kv); - liveKeys->add(kv->getKey()); - } - delete bbDecode; -} - -Array *Commit::convertDataToBytes() { - // Calculate the size of the data - int sizeOfData = sizeof(int32_t); // Number of Update KV's - SetIterator *kvit = keyValueUpdateSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - sizeOfData += kv->getSize(); - } - delete kvit; - - // Data handlers and storage - Array *dataArray = new Array(sizeOfData); - ByteBuffer *bbEncode = ByteBuffer_wrap(dataArray); - - // Encode the size of the updates and guard sets - bbEncode->putInt(keyValueUpdateSet->size()); - - // Encode all the updates - kvit = keyValueUpdateSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - kv->encode(bbEncode); - } - delete kvit; - Array * array = bbEncode->array(); - bbEncode->releaseArray(); - delete bbEncode; - return array; -} - -void Commit::setKVsMap(Hashset *newKVs) { - keyValueUpdateSet->clear(); - liveKeys->clear(); - SetIterator *kvit = newKVs->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - KeyValue *kvcopy = kv->getCopy(); - liveKeys->add(kvcopy->getKey()); - keyValueUpdateSet->add(kvcopy); - } - delete kvit; -} - -Commit *Commit_merge(Commit *newer, Commit *older, int64_t newSequenceNumber) { - if (older == NULL) { - return newer; - } else if (newer == NULL) { - return older; - } - Hashset *kvSet = new Hashset(); - SetIterator *kvit = older->getKeyValueUpdateSet()->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - kvSet->add(kv); - } - delete kvit; - kvit = newer->getKeyValueUpdateSet()->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - kvSet->add(kv); - } - delete kvit; - - int64_t transactionSequenceNumber = newer->getTransactionSequenceNumber(); - if (transactionSequenceNumber == -1) { - transactionSequenceNumber = older->getTransactionSequenceNumber(); - } - - Commit *newCommit = new Commit(newSequenceNumber, newer->getMachineId(), transactionSequenceNumber); - newCommit->setKVsMap(kvSet); - - delete kvSet; - return newCommit; -} diff --git a/version2/src/C/Commit.cpp b/version2/src/C/Commit.cpp new file mode 100644 index 0000000..7665841 --- /dev/null +++ b/version2/src/C/Commit.cpp @@ -0,0 +1,300 @@ +#include "Commit.h" +#include "CommitPart.h" +#include "ByteBuffer.h" +#include "IoTString.h" + +Commit::Commit() : + parts(new Vector()), + partCount(0), + missingParts(NULL), + fldisComplete(false), + hasLastPart(false), + keyValueUpdateSet(new Hashset()), + isDead(false), + sequenceNumber(-1), + machineId(-1), + transactionSequenceNumber(-1), + dataBytes(NULL), + liveKeys(new Hashset()) { +} + +Commit::Commit(int64_t _sequenceNumber, int64_t _machineId, int64_t _transactionSequenceNumber) : + parts(new Vector()), + partCount(0), + missingParts(NULL), + fldisComplete(true), + hasLastPart(false), + keyValueUpdateSet(new Hashset()), + isDead(false), + sequenceNumber(_sequenceNumber), + machineId(_machineId), + transactionSequenceNumber(_transactionSequenceNumber), + dataBytes(NULL), + liveKeys(new Hashset()) { +} + +Commit::~Commit() { + { + uint Size = parts->size(); + for(uint i=0;iget(i)->releaseRef(); + delete parts; + } + { + SetIterator * keyit = keyValueUpdateSet->iterator(); + while(keyit->hasNext()) { + delete keyit->next(); + } + delete keyit; + delete keyValueUpdateSet; + } + delete liveKeys; + if (missingParts != NULL) + delete missingParts; + if (dataBytes != NULL) + delete dataBytes; +} + +void Commit::addPartDecode(CommitPart *newPart) { + if (isDead) { + // If dead then just kill this part and move on + newPart->setDead(); + return; + } + + newPart->acquireRef(); + CommitPart *previouslySeenPart = parts->setExpand(newPart->getPartNumber(), newPart); + if (previouslySeenPart == NULL) + partCount++; + + if (previouslySeenPart != NULL) { + // Set dead the old one since the new one is a rescued version of this part + previouslySeenPart->setDead(); + previouslySeenPart->releaseRef(); + } else if (newPart->isLastPart()) { + missingParts = new Hashset(); + hasLastPart = true; + + for (int i = 0; i < newPart->getPartNumber(); i++) { + if (parts->get(i) == NULL) { + missingParts->add(i); + } + } + } + + if (!fldisComplete && hasLastPart) { + + // We have seen this part so remove it from the set of missing parts + missingParts->remove(newPart->getPartNumber()); + + // Check if all the parts have been seen + if (missingParts->size() == 0) { + + // We have all the parts + fldisComplete = true; + + // Decode all the parts and create the key value guard and update sets + decodeCommitData(); + + // Get the sequence number and arbitrator of this transaction + sequenceNumber = parts->get(0)->getSequenceNumber(); + machineId = parts->get(0)->getMachineId(); + transactionSequenceNumber = parts->get(0)->getTransactionSequenceNumber(); + } + } +} + +int64_t Commit::getSequenceNumber() { + return sequenceNumber; +} + +int64_t Commit::getTransactionSequenceNumber() { + return transactionSequenceNumber; +} + +Vector *Commit::getParts() { + return parts; +} + +void Commit::addKV(KeyValue *kv) { + KeyValue * kvcopy = kv->getCopy(); + keyValueUpdateSet->add(kvcopy); + liveKeys->add(kvcopy->getKey()); +} + +void Commit::invalidateKey(IoTString *key) { + liveKeys->remove(key); + + if (liveKeys->size() == 0) { + setDead(); + } +} + +Hashset *Commit::getKeyValueUpdateSet() { + return keyValueUpdateSet; +} + +int32_t Commit::getNumberOfParts() { + return partCount; +} + +void Commit::setDead() { + if (!isDead) { + isDead = true; + // Make all the parts of this transaction dead + for (uint32_t partNumber = 0; partNumber < parts->size(); partNumber++) { + CommitPart *part = parts->get(partNumber); + part->setDead(); + } + } +} + +void Commit::createCommitParts() { + uint Size = parts->size(); + for(uint i=0;i < Size; i++) { + Entry * e=parts->get(i); + e->releaseRef(); + } + parts->clear(); + partCount = 0; + // Convert to chars + Array *charData = convertDataToBytes(); + + int commitPartCount = 0; + int currentPosition = 0; + int remaining = charData->length(); + + while (remaining > 0) { + bool isLastPart = false; + // determine how much to copy + int copySize = CommitPart_MAX_NON_HEADER_SIZE; + if (remaining <= CommitPart_MAX_NON_HEADER_SIZE) { + copySize = remaining; + isLastPart = true;// last bit of data so last part + } + + // Copy to a smaller version + Array *partData = new Array(copySize); + System_arraycopy(charData, currentPosition, partData, 0, copySize); + + CommitPart *part = new CommitPart(NULL, machineId, sequenceNumber, transactionSequenceNumber, commitPartCount, partData, isLastPart); + parts->setExpand(part->getPartNumber(), part); + + // Update position, count and remaining + currentPosition += copySize; + commitPartCount++; + remaining -= copySize; + } + delete charData; +} + +void Commit::decodeCommitData() { + // Calculate the size of the data section + int dataSize = 0; + for (uint i = 0; i < parts->size(); i++) { + CommitPart *tp = parts->get(i); + if (tp != NULL) + dataSize += tp->getDataSize(); + } + + Array *combinedData = new Array(dataSize); + int currentPosition = 0; + + // Stitch all the data sections together + for (uint i = 0; i < parts->size(); i++) { + CommitPart *tp = parts->get(i); + if (tp != NULL) { + System_arraycopy(tp->getData(), 0, combinedData, currentPosition, tp->getDataSize()); + currentPosition += tp->getDataSize(); + } + } + + // Decoder Object + ByteBuffer *bbDecode = ByteBuffer_wrap(combinedData); + + // Decode how many key value pairs need to be decoded + int numberOfKVUpdates = bbDecode->getInt(); + + // Decode all the updates key values + for (int i = 0; i < numberOfKVUpdates; i++) { + KeyValue *kv = (KeyValue *)KeyValue_decode(bbDecode); + keyValueUpdateSet->add(kv); + liveKeys->add(kv->getKey()); + } + delete bbDecode; +} + +Array *Commit::convertDataToBytes() { + // Calculate the size of the data + int sizeOfData = sizeof(int32_t); // Number of Update KV's + SetIterator *kvit = keyValueUpdateSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + sizeOfData += kv->getSize(); + } + delete kvit; + + // Data handlers and storage + Array *dataArray = new Array(sizeOfData); + ByteBuffer *bbEncode = ByteBuffer_wrap(dataArray); + + // Encode the size of the updates and guard sets + bbEncode->putInt(keyValueUpdateSet->size()); + + // Encode all the updates + kvit = keyValueUpdateSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + kv->encode(bbEncode); + } + delete kvit; + Array * array = bbEncode->array(); + bbEncode->releaseArray(); + delete bbEncode; + return array; +} + +void Commit::setKVsMap(Hashset *newKVs) { + keyValueUpdateSet->clear(); + liveKeys->clear(); + SetIterator *kvit = newKVs->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + KeyValue *kvcopy = kv->getCopy(); + liveKeys->add(kvcopy->getKey()); + keyValueUpdateSet->add(kvcopy); + } + delete kvit; +} + +Commit *Commit_merge(Commit *newer, Commit *older, int64_t newSequenceNumber) { + if (older == NULL) { + return newer; + } else if (newer == NULL) { + return older; + } + Hashset *kvSet = new Hashset(); + SetIterator *kvit = older->getKeyValueUpdateSet()->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + kvSet->add(kv); + } + delete kvit; + kvit = newer->getKeyValueUpdateSet()->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + kvSet->add(kv); + } + delete kvit; + + int64_t transactionSequenceNumber = newer->getTransactionSequenceNumber(); + if (transactionSequenceNumber == -1) { + transactionSequenceNumber = older->getTransactionSequenceNumber(); + } + + Commit *newCommit = new Commit(newSequenceNumber, newer->getMachineId(), transactionSequenceNumber); + newCommit->setKVsMap(kvSet); + + delete kvSet; + return newCommit; +} diff --git a/version2/src/C/CommitPart.cc b/version2/src/C/CommitPart.cc deleted file mode 100644 index 6cb382c..0000000 --- a/version2/src/C/CommitPart.cc +++ /dev/null @@ -1,102 +0,0 @@ -#include "CommitPart.h" -#include "ByteBuffer.h" - -CommitPart::CommitPart(Slot *s, int64_t _machineId, int64_t _sequenceNumber, int64_t _transactionSequenceNumber, int _partNumber, Array *_data, bool _isLastPart) : - Entry(s), - machineId(_machineId), - sequenceNumber(_sequenceNumber), - transactionSequenceNumber(_transactionSequenceNumber), - partNumber(_partNumber), - fldisLastPart(_isLastPart), - refCount(1), - data(_data), - partId(Pair(sequenceNumber, partNumber)), - commitId(Pair(machineId, sequenceNumber)) { -} - -CommitPart::~CommitPart() { - delete data; -} - -int CommitPart::getSize() { - if (data == NULL) { - return (3 * sizeof(int64_t)) + (2 * sizeof(int32_t)) + (2 * sizeof(char)); - } - return (3 * sizeof(int64_t)) + (2 * sizeof(int32_t)) + (2 * sizeof(char)) + data->length(); -} - -int CommitPart::getPartNumber() { - return partNumber; -} - -int CommitPart::getDataSize() { - return data->length(); -} - -Array *CommitPart::getData() { - return data; -} - -Pair * CommitPart::getPartId() { - return & partId; -} - -Pair CommitPart::getCommitId() { - return commitId; -} - -bool CommitPart::isLastPart() { - return fldisLastPart; -} - -int64_t CommitPart::getMachineId() { - return machineId; -} - -int64_t CommitPart::getTransactionSequenceNumber() { - return transactionSequenceNumber; -} - -int64_t CommitPart::getSequenceNumber() { - return sequenceNumber; -} - -Entry *CommitPart_decode(Slot *s, ByteBuffer *bb) { - int64_t machineId = bb->getLong(); - int64_t sequenceNumber = bb->getLong(); - int64_t transactionSequenceNumber = bb->getLong(); - int partNumber = bb->getInt(); - int dataSize = bb->getInt(); - bool isLastPart = bb->get() == 1; - - // Get the data - Array *data = new Array(dataSize); - bb->get(data); - - return new CommitPart(s, machineId, sequenceNumber, transactionSequenceNumber, partNumber, data, isLastPart); -} - -void CommitPart::encode(ByteBuffer *bb) { - bb->put(TypeCommitPart); - bb->putLong(machineId); - bb->putLong(sequenceNumber); - bb->putLong(transactionSequenceNumber); - bb->putInt(partNumber); - bb->putInt(data->length()); - - if (fldisLastPart) { - bb->put((char)1); - } else { - bb->put((char)0); - } - - bb->put(data); -} - -char CommitPart::getType() { - return TypeCommitPart; -} - -Entry *CommitPart::getCopy(Slot *s) { - return new CommitPart(s, machineId, sequenceNumber, transactionSequenceNumber, partNumber, new Array(data), fldisLastPart); -} diff --git a/version2/src/C/CommitPart.cpp b/version2/src/C/CommitPart.cpp new file mode 100644 index 0000000..6cb382c --- /dev/null +++ b/version2/src/C/CommitPart.cpp @@ -0,0 +1,102 @@ +#include "CommitPart.h" +#include "ByteBuffer.h" + +CommitPart::CommitPart(Slot *s, int64_t _machineId, int64_t _sequenceNumber, int64_t _transactionSequenceNumber, int _partNumber, Array *_data, bool _isLastPart) : + Entry(s), + machineId(_machineId), + sequenceNumber(_sequenceNumber), + transactionSequenceNumber(_transactionSequenceNumber), + partNumber(_partNumber), + fldisLastPart(_isLastPart), + refCount(1), + data(_data), + partId(Pair(sequenceNumber, partNumber)), + commitId(Pair(machineId, sequenceNumber)) { +} + +CommitPart::~CommitPart() { + delete data; +} + +int CommitPart::getSize() { + if (data == NULL) { + return (3 * sizeof(int64_t)) + (2 * sizeof(int32_t)) + (2 * sizeof(char)); + } + return (3 * sizeof(int64_t)) + (2 * sizeof(int32_t)) + (2 * sizeof(char)) + data->length(); +} + +int CommitPart::getPartNumber() { + return partNumber; +} + +int CommitPart::getDataSize() { + return data->length(); +} + +Array *CommitPart::getData() { + return data; +} + +Pair * CommitPart::getPartId() { + return & partId; +} + +Pair CommitPart::getCommitId() { + return commitId; +} + +bool CommitPart::isLastPart() { + return fldisLastPart; +} + +int64_t CommitPart::getMachineId() { + return machineId; +} + +int64_t CommitPart::getTransactionSequenceNumber() { + return transactionSequenceNumber; +} + +int64_t CommitPart::getSequenceNumber() { + return sequenceNumber; +} + +Entry *CommitPart_decode(Slot *s, ByteBuffer *bb) { + int64_t machineId = bb->getLong(); + int64_t sequenceNumber = bb->getLong(); + int64_t transactionSequenceNumber = bb->getLong(); + int partNumber = bb->getInt(); + int dataSize = bb->getInt(); + bool isLastPart = bb->get() == 1; + + // Get the data + Array *data = new Array(dataSize); + bb->get(data); + + return new CommitPart(s, machineId, sequenceNumber, transactionSequenceNumber, partNumber, data, isLastPart); +} + +void CommitPart::encode(ByteBuffer *bb) { + bb->put(TypeCommitPart); + bb->putLong(machineId); + bb->putLong(sequenceNumber); + bb->putLong(transactionSequenceNumber); + bb->putInt(partNumber); + bb->putInt(data->length()); + + if (fldisLastPart) { + bb->put((char)1); + } else { + bb->put((char)0); + } + + bb->put(data); +} + +char CommitPart::getType() { + return TypeCommitPart; +} + +Entry *CommitPart::getCopy(Slot *s) { + return new CommitPart(s, machineId, sequenceNumber, transactionSequenceNumber, partNumber, new Array(data), fldisLastPart); +} diff --git a/version2/src/C/Crypto.cc b/version2/src/C/Crypto.cc deleted file mode 100644 index 9fe1154..0000000 --- a/version2/src/C/Crypto.cc +++ /dev/null @@ -1,24 +0,0 @@ -#include "Crypto.h" -#include "pbkdf2-sha256.h" - -AESKey::AESKey(Array *password, Array *salt, int iterationCount, int keyLength) { - key = new Array(keyLength / 8); - PKCS5_PBKDF2_HMAC((unsigned char *) password->internalArray(), password->length(), - (unsigned char *) salt->internalArray(), salt->length(), - iterationCount, keyLength / 8, (unsigned char *) key->internalArray()); - aes_key_setup((BYTE *)key->internalArray(), key_schedule, keyLength); -} - -AESKey::~AESKey() { - bzero(key->internalArray(), key->length()); - delete key; - bzero(key_schedule, sizeof(key_schedule)); -} - -WORD *AESKey::getKeySchedule() { - return (WORD *) &key_schedule; -} - -Array *AESKey::getKey() { - return key; -} diff --git a/version2/src/C/Crypto.cpp b/version2/src/C/Crypto.cpp new file mode 100644 index 0000000..9fe1154 --- /dev/null +++ b/version2/src/C/Crypto.cpp @@ -0,0 +1,24 @@ +#include "Crypto.h" +#include "pbkdf2-sha256.h" + +AESKey::AESKey(Array *password, Array *salt, int iterationCount, int keyLength) { + key = new Array(keyLength / 8); + PKCS5_PBKDF2_HMAC((unsigned char *) password->internalArray(), password->length(), + (unsigned char *) salt->internalArray(), salt->length(), + iterationCount, keyLength / 8, (unsigned char *) key->internalArray()); + aes_key_setup((BYTE *)key->internalArray(), key_schedule, keyLength); +} + +AESKey::~AESKey() { + bzero(key->internalArray(), key->length()); + delete key; + bzero(key_schedule, sizeof(key_schedule)); +} + +WORD *AESKey::getKeySchedule() { + return (WORD *) &key_schedule; +} + +Array *AESKey::getKey() { + return key; +} diff --git a/version2/src/C/Entry.cc b/version2/src/C/Entry.cc deleted file mode 100644 index 097967b..0000000 --- a/version2/src/C/Entry.cc +++ /dev/null @@ -1,48 +0,0 @@ -#include "Entry.h" -#include "Slot.h" -#include "ByteBuffer.h" -#include "Abort.h" -#include "CommitPart.h" -#include "NewKey.h" -#include "LastMessage.h" -#include "RejectedMessage.h" -#include "TableStatus.h" -#include "TransactionPart.h" -/** - * Generic class that wraps all the different types of information - * that can be stored in a Slot. - * @author Brian Demsky - * @version 1.0 - */ - -Entry *Entry_decode(Slot *slot, ByteBuffer *bb) { - char type = bb->get(); - switch (type) { - case TypeCommitPart: - return CommitPart_decode(slot, bb); - case TypeAbort: - return Abort_decode(slot, bb); - case TypeTransactionPart: - return TransactionPart_decode(slot, bb); - case TypeNewKey: - return NewKey_decode(slot, bb); - case TypeLastMessage: - return LastMessage_decode(slot, bb); - case TypeRejectedMessage: - return RejectedMessage_decode(slot, bb); - case TypeTableStatus: - return TableStatus_decode(slot, bb); - - default: - ASSERT(0); - } -} - -void Entry::setDead() { - if (islive) { - islive = false; - if (parentslot != NULL) { - parentslot->decrementLiveCount(); - } - } -} diff --git a/version2/src/C/Entry.cpp b/version2/src/C/Entry.cpp new file mode 100644 index 0000000..097967b --- /dev/null +++ b/version2/src/C/Entry.cpp @@ -0,0 +1,48 @@ +#include "Entry.h" +#include "Slot.h" +#include "ByteBuffer.h" +#include "Abort.h" +#include "CommitPart.h" +#include "NewKey.h" +#include "LastMessage.h" +#include "RejectedMessage.h" +#include "TableStatus.h" +#include "TransactionPart.h" +/** + * Generic class that wraps all the different types of information + * that can be stored in a Slot. + * @author Brian Demsky + * @version 1.0 + */ + +Entry *Entry_decode(Slot *slot, ByteBuffer *bb) { + char type = bb->get(); + switch (type) { + case TypeCommitPart: + return CommitPart_decode(slot, bb); + case TypeAbort: + return Abort_decode(slot, bb); + case TypeTransactionPart: + return TransactionPart_decode(slot, bb); + case TypeNewKey: + return NewKey_decode(slot, bb); + case TypeLastMessage: + return LastMessage_decode(slot, bb); + case TypeRejectedMessage: + return RejectedMessage_decode(slot, bb); + case TypeTableStatus: + return TableStatus_decode(slot, bb); + + default: + ASSERT(0); + } +} + +void Entry::setDead() { + if (islive) { + islive = false; + if (parentslot != NULL) { + parentslot->decrementLiveCount(); + } + } +} diff --git a/version2/src/C/KeyValue.cc b/version2/src/C/KeyValue.cc deleted file mode 100644 index 8aea6f2..0000000 --- a/version2/src/C/KeyValue.cc +++ /dev/null @@ -1,51 +0,0 @@ -#include "KeyValue.h" -#include "ByteBuffer.h" -#include "IoTString.h" -/** - * KeyValue entry for Slot. - * @author Brian Demsky - * @version 1.0 - */ - -KeyValue::~KeyValue() { - delete key; - delete value; -} - -KeyValue *KeyValue_decode(ByteBuffer *bb) { - int keylength = bb->getInt(); - int valuelength = bb->getInt(); - Array *key = new Array(keylength); - bb->get(key); - - if (valuelength != 0) { - Array *value = new Array(valuelength); - bb->get(value); - return new KeyValue(IoTString_shallow(key), IoTString_shallow(value)); - } - - return new KeyValue(IoTString_shallow(key), NULL); -} - -void KeyValue::encode(ByteBuffer *bb) { - bb->putInt(key->length()); - if (value != NULL) { - bb->putInt(value->length()); - } else { - bb->putInt(0); - } - bb->put(key->internalBytes()); - if (value != NULL) { - bb->put(value->internalBytes()); - } -} - -int KeyValue::getSize() { - if (value != NULL) - return 2 * sizeof(int32_t) + key->length() + value->length(); - return 2 * sizeof(int32_t) + key->length(); -} - -KeyValue *KeyValue::getCopy() { - return new KeyValue(new IoTString(key), new IoTString(value)); -} diff --git a/version2/src/C/KeyValue.cpp b/version2/src/C/KeyValue.cpp new file mode 100644 index 0000000..8aea6f2 --- /dev/null +++ b/version2/src/C/KeyValue.cpp @@ -0,0 +1,51 @@ +#include "KeyValue.h" +#include "ByteBuffer.h" +#include "IoTString.h" +/** + * KeyValue entry for Slot. + * @author Brian Demsky + * @version 1.0 + */ + +KeyValue::~KeyValue() { + delete key; + delete value; +} + +KeyValue *KeyValue_decode(ByteBuffer *bb) { + int keylength = bb->getInt(); + int valuelength = bb->getInt(); + Array *key = new Array(keylength); + bb->get(key); + + if (valuelength != 0) { + Array *value = new Array(valuelength); + bb->get(value); + return new KeyValue(IoTString_shallow(key), IoTString_shallow(value)); + } + + return new KeyValue(IoTString_shallow(key), NULL); +} + +void KeyValue::encode(ByteBuffer *bb) { + bb->putInt(key->length()); + if (value != NULL) { + bb->putInt(value->length()); + } else { + bb->putInt(0); + } + bb->put(key->internalBytes()); + if (value != NULL) { + bb->put(value->internalBytes()); + } +} + +int KeyValue::getSize() { + if (value != NULL) + return 2 * sizeof(int32_t) + key->length() + value->length(); + return 2 * sizeof(int32_t) + key->length(); +} + +KeyValue *KeyValue::getCopy() { + return new KeyValue(new IoTString(key), new IoTString(value)); +} diff --git a/version2/src/C/LastMessage.cc b/version2/src/C/LastMessage.cc deleted file mode 100644 index 92cf608..0000000 --- a/version2/src/C/LastMessage.cc +++ /dev/null @@ -1,21 +0,0 @@ -#include "LastMessage.h" -#include "Slot.h" -#include "ByteBuffer.h" - -/** - * This Entry records the last message sent by a given machine. - * @author Brian Demsky - * @version 1.0 - */ - -Entry *LastMessage_decode(Slot *slot, ByteBuffer *bb) { - int64_t machineid = bb->getLong(); - int64_t seqnum = bb->getLong(); - return new LastMessage(slot, machineid, seqnum); -} - -void LastMessage::encode(ByteBuffer *bb) { - bb->put(TypeLastMessage); - bb->putLong(machineid); - bb->putLong(seqnum); -} diff --git a/version2/src/C/LastMessage.cpp b/version2/src/C/LastMessage.cpp new file mode 100644 index 0000000..92cf608 --- /dev/null +++ b/version2/src/C/LastMessage.cpp @@ -0,0 +1,21 @@ +#include "LastMessage.h" +#include "Slot.h" +#include "ByteBuffer.h" + +/** + * This Entry records the last message sent by a given machine. + * @author Brian Demsky + * @version 1.0 + */ + +Entry *LastMessage_decode(Slot *slot, ByteBuffer *bb) { + int64_t machineid = bb->getLong(); + int64_t seqnum = bb->getLong(); + return new LastMessage(slot, machineid, seqnum); +} + +void LastMessage::encode(ByteBuffer *bb) { + bb->put(TypeLastMessage); + bb->putLong(machineid); + bb->putLong(seqnum); +} diff --git a/version2/src/C/LocalComm.cc b/version2/src/C/LocalComm.cc deleted file mode 100644 index ac4c343..0000000 --- a/version2/src/C/LocalComm.cc +++ /dev/null @@ -1,17 +0,0 @@ -#include "LocalComm.h" -#include "Error.h" -#include "Table.h" - -Array *LocalComm::sendDataToLocalDevice(int64_t deviceId, Array *data) { - printf("Passing Locally\n"); - - if (deviceId == t1->getMachineId()) { - // return t1.localCommInput(data); - } else if (deviceId == t2->getMachineId()) { - // return t2.localCommInput(data); - } else { - throw new Error("Cannot send to deviceId using this local comm"); - } - - return new Array((uint32_t)0); -} diff --git a/version2/src/C/LocalComm.cpp b/version2/src/C/LocalComm.cpp new file mode 100644 index 0000000..ac4c343 --- /dev/null +++ b/version2/src/C/LocalComm.cpp @@ -0,0 +1,17 @@ +#include "LocalComm.h" +#include "Error.h" +#include "Table.h" + +Array *LocalComm::sendDataToLocalDevice(int64_t deviceId, Array *data) { + printf("Passing Locally\n"); + + if (deviceId == t1->getMachineId()) { + // return t1.localCommInput(data); + } else if (deviceId == t2->getMachineId()) { + // return t2.localCommInput(data); + } else { + throw new Error("Cannot send to deviceId using this local comm"); + } + + return new Array((uint32_t)0); +} diff --git a/version2/src/C/Mac.cc b/version2/src/C/Mac.cc deleted file mode 100644 index aec7455..0000000 --- a/version2/src/C/Mac.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include "Mac.h" -#include "Crypto.h" - -Mac::Mac() { -} - -void Mac::update(Array *array, int32_t offset, int32_t len) { - sha2_hmac_update(&ctx, (const unsigned char *) &array->internalArray()[offset], len); -} - -Array *Mac::doFinal() { - Array *hmac = new Array(32); - sha2_hmac_finish(&ctx, (unsigned char *) hmac->internalArray()); - sha2_hmac_reset(&ctx); - return hmac; -} - -void Mac::init(AESKey *key) { - sha2_hmac_starts(&ctx, (const unsigned char *) key->getKey()->internalArray(), key->getKey()->length(), false); -} diff --git a/version2/src/C/Mac.cpp b/version2/src/C/Mac.cpp new file mode 100644 index 0000000..aec7455 --- /dev/null +++ b/version2/src/C/Mac.cpp @@ -0,0 +1,20 @@ +#include "Mac.h" +#include "Crypto.h" + +Mac::Mac() { +} + +void Mac::update(Array *array, int32_t offset, int32_t len) { + sha2_hmac_update(&ctx, (const unsigned char *) &array->internalArray()[offset], len); +} + +Array *Mac::doFinal() { + Array *hmac = new Array(32); + sha2_hmac_finish(&ctx, (unsigned char *) hmac->internalArray()); + sha2_hmac_reset(&ctx); + return hmac; +} + +void Mac::init(AESKey *key) { + sha2_hmac_starts(&ctx, (const unsigned char *) key->getKey()->internalArray(), key->getKey()->length(), false); +} diff --git a/version2/src/C/Makefile b/version2/src/C/Makefile index 66f289f..199196e 100644 --- a/version2/src/C/Makefile +++ b/version2/src/C/Makefile @@ -4,13 +4,13 @@ PHONY += directories MKDIR_P = mkdir -p OBJ_DIR = bin -CPP_SOURCES := $(wildcard *.cc) +CPP_SOURCES := $(wildcard *.cpp) HEADERS := $(wildcard *.h) -OBJECTS := $(CPP_SOURCES:%.cc=$(OBJ_DIR)/%.o) $(C_SOURCES:%.c=$(OBJ_DIR)/%.o) +OBJECTS := $(CPP_SOURCES:%.cpp=$(OBJ_DIR)/%.o) $(C_SOURCES:%.c=$(OBJ_DIR)/%.o) -CFLAGS := -Wall -O0 -g +CFLAGS := -Wall -O3 -g CFLAGS += -I. LDFLAGS := -ldl -lrt -rdynamic -g SHARED := -shared @@ -28,7 +28,7 @@ all: directories ${OBJ_DIR}/$(LIB_SO) test directories: ${OBJ_DIR} test: bin/lib_iotcloud.so - g++ -g -O0 Test.C -L./bin/ -l_iotcloud -lpthread -lbsd -o bin/Test + g++ -g -O3 Test.C -L./bin/ -l_iotcloud -lpthread -lbsd -o bin/Test ${OBJ_DIR}: ${MKDIR_P} ${OBJ_DIR} @@ -44,7 +44,7 @@ docs: $(C_SOURCES) $(HEADERS) ${OBJ_DIR}/$(LIB_SO): $(OBJECTS) $(CXX) -g $(SHARED) -o ${OBJ_DIR}/$(LIB_SO) $+ $(LDFLAGS) -${OBJ_DIR}/%.o: %.cc +${OBJ_DIR}/%.o: %.cpp $(CXX) -fPIC -c $< -o $@ $(CFLAGS) -Wno-unused-variable -include $(OBJECTS:%=$OBJ_DIR/.%.d) @@ -63,10 +63,10 @@ tags: ctags -R tabbing: - uncrustify -c C.cfg --no-backup *.cc + uncrustify -c C.cfg --no-backup *.cpp uncrustify -c C.cfg --no-backup *.h wc: - wc *.cc *.h + wc *.cpp *.h .PHONY: $(PHONY) diff --git a/version2/src/C/NewKey.cc b/version2/src/C/NewKey.cc deleted file mode 100644 index 0b35933..0000000 --- a/version2/src/C/NewKey.cc +++ /dev/null @@ -1,37 +0,0 @@ -#include "NewKey.h" -#include "ByteBuffer.h" -#include "IoTString.h" - -NewKey::NewKey(Slot *slot, IoTString *_key, int64_t _machineid) : - Entry(slot), - key(new IoTString(_key)), - machineid(_machineid) { -} - -NewKey::~NewKey() { - delete key; -} - -Entry *NewKey_decode(Slot *slot, ByteBuffer *bb) { - int keylength = bb->getInt(); - Array *key = new Array(keylength); - bb->get(key); - int64_t machineid = bb->getLong(); - IoTString *str = IoTString_shallow(key); - NewKey *newkey = new NewKey(slot, str, machineid); - delete str; - return newkey; -} - -Entry *NewKey::getCopy(Slot *s) { return new NewKey(s, key, machineid); } - -void NewKey::encode(ByteBuffer *bb) { - bb->put(TypeNewKey); - bb->putInt(key->length()); - bb->put(key->internalBytes()); - bb->putLong(machineid); -} - -int NewKey::getSize() { - return sizeof(int64_t) + sizeof(char) + sizeof(int32_t) + key->length(); -} diff --git a/version2/src/C/NewKey.cpp b/version2/src/C/NewKey.cpp new file mode 100644 index 0000000..0b35933 --- /dev/null +++ b/version2/src/C/NewKey.cpp @@ -0,0 +1,37 @@ +#include "NewKey.h" +#include "ByteBuffer.h" +#include "IoTString.h" + +NewKey::NewKey(Slot *slot, IoTString *_key, int64_t _machineid) : + Entry(slot), + key(new IoTString(_key)), + machineid(_machineid) { +} + +NewKey::~NewKey() { + delete key; +} + +Entry *NewKey_decode(Slot *slot, ByteBuffer *bb) { + int keylength = bb->getInt(); + Array *key = new Array(keylength); + bb->get(key); + int64_t machineid = bb->getLong(); + IoTString *str = IoTString_shallow(key); + NewKey *newkey = new NewKey(slot, str, machineid); + delete str; + return newkey; +} + +Entry *NewKey::getCopy(Slot *s) { return new NewKey(s, key, machineid); } + +void NewKey::encode(ByteBuffer *bb) { + bb->put(TypeNewKey); + bb->putInt(key->length()); + bb->put(key->internalBytes()); + bb->putLong(machineid); +} + +int NewKey::getSize() { + return sizeof(int64_t) + sizeof(char) + sizeof(int32_t) + key->length(); +} diff --git a/version2/src/C/PendingTransaction.cc b/version2/src/C/PendingTransaction.cc deleted file mode 100644 index c0d32a3..0000000 --- a/version2/src/C/PendingTransaction.cc +++ /dev/null @@ -1,194 +0,0 @@ -#include "PendingTransaction.h" -#include "KeyValue.h" -#include "IoTString.h" -#include "Transaction.h" -#include "TransactionPart.h" -#include "ByteBuffer.h" - -PendingTransaction::PendingTransaction(int64_t _machineId) : - keyValueUpdateSet(new Hashset()), - keyValueGuardSet(new Hashset()), - arbitrator(-1), - clientLocalSequenceNumber(-1), - machineId(_machineId), - currentDataSize(0) { -} - -PendingTransaction::~PendingTransaction() { - delete keyValueUpdateSet; - delete keyValueGuardSet; -} - -/** - * Add a new key value to the updates - * - */ -void PendingTransaction::addKV(KeyValue *newKV) { - KeyValue *rmKV = NULL; - - // Make sure there are no duplicates - SetIterator *kvit = keyValueUpdateSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - if (kv->getKey()->equals(newKV->getKey())) { - // Remove key if we are adding a newer version of the same key - rmKV = kv; - break; - } - } - delete kvit; - - // Remove key if we are adding a newer version of the same key - if (rmKV != NULL) { - keyValueUpdateSet->remove(rmKV); - currentDataSize -= rmKV->getSize(); - } - - // Add the key to the hash set - keyValueUpdateSet->add(newKV); - currentDataSize += newKV->getSize(); -} - -/** - * Add a new key value to the guard set - * - */ -void PendingTransaction::addKVGuard(KeyValue *newKV) { - // Add the key to the hash set - keyValueGuardSet->add(newKV); - currentDataSize += newKV->getSize(); -} - -/** - * Checks if the arbitrator is the same - */ -bool PendingTransaction::checkArbitrator(int64_t arb) { - if (arbitrator == -1) { - arbitrator = arb; - return true; - } - return arb == arbitrator; -} - -bool PendingTransaction::evaluateGuard(Hashtable *keyValTableCommitted, Hashtable *keyValTableSpeculative, Hashtable *keyValTablePendingTransSpeculative) { - SetIterator *kvit = keyValueGuardSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kvGuard = kvit->next(); - // First check if the key is in the speculative table, this is the - // value of the latest assumption - KeyValue *kv = keyValTablePendingTransSpeculative->get(kvGuard->getKey()); - - - if (kv == NULL) { - // if it is not in the pending trans table then check the - // speculative table and use that value as our latest assumption - kv = keyValTableSpeculative->get(kvGuard->getKey()); - } - - - if (kv == NULL) { - // if it is not in the speculative table then check the - // committed table and use that value as our latest assumption - kv = keyValTableCommitted->get(kvGuard->getKey()); - } - - if (kvGuard->getValue() != NULL) { - if ((kv == NULL) || (!kvGuard->getValue()->equals(kv->getValue()))) { - delete kvit; - return false; - } - } else { - if (kv != NULL) { - delete kvit; - return false; - } - } - } - delete kvit; - return true; -} - -Transaction *PendingTransaction::createTransaction() { - Transaction *newTransaction = new Transaction(); - int transactionPartCount = 0; - - // Convert all the data into a char array so we can start partitioning - Array *charData = convertDataToBytes(); - - int currentPosition = 0; - for (int remaining = charData->length(); remaining > 0;) { - bool isLastPart = false; - // determine how much to copy - int copySize = TransactionPart_MAX_NON_HEADER_SIZE; - if (remaining <= TransactionPart_MAX_NON_HEADER_SIZE) { - copySize = remaining; - isLastPart = true;//last bit of data so last part - } - - // Copy to a smaller version - Array *partData = new Array(copySize); - System_arraycopy(charData, currentPosition, partData, 0, copySize); - - TransactionPart *part = new TransactionPart(NULL, machineId, arbitrator, clientLocalSequenceNumber, transactionPartCount, partData, isLastPart); - newTransaction->addPartEncode(part); - part->releaseRef(); - - // Update position, count and remaining - currentPosition += copySize; - transactionPartCount++; - remaining -= copySize; - } - delete charData; - - // Add the Guard Conditions - SetIterator *kvit = keyValueGuardSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - newTransaction->addGuardKV(kv); - } - delete kvit; - - // Add the updates - kvit = keyValueUpdateSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - newTransaction->addUpdateKV(kv); - } - delete kvit; - return newTransaction; -} - -Array *PendingTransaction::convertDataToBytes() { - // Calculate the size of the data - int sizeOfData = 2 * sizeof(int32_t); // Number of Update KV's and Guard KV's - sizeOfData += currentDataSize; - - // Data handlers and storage - Array *dataArray = new Array(sizeOfData); - ByteBuffer *bbEncode = ByteBuffer_wrap(dataArray); - - // Encode the size of the updates and guard sets - bbEncode->putInt(keyValueGuardSet->size()); - bbEncode->putInt(keyValueUpdateSet->size()); - - // Encode all the guard conditions - SetIterator *kvit = keyValueGuardSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - kv->encode(bbEncode); - } - delete kvit; - - // Encode all the updates - kvit = keyValueUpdateSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - kv->encode(bbEncode); - } - delete kvit; - - Array *array = bbEncode->array(); - bbEncode->releaseArray(); - delete bbEncode; - return array; -} diff --git a/version2/src/C/PendingTransaction.cpp b/version2/src/C/PendingTransaction.cpp new file mode 100644 index 0000000..c0d32a3 --- /dev/null +++ b/version2/src/C/PendingTransaction.cpp @@ -0,0 +1,194 @@ +#include "PendingTransaction.h" +#include "KeyValue.h" +#include "IoTString.h" +#include "Transaction.h" +#include "TransactionPart.h" +#include "ByteBuffer.h" + +PendingTransaction::PendingTransaction(int64_t _machineId) : + keyValueUpdateSet(new Hashset()), + keyValueGuardSet(new Hashset()), + arbitrator(-1), + clientLocalSequenceNumber(-1), + machineId(_machineId), + currentDataSize(0) { +} + +PendingTransaction::~PendingTransaction() { + delete keyValueUpdateSet; + delete keyValueGuardSet; +} + +/** + * Add a new key value to the updates + * + */ +void PendingTransaction::addKV(KeyValue *newKV) { + KeyValue *rmKV = NULL; + + // Make sure there are no duplicates + SetIterator *kvit = keyValueUpdateSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + if (kv->getKey()->equals(newKV->getKey())) { + // Remove key if we are adding a newer version of the same key + rmKV = kv; + break; + } + } + delete kvit; + + // Remove key if we are adding a newer version of the same key + if (rmKV != NULL) { + keyValueUpdateSet->remove(rmKV); + currentDataSize -= rmKV->getSize(); + } + + // Add the key to the hash set + keyValueUpdateSet->add(newKV); + currentDataSize += newKV->getSize(); +} + +/** + * Add a new key value to the guard set + * + */ +void PendingTransaction::addKVGuard(KeyValue *newKV) { + // Add the key to the hash set + keyValueGuardSet->add(newKV); + currentDataSize += newKV->getSize(); +} + +/** + * Checks if the arbitrator is the same + */ +bool PendingTransaction::checkArbitrator(int64_t arb) { + if (arbitrator == -1) { + arbitrator = arb; + return true; + } + return arb == arbitrator; +} + +bool PendingTransaction::evaluateGuard(Hashtable *keyValTableCommitted, Hashtable *keyValTableSpeculative, Hashtable *keyValTablePendingTransSpeculative) { + SetIterator *kvit = keyValueGuardSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kvGuard = kvit->next(); + // First check if the key is in the speculative table, this is the + // value of the latest assumption + KeyValue *kv = keyValTablePendingTransSpeculative->get(kvGuard->getKey()); + + + if (kv == NULL) { + // if it is not in the pending trans table then check the + // speculative table and use that value as our latest assumption + kv = keyValTableSpeculative->get(kvGuard->getKey()); + } + + + if (kv == NULL) { + // if it is not in the speculative table then check the + // committed table and use that value as our latest assumption + kv = keyValTableCommitted->get(kvGuard->getKey()); + } + + if (kvGuard->getValue() != NULL) { + if ((kv == NULL) || (!kvGuard->getValue()->equals(kv->getValue()))) { + delete kvit; + return false; + } + } else { + if (kv != NULL) { + delete kvit; + return false; + } + } + } + delete kvit; + return true; +} + +Transaction *PendingTransaction::createTransaction() { + Transaction *newTransaction = new Transaction(); + int transactionPartCount = 0; + + // Convert all the data into a char array so we can start partitioning + Array *charData = convertDataToBytes(); + + int currentPosition = 0; + for (int remaining = charData->length(); remaining > 0;) { + bool isLastPart = false; + // determine how much to copy + int copySize = TransactionPart_MAX_NON_HEADER_SIZE; + if (remaining <= TransactionPart_MAX_NON_HEADER_SIZE) { + copySize = remaining; + isLastPart = true;//last bit of data so last part + } + + // Copy to a smaller version + Array *partData = new Array(copySize); + System_arraycopy(charData, currentPosition, partData, 0, copySize); + + TransactionPart *part = new TransactionPart(NULL, machineId, arbitrator, clientLocalSequenceNumber, transactionPartCount, partData, isLastPart); + newTransaction->addPartEncode(part); + part->releaseRef(); + + // Update position, count and remaining + currentPosition += copySize; + transactionPartCount++; + remaining -= copySize; + } + delete charData; + + // Add the Guard Conditions + SetIterator *kvit = keyValueGuardSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + newTransaction->addGuardKV(kv); + } + delete kvit; + + // Add the updates + kvit = keyValueUpdateSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + newTransaction->addUpdateKV(kv); + } + delete kvit; + return newTransaction; +} + +Array *PendingTransaction::convertDataToBytes() { + // Calculate the size of the data + int sizeOfData = 2 * sizeof(int32_t); // Number of Update KV's and Guard KV's + sizeOfData += currentDataSize; + + // Data handlers and storage + Array *dataArray = new Array(sizeOfData); + ByteBuffer *bbEncode = ByteBuffer_wrap(dataArray); + + // Encode the size of the updates and guard sets + bbEncode->putInt(keyValueGuardSet->size()); + bbEncode->putInt(keyValueUpdateSet->size()); + + // Encode all the guard conditions + SetIterator *kvit = keyValueGuardSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + kv->encode(bbEncode); + } + delete kvit; + + // Encode all the updates + kvit = keyValueUpdateSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + kv->encode(bbEncode); + } + delete kvit; + + Array *array = bbEncode->array(); + bbEncode->releaseArray(); + delete bbEncode; + return array; +} diff --git a/version2/src/C/RejectedMessage.cc b/version2/src/C/RejectedMessage.cc deleted file mode 100644 index 261f30c..0000000 --- a/version2/src/C/RejectedMessage.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include "RejectedMessage.h" -#include "ByteBuffer.h" - -/** - * Entry for tracking messages that the server rejected. We have to - * make sure that all clients know that this message was rejected to - * prevent the server from reusing these messages in an attack. - * @author Brian Demsky - * @version 1.0 - */ - -Entry *RejectedMessage_decode(Slot *slot, ByteBuffer *bb) { - int64_t sequencenum = bb->getLong(); - int64_t machineid = bb->getLong(); - int64_t oldseqnum = bb->getLong(); - int64_t newseqnum = bb->getLong(); - char equalto = bb->get(); - return new RejectedMessage(slot,sequencenum, machineid, oldseqnum, newseqnum, equalto == 1); -} - -RejectedMessage::~RejectedMessage() { - if (watchset != NULL) - delete watchset; -} - -void RejectedMessage::removeWatcher(int64_t machineid) { - if (watchset->remove(machineid)) - if (watchset->isEmpty()) - setDead(); -} - -void RejectedMessage::encode(ByteBuffer *bb) { - bb->put(TypeRejectedMessage); - bb->putLong(sequencenum); - bb->putLong(machineid); - bb->putLong(oldseqnum); - bb->putLong(newseqnum); - bb->put(equalto ? (char)1 : (char)0); -} diff --git a/version2/src/C/RejectedMessage.cpp b/version2/src/C/RejectedMessage.cpp new file mode 100644 index 0000000..261f30c --- /dev/null +++ b/version2/src/C/RejectedMessage.cpp @@ -0,0 +1,39 @@ +#include "RejectedMessage.h" +#include "ByteBuffer.h" + +/** + * Entry for tracking messages that the server rejected. We have to + * make sure that all clients know that this message was rejected to + * prevent the server from reusing these messages in an attack. + * @author Brian Demsky + * @version 1.0 + */ + +Entry *RejectedMessage_decode(Slot *slot, ByteBuffer *bb) { + int64_t sequencenum = bb->getLong(); + int64_t machineid = bb->getLong(); + int64_t oldseqnum = bb->getLong(); + int64_t newseqnum = bb->getLong(); + char equalto = bb->get(); + return new RejectedMessage(slot,sequencenum, machineid, oldseqnum, newseqnum, equalto == 1); +} + +RejectedMessage::~RejectedMessage() { + if (watchset != NULL) + delete watchset; +} + +void RejectedMessage::removeWatcher(int64_t machineid) { + if (watchset->remove(machineid)) + if (watchset->isEmpty()) + setDead(); +} + +void RejectedMessage::encode(ByteBuffer *bb) { + bb->put(TypeRejectedMessage); + bb->putLong(sequencenum); + bb->putLong(machineid); + bb->putLong(oldseqnum); + bb->putLong(newseqnum); + bb->put(equalto ? (char)1 : (char)0); +} diff --git a/version2/src/C/SecureRandom.cc b/version2/src/C/SecureRandom.cc deleted file mode 100644 index 84f40c3..0000000 --- a/version2/src/C/SecureRandom.cc +++ /dev/null @@ -1,14 +0,0 @@ -#include "SecureRandom.h" -#include -#include - -SecureRandom::SecureRandom() { -} - -void SecureRandom::nextBytes(Array *array) { - arc4random_buf(array->internalArray(), array->length()); -} - -int32_t SecureRandom::nextInt(int32_t val) { - return arc4random_uniform(val); -} diff --git a/version2/src/C/SecureRandom.cpp b/version2/src/C/SecureRandom.cpp new file mode 100644 index 0000000..84f40c3 --- /dev/null +++ b/version2/src/C/SecureRandom.cpp @@ -0,0 +1,14 @@ +#include "SecureRandom.h" +#include +#include + +SecureRandom::SecureRandom() { +} + +void SecureRandom::nextBytes(Array *array) { + arc4random_buf(array->internalArray(), array->length()); +} + +int32_t SecureRandom::nextInt(int32_t val) { + return arc4random_uniform(val); +} diff --git a/version2/src/C/Slot.cc b/version2/src/C/Slot.cc deleted file mode 100644 index 31cce57..0000000 --- a/version2/src/C/Slot.cc +++ /dev/null @@ -1,200 +0,0 @@ -#include "Slot.h" -#include "ByteBuffer.h" -#include "Entry.h" -#include "Error.h" -#include "CloudComm.h" -#include "Table.h" -#include "LastMessage.h" -#include "Mac.h" - -Slot::Slot(Table *_table, int64_t _seqnum, int64_t _machineid, Array *_prevhmac, Array *_hmac, int64_t _localSequenceNumber) : - seqnum(_seqnum), - prevhmac(_prevhmac), - hmac(_hmac), - machineid(_machineid), - entries(new Vector()), - livecount(1), - seqnumlive(true), - freespace(SLOT_SIZE - getBaseSize()), - table(_table), - fakeLastMessage(NULL), - localSequenceNumber(_localSequenceNumber) { -} - -Slot::Slot(Table *_table, int64_t _seqnum, int64_t _machineid, Array *_prevhmac, int64_t _localSequenceNumber) : - seqnum(_seqnum), - prevhmac(_prevhmac), - hmac(NULL), - machineid(_machineid), - entries(new Vector()), - livecount(1), - seqnumlive(true), - freespace(SLOT_SIZE - getBaseSize()), - table(_table), - fakeLastMessage(NULL), - localSequenceNumber(_localSequenceNumber) { -} - -Slot::Slot(Table *_table, int64_t _seqnum, int64_t _machineid, int64_t _localSequenceNumber) : - seqnum(_seqnum), - prevhmac(new Array(HMAC_SIZE)), - hmac(NULL), - machineid(_machineid), - entries(new Vector()), - livecount(1), - seqnumlive(true), - freespace(SLOT_SIZE - getBaseSize()), - table(_table), - fakeLastMessage(NULL), - localSequenceNumber(_localSequenceNumber) { -} - -Slot::~Slot() { - if (hmac != NULL) - delete hmac; - delete prevhmac; - for(uint i=0; i< entries->size(); i++) - entries->get(i)->releaseRef(); - delete entries; - if (fakeLastMessage) - delete fakeLastMessage; -} - -Entry *Slot::addEntry(Entry *e) { - e = e->getCopy(this); - entries->add(e); - livecount++; - freespace -= e->getSize(); - return e; -} - -void Slot::addShallowEntry(Entry *e) { - entries->add(e); - livecount++; - freespace -= e->getSize(); -} - -/** - * Returns true if the slot has free space to hold the entry without - * using its reserved space. */ - -bool Slot::hasSpace(Entry *e) { - int newfreespace = freespace - e->getSize(); - return newfreespace >= 0; -} - -Vector *Slot::getEntries() { - return entries; -} - -Slot *Slot_decode(Table *table, Array *array, Mac *mac) { - mac->update(array, HMAC_SIZE, array->length() - HMAC_SIZE); - Array *realmac = mac->doFinal(); - - ByteBuffer *bb = ByteBuffer_wrap(array); - Array *hmac = new Array(HMAC_SIZE); - Array *prevhmac = new Array(HMAC_SIZE); - bb->get(hmac); - bb->get(prevhmac); - if (!realmac->equals(hmac)) - throw new Error("Server Error: Invalid HMAC! Potential Attack!"); - delete realmac; - - int64_t seqnum = bb->getLong(); - int64_t machineid = bb->getLong(); - int numentries = bb->getInt(); - Slot *slot = new Slot(table, seqnum, machineid, prevhmac, hmac, -1); - - for (int i = 0; i < numentries; i++) { - slot->addShallowEntry(Entry_decode(slot, bb)); - } - bb->releaseArray(); - delete bb; - return slot; -} - -char Slot::getType() { - return TypeSlot; -} - -Array *Slot::encode(Mac *mac) { - Array *array = new Array(SLOT_SIZE); - ByteBuffer *bb = ByteBuffer_wrap(array); - /* Leave space for the slot HMAC. */ - bb->position(HMAC_SIZE); - bb->put(prevhmac); - bb->putLong(seqnum); - bb->putLong(machineid); - bb->putInt(entries->size()); - for (uint ei = 0; ei < entries->size(); ei++) { - Entry *entry = entries->get(ei); - entry->encode(bb); - } - /* Compute our HMAC */ - mac->update(array, HMAC_SIZE, array->length() - HMAC_SIZE); - Array *realmac = mac->doFinal(); - hmac = realmac; - bb->position(0); - bb->put(realmac); - bb->releaseArray(); - delete bb; - return array; -} - - -/** - * Returns the live set of entries for this Slot. Generates a fake - * LastMessage entry to represent the information stored by the slot - * itself. - */ - -Vector *Slot::getLiveEntries(bool resize) { - Vector *liveEntries = new Vector(); - for (uint ei = 0; ei < entries->size(); ei++) { - Entry *entry = entries->get(ei); - if (entry->isLive()) { - if (!resize || entry->getType() != TypeTableStatus) - liveEntries->add(entry); - } - } - - if (seqnumlive && !resize) { - if (! fakeLastMessage) - fakeLastMessage = new LastMessage(this, machineid, seqnum); - liveEntries->add(fakeLastMessage); - } - return liveEntries; -} - - -/** - * Records that a newer slot records the fact that this slot was - * sent by the relevant machine. - */ - -void Slot::setDead() { - seqnumlive = false; - decrementLiveCount(); -} - -/** - * Update the count of live entries. - */ - -void Slot::decrementLiveCount() { - livecount--; - if (livecount == 0) { - table->decrementLiveCount(); - } -} - -Array *Slot::getSlotCryptIV() { - ByteBuffer *buffer = ByteBuffer_allocate(CloudComm_IV_SIZE); - buffer->putLong(machineid); - int64_t localSequenceNumberShift = localSequenceNumber << 16; - buffer->putLong(localSequenceNumberShift); - Array * array = buffer->array(); - buffer->releaseArray(); - delete buffer; - return array; -} diff --git a/version2/src/C/Slot.cpp b/version2/src/C/Slot.cpp new file mode 100644 index 0000000..31cce57 --- /dev/null +++ b/version2/src/C/Slot.cpp @@ -0,0 +1,200 @@ +#include "Slot.h" +#include "ByteBuffer.h" +#include "Entry.h" +#include "Error.h" +#include "CloudComm.h" +#include "Table.h" +#include "LastMessage.h" +#include "Mac.h" + +Slot::Slot(Table *_table, int64_t _seqnum, int64_t _machineid, Array *_prevhmac, Array *_hmac, int64_t _localSequenceNumber) : + seqnum(_seqnum), + prevhmac(_prevhmac), + hmac(_hmac), + machineid(_machineid), + entries(new Vector()), + livecount(1), + seqnumlive(true), + freespace(SLOT_SIZE - getBaseSize()), + table(_table), + fakeLastMessage(NULL), + localSequenceNumber(_localSequenceNumber) { +} + +Slot::Slot(Table *_table, int64_t _seqnum, int64_t _machineid, Array *_prevhmac, int64_t _localSequenceNumber) : + seqnum(_seqnum), + prevhmac(_prevhmac), + hmac(NULL), + machineid(_machineid), + entries(new Vector()), + livecount(1), + seqnumlive(true), + freespace(SLOT_SIZE - getBaseSize()), + table(_table), + fakeLastMessage(NULL), + localSequenceNumber(_localSequenceNumber) { +} + +Slot::Slot(Table *_table, int64_t _seqnum, int64_t _machineid, int64_t _localSequenceNumber) : + seqnum(_seqnum), + prevhmac(new Array(HMAC_SIZE)), + hmac(NULL), + machineid(_machineid), + entries(new Vector()), + livecount(1), + seqnumlive(true), + freespace(SLOT_SIZE - getBaseSize()), + table(_table), + fakeLastMessage(NULL), + localSequenceNumber(_localSequenceNumber) { +} + +Slot::~Slot() { + if (hmac != NULL) + delete hmac; + delete prevhmac; + for(uint i=0; i< entries->size(); i++) + entries->get(i)->releaseRef(); + delete entries; + if (fakeLastMessage) + delete fakeLastMessage; +} + +Entry *Slot::addEntry(Entry *e) { + e = e->getCopy(this); + entries->add(e); + livecount++; + freespace -= e->getSize(); + return e; +} + +void Slot::addShallowEntry(Entry *e) { + entries->add(e); + livecount++; + freespace -= e->getSize(); +} + +/** + * Returns true if the slot has free space to hold the entry without + * using its reserved space. */ + +bool Slot::hasSpace(Entry *e) { + int newfreespace = freespace - e->getSize(); + return newfreespace >= 0; +} + +Vector *Slot::getEntries() { + return entries; +} + +Slot *Slot_decode(Table *table, Array *array, Mac *mac) { + mac->update(array, HMAC_SIZE, array->length() - HMAC_SIZE); + Array *realmac = mac->doFinal(); + + ByteBuffer *bb = ByteBuffer_wrap(array); + Array *hmac = new Array(HMAC_SIZE); + Array *prevhmac = new Array(HMAC_SIZE); + bb->get(hmac); + bb->get(prevhmac); + if (!realmac->equals(hmac)) + throw new Error("Server Error: Invalid HMAC! Potential Attack!"); + delete realmac; + + int64_t seqnum = bb->getLong(); + int64_t machineid = bb->getLong(); + int numentries = bb->getInt(); + Slot *slot = new Slot(table, seqnum, machineid, prevhmac, hmac, -1); + + for (int i = 0; i < numentries; i++) { + slot->addShallowEntry(Entry_decode(slot, bb)); + } + bb->releaseArray(); + delete bb; + return slot; +} + +char Slot::getType() { + return TypeSlot; +} + +Array *Slot::encode(Mac *mac) { + Array *array = new Array(SLOT_SIZE); + ByteBuffer *bb = ByteBuffer_wrap(array); + /* Leave space for the slot HMAC. */ + bb->position(HMAC_SIZE); + bb->put(prevhmac); + bb->putLong(seqnum); + bb->putLong(machineid); + bb->putInt(entries->size()); + for (uint ei = 0; ei < entries->size(); ei++) { + Entry *entry = entries->get(ei); + entry->encode(bb); + } + /* Compute our HMAC */ + mac->update(array, HMAC_SIZE, array->length() - HMAC_SIZE); + Array *realmac = mac->doFinal(); + hmac = realmac; + bb->position(0); + bb->put(realmac); + bb->releaseArray(); + delete bb; + return array; +} + + +/** + * Returns the live set of entries for this Slot. Generates a fake + * LastMessage entry to represent the information stored by the slot + * itself. + */ + +Vector *Slot::getLiveEntries(bool resize) { + Vector *liveEntries = new Vector(); + for (uint ei = 0; ei < entries->size(); ei++) { + Entry *entry = entries->get(ei); + if (entry->isLive()) { + if (!resize || entry->getType() != TypeTableStatus) + liveEntries->add(entry); + } + } + + if (seqnumlive && !resize) { + if (! fakeLastMessage) + fakeLastMessage = new LastMessage(this, machineid, seqnum); + liveEntries->add(fakeLastMessage); + } + return liveEntries; +} + + +/** + * Records that a newer slot records the fact that this slot was + * sent by the relevant machine. + */ + +void Slot::setDead() { + seqnumlive = false; + decrementLiveCount(); +} + +/** + * Update the count of live entries. + */ + +void Slot::decrementLiveCount() { + livecount--; + if (livecount == 0) { + table->decrementLiveCount(); + } +} + +Array *Slot::getSlotCryptIV() { + ByteBuffer *buffer = ByteBuffer_allocate(CloudComm_IV_SIZE); + buffer->putLong(machineid); + int64_t localSequenceNumberShift = localSequenceNumber << 16; + buffer->putLong(localSequenceNumberShift); + Array * array = buffer->array(); + buffer->releaseArray(); + delete buffer; + return array; +} diff --git a/version2/src/C/SlotBuffer.cc b/version2/src/C/SlotBuffer.cc deleted file mode 100644 index 06c46e9..0000000 --- a/version2/src/C/SlotBuffer.cc +++ /dev/null @@ -1,122 +0,0 @@ -#include "SlotBuffer.h" -#include "Slot.h" -/** - * Circular buffer that holds the live set of slots. - * @author Brian Demsky - * @version 1.0 - */ - -SlotBuffer::SlotBuffer() : - array(new Array(SlotBuffer_DEFAULT_SIZE + 1)), - head(0), - tail(0), - oldestseqn(0) { -} - -SlotBuffer::~SlotBuffer() { - int32_t index = tail; - while (index != head) { - delete array->get(index); - index++; - if (index == (int32_t) array->length()) - index = 0; - } - delete array; -} - -int SlotBuffer::size() { - if (head >= tail) - return head - tail; - return (array->length() + head) - tail; -} - -int SlotBuffer::capacity() { - return array->length() - 1; -} - -void SlotBuffer::resize(int newsize) { - if ((uint32_t)newsize == (array->length() - 1)) - return; - - Array *newarray = new Array(newsize + 1); - int currsize = size(); - int index = tail; - for (int i = 0; i < currsize; i++) { - newarray->set(i, array->get(index)); - if (((uint32_t)++ index) == array->length()) - index = 0; - } - array = newarray; - tail = 0; - head = currsize; -} - -void SlotBuffer::incrementHead() { - head++; - if (((uint32_t)head) >= array->length()) - head = 0; -} - -void SlotBuffer::incrementTail() { - delete array->get(tail); - tail++; - if (((uint32_t)tail) >= array->length()) - tail = 0; -} - -void SlotBuffer::putSlot(Slot *s) { - int64_t checkNum = (getNewestSeqNum() + 1); - - if (checkNum != s->getSequenceNumber()) { - int32_t index = tail; - while (index != head) { - delete array->get(index); - index++; - if (index == (int32_t) array->length()) - index = 0; - } - oldestseqn = s->getSequenceNumber(); - tail = 0; - head = 1; - array->set(0, s); - return; - } - - array->set(head, s); - incrementHead(); - - if (oldestseqn == 0) { - oldestseqn = s->getSequenceNumber(); - } - - if (head == tail) { - incrementTail(); - oldestseqn++; - } -} - -Slot *SlotBuffer::getSlot(int64_t seqnum) { - int32_t diff = (int32_t) (seqnum - oldestseqn); - int32_t index = diff + tail; - - if (index < 0) { - return NULL; - } - - if (((uint32_t)index) >= array->length()) { - if (head >= tail) { - return NULL; - } - index -= (int32_t) array->length(); - } - - if (((uint32_t)index) >= array->length()) { - return NULL; - } - - if (head >= tail && index >= head) { - return NULL; - } - - return array->get(index); -} diff --git a/version2/src/C/SlotBuffer.cpp b/version2/src/C/SlotBuffer.cpp new file mode 100644 index 0000000..06c46e9 --- /dev/null +++ b/version2/src/C/SlotBuffer.cpp @@ -0,0 +1,122 @@ +#include "SlotBuffer.h" +#include "Slot.h" +/** + * Circular buffer that holds the live set of slots. + * @author Brian Demsky + * @version 1.0 + */ + +SlotBuffer::SlotBuffer() : + array(new Array(SlotBuffer_DEFAULT_SIZE + 1)), + head(0), + tail(0), + oldestseqn(0) { +} + +SlotBuffer::~SlotBuffer() { + int32_t index = tail; + while (index != head) { + delete array->get(index); + index++; + if (index == (int32_t) array->length()) + index = 0; + } + delete array; +} + +int SlotBuffer::size() { + if (head >= tail) + return head - tail; + return (array->length() + head) - tail; +} + +int SlotBuffer::capacity() { + return array->length() - 1; +} + +void SlotBuffer::resize(int newsize) { + if ((uint32_t)newsize == (array->length() - 1)) + return; + + Array *newarray = new Array(newsize + 1); + int currsize = size(); + int index = tail; + for (int i = 0; i < currsize; i++) { + newarray->set(i, array->get(index)); + if (((uint32_t)++ index) == array->length()) + index = 0; + } + array = newarray; + tail = 0; + head = currsize; +} + +void SlotBuffer::incrementHead() { + head++; + if (((uint32_t)head) >= array->length()) + head = 0; +} + +void SlotBuffer::incrementTail() { + delete array->get(tail); + tail++; + if (((uint32_t)tail) >= array->length()) + tail = 0; +} + +void SlotBuffer::putSlot(Slot *s) { + int64_t checkNum = (getNewestSeqNum() + 1); + + if (checkNum != s->getSequenceNumber()) { + int32_t index = tail; + while (index != head) { + delete array->get(index); + index++; + if (index == (int32_t) array->length()) + index = 0; + } + oldestseqn = s->getSequenceNumber(); + tail = 0; + head = 1; + array->set(0, s); + return; + } + + array->set(head, s); + incrementHead(); + + if (oldestseqn == 0) { + oldestseqn = s->getSequenceNumber(); + } + + if (head == tail) { + incrementTail(); + oldestseqn++; + } +} + +Slot *SlotBuffer::getSlot(int64_t seqnum) { + int32_t diff = (int32_t) (seqnum - oldestseqn); + int32_t index = diff + tail; + + if (index < 0) { + return NULL; + } + + if (((uint32_t)index) >= array->length()) { + if (head >= tail) { + return NULL; + } + index -= (int32_t) array->length(); + } + + if (((uint32_t)index) >= array->length()) { + return NULL; + } + + if (head >= tail && index >= head) { + return NULL; + } + + return array->get(index); +} diff --git a/version2/src/C/SlotIndexer.cc b/version2/src/C/SlotIndexer.cc deleted file mode 100644 index b49ca45..0000000 --- a/version2/src/C/SlotIndexer.cc +++ /dev/null @@ -1,27 +0,0 @@ -#include "SlotIndexer.h" -#include "Slot.h" -#include "Error.h" -#include "SlotBuffer.h" -/** - * Slot indexer allows slots in both the slot buffer and the new - * server response to looked up in a consistent fashion. - * @author Brian Demsky - * @version 1.0 - */ - -SlotIndexer::SlotIndexer(Array *_updates, SlotBuffer *_buffer) : - updates(_updates), - buffer(_buffer), - firstslotseqnum(updates->get(0)->getSequenceNumber()) { -} - -Slot *SlotIndexer::getSlot(int64_t seqnum) { - if (seqnum >= firstslotseqnum) { - int32_t offset = (int32_t) (seqnum - firstslotseqnum); - if (((uint32_t)offset) >= updates->length()) - throw new Error("Invalid Slot Sequence Number Reference"); - else - return updates->get(offset); - } else - return buffer->getSlot(seqnum); -} diff --git a/version2/src/C/SlotIndexer.cpp b/version2/src/C/SlotIndexer.cpp new file mode 100644 index 0000000..b49ca45 --- /dev/null +++ b/version2/src/C/SlotIndexer.cpp @@ -0,0 +1,27 @@ +#include "SlotIndexer.h" +#include "Slot.h" +#include "Error.h" +#include "SlotBuffer.h" +/** + * Slot indexer allows slots in both the slot buffer and the new + * server response to looked up in a consistent fashion. + * @author Brian Demsky + * @version 1.0 + */ + +SlotIndexer::SlotIndexer(Array *_updates, SlotBuffer *_buffer) : + updates(_updates), + buffer(_buffer), + firstslotseqnum(updates->get(0)->getSequenceNumber()) { +} + +Slot *SlotIndexer::getSlot(int64_t seqnum) { + if (seqnum >= firstslotseqnum) { + int32_t offset = (int32_t) (seqnum - firstslotseqnum); + if (((uint32_t)offset) >= updates->length()) + throw new Error("Invalid Slot Sequence Number Reference"); + else + return updates->get(offset); + } else + return buffer->getSlot(seqnum); +} diff --git a/version2/src/C/Table.cc b/version2/src/C/Table.cc deleted file mode 100644 index 255ba3c..0000000 --- a/version2/src/C/Table.cc +++ /dev/null @@ -1,2863 +0,0 @@ -#include "Table.h" -#include "CloudComm.h" -#include "SlotBuffer.h" -#include "NewKey.h" -#include "Slot.h" -#include "KeyValue.h" -#include "Error.h" -#include "PendingTransaction.h" -#include "TableStatus.h" -#include "TransactionStatus.h" -#include "Transaction.h" -#include "LastMessage.h" -#include "SecureRandom.h" -#include "ByteBuffer.h" -#include "Abort.h" -#include "CommitPart.h" -#include "ArbitrationRound.h" -#include "TransactionPart.h" -#include "Commit.h" -#include "RejectedMessage.h" -#include "SlotIndexer.h" -#include - -int compareInt64(const void *a, const void *b) { - const int64_t *pa = (const int64_t *) a; - const int64_t *pb = (const int64_t *) b; - if (*pa < *pb) - return -1; - else if (*pa > *pb) - return 1; - else - return 0; -} - -Table::Table(IoTString *baseurl, IoTString *password, int64_t _localMachineId, int listeningPort) : - buffer(NULL), - cloud(new CloudComm(this, baseurl, password, listeningPort)), - random(NULL), - liveTableStatus(NULL), - pendingTransactionBuilder(NULL), - lastPendingTransactionSpeculatedOn(NULL), - firstPendingTransaction(NULL), - numberOfSlots(0), - bufferResizeThreshold(0), - liveSlotCount(0), - oldestLiveSlotSequenceNumver(1), - localMachineId(_localMachineId), - sequenceNumber(0), - localSequenceNumber(0), - localTransactionSequenceNumber(0), - lastTransactionSequenceNumberSpeculatedOn(0), - oldestTransactionSequenceNumberSpeculatedOn(0), - localArbitrationSequenceNumber(0), - hadPartialSendToServer(false), - attemptedToSendToServer(false), - expectedsize(0), - didFindTableStatus(false), - currMaxSize(0), - lastSlotAttemptedToSend(NULL), - lastIsNewKey(false), - lastNewSize(0), - lastTransactionPartsSent(NULL), - lastNewKey(NULL), - committedKeyValueTable(NULL), - speculatedKeyValueTable(NULL), - pendingTransactionSpeculatedKeyValueTable(NULL), - liveNewKeyTable(NULL), - lastMessageTable(NULL), - rejectedMessageWatchVectorTable(NULL), - arbitratorTable(NULL), - liveAbortTable(NULL), - newTransactionParts(NULL), - newCommitParts(NULL), - lastArbitratedTransactionNumberByArbitratorTable(NULL), - liveTransactionBySequenceNumberTable(NULL), - liveTransactionByTransactionIdTable(NULL), - liveCommitsTable(NULL), - liveCommitsByKeyTable(NULL), - lastCommitSeenSequenceNumberByArbitratorTable(NULL), - rejectedSlotVector(NULL), - pendingTransactionQueue(NULL), - pendingSendArbitrationRounds(NULL), - pendingSendArbitrationEntriesToDelete(NULL), - transactionPartsSent(NULL), - outstandingTransactionStatus(NULL), - liveAbortsGeneratedByLocal(NULL), - offlineTransactionsCommittedAndAtServer(NULL), - localCommunicationTable(NULL), - lastTransactionSeenFromMachineFromServer(NULL), - lastArbitrationDataLocalSequenceNumberSeenFromArbitrator(NULL), - lastInsertedNewKey(false), - lastSeqNumArbOn(0) -{ - init(); -} - -Table::Table(CloudComm *_cloud, int64_t _localMachineId) : - buffer(NULL), - cloud(_cloud), - random(NULL), - liveTableStatus(NULL), - pendingTransactionBuilder(NULL), - lastPendingTransactionSpeculatedOn(NULL), - firstPendingTransaction(NULL), - numberOfSlots(0), - bufferResizeThreshold(0), - liveSlotCount(0), - oldestLiveSlotSequenceNumver(1), - localMachineId(_localMachineId), - sequenceNumber(0), - localSequenceNumber(0), - localTransactionSequenceNumber(0), - lastTransactionSequenceNumberSpeculatedOn(0), - oldestTransactionSequenceNumberSpeculatedOn(0), - localArbitrationSequenceNumber(0), - hadPartialSendToServer(false), - attemptedToSendToServer(false), - expectedsize(0), - didFindTableStatus(false), - currMaxSize(0), - lastSlotAttemptedToSend(NULL), - lastIsNewKey(false), - lastNewSize(0), - lastTransactionPartsSent(NULL), - lastNewKey(NULL), - committedKeyValueTable(NULL), - speculatedKeyValueTable(NULL), - pendingTransactionSpeculatedKeyValueTable(NULL), - liveNewKeyTable(NULL), - lastMessageTable(NULL), - rejectedMessageWatchVectorTable(NULL), - arbitratorTable(NULL), - liveAbortTable(NULL), - newTransactionParts(NULL), - newCommitParts(NULL), - lastArbitratedTransactionNumberByArbitratorTable(NULL), - liveTransactionBySequenceNumberTable(NULL), - liveTransactionByTransactionIdTable(NULL), - liveCommitsTable(NULL), - liveCommitsByKeyTable(NULL), - lastCommitSeenSequenceNumberByArbitratorTable(NULL), - rejectedSlotVector(NULL), - pendingTransactionQueue(NULL), - pendingSendArbitrationRounds(NULL), - pendingSendArbitrationEntriesToDelete(NULL), - transactionPartsSent(NULL), - outstandingTransactionStatus(NULL), - liveAbortsGeneratedByLocal(NULL), - offlineTransactionsCommittedAndAtServer(NULL), - localCommunicationTable(NULL), - lastTransactionSeenFromMachineFromServer(NULL), - lastArbitrationDataLocalSequenceNumberSeenFromArbitrator(NULL), - lastInsertedNewKey(false), - lastSeqNumArbOn(0) -{ - init(); -} - -Table::~Table() { - delete cloud; - delete random; - delete buffer; - // init data structs - delete committedKeyValueTable; - delete speculatedKeyValueTable; - delete pendingTransactionSpeculatedKeyValueTable; - delete liveNewKeyTable; - { - SetIterator *> *lmit = getKeyIterator(lastMessageTable); - while (lmit->hasNext()) { - Pair * pair = lastMessageTable->get(lmit->next()); - delete pair; - } - delete lmit; - delete lastMessageTable; - } - if (pendingTransactionBuilder != NULL) - delete pendingTransactionBuilder; - { - SetIterator *> *rmit = getKeyIterator(rejectedMessageWatchVectorTable); - while(rmit->hasNext()) { - int64_t machineid = rmit->next(); - Hashset * rmset = rejectedMessageWatchVectorTable->get(machineid); - SetIterator * mit = rmset->iterator(); - while (mit->hasNext()) { - RejectedMessage * rm = mit->next(); - delete rm; - } - delete mit; - delete rmset; - } - delete rmit; - delete rejectedMessageWatchVectorTable; - } - delete arbitratorTable; - delete liveAbortTable; - { - SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *partsit = getKeyIterator(newTransactionParts); - while (partsit->hasNext()) { - int64_t machineId = partsit->next(); - Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = partsit->currVal(); - SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *pit = getKeyIterator(parts); - while(pit->hasNext()) { - Pair * pair=pit->next(); - pit->currVal()->releaseRef(); - } - delete pit; - - delete parts; - } - delete partsit; - delete newTransactionParts; - } - { - SetIterator *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *partsit = getKeyIterator(newCommitParts); - while (partsit->hasNext()) { - int64_t machineId = partsit->next(); - Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = partsit->currVal(); - SetIterator *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *pit = getKeyIterator(parts); - while(pit->hasNext()) { - Pair * pair=pit->next(); - pit->currVal()->releaseRef(); - } - delete pit; - delete parts; - } - delete partsit; - delete newCommitParts; - } - delete lastArbitratedTransactionNumberByArbitratorTable; - delete liveTransactionBySequenceNumberTable; - delete liveTransactionByTransactionIdTable; - { - SetIterator *> *liveit = getKeyIterator(liveCommitsTable); - while (liveit->hasNext()) { - int64_t arbitratorId = liveit->next(); - - // Get all the commits for a specific arbitrator - Hashtable *commitForClientTable = liveit->currVal(); - { - SetIterator *clientit = getKeyIterator(commitForClientTable); - while (clientit->hasNext()) { - int64_t id = clientit->next(); - delete commitForClientTable->get(id); - } - delete clientit; - } - - delete commitForClientTable; - } - delete liveit; - delete liveCommitsTable; - } - delete liveCommitsByKeyTable; - delete lastCommitSeenSequenceNumberByArbitratorTable; - delete rejectedSlotVector; - { - uint size = pendingTransactionQueue->size(); - for (uint iter = 0; iter < size; iter++) { - delete pendingTransactionQueue->get(iter); - } - delete pendingTransactionQueue; - } - delete pendingSendArbitrationEntriesToDelete; - { - SetIterator *> *trit = getKeyIterator(transactionPartsSent); - while (trit->hasNext()) { - Transaction *transaction = trit->next(); - delete trit->currVal(); - } - delete trit; - delete transactionPartsSent; - } - delete outstandingTransactionStatus; - delete liveAbortsGeneratedByLocal; - delete offlineTransactionsCommittedAndAtServer; - delete localCommunicationTable; - delete lastTransactionSeenFromMachineFromServer; - { - for(uint i = 0; i < pendingSendArbitrationRounds->size(); i++) { - delete pendingSendArbitrationRounds->get(i); - } - delete pendingSendArbitrationRounds; - } - if (lastTransactionPartsSent != NULL) - delete lastTransactionPartsSent; - delete lastArbitrationDataLocalSequenceNumberSeenFromArbitrator; - if (lastNewKey) - delete lastNewKey; -} - -/** - * Init all the stuff needed for for table usage - */ -void Table::init() { - // Init helper objects - random = new SecureRandom(); - buffer = new SlotBuffer(); - - // init data structs - committedKeyValueTable = new Hashtable(); - speculatedKeyValueTable = new Hashtable(); - pendingTransactionSpeculatedKeyValueTable = new Hashtable(); - liveNewKeyTable = new Hashtable(); - lastMessageTable = new Hashtable * >(); - rejectedMessageWatchVectorTable = new Hashtable * >(); - arbitratorTable = new Hashtable(); - liveAbortTable = new Hashtable *, Abort *, uintptr_t, 0, pairHashFunction, pairEquals>(); - newTransactionParts = new Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *>(); - newCommitParts = new Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *>(); - lastArbitratedTransactionNumberByArbitratorTable = new Hashtable(); - liveTransactionBySequenceNumberTable = new Hashtable(); - liveTransactionByTransactionIdTable = new Hashtable *, Transaction *, uintptr_t, 0, pairHashFunction, pairEquals>(); - liveCommitsTable = new Hashtable * >(); - liveCommitsByKeyTable = new Hashtable(); - lastCommitSeenSequenceNumberByArbitratorTable = new Hashtable(); - rejectedSlotVector = new Vector(); - pendingTransactionQueue = new Vector(); - pendingSendArbitrationEntriesToDelete = new Vector(); - transactionPartsSent = new Hashtable *>(); - outstandingTransactionStatus = new Hashtable(); - liveAbortsGeneratedByLocal = new Hashtable(); - offlineTransactionsCommittedAndAtServer = new Hashset *, uintptr_t, 0, pairHashFunction, pairEquals>(); - localCommunicationTable = new Hashtable *>(); - lastTransactionSeenFromMachineFromServer = new Hashtable(); - pendingSendArbitrationRounds = new Vector(); - lastArbitrationDataLocalSequenceNumberSeenFromArbitrator = new Hashtable(); - - // Other init stuff - numberOfSlots = buffer->capacity(); - setResizeThreshold(); -} - -/** - * Initialize the table by inserting a table status as the first entry - * into the table status also initialize the crypto stuff. - */ -void Table::initTable() { - cloud->initSecurity(); - - // Create the first insertion into the block chain which is the table status - Slot *s = new Slot(this, 1, localMachineId, localSequenceNumber); - localSequenceNumber++; - TableStatus *status = new TableStatus(s, numberOfSlots); - s->addShallowEntry(status); - Array *array = cloud->putSlot(s, numberOfSlots); - - if (array == NULL) { - array = new Array(1); - array->set(0, s); - // update local block chain - validateAndUpdate(array, true); - delete array; - } else if (array->length() == 1) { - // in case we did push the slot BUT we failed to init it - validateAndUpdate(array, true); - delete s; - delete array; - } else { - delete s; - delete array; - throw new Error("Error on initialization"); - } -} - -/** - * Rebuild the table from scratch by pulling the latest block chain - * from the server. - */ -void Table::rebuild() { - // Just pull the latest slots from the server - Array *newslots = cloud->getSlots(sequenceNumber + 1); - validateAndUpdate(newslots, true); - delete newslots; - sendToServer(NULL); - updateLiveTransactionsAndStatus(); -} - -void Table::addLocalCommunication(int64_t arbitrator, IoTString *hostName, int portNumber) { - localCommunicationTable->put(arbitrator, new Pair(hostName, portNumber)); -} - -int64_t Table::getArbitrator(IoTString *key) { - return arbitratorTable->get(key); -} - -void Table::close() { - cloud->closeCloud(); -} - -IoTString *Table::getCommitted(IoTString *key) { - KeyValue *kv = committedKeyValueTable->get(key); - - if (kv != NULL) { - return new IoTString(kv->getValue()); - } else { - return NULL; - } -} - -IoTString *Table::getSpeculative(IoTString *key) { - KeyValue *kv = pendingTransactionSpeculatedKeyValueTable->get(key); - - if (kv == NULL) { - kv = speculatedKeyValueTable->get(key); - } - - if (kv == NULL) { - kv = committedKeyValueTable->get(key); - } - - if (kv != NULL) { - return new IoTString(kv->getValue()); - } else { - return NULL; - } -} - -IoTString *Table::getCommittedAtomic(IoTString *key) { - KeyValue *kv = committedKeyValueTable->get(key); - - if (!arbitratorTable->contains(key)) { - throw new Error("Key not Found."); - } - - // Make sure new key value pair matches the current arbitrator - if (!pendingTransactionBuilder->checkArbitrator(arbitratorTable->get(key))) { - // TODO: Maybe not throw en error - throw new Error("Not all Key Values Match Arbitrator."); - } - - if (kv != NULL) { - pendingTransactionBuilder->addKVGuard(new KeyValue(key, kv->getValue())); - return new IoTString(kv->getValue()); - } else { - pendingTransactionBuilder->addKVGuard(new KeyValue(key, NULL)); - return NULL; - } -} - -IoTString *Table::getSpeculativeAtomic(IoTString *key) { - if (!arbitratorTable->contains(key)) { - throw new Error("Key not Found."); - } - - // Make sure new key value pair matches the current arbitrator - if (!pendingTransactionBuilder->checkArbitrator(arbitratorTable->get(key))) { - // TODO: Maybe not throw en error - throw new Error("Not all Key Values Match Arbitrator."); - } - - KeyValue *kv = pendingTransactionSpeculatedKeyValueTable->get(key); - - if (kv == NULL) { - kv = speculatedKeyValueTable->get(key); - } - - if (kv == NULL) { - kv = committedKeyValueTable->get(key); - } - - if (kv != NULL) { - pendingTransactionBuilder->addKVGuard(new KeyValue(key, kv->getValue())); - return new IoTString(kv->getValue()); - } else { - pendingTransactionBuilder->addKVGuard(new KeyValue(key, NULL)); - return NULL; - } -} - -bool Table::update() { - try { - Array *newSlots = cloud->getSlots(sequenceNumber + 1); - validateAndUpdate(newSlots, false); - delete newSlots; - sendToServer(NULL); - updateLiveTransactionsAndStatus(); - return true; - } catch (Exception *e) { - SetIterator *> *kit = getKeyIterator(localCommunicationTable); - while (kit->hasNext()) { - int64_t m = kit->next(); - updateFromLocal(m); - } - delete kit; - } - - return false; -} - -bool Table::createNewKey(IoTString *keyName, int64_t machineId) { - while (true) { - if (arbitratorTable->contains(keyName)) { - // There is already an arbitrator - return false; - } - NewKey *newKey = new NewKey(NULL, keyName, machineId); - - if (sendToServer(newKey)) { - // If successfully inserted - return true; - } - } -} - -void Table::startTransaction() { - // Create a new transaction, invalidates any old pending transactions. - if (pendingTransactionBuilder != NULL) - delete pendingTransactionBuilder; - pendingTransactionBuilder = new PendingTransaction(localMachineId); -} - -void Table::put(IoTString *key, IoTString *value) { - // Make sure it is a valid key - if (!arbitratorTable->contains(key)) { - throw new Error("Key not Found."); - } - - // Make sure new key value pair matches the current arbitrator - if (!pendingTransactionBuilder->checkArbitrator(arbitratorTable->get(key))) { - // TODO: Maybe not throw en error - throw new Error("Not all Key Values Match Arbitrator."); - } - - // Add the key value to this transaction - KeyValue *kv = new KeyValue(new IoTString(key), new IoTString(value)); - pendingTransactionBuilder->addKV(kv); -} - -TransactionStatus *Table::commitTransaction() { - if (pendingTransactionBuilder->getKVUpdates()->size() == 0) { - // transaction with no updates will have no effect on the system - return new TransactionStatus(TransactionStatus_StatusNoEffect, -1); - } - - // Set the local transaction sequence number and increment - pendingTransactionBuilder->setClientLocalSequenceNumber(localTransactionSequenceNumber); - localTransactionSequenceNumber++; - - // Create the transaction status - TransactionStatus *transactionStatus = new TransactionStatus(TransactionStatus_StatusPending, pendingTransactionBuilder->getArbitrator()); - - // Create the new transaction - Transaction *newTransaction = pendingTransactionBuilder->createTransaction(); - newTransaction->setTransactionStatus(transactionStatus); - - if (pendingTransactionBuilder->getArbitrator() != localMachineId) { - // Add it to the queue and invalidate the builder for safety - pendingTransactionQueue->add(newTransaction); - } else { - arbitrateOnLocalTransaction(newTransaction); - delete newTransaction; - updateLiveStateFromLocal(); - } - if (pendingTransactionBuilder != NULL) - delete pendingTransactionBuilder; - - pendingTransactionBuilder = new PendingTransaction(localMachineId); - - try { - sendToServer(NULL); - } catch (ServerException *e) { - - Hashset *arbitratorTriedAndFailed = new Hashset(); - uint size = pendingTransactionQueue->size(); - uint oldindex = 0; - for (uint iter = 0; iter < size; iter++) { - Transaction *transaction = pendingTransactionQueue->get(iter); - pendingTransactionQueue->set(oldindex++, pendingTransactionQueue->get(iter)); - - if (arbitratorTriedAndFailed->contains(transaction->getArbitrator())) { - // Already contacted this client so ignore all attempts to contact this client - // to preserve ordering for arbitrator - continue; - } - - Pair sendReturn = sendTransactionToLocal(transaction); - - if (sendReturn.getFirst()) { - // Failed to contact over local - arbitratorTriedAndFailed->add(transaction->getArbitrator()); - } else { - // Successful contact or should not contact - - if (sendReturn.getSecond()) { - // did arbitrate - delete transaction; - oldindex--; - } - } - } - pendingTransactionQueue->setSize(oldindex); - } - - updateLiveStateFromLocal(); - - return transactionStatus; -} - -/** - * Recalculate the new resize threshold - */ -void Table::setResizeThreshold() { - int resizeLower = (int) (Table_RESIZE_THRESHOLD * numberOfSlots); - bufferResizeThreshold = resizeLower - 1 + random->nextInt(numberOfSlots - resizeLower); -} - -int64_t Table::getLocalSequenceNumber() { - return localSequenceNumber; -} - -void Table::processTransactionList(bool handlePartial) { - SetIterator *> *trit = getKeyIterator(lastTransactionPartsSent); - while (trit->hasNext()) { - Transaction *transaction = trit->next(); - transaction->resetServerFailure(); - // Update which transactions parts still need to be sent - transaction->removeSentParts(lastTransactionPartsSent->get(transaction)); - // Add the transaction status to the outstanding list - outstandingTransactionStatus->put(transaction->getSequenceNumber(), transaction->getTransactionStatus()); - - // Update the transaction status - transaction->getTransactionStatus()->setStatus(TransactionStatus_StatusSentPartial); - - // Check if all the transaction parts were successfully - // sent and if so then remove it from pending - if (transaction->didSendAllParts()) { - transaction->getTransactionStatus()->setStatus(TransactionStatus_StatusSentFully); - pendingTransactionQueue->remove(transaction); - delete transaction; - } else if (handlePartial) { - transaction->resetServerFailure(); - // Set the transaction sequence number back to nothing - if (!transaction->didSendAPartToServer()) { - transaction->setSequenceNumber(-1); - } - } - } - delete trit; -} - -NewKey * Table::handlePartialSend(NewKey * newKey) { - //Didn't receive acknowledgement for last send - //See if the server has received a newer slot - - Array *newSlots = cloud->getSlots(sequenceNumber + 1); - if (newSlots->length() == 0) { - //Retry sending old slot - bool wasInserted = false; - bool sendSlotsReturn = sendSlotsToServer(lastSlotAttemptedToSend, lastNewSize, lastIsNewKey, &wasInserted, &newSlots); - - if (sendSlotsReturn) { - lastSlotAttemptedToSend = NULL; - if (newKey != NULL) { - if (lastInsertedNewKey && (lastNewKey->getKey() == newKey->getKey()) && (lastNewKey->getMachineID() == newKey->getMachineID())) { - delete newKey; - newKey = NULL; - } - } - processTransactionList(false); - } else { - if (checkSend(newSlots, lastSlotAttemptedToSend)) { - if (newKey != NULL) { - if (lastInsertedNewKey && (lastNewKey->getKey() == newKey->getKey()) && (lastNewKey->getMachineID() == newKey->getMachineID())) { - delete newKey; - newKey = NULL; - } - } - processTransactionList(true); - } - } - - SetIterator *> *trit = getKeyIterator(lastTransactionPartsSent); - while (trit->hasNext()) { - Transaction *transaction = trit->next(); - transaction->resetServerFailure(); - // Set the transaction sequence number back to nothing - if (!transaction->didSendAPartToServer()) { - transaction->setSequenceNumber(-1); - } - } - delete trit; - - if (newSlots->length() != 0) { - // insert into the local block chain - validateAndUpdate(newSlots, true); - } - } else { - if (checkSend(newSlots, lastSlotAttemptedToSend)) { - if (newKey != NULL) { - if (lastInsertedNewKey && (lastNewKey->getKey() == newKey->getKey()) && (lastNewKey->getMachineID() == newKey->getMachineID())) { - delete newKey; - newKey = NULL; - } - } - - processTransactionList(true); - } else { - SetIterator *> *trit = getKeyIterator(lastTransactionPartsSent); - while (trit->hasNext()) { - Transaction *transaction = trit->next(); - transaction->resetServerFailure(); - // Set the transaction sequence number back to nothing - if (!transaction->didSendAPartToServer()) { - transaction->setSequenceNumber(-1); - } - } - delete trit; - } - - // insert into the local block chain - validateAndUpdate(newSlots, true); - } - delete newSlots; - return newKey; -} - -void Table::clearSentParts() { - // Clear the sent data since we are trying again - pendingSendArbitrationEntriesToDelete->clear(); - SetIterator *> *trit = getKeyIterator(transactionPartsSent); - while (trit->hasNext()) { - Transaction *transaction = trit->next(); - delete trit->currVal(); - } - delete trit; - transactionPartsSent->clear(); -} - -bool Table::sendToServer(NewKey *newKey) { - if (hadPartialSendToServer) { - newKey = handlePartialSend(newKey); - } - - try { - // While we have stuff that needs inserting into the block chain - while ((pendingTransactionQueue->size() > 0) || (pendingSendArbitrationRounds->size() > 0) || (newKey != NULL)) { - if (hadPartialSendToServer) { - throw new Error("Should Be error free"); - } - - // If there is a new key with same name then end - if ((newKey != NULL) && arbitratorTable->contains(newKey->getKey())) { - delete newKey; - return false; - } - - // Create the slot - Slot *slot = new Slot(this, sequenceNumber + 1, localMachineId, new Array(buffer->getSlot(sequenceNumber)->getHMAC()), localSequenceNumber); - localSequenceNumber++; - - // Try to fill the slot with data - int newSize = 0; - bool insertedNewKey = false; - bool needsResize = fillSlot(slot, false, newKey, newSize, insertedNewKey); - - if (needsResize) { - // Reset which transaction to send - SetIterator *> *trit = getKeyIterator(transactionPartsSent); - while (trit->hasNext()) { - Transaction *transaction = trit->next(); - transaction->resetNextPartToSend(); - - // Set the transaction sequence number back to nothing - if (!transaction->didSendAPartToServer() && !transaction->getServerFailure()) { - transaction->setSequenceNumber(-1); - } - } - delete trit; - - // Clear the sent data since we are trying again - clearSentParts(); - - // We needed a resize so try again - fillSlot(slot, true, newKey, newSize, insertedNewKey); - } - if (lastSlotAttemptedToSend != NULL) - delete lastSlotAttemptedToSend; - - lastSlotAttemptedToSend = slot; - lastIsNewKey = (newKey != NULL); - lastInsertedNewKey = insertedNewKey; - lastNewSize = newSize; - if (( newKey != lastNewKey) && (lastNewKey != NULL)) - delete lastNewKey; - lastNewKey = newKey; - if (lastTransactionPartsSent != NULL) - delete lastTransactionPartsSent; - lastTransactionPartsSent = transactionPartsSent->clone(); - - Array * newSlots = NULL; - bool wasInserted = false; - bool sendSlotsReturn = sendSlotsToServer(slot, newSize, newKey != NULL, &wasInserted, &newSlots); - - if (sendSlotsReturn) { - lastSlotAttemptedToSend = NULL; - // Did insert into the block chain - if (insertedNewKey) { - // This slot was what was inserted not a previous slot - // New Key was successfully inserted into the block chain so dont want to insert it again - newKey = NULL; - } - - // Remove the aborts and commit parts that were sent from the pending to send queue - uint size = pendingSendArbitrationRounds->size(); - uint oldcount = 0; - for (uint i = 0; i < size; i++) { - ArbitrationRound *round = pendingSendArbitrationRounds->get(i); - round->removeParts(pendingSendArbitrationEntriesToDelete); - - if (!round->isDoneSending()) { - //Add part back in - pendingSendArbitrationRounds->set(oldcount++, - pendingSendArbitrationRounds->get(i)); - } else - delete pendingSendArbitrationRounds->get(i); - } - pendingSendArbitrationRounds->setSize(oldcount); - processTransactionList(false); - } else { - // Reset which transaction to send - SetIterator *> *trit = getKeyIterator(transactionPartsSent); - while (trit->hasNext()) { - Transaction *transaction = trit->next(); - transaction->resetNextPartToSend(); - - // Set the transaction sequence number back to nothing - if (!transaction->didSendAPartToServer() && !transaction->getServerFailure()) { - transaction->setSequenceNumber(-1); - } - } - delete trit; - } - - // Clear the sent data in preparation for next send - clearSentParts(); - - if (newSlots->length() != 0) { - // insert into the local block chain - validateAndUpdate(newSlots, true); - } - delete newSlots; - } - } catch (ServerException *e) { - if (e->getType() != ServerException_TypeInputTimeout) { - // Nothing was able to be sent to the server so just clear these data structures - SetIterator *> *trit = getKeyIterator(transactionPartsSent); - while (trit->hasNext()) { - Transaction *transaction = trit->next(); - transaction->resetNextPartToSend(); - - // Set the transaction sequence number back to nothing - if (!transaction->didSendAPartToServer() && !transaction->getServerFailure()) { - transaction->setSequenceNumber(-1); - } - } - delete trit; - } else { - // There was a partial send to the server - hadPartialSendToServer = true; - - // Nothing was able to be sent to the server so just clear these data structures - SetIterator *> *trit = getKeyIterator(transactionPartsSent); - while (trit->hasNext()) { - Transaction *transaction = trit->next(); - transaction->resetNextPartToSend(); - transaction->setServerFailure(); - } - delete trit; - } - - clearSentParts(); - - throw e; - } - - return newKey == NULL; -} - -bool Table::updateFromLocal(int64_t machineId) { - if (!localCommunicationTable->contains(machineId)) - return false; - - Pair *localCommunicationInformation = localCommunicationTable->get(machineId); - - // Get the size of the send data - int sendDataSize = sizeof(int32_t) + sizeof(int64_t); - - int64_t lastArbitrationDataLocalSequenceNumber = (int64_t) -1; - if (lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->contains(machineId)) { - lastArbitrationDataLocalSequenceNumber = lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->get(machineId); - } - - Array *sendData = new Array(sendDataSize); - ByteBuffer *bbEncode = ByteBuffer_wrap(sendData); - - // Encode the data - bbEncode->putLong(lastArbitrationDataLocalSequenceNumber); - bbEncode->putInt(0); - - // Send by local - Array *returnData = cloud->sendLocalData(sendData, localSequenceNumber, localCommunicationInformation->getFirst(), localCommunicationInformation->getSecond()); - localSequenceNumber++; - - if (returnData == NULL) { - // Could not contact server - return false; - } - - // Decode the data - ByteBuffer *bbDecode = ByteBuffer_wrap(returnData); - int numberOfEntries = bbDecode->getInt(); - - for (int i = 0; i < numberOfEntries; i++) { - char type = bbDecode->get(); - if (type == TypeAbort) { - Abort *abort = (Abort *)Abort_decode(NULL, bbDecode); - processEntry(abort); - } else if (type == TypeCommitPart) { - CommitPart *commitPart = (CommitPart *)CommitPart_decode(NULL, bbDecode); - processEntry(commitPart); - } - } - - updateLiveStateFromLocal(); - - return true; -} - -Pair Table::sendTransactionToLocal(Transaction *transaction) { - - // Get the devices local communications - if (!localCommunicationTable->contains(transaction->getArbitrator())) - return Pair(true, false); - - Pair *localCommunicationInformation = localCommunicationTable->get(transaction->getArbitrator()); - - // Get the size of the send data - int sendDataSize = sizeof(int32_t) + sizeof(int64_t); - { - Vector *tParts = transaction->getParts(); - uint tPartsSize = tParts->size(); - for (uint i = 0; i < tPartsSize; i++) { - TransactionPart *part = tParts->get(i); - sendDataSize += part->getSize(); - } - } - - int64_t lastArbitrationDataLocalSequenceNumber = (int64_t) -1; - if (lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->contains(transaction->getArbitrator())) { - lastArbitrationDataLocalSequenceNumber = lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->get(transaction->getArbitrator()); - } - - // Make the send data size - Array *sendData = new Array(sendDataSize); - ByteBuffer *bbEncode = ByteBuffer_wrap(sendData); - - // Encode the data - bbEncode->putLong(lastArbitrationDataLocalSequenceNumber); - bbEncode->putInt(transaction->getParts()->size()); - { - Vector *tParts = transaction->getParts(); - uint tPartsSize = tParts->size(); - for (uint i = 0; i < tPartsSize; i++) { - TransactionPart *part = tParts->get(i); - part->encode(bbEncode); - } - } - - // Send by local - Array *returnData = cloud->sendLocalData(sendData, localSequenceNumber, localCommunicationInformation->getFirst(), localCommunicationInformation->getSecond()); - localSequenceNumber++; - - if (returnData == NULL) { - // Could not contact server - return Pair(true, false); - } - - // Decode the data - ByteBuffer *bbDecode = ByteBuffer_wrap(returnData); - bool didCommit = bbDecode->get() == 1; - bool couldArbitrate = bbDecode->get() == 1; - int numberOfEntries = bbDecode->getInt(); - bool foundAbort = false; - - for (int i = 0; i < numberOfEntries; i++) { - char type = bbDecode->get(); - if (type == TypeAbort) { - Abort *abort = (Abort *)Abort_decode(NULL, bbDecode); - - if ((abort->getTransactionMachineId() == localMachineId) && (abort->getTransactionClientLocalSequenceNumber() == transaction->getClientLocalSequenceNumber())) { - foundAbort = true; - } - - processEntry(abort); - } else if (type == TypeCommitPart) { - CommitPart *commitPart = (CommitPart *)CommitPart_decode(NULL, bbDecode); - processEntry(commitPart); - } - } - - updateLiveStateFromLocal(); - - if (couldArbitrate) { - TransactionStatus *status = transaction->getTransactionStatus(); - if (didCommit) { - status->setStatus(TransactionStatus_StatusCommitted); - } else { - status->setStatus(TransactionStatus_StatusAborted); - } - } else { - TransactionStatus *status = transaction->getTransactionStatus(); - if (foundAbort) { - status->setStatus(TransactionStatus_StatusAborted); - } else { - status->setStatus(TransactionStatus_StatusCommitted); - } - } - - return Pair(false, true); -} - -Array *Table::acceptDataFromLocal(Array *data) { - // Decode the data - ByteBuffer *bbDecode = ByteBuffer_wrap(data); - int64_t lastArbitratedSequenceNumberSeen = bbDecode->getLong(); - int numberOfParts = bbDecode->getInt(); - - // If we did commit a transaction or not - bool didCommit = false; - bool couldArbitrate = false; - - if (numberOfParts != 0) { - - // decode the transaction - Transaction *transaction = new Transaction(); - for (int i = 0; i < numberOfParts; i++) { - bbDecode->get(); - TransactionPart *newPart = (TransactionPart *)TransactionPart_decode(NULL, bbDecode); - transaction->addPartDecode(newPart); - } - - // Arbitrate on transaction and pull relevant return data - Pair localArbitrateReturn = arbitrateOnLocalTransaction(transaction); - couldArbitrate = localArbitrateReturn.getFirst(); - didCommit = localArbitrateReturn.getSecond(); - - updateLiveStateFromLocal(); - - // Transaction was sent to the server so keep track of it to prevent double commit - if (transaction->getSequenceNumber() != -1) { - offlineTransactionsCommittedAndAtServer->add(new Pair(transaction->getId())); - } - } - - // The data to send back - int returnDataSize = 0; - Vector *unseenArbitrations = new Vector(); - - // Get the aborts to send back - Vector *abortLocalSequenceNumbers = new Vector(); - { - SetIterator *abortit = getKeyIterator(liveAbortsGeneratedByLocal); - while (abortit->hasNext()) - abortLocalSequenceNumbers->add(abortit->next()); - delete abortit; - } - - qsort(abortLocalSequenceNumbers->expose(), abortLocalSequenceNumbers->size(), sizeof(int64_t), compareInt64); - - uint asize = abortLocalSequenceNumbers->size(); - for (uint i = 0; i < asize; i++) { - int64_t localSequenceNumber = abortLocalSequenceNumbers->get(i); - if (localSequenceNumber <= lastArbitratedSequenceNumberSeen) { - continue; - } - - Abort *abort = liveAbortsGeneratedByLocal->get(localSequenceNumber); - unseenArbitrations->add(abort); - returnDataSize += abort->getSize(); - } - - // Get the commits to send back - Hashtable *commitForClientTable = liveCommitsTable->get(localMachineId); - if (commitForClientTable != NULL) { - Vector *commitLocalSequenceNumbers = new Vector(); - { - SetIterator *commitit = getKeyIterator(commitForClientTable); - while (commitit->hasNext()) - commitLocalSequenceNumbers->add(commitit->next()); - delete commitit; - } - qsort(commitLocalSequenceNumbers->expose(), commitLocalSequenceNumbers->size(), sizeof(int64_t), compareInt64); - - uint clsSize = commitLocalSequenceNumbers->size(); - for (uint clsi = 0; clsi < clsSize; clsi++) { - int64_t localSequenceNumber = commitLocalSequenceNumbers->get(clsi); - Commit *commit = commitForClientTable->get(localSequenceNumber); - - if (localSequenceNumber <= lastArbitratedSequenceNumberSeen) { - continue; - } - - { - Vector *parts = commit->getParts(); - uint nParts = parts->size(); - for (uint i = 0; i < nParts; i++) { - CommitPart *commitPart = parts->get(i); - unseenArbitrations->add(commitPart); - returnDataSize += commitPart->getSize(); - } - } - } - } - - // Number of arbitration entries to decode - returnDataSize += 2 * sizeof(int32_t); - - // bool of did commit or not - if (numberOfParts != 0) { - returnDataSize += sizeof(char); - } - - // Data to send Back - Array *returnData = new Array(returnDataSize); - ByteBuffer *bbEncode = ByteBuffer_wrap(returnData); - - if (numberOfParts != 0) { - if (didCommit) { - bbEncode->put((char)1); - } else { - bbEncode->put((char)0); - } - if (couldArbitrate) { - bbEncode->put((char)1); - } else { - bbEncode->put((char)0); - } - } - - bbEncode->putInt(unseenArbitrations->size()); - uint size = unseenArbitrations->size(); - for (uint i = 0; i < size; i++) { - Entry *entry = unseenArbitrations->get(i); - entry->encode(bbEncode); - } - - localSequenceNumber++; - return returnData; -} - -/** Checks whether a given slot was sent using new slots in - array. Returns true if sent and false otherwise. */ - -bool Table::checkSend(Array * array, Slot *checkSlot) { - uint size = array->length(); - for (uint i = 0; i < size; i++) { - Slot *s = array->get(i); - if ((s->getSequenceNumber() == checkSlot->getSequenceNumber()) && (s->getMachineID() == localMachineId)) { - return true; - } - } - - //Also need to see if other machines acknowledged our message - for (uint i = 0; i < size; i++) { - Slot *s = array->get(i); - - // Process each entry in the slot - Vector *entries = s->getEntries(); - uint eSize = entries->size(); - for (uint ei = 0; ei < eSize; ei++) { - Entry *entry = entries->get(ei); - - if (entry->getType() == TypeLastMessage) { - LastMessage *lastMessage = (LastMessage *)entry; - - if ((lastMessage->getMachineID() == localMachineId) && (lastMessage->getSequenceNumber() == checkSlot->getSequenceNumber())) { - return true; - } - } - } - } - //Not found - return false; -} - -/** Method tries to send slot to server. Returns status in tuple. - isInserted returns whether last un-acked send (if any) was - successful. Returns whether send was confirmed.x - */ - -bool Table::sendSlotsToServer(Slot *slot, int newSize, bool isNewKey, bool *isInserted, Array **array) { - attemptedToSendToServer = true; - - *array = cloud->putSlot(slot, newSize); - if (*array == NULL) { - *array = new Array(1); - (*array)->set(0, slot); - rejectedSlotVector->clear(); - *isInserted = false; - return true; - } else { - if ((*array)->length() == 0) { - throw new Error("Server Error: Did not send any slots"); - } - - if (hadPartialSendToServer) { - *isInserted = checkSend(*array, slot); - - if (!(*isInserted)) { - rejectedSlotVector->add(slot->getSequenceNumber()); - } - - return false; - } else { - rejectedSlotVector->add(slot->getSequenceNumber()); - *isInserted = false; - return false; - } - } -} - -/** - * Returns true if a resize was needed but not done. - */ -bool Table::fillSlot(Slot *slot, bool resize, NewKey *newKeyEntry, int & newSize, bool & insertedKey) { - newSize = 0;//special value to indicate no resize - if (liveSlotCount > bufferResizeThreshold) { - resize = true;//Resize is forced - } - - if (resize) { - newSize = (int) (numberOfSlots * Table_RESIZE_MULTIPLE); - TableStatus *status = new TableStatus(slot, newSize); - slot->addShallowEntry(status); - } - - // Fill with rejected slots first before doing anything else - doRejectedMessages(slot); - - // Do mandatory rescue of entries - ThreeTuple mandatoryRescueReturn = doMandatoryRescue(slot, resize); - - // Extract working variables - bool needsResize = mandatoryRescueReturn.getFirst(); - bool seenLiveSlot = mandatoryRescueReturn.getSecond(); - int64_t currentRescueSequenceNumber = mandatoryRescueReturn.getThird(); - - if (needsResize && !resize) { - // We need to resize but we are not resizing so return true to force on retry - return true; - } - - insertedKey = false; - if (newKeyEntry != NULL) { - newKeyEntry->setSlot(slot); - if (slot->hasSpace(newKeyEntry)) { - slot->addEntry(newKeyEntry); - insertedKey = true; - } - } - - // Clear the transactions, aborts and commits that were sent previously - clearSentParts(); - uint size = pendingSendArbitrationRounds->size(); - for (uint i = 0; i < size; i++) { - ArbitrationRound *round = pendingSendArbitrationRounds->get(i); - bool isFull = false; - round->generateParts(); - Vector *parts = round->getParts(); - - // Insert pending arbitration data - uint vsize = parts->size(); - for (uint vi = 0; vi < vsize; vi++) { - Entry *arbitrationData = parts->get(vi); - - // If it is an abort then we need to set some information - if (arbitrationData->getType() == TypeAbort) { - ((Abort *)arbitrationData)->setSequenceNumber(slot->getSequenceNumber()); - } - - if (!slot->hasSpace(arbitrationData)) { - // No space so cant do anything else with these data entries - isFull = true; - break; - } - - // Add to this current slot and add it to entries to delete - slot->addEntry(arbitrationData); - pendingSendArbitrationEntriesToDelete->add(arbitrationData); - } - - if (isFull) { - break; - } - } - - if (pendingTransactionQueue->size() > 0) { - Transaction *transaction = pendingTransactionQueue->get(0); - // Set the transaction sequence number if it has yet to be inserted into the block chain - if ((!transaction->didSendAPartToServer()) || (transaction->getSequenceNumber() == -1)) { - transaction->setSequenceNumber(slot->getSequenceNumber()); - } - - while (true) { - TransactionPart *part = transaction->getNextPartToSend(); - if (part == NULL) { - // Ran out of parts to send for this transaction so move on - break; - } - - if (slot->hasSpace(part)) { - slot->addEntry(part); - Vector *partsSent = transactionPartsSent->get(transaction); - if (partsSent == NULL) { - partsSent = new Vector(); - transactionPartsSent->put(transaction, partsSent); - } - partsSent->add(part->getPartNumber()); - transactionPartsSent->put(transaction, partsSent); - } else { - break; - } - } - } - - // Fill the remainder of the slot with rescue data - doOptionalRescue(slot, seenLiveSlot, currentRescueSequenceNumber, resize); - - return false; -} - -void Table::doRejectedMessages(Slot *s) { - if (!rejectedSlotVector->isEmpty()) { - /* TODO: We should avoid generating a rejected message entry if - * there is already a sufficient entry in the queue (e->g->, - * equalsto value of true and same sequence number)-> */ - - int64_t old_seqn = rejectedSlotVector->get(0); - if (rejectedSlotVector->size() > Table_REJECTED_THRESHOLD) { - int64_t new_seqn = rejectedSlotVector->lastElement(); - RejectedMessage *rm = new RejectedMessage(s, s->getSequenceNumber(), localMachineId, old_seqn, new_seqn, false); - s->addShallowEntry(rm); - } else { - int64_t prev_seqn = -1; - uint i = 0; - /* Go through list of missing messages */ - for (; i < rejectedSlotVector->size(); i++) { - int64_t curr_seqn = rejectedSlotVector->get(i); - Slot *s_msg = buffer->getSlot(curr_seqn); - if (s_msg != NULL) - break; - prev_seqn = curr_seqn; - } - /* Generate rejected message entry for missing messages */ - if (prev_seqn != -1) { - RejectedMessage *rm = new RejectedMessage(s, s->getSequenceNumber(), localMachineId, old_seqn, prev_seqn, false); - s->addShallowEntry(rm); - } - /* Generate rejected message entries for present messages */ - for (; i < rejectedSlotVector->size(); i++) { - int64_t curr_seqn = rejectedSlotVector->get(i); - Slot *s_msg = buffer->getSlot(curr_seqn); - int64_t machineid = s_msg->getMachineID(); - RejectedMessage *rm = new RejectedMessage(s, s->getSequenceNumber(), machineid, curr_seqn, curr_seqn, true); - s->addShallowEntry(rm); - } - } - } -} - -ThreeTuple Table::doMandatoryRescue(Slot *slot, bool resize) { - int64_t newestSequenceNumber = buffer->getNewestSeqNum(); - int64_t oldestSequenceNumber = buffer->getOldestSeqNum(); - if (oldestLiveSlotSequenceNumver < oldestSequenceNumber) { - oldestLiveSlotSequenceNumver = oldestSequenceNumber; - } - - int64_t currentSequenceNumber = oldestLiveSlotSequenceNumver; - bool seenLiveSlot = false; - int64_t firstIfFull = newestSequenceNumber + 1 - numberOfSlots; // smallest seq number in the buffer if it is full - int64_t threshold = firstIfFull + Table_FREE_SLOTS; // we want the buffer to be clear of live entries up to this point - - - // Mandatory Rescue - for (; currentSequenceNumber < threshold; currentSequenceNumber++) { - Slot *previousSlot = buffer->getSlot(currentSequenceNumber); - // Push slot number forward - if (!seenLiveSlot) { - oldestLiveSlotSequenceNumver = currentSequenceNumber; - } - - if (!previousSlot->isLive()) { - continue; - } - - // We have seen a live slot - seenLiveSlot = true; - - // Get all the live entries for a slot - Vector *liveEntries = previousSlot->getLiveEntries(resize); - - // Iterate over all the live entries and try to rescue them - uint lESize = liveEntries->size(); - for (uint i = 0; i < lESize; i++) { - Entry *liveEntry = liveEntries->get(i); - if (slot->hasSpace(liveEntry)) { - // Enough space to rescue the entry - slot->addEntry(liveEntry); - } else if (currentSequenceNumber == firstIfFull) { - //if there's no space but the entry is about to fall off the queue - return ThreeTuple(true, seenLiveSlot, currentSequenceNumber); - } - } - } - - // Did not resize - return ThreeTuple(false, seenLiveSlot, currentSequenceNumber); -} - -void Table::doOptionalRescue(Slot *s, bool seenliveslot, int64_t seqn, bool resize) { - /* now go through live entries from least to greatest sequence number until - * either all live slots added, or the slot doesn't have enough room - * for SKIP_THRESHOLD consecutive entries*/ - int skipcount = 0; - int64_t newestseqnum = buffer->getNewestSeqNum(); - for (; seqn <= newestseqnum; seqn++) { - Slot *prevslot = buffer->getSlot(seqn); - //Push slot number forward - if (!seenliveslot) - oldestLiveSlotSequenceNumver = seqn; - - if (!prevslot->isLive()) - continue; - seenliveslot = true; - Vector *liveentries = prevslot->getLiveEntries(resize); - uint lESize = liveentries->size(); - for (uint i = 0; i < lESize; i++) { - Entry *liveentry = liveentries->get(i); - if (s->hasSpace(liveentry)) - s->addEntry(liveentry); - else { - skipcount++; - if (skipcount > Table_SKIP_THRESHOLD) { - delete liveentries; - goto donesearch; - } - } - } - delete liveentries; - } -donesearch: - ; -} - -/** - * Checks for malicious activity and updates the local copy of the block chain-> - */ -void Table::validateAndUpdate(Array *newSlots, bool acceptUpdatesToLocal) { - // The cloud communication layer has checked slot HMACs already - // before decoding - if (newSlots->length() == 0) { - return; - } - - // Make sure all slots are newer than the last largest slot this - // client has seen - int64_t firstSeqNum = newSlots->get(0)->getSequenceNumber(); - if (firstSeqNum <= sequenceNumber) { - throw new Error("Server Error: Sent older slots!"); - } - - // Create an object that can access both new slots and slots in our - // local chain without committing slots to our local chain - SlotIndexer *indexer = new SlotIndexer(newSlots, buffer); - - // Check that the HMAC chain is not broken - checkHMACChain(indexer, newSlots); - - // Set to keep track of messages from clients - Hashset *machineSet = new Hashset(); - { - SetIterator *> *lmit = getKeyIterator(lastMessageTable); - while (lmit->hasNext()) - machineSet->add(lmit->next()); - delete lmit; - } - - // Process each slots data - { - uint numSlots = newSlots->length(); - for (uint i = 0; i < numSlots; i++) { - Slot *slot = newSlots->get(i); - processSlot(indexer, slot, acceptUpdatesToLocal, machineSet); - updateExpectedSize(); - } - } - delete indexer; - - // If there is a gap, check to see if the server sent us - // everything-> - if (firstSeqNum != (sequenceNumber + 1)) { - - // Check the size of the slots that were sent down by the server-> - // Can only check the size if there was a gap - checkNumSlots(newSlots->length()); - - // Since there was a gap every machine must have pushed a slot or - // must have a last message message-> If not then the server is - // hiding slots - if (!machineSet->isEmpty()) { - delete machineSet; - throw new Error("Missing record for machines: "); - } - } - delete machineSet; - // Update the size of our local block chain-> - commitNewMaxSize(); - - // Commit new to slots to the local block chain-> - { - uint numSlots = newSlots->length(); - for (uint i = 0; i < numSlots; i++) { - Slot *slot = newSlots->get(i); - - // Insert this slot into our local block chain copy-> - buffer->putSlot(slot); - - // Keep track of how many slots are currently live (have live data - // in them)-> - liveSlotCount++; - } - } - // Get the sequence number of the latest slot in the system - sequenceNumber = newSlots->get(newSlots->length() - 1)->getSequenceNumber(); - updateLiveStateFromServer(); - - // No Need to remember after we pulled from the server - offlineTransactionsCommittedAndAtServer->clear(); - - // This is invalidated now - hadPartialSendToServer = false; -} - -void Table::updateLiveStateFromServer() { - // Process the new transaction parts - processNewTransactionParts(); - - // Do arbitration on new transactions that were received - arbitrateFromServer(); - - // Update all the committed keys - bool didCommitOrSpeculate = updateCommittedTable(); - - // Delete the transactions that are now dead - updateLiveTransactionsAndStatus(); - - // Do speculations - didCommitOrSpeculate |= updateSpeculativeTable(didCommitOrSpeculate); - updatePendingTransactionSpeculativeTable(didCommitOrSpeculate); -} - -void Table::updateLiveStateFromLocal() { - // Update all the committed keys - bool didCommitOrSpeculate = updateCommittedTable(); - - // Delete the transactions that are now dead - updateLiveTransactionsAndStatus(); - - // Do speculations - didCommitOrSpeculate |= updateSpeculativeTable(didCommitOrSpeculate); - updatePendingTransactionSpeculativeTable(didCommitOrSpeculate); -} - -void Table::initExpectedSize(int64_t firstSequenceNumber, int64_t numberOfSlots) { - int64_t prevslots = firstSequenceNumber; - - if (didFindTableStatus) { - } else { - expectedsize = (prevslots < ((int64_t) numberOfSlots)) ? (int) prevslots : numberOfSlots; - } - - didFindTableStatus = true; - currMaxSize = numberOfSlots; -} - -void Table::updateExpectedSize() { - expectedsize++; - - if (expectedsize > currMaxSize) { - expectedsize = currMaxSize; - } -} - - -/** - * Check the size of the block chain to make sure there are enough - * slots sent back by the server-> This is only called when we have a - * gap between the slots that we have locally and the slots sent by - * the server therefore in the slots sent by the server there will be - * at least 1 Table status message - */ -void Table::checkNumSlots(int numberOfSlots) { - if (numberOfSlots != expectedsize) { - throw new Error("Server Error: Server did not send all slots-> Expected: "); - } -} - -/** - * Update the size of of the local buffer if it is needed-> - */ -void Table::commitNewMaxSize() { - didFindTableStatus = false; - - // Resize the local slot buffer - if (numberOfSlots != currMaxSize) { - buffer->resize((int32_t)currMaxSize); - } - - // Change the number of local slots to the new size - numberOfSlots = (int32_t)currMaxSize; - - // Recalculate the resize threshold since the size of the local - // buffer has changed - setResizeThreshold(); -} - -/** - * Process the new transaction parts from this latest round of slots - * received from the server - */ -void Table::processNewTransactionParts() { - - if (newTransactionParts->size() == 0) { - // Nothing new to process - return; - } - - // Iterate through all the machine Ids that we received new parts - // for - SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *tpit = getKeyIterator(newTransactionParts); - while (tpit->hasNext()) { - int64_t machineId = tpit->next(); - Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = tpit->currVal(); - - SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *ptit = getKeyIterator(parts); - // Iterate through all the parts for that machine Id - while (ptit->hasNext()) { - Pair *partId = ptit->next(); - TransactionPart *part = parts->get(partId); - - if (lastArbitratedTransactionNumberByArbitratorTable->contains(part->getArbitratorId())) { - int64_t lastTransactionNumber = lastArbitratedTransactionNumberByArbitratorTable->get(part->getArbitratorId()); - if (lastTransactionNumber >= part->getSequenceNumber()) { - // Set dead the transaction part - part->setDead(); - part->releaseRef(); - continue; - } - } - - // Get the transaction object for that sequence number - Transaction *transaction = liveTransactionBySequenceNumberTable->get(part->getSequenceNumber()); - - if (transaction == NULL) { - // This is a new transaction that we dont have so make a new one - transaction = new Transaction(); - - // Add that part to the transaction - transaction->addPartDecode(part); - - // Insert this new transaction into the live tables - liveTransactionBySequenceNumberTable->put(part->getSequenceNumber(), transaction); - liveTransactionByTransactionIdTable->put(transaction->getId(), transaction); - } - part->releaseRef(); - } - delete ptit; - } - delete tpit; - // Clear all the new transaction parts in preparation for the next - // time the server sends slots - { - SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *partsit = getKeyIterator(newTransactionParts); - while (partsit->hasNext()) { - int64_t machineId = partsit->next(); - Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = newTransactionParts->get(machineId); - delete parts; - } - delete partsit; - newTransactionParts->clear(); - } -} - -void Table::arbitrateFromServer() { - if (liveTransactionBySequenceNumberTable->size() == 0) { - // Nothing to arbitrate on so move on - return; - } - - // Get the transaction sequence numbers and sort from oldest to newest - Vector *transactionSequenceNumbers = new Vector(); - { - SetIterator *trit = getKeyIterator(liveTransactionBySequenceNumberTable); - while (trit->hasNext()) - transactionSequenceNumbers->add(trit->next()); - delete trit; - } - qsort(transactionSequenceNumbers->expose(), transactionSequenceNumbers->size(), sizeof(int64_t), compareInt64); - - // Collection of key value pairs that are - Hashtable *speculativeTableTmp = new Hashtable(); - - // The last transaction arbitrated on - int64_t lastTransactionCommitted = -1; - Hashset *generatedAborts = new Hashset(); - uint tsnSize = transactionSequenceNumbers->size(); - for (uint i = 0; i < tsnSize; i++) { - int64_t transactionSequenceNumber = transactionSequenceNumbers->get(i); - Transaction *transaction = liveTransactionBySequenceNumberTable->get(transactionSequenceNumber); - - // Check if this machine arbitrates for this transaction if not - // then we cant arbitrate this transaction - if (transaction->getArbitrator() != localMachineId) { - continue; - } - - if (transactionSequenceNumber < lastSeqNumArbOn) { - continue; - } - - if (offlineTransactionsCommittedAndAtServer->contains(transaction->getId())) { - // We have seen this already locally so dont commit again - continue; - } - - if (!transaction->isComplete()) { - // Will arbitrate in incorrect order if we continue so just break - // Most likely this - break; - } - - // update the largest transaction seen by arbitrator from server - if (!lastTransactionSeenFromMachineFromServer->contains(transaction->getMachineId())) { - lastTransactionSeenFromMachineFromServer->put(transaction->getMachineId(), transaction->getClientLocalSequenceNumber()); - } else { - int64_t lastTransactionSeenFromMachine = lastTransactionSeenFromMachineFromServer->get(transaction->getMachineId()); - if (transaction->getClientLocalSequenceNumber() > lastTransactionSeenFromMachine) { - lastTransactionSeenFromMachineFromServer->put(transaction->getMachineId(), transaction->getClientLocalSequenceNumber()); - } - } - - if (transaction->evaluateGuard(committedKeyValueTable, speculativeTableTmp, NULL)) { - // Guard evaluated as true - // Update the local changes so we can make the commit - SetIterator *kvit = transaction->getKeyValueUpdateSet()->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - speculativeTableTmp->put(kv->getKey(), kv); - } - delete kvit; - - // Update what the last transaction committed was for use in batch commit - lastTransactionCommitted = transactionSequenceNumber; - } else { - // Guard evaluated was false so create abort - // Create the abort - Abort *newAbort = new Abort(NULL, - transaction->getClientLocalSequenceNumber(), - transaction->getSequenceNumber(), - transaction->getMachineId(), - transaction->getArbitrator(), - localArbitrationSequenceNumber); - localArbitrationSequenceNumber++; - generatedAborts->add(newAbort); - - // Insert the abort so we can process - processEntry(newAbort); - } - - lastSeqNumArbOn = transactionSequenceNumber; - } - - delete transactionSequenceNumbers; - - Commit *newCommit = NULL; - - // If there is something to commit - if (speculativeTableTmp->size() != 0) { - // Create the commit and increment the commit sequence number - newCommit = new Commit(localArbitrationSequenceNumber, localMachineId, lastTransactionCommitted); - localArbitrationSequenceNumber++; - - // Add all the new keys to the commit - SetIterator *spit = getKeyIterator(speculativeTableTmp); - while (spit->hasNext()) { - IoTString *string = spit->next(); - KeyValue *kv = speculativeTableTmp->get(string); - newCommit->addKV(kv); - } - delete spit; - - // create the commit parts - newCommit->createCommitParts(); - - // Append all the commit parts to the end of the pending queue - // waiting for sending to the server - // Insert the commit so we can process it - Vector *parts = newCommit->getParts(); - uint partsSize = parts->size(); - for (uint i = 0; i < partsSize; i++) { - CommitPart *commitPart = parts->get(i); - processEntry(commitPart); - } - } - delete speculativeTableTmp; - - if ((newCommit != NULL) || (generatedAborts->size() > 0)) { - ArbitrationRound *arbitrationRound = new ArbitrationRound(newCommit, generatedAborts); - pendingSendArbitrationRounds->add(arbitrationRound); - - if (compactArbitrationData()) { - ArbitrationRound *newArbitrationRound = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - 1); - if (newArbitrationRound->getCommit() != NULL) { - Vector *parts = newArbitrationRound->getCommit()->getParts(); - uint partsSize = parts->size(); - for (uint i = 0; i < partsSize; i++) { - CommitPart *commitPart = parts->get(i); - processEntry(commitPart); - } - } - } - } else { - delete generatedAborts; - } -} - -Pair Table::arbitrateOnLocalTransaction(Transaction *transaction) { - - // Check if this machine arbitrates for this transaction if not then - // we cant arbitrate this transaction - if (transaction->getArbitrator() != localMachineId) { - return Pair(false, false); - } - - if (!transaction->isComplete()) { - // Will arbitrate in incorrect order if we continue so just break - // Most likely this - return Pair(false, false); - } - - if (transaction->getMachineId() != localMachineId) { - // dont do this check for local transactions - if (lastTransactionSeenFromMachineFromServer->contains(transaction->getMachineId())) { - if (lastTransactionSeenFromMachineFromServer->get(transaction->getMachineId()) > transaction->getClientLocalSequenceNumber()) { - // We've have already seen this from the server - return Pair(false, false); - } - } - } - - if (transaction->evaluateGuard(committedKeyValueTable, NULL, NULL)) { - // Guard evaluated as true Create the commit and increment the - // commit sequence number - Commit *newCommit = new Commit(localArbitrationSequenceNumber, localMachineId, -1); - localArbitrationSequenceNumber++; - - // Update the local changes so we can make the commit - SetIterator *kvit = transaction->getKeyValueUpdateSet()->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - newCommit->addKV(kv); - } - delete kvit; - - // create the commit parts - newCommit->createCommitParts(); - - // Append all the commit parts to the end of the pending queue - // waiting for sending to the server - ArbitrationRound *arbitrationRound = new ArbitrationRound(newCommit, new Hashset()); - pendingSendArbitrationRounds->add(arbitrationRound); - - if (compactArbitrationData()) { - ArbitrationRound *newArbitrationRound = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - 1); - Vector *parts = newArbitrationRound->getCommit()->getParts(); - uint partsSize = parts->size(); - for (uint i = 0; i < partsSize; i++) { - CommitPart *commitPart = parts->get(i); - processEntry(commitPart); - } - } else { - // Insert the commit so we can process it - Vector *parts = newCommit->getParts(); - uint partsSize = parts->size(); - for (uint i = 0; i < partsSize; i++) { - CommitPart *commitPart = parts->get(i); - processEntry(commitPart); - } - } - - if (transaction->getMachineId() == localMachineId) { - TransactionStatus *status = transaction->getTransactionStatus(); - if (status != NULL) { - status->setStatus(TransactionStatus_StatusCommitted); - } - } - - updateLiveStateFromLocal(); - return Pair(true, true); - } else { - if (transaction->getMachineId() == localMachineId) { - // For locally created messages update the status - // Guard evaluated was false so create abort - TransactionStatus *status = transaction->getTransactionStatus(); - if (status != NULL) { - status->setStatus(TransactionStatus_StatusAborted); - } - } else { - Hashset *addAbortSet = new Hashset(); - - // Create the abort - Abort *newAbort = new Abort(NULL, - transaction->getClientLocalSequenceNumber(), - -1, - transaction->getMachineId(), - transaction->getArbitrator(), - localArbitrationSequenceNumber); - localArbitrationSequenceNumber++; - addAbortSet->add(newAbort); - - // Append all the commit parts to the end of the pending queue - // waiting for sending to the server - ArbitrationRound *arbitrationRound = new ArbitrationRound(NULL, addAbortSet); - pendingSendArbitrationRounds->add(arbitrationRound); - - if (compactArbitrationData()) { - ArbitrationRound *newArbitrationRound = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - 1); - - Vector *parts = newArbitrationRound->getCommit()->getParts(); - uint partsSize = parts->size(); - for (uint i = 0; i < partsSize; i++) { - CommitPart *commitPart = parts->get(i); - processEntry(commitPart); - } - } - } - - updateLiveStateFromLocal(); - return Pair(true, false); - } -} - -/** - * Compacts the arbitration data by merging commits and aggregating - * aborts so that a single large push of commits can be done instead - * of many small updates - */ -bool Table::compactArbitrationData() { - if (pendingSendArbitrationRounds->size() < 2) { - // Nothing to compact so do nothing - return false; - } - - ArbitrationRound *lastRound = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - 1); - if (lastRound->getDidSendPart()) { - return false; - } - - bool hadCommit = (lastRound->getCommit() == NULL); - bool gotNewCommit = false; - - uint numberToDelete = 1; - - while (numberToDelete < pendingSendArbitrationRounds->size()) { - ArbitrationRound *round = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - numberToDelete - 1); - - if (round->isFull() || round->getDidSendPart()) { - // Stop since there is a part that cannot be compacted and we - // need to compact in order - break; - } - - if (round->getCommit() == NULL) { - // Try compacting aborts only - int newSize = round->getCurrentSize() + lastRound->getAbortsCount(); - if (newSize > ArbitrationRound_MAX_PARTS) { - // Cant compact since it would be too large - break; - } - lastRound->addAborts(round->getAborts()); - } else { - // Create a new larger commit - Commit *newCommit = Commit_merge(lastRound->getCommit(), round->getCommit(), localArbitrationSequenceNumber); - localArbitrationSequenceNumber++; - - // Create the commit parts so that we can count them - newCommit->createCommitParts(); - - // Calculate the new size of the parts - int newSize = newCommit->getNumberOfParts(); - newSize += lastRound->getAbortsCount(); - newSize += round->getAbortsCount(); - - if (newSize > ArbitrationRound_MAX_PARTS) { - // Can't compact since it would be too large - if (lastRound->getCommit() != newCommit && - round->getCommit() != newCommit) - delete newCommit; - break; - } - // Set the new compacted part - if (lastRound->getCommit() == newCommit) - lastRound->setCommit(NULL); - if (round->getCommit() == newCommit) - round->setCommit(NULL); - - if (lastRound->getCommit() != NULL) { - Commit * oldcommit = lastRound->getCommit(); - lastRound->setCommit(NULL); - delete oldcommit; - } - lastRound->setCommit(newCommit); - lastRound->addAborts(round->getAborts()); - gotNewCommit = true; - } - - numberToDelete++; - } - - if (numberToDelete != 1) { - // If there is a compaction - // Delete the previous pieces that are now in the new compacted piece - for (uint i = 2; i <= numberToDelete; i++) { - delete pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size()-i); - } - pendingSendArbitrationRounds->setSize(pendingSendArbitrationRounds->size() - numberToDelete); - - pendingSendArbitrationRounds->add(lastRound); - - // Should reinsert into the commit processor - if (hadCommit && gotNewCommit) { - return true; - } - } - - return false; -} - -/** - * Update all the commits and the committed tables, sets dead the dead - * transactions - */ -bool Table::updateCommittedTable() { - if (newCommitParts->size() == 0) { - // Nothing new to process - return false; - } - - // Iterate through all the machine Ids that we received new parts for - SetIterator *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *partsit = getKeyIterator(newCommitParts); - while (partsit->hasNext()) { - int64_t machineId = partsit->next(); - Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = newCommitParts->get(machineId); - - // Iterate through all the parts for that machine Id - SetIterator *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *pairit = getKeyIterator(parts); - while (pairit->hasNext()) { - Pair *partId = pairit->next(); - CommitPart *part = pairit->currVal(); - - // Get the transaction object for that sequence number - Hashtable *commitForClientTable = liveCommitsTable->get(part->getMachineId()); - - if (commitForClientTable == NULL) { - // This is the first commit from this device - commitForClientTable = new Hashtable(); - liveCommitsTable->put(part->getMachineId(), commitForClientTable); - } - - Commit *commit = commitForClientTable->get(part->getSequenceNumber()); - - if (commit == NULL) { - // This is a new commit that we dont have so make a new one - commit = new Commit(); - - // Insert this new commit into the live tables - commitForClientTable->put(part->getSequenceNumber(), commit); - } - - // Add that part to the commit - commit->addPartDecode(part); - part->releaseRef(); - } - delete pairit; - delete parts; - } - delete partsit; - - // Clear all the new commits parts in preparation for the next time - // the server sends slots - newCommitParts->clear(); - - // If we process a new commit keep track of it for future use - bool didProcessANewCommit = false; - - // Process the commits one by one - SetIterator *> *liveit = getKeyIterator(liveCommitsTable); - while (liveit->hasNext()) { - int64_t arbitratorId = liveit->next(); - // Get all the commits for a specific arbitrator - Hashtable *commitForClientTable = liveCommitsTable->get(arbitratorId); - - // Sort the commits in order - Vector *commitSequenceNumbers = new Vector(); - { - SetIterator *clientit = getKeyIterator(commitForClientTable); - while (clientit->hasNext()) - commitSequenceNumbers->add(clientit->next()); - delete clientit; - } - - qsort(commitSequenceNumbers->expose(), commitSequenceNumbers->size(), sizeof(int64_t), compareInt64); - - // Get the last commit seen from this arbitrator - int64_t lastCommitSeenSequenceNumber = -1; - if (lastCommitSeenSequenceNumberByArbitratorTable->contains(arbitratorId)) { - lastCommitSeenSequenceNumber = lastCommitSeenSequenceNumberByArbitratorTable->get(arbitratorId); - } - - // Go through each new commit one by one - for (uint i = 0; i < commitSequenceNumbers->size(); i++) { - int64_t commitSequenceNumber = commitSequenceNumbers->get(i); - Commit *commit = commitForClientTable->get(commitSequenceNumber); - // Special processing if a commit is not complete - if (!commit->isComplete()) { - if (i == (commitSequenceNumbers->size() - 1)) { - // If there is an incomplete commit and this commit is the - // latest one seen then this commit cannot be processed and - // there are no other commits - break; - } else { - // This is a commit that was already dead but parts of it - // are still in the block chain (not flushed out yet)-> - // Delete it and move on - commit->setDead(); - commitForClientTable->remove(commit->getSequenceNumber()); - delete commit; - continue; - } - } - - // Update the last transaction that was updated if we can - if (commit->getTransactionSequenceNumber() != -1) { - // Update the last transaction sequence number that the arbitrator arbitrated on1 - if (!lastArbitratedTransactionNumberByArbitratorTable->contains(commit->getMachineId()) || lastArbitratedTransactionNumberByArbitratorTable->get(commit->getMachineId()) < commit->getTransactionSequenceNumber()) { - lastArbitratedTransactionNumberByArbitratorTable->put(commit->getMachineId(), commit->getTransactionSequenceNumber()); - } - } - - // Update the last arbitration data that we have seen so far - if (lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->contains(commit->getMachineId())) { - int64_t lastArbitrationSequenceNumber = lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->get(commit->getMachineId()); - if (commit->getSequenceNumber() > lastArbitrationSequenceNumber) { - // Is larger - lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->put(commit->getMachineId(), commit->getSequenceNumber()); - } - } else { - // Never seen any data from this arbitrator so record the first one - lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->put(commit->getMachineId(), commit->getSequenceNumber()); - } - - // We have already seen this commit before so need to do the - // full processing on this commit - if (commit->getSequenceNumber() <= lastCommitSeenSequenceNumber) { - // Update the last transaction that was updated if we can - if (commit->getTransactionSequenceNumber() != -1) { - int64_t lastTransactionNumber = lastArbitratedTransactionNumberByArbitratorTable->get(commit->getMachineId()); - if (!lastArbitratedTransactionNumberByArbitratorTable->contains(commit->getMachineId()) || - lastArbitratedTransactionNumberByArbitratorTable->get(commit->getMachineId()) < commit->getTransactionSequenceNumber()) { - lastArbitratedTransactionNumberByArbitratorTable->put(commit->getMachineId(), commit->getTransactionSequenceNumber()); - } - } - continue; - } - - // If we got here then this is a brand new commit and needs full - // processing - // Get what commits should be edited, these are the commits that - // have live values for their keys - Hashset *commitsToEdit = new Hashset(); - { - SetIterator *kvit = commit->getKeyValueUpdateSet()->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - Commit *commit = liveCommitsByKeyTable->get(kv->getKey()); - if (commit != NULL) - commitsToEdit->add(commit); - } - delete kvit; - } - - // Update each previous commit that needs to be updated - SetIterator *commitit = commitsToEdit->iterator(); - while (commitit->hasNext()) { - Commit *previousCommit = commitit->next(); - - // Only bother with live commits (TODO: Maybe remove this check) - if (previousCommit->isLive()) { - - // Update which keys in the old commits are still live - { - SetIterator *kvit = commit->getKeyValueUpdateSet()->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - previousCommit->invalidateKey(kv->getKey()); - } - delete kvit; - } - - // if the commit is now dead then remove it - if (!previousCommit->isLive()) { - commitForClientTable->remove(previousCommit->getSequenceNumber()); - delete previousCommit; - } - } - } - delete commitit; - delete commitsToEdit; - - // Update the last seen sequence number from this arbitrator - if (lastCommitSeenSequenceNumberByArbitratorTable->contains(commit->getMachineId())) { - if (commit->getSequenceNumber() > lastCommitSeenSequenceNumberByArbitratorTable->get(commit->getMachineId())) { - lastCommitSeenSequenceNumberByArbitratorTable->put(commit->getMachineId(), commit->getSequenceNumber()); - } - } else { - lastCommitSeenSequenceNumberByArbitratorTable->put(commit->getMachineId(), commit->getSequenceNumber()); - } - - // We processed a new commit that we havent seen before - didProcessANewCommit = true; - - // Update the committed table of keys and which commit is using which key - { - SetIterator *kvit = commit->getKeyValueUpdateSet()->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - committedKeyValueTable->put(kv->getKey(), kv); - liveCommitsByKeyTable->put(kv->getKey(), commit); - } - delete kvit; - } - } - delete commitSequenceNumbers; - } - delete liveit; - - return didProcessANewCommit; -} - -/** - * Create the speculative table from transactions that are still live - * and have come from the cloud - */ -bool Table::updateSpeculativeTable(bool didProcessNewCommits) { - if (liveTransactionBySequenceNumberTable->size() == 0) { - // There is nothing to speculate on - return false; - } - - // Create a list of the transaction sequence numbers and sort them - // from oldest to newest - Vector *transactionSequenceNumbersSorted = new Vector(); - { - SetIterator *trit = getKeyIterator(liveTransactionBySequenceNumberTable); - while (trit->hasNext()) - transactionSequenceNumbersSorted->add(trit->next()); - delete trit; - } - - qsort(transactionSequenceNumbersSorted->expose(), transactionSequenceNumbersSorted->size(), sizeof(int64_t), compareInt64); - - bool hasGapInTransactionSequenceNumbers = transactionSequenceNumbersSorted->get(0) != oldestTransactionSequenceNumberSpeculatedOn; - - - if (hasGapInTransactionSequenceNumbers || didProcessNewCommits) { - // If there is a gap in the transaction sequence numbers then - // there was a commit or an abort of a transaction OR there was a - // new commit (Could be from offline commit) so a redo the - // speculation from scratch - - // Start from scratch - speculatedKeyValueTable->clear(); - lastTransactionSequenceNumberSpeculatedOn = -1; - oldestTransactionSequenceNumberSpeculatedOn = -1; - } - - // Remember the front of the transaction list - oldestTransactionSequenceNumberSpeculatedOn = transactionSequenceNumbersSorted->get(0); - - // Find where to start arbitration from - uint startIndex = 0; - - for (; startIndex < transactionSequenceNumbersSorted->size(); startIndex++) - if (transactionSequenceNumbersSorted->get(startIndex) == lastTransactionSequenceNumberSpeculatedOn) - break; - startIndex++; - - if (startIndex >= transactionSequenceNumbersSorted->size()) { - // Make sure we are not out of bounds - delete transactionSequenceNumbersSorted; - return false; // did not speculate - } - - Hashset *incompleteTransactionArbitrator = new Hashset(); - bool didSkip = true; - - for (uint i = startIndex; i < transactionSequenceNumbersSorted->size(); i++) { - int64_t transactionSequenceNumber = transactionSequenceNumbersSorted->get(i); - Transaction *transaction = liveTransactionBySequenceNumberTable->get(transactionSequenceNumber); - - if (!transaction->isComplete()) { - // If there is an incomplete transaction then there is nothing - // we can do add this transactions arbitrator to the list of - // arbitrators we should ignore - incompleteTransactionArbitrator->add(transaction->getArbitrator()); - didSkip = true; - continue; - } - - if (incompleteTransactionArbitrator->contains(transaction->getArbitrator())) { - continue; - } - - lastTransactionSequenceNumberSpeculatedOn = transactionSequenceNumber; - - if (transaction->evaluateGuard(committedKeyValueTable, speculatedKeyValueTable, NULL)) { - // Guard evaluated to true so update the speculative table - { - SetIterator *kvit = transaction->getKeyValueUpdateSet()->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - speculatedKeyValueTable->put(kv->getKey(), kv); - } - delete kvit; - } - } - } - - delete transactionSequenceNumbersSorted; - - if (didSkip) { - // Since there was a skip we need to redo the speculation next time around - lastTransactionSequenceNumberSpeculatedOn = -1; - oldestTransactionSequenceNumberSpeculatedOn = -1; - } - - // We did some speculation - return true; -} - -/** - * Create the pending transaction speculative table from transactions - * that are still in the pending transaction buffer - */ -void Table::updatePendingTransactionSpeculativeTable(bool didProcessNewCommitsOrSpeculate) { - if (pendingTransactionQueue->size() == 0) { - // There is nothing to speculate on - return; - } - - if (didProcessNewCommitsOrSpeculate || (firstPendingTransaction != pendingTransactionQueue->get(0))) { - // need to reset on the pending speculation - lastPendingTransactionSpeculatedOn = NULL; - firstPendingTransaction = pendingTransactionQueue->get(0); - pendingTransactionSpeculatedKeyValueTable->clear(); - } - - // Find where to start arbitration from - uint startIndex = 0; - - for (; startIndex < pendingTransactionQueue->size(); startIndex++) - if (pendingTransactionQueue->get(startIndex) == firstPendingTransaction) - break; - - if (startIndex >= pendingTransactionQueue->size()) { - // Make sure we are not out of bounds - return; - } - - for (uint i = startIndex; i < pendingTransactionQueue->size(); i++) { - Transaction *transaction = pendingTransactionQueue->get(i); - - lastPendingTransactionSpeculatedOn = transaction; - - if (transaction->evaluateGuard(committedKeyValueTable, speculatedKeyValueTable, pendingTransactionSpeculatedKeyValueTable)) { - // Guard evaluated to true so update the speculative table - SetIterator *kvit = transaction->getKeyValueUpdateSet()->iterator(); - while (kvit->hasNext()) { - KeyValue *kv = kvit->next(); - pendingTransactionSpeculatedKeyValueTable->put(kv->getKey(), kv); - } - delete kvit; - } - } -} - -/** - * Set dead and remove from the live transaction tables the - * transactions that are dead - */ -void Table::updateLiveTransactionsAndStatus() { - // Go through each of the transactions - { - SetIterator *iter = getKeyIterator(liveTransactionBySequenceNumberTable); - while (iter->hasNext()) { - int64_t key = iter->next(); - Transaction *transaction = liveTransactionBySequenceNumberTable->get(key); - - // Check if the transaction is dead - if (lastArbitratedTransactionNumberByArbitratorTable->contains(transaction->getArbitrator()) - && lastArbitratedTransactionNumberByArbitratorTable->get(transaction->getArbitrator()) >= transaction->getSequenceNumber()) { - // Set dead the transaction - transaction->setDead(); - - // Remove the transaction from the live table - iter->remove(); - liveTransactionByTransactionIdTable->remove(transaction->getId()); - delete transaction; - } - } - delete iter; - } - - // Go through each of the transactions - { - SetIterator *iter = getKeyIterator(outstandingTransactionStatus); - while (iter->hasNext()) { - int64_t key = iter->next(); - TransactionStatus *status = outstandingTransactionStatus->get(key); - - // Check if the transaction is dead - if (lastArbitratedTransactionNumberByArbitratorTable->contains(status->getTransactionArbitrator()) - && (lastArbitratedTransactionNumberByArbitratorTable->get(status->getTransactionArbitrator()) >= status->getTransactionSequenceNumber())) { - // Set committed - status->setStatus(TransactionStatus_StatusCommitted); - - // Remove - iter->remove(); - } - } - delete iter; - } -} - -/** - * Process this slot, entry by entry-> Also update the latest message sent by slot - */ -void Table::processSlot(SlotIndexer *indexer, Slot *slot, bool acceptUpdatesToLocal, Hashset *machineSet) { - - // Update the last message seen - updateLastMessage(slot->getMachineID(), slot->getSequenceNumber(), slot, acceptUpdatesToLocal, machineSet); - - // Process each entry in the slot - Vector *entries = slot->getEntries(); - uint eSize = entries->size(); - for (uint ei = 0; ei < eSize; ei++) { - Entry *entry = entries->get(ei); - switch (entry->getType()) { - case TypeCommitPart: - processEntry((CommitPart *)entry); - break; - case TypeAbort: - processEntry((Abort *)entry); - break; - case TypeTransactionPart: - processEntry((TransactionPart *)entry); - break; - case TypeNewKey: - processEntry((NewKey *)entry); - break; - case TypeLastMessage: - processEntry((LastMessage *)entry, machineSet); - break; - case TypeRejectedMessage: - processEntry((RejectedMessage *)entry, indexer); - break; - case TypeTableStatus: - processEntry((TableStatus *)entry, slot->getSequenceNumber()); - break; - default: - throw new Error("Unrecognized type: "); - } - } -} - -/** - * Update the last message that was sent for a machine Id - */ -void Table::processEntry(LastMessage *entry, Hashset *machineSet) { - // Update what the last message received by a machine was - updateLastMessage(entry->getMachineID(), entry->getSequenceNumber(), entry, false, machineSet); -} - -/** - * Add the new key to the arbitrators table and update the set of live - * new keys (in case of a rescued new key message) - */ -void Table::processEntry(NewKey *entry) { - // Update the arbitrator table with the new key information - arbitratorTable->put(entry->getKey(), entry->getMachineID()); - - // Update what the latest live new key is - NewKey *oldNewKey = liveNewKeyTable->put(entry->getKey(), entry); - if (oldNewKey != NULL) { - // Delete the old new key messages - oldNewKey->setDead(); - } -} - -/** - * Process new table status entries and set dead the old ones as new - * ones come in-> keeps track of the largest and smallest table status - * seen in this current round of updating the local copy of the block - * chain - */ -void Table::processEntry(TableStatus *entry, int64_t seq) { - int newNumSlots = entry->getMaxSlots(); - updateCurrMaxSize(newNumSlots); - initExpectedSize(seq, newNumSlots); - - if (liveTableStatus != NULL) { - // We have a larger table status so the old table status is no - // int64_ter alive - liveTableStatus->setDead(); - } - - // Make this new table status the latest alive table status - liveTableStatus = entry; -} - -/** - * Check old messages to see if there is a block chain violation-> - * Also - */ -void Table::processEntry(RejectedMessage *entry, SlotIndexer *indexer) { - int64_t oldSeqNum = entry->getOldSeqNum(); - int64_t newSeqNum = entry->getNewSeqNum(); - bool isequal = entry->getEqual(); - int64_t machineId = entry->getMachineID(); - int64_t seq = entry->getSequenceNumber(); - - // Check if we have messages that were supposed to be rejected in - // our local block chain - for (int64_t seqNum = oldSeqNum; seqNum <= newSeqNum; seqNum++) { - // Get the slot - Slot *slot = indexer->getSlot(seqNum); - - if (slot != NULL) { - // If we have this slot make sure that it was not supposed to be - // a rejected slot - int64_t slotMachineId = slot->getMachineID(); - if (isequal != (slotMachineId == machineId)) { - throw new Error("Server Error: Trying to insert rejected message for slot "); - } - } - } - - // Create a list of clients to watch until they see this rejected - // message entry-> - Hashset *deviceWatchSet = new Hashset(); - SetIterator *> *iter = getKeyIterator(lastMessageTable); - while (iter->hasNext()) { - // Machine ID for the last message entry - int64_t lastMessageEntryMachineId = iter->next(); - - // We've seen it, don't need to continue to watch-> Our next - // message will implicitly acknowledge it-> - if (lastMessageEntryMachineId == localMachineId) { - continue; - } - - Pair *lastMessageValue = lastMessageTable->get(lastMessageEntryMachineId); - int64_t entrySequenceNumber = lastMessageValue->getFirst(); - - if (entrySequenceNumber < seq) { - // Add this rejected message to the set of messages that this - // machine ID did not see yet - addWatchVector(lastMessageEntryMachineId, entry); - // This client did not see this rejected message yet so add it - // to the watch set to monitor - deviceWatchSet->add(lastMessageEntryMachineId); - } - } - delete iter; - - if (deviceWatchSet->isEmpty()) { - // This rejected message has been seen by all the clients so - entry->setDead(); - delete deviceWatchSet; - } else { - // We need to watch this rejected message - entry->setWatchSet(deviceWatchSet); - } -} - -/** - * Check if this abort is live, if not then save it so we can kill it - * later-> update the last transaction number that was arbitrated on-> - */ -void Table::processEntry(Abort *entry) { - if (entry->getTransactionSequenceNumber() != -1) { - // update the transaction status if it was sent to the server - TransactionStatus *status = outstandingTransactionStatus->remove(entry->getTransactionSequenceNumber()); - if (status != NULL) { - status->setStatus(TransactionStatus_StatusAborted); - } - } - - // Abort has not been seen by the client it is for yet so we need to - // keep track of it - - Abort *previouslySeenAbort = liveAbortTable->put(new Pair(entry->getAbortId()), entry); - if (previouslySeenAbort != NULL) { - previouslySeenAbort->setDead(); // Delete old version of the abort since we got a rescued newer version - } - - if (entry->getTransactionArbitrator() == localMachineId) { - liveAbortsGeneratedByLocal->put(entry->getArbitratorLocalSequenceNumber(), entry); - } - - if ((entry->getSequenceNumber() != -1) && (lastMessageTable->get(entry->getTransactionMachineId())->getFirst() >= entry->getSequenceNumber())) { - // The machine already saw this so it is dead - entry->setDead(); - Pair abortid = entry->getAbortId(); - liveAbortTable->remove(&abortid); - - if (entry->getTransactionArbitrator() == localMachineId) { - liveAbortsGeneratedByLocal->remove(entry->getArbitratorLocalSequenceNumber()); - } - return; - } - - // Update the last arbitration data that we have seen so far - if (lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->contains(entry->getTransactionArbitrator())) { - int64_t lastArbitrationSequenceNumber = lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->get(entry->getTransactionArbitrator()); - if (entry->getSequenceNumber() > lastArbitrationSequenceNumber) { - // Is larger - lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->put(entry->getTransactionArbitrator(), entry->getSequenceNumber()); - } - } else { - // Never seen any data from this arbitrator so record the first one - lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->put(entry->getTransactionArbitrator(), entry->getSequenceNumber()); - } - - // Set dead a transaction if we can - Pair deadPair = Pair(entry->getTransactionMachineId(), entry->getTransactionClientLocalSequenceNumber()); - - Transaction *transactionToSetDead = liveTransactionByTransactionIdTable->remove(&deadPair); - if (transactionToSetDead != NULL) { - liveTransactionBySequenceNumberTable->remove(transactionToSetDead->getSequenceNumber()); - } - - // Update the last transaction sequence number that the arbitrator - // arbitrated on - if (!lastArbitratedTransactionNumberByArbitratorTable->contains(entry->getTransactionArbitrator()) || - (lastArbitratedTransactionNumberByArbitratorTable->get(entry->getTransactionArbitrator()) < entry->getTransactionSequenceNumber())) { - // Is a valid one - if (entry->getTransactionSequenceNumber() != -1) { - lastArbitratedTransactionNumberByArbitratorTable->put(entry->getTransactionArbitrator(), entry->getTransactionSequenceNumber()); - } - } -} - -/** - * Set dead the transaction part if that transaction is dead and keep - * track of all new parts - */ -void Table::processEntry(TransactionPart *entry) { - // Check if we have already seen this transaction and set it dead OR - // if it is not alive - if (lastArbitratedTransactionNumberByArbitratorTable->contains(entry->getArbitratorId()) && (lastArbitratedTransactionNumberByArbitratorTable->get(entry->getArbitratorId()) >= entry->getSequenceNumber())) { - // This transaction is dead, it was already committed or aborted - entry->setDead(); - return; - } - - // This part is still alive - Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *transactionPart = newTransactionParts->get(entry->getMachineId()); - - if (transactionPart == NULL) { - // Dont have a table for this machine Id yet so make one - transactionPart = new Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals>(); - newTransactionParts->put(entry->getMachineId(), transactionPart); - } - - // Update the part and set dead ones we have already seen (got a - // rescued version) - entry->acquireRef(); - TransactionPart *previouslySeenPart = transactionPart->put(entry->getPartId(), entry); - if (previouslySeenPart != NULL) { - previouslySeenPart->releaseRef(); - previouslySeenPart->setDead(); - } -} - -/** - * Process new commit entries and save them for future use-> Delete duplicates - */ -void Table::processEntry(CommitPart *entry) { - // Update the last transaction that was updated if we can - if (entry->getTransactionSequenceNumber() != -1) { - if (!lastArbitratedTransactionNumberByArbitratorTable->contains(entry->getMachineId()) || - lastArbitratedTransactionNumberByArbitratorTable->get(entry->getMachineId()) < entry->getTransactionSequenceNumber()) { - lastArbitratedTransactionNumberByArbitratorTable->put(entry->getMachineId(), entry->getTransactionSequenceNumber()); - } - } - - Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *commitPart = newCommitParts->get(entry->getMachineId()); - if (commitPart == NULL) { - // Don't have a table for this machine Id yet so make one - commitPart = new Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals>(); - newCommitParts->put(entry->getMachineId(), commitPart); - } - // Update the part and set dead ones we have already seen (got a - // rescued version) - entry->acquireRef(); - CommitPart *previouslySeenPart = commitPart->put(entry->getPartId(), entry); - if (previouslySeenPart != NULL) { - previouslySeenPart->setDead(); - previouslySeenPart->releaseRef(); - } -} - -/** - * Update the last message seen table-> Update and set dead the - * appropriate RejectedMessages as clients see them-> Updates the live - * aborts, removes those that are dead and sets them dead-> Check that - * the last message seen is correct and that there is no mismatch of - * our own last message or that other clients have not had a rollback - * on the last message-> - */ -void Table::updateLastMessage(int64_t machineId, int64_t seqNum, Liveness *liveness, bool acceptUpdatesToLocal, Hashset *machineSet) { - // We have seen this machine ID - machineSet->remove(machineId); - - // Get the set of rejected messages that this machine Id is has not seen yet - Hashset *watchset = rejectedMessageWatchVectorTable->get(machineId); - // If there is a rejected message that this machine Id has not seen yet - if (watchset != NULL) { - // Go through each rejected message that this machine Id has not - // seen yet - - SetIterator *rmit = watchset->iterator(); - while (rmit->hasNext()) { - RejectedMessage *rm = rmit->next(); - // If this machine Id has seen this rejected message->->-> - if (rm->getSequenceNumber() <= seqNum) { - // Remove it from our watchlist - rmit->remove(); - // Decrement machines that need to see this notification - rm->removeWatcher(machineId); - } - } - delete rmit; - } - - // Set dead the abort - SetIterator *, Abort *, uintptr_t, 0, pairHashFunction, pairEquals> *abortit = getKeyIterator(liveAbortTable); - - while (abortit->hasNext()) { - Pair *key = abortit->next(); - Abort *abort = liveAbortTable->get(key); - if ((abort->getTransactionMachineId() == machineId) && (abort->getSequenceNumber() <= seqNum)) { - abort->setDead(); - abortit->remove(); - if (abort->getTransactionArbitrator() == localMachineId) { - liveAbortsGeneratedByLocal->remove(abort->getArbitratorLocalSequenceNumber()); - } - } - } - delete abortit; - if (machineId == localMachineId) { - // Our own messages are immediately dead-> - char livenessType = liveness->getType(); - if (livenessType == TypeLastMessage) { - ((LastMessage *)liveness)->setDead(); - } else if (livenessType == TypeSlot) { - ((Slot *)liveness)->setDead(); - } else { - throw new Error("Unrecognized type"); - } - } - // Get the old last message for this device - Pair *lastMessageEntry = lastMessageTable->put(machineId, new Pair(seqNum, liveness)); - if (lastMessageEntry == NULL) { - // If no last message then there is nothing else to process - return; - } - - int64_t lastMessageSeqNum = lastMessageEntry->getFirst(); - Liveness *lastEntry = lastMessageEntry->getSecond(); - delete lastMessageEntry; - - // If it is not our machine Id since we already set ours to dead - if (machineId != localMachineId) { - char lastEntryType = lastEntry->getType(); - - if (lastEntryType == TypeLastMessage) { - ((LastMessage *)lastEntry)->setDead(); - } else if (lastEntryType == TypeSlot) { - ((Slot *)lastEntry)->setDead(); - } else { - throw new Error("Unrecognized type"); - } - } - // Make sure the server is not playing any games - if (machineId == localMachineId) { - if (hadPartialSendToServer) { - // We were not making any updates and we had a machine mismatch - if (lastMessageSeqNum > seqNum && !acceptUpdatesToLocal) { - throw new Error("Server Error: Mismatch on local machine sequence number, needed at least: "); - } - } else { - // We were not making any updates and we had a machine mismatch - if (lastMessageSeqNum != seqNum && !acceptUpdatesToLocal) { - throw new Error("Server Error: Mismatch on local machine sequence number, needed: "); - } - } - } else { - if (lastMessageSeqNum > seqNum) { - throw new Error("Server Error: Rollback on remote machine sequence number"); - } - } -} - -/** - * Add a rejected message entry to the watch set to keep track of - * which clients have seen that rejected message entry and which have - * not. - */ -void Table::addWatchVector(int64_t machineId, RejectedMessage *entry) { - Hashset *entries = rejectedMessageWatchVectorTable->get(machineId); - if (entries == NULL) { - // There is no set for this machine ID yet so create one - entries = new Hashset(); - rejectedMessageWatchVectorTable->put(machineId, entries); - } - entries->add(entry); -} - -/** - * Check if the HMAC chain is not violated - */ -void Table::checkHMACChain(SlotIndexer *indexer, Array *newSlots) { - for (uint i = 0; i < newSlots->length(); i++) { - Slot *currSlot = newSlots->get(i); - Slot *prevSlot = indexer->getSlot(currSlot->getSequenceNumber() - 1); - if (prevSlot != NULL && - !prevSlot->getHMAC()->equals(currSlot->getPrevHMAC())) - throw new Error("Server Error: Invalid HMAC Chain"); - } -} diff --git a/version2/src/C/Table.cpp b/version2/src/C/Table.cpp new file mode 100644 index 0000000..255ba3c --- /dev/null +++ b/version2/src/C/Table.cpp @@ -0,0 +1,2863 @@ +#include "Table.h" +#include "CloudComm.h" +#include "SlotBuffer.h" +#include "NewKey.h" +#include "Slot.h" +#include "KeyValue.h" +#include "Error.h" +#include "PendingTransaction.h" +#include "TableStatus.h" +#include "TransactionStatus.h" +#include "Transaction.h" +#include "LastMessage.h" +#include "SecureRandom.h" +#include "ByteBuffer.h" +#include "Abort.h" +#include "CommitPart.h" +#include "ArbitrationRound.h" +#include "TransactionPart.h" +#include "Commit.h" +#include "RejectedMessage.h" +#include "SlotIndexer.h" +#include + +int compareInt64(const void *a, const void *b) { + const int64_t *pa = (const int64_t *) a; + const int64_t *pb = (const int64_t *) b; + if (*pa < *pb) + return -1; + else if (*pa > *pb) + return 1; + else + return 0; +} + +Table::Table(IoTString *baseurl, IoTString *password, int64_t _localMachineId, int listeningPort) : + buffer(NULL), + cloud(new CloudComm(this, baseurl, password, listeningPort)), + random(NULL), + liveTableStatus(NULL), + pendingTransactionBuilder(NULL), + lastPendingTransactionSpeculatedOn(NULL), + firstPendingTransaction(NULL), + numberOfSlots(0), + bufferResizeThreshold(0), + liveSlotCount(0), + oldestLiveSlotSequenceNumver(1), + localMachineId(_localMachineId), + sequenceNumber(0), + localSequenceNumber(0), + localTransactionSequenceNumber(0), + lastTransactionSequenceNumberSpeculatedOn(0), + oldestTransactionSequenceNumberSpeculatedOn(0), + localArbitrationSequenceNumber(0), + hadPartialSendToServer(false), + attemptedToSendToServer(false), + expectedsize(0), + didFindTableStatus(false), + currMaxSize(0), + lastSlotAttemptedToSend(NULL), + lastIsNewKey(false), + lastNewSize(0), + lastTransactionPartsSent(NULL), + lastNewKey(NULL), + committedKeyValueTable(NULL), + speculatedKeyValueTable(NULL), + pendingTransactionSpeculatedKeyValueTable(NULL), + liveNewKeyTable(NULL), + lastMessageTable(NULL), + rejectedMessageWatchVectorTable(NULL), + arbitratorTable(NULL), + liveAbortTable(NULL), + newTransactionParts(NULL), + newCommitParts(NULL), + lastArbitratedTransactionNumberByArbitratorTable(NULL), + liveTransactionBySequenceNumberTable(NULL), + liveTransactionByTransactionIdTable(NULL), + liveCommitsTable(NULL), + liveCommitsByKeyTable(NULL), + lastCommitSeenSequenceNumberByArbitratorTable(NULL), + rejectedSlotVector(NULL), + pendingTransactionQueue(NULL), + pendingSendArbitrationRounds(NULL), + pendingSendArbitrationEntriesToDelete(NULL), + transactionPartsSent(NULL), + outstandingTransactionStatus(NULL), + liveAbortsGeneratedByLocal(NULL), + offlineTransactionsCommittedAndAtServer(NULL), + localCommunicationTable(NULL), + lastTransactionSeenFromMachineFromServer(NULL), + lastArbitrationDataLocalSequenceNumberSeenFromArbitrator(NULL), + lastInsertedNewKey(false), + lastSeqNumArbOn(0) +{ + init(); +} + +Table::Table(CloudComm *_cloud, int64_t _localMachineId) : + buffer(NULL), + cloud(_cloud), + random(NULL), + liveTableStatus(NULL), + pendingTransactionBuilder(NULL), + lastPendingTransactionSpeculatedOn(NULL), + firstPendingTransaction(NULL), + numberOfSlots(0), + bufferResizeThreshold(0), + liveSlotCount(0), + oldestLiveSlotSequenceNumver(1), + localMachineId(_localMachineId), + sequenceNumber(0), + localSequenceNumber(0), + localTransactionSequenceNumber(0), + lastTransactionSequenceNumberSpeculatedOn(0), + oldestTransactionSequenceNumberSpeculatedOn(0), + localArbitrationSequenceNumber(0), + hadPartialSendToServer(false), + attemptedToSendToServer(false), + expectedsize(0), + didFindTableStatus(false), + currMaxSize(0), + lastSlotAttemptedToSend(NULL), + lastIsNewKey(false), + lastNewSize(0), + lastTransactionPartsSent(NULL), + lastNewKey(NULL), + committedKeyValueTable(NULL), + speculatedKeyValueTable(NULL), + pendingTransactionSpeculatedKeyValueTable(NULL), + liveNewKeyTable(NULL), + lastMessageTable(NULL), + rejectedMessageWatchVectorTable(NULL), + arbitratorTable(NULL), + liveAbortTable(NULL), + newTransactionParts(NULL), + newCommitParts(NULL), + lastArbitratedTransactionNumberByArbitratorTable(NULL), + liveTransactionBySequenceNumberTable(NULL), + liveTransactionByTransactionIdTable(NULL), + liveCommitsTable(NULL), + liveCommitsByKeyTable(NULL), + lastCommitSeenSequenceNumberByArbitratorTable(NULL), + rejectedSlotVector(NULL), + pendingTransactionQueue(NULL), + pendingSendArbitrationRounds(NULL), + pendingSendArbitrationEntriesToDelete(NULL), + transactionPartsSent(NULL), + outstandingTransactionStatus(NULL), + liveAbortsGeneratedByLocal(NULL), + offlineTransactionsCommittedAndAtServer(NULL), + localCommunicationTable(NULL), + lastTransactionSeenFromMachineFromServer(NULL), + lastArbitrationDataLocalSequenceNumberSeenFromArbitrator(NULL), + lastInsertedNewKey(false), + lastSeqNumArbOn(0) +{ + init(); +} + +Table::~Table() { + delete cloud; + delete random; + delete buffer; + // init data structs + delete committedKeyValueTable; + delete speculatedKeyValueTable; + delete pendingTransactionSpeculatedKeyValueTable; + delete liveNewKeyTable; + { + SetIterator *> *lmit = getKeyIterator(lastMessageTable); + while (lmit->hasNext()) { + Pair * pair = lastMessageTable->get(lmit->next()); + delete pair; + } + delete lmit; + delete lastMessageTable; + } + if (pendingTransactionBuilder != NULL) + delete pendingTransactionBuilder; + { + SetIterator *> *rmit = getKeyIterator(rejectedMessageWatchVectorTable); + while(rmit->hasNext()) { + int64_t machineid = rmit->next(); + Hashset * rmset = rejectedMessageWatchVectorTable->get(machineid); + SetIterator * mit = rmset->iterator(); + while (mit->hasNext()) { + RejectedMessage * rm = mit->next(); + delete rm; + } + delete mit; + delete rmset; + } + delete rmit; + delete rejectedMessageWatchVectorTable; + } + delete arbitratorTable; + delete liveAbortTable; + { + SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *partsit = getKeyIterator(newTransactionParts); + while (partsit->hasNext()) { + int64_t machineId = partsit->next(); + Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = partsit->currVal(); + SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *pit = getKeyIterator(parts); + while(pit->hasNext()) { + Pair * pair=pit->next(); + pit->currVal()->releaseRef(); + } + delete pit; + + delete parts; + } + delete partsit; + delete newTransactionParts; + } + { + SetIterator *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *partsit = getKeyIterator(newCommitParts); + while (partsit->hasNext()) { + int64_t machineId = partsit->next(); + Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = partsit->currVal(); + SetIterator *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *pit = getKeyIterator(parts); + while(pit->hasNext()) { + Pair * pair=pit->next(); + pit->currVal()->releaseRef(); + } + delete pit; + delete parts; + } + delete partsit; + delete newCommitParts; + } + delete lastArbitratedTransactionNumberByArbitratorTable; + delete liveTransactionBySequenceNumberTable; + delete liveTransactionByTransactionIdTable; + { + SetIterator *> *liveit = getKeyIterator(liveCommitsTable); + while (liveit->hasNext()) { + int64_t arbitratorId = liveit->next(); + + // Get all the commits for a specific arbitrator + Hashtable *commitForClientTable = liveit->currVal(); + { + SetIterator *clientit = getKeyIterator(commitForClientTable); + while (clientit->hasNext()) { + int64_t id = clientit->next(); + delete commitForClientTable->get(id); + } + delete clientit; + } + + delete commitForClientTable; + } + delete liveit; + delete liveCommitsTable; + } + delete liveCommitsByKeyTable; + delete lastCommitSeenSequenceNumberByArbitratorTable; + delete rejectedSlotVector; + { + uint size = pendingTransactionQueue->size(); + for (uint iter = 0; iter < size; iter++) { + delete pendingTransactionQueue->get(iter); + } + delete pendingTransactionQueue; + } + delete pendingSendArbitrationEntriesToDelete; + { + SetIterator *> *trit = getKeyIterator(transactionPartsSent); + while (trit->hasNext()) { + Transaction *transaction = trit->next(); + delete trit->currVal(); + } + delete trit; + delete transactionPartsSent; + } + delete outstandingTransactionStatus; + delete liveAbortsGeneratedByLocal; + delete offlineTransactionsCommittedAndAtServer; + delete localCommunicationTable; + delete lastTransactionSeenFromMachineFromServer; + { + for(uint i = 0; i < pendingSendArbitrationRounds->size(); i++) { + delete pendingSendArbitrationRounds->get(i); + } + delete pendingSendArbitrationRounds; + } + if (lastTransactionPartsSent != NULL) + delete lastTransactionPartsSent; + delete lastArbitrationDataLocalSequenceNumberSeenFromArbitrator; + if (lastNewKey) + delete lastNewKey; +} + +/** + * Init all the stuff needed for for table usage + */ +void Table::init() { + // Init helper objects + random = new SecureRandom(); + buffer = new SlotBuffer(); + + // init data structs + committedKeyValueTable = new Hashtable(); + speculatedKeyValueTable = new Hashtable(); + pendingTransactionSpeculatedKeyValueTable = new Hashtable(); + liveNewKeyTable = new Hashtable(); + lastMessageTable = new Hashtable * >(); + rejectedMessageWatchVectorTable = new Hashtable * >(); + arbitratorTable = new Hashtable(); + liveAbortTable = new Hashtable *, Abort *, uintptr_t, 0, pairHashFunction, pairEquals>(); + newTransactionParts = new Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *>(); + newCommitParts = new Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *>(); + lastArbitratedTransactionNumberByArbitratorTable = new Hashtable(); + liveTransactionBySequenceNumberTable = new Hashtable(); + liveTransactionByTransactionIdTable = new Hashtable *, Transaction *, uintptr_t, 0, pairHashFunction, pairEquals>(); + liveCommitsTable = new Hashtable * >(); + liveCommitsByKeyTable = new Hashtable(); + lastCommitSeenSequenceNumberByArbitratorTable = new Hashtable(); + rejectedSlotVector = new Vector(); + pendingTransactionQueue = new Vector(); + pendingSendArbitrationEntriesToDelete = new Vector(); + transactionPartsSent = new Hashtable *>(); + outstandingTransactionStatus = new Hashtable(); + liveAbortsGeneratedByLocal = new Hashtable(); + offlineTransactionsCommittedAndAtServer = new Hashset *, uintptr_t, 0, pairHashFunction, pairEquals>(); + localCommunicationTable = new Hashtable *>(); + lastTransactionSeenFromMachineFromServer = new Hashtable(); + pendingSendArbitrationRounds = new Vector(); + lastArbitrationDataLocalSequenceNumberSeenFromArbitrator = new Hashtable(); + + // Other init stuff + numberOfSlots = buffer->capacity(); + setResizeThreshold(); +} + +/** + * Initialize the table by inserting a table status as the first entry + * into the table status also initialize the crypto stuff. + */ +void Table::initTable() { + cloud->initSecurity(); + + // Create the first insertion into the block chain which is the table status + Slot *s = new Slot(this, 1, localMachineId, localSequenceNumber); + localSequenceNumber++; + TableStatus *status = new TableStatus(s, numberOfSlots); + s->addShallowEntry(status); + Array *array = cloud->putSlot(s, numberOfSlots); + + if (array == NULL) { + array = new Array(1); + array->set(0, s); + // update local block chain + validateAndUpdate(array, true); + delete array; + } else if (array->length() == 1) { + // in case we did push the slot BUT we failed to init it + validateAndUpdate(array, true); + delete s; + delete array; + } else { + delete s; + delete array; + throw new Error("Error on initialization"); + } +} + +/** + * Rebuild the table from scratch by pulling the latest block chain + * from the server. + */ +void Table::rebuild() { + // Just pull the latest slots from the server + Array *newslots = cloud->getSlots(sequenceNumber + 1); + validateAndUpdate(newslots, true); + delete newslots; + sendToServer(NULL); + updateLiveTransactionsAndStatus(); +} + +void Table::addLocalCommunication(int64_t arbitrator, IoTString *hostName, int portNumber) { + localCommunicationTable->put(arbitrator, new Pair(hostName, portNumber)); +} + +int64_t Table::getArbitrator(IoTString *key) { + return arbitratorTable->get(key); +} + +void Table::close() { + cloud->closeCloud(); +} + +IoTString *Table::getCommitted(IoTString *key) { + KeyValue *kv = committedKeyValueTable->get(key); + + if (kv != NULL) { + return new IoTString(kv->getValue()); + } else { + return NULL; + } +} + +IoTString *Table::getSpeculative(IoTString *key) { + KeyValue *kv = pendingTransactionSpeculatedKeyValueTable->get(key); + + if (kv == NULL) { + kv = speculatedKeyValueTable->get(key); + } + + if (kv == NULL) { + kv = committedKeyValueTable->get(key); + } + + if (kv != NULL) { + return new IoTString(kv->getValue()); + } else { + return NULL; + } +} + +IoTString *Table::getCommittedAtomic(IoTString *key) { + KeyValue *kv = committedKeyValueTable->get(key); + + if (!arbitratorTable->contains(key)) { + throw new Error("Key not Found."); + } + + // Make sure new key value pair matches the current arbitrator + if (!pendingTransactionBuilder->checkArbitrator(arbitratorTable->get(key))) { + // TODO: Maybe not throw en error + throw new Error("Not all Key Values Match Arbitrator."); + } + + if (kv != NULL) { + pendingTransactionBuilder->addKVGuard(new KeyValue(key, kv->getValue())); + return new IoTString(kv->getValue()); + } else { + pendingTransactionBuilder->addKVGuard(new KeyValue(key, NULL)); + return NULL; + } +} + +IoTString *Table::getSpeculativeAtomic(IoTString *key) { + if (!arbitratorTable->contains(key)) { + throw new Error("Key not Found."); + } + + // Make sure new key value pair matches the current arbitrator + if (!pendingTransactionBuilder->checkArbitrator(arbitratorTable->get(key))) { + // TODO: Maybe not throw en error + throw new Error("Not all Key Values Match Arbitrator."); + } + + KeyValue *kv = pendingTransactionSpeculatedKeyValueTable->get(key); + + if (kv == NULL) { + kv = speculatedKeyValueTable->get(key); + } + + if (kv == NULL) { + kv = committedKeyValueTable->get(key); + } + + if (kv != NULL) { + pendingTransactionBuilder->addKVGuard(new KeyValue(key, kv->getValue())); + return new IoTString(kv->getValue()); + } else { + pendingTransactionBuilder->addKVGuard(new KeyValue(key, NULL)); + return NULL; + } +} + +bool Table::update() { + try { + Array *newSlots = cloud->getSlots(sequenceNumber + 1); + validateAndUpdate(newSlots, false); + delete newSlots; + sendToServer(NULL); + updateLiveTransactionsAndStatus(); + return true; + } catch (Exception *e) { + SetIterator *> *kit = getKeyIterator(localCommunicationTable); + while (kit->hasNext()) { + int64_t m = kit->next(); + updateFromLocal(m); + } + delete kit; + } + + return false; +} + +bool Table::createNewKey(IoTString *keyName, int64_t machineId) { + while (true) { + if (arbitratorTable->contains(keyName)) { + // There is already an arbitrator + return false; + } + NewKey *newKey = new NewKey(NULL, keyName, machineId); + + if (sendToServer(newKey)) { + // If successfully inserted + return true; + } + } +} + +void Table::startTransaction() { + // Create a new transaction, invalidates any old pending transactions. + if (pendingTransactionBuilder != NULL) + delete pendingTransactionBuilder; + pendingTransactionBuilder = new PendingTransaction(localMachineId); +} + +void Table::put(IoTString *key, IoTString *value) { + // Make sure it is a valid key + if (!arbitratorTable->contains(key)) { + throw new Error("Key not Found."); + } + + // Make sure new key value pair matches the current arbitrator + if (!pendingTransactionBuilder->checkArbitrator(arbitratorTable->get(key))) { + // TODO: Maybe not throw en error + throw new Error("Not all Key Values Match Arbitrator."); + } + + // Add the key value to this transaction + KeyValue *kv = new KeyValue(new IoTString(key), new IoTString(value)); + pendingTransactionBuilder->addKV(kv); +} + +TransactionStatus *Table::commitTransaction() { + if (pendingTransactionBuilder->getKVUpdates()->size() == 0) { + // transaction with no updates will have no effect on the system + return new TransactionStatus(TransactionStatus_StatusNoEffect, -1); + } + + // Set the local transaction sequence number and increment + pendingTransactionBuilder->setClientLocalSequenceNumber(localTransactionSequenceNumber); + localTransactionSequenceNumber++; + + // Create the transaction status + TransactionStatus *transactionStatus = new TransactionStatus(TransactionStatus_StatusPending, pendingTransactionBuilder->getArbitrator()); + + // Create the new transaction + Transaction *newTransaction = pendingTransactionBuilder->createTransaction(); + newTransaction->setTransactionStatus(transactionStatus); + + if (pendingTransactionBuilder->getArbitrator() != localMachineId) { + // Add it to the queue and invalidate the builder for safety + pendingTransactionQueue->add(newTransaction); + } else { + arbitrateOnLocalTransaction(newTransaction); + delete newTransaction; + updateLiveStateFromLocal(); + } + if (pendingTransactionBuilder != NULL) + delete pendingTransactionBuilder; + + pendingTransactionBuilder = new PendingTransaction(localMachineId); + + try { + sendToServer(NULL); + } catch (ServerException *e) { + + Hashset *arbitratorTriedAndFailed = new Hashset(); + uint size = pendingTransactionQueue->size(); + uint oldindex = 0; + for (uint iter = 0; iter < size; iter++) { + Transaction *transaction = pendingTransactionQueue->get(iter); + pendingTransactionQueue->set(oldindex++, pendingTransactionQueue->get(iter)); + + if (arbitratorTriedAndFailed->contains(transaction->getArbitrator())) { + // Already contacted this client so ignore all attempts to contact this client + // to preserve ordering for arbitrator + continue; + } + + Pair sendReturn = sendTransactionToLocal(transaction); + + if (sendReturn.getFirst()) { + // Failed to contact over local + arbitratorTriedAndFailed->add(transaction->getArbitrator()); + } else { + // Successful contact or should not contact + + if (sendReturn.getSecond()) { + // did arbitrate + delete transaction; + oldindex--; + } + } + } + pendingTransactionQueue->setSize(oldindex); + } + + updateLiveStateFromLocal(); + + return transactionStatus; +} + +/** + * Recalculate the new resize threshold + */ +void Table::setResizeThreshold() { + int resizeLower = (int) (Table_RESIZE_THRESHOLD * numberOfSlots); + bufferResizeThreshold = resizeLower - 1 + random->nextInt(numberOfSlots - resizeLower); +} + +int64_t Table::getLocalSequenceNumber() { + return localSequenceNumber; +} + +void Table::processTransactionList(bool handlePartial) { + SetIterator *> *trit = getKeyIterator(lastTransactionPartsSent); + while (trit->hasNext()) { + Transaction *transaction = trit->next(); + transaction->resetServerFailure(); + // Update which transactions parts still need to be sent + transaction->removeSentParts(lastTransactionPartsSent->get(transaction)); + // Add the transaction status to the outstanding list + outstandingTransactionStatus->put(transaction->getSequenceNumber(), transaction->getTransactionStatus()); + + // Update the transaction status + transaction->getTransactionStatus()->setStatus(TransactionStatus_StatusSentPartial); + + // Check if all the transaction parts were successfully + // sent and if so then remove it from pending + if (transaction->didSendAllParts()) { + transaction->getTransactionStatus()->setStatus(TransactionStatus_StatusSentFully); + pendingTransactionQueue->remove(transaction); + delete transaction; + } else if (handlePartial) { + transaction->resetServerFailure(); + // Set the transaction sequence number back to nothing + if (!transaction->didSendAPartToServer()) { + transaction->setSequenceNumber(-1); + } + } + } + delete trit; +} + +NewKey * Table::handlePartialSend(NewKey * newKey) { + //Didn't receive acknowledgement for last send + //See if the server has received a newer slot + + Array *newSlots = cloud->getSlots(sequenceNumber + 1); + if (newSlots->length() == 0) { + //Retry sending old slot + bool wasInserted = false; + bool sendSlotsReturn = sendSlotsToServer(lastSlotAttemptedToSend, lastNewSize, lastIsNewKey, &wasInserted, &newSlots); + + if (sendSlotsReturn) { + lastSlotAttemptedToSend = NULL; + if (newKey != NULL) { + if (lastInsertedNewKey && (lastNewKey->getKey() == newKey->getKey()) && (lastNewKey->getMachineID() == newKey->getMachineID())) { + delete newKey; + newKey = NULL; + } + } + processTransactionList(false); + } else { + if (checkSend(newSlots, lastSlotAttemptedToSend)) { + if (newKey != NULL) { + if (lastInsertedNewKey && (lastNewKey->getKey() == newKey->getKey()) && (lastNewKey->getMachineID() == newKey->getMachineID())) { + delete newKey; + newKey = NULL; + } + } + processTransactionList(true); + } + } + + SetIterator *> *trit = getKeyIterator(lastTransactionPartsSent); + while (trit->hasNext()) { + Transaction *transaction = trit->next(); + transaction->resetServerFailure(); + // Set the transaction sequence number back to nothing + if (!transaction->didSendAPartToServer()) { + transaction->setSequenceNumber(-1); + } + } + delete trit; + + if (newSlots->length() != 0) { + // insert into the local block chain + validateAndUpdate(newSlots, true); + } + } else { + if (checkSend(newSlots, lastSlotAttemptedToSend)) { + if (newKey != NULL) { + if (lastInsertedNewKey && (lastNewKey->getKey() == newKey->getKey()) && (lastNewKey->getMachineID() == newKey->getMachineID())) { + delete newKey; + newKey = NULL; + } + } + + processTransactionList(true); + } else { + SetIterator *> *trit = getKeyIterator(lastTransactionPartsSent); + while (trit->hasNext()) { + Transaction *transaction = trit->next(); + transaction->resetServerFailure(); + // Set the transaction sequence number back to nothing + if (!transaction->didSendAPartToServer()) { + transaction->setSequenceNumber(-1); + } + } + delete trit; + } + + // insert into the local block chain + validateAndUpdate(newSlots, true); + } + delete newSlots; + return newKey; +} + +void Table::clearSentParts() { + // Clear the sent data since we are trying again + pendingSendArbitrationEntriesToDelete->clear(); + SetIterator *> *trit = getKeyIterator(transactionPartsSent); + while (trit->hasNext()) { + Transaction *transaction = trit->next(); + delete trit->currVal(); + } + delete trit; + transactionPartsSent->clear(); +} + +bool Table::sendToServer(NewKey *newKey) { + if (hadPartialSendToServer) { + newKey = handlePartialSend(newKey); + } + + try { + // While we have stuff that needs inserting into the block chain + while ((pendingTransactionQueue->size() > 0) || (pendingSendArbitrationRounds->size() > 0) || (newKey != NULL)) { + if (hadPartialSendToServer) { + throw new Error("Should Be error free"); + } + + // If there is a new key with same name then end + if ((newKey != NULL) && arbitratorTable->contains(newKey->getKey())) { + delete newKey; + return false; + } + + // Create the slot + Slot *slot = new Slot(this, sequenceNumber + 1, localMachineId, new Array(buffer->getSlot(sequenceNumber)->getHMAC()), localSequenceNumber); + localSequenceNumber++; + + // Try to fill the slot with data + int newSize = 0; + bool insertedNewKey = false; + bool needsResize = fillSlot(slot, false, newKey, newSize, insertedNewKey); + + if (needsResize) { + // Reset which transaction to send + SetIterator *> *trit = getKeyIterator(transactionPartsSent); + while (trit->hasNext()) { + Transaction *transaction = trit->next(); + transaction->resetNextPartToSend(); + + // Set the transaction sequence number back to nothing + if (!transaction->didSendAPartToServer() && !transaction->getServerFailure()) { + transaction->setSequenceNumber(-1); + } + } + delete trit; + + // Clear the sent data since we are trying again + clearSentParts(); + + // We needed a resize so try again + fillSlot(slot, true, newKey, newSize, insertedNewKey); + } + if (lastSlotAttemptedToSend != NULL) + delete lastSlotAttemptedToSend; + + lastSlotAttemptedToSend = slot; + lastIsNewKey = (newKey != NULL); + lastInsertedNewKey = insertedNewKey; + lastNewSize = newSize; + if (( newKey != lastNewKey) && (lastNewKey != NULL)) + delete lastNewKey; + lastNewKey = newKey; + if (lastTransactionPartsSent != NULL) + delete lastTransactionPartsSent; + lastTransactionPartsSent = transactionPartsSent->clone(); + + Array * newSlots = NULL; + bool wasInserted = false; + bool sendSlotsReturn = sendSlotsToServer(slot, newSize, newKey != NULL, &wasInserted, &newSlots); + + if (sendSlotsReturn) { + lastSlotAttemptedToSend = NULL; + // Did insert into the block chain + if (insertedNewKey) { + // This slot was what was inserted not a previous slot + // New Key was successfully inserted into the block chain so dont want to insert it again + newKey = NULL; + } + + // Remove the aborts and commit parts that were sent from the pending to send queue + uint size = pendingSendArbitrationRounds->size(); + uint oldcount = 0; + for (uint i = 0; i < size; i++) { + ArbitrationRound *round = pendingSendArbitrationRounds->get(i); + round->removeParts(pendingSendArbitrationEntriesToDelete); + + if (!round->isDoneSending()) { + //Add part back in + pendingSendArbitrationRounds->set(oldcount++, + pendingSendArbitrationRounds->get(i)); + } else + delete pendingSendArbitrationRounds->get(i); + } + pendingSendArbitrationRounds->setSize(oldcount); + processTransactionList(false); + } else { + // Reset which transaction to send + SetIterator *> *trit = getKeyIterator(transactionPartsSent); + while (trit->hasNext()) { + Transaction *transaction = trit->next(); + transaction->resetNextPartToSend(); + + // Set the transaction sequence number back to nothing + if (!transaction->didSendAPartToServer() && !transaction->getServerFailure()) { + transaction->setSequenceNumber(-1); + } + } + delete trit; + } + + // Clear the sent data in preparation for next send + clearSentParts(); + + if (newSlots->length() != 0) { + // insert into the local block chain + validateAndUpdate(newSlots, true); + } + delete newSlots; + } + } catch (ServerException *e) { + if (e->getType() != ServerException_TypeInputTimeout) { + // Nothing was able to be sent to the server so just clear these data structures + SetIterator *> *trit = getKeyIterator(transactionPartsSent); + while (trit->hasNext()) { + Transaction *transaction = trit->next(); + transaction->resetNextPartToSend(); + + // Set the transaction sequence number back to nothing + if (!transaction->didSendAPartToServer() && !transaction->getServerFailure()) { + transaction->setSequenceNumber(-1); + } + } + delete trit; + } else { + // There was a partial send to the server + hadPartialSendToServer = true; + + // Nothing was able to be sent to the server so just clear these data structures + SetIterator *> *trit = getKeyIterator(transactionPartsSent); + while (trit->hasNext()) { + Transaction *transaction = trit->next(); + transaction->resetNextPartToSend(); + transaction->setServerFailure(); + } + delete trit; + } + + clearSentParts(); + + throw e; + } + + return newKey == NULL; +} + +bool Table::updateFromLocal(int64_t machineId) { + if (!localCommunicationTable->contains(machineId)) + return false; + + Pair *localCommunicationInformation = localCommunicationTable->get(machineId); + + // Get the size of the send data + int sendDataSize = sizeof(int32_t) + sizeof(int64_t); + + int64_t lastArbitrationDataLocalSequenceNumber = (int64_t) -1; + if (lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->contains(machineId)) { + lastArbitrationDataLocalSequenceNumber = lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->get(machineId); + } + + Array *sendData = new Array(sendDataSize); + ByteBuffer *bbEncode = ByteBuffer_wrap(sendData); + + // Encode the data + bbEncode->putLong(lastArbitrationDataLocalSequenceNumber); + bbEncode->putInt(0); + + // Send by local + Array *returnData = cloud->sendLocalData(sendData, localSequenceNumber, localCommunicationInformation->getFirst(), localCommunicationInformation->getSecond()); + localSequenceNumber++; + + if (returnData == NULL) { + // Could not contact server + return false; + } + + // Decode the data + ByteBuffer *bbDecode = ByteBuffer_wrap(returnData); + int numberOfEntries = bbDecode->getInt(); + + for (int i = 0; i < numberOfEntries; i++) { + char type = bbDecode->get(); + if (type == TypeAbort) { + Abort *abort = (Abort *)Abort_decode(NULL, bbDecode); + processEntry(abort); + } else if (type == TypeCommitPart) { + CommitPart *commitPart = (CommitPart *)CommitPart_decode(NULL, bbDecode); + processEntry(commitPart); + } + } + + updateLiveStateFromLocal(); + + return true; +} + +Pair Table::sendTransactionToLocal(Transaction *transaction) { + + // Get the devices local communications + if (!localCommunicationTable->contains(transaction->getArbitrator())) + return Pair(true, false); + + Pair *localCommunicationInformation = localCommunicationTable->get(transaction->getArbitrator()); + + // Get the size of the send data + int sendDataSize = sizeof(int32_t) + sizeof(int64_t); + { + Vector *tParts = transaction->getParts(); + uint tPartsSize = tParts->size(); + for (uint i = 0; i < tPartsSize; i++) { + TransactionPart *part = tParts->get(i); + sendDataSize += part->getSize(); + } + } + + int64_t lastArbitrationDataLocalSequenceNumber = (int64_t) -1; + if (lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->contains(transaction->getArbitrator())) { + lastArbitrationDataLocalSequenceNumber = lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->get(transaction->getArbitrator()); + } + + // Make the send data size + Array *sendData = new Array(sendDataSize); + ByteBuffer *bbEncode = ByteBuffer_wrap(sendData); + + // Encode the data + bbEncode->putLong(lastArbitrationDataLocalSequenceNumber); + bbEncode->putInt(transaction->getParts()->size()); + { + Vector *tParts = transaction->getParts(); + uint tPartsSize = tParts->size(); + for (uint i = 0; i < tPartsSize; i++) { + TransactionPart *part = tParts->get(i); + part->encode(bbEncode); + } + } + + // Send by local + Array *returnData = cloud->sendLocalData(sendData, localSequenceNumber, localCommunicationInformation->getFirst(), localCommunicationInformation->getSecond()); + localSequenceNumber++; + + if (returnData == NULL) { + // Could not contact server + return Pair(true, false); + } + + // Decode the data + ByteBuffer *bbDecode = ByteBuffer_wrap(returnData); + bool didCommit = bbDecode->get() == 1; + bool couldArbitrate = bbDecode->get() == 1; + int numberOfEntries = bbDecode->getInt(); + bool foundAbort = false; + + for (int i = 0; i < numberOfEntries; i++) { + char type = bbDecode->get(); + if (type == TypeAbort) { + Abort *abort = (Abort *)Abort_decode(NULL, bbDecode); + + if ((abort->getTransactionMachineId() == localMachineId) && (abort->getTransactionClientLocalSequenceNumber() == transaction->getClientLocalSequenceNumber())) { + foundAbort = true; + } + + processEntry(abort); + } else if (type == TypeCommitPart) { + CommitPart *commitPart = (CommitPart *)CommitPart_decode(NULL, bbDecode); + processEntry(commitPart); + } + } + + updateLiveStateFromLocal(); + + if (couldArbitrate) { + TransactionStatus *status = transaction->getTransactionStatus(); + if (didCommit) { + status->setStatus(TransactionStatus_StatusCommitted); + } else { + status->setStatus(TransactionStatus_StatusAborted); + } + } else { + TransactionStatus *status = transaction->getTransactionStatus(); + if (foundAbort) { + status->setStatus(TransactionStatus_StatusAborted); + } else { + status->setStatus(TransactionStatus_StatusCommitted); + } + } + + return Pair(false, true); +} + +Array *Table::acceptDataFromLocal(Array *data) { + // Decode the data + ByteBuffer *bbDecode = ByteBuffer_wrap(data); + int64_t lastArbitratedSequenceNumberSeen = bbDecode->getLong(); + int numberOfParts = bbDecode->getInt(); + + // If we did commit a transaction or not + bool didCommit = false; + bool couldArbitrate = false; + + if (numberOfParts != 0) { + + // decode the transaction + Transaction *transaction = new Transaction(); + for (int i = 0; i < numberOfParts; i++) { + bbDecode->get(); + TransactionPart *newPart = (TransactionPart *)TransactionPart_decode(NULL, bbDecode); + transaction->addPartDecode(newPart); + } + + // Arbitrate on transaction and pull relevant return data + Pair localArbitrateReturn = arbitrateOnLocalTransaction(transaction); + couldArbitrate = localArbitrateReturn.getFirst(); + didCommit = localArbitrateReturn.getSecond(); + + updateLiveStateFromLocal(); + + // Transaction was sent to the server so keep track of it to prevent double commit + if (transaction->getSequenceNumber() != -1) { + offlineTransactionsCommittedAndAtServer->add(new Pair(transaction->getId())); + } + } + + // The data to send back + int returnDataSize = 0; + Vector *unseenArbitrations = new Vector(); + + // Get the aborts to send back + Vector *abortLocalSequenceNumbers = new Vector(); + { + SetIterator *abortit = getKeyIterator(liveAbortsGeneratedByLocal); + while (abortit->hasNext()) + abortLocalSequenceNumbers->add(abortit->next()); + delete abortit; + } + + qsort(abortLocalSequenceNumbers->expose(), abortLocalSequenceNumbers->size(), sizeof(int64_t), compareInt64); + + uint asize = abortLocalSequenceNumbers->size(); + for (uint i = 0; i < asize; i++) { + int64_t localSequenceNumber = abortLocalSequenceNumbers->get(i); + if (localSequenceNumber <= lastArbitratedSequenceNumberSeen) { + continue; + } + + Abort *abort = liveAbortsGeneratedByLocal->get(localSequenceNumber); + unseenArbitrations->add(abort); + returnDataSize += abort->getSize(); + } + + // Get the commits to send back + Hashtable *commitForClientTable = liveCommitsTable->get(localMachineId); + if (commitForClientTable != NULL) { + Vector *commitLocalSequenceNumbers = new Vector(); + { + SetIterator *commitit = getKeyIterator(commitForClientTable); + while (commitit->hasNext()) + commitLocalSequenceNumbers->add(commitit->next()); + delete commitit; + } + qsort(commitLocalSequenceNumbers->expose(), commitLocalSequenceNumbers->size(), sizeof(int64_t), compareInt64); + + uint clsSize = commitLocalSequenceNumbers->size(); + for (uint clsi = 0; clsi < clsSize; clsi++) { + int64_t localSequenceNumber = commitLocalSequenceNumbers->get(clsi); + Commit *commit = commitForClientTable->get(localSequenceNumber); + + if (localSequenceNumber <= lastArbitratedSequenceNumberSeen) { + continue; + } + + { + Vector *parts = commit->getParts(); + uint nParts = parts->size(); + for (uint i = 0; i < nParts; i++) { + CommitPart *commitPart = parts->get(i); + unseenArbitrations->add(commitPart); + returnDataSize += commitPart->getSize(); + } + } + } + } + + // Number of arbitration entries to decode + returnDataSize += 2 * sizeof(int32_t); + + // bool of did commit or not + if (numberOfParts != 0) { + returnDataSize += sizeof(char); + } + + // Data to send Back + Array *returnData = new Array(returnDataSize); + ByteBuffer *bbEncode = ByteBuffer_wrap(returnData); + + if (numberOfParts != 0) { + if (didCommit) { + bbEncode->put((char)1); + } else { + bbEncode->put((char)0); + } + if (couldArbitrate) { + bbEncode->put((char)1); + } else { + bbEncode->put((char)0); + } + } + + bbEncode->putInt(unseenArbitrations->size()); + uint size = unseenArbitrations->size(); + for (uint i = 0; i < size; i++) { + Entry *entry = unseenArbitrations->get(i); + entry->encode(bbEncode); + } + + localSequenceNumber++; + return returnData; +} + +/** Checks whether a given slot was sent using new slots in + array. Returns true if sent and false otherwise. */ + +bool Table::checkSend(Array * array, Slot *checkSlot) { + uint size = array->length(); + for (uint i = 0; i < size; i++) { + Slot *s = array->get(i); + if ((s->getSequenceNumber() == checkSlot->getSequenceNumber()) && (s->getMachineID() == localMachineId)) { + return true; + } + } + + //Also need to see if other machines acknowledged our message + for (uint i = 0; i < size; i++) { + Slot *s = array->get(i); + + // Process each entry in the slot + Vector *entries = s->getEntries(); + uint eSize = entries->size(); + for (uint ei = 0; ei < eSize; ei++) { + Entry *entry = entries->get(ei); + + if (entry->getType() == TypeLastMessage) { + LastMessage *lastMessage = (LastMessage *)entry; + + if ((lastMessage->getMachineID() == localMachineId) && (lastMessage->getSequenceNumber() == checkSlot->getSequenceNumber())) { + return true; + } + } + } + } + //Not found + return false; +} + +/** Method tries to send slot to server. Returns status in tuple. + isInserted returns whether last un-acked send (if any) was + successful. Returns whether send was confirmed.x + */ + +bool Table::sendSlotsToServer(Slot *slot, int newSize, bool isNewKey, bool *isInserted, Array **array) { + attemptedToSendToServer = true; + + *array = cloud->putSlot(slot, newSize); + if (*array == NULL) { + *array = new Array(1); + (*array)->set(0, slot); + rejectedSlotVector->clear(); + *isInserted = false; + return true; + } else { + if ((*array)->length() == 0) { + throw new Error("Server Error: Did not send any slots"); + } + + if (hadPartialSendToServer) { + *isInserted = checkSend(*array, slot); + + if (!(*isInserted)) { + rejectedSlotVector->add(slot->getSequenceNumber()); + } + + return false; + } else { + rejectedSlotVector->add(slot->getSequenceNumber()); + *isInserted = false; + return false; + } + } +} + +/** + * Returns true if a resize was needed but not done. + */ +bool Table::fillSlot(Slot *slot, bool resize, NewKey *newKeyEntry, int & newSize, bool & insertedKey) { + newSize = 0;//special value to indicate no resize + if (liveSlotCount > bufferResizeThreshold) { + resize = true;//Resize is forced + } + + if (resize) { + newSize = (int) (numberOfSlots * Table_RESIZE_MULTIPLE); + TableStatus *status = new TableStatus(slot, newSize); + slot->addShallowEntry(status); + } + + // Fill with rejected slots first before doing anything else + doRejectedMessages(slot); + + // Do mandatory rescue of entries + ThreeTuple mandatoryRescueReturn = doMandatoryRescue(slot, resize); + + // Extract working variables + bool needsResize = mandatoryRescueReturn.getFirst(); + bool seenLiveSlot = mandatoryRescueReturn.getSecond(); + int64_t currentRescueSequenceNumber = mandatoryRescueReturn.getThird(); + + if (needsResize && !resize) { + // We need to resize but we are not resizing so return true to force on retry + return true; + } + + insertedKey = false; + if (newKeyEntry != NULL) { + newKeyEntry->setSlot(slot); + if (slot->hasSpace(newKeyEntry)) { + slot->addEntry(newKeyEntry); + insertedKey = true; + } + } + + // Clear the transactions, aborts and commits that were sent previously + clearSentParts(); + uint size = pendingSendArbitrationRounds->size(); + for (uint i = 0; i < size; i++) { + ArbitrationRound *round = pendingSendArbitrationRounds->get(i); + bool isFull = false; + round->generateParts(); + Vector *parts = round->getParts(); + + // Insert pending arbitration data + uint vsize = parts->size(); + for (uint vi = 0; vi < vsize; vi++) { + Entry *arbitrationData = parts->get(vi); + + // If it is an abort then we need to set some information + if (arbitrationData->getType() == TypeAbort) { + ((Abort *)arbitrationData)->setSequenceNumber(slot->getSequenceNumber()); + } + + if (!slot->hasSpace(arbitrationData)) { + // No space so cant do anything else with these data entries + isFull = true; + break; + } + + // Add to this current slot and add it to entries to delete + slot->addEntry(arbitrationData); + pendingSendArbitrationEntriesToDelete->add(arbitrationData); + } + + if (isFull) { + break; + } + } + + if (pendingTransactionQueue->size() > 0) { + Transaction *transaction = pendingTransactionQueue->get(0); + // Set the transaction sequence number if it has yet to be inserted into the block chain + if ((!transaction->didSendAPartToServer()) || (transaction->getSequenceNumber() == -1)) { + transaction->setSequenceNumber(slot->getSequenceNumber()); + } + + while (true) { + TransactionPart *part = transaction->getNextPartToSend(); + if (part == NULL) { + // Ran out of parts to send for this transaction so move on + break; + } + + if (slot->hasSpace(part)) { + slot->addEntry(part); + Vector *partsSent = transactionPartsSent->get(transaction); + if (partsSent == NULL) { + partsSent = new Vector(); + transactionPartsSent->put(transaction, partsSent); + } + partsSent->add(part->getPartNumber()); + transactionPartsSent->put(transaction, partsSent); + } else { + break; + } + } + } + + // Fill the remainder of the slot with rescue data + doOptionalRescue(slot, seenLiveSlot, currentRescueSequenceNumber, resize); + + return false; +} + +void Table::doRejectedMessages(Slot *s) { + if (!rejectedSlotVector->isEmpty()) { + /* TODO: We should avoid generating a rejected message entry if + * there is already a sufficient entry in the queue (e->g->, + * equalsto value of true and same sequence number)-> */ + + int64_t old_seqn = rejectedSlotVector->get(0); + if (rejectedSlotVector->size() > Table_REJECTED_THRESHOLD) { + int64_t new_seqn = rejectedSlotVector->lastElement(); + RejectedMessage *rm = new RejectedMessage(s, s->getSequenceNumber(), localMachineId, old_seqn, new_seqn, false); + s->addShallowEntry(rm); + } else { + int64_t prev_seqn = -1; + uint i = 0; + /* Go through list of missing messages */ + for (; i < rejectedSlotVector->size(); i++) { + int64_t curr_seqn = rejectedSlotVector->get(i); + Slot *s_msg = buffer->getSlot(curr_seqn); + if (s_msg != NULL) + break; + prev_seqn = curr_seqn; + } + /* Generate rejected message entry for missing messages */ + if (prev_seqn != -1) { + RejectedMessage *rm = new RejectedMessage(s, s->getSequenceNumber(), localMachineId, old_seqn, prev_seqn, false); + s->addShallowEntry(rm); + } + /* Generate rejected message entries for present messages */ + for (; i < rejectedSlotVector->size(); i++) { + int64_t curr_seqn = rejectedSlotVector->get(i); + Slot *s_msg = buffer->getSlot(curr_seqn); + int64_t machineid = s_msg->getMachineID(); + RejectedMessage *rm = new RejectedMessage(s, s->getSequenceNumber(), machineid, curr_seqn, curr_seqn, true); + s->addShallowEntry(rm); + } + } + } +} + +ThreeTuple Table::doMandatoryRescue(Slot *slot, bool resize) { + int64_t newestSequenceNumber = buffer->getNewestSeqNum(); + int64_t oldestSequenceNumber = buffer->getOldestSeqNum(); + if (oldestLiveSlotSequenceNumver < oldestSequenceNumber) { + oldestLiveSlotSequenceNumver = oldestSequenceNumber; + } + + int64_t currentSequenceNumber = oldestLiveSlotSequenceNumver; + bool seenLiveSlot = false; + int64_t firstIfFull = newestSequenceNumber + 1 - numberOfSlots; // smallest seq number in the buffer if it is full + int64_t threshold = firstIfFull + Table_FREE_SLOTS; // we want the buffer to be clear of live entries up to this point + + + // Mandatory Rescue + for (; currentSequenceNumber < threshold; currentSequenceNumber++) { + Slot *previousSlot = buffer->getSlot(currentSequenceNumber); + // Push slot number forward + if (!seenLiveSlot) { + oldestLiveSlotSequenceNumver = currentSequenceNumber; + } + + if (!previousSlot->isLive()) { + continue; + } + + // We have seen a live slot + seenLiveSlot = true; + + // Get all the live entries for a slot + Vector *liveEntries = previousSlot->getLiveEntries(resize); + + // Iterate over all the live entries and try to rescue them + uint lESize = liveEntries->size(); + for (uint i = 0; i < lESize; i++) { + Entry *liveEntry = liveEntries->get(i); + if (slot->hasSpace(liveEntry)) { + // Enough space to rescue the entry + slot->addEntry(liveEntry); + } else if (currentSequenceNumber == firstIfFull) { + //if there's no space but the entry is about to fall off the queue + return ThreeTuple(true, seenLiveSlot, currentSequenceNumber); + } + } + } + + // Did not resize + return ThreeTuple(false, seenLiveSlot, currentSequenceNumber); +} + +void Table::doOptionalRescue(Slot *s, bool seenliveslot, int64_t seqn, bool resize) { + /* now go through live entries from least to greatest sequence number until + * either all live slots added, or the slot doesn't have enough room + * for SKIP_THRESHOLD consecutive entries*/ + int skipcount = 0; + int64_t newestseqnum = buffer->getNewestSeqNum(); + for (; seqn <= newestseqnum; seqn++) { + Slot *prevslot = buffer->getSlot(seqn); + //Push slot number forward + if (!seenliveslot) + oldestLiveSlotSequenceNumver = seqn; + + if (!prevslot->isLive()) + continue; + seenliveslot = true; + Vector *liveentries = prevslot->getLiveEntries(resize); + uint lESize = liveentries->size(); + for (uint i = 0; i < lESize; i++) { + Entry *liveentry = liveentries->get(i); + if (s->hasSpace(liveentry)) + s->addEntry(liveentry); + else { + skipcount++; + if (skipcount > Table_SKIP_THRESHOLD) { + delete liveentries; + goto donesearch; + } + } + } + delete liveentries; + } +donesearch: + ; +} + +/** + * Checks for malicious activity and updates the local copy of the block chain-> + */ +void Table::validateAndUpdate(Array *newSlots, bool acceptUpdatesToLocal) { + // The cloud communication layer has checked slot HMACs already + // before decoding + if (newSlots->length() == 0) { + return; + } + + // Make sure all slots are newer than the last largest slot this + // client has seen + int64_t firstSeqNum = newSlots->get(0)->getSequenceNumber(); + if (firstSeqNum <= sequenceNumber) { + throw new Error("Server Error: Sent older slots!"); + } + + // Create an object that can access both new slots and slots in our + // local chain without committing slots to our local chain + SlotIndexer *indexer = new SlotIndexer(newSlots, buffer); + + // Check that the HMAC chain is not broken + checkHMACChain(indexer, newSlots); + + // Set to keep track of messages from clients + Hashset *machineSet = new Hashset(); + { + SetIterator *> *lmit = getKeyIterator(lastMessageTable); + while (lmit->hasNext()) + machineSet->add(lmit->next()); + delete lmit; + } + + // Process each slots data + { + uint numSlots = newSlots->length(); + for (uint i = 0; i < numSlots; i++) { + Slot *slot = newSlots->get(i); + processSlot(indexer, slot, acceptUpdatesToLocal, machineSet); + updateExpectedSize(); + } + } + delete indexer; + + // If there is a gap, check to see if the server sent us + // everything-> + if (firstSeqNum != (sequenceNumber + 1)) { + + // Check the size of the slots that were sent down by the server-> + // Can only check the size if there was a gap + checkNumSlots(newSlots->length()); + + // Since there was a gap every machine must have pushed a slot or + // must have a last message message-> If not then the server is + // hiding slots + if (!machineSet->isEmpty()) { + delete machineSet; + throw new Error("Missing record for machines: "); + } + } + delete machineSet; + // Update the size of our local block chain-> + commitNewMaxSize(); + + // Commit new to slots to the local block chain-> + { + uint numSlots = newSlots->length(); + for (uint i = 0; i < numSlots; i++) { + Slot *slot = newSlots->get(i); + + // Insert this slot into our local block chain copy-> + buffer->putSlot(slot); + + // Keep track of how many slots are currently live (have live data + // in them)-> + liveSlotCount++; + } + } + // Get the sequence number of the latest slot in the system + sequenceNumber = newSlots->get(newSlots->length() - 1)->getSequenceNumber(); + updateLiveStateFromServer(); + + // No Need to remember after we pulled from the server + offlineTransactionsCommittedAndAtServer->clear(); + + // This is invalidated now + hadPartialSendToServer = false; +} + +void Table::updateLiveStateFromServer() { + // Process the new transaction parts + processNewTransactionParts(); + + // Do arbitration on new transactions that were received + arbitrateFromServer(); + + // Update all the committed keys + bool didCommitOrSpeculate = updateCommittedTable(); + + // Delete the transactions that are now dead + updateLiveTransactionsAndStatus(); + + // Do speculations + didCommitOrSpeculate |= updateSpeculativeTable(didCommitOrSpeculate); + updatePendingTransactionSpeculativeTable(didCommitOrSpeculate); +} + +void Table::updateLiveStateFromLocal() { + // Update all the committed keys + bool didCommitOrSpeculate = updateCommittedTable(); + + // Delete the transactions that are now dead + updateLiveTransactionsAndStatus(); + + // Do speculations + didCommitOrSpeculate |= updateSpeculativeTable(didCommitOrSpeculate); + updatePendingTransactionSpeculativeTable(didCommitOrSpeculate); +} + +void Table::initExpectedSize(int64_t firstSequenceNumber, int64_t numberOfSlots) { + int64_t prevslots = firstSequenceNumber; + + if (didFindTableStatus) { + } else { + expectedsize = (prevslots < ((int64_t) numberOfSlots)) ? (int) prevslots : numberOfSlots; + } + + didFindTableStatus = true; + currMaxSize = numberOfSlots; +} + +void Table::updateExpectedSize() { + expectedsize++; + + if (expectedsize > currMaxSize) { + expectedsize = currMaxSize; + } +} + + +/** + * Check the size of the block chain to make sure there are enough + * slots sent back by the server-> This is only called when we have a + * gap between the slots that we have locally and the slots sent by + * the server therefore in the slots sent by the server there will be + * at least 1 Table status message + */ +void Table::checkNumSlots(int numberOfSlots) { + if (numberOfSlots != expectedsize) { + throw new Error("Server Error: Server did not send all slots-> Expected: "); + } +} + +/** + * Update the size of of the local buffer if it is needed-> + */ +void Table::commitNewMaxSize() { + didFindTableStatus = false; + + // Resize the local slot buffer + if (numberOfSlots != currMaxSize) { + buffer->resize((int32_t)currMaxSize); + } + + // Change the number of local slots to the new size + numberOfSlots = (int32_t)currMaxSize; + + // Recalculate the resize threshold since the size of the local + // buffer has changed + setResizeThreshold(); +} + +/** + * Process the new transaction parts from this latest round of slots + * received from the server + */ +void Table::processNewTransactionParts() { + + if (newTransactionParts->size() == 0) { + // Nothing new to process + return; + } + + // Iterate through all the machine Ids that we received new parts + // for + SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *tpit = getKeyIterator(newTransactionParts); + while (tpit->hasNext()) { + int64_t machineId = tpit->next(); + Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = tpit->currVal(); + + SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *ptit = getKeyIterator(parts); + // Iterate through all the parts for that machine Id + while (ptit->hasNext()) { + Pair *partId = ptit->next(); + TransactionPart *part = parts->get(partId); + + if (lastArbitratedTransactionNumberByArbitratorTable->contains(part->getArbitratorId())) { + int64_t lastTransactionNumber = lastArbitratedTransactionNumberByArbitratorTable->get(part->getArbitratorId()); + if (lastTransactionNumber >= part->getSequenceNumber()) { + // Set dead the transaction part + part->setDead(); + part->releaseRef(); + continue; + } + } + + // Get the transaction object for that sequence number + Transaction *transaction = liveTransactionBySequenceNumberTable->get(part->getSequenceNumber()); + + if (transaction == NULL) { + // This is a new transaction that we dont have so make a new one + transaction = new Transaction(); + + // Add that part to the transaction + transaction->addPartDecode(part); + + // Insert this new transaction into the live tables + liveTransactionBySequenceNumberTable->put(part->getSequenceNumber(), transaction); + liveTransactionByTransactionIdTable->put(transaction->getId(), transaction); + } + part->releaseRef(); + } + delete ptit; + } + delete tpit; + // Clear all the new transaction parts in preparation for the next + // time the server sends slots + { + SetIterator *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *partsit = getKeyIterator(newTransactionParts); + while (partsit->hasNext()) { + int64_t machineId = partsit->next(); + Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = newTransactionParts->get(machineId); + delete parts; + } + delete partsit; + newTransactionParts->clear(); + } +} + +void Table::arbitrateFromServer() { + if (liveTransactionBySequenceNumberTable->size() == 0) { + // Nothing to arbitrate on so move on + return; + } + + // Get the transaction sequence numbers and sort from oldest to newest + Vector *transactionSequenceNumbers = new Vector(); + { + SetIterator *trit = getKeyIterator(liveTransactionBySequenceNumberTable); + while (trit->hasNext()) + transactionSequenceNumbers->add(trit->next()); + delete trit; + } + qsort(transactionSequenceNumbers->expose(), transactionSequenceNumbers->size(), sizeof(int64_t), compareInt64); + + // Collection of key value pairs that are + Hashtable *speculativeTableTmp = new Hashtable(); + + // The last transaction arbitrated on + int64_t lastTransactionCommitted = -1; + Hashset *generatedAborts = new Hashset(); + uint tsnSize = transactionSequenceNumbers->size(); + for (uint i = 0; i < tsnSize; i++) { + int64_t transactionSequenceNumber = transactionSequenceNumbers->get(i); + Transaction *transaction = liveTransactionBySequenceNumberTable->get(transactionSequenceNumber); + + // Check if this machine arbitrates for this transaction if not + // then we cant arbitrate this transaction + if (transaction->getArbitrator() != localMachineId) { + continue; + } + + if (transactionSequenceNumber < lastSeqNumArbOn) { + continue; + } + + if (offlineTransactionsCommittedAndAtServer->contains(transaction->getId())) { + // We have seen this already locally so dont commit again + continue; + } + + if (!transaction->isComplete()) { + // Will arbitrate in incorrect order if we continue so just break + // Most likely this + break; + } + + // update the largest transaction seen by arbitrator from server + if (!lastTransactionSeenFromMachineFromServer->contains(transaction->getMachineId())) { + lastTransactionSeenFromMachineFromServer->put(transaction->getMachineId(), transaction->getClientLocalSequenceNumber()); + } else { + int64_t lastTransactionSeenFromMachine = lastTransactionSeenFromMachineFromServer->get(transaction->getMachineId()); + if (transaction->getClientLocalSequenceNumber() > lastTransactionSeenFromMachine) { + lastTransactionSeenFromMachineFromServer->put(transaction->getMachineId(), transaction->getClientLocalSequenceNumber()); + } + } + + if (transaction->evaluateGuard(committedKeyValueTable, speculativeTableTmp, NULL)) { + // Guard evaluated as true + // Update the local changes so we can make the commit + SetIterator *kvit = transaction->getKeyValueUpdateSet()->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + speculativeTableTmp->put(kv->getKey(), kv); + } + delete kvit; + + // Update what the last transaction committed was for use in batch commit + lastTransactionCommitted = transactionSequenceNumber; + } else { + // Guard evaluated was false so create abort + // Create the abort + Abort *newAbort = new Abort(NULL, + transaction->getClientLocalSequenceNumber(), + transaction->getSequenceNumber(), + transaction->getMachineId(), + transaction->getArbitrator(), + localArbitrationSequenceNumber); + localArbitrationSequenceNumber++; + generatedAborts->add(newAbort); + + // Insert the abort so we can process + processEntry(newAbort); + } + + lastSeqNumArbOn = transactionSequenceNumber; + } + + delete transactionSequenceNumbers; + + Commit *newCommit = NULL; + + // If there is something to commit + if (speculativeTableTmp->size() != 0) { + // Create the commit and increment the commit sequence number + newCommit = new Commit(localArbitrationSequenceNumber, localMachineId, lastTransactionCommitted); + localArbitrationSequenceNumber++; + + // Add all the new keys to the commit + SetIterator *spit = getKeyIterator(speculativeTableTmp); + while (spit->hasNext()) { + IoTString *string = spit->next(); + KeyValue *kv = speculativeTableTmp->get(string); + newCommit->addKV(kv); + } + delete spit; + + // create the commit parts + newCommit->createCommitParts(); + + // Append all the commit parts to the end of the pending queue + // waiting for sending to the server + // Insert the commit so we can process it + Vector *parts = newCommit->getParts(); + uint partsSize = parts->size(); + for (uint i = 0; i < partsSize; i++) { + CommitPart *commitPart = parts->get(i); + processEntry(commitPart); + } + } + delete speculativeTableTmp; + + if ((newCommit != NULL) || (generatedAborts->size() > 0)) { + ArbitrationRound *arbitrationRound = new ArbitrationRound(newCommit, generatedAborts); + pendingSendArbitrationRounds->add(arbitrationRound); + + if (compactArbitrationData()) { + ArbitrationRound *newArbitrationRound = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - 1); + if (newArbitrationRound->getCommit() != NULL) { + Vector *parts = newArbitrationRound->getCommit()->getParts(); + uint partsSize = parts->size(); + for (uint i = 0; i < partsSize; i++) { + CommitPart *commitPart = parts->get(i); + processEntry(commitPart); + } + } + } + } else { + delete generatedAborts; + } +} + +Pair Table::arbitrateOnLocalTransaction(Transaction *transaction) { + + // Check if this machine arbitrates for this transaction if not then + // we cant arbitrate this transaction + if (transaction->getArbitrator() != localMachineId) { + return Pair(false, false); + } + + if (!transaction->isComplete()) { + // Will arbitrate in incorrect order if we continue so just break + // Most likely this + return Pair(false, false); + } + + if (transaction->getMachineId() != localMachineId) { + // dont do this check for local transactions + if (lastTransactionSeenFromMachineFromServer->contains(transaction->getMachineId())) { + if (lastTransactionSeenFromMachineFromServer->get(transaction->getMachineId()) > transaction->getClientLocalSequenceNumber()) { + // We've have already seen this from the server + return Pair(false, false); + } + } + } + + if (transaction->evaluateGuard(committedKeyValueTable, NULL, NULL)) { + // Guard evaluated as true Create the commit and increment the + // commit sequence number + Commit *newCommit = new Commit(localArbitrationSequenceNumber, localMachineId, -1); + localArbitrationSequenceNumber++; + + // Update the local changes so we can make the commit + SetIterator *kvit = transaction->getKeyValueUpdateSet()->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + newCommit->addKV(kv); + } + delete kvit; + + // create the commit parts + newCommit->createCommitParts(); + + // Append all the commit parts to the end of the pending queue + // waiting for sending to the server + ArbitrationRound *arbitrationRound = new ArbitrationRound(newCommit, new Hashset()); + pendingSendArbitrationRounds->add(arbitrationRound); + + if (compactArbitrationData()) { + ArbitrationRound *newArbitrationRound = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - 1); + Vector *parts = newArbitrationRound->getCommit()->getParts(); + uint partsSize = parts->size(); + for (uint i = 0; i < partsSize; i++) { + CommitPart *commitPart = parts->get(i); + processEntry(commitPart); + } + } else { + // Insert the commit so we can process it + Vector *parts = newCommit->getParts(); + uint partsSize = parts->size(); + for (uint i = 0; i < partsSize; i++) { + CommitPart *commitPart = parts->get(i); + processEntry(commitPart); + } + } + + if (transaction->getMachineId() == localMachineId) { + TransactionStatus *status = transaction->getTransactionStatus(); + if (status != NULL) { + status->setStatus(TransactionStatus_StatusCommitted); + } + } + + updateLiveStateFromLocal(); + return Pair(true, true); + } else { + if (transaction->getMachineId() == localMachineId) { + // For locally created messages update the status + // Guard evaluated was false so create abort + TransactionStatus *status = transaction->getTransactionStatus(); + if (status != NULL) { + status->setStatus(TransactionStatus_StatusAborted); + } + } else { + Hashset *addAbortSet = new Hashset(); + + // Create the abort + Abort *newAbort = new Abort(NULL, + transaction->getClientLocalSequenceNumber(), + -1, + transaction->getMachineId(), + transaction->getArbitrator(), + localArbitrationSequenceNumber); + localArbitrationSequenceNumber++; + addAbortSet->add(newAbort); + + // Append all the commit parts to the end of the pending queue + // waiting for sending to the server + ArbitrationRound *arbitrationRound = new ArbitrationRound(NULL, addAbortSet); + pendingSendArbitrationRounds->add(arbitrationRound); + + if (compactArbitrationData()) { + ArbitrationRound *newArbitrationRound = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - 1); + + Vector *parts = newArbitrationRound->getCommit()->getParts(); + uint partsSize = parts->size(); + for (uint i = 0; i < partsSize; i++) { + CommitPart *commitPart = parts->get(i); + processEntry(commitPart); + } + } + } + + updateLiveStateFromLocal(); + return Pair(true, false); + } +} + +/** + * Compacts the arbitration data by merging commits and aggregating + * aborts so that a single large push of commits can be done instead + * of many small updates + */ +bool Table::compactArbitrationData() { + if (pendingSendArbitrationRounds->size() < 2) { + // Nothing to compact so do nothing + return false; + } + + ArbitrationRound *lastRound = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - 1); + if (lastRound->getDidSendPart()) { + return false; + } + + bool hadCommit = (lastRound->getCommit() == NULL); + bool gotNewCommit = false; + + uint numberToDelete = 1; + + while (numberToDelete < pendingSendArbitrationRounds->size()) { + ArbitrationRound *round = pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size() - numberToDelete - 1); + + if (round->isFull() || round->getDidSendPart()) { + // Stop since there is a part that cannot be compacted and we + // need to compact in order + break; + } + + if (round->getCommit() == NULL) { + // Try compacting aborts only + int newSize = round->getCurrentSize() + lastRound->getAbortsCount(); + if (newSize > ArbitrationRound_MAX_PARTS) { + // Cant compact since it would be too large + break; + } + lastRound->addAborts(round->getAborts()); + } else { + // Create a new larger commit + Commit *newCommit = Commit_merge(lastRound->getCommit(), round->getCommit(), localArbitrationSequenceNumber); + localArbitrationSequenceNumber++; + + // Create the commit parts so that we can count them + newCommit->createCommitParts(); + + // Calculate the new size of the parts + int newSize = newCommit->getNumberOfParts(); + newSize += lastRound->getAbortsCount(); + newSize += round->getAbortsCount(); + + if (newSize > ArbitrationRound_MAX_PARTS) { + // Can't compact since it would be too large + if (lastRound->getCommit() != newCommit && + round->getCommit() != newCommit) + delete newCommit; + break; + } + // Set the new compacted part + if (lastRound->getCommit() == newCommit) + lastRound->setCommit(NULL); + if (round->getCommit() == newCommit) + round->setCommit(NULL); + + if (lastRound->getCommit() != NULL) { + Commit * oldcommit = lastRound->getCommit(); + lastRound->setCommit(NULL); + delete oldcommit; + } + lastRound->setCommit(newCommit); + lastRound->addAborts(round->getAborts()); + gotNewCommit = true; + } + + numberToDelete++; + } + + if (numberToDelete != 1) { + // If there is a compaction + // Delete the previous pieces that are now in the new compacted piece + for (uint i = 2; i <= numberToDelete; i++) { + delete pendingSendArbitrationRounds->get(pendingSendArbitrationRounds->size()-i); + } + pendingSendArbitrationRounds->setSize(pendingSendArbitrationRounds->size() - numberToDelete); + + pendingSendArbitrationRounds->add(lastRound); + + // Should reinsert into the commit processor + if (hadCommit && gotNewCommit) { + return true; + } + } + + return false; +} + +/** + * Update all the commits and the committed tables, sets dead the dead + * transactions + */ +bool Table::updateCommittedTable() { + if (newCommitParts->size() == 0) { + // Nothing new to process + return false; + } + + // Iterate through all the machine Ids that we received new parts for + SetIterator *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *> *partsit = getKeyIterator(newCommitParts); + while (partsit->hasNext()) { + int64_t machineId = partsit->next(); + Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *parts = newCommitParts->get(machineId); + + // Iterate through all the parts for that machine Id + SetIterator *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *pairit = getKeyIterator(parts); + while (pairit->hasNext()) { + Pair *partId = pairit->next(); + CommitPart *part = pairit->currVal(); + + // Get the transaction object for that sequence number + Hashtable *commitForClientTable = liveCommitsTable->get(part->getMachineId()); + + if (commitForClientTable == NULL) { + // This is the first commit from this device + commitForClientTable = new Hashtable(); + liveCommitsTable->put(part->getMachineId(), commitForClientTable); + } + + Commit *commit = commitForClientTable->get(part->getSequenceNumber()); + + if (commit == NULL) { + // This is a new commit that we dont have so make a new one + commit = new Commit(); + + // Insert this new commit into the live tables + commitForClientTable->put(part->getSequenceNumber(), commit); + } + + // Add that part to the commit + commit->addPartDecode(part); + part->releaseRef(); + } + delete pairit; + delete parts; + } + delete partsit; + + // Clear all the new commits parts in preparation for the next time + // the server sends slots + newCommitParts->clear(); + + // If we process a new commit keep track of it for future use + bool didProcessANewCommit = false; + + // Process the commits one by one + SetIterator *> *liveit = getKeyIterator(liveCommitsTable); + while (liveit->hasNext()) { + int64_t arbitratorId = liveit->next(); + // Get all the commits for a specific arbitrator + Hashtable *commitForClientTable = liveCommitsTable->get(arbitratorId); + + // Sort the commits in order + Vector *commitSequenceNumbers = new Vector(); + { + SetIterator *clientit = getKeyIterator(commitForClientTable); + while (clientit->hasNext()) + commitSequenceNumbers->add(clientit->next()); + delete clientit; + } + + qsort(commitSequenceNumbers->expose(), commitSequenceNumbers->size(), sizeof(int64_t), compareInt64); + + // Get the last commit seen from this arbitrator + int64_t lastCommitSeenSequenceNumber = -1; + if (lastCommitSeenSequenceNumberByArbitratorTable->contains(arbitratorId)) { + lastCommitSeenSequenceNumber = lastCommitSeenSequenceNumberByArbitratorTable->get(arbitratorId); + } + + // Go through each new commit one by one + for (uint i = 0; i < commitSequenceNumbers->size(); i++) { + int64_t commitSequenceNumber = commitSequenceNumbers->get(i); + Commit *commit = commitForClientTable->get(commitSequenceNumber); + // Special processing if a commit is not complete + if (!commit->isComplete()) { + if (i == (commitSequenceNumbers->size() - 1)) { + // If there is an incomplete commit and this commit is the + // latest one seen then this commit cannot be processed and + // there are no other commits + break; + } else { + // This is a commit that was already dead but parts of it + // are still in the block chain (not flushed out yet)-> + // Delete it and move on + commit->setDead(); + commitForClientTable->remove(commit->getSequenceNumber()); + delete commit; + continue; + } + } + + // Update the last transaction that was updated if we can + if (commit->getTransactionSequenceNumber() != -1) { + // Update the last transaction sequence number that the arbitrator arbitrated on1 + if (!lastArbitratedTransactionNumberByArbitratorTable->contains(commit->getMachineId()) || lastArbitratedTransactionNumberByArbitratorTable->get(commit->getMachineId()) < commit->getTransactionSequenceNumber()) { + lastArbitratedTransactionNumberByArbitratorTable->put(commit->getMachineId(), commit->getTransactionSequenceNumber()); + } + } + + // Update the last arbitration data that we have seen so far + if (lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->contains(commit->getMachineId())) { + int64_t lastArbitrationSequenceNumber = lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->get(commit->getMachineId()); + if (commit->getSequenceNumber() > lastArbitrationSequenceNumber) { + // Is larger + lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->put(commit->getMachineId(), commit->getSequenceNumber()); + } + } else { + // Never seen any data from this arbitrator so record the first one + lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->put(commit->getMachineId(), commit->getSequenceNumber()); + } + + // We have already seen this commit before so need to do the + // full processing on this commit + if (commit->getSequenceNumber() <= lastCommitSeenSequenceNumber) { + // Update the last transaction that was updated if we can + if (commit->getTransactionSequenceNumber() != -1) { + int64_t lastTransactionNumber = lastArbitratedTransactionNumberByArbitratorTable->get(commit->getMachineId()); + if (!lastArbitratedTransactionNumberByArbitratorTable->contains(commit->getMachineId()) || + lastArbitratedTransactionNumberByArbitratorTable->get(commit->getMachineId()) < commit->getTransactionSequenceNumber()) { + lastArbitratedTransactionNumberByArbitratorTable->put(commit->getMachineId(), commit->getTransactionSequenceNumber()); + } + } + continue; + } + + // If we got here then this is a brand new commit and needs full + // processing + // Get what commits should be edited, these are the commits that + // have live values for their keys + Hashset *commitsToEdit = new Hashset(); + { + SetIterator *kvit = commit->getKeyValueUpdateSet()->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + Commit *commit = liveCommitsByKeyTable->get(kv->getKey()); + if (commit != NULL) + commitsToEdit->add(commit); + } + delete kvit; + } + + // Update each previous commit that needs to be updated + SetIterator *commitit = commitsToEdit->iterator(); + while (commitit->hasNext()) { + Commit *previousCommit = commitit->next(); + + // Only bother with live commits (TODO: Maybe remove this check) + if (previousCommit->isLive()) { + + // Update which keys in the old commits are still live + { + SetIterator *kvit = commit->getKeyValueUpdateSet()->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + previousCommit->invalidateKey(kv->getKey()); + } + delete kvit; + } + + // if the commit is now dead then remove it + if (!previousCommit->isLive()) { + commitForClientTable->remove(previousCommit->getSequenceNumber()); + delete previousCommit; + } + } + } + delete commitit; + delete commitsToEdit; + + // Update the last seen sequence number from this arbitrator + if (lastCommitSeenSequenceNumberByArbitratorTable->contains(commit->getMachineId())) { + if (commit->getSequenceNumber() > lastCommitSeenSequenceNumberByArbitratorTable->get(commit->getMachineId())) { + lastCommitSeenSequenceNumberByArbitratorTable->put(commit->getMachineId(), commit->getSequenceNumber()); + } + } else { + lastCommitSeenSequenceNumberByArbitratorTable->put(commit->getMachineId(), commit->getSequenceNumber()); + } + + // We processed a new commit that we havent seen before + didProcessANewCommit = true; + + // Update the committed table of keys and which commit is using which key + { + SetIterator *kvit = commit->getKeyValueUpdateSet()->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + committedKeyValueTable->put(kv->getKey(), kv); + liveCommitsByKeyTable->put(kv->getKey(), commit); + } + delete kvit; + } + } + delete commitSequenceNumbers; + } + delete liveit; + + return didProcessANewCommit; +} + +/** + * Create the speculative table from transactions that are still live + * and have come from the cloud + */ +bool Table::updateSpeculativeTable(bool didProcessNewCommits) { + if (liveTransactionBySequenceNumberTable->size() == 0) { + // There is nothing to speculate on + return false; + } + + // Create a list of the transaction sequence numbers and sort them + // from oldest to newest + Vector *transactionSequenceNumbersSorted = new Vector(); + { + SetIterator *trit = getKeyIterator(liveTransactionBySequenceNumberTable); + while (trit->hasNext()) + transactionSequenceNumbersSorted->add(trit->next()); + delete trit; + } + + qsort(transactionSequenceNumbersSorted->expose(), transactionSequenceNumbersSorted->size(), sizeof(int64_t), compareInt64); + + bool hasGapInTransactionSequenceNumbers = transactionSequenceNumbersSorted->get(0) != oldestTransactionSequenceNumberSpeculatedOn; + + + if (hasGapInTransactionSequenceNumbers || didProcessNewCommits) { + // If there is a gap in the transaction sequence numbers then + // there was a commit or an abort of a transaction OR there was a + // new commit (Could be from offline commit) so a redo the + // speculation from scratch + + // Start from scratch + speculatedKeyValueTable->clear(); + lastTransactionSequenceNumberSpeculatedOn = -1; + oldestTransactionSequenceNumberSpeculatedOn = -1; + } + + // Remember the front of the transaction list + oldestTransactionSequenceNumberSpeculatedOn = transactionSequenceNumbersSorted->get(0); + + // Find where to start arbitration from + uint startIndex = 0; + + for (; startIndex < transactionSequenceNumbersSorted->size(); startIndex++) + if (transactionSequenceNumbersSorted->get(startIndex) == lastTransactionSequenceNumberSpeculatedOn) + break; + startIndex++; + + if (startIndex >= transactionSequenceNumbersSorted->size()) { + // Make sure we are not out of bounds + delete transactionSequenceNumbersSorted; + return false; // did not speculate + } + + Hashset *incompleteTransactionArbitrator = new Hashset(); + bool didSkip = true; + + for (uint i = startIndex; i < transactionSequenceNumbersSorted->size(); i++) { + int64_t transactionSequenceNumber = transactionSequenceNumbersSorted->get(i); + Transaction *transaction = liveTransactionBySequenceNumberTable->get(transactionSequenceNumber); + + if (!transaction->isComplete()) { + // If there is an incomplete transaction then there is nothing + // we can do add this transactions arbitrator to the list of + // arbitrators we should ignore + incompleteTransactionArbitrator->add(transaction->getArbitrator()); + didSkip = true; + continue; + } + + if (incompleteTransactionArbitrator->contains(transaction->getArbitrator())) { + continue; + } + + lastTransactionSequenceNumberSpeculatedOn = transactionSequenceNumber; + + if (transaction->evaluateGuard(committedKeyValueTable, speculatedKeyValueTable, NULL)) { + // Guard evaluated to true so update the speculative table + { + SetIterator *kvit = transaction->getKeyValueUpdateSet()->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + speculatedKeyValueTable->put(kv->getKey(), kv); + } + delete kvit; + } + } + } + + delete transactionSequenceNumbersSorted; + + if (didSkip) { + // Since there was a skip we need to redo the speculation next time around + lastTransactionSequenceNumberSpeculatedOn = -1; + oldestTransactionSequenceNumberSpeculatedOn = -1; + } + + // We did some speculation + return true; +} + +/** + * Create the pending transaction speculative table from transactions + * that are still in the pending transaction buffer + */ +void Table::updatePendingTransactionSpeculativeTable(bool didProcessNewCommitsOrSpeculate) { + if (pendingTransactionQueue->size() == 0) { + // There is nothing to speculate on + return; + } + + if (didProcessNewCommitsOrSpeculate || (firstPendingTransaction != pendingTransactionQueue->get(0))) { + // need to reset on the pending speculation + lastPendingTransactionSpeculatedOn = NULL; + firstPendingTransaction = pendingTransactionQueue->get(0); + pendingTransactionSpeculatedKeyValueTable->clear(); + } + + // Find where to start arbitration from + uint startIndex = 0; + + for (; startIndex < pendingTransactionQueue->size(); startIndex++) + if (pendingTransactionQueue->get(startIndex) == firstPendingTransaction) + break; + + if (startIndex >= pendingTransactionQueue->size()) { + // Make sure we are not out of bounds + return; + } + + for (uint i = startIndex; i < pendingTransactionQueue->size(); i++) { + Transaction *transaction = pendingTransactionQueue->get(i); + + lastPendingTransactionSpeculatedOn = transaction; + + if (transaction->evaluateGuard(committedKeyValueTable, speculatedKeyValueTable, pendingTransactionSpeculatedKeyValueTable)) { + // Guard evaluated to true so update the speculative table + SetIterator *kvit = transaction->getKeyValueUpdateSet()->iterator(); + while (kvit->hasNext()) { + KeyValue *kv = kvit->next(); + pendingTransactionSpeculatedKeyValueTable->put(kv->getKey(), kv); + } + delete kvit; + } + } +} + +/** + * Set dead and remove from the live transaction tables the + * transactions that are dead + */ +void Table::updateLiveTransactionsAndStatus() { + // Go through each of the transactions + { + SetIterator *iter = getKeyIterator(liveTransactionBySequenceNumberTable); + while (iter->hasNext()) { + int64_t key = iter->next(); + Transaction *transaction = liveTransactionBySequenceNumberTable->get(key); + + // Check if the transaction is dead + if (lastArbitratedTransactionNumberByArbitratorTable->contains(transaction->getArbitrator()) + && lastArbitratedTransactionNumberByArbitratorTable->get(transaction->getArbitrator()) >= transaction->getSequenceNumber()) { + // Set dead the transaction + transaction->setDead(); + + // Remove the transaction from the live table + iter->remove(); + liveTransactionByTransactionIdTable->remove(transaction->getId()); + delete transaction; + } + } + delete iter; + } + + // Go through each of the transactions + { + SetIterator *iter = getKeyIterator(outstandingTransactionStatus); + while (iter->hasNext()) { + int64_t key = iter->next(); + TransactionStatus *status = outstandingTransactionStatus->get(key); + + // Check if the transaction is dead + if (lastArbitratedTransactionNumberByArbitratorTable->contains(status->getTransactionArbitrator()) + && (lastArbitratedTransactionNumberByArbitratorTable->get(status->getTransactionArbitrator()) >= status->getTransactionSequenceNumber())) { + // Set committed + status->setStatus(TransactionStatus_StatusCommitted); + + // Remove + iter->remove(); + } + } + delete iter; + } +} + +/** + * Process this slot, entry by entry-> Also update the latest message sent by slot + */ +void Table::processSlot(SlotIndexer *indexer, Slot *slot, bool acceptUpdatesToLocal, Hashset *machineSet) { + + // Update the last message seen + updateLastMessage(slot->getMachineID(), slot->getSequenceNumber(), slot, acceptUpdatesToLocal, machineSet); + + // Process each entry in the slot + Vector *entries = slot->getEntries(); + uint eSize = entries->size(); + for (uint ei = 0; ei < eSize; ei++) { + Entry *entry = entries->get(ei); + switch (entry->getType()) { + case TypeCommitPart: + processEntry((CommitPart *)entry); + break; + case TypeAbort: + processEntry((Abort *)entry); + break; + case TypeTransactionPart: + processEntry((TransactionPart *)entry); + break; + case TypeNewKey: + processEntry((NewKey *)entry); + break; + case TypeLastMessage: + processEntry((LastMessage *)entry, machineSet); + break; + case TypeRejectedMessage: + processEntry((RejectedMessage *)entry, indexer); + break; + case TypeTableStatus: + processEntry((TableStatus *)entry, slot->getSequenceNumber()); + break; + default: + throw new Error("Unrecognized type: "); + } + } +} + +/** + * Update the last message that was sent for a machine Id + */ +void Table::processEntry(LastMessage *entry, Hashset *machineSet) { + // Update what the last message received by a machine was + updateLastMessage(entry->getMachineID(), entry->getSequenceNumber(), entry, false, machineSet); +} + +/** + * Add the new key to the arbitrators table and update the set of live + * new keys (in case of a rescued new key message) + */ +void Table::processEntry(NewKey *entry) { + // Update the arbitrator table with the new key information + arbitratorTable->put(entry->getKey(), entry->getMachineID()); + + // Update what the latest live new key is + NewKey *oldNewKey = liveNewKeyTable->put(entry->getKey(), entry); + if (oldNewKey != NULL) { + // Delete the old new key messages + oldNewKey->setDead(); + } +} + +/** + * Process new table status entries and set dead the old ones as new + * ones come in-> keeps track of the largest and smallest table status + * seen in this current round of updating the local copy of the block + * chain + */ +void Table::processEntry(TableStatus *entry, int64_t seq) { + int newNumSlots = entry->getMaxSlots(); + updateCurrMaxSize(newNumSlots); + initExpectedSize(seq, newNumSlots); + + if (liveTableStatus != NULL) { + // We have a larger table status so the old table status is no + // int64_ter alive + liveTableStatus->setDead(); + } + + // Make this new table status the latest alive table status + liveTableStatus = entry; +} + +/** + * Check old messages to see if there is a block chain violation-> + * Also + */ +void Table::processEntry(RejectedMessage *entry, SlotIndexer *indexer) { + int64_t oldSeqNum = entry->getOldSeqNum(); + int64_t newSeqNum = entry->getNewSeqNum(); + bool isequal = entry->getEqual(); + int64_t machineId = entry->getMachineID(); + int64_t seq = entry->getSequenceNumber(); + + // Check if we have messages that were supposed to be rejected in + // our local block chain + for (int64_t seqNum = oldSeqNum; seqNum <= newSeqNum; seqNum++) { + // Get the slot + Slot *slot = indexer->getSlot(seqNum); + + if (slot != NULL) { + // If we have this slot make sure that it was not supposed to be + // a rejected slot + int64_t slotMachineId = slot->getMachineID(); + if (isequal != (slotMachineId == machineId)) { + throw new Error("Server Error: Trying to insert rejected message for slot "); + } + } + } + + // Create a list of clients to watch until they see this rejected + // message entry-> + Hashset *deviceWatchSet = new Hashset(); + SetIterator *> *iter = getKeyIterator(lastMessageTable); + while (iter->hasNext()) { + // Machine ID for the last message entry + int64_t lastMessageEntryMachineId = iter->next(); + + // We've seen it, don't need to continue to watch-> Our next + // message will implicitly acknowledge it-> + if (lastMessageEntryMachineId == localMachineId) { + continue; + } + + Pair *lastMessageValue = lastMessageTable->get(lastMessageEntryMachineId); + int64_t entrySequenceNumber = lastMessageValue->getFirst(); + + if (entrySequenceNumber < seq) { + // Add this rejected message to the set of messages that this + // machine ID did not see yet + addWatchVector(lastMessageEntryMachineId, entry); + // This client did not see this rejected message yet so add it + // to the watch set to monitor + deviceWatchSet->add(lastMessageEntryMachineId); + } + } + delete iter; + + if (deviceWatchSet->isEmpty()) { + // This rejected message has been seen by all the clients so + entry->setDead(); + delete deviceWatchSet; + } else { + // We need to watch this rejected message + entry->setWatchSet(deviceWatchSet); + } +} + +/** + * Check if this abort is live, if not then save it so we can kill it + * later-> update the last transaction number that was arbitrated on-> + */ +void Table::processEntry(Abort *entry) { + if (entry->getTransactionSequenceNumber() != -1) { + // update the transaction status if it was sent to the server + TransactionStatus *status = outstandingTransactionStatus->remove(entry->getTransactionSequenceNumber()); + if (status != NULL) { + status->setStatus(TransactionStatus_StatusAborted); + } + } + + // Abort has not been seen by the client it is for yet so we need to + // keep track of it + + Abort *previouslySeenAbort = liveAbortTable->put(new Pair(entry->getAbortId()), entry); + if (previouslySeenAbort != NULL) { + previouslySeenAbort->setDead(); // Delete old version of the abort since we got a rescued newer version + } + + if (entry->getTransactionArbitrator() == localMachineId) { + liveAbortsGeneratedByLocal->put(entry->getArbitratorLocalSequenceNumber(), entry); + } + + if ((entry->getSequenceNumber() != -1) && (lastMessageTable->get(entry->getTransactionMachineId())->getFirst() >= entry->getSequenceNumber())) { + // The machine already saw this so it is dead + entry->setDead(); + Pair abortid = entry->getAbortId(); + liveAbortTable->remove(&abortid); + + if (entry->getTransactionArbitrator() == localMachineId) { + liveAbortsGeneratedByLocal->remove(entry->getArbitratorLocalSequenceNumber()); + } + return; + } + + // Update the last arbitration data that we have seen so far + if (lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->contains(entry->getTransactionArbitrator())) { + int64_t lastArbitrationSequenceNumber = lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->get(entry->getTransactionArbitrator()); + if (entry->getSequenceNumber() > lastArbitrationSequenceNumber) { + // Is larger + lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->put(entry->getTransactionArbitrator(), entry->getSequenceNumber()); + } + } else { + // Never seen any data from this arbitrator so record the first one + lastArbitrationDataLocalSequenceNumberSeenFromArbitrator->put(entry->getTransactionArbitrator(), entry->getSequenceNumber()); + } + + // Set dead a transaction if we can + Pair deadPair = Pair(entry->getTransactionMachineId(), entry->getTransactionClientLocalSequenceNumber()); + + Transaction *transactionToSetDead = liveTransactionByTransactionIdTable->remove(&deadPair); + if (transactionToSetDead != NULL) { + liveTransactionBySequenceNumberTable->remove(transactionToSetDead->getSequenceNumber()); + } + + // Update the last transaction sequence number that the arbitrator + // arbitrated on + if (!lastArbitratedTransactionNumberByArbitratorTable->contains(entry->getTransactionArbitrator()) || + (lastArbitratedTransactionNumberByArbitratorTable->get(entry->getTransactionArbitrator()) < entry->getTransactionSequenceNumber())) { + // Is a valid one + if (entry->getTransactionSequenceNumber() != -1) { + lastArbitratedTransactionNumberByArbitratorTable->put(entry->getTransactionArbitrator(), entry->getTransactionSequenceNumber()); + } + } +} + +/** + * Set dead the transaction part if that transaction is dead and keep + * track of all new parts + */ +void Table::processEntry(TransactionPart *entry) { + // Check if we have already seen this transaction and set it dead OR + // if it is not alive + if (lastArbitratedTransactionNumberByArbitratorTable->contains(entry->getArbitratorId()) && (lastArbitratedTransactionNumberByArbitratorTable->get(entry->getArbitratorId()) >= entry->getSequenceNumber())) { + // This transaction is dead, it was already committed or aborted + entry->setDead(); + return; + } + + // This part is still alive + Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals> *transactionPart = newTransactionParts->get(entry->getMachineId()); + + if (transactionPart == NULL) { + // Dont have a table for this machine Id yet so make one + transactionPart = new Hashtable *, TransactionPart *, uintptr_t, 0, pairHashFunction, pairEquals>(); + newTransactionParts->put(entry->getMachineId(), transactionPart); + } + + // Update the part and set dead ones we have already seen (got a + // rescued version) + entry->acquireRef(); + TransactionPart *previouslySeenPart = transactionPart->put(entry->getPartId(), entry); + if (previouslySeenPart != NULL) { + previouslySeenPart->releaseRef(); + previouslySeenPart->setDead(); + } +} + +/** + * Process new commit entries and save them for future use-> Delete duplicates + */ +void Table::processEntry(CommitPart *entry) { + // Update the last transaction that was updated if we can + if (entry->getTransactionSequenceNumber() != -1) { + if (!lastArbitratedTransactionNumberByArbitratorTable->contains(entry->getMachineId()) || + lastArbitratedTransactionNumberByArbitratorTable->get(entry->getMachineId()) < entry->getTransactionSequenceNumber()) { + lastArbitratedTransactionNumberByArbitratorTable->put(entry->getMachineId(), entry->getTransactionSequenceNumber()); + } + } + + Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals> *commitPart = newCommitParts->get(entry->getMachineId()); + if (commitPart == NULL) { + // Don't have a table for this machine Id yet so make one + commitPart = new Hashtable *, CommitPart *, uintptr_t, 0, pairHashFunction, pairEquals>(); + newCommitParts->put(entry->getMachineId(), commitPart); + } + // Update the part and set dead ones we have already seen (got a + // rescued version) + entry->acquireRef(); + CommitPart *previouslySeenPart = commitPart->put(entry->getPartId(), entry); + if (previouslySeenPart != NULL) { + previouslySeenPart->setDead(); + previouslySeenPart->releaseRef(); + } +} + +/** + * Update the last message seen table-> Update and set dead the + * appropriate RejectedMessages as clients see them-> Updates the live + * aborts, removes those that are dead and sets them dead-> Check that + * the last message seen is correct and that there is no mismatch of + * our own last message or that other clients have not had a rollback + * on the last message-> + */ +void Table::updateLastMessage(int64_t machineId, int64_t seqNum, Liveness *liveness, bool acceptUpdatesToLocal, Hashset *machineSet) { + // We have seen this machine ID + machineSet->remove(machineId); + + // Get the set of rejected messages that this machine Id is has not seen yet + Hashset *watchset = rejectedMessageWatchVectorTable->get(machineId); + // If there is a rejected message that this machine Id has not seen yet + if (watchset != NULL) { + // Go through each rejected message that this machine Id has not + // seen yet + + SetIterator *rmit = watchset->iterator(); + while (rmit->hasNext()) { + RejectedMessage *rm = rmit->next(); + // If this machine Id has seen this rejected message->->-> + if (rm->getSequenceNumber() <= seqNum) { + // Remove it from our watchlist + rmit->remove(); + // Decrement machines that need to see this notification + rm->removeWatcher(machineId); + } + } + delete rmit; + } + + // Set dead the abort + SetIterator *, Abort *, uintptr_t, 0, pairHashFunction, pairEquals> *abortit = getKeyIterator(liveAbortTable); + + while (abortit->hasNext()) { + Pair *key = abortit->next(); + Abort *abort = liveAbortTable->get(key); + if ((abort->getTransactionMachineId() == machineId) && (abort->getSequenceNumber() <= seqNum)) { + abort->setDead(); + abortit->remove(); + if (abort->getTransactionArbitrator() == localMachineId) { + liveAbortsGeneratedByLocal->remove(abort->getArbitratorLocalSequenceNumber()); + } + } + } + delete abortit; + if (machineId == localMachineId) { + // Our own messages are immediately dead-> + char livenessType = liveness->getType(); + if (livenessType == TypeLastMessage) { + ((LastMessage *)liveness)->setDead(); + } else if (livenessType == TypeSlot) { + ((Slot *)liveness)->setDead(); + } else { + throw new Error("Unrecognized type"); + } + } + // Get the old last message for this device + Pair *lastMessageEntry = lastMessageTable->put(machineId, new Pair(seqNum, liveness)); + if (lastMessageEntry == NULL) { + // If no last message then there is nothing else to process + return; + } + + int64_t lastMessageSeqNum = lastMessageEntry->getFirst(); + Liveness *lastEntry = lastMessageEntry->getSecond(); + delete lastMessageEntry; + + // If it is not our machine Id since we already set ours to dead + if (machineId != localMachineId) { + char lastEntryType = lastEntry->getType(); + + if (lastEntryType == TypeLastMessage) { + ((LastMessage *)lastEntry)->setDead(); + } else if (lastEntryType == TypeSlot) { + ((Slot *)lastEntry)->setDead(); + } else { + throw new Error("Unrecognized type"); + } + } + // Make sure the server is not playing any games + if (machineId == localMachineId) { + if (hadPartialSendToServer) { + // We were not making any updates and we had a machine mismatch + if (lastMessageSeqNum > seqNum && !acceptUpdatesToLocal) { + throw new Error("Server Error: Mismatch on local machine sequence number, needed at least: "); + } + } else { + // We were not making any updates and we had a machine mismatch + if (lastMessageSeqNum != seqNum && !acceptUpdatesToLocal) { + throw new Error("Server Error: Mismatch on local machine sequence number, needed: "); + } + } + } else { + if (lastMessageSeqNum > seqNum) { + throw new Error("Server Error: Rollback on remote machine sequence number"); + } + } +} + +/** + * Add a rejected message entry to the watch set to keep track of + * which clients have seen that rejected message entry and which have + * not. + */ +void Table::addWatchVector(int64_t machineId, RejectedMessage *entry) { + Hashset *entries = rejectedMessageWatchVectorTable->get(machineId); + if (entries == NULL) { + // There is no set for this machine ID yet so create one + entries = new Hashset(); + rejectedMessageWatchVectorTable->put(machineId, entries); + } + entries->add(entry); +} + +/** + * Check if the HMAC chain is not violated + */ +void Table::checkHMACChain(SlotIndexer *indexer, Array *newSlots) { + for (uint i = 0; i < newSlots->length(); i++) { + Slot *currSlot = newSlots->get(i); + Slot *prevSlot = indexer->getSlot(currSlot->getSequenceNumber() - 1); + if (prevSlot != NULL && + !prevSlot->getHMAC()->equals(currSlot->getPrevHMAC())) + throw new Error("Server Error: Invalid HMAC Chain"); + } +} diff --git a/version2/src/C/TableStatus.cc b/version2/src/C/TableStatus.cc deleted file mode 100644 index f61a6ae..0000000 --- a/version2/src/C/TableStatus.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "TableStatus.h" -#include "ByteBuffer.h" - -Entry *TableStatus_decode(Slot *slot, ByteBuffer *bb) { - int maxslots = bb->getInt(); - return new TableStatus(slot, maxslots); -} - -void TableStatus::encode(ByteBuffer *bb) { - bb->put(TypeTableStatus); - bb->putInt(maxslots); -} diff --git a/version2/src/C/TableStatus.cpp b/version2/src/C/TableStatus.cpp new file mode 100644 index 0000000..f61a6ae --- /dev/null +++ b/version2/src/C/TableStatus.cpp @@ -0,0 +1,12 @@ +#include "TableStatus.h" +#include "ByteBuffer.h" + +Entry *TableStatus_decode(Slot *slot, ByteBuffer *bb) { + int maxslots = bb->getInt(); + return new TableStatus(slot, maxslots); +} + +void TableStatus::encode(ByteBuffer *bb) { + bb->put(TypeTableStatus); + bb->putInt(maxslots); +} diff --git a/version2/src/C/Transaction.cc b/version2/src/C/Transaction.cc deleted file mode 100644 index b0d3c5b..0000000 --- a/version2/src/C/Transaction.cc +++ /dev/null @@ -1,347 +0,0 @@ -#include "Transaction.h" -#include "TransactionPart.h" -#include "KeyValue.h" -#include "ByteBuffer.h" -#include "IoTString.h" -#include "TransactionStatus.h" - -Transaction::Transaction() : - parts(new Vector()), - partCount(0), - missingParts(NULL), - partsPendingSend(new Vector()), - fldisComplete(false), - hasLastPart(false), - keyValueGuardSet(new Hashset()), - keyValueUpdateSet(new Hashset()), - isDead(false), - sequenceNumber(-1), - clientLocalSequenceNumber(-1), - arbitratorId(-1), - machineId(-1), - transactionId(Pair(0,0)), - nextPartToSend(0), - flddidSendAPartToServer(false), - transactionStatus(NULL), - hadServerFailure(false) { -} - -Transaction::~Transaction() { - if (missingParts) - delete missingParts; - { - uint Size = parts->size(); - for(uint i=0; iget(i)->releaseRef(); - delete parts; - } - { - SetIterator *kvit = keyValueGuardSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kvGuard = kvit->next(); - delete kvGuard; - } - delete kvit; - delete keyValueGuardSet; - } - { - SetIterator *kvit = keyValueUpdateSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kvUpdate = kvit->next(); - delete kvUpdate; - } - delete kvit; - delete keyValueUpdateSet; - } - delete partsPendingSend; -} - -void Transaction::addPartEncode(TransactionPart *newPart) { - newPart->acquireRef(); - printf("Add part %d\n", newPart->getPartNumber()); - TransactionPart *old = parts->setExpand(newPart->getPartNumber(), newPart); - if (old == NULL) { - partCount++; - } else { - old->releaseRef(); - } - partsPendingSend->add(newPart->getPartNumber()); - - sequenceNumber = newPart->getSequenceNumber(); - arbitratorId = newPart->getArbitratorId(); - transactionId = newPart->getTransactionId(); - clientLocalSequenceNumber = newPart->getClientLocalSequenceNumber(); - machineId = newPart->getMachineId(); - fldisComplete = true; -} - -void Transaction::addPartDecode(TransactionPart *newPart) { - if (isDead) { - // If dead then just kill this part and move on - newPart->setDead(); - return; - } - newPart->acquireRef(); - sequenceNumber = newPart->getSequenceNumber(); - arbitratorId = newPart->getArbitratorId(); - transactionId = newPart->getTransactionId(); - clientLocalSequenceNumber = newPart->getClientLocalSequenceNumber(); - machineId = newPart->getMachineId(); - - TransactionPart *previouslySeenPart = parts->setExpand(newPart->getPartNumber(), newPart); - if (previouslySeenPart == NULL) - partCount++; - - if (previouslySeenPart != NULL) { - // Set dead the old one since the new one is a rescued version of this part - previouslySeenPart->releaseRef(); - previouslySeenPart->setDead(); - } else if (newPart->isLastPart()) { - missingParts = new Hashset(); - hasLastPart = true; - - for (int i = 0; i < newPart->getPartNumber(); i++) { - if (parts->get(i) == NULL) { - missingParts->add(i); - } - } - } - - if (!fldisComplete && hasLastPart) { - - // We have seen this part so remove it from the set of missing parts - missingParts->remove(newPart->getPartNumber()); - - // Check if all the parts have been seen - if (missingParts->size() == 0) { - - // We have all the parts - fldisComplete = true; - - // Decode all the parts and create the key value guard and update sets - decodeTransactionData(); - } - } -} - -void Transaction::addUpdateKV(KeyValue *kv) { - keyValueUpdateSet->add(kv); -} - -void Transaction::addGuardKV(KeyValue *kv) { - keyValueGuardSet->add(kv); -} - - -int64_t Transaction::getSequenceNumber() { - return sequenceNumber; -} - -void Transaction::setSequenceNumber(int64_t _sequenceNumber) { - sequenceNumber = _sequenceNumber; - - for (uint32_t i = 0; i < parts->size(); i++) { - TransactionPart *tp = parts->get(i); - if (tp != NULL) - tp->setSequenceNumber(sequenceNumber); - } -} - -int64_t Transaction::getClientLocalSequenceNumber() { - return clientLocalSequenceNumber; -} - -Vector *Transaction::getParts() { - return parts; -} - -bool Transaction::didSendAPartToServer() { - return flddidSendAPartToServer; -} - -void Transaction::resetNextPartToSend() { - nextPartToSend = 0; -} - -TransactionPart *Transaction::getNextPartToSend() { - if ((partsPendingSend->size() == 0) || (partsPendingSend->size() == nextPartToSend)) { - return NULL; - } - TransactionPart *part = parts->get(partsPendingSend->get(nextPartToSend)); - nextPartToSend++; - return part; -} - - -void Transaction::setServerFailure() { - hadServerFailure = true; -} - -bool Transaction::getServerFailure() { - return hadServerFailure; -} - - -void Transaction::resetServerFailure() { - hadServerFailure = false; -} - - -void Transaction::setTransactionStatus(TransactionStatus *_transactionStatus) { - transactionStatus = _transactionStatus; -} - -TransactionStatus *Transaction::getTransactionStatus() { - return transactionStatus; -} - -void Transaction::removeSentParts(Vector *sentParts) { - nextPartToSend = 0; - bool changed = false; - uint lastusedindex = 0; - for (uint i = 0; i < partsPendingSend->size(); i++) { - int32_t parti = partsPendingSend->get(i); - for (uint j = 0; j < sentParts->size(); j++) { - int32_t partj = sentParts->get(j); - if (parti == partj) { - changed = true; - goto NextElement; - } - } - partsPendingSend->set(lastusedindex++, parti); -NextElement: - ; - } - if (changed) { - partsPendingSend->setSize(lastusedindex); - flddidSendAPartToServer = true; - transactionStatus->setTransactionSequenceNumber(sequenceNumber); - } -} - -bool Transaction::didSendAllParts() { - return partsPendingSend->isEmpty(); -} - -Hashset *Transaction::getKeyValueUpdateSet() { - return keyValueUpdateSet; -} - -int Transaction::getNumberOfParts() { - return partCount; -} - -int64_t Transaction::getMachineId() { - return machineId; -} - -int64_t Transaction::getArbitrator() { - return arbitratorId; -} - -bool Transaction::isComplete() { - return fldisComplete; -} - -Pair *Transaction::getId() { - return &transactionId; -} - -void Transaction::setDead() { - if (!isDead) { - // Set dead - isDead = true; - // Make all the parts of this transaction dead - for (uint32_t partNumber = 0; partNumber < parts->size(); partNumber++) { - TransactionPart *part = parts->get(partNumber); - if (part != NULL) - part->setDead(); - } - } -} - -void Transaction::decodeTransactionData() { - // Calculate the size of the data section - int dataSize = 0; - for (uint i = 0; i < parts->size(); i++) { - TransactionPart *tp = parts->get(i); - dataSize += tp->getDataSize(); - } - - Array *combinedData = new Array(dataSize); - int currentPosition = 0; - - // Stitch all the data sections together - for (uint i = 0; i < parts->size(); i++) { - TransactionPart *tp = parts->get(i); - System_arraycopy(tp->getData(), 0, combinedData, currentPosition, tp->getDataSize()); - currentPosition += tp->getDataSize(); - } - - // Decoder Object - ByteBuffer *bbDecode = ByteBuffer_wrap(combinedData); - - // Decode how many key value pairs need to be decoded - int numberOfKVGuards = bbDecode->getInt(); - int numberOfKVUpdates = bbDecode->getInt(); - - // Decode all the guard key values - for (int i = 0; i < numberOfKVGuards; i++) { - KeyValue *kv = (KeyValue *)KeyValue_decode(bbDecode); - keyValueGuardSet->add(kv); - } - - // Decode all the updates key values - for (int i = 0; i < numberOfKVUpdates; i++) { - KeyValue *kv = (KeyValue *)KeyValue_decode(bbDecode); - keyValueUpdateSet->add(kv); - } - delete bbDecode; -} - -bool Transaction::evaluateGuard(Hashtable *committedKeyValueTable, Hashtable *speculatedKeyValueTable, Hashtable *pendingTransactionSpeculatedKeyValueTable) { - SetIterator *kvit = keyValueGuardSet->iterator(); - while (kvit->hasNext()) { - KeyValue *kvGuard = kvit->next(); - // First check if the key is in the speculative table, this is the value of the latest assumption - KeyValue *kv = NULL; - - // If we have a speculation table then use it first - if (pendingTransactionSpeculatedKeyValueTable != NULL) { - kv = pendingTransactionSpeculatedKeyValueTable->get(kvGuard->getKey()); - } - - // If we have a speculation table then use it first - if ((kv == NULL) && (speculatedKeyValueTable != NULL)) { - kv = speculatedKeyValueTable->get(kvGuard->getKey()); - } - - if (kv == NULL) { - // if it is not in the speculative table then check the committed table and use that - // value as our latest assumption - kv = committedKeyValueTable->get(kvGuard->getKey()); - } - - if (kvGuard->getValue() != NULL) { - if ((kv == NULL) || (!kvGuard->getValue()->equals(kv->getValue()))) { - - - if (kv != NULL) { - printf("%s %s\n", kvGuard->getKey()->internalBytes()->internalArray(), kv->getValue()->internalBytes()->internalArray()); - } else { - printf("%s null\n", kvGuard->getValue()->internalBytes()->internalArray()); - } - delete kvit; - return false; - } - } else { - if (kv != NULL) { - delete kvit; - return false; - } - } - } - delete kvit; - return true; -} - diff --git a/version2/src/C/Transaction.cpp b/version2/src/C/Transaction.cpp new file mode 100644 index 0000000..b0d3c5b --- /dev/null +++ b/version2/src/C/Transaction.cpp @@ -0,0 +1,347 @@ +#include "Transaction.h" +#include "TransactionPart.h" +#include "KeyValue.h" +#include "ByteBuffer.h" +#include "IoTString.h" +#include "TransactionStatus.h" + +Transaction::Transaction() : + parts(new Vector()), + partCount(0), + missingParts(NULL), + partsPendingSend(new Vector()), + fldisComplete(false), + hasLastPart(false), + keyValueGuardSet(new Hashset()), + keyValueUpdateSet(new Hashset()), + isDead(false), + sequenceNumber(-1), + clientLocalSequenceNumber(-1), + arbitratorId(-1), + machineId(-1), + transactionId(Pair(0,0)), + nextPartToSend(0), + flddidSendAPartToServer(false), + transactionStatus(NULL), + hadServerFailure(false) { +} + +Transaction::~Transaction() { + if (missingParts) + delete missingParts; + { + uint Size = parts->size(); + for(uint i=0; iget(i)->releaseRef(); + delete parts; + } + { + SetIterator *kvit = keyValueGuardSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kvGuard = kvit->next(); + delete kvGuard; + } + delete kvit; + delete keyValueGuardSet; + } + { + SetIterator *kvit = keyValueUpdateSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kvUpdate = kvit->next(); + delete kvUpdate; + } + delete kvit; + delete keyValueUpdateSet; + } + delete partsPendingSend; +} + +void Transaction::addPartEncode(TransactionPart *newPart) { + newPart->acquireRef(); + printf("Add part %d\n", newPart->getPartNumber()); + TransactionPart *old = parts->setExpand(newPart->getPartNumber(), newPart); + if (old == NULL) { + partCount++; + } else { + old->releaseRef(); + } + partsPendingSend->add(newPart->getPartNumber()); + + sequenceNumber = newPart->getSequenceNumber(); + arbitratorId = newPart->getArbitratorId(); + transactionId = newPart->getTransactionId(); + clientLocalSequenceNumber = newPart->getClientLocalSequenceNumber(); + machineId = newPart->getMachineId(); + fldisComplete = true; +} + +void Transaction::addPartDecode(TransactionPart *newPart) { + if (isDead) { + // If dead then just kill this part and move on + newPart->setDead(); + return; + } + newPart->acquireRef(); + sequenceNumber = newPart->getSequenceNumber(); + arbitratorId = newPart->getArbitratorId(); + transactionId = newPart->getTransactionId(); + clientLocalSequenceNumber = newPart->getClientLocalSequenceNumber(); + machineId = newPart->getMachineId(); + + TransactionPart *previouslySeenPart = parts->setExpand(newPart->getPartNumber(), newPart); + if (previouslySeenPart == NULL) + partCount++; + + if (previouslySeenPart != NULL) { + // Set dead the old one since the new one is a rescued version of this part + previouslySeenPart->releaseRef(); + previouslySeenPart->setDead(); + } else if (newPart->isLastPart()) { + missingParts = new Hashset(); + hasLastPart = true; + + for (int i = 0; i < newPart->getPartNumber(); i++) { + if (parts->get(i) == NULL) { + missingParts->add(i); + } + } + } + + if (!fldisComplete && hasLastPart) { + + // We have seen this part so remove it from the set of missing parts + missingParts->remove(newPart->getPartNumber()); + + // Check if all the parts have been seen + if (missingParts->size() == 0) { + + // We have all the parts + fldisComplete = true; + + // Decode all the parts and create the key value guard and update sets + decodeTransactionData(); + } + } +} + +void Transaction::addUpdateKV(KeyValue *kv) { + keyValueUpdateSet->add(kv); +} + +void Transaction::addGuardKV(KeyValue *kv) { + keyValueGuardSet->add(kv); +} + + +int64_t Transaction::getSequenceNumber() { + return sequenceNumber; +} + +void Transaction::setSequenceNumber(int64_t _sequenceNumber) { + sequenceNumber = _sequenceNumber; + + for (uint32_t i = 0; i < parts->size(); i++) { + TransactionPart *tp = parts->get(i); + if (tp != NULL) + tp->setSequenceNumber(sequenceNumber); + } +} + +int64_t Transaction::getClientLocalSequenceNumber() { + return clientLocalSequenceNumber; +} + +Vector *Transaction::getParts() { + return parts; +} + +bool Transaction::didSendAPartToServer() { + return flddidSendAPartToServer; +} + +void Transaction::resetNextPartToSend() { + nextPartToSend = 0; +} + +TransactionPart *Transaction::getNextPartToSend() { + if ((partsPendingSend->size() == 0) || (partsPendingSend->size() == nextPartToSend)) { + return NULL; + } + TransactionPart *part = parts->get(partsPendingSend->get(nextPartToSend)); + nextPartToSend++; + return part; +} + + +void Transaction::setServerFailure() { + hadServerFailure = true; +} + +bool Transaction::getServerFailure() { + return hadServerFailure; +} + + +void Transaction::resetServerFailure() { + hadServerFailure = false; +} + + +void Transaction::setTransactionStatus(TransactionStatus *_transactionStatus) { + transactionStatus = _transactionStatus; +} + +TransactionStatus *Transaction::getTransactionStatus() { + return transactionStatus; +} + +void Transaction::removeSentParts(Vector *sentParts) { + nextPartToSend = 0; + bool changed = false; + uint lastusedindex = 0; + for (uint i = 0; i < partsPendingSend->size(); i++) { + int32_t parti = partsPendingSend->get(i); + for (uint j = 0; j < sentParts->size(); j++) { + int32_t partj = sentParts->get(j); + if (parti == partj) { + changed = true; + goto NextElement; + } + } + partsPendingSend->set(lastusedindex++, parti); +NextElement: + ; + } + if (changed) { + partsPendingSend->setSize(lastusedindex); + flddidSendAPartToServer = true; + transactionStatus->setTransactionSequenceNumber(sequenceNumber); + } +} + +bool Transaction::didSendAllParts() { + return partsPendingSend->isEmpty(); +} + +Hashset *Transaction::getKeyValueUpdateSet() { + return keyValueUpdateSet; +} + +int Transaction::getNumberOfParts() { + return partCount; +} + +int64_t Transaction::getMachineId() { + return machineId; +} + +int64_t Transaction::getArbitrator() { + return arbitratorId; +} + +bool Transaction::isComplete() { + return fldisComplete; +} + +Pair *Transaction::getId() { + return &transactionId; +} + +void Transaction::setDead() { + if (!isDead) { + // Set dead + isDead = true; + // Make all the parts of this transaction dead + for (uint32_t partNumber = 0; partNumber < parts->size(); partNumber++) { + TransactionPart *part = parts->get(partNumber); + if (part != NULL) + part->setDead(); + } + } +} + +void Transaction::decodeTransactionData() { + // Calculate the size of the data section + int dataSize = 0; + for (uint i = 0; i < parts->size(); i++) { + TransactionPart *tp = parts->get(i); + dataSize += tp->getDataSize(); + } + + Array *combinedData = new Array(dataSize); + int currentPosition = 0; + + // Stitch all the data sections together + for (uint i = 0; i < parts->size(); i++) { + TransactionPart *tp = parts->get(i); + System_arraycopy(tp->getData(), 0, combinedData, currentPosition, tp->getDataSize()); + currentPosition += tp->getDataSize(); + } + + // Decoder Object + ByteBuffer *bbDecode = ByteBuffer_wrap(combinedData); + + // Decode how many key value pairs need to be decoded + int numberOfKVGuards = bbDecode->getInt(); + int numberOfKVUpdates = bbDecode->getInt(); + + // Decode all the guard key values + for (int i = 0; i < numberOfKVGuards; i++) { + KeyValue *kv = (KeyValue *)KeyValue_decode(bbDecode); + keyValueGuardSet->add(kv); + } + + // Decode all the updates key values + for (int i = 0; i < numberOfKVUpdates; i++) { + KeyValue *kv = (KeyValue *)KeyValue_decode(bbDecode); + keyValueUpdateSet->add(kv); + } + delete bbDecode; +} + +bool Transaction::evaluateGuard(Hashtable *committedKeyValueTable, Hashtable *speculatedKeyValueTable, Hashtable *pendingTransactionSpeculatedKeyValueTable) { + SetIterator *kvit = keyValueGuardSet->iterator(); + while (kvit->hasNext()) { + KeyValue *kvGuard = kvit->next(); + // First check if the key is in the speculative table, this is the value of the latest assumption + KeyValue *kv = NULL; + + // If we have a speculation table then use it first + if (pendingTransactionSpeculatedKeyValueTable != NULL) { + kv = pendingTransactionSpeculatedKeyValueTable->get(kvGuard->getKey()); + } + + // If we have a speculation table then use it first + if ((kv == NULL) && (speculatedKeyValueTable != NULL)) { + kv = speculatedKeyValueTable->get(kvGuard->getKey()); + } + + if (kv == NULL) { + // if it is not in the speculative table then check the committed table and use that + // value as our latest assumption + kv = committedKeyValueTable->get(kvGuard->getKey()); + } + + if (kvGuard->getValue() != NULL) { + if ((kv == NULL) || (!kvGuard->getValue()->equals(kv->getValue()))) { + + + if (kv != NULL) { + printf("%s %s\n", kvGuard->getKey()->internalBytes()->internalArray(), kv->getValue()->internalBytes()->internalArray()); + } else { + printf("%s null\n", kvGuard->getValue()->internalBytes()->internalArray()); + } + delete kvit; + return false; + } + } else { + if (kv != NULL) { + delete kvit; + return false; + } + } + } + delete kvit; + return true; +} + diff --git a/version2/src/C/TransactionPart.cc b/version2/src/C/TransactionPart.cc deleted file mode 100644 index e8f4356..0000000 --- a/version2/src/C/TransactionPart.cc +++ /dev/null @@ -1,104 +0,0 @@ -#include "TransactionPart.h" -#include "ByteBuffer.h" - -int TransactionPart::getSize() { - if (data == NULL) { - return (4 * sizeof(int64_t)) + (2 * sizeof(int32_t)) + (2 * sizeof(char)); - } - return (4 * sizeof(int64_t)) + (2 * sizeof(int32_t)) + (2 * sizeof(char)) + data->length(); -} - -TransactionPart::~TransactionPart() { - delete data; -} - -Pair TransactionPart::getTransactionId() { - return transactionId; -} - -int64_t TransactionPart::getArbitratorId() { - return arbitratorId; -} - -Pair * TransactionPart::getPartId() { - return & partId; -} - -int TransactionPart::getPartNumber() { - return partNumber; -} - -int TransactionPart::getDataSize() { - return data->length(); -} - -Array *TransactionPart::getData() { - return data; -} - -bool TransactionPart::isLastPart() { - return fldisLastPart; -} - -int64_t TransactionPart::getMachineId() { - return machineId; -} - -int64_t TransactionPart::getClientLocalSequenceNumber() { - return clientLocalSequenceNumber; -} - -int64_t TransactionPart::getSequenceNumber() { - return sequenceNumber; -} - -void TransactionPart::setSequenceNumber(int64_t _sequenceNumber) { - sequenceNumber = _sequenceNumber; -} - -Entry *TransactionPart_decode(Slot *s, ByteBuffer *bb) { - int64_t sequenceNumber = bb->getLong(); - int64_t machineId = bb->getLong(); - int64_t arbitratorId = bb->getLong(); - int64_t clientLocalSequenceNumber = bb->getLong(); - int partNumber = bb->getInt(); - int dataSize = bb->getInt(); - bool isLastPart = (bb->get() == 1); - // Get the data - Array *data = new Array(dataSize); - bb->get(data); - - TransactionPart *returnTransactionPart = new TransactionPart(s, machineId, arbitratorId, clientLocalSequenceNumber, partNumber, data, isLastPart); - returnTransactionPart->setSequenceNumber(sequenceNumber); - - return returnTransactionPart; -} - -void TransactionPart::encode(ByteBuffer *bb) { - bb->put(TypeTransactionPart); - bb->putLong(sequenceNumber); - bb->putLong(machineId); - bb->putLong(arbitratorId); - bb->putLong(clientLocalSequenceNumber); - bb->putInt(partNumber); - bb->putInt(data->length()); - - if (fldisLastPart) { - bb->put((char)1); - } else { - bb->put((char)0); - } - - bb->put(data); -} - -char TransactionPart::getType() { - return TypeTransactionPart; -} - -Entry *TransactionPart::getCopy(Slot *s) { - TransactionPart *copyTransaction = new TransactionPart(s, machineId, arbitratorId, clientLocalSequenceNumber, partNumber, new Array(data), fldisLastPart); - copyTransaction->setSequenceNumber(sequenceNumber); - - return copyTransaction; -} diff --git a/version2/src/C/TransactionPart.cpp b/version2/src/C/TransactionPart.cpp new file mode 100644 index 0000000..e8f4356 --- /dev/null +++ b/version2/src/C/TransactionPart.cpp @@ -0,0 +1,104 @@ +#include "TransactionPart.h" +#include "ByteBuffer.h" + +int TransactionPart::getSize() { + if (data == NULL) { + return (4 * sizeof(int64_t)) + (2 * sizeof(int32_t)) + (2 * sizeof(char)); + } + return (4 * sizeof(int64_t)) + (2 * sizeof(int32_t)) + (2 * sizeof(char)) + data->length(); +} + +TransactionPart::~TransactionPart() { + delete data; +} + +Pair TransactionPart::getTransactionId() { + return transactionId; +} + +int64_t TransactionPart::getArbitratorId() { + return arbitratorId; +} + +Pair * TransactionPart::getPartId() { + return & partId; +} + +int TransactionPart::getPartNumber() { + return partNumber; +} + +int TransactionPart::getDataSize() { + return data->length(); +} + +Array *TransactionPart::getData() { + return data; +} + +bool TransactionPart::isLastPart() { + return fldisLastPart; +} + +int64_t TransactionPart::getMachineId() { + return machineId; +} + +int64_t TransactionPart::getClientLocalSequenceNumber() { + return clientLocalSequenceNumber; +} + +int64_t TransactionPart::getSequenceNumber() { + return sequenceNumber; +} + +void TransactionPart::setSequenceNumber(int64_t _sequenceNumber) { + sequenceNumber = _sequenceNumber; +} + +Entry *TransactionPart_decode(Slot *s, ByteBuffer *bb) { + int64_t sequenceNumber = bb->getLong(); + int64_t machineId = bb->getLong(); + int64_t arbitratorId = bb->getLong(); + int64_t clientLocalSequenceNumber = bb->getLong(); + int partNumber = bb->getInt(); + int dataSize = bb->getInt(); + bool isLastPart = (bb->get() == 1); + // Get the data + Array *data = new Array(dataSize); + bb->get(data); + + TransactionPart *returnTransactionPart = new TransactionPart(s, machineId, arbitratorId, clientLocalSequenceNumber, partNumber, data, isLastPart); + returnTransactionPart->setSequenceNumber(sequenceNumber); + + return returnTransactionPart; +} + +void TransactionPart::encode(ByteBuffer *bb) { + bb->put(TypeTransactionPart); + bb->putLong(sequenceNumber); + bb->putLong(machineId); + bb->putLong(arbitratorId); + bb->putLong(clientLocalSequenceNumber); + bb->putInt(partNumber); + bb->putInt(data->length()); + + if (fldisLastPart) { + bb->put((char)1); + } else { + bb->put((char)0); + } + + bb->put(data); +} + +char TransactionPart::getType() { + return TypeTransactionPart; +} + +Entry *TransactionPart::getCopy(Slot *s) { + TransactionPart *copyTransaction = new TransactionPart(s, machineId, arbitratorId, clientLocalSequenceNumber, partNumber, new Array(data), fldisLastPart); + copyTransaction->setSequenceNumber(sequenceNumber); + + return copyTransaction; +} diff --git a/version2/src/C/aes.cc b/version2/src/C/aes.cc deleted file mode 100644 index a917c7f..0000000 --- a/version2/src/C/aes.cc +++ /dev/null @@ -1,1095 +0,0 @@ -/********************************************************************* -* Filename: aes.c -* Author: Brad Conte (brad AT bradconte.com) -* Copyright: -* Disclaimer: This code is presented "as is" without any guarantees. -* Details: This code is the implementation of the AES algorithm and - the CTR, CBC, and CCM modes of operation it can be used in. - AES is, specified by the NIST in in publication FIPS PUB 197, - availible at: -* http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf . - The CBC and CTR modes of operation are specified by - NIST SP 800-38 A, available at: -* http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf . - The CCM mode of operation is specified by NIST SP80-38 C, available at: -* http://csrc.nist.gov/publications/nistpubs/800-38C/SP800-38C_updated-July20_2007.pdf -*********************************************************************/ - -/*************************** HEADER FILES ***************************/ -#include -#include -#include "aes.h" - -#include - -/****************************** MACROS ******************************/ -// The least significant byte of the word is rotated to the end. -#define KE_ROTWORD(x) (((x) << 8) | ((x) >> 24)) - -#define TRUE 1 -#define FALSE 0 - -/**************************** DATA TYPES ****************************/ -#define AES_128_ROUNDS 10 -#define AES_192_ROUNDS 12 -#define AES_256_ROUNDS 14 - -/*********************** FUNCTION DECLARATIONS **********************/ -void ccm_prepare_first_ctr_blk(BYTE counter[], const BYTE nonce[], int nonce_len, int payload_len_store_size); -void ccm_prepare_first_format_blk(BYTE buf[], int assoc_len, int payload_len, int payload_len_store_size, int mac_len, const BYTE nonce[], int nonce_len); -void ccm_format_assoc_data(BYTE buf[], int *end_of_buf, const BYTE assoc[], int assoc_len); -void ccm_format_payload_data(BYTE buf[], int *end_of_buf, const BYTE payload[], int payload_len); - -/**************************** VARIABLES *****************************/ -// This is the specified AES SBox. To look up a substitution value, put the first -// nibble in the first index (row) and the second nibble in the second index (column). -static const BYTE aes_sbox[16][16] = { - {0x63,0x7C,0x77,0x7B,0xF2,0x6B,0x6F,0xC5,0x30,0x01,0x67,0x2B,0xFE,0xD7,0xAB,0x76}, - {0xCA,0x82,0xC9,0x7D,0xFA,0x59,0x47,0xF0,0xAD,0xD4,0xA2,0xAF,0x9C,0xA4,0x72,0xC0}, - {0xB7,0xFD,0x93,0x26,0x36,0x3F,0xF7,0xCC,0x34,0xA5,0xE5,0xF1,0x71,0xD8,0x31,0x15}, - {0x04,0xC7,0x23,0xC3,0x18,0x96,0x05,0x9A,0x07,0x12,0x80,0xE2,0xEB,0x27,0xB2,0x75}, - {0x09,0x83,0x2C,0x1A,0x1B,0x6E,0x5A,0xA0,0x52,0x3B,0xD6,0xB3,0x29,0xE3,0x2F,0x84}, - {0x53,0xD1,0x00,0xED,0x20,0xFC,0xB1,0x5B,0x6A,0xCB,0xBE,0x39,0x4A,0x4C,0x58,0xCF}, - {0xD0,0xEF,0xAA,0xFB,0x43,0x4D,0x33,0x85,0x45,0xF9,0x02,0x7F,0x50,0x3C,0x9F,0xA8}, - {0x51,0xA3,0x40,0x8F,0x92,0x9D,0x38,0xF5,0xBC,0xB6,0xDA,0x21,0x10,0xFF,0xF3,0xD2}, - {0xCD,0x0C,0x13,0xEC,0x5F,0x97,0x44,0x17,0xC4,0xA7,0x7E,0x3D,0x64,0x5D,0x19,0x73}, - {0x60,0x81,0x4F,0xDC,0x22,0x2A,0x90,0x88,0x46,0xEE,0xB8,0x14,0xDE,0x5E,0x0B,0xDB}, - {0xE0,0x32,0x3A,0x0A,0x49,0x06,0x24,0x5C,0xC2,0xD3,0xAC,0x62,0x91,0x95,0xE4,0x79}, - {0xE7,0xC8,0x37,0x6D,0x8D,0xD5,0x4E,0xA9,0x6C,0x56,0xF4,0xEA,0x65,0x7A,0xAE,0x08}, - {0xBA,0x78,0x25,0x2E,0x1C,0xA6,0xB4,0xC6,0xE8,0xDD,0x74,0x1F,0x4B,0xBD,0x8B,0x8A}, - {0x70,0x3E,0xB5,0x66,0x48,0x03,0xF6,0x0E,0x61,0x35,0x57,0xB9,0x86,0xC1,0x1D,0x9E}, - {0xE1,0xF8,0x98,0x11,0x69,0xD9,0x8E,0x94,0x9B,0x1E,0x87,0xE9,0xCE,0x55,0x28,0xDF}, - {0x8C,0xA1,0x89,0x0D,0xBF,0xE6,0x42,0x68,0x41,0x99,0x2D,0x0F,0xB0,0x54,0xBB,0x16} -}; - -static const BYTE aes_invsbox[16][16] = { - {0x52,0x09,0x6A,0xD5,0x30,0x36,0xA5,0x38,0xBF,0x40,0xA3,0x9E,0x81,0xF3,0xD7,0xFB}, - {0x7C,0xE3,0x39,0x82,0x9B,0x2F,0xFF,0x87,0x34,0x8E,0x43,0x44,0xC4,0xDE,0xE9,0xCB}, - {0x54,0x7B,0x94,0x32,0xA6,0xC2,0x23,0x3D,0xEE,0x4C,0x95,0x0B,0x42,0xFA,0xC3,0x4E}, - {0x08,0x2E,0xA1,0x66,0x28,0xD9,0x24,0xB2,0x76,0x5B,0xA2,0x49,0x6D,0x8B,0xD1,0x25}, - {0x72,0xF8,0xF6,0x64,0x86,0x68,0x98,0x16,0xD4,0xA4,0x5C,0xCC,0x5D,0x65,0xB6,0x92}, - {0x6C,0x70,0x48,0x50,0xFD,0xED,0xB9,0xDA,0x5E,0x15,0x46,0x57,0xA7,0x8D,0x9D,0x84}, - {0x90,0xD8,0xAB,0x00,0x8C,0xBC,0xD3,0x0A,0xF7,0xE4,0x58,0x05,0xB8,0xB3,0x45,0x06}, - {0xD0,0x2C,0x1E,0x8F,0xCA,0x3F,0x0F,0x02,0xC1,0xAF,0xBD,0x03,0x01,0x13,0x8A,0x6B}, - {0x3A,0x91,0x11,0x41,0x4F,0x67,0xDC,0xEA,0x97,0xF2,0xCF,0xCE,0xF0,0xB4,0xE6,0x73}, - {0x96,0xAC,0x74,0x22,0xE7,0xAD,0x35,0x85,0xE2,0xF9,0x37,0xE8,0x1C,0x75,0xDF,0x6E}, - {0x47,0xF1,0x1A,0x71,0x1D,0x29,0xC5,0x89,0x6F,0xB7,0x62,0x0E,0xAA,0x18,0xBE,0x1B}, - {0xFC,0x56,0x3E,0x4B,0xC6,0xD2,0x79,0x20,0x9A,0xDB,0xC0,0xFE,0x78,0xCD,0x5A,0xF4}, - {0x1F,0xDD,0xA8,0x33,0x88,0x07,0xC7,0x31,0xB1,0x12,0x10,0x59,0x27,0x80,0xEC,0x5F}, - {0x60,0x51,0x7F,0xA9,0x19,0xB5,0x4A,0x0D,0x2D,0xE5,0x7A,0x9F,0x93,0xC9,0x9C,0xEF}, - {0xA0,0xE0,0x3B,0x4D,0xAE,0x2A,0xF5,0xB0,0xC8,0xEB,0xBB,0x3C,0x83,0x53,0x99,0x61}, - {0x17,0x2B,0x04,0x7E,0xBA,0x77,0xD6,0x26,0xE1,0x69,0x14,0x63,0x55,0x21,0x0C,0x7D} -}; - -// This table stores pre-calculated values for all possible GF(2^8) calculations.This -// table is only used by the (Inv)MixColumns steps. -// USAGE: The second index (column) is the coefficient of multiplication. Only 7 different -// coefficients are used: 0x01, 0x02, 0x03, 0x09, 0x0b, 0x0d, 0x0e, but multiplication by -// 1 is negligible leaving only 6 coefficients. Each column of the table is devoted to one -// of these coefficients, in the ascending order of value, from values 0x00 to 0xFF. -static const BYTE gf_mul[256][6] = { - {0x00,0x00,0x00,0x00,0x00,0x00},{0x02,0x03,0x09,0x0b,0x0d,0x0e}, - {0x04,0x06,0x12,0x16,0x1a,0x1c},{0x06,0x05,0x1b,0x1d,0x17,0x12}, - {0x08,0x0c,0x24,0x2c,0x34,0x38},{0x0a,0x0f,0x2d,0x27,0x39,0x36}, - {0x0c,0x0a,0x36,0x3a,0x2e,0x24},{0x0e,0x09,0x3f,0x31,0x23,0x2a}, - {0x10,0x18,0x48,0x58,0x68,0x70},{0x12,0x1b,0x41,0x53,0x65,0x7e}, - {0x14,0x1e,0x5a,0x4e,0x72,0x6c},{0x16,0x1d,0x53,0x45,0x7f,0x62}, - {0x18,0x14,0x6c,0x74,0x5c,0x48},{0x1a,0x17,0x65,0x7f,0x51,0x46}, - {0x1c,0x12,0x7e,0x62,0x46,0x54},{0x1e,0x11,0x77,0x69,0x4b,0x5a}, - {0x20,0x30,0x90,0xb0,0xd0,0xe0},{0x22,0x33,0x99,0xbb,0xdd,0xee}, - {0x24,0x36,0x82,0xa6,0xca,0xfc},{0x26,0x35,0x8b,0xad,0xc7,0xf2}, - {0x28,0x3c,0xb4,0x9c,0xe4,0xd8},{0x2a,0x3f,0xbd,0x97,0xe9,0xd6}, - {0x2c,0x3a,0xa6,0x8a,0xfe,0xc4},{0x2e,0x39,0xaf,0x81,0xf3,0xca}, - {0x30,0x28,0xd8,0xe8,0xb8,0x90},{0x32,0x2b,0xd1,0xe3,0xb5,0x9e}, - {0x34,0x2e,0xca,0xfe,0xa2,0x8c},{0x36,0x2d,0xc3,0xf5,0xaf,0x82}, - {0x38,0x24,0xfc,0xc4,0x8c,0xa8},{0x3a,0x27,0xf5,0xcf,0x81,0xa6}, - {0x3c,0x22,0xee,0xd2,0x96,0xb4},{0x3e,0x21,0xe7,0xd9,0x9b,0xba}, - {0x40,0x60,0x3b,0x7b,0xbb,0xdb},{0x42,0x63,0x32,0x70,0xb6,0xd5}, - {0x44,0x66,0x29,0x6d,0xa1,0xc7},{0x46,0x65,0x20,0x66,0xac,0xc9}, - {0x48,0x6c,0x1f,0x57,0x8f,0xe3},{0x4a,0x6f,0x16,0x5c,0x82,0xed}, - {0x4c,0x6a,0x0d,0x41,0x95,0xff},{0x4e,0x69,0x04,0x4a,0x98,0xf1}, - {0x50,0x78,0x73,0x23,0xd3,0xab},{0x52,0x7b,0x7a,0x28,0xde,0xa5}, - {0x54,0x7e,0x61,0x35,0xc9,0xb7},{0x56,0x7d,0x68,0x3e,0xc4,0xb9}, - {0x58,0x74,0x57,0x0f,0xe7,0x93},{0x5a,0x77,0x5e,0x04,0xea,0x9d}, - {0x5c,0x72,0x45,0x19,0xfd,0x8f},{0x5e,0x71,0x4c,0x12,0xf0,0x81}, - {0x60,0x50,0xab,0xcb,0x6b,0x3b},{0x62,0x53,0xa2,0xc0,0x66,0x35}, - {0x64,0x56,0xb9,0xdd,0x71,0x27},{0x66,0x55,0xb0,0xd6,0x7c,0x29}, - {0x68,0x5c,0x8f,0xe7,0x5f,0x03},{0x6a,0x5f,0x86,0xec,0x52,0x0d}, - {0x6c,0x5a,0x9d,0xf1,0x45,0x1f},{0x6e,0x59,0x94,0xfa,0x48,0x11}, - {0x70,0x48,0xe3,0x93,0x03,0x4b},{0x72,0x4b,0xea,0x98,0x0e,0x45}, - {0x74,0x4e,0xf1,0x85,0x19,0x57},{0x76,0x4d,0xf8,0x8e,0x14,0x59}, - {0x78,0x44,0xc7,0xbf,0x37,0x73},{0x7a,0x47,0xce,0xb4,0x3a,0x7d}, - {0x7c,0x42,0xd5,0xa9,0x2d,0x6f},{0x7e,0x41,0xdc,0xa2,0x20,0x61}, - {0x80,0xc0,0x76,0xf6,0x6d,0xad},{0x82,0xc3,0x7f,0xfd,0x60,0xa3}, - {0x84,0xc6,0x64,0xe0,0x77,0xb1},{0x86,0xc5,0x6d,0xeb,0x7a,0xbf}, - {0x88,0xcc,0x52,0xda,0x59,0x95},{0x8a,0xcf,0x5b,0xd1,0x54,0x9b}, - {0x8c,0xca,0x40,0xcc,0x43,0x89},{0x8e,0xc9,0x49,0xc7,0x4e,0x87}, - {0x90,0xd8,0x3e,0xae,0x05,0xdd},{0x92,0xdb,0x37,0xa5,0x08,0xd3}, - {0x94,0xde,0x2c,0xb8,0x1f,0xc1},{0x96,0xdd,0x25,0xb3,0x12,0xcf}, - {0x98,0xd4,0x1a,0x82,0x31,0xe5},{0x9a,0xd7,0x13,0x89,0x3c,0xeb}, - {0x9c,0xd2,0x08,0x94,0x2b,0xf9},{0x9e,0xd1,0x01,0x9f,0x26,0xf7}, - {0xa0,0xf0,0xe6,0x46,0xbd,0x4d},{0xa2,0xf3,0xef,0x4d,0xb0,0x43}, - {0xa4,0xf6,0xf4,0x50,0xa7,0x51},{0xa6,0xf5,0xfd,0x5b,0xaa,0x5f}, - {0xa8,0xfc,0xc2,0x6a,0x89,0x75},{0xaa,0xff,0xcb,0x61,0x84,0x7b}, - {0xac,0xfa,0xd0,0x7c,0x93,0x69},{0xae,0xf9,0xd9,0x77,0x9e,0x67}, - {0xb0,0xe8,0xae,0x1e,0xd5,0x3d},{0xb2,0xeb,0xa7,0x15,0xd8,0x33}, - {0xb4,0xee,0xbc,0x08,0xcf,0x21},{0xb6,0xed,0xb5,0x03,0xc2,0x2f}, - {0xb8,0xe4,0x8a,0x32,0xe1,0x05},{0xba,0xe7,0x83,0x39,0xec,0x0b}, - {0xbc,0xe2,0x98,0x24,0xfb,0x19},{0xbe,0xe1,0x91,0x2f,0xf6,0x17}, - {0xc0,0xa0,0x4d,0x8d,0xd6,0x76},{0xc2,0xa3,0x44,0x86,0xdb,0x78}, - {0xc4,0xa6,0x5f,0x9b,0xcc,0x6a},{0xc6,0xa5,0x56,0x90,0xc1,0x64}, - {0xc8,0xac,0x69,0xa1,0xe2,0x4e},{0xca,0xaf,0x60,0xaa,0xef,0x40}, - {0xcc,0xaa,0x7b,0xb7,0xf8,0x52},{0xce,0xa9,0x72,0xbc,0xf5,0x5c}, - {0xd0,0xb8,0x05,0xd5,0xbe,0x06},{0xd2,0xbb,0x0c,0xde,0xb3,0x08}, - {0xd4,0xbe,0x17,0xc3,0xa4,0x1a},{0xd6,0xbd,0x1e,0xc8,0xa9,0x14}, - {0xd8,0xb4,0x21,0xf9,0x8a,0x3e},{0xda,0xb7,0x28,0xf2,0x87,0x30}, - {0xdc,0xb2,0x33,0xef,0x90,0x22},{0xde,0xb1,0x3a,0xe4,0x9d,0x2c}, - {0xe0,0x90,0xdd,0x3d,0x06,0x96},{0xe2,0x93,0xd4,0x36,0x0b,0x98}, - {0xe4,0x96,0xcf,0x2b,0x1c,0x8a},{0xe6,0x95,0xc6,0x20,0x11,0x84}, - {0xe8,0x9c,0xf9,0x11,0x32,0xae},{0xea,0x9f,0xf0,0x1a,0x3f,0xa0}, - {0xec,0x9a,0xeb,0x07,0x28,0xb2},{0xee,0x99,0xe2,0x0c,0x25,0xbc}, - {0xf0,0x88,0x95,0x65,0x6e,0xe6},{0xf2,0x8b,0x9c,0x6e,0x63,0xe8}, - {0xf4,0x8e,0x87,0x73,0x74,0xfa},{0xf6,0x8d,0x8e,0x78,0x79,0xf4}, - {0xf8,0x84,0xb1,0x49,0x5a,0xde},{0xfa,0x87,0xb8,0x42,0x57,0xd0}, - {0xfc,0x82,0xa3,0x5f,0x40,0xc2},{0xfe,0x81,0xaa,0x54,0x4d,0xcc}, - {0x1b,0x9b,0xec,0xf7,0xda,0x41},{0x19,0x98,0xe5,0xfc,0xd7,0x4f}, - {0x1f,0x9d,0xfe,0xe1,0xc0,0x5d},{0x1d,0x9e,0xf7,0xea,0xcd,0x53}, - {0x13,0x97,0xc8,0xdb,0xee,0x79},{0x11,0x94,0xc1,0xd0,0xe3,0x77}, - {0x17,0x91,0xda,0xcd,0xf4,0x65},{0x15,0x92,0xd3,0xc6,0xf9,0x6b}, - {0x0b,0x83,0xa4,0xaf,0xb2,0x31},{0x09,0x80,0xad,0xa4,0xbf,0x3f}, - {0x0f,0x85,0xb6,0xb9,0xa8,0x2d},{0x0d,0x86,0xbf,0xb2,0xa5,0x23}, - {0x03,0x8f,0x80,0x83,0x86,0x09},{0x01,0x8c,0x89,0x88,0x8b,0x07}, - {0x07,0x89,0x92,0x95,0x9c,0x15},{0x05,0x8a,0x9b,0x9e,0x91,0x1b}, - {0x3b,0xab,0x7c,0x47,0x0a,0xa1},{0x39,0xa8,0x75,0x4c,0x07,0xaf}, - {0x3f,0xad,0x6e,0x51,0x10,0xbd},{0x3d,0xae,0x67,0x5a,0x1d,0xb3}, - {0x33,0xa7,0x58,0x6b,0x3e,0x99},{0x31,0xa4,0x51,0x60,0x33,0x97}, - {0x37,0xa1,0x4a,0x7d,0x24,0x85},{0x35,0xa2,0x43,0x76,0x29,0x8b}, - {0x2b,0xb3,0x34,0x1f,0x62,0xd1},{0x29,0xb0,0x3d,0x14,0x6f,0xdf}, - {0x2f,0xb5,0x26,0x09,0x78,0xcd},{0x2d,0xb6,0x2f,0x02,0x75,0xc3}, - {0x23,0xbf,0x10,0x33,0x56,0xe9},{0x21,0xbc,0x19,0x38,0x5b,0xe7}, - {0x27,0xb9,0x02,0x25,0x4c,0xf5},{0x25,0xba,0x0b,0x2e,0x41,0xfb}, - {0x5b,0xfb,0xd7,0x8c,0x61,0x9a},{0x59,0xf8,0xde,0x87,0x6c,0x94}, - {0x5f,0xfd,0xc5,0x9a,0x7b,0x86},{0x5d,0xfe,0xcc,0x91,0x76,0x88}, - {0x53,0xf7,0xf3,0xa0,0x55,0xa2},{0x51,0xf4,0xfa,0xab,0x58,0xac}, - {0x57,0xf1,0xe1,0xb6,0x4f,0xbe},{0x55,0xf2,0xe8,0xbd,0x42,0xb0}, - {0x4b,0xe3,0x9f,0xd4,0x09,0xea},{0x49,0xe0,0x96,0xdf,0x04,0xe4}, - {0x4f,0xe5,0x8d,0xc2,0x13,0xf6},{0x4d,0xe6,0x84,0xc9,0x1e,0xf8}, - {0x43,0xef,0xbb,0xf8,0x3d,0xd2},{0x41,0xec,0xb2,0xf3,0x30,0xdc}, - {0x47,0xe9,0xa9,0xee,0x27,0xce},{0x45,0xea,0xa0,0xe5,0x2a,0xc0}, - {0x7b,0xcb,0x47,0x3c,0xb1,0x7a},{0x79,0xc8,0x4e,0x37,0xbc,0x74}, - {0x7f,0xcd,0x55,0x2a,0xab,0x66},{0x7d,0xce,0x5c,0x21,0xa6,0x68}, - {0x73,0xc7,0x63,0x10,0x85,0x42},{0x71,0xc4,0x6a,0x1b,0x88,0x4c}, - {0x77,0xc1,0x71,0x06,0x9f,0x5e},{0x75,0xc2,0x78,0x0d,0x92,0x50}, - {0x6b,0xd3,0x0f,0x64,0xd9,0x0a},{0x69,0xd0,0x06,0x6f,0xd4,0x04}, - {0x6f,0xd5,0x1d,0x72,0xc3,0x16},{0x6d,0xd6,0x14,0x79,0xce,0x18}, - {0x63,0xdf,0x2b,0x48,0xed,0x32},{0x61,0xdc,0x22,0x43,0xe0,0x3c}, - {0x67,0xd9,0x39,0x5e,0xf7,0x2e},{0x65,0xda,0x30,0x55,0xfa,0x20}, - {0x9b,0x5b,0x9a,0x01,0xb7,0xec},{0x99,0x58,0x93,0x0a,0xba,0xe2}, - {0x9f,0x5d,0x88,0x17,0xad,0xf0},{0x9d,0x5e,0x81,0x1c,0xa0,0xfe}, - {0x93,0x57,0xbe,0x2d,0x83,0xd4},{0x91,0x54,0xb7,0x26,0x8e,0xda}, - {0x97,0x51,0xac,0x3b,0x99,0xc8},{0x95,0x52,0xa5,0x30,0x94,0xc6}, - {0x8b,0x43,0xd2,0x59,0xdf,0x9c},{0x89,0x40,0xdb,0x52,0xd2,0x92}, - {0x8f,0x45,0xc0,0x4f,0xc5,0x80},{0x8d,0x46,0xc9,0x44,0xc8,0x8e}, - {0x83,0x4f,0xf6,0x75,0xeb,0xa4},{0x81,0x4c,0xff,0x7e,0xe6,0xaa}, - {0x87,0x49,0xe4,0x63,0xf1,0xb8},{0x85,0x4a,0xed,0x68,0xfc,0xb6}, - {0xbb,0x6b,0x0a,0xb1,0x67,0x0c},{0xb9,0x68,0x03,0xba,0x6a,0x02}, - {0xbf,0x6d,0x18,0xa7,0x7d,0x10},{0xbd,0x6e,0x11,0xac,0x70,0x1e}, - {0xb3,0x67,0x2e,0x9d,0x53,0x34},{0xb1,0x64,0x27,0x96,0x5e,0x3a}, - {0xb7,0x61,0x3c,0x8b,0x49,0x28},{0xb5,0x62,0x35,0x80,0x44,0x26}, - {0xab,0x73,0x42,0xe9,0x0f,0x7c},{0xa9,0x70,0x4b,0xe2,0x02,0x72}, - {0xaf,0x75,0x50,0xff,0x15,0x60},{0xad,0x76,0x59,0xf4,0x18,0x6e}, - {0xa3,0x7f,0x66,0xc5,0x3b,0x44},{0xa1,0x7c,0x6f,0xce,0x36,0x4a}, - {0xa7,0x79,0x74,0xd3,0x21,0x58},{0xa5,0x7a,0x7d,0xd8,0x2c,0x56}, - {0xdb,0x3b,0xa1,0x7a,0x0c,0x37},{0xd9,0x38,0xa8,0x71,0x01,0x39}, - {0xdf,0x3d,0xb3,0x6c,0x16,0x2b},{0xdd,0x3e,0xba,0x67,0x1b,0x25}, - {0xd3,0x37,0x85,0x56,0x38,0x0f},{0xd1,0x34,0x8c,0x5d,0x35,0x01}, - {0xd7,0x31,0x97,0x40,0x22,0x13},{0xd5,0x32,0x9e,0x4b,0x2f,0x1d}, - {0xcb,0x23,0xe9,0x22,0x64,0x47},{0xc9,0x20,0xe0,0x29,0x69,0x49}, - {0xcf,0x25,0xfb,0x34,0x7e,0x5b},{0xcd,0x26,0xf2,0x3f,0x73,0x55}, - {0xc3,0x2f,0xcd,0x0e,0x50,0x7f},{0xc1,0x2c,0xc4,0x05,0x5d,0x71}, - {0xc7,0x29,0xdf,0x18,0x4a,0x63},{0xc5,0x2a,0xd6,0x13,0x47,0x6d}, - {0xfb,0x0b,0x31,0xca,0xdc,0xd7},{0xf9,0x08,0x38,0xc1,0xd1,0xd9}, - {0xff,0x0d,0x23,0xdc,0xc6,0xcb},{0xfd,0x0e,0x2a,0xd7,0xcb,0xc5}, - {0xf3,0x07,0x15,0xe6,0xe8,0xef},{0xf1,0x04,0x1c,0xed,0xe5,0xe1}, - {0xf7,0x01,0x07,0xf0,0xf2,0xf3},{0xf5,0x02,0x0e,0xfb,0xff,0xfd}, - {0xeb,0x13,0x79,0x92,0xb4,0xa7},{0xe9,0x10,0x70,0x99,0xb9,0xa9}, - {0xef,0x15,0x6b,0x84,0xae,0xbb},{0xed,0x16,0x62,0x8f,0xa3,0xb5}, - {0xe3,0x1f,0x5d,0xbe,0x80,0x9f},{0xe1,0x1c,0x54,0xb5,0x8d,0x91}, - {0xe7,0x19,0x4f,0xa8,0x9a,0x83},{0xe5,0x1a,0x46,0xa3,0x97,0x8d} -}; - -/*********************** FUNCTION DEFINITIONS ***********************/ -// XORs the in and out buffers, storing the result in out. Length is in bytes. -void xor_buf(const BYTE in[], BYTE out[], size_t len) -{ - size_t idx; - - for (idx = 0; idx < len; idx++) - out[idx] ^= in[idx]; -} - -/******************* -* AES - CBC -*******************/ -int aes_encrypt_cbc(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) -{ - BYTE buf_in[AES_BLOCK_SIZE], buf_out[AES_BLOCK_SIZE], iv_buf[AES_BLOCK_SIZE]; - int blocks, idx; - - if (in_len % AES_BLOCK_SIZE != 0) - return(FALSE); - - blocks = in_len / AES_BLOCK_SIZE; - - memcpy(iv_buf, iv, AES_BLOCK_SIZE); - - for (idx = 0; idx < blocks; idx++) { - memcpy(buf_in, &in[idx * AES_BLOCK_SIZE], AES_BLOCK_SIZE); - xor_buf(iv_buf, buf_in, AES_BLOCK_SIZE); - aes_encrypt(buf_in, buf_out, key, keysize); - memcpy(&out[idx * AES_BLOCK_SIZE], buf_out, AES_BLOCK_SIZE); - memcpy(iv_buf, buf_out, AES_BLOCK_SIZE); - } - - return(TRUE); -} - -int aes_encrypt_cbc_mac(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) -{ - BYTE buf_in[AES_BLOCK_SIZE], buf_out[AES_BLOCK_SIZE], iv_buf[AES_BLOCK_SIZE]; - int blocks, idx; - - if (in_len % AES_BLOCK_SIZE != 0) - return(FALSE); - - blocks = in_len / AES_BLOCK_SIZE; - - memcpy(iv_buf, iv, AES_BLOCK_SIZE); - - for (idx = 0; idx < blocks; idx++) { - memcpy(buf_in, &in[idx * AES_BLOCK_SIZE], AES_BLOCK_SIZE); - xor_buf(iv_buf, buf_in, AES_BLOCK_SIZE); - aes_encrypt(buf_in, buf_out, key, keysize); - memcpy(iv_buf, buf_out, AES_BLOCK_SIZE); - // Do not output all encrypted blocks. - } - - memcpy(out, buf_out, AES_BLOCK_SIZE); // Only output the last block. - - return(TRUE); -} - -int aes_decrypt_cbc(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) -{ - BYTE buf_in[AES_BLOCK_SIZE], buf_out[AES_BLOCK_SIZE], iv_buf[AES_BLOCK_SIZE]; - int blocks, idx; - - if (in_len % AES_BLOCK_SIZE != 0) - return(FALSE); - - blocks = in_len / AES_BLOCK_SIZE; - - memcpy(iv_buf, iv, AES_BLOCK_SIZE); - - for (idx = 0; idx < blocks; idx++) { - memcpy(buf_in, &in[idx * AES_BLOCK_SIZE], AES_BLOCK_SIZE); - aes_decrypt(buf_in, buf_out, key, keysize); - xor_buf(iv_buf, buf_out, AES_BLOCK_SIZE); - memcpy(&out[idx * AES_BLOCK_SIZE], buf_out, AES_BLOCK_SIZE); - memcpy(iv_buf, buf_in, AES_BLOCK_SIZE); - } - - return(TRUE); -} - -/******************* -* AES - CTR -*******************/ -void increment_iv(BYTE iv[], int counter_size) -{ - int idx; - - // Use counter_size bytes at the end of the IV as the big-endian integer to increment. - for (idx = AES_BLOCK_SIZE - 1; idx >= AES_BLOCK_SIZE - counter_size; idx--) { - iv[idx]++; - if (iv[idx] != 0 || idx == AES_BLOCK_SIZE - counter_size) - break; - } -} - -// Performs the encryption in-place, the input and output buffers may be the same. -// Input may be an arbitrary length (in bytes). -void aes_encrypt_ctr(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) -{ - size_t idx = 0, last_block_length; - BYTE iv_buf[AES_BLOCK_SIZE], out_buf[AES_BLOCK_SIZE]; - - if (in != out) - memcpy(out, in, in_len); - - memcpy(iv_buf, iv, AES_BLOCK_SIZE); - last_block_length = in_len - AES_BLOCK_SIZE; - - if (in_len > AES_BLOCK_SIZE) { - for (idx = 0; idx < last_block_length; idx += AES_BLOCK_SIZE) { - aes_encrypt(iv_buf, out_buf, key, keysize); - xor_buf(out_buf, &out[idx], AES_BLOCK_SIZE); - increment_iv(iv_buf, AES_BLOCK_SIZE); - } - } - - aes_encrypt(iv_buf, out_buf, key, keysize); - xor_buf(out_buf, &out[idx], in_len - idx); // Use the Most Significant bytes. -} - -void aes_decrypt_ctr(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) -{ - // CTR encryption is its own inverse function. - aes_encrypt_ctr(in, in_len, out, key, keysize, iv); -} - -/******************* -* AES - CCM -*******************/ -// out_len = payload_len + assoc_len -int aes_encrypt_ccm(const BYTE payload[], WORD payload_len, const BYTE assoc[], unsigned short assoc_len, - const BYTE nonce[], unsigned short nonce_len, BYTE out[], WORD *out_len, - WORD mac_len, const BYTE key_str[], int keysize) -{ - BYTE temp_iv[AES_BLOCK_SIZE], counter[AES_BLOCK_SIZE], mac[16], *buf; - int end_of_buf, payload_len_store_size; - WORD key[60]; - - if (mac_len != 4 && mac_len != 6 && mac_len != 8 && mac_len != 10 && - mac_len != 12 && mac_len != 14 && mac_len != 16) - return(FALSE); - - if (nonce_len < 7 || nonce_len > 13) - return(FALSE); - - if (assoc_len > 32768 /* = 2^15 */) - return(FALSE); - - buf = (BYTE *)malloc(payload_len + assoc_len + 48 /*Round both payload and associated data up a block size and add an extra block.*/); - if (!buf) - return(FALSE); - - // Prepare the key for usage. - aes_key_setup(key_str, key, keysize); - - // Format the first block of the formatted data. - payload_len_store_size = AES_BLOCK_SIZE - 1 - nonce_len; - ccm_prepare_first_format_blk(buf, assoc_len, payload_len, payload_len_store_size, mac_len, nonce, nonce_len); - end_of_buf = AES_BLOCK_SIZE; - - // Format the Associated Data, aka, assoc[]. - ccm_format_assoc_data(buf, &end_of_buf, assoc, assoc_len); - - // Format the Payload, aka payload[]. - ccm_format_payload_data(buf, &end_of_buf, payload, payload_len); - - // Create the first counter block. - ccm_prepare_first_ctr_blk(counter, nonce, nonce_len, payload_len_store_size); - - // Perform the CBC operation with an IV of zeros on the formatted buffer to calculate the MAC. - memset(temp_iv, 0, AES_BLOCK_SIZE); - aes_encrypt_cbc_mac(buf, end_of_buf, mac, key, keysize, temp_iv); - - // Copy the Payload and MAC to the output buffer. - memcpy(out, payload, payload_len); - memcpy(&out[payload_len], mac, mac_len); - - // Encrypt the Payload with CTR mode with a counter starting at 1. - memcpy(temp_iv, counter, AES_BLOCK_SIZE); - increment_iv(temp_iv, AES_BLOCK_SIZE - 1 - mac_len); // Last argument is the byte size of the counting portion of the counter block. /*BUG?*/ - aes_encrypt_ctr(out, payload_len, out, key, keysize, temp_iv); - - // Encrypt the MAC with CTR mode with a counter starting at 0. - aes_encrypt_ctr(&out[payload_len], mac_len, &out[payload_len], key, keysize, counter); - - free(buf); - *out_len = payload_len + mac_len; - - return(TRUE); -} - -// plaintext_len = ciphertext_len - mac_len -// Needs a flag for whether the MAC matches. -int aes_decrypt_ccm(const BYTE ciphertext[], WORD ciphertext_len, const BYTE assoc[], unsigned short assoc_len, - const BYTE nonce[], unsigned short nonce_len, BYTE plaintext[], WORD *plaintext_len, - WORD mac_len, int *mac_auth, const BYTE key_str[], int keysize) -{ - BYTE temp_iv[AES_BLOCK_SIZE], counter[AES_BLOCK_SIZE], mac[16], mac_buf[16], *buf; - int end_of_buf, plaintext_len_store_size; - WORD key[60]; - - if (ciphertext_len <= mac_len) - return(FALSE); - - buf = (BYTE *)malloc(assoc_len + ciphertext_len /*ciphertext_len = plaintext_len + mac_len*/ + 48); - if (!buf) - return(FALSE); - - // Prepare the key for usage. - aes_key_setup(key_str, key, keysize); - - // Copy the plaintext and MAC to the output buffers. - *plaintext_len = ciphertext_len - mac_len; - plaintext_len_store_size = AES_BLOCK_SIZE - 1 - nonce_len; - memcpy(plaintext, ciphertext, *plaintext_len); - memcpy(mac, &ciphertext[*plaintext_len], mac_len); - - // Prepare the first counter block for use in decryption. - ccm_prepare_first_ctr_blk(counter, nonce, nonce_len, plaintext_len_store_size); - - // Decrypt the Payload with CTR mode with a counter starting at 1. - memcpy(temp_iv, counter, AES_BLOCK_SIZE); - increment_iv(temp_iv, AES_BLOCK_SIZE - 1 - mac_len); // (AES_BLOCK_SIZE - 1 - mac_len) is the byte size of the counting portion of the counter block. - aes_decrypt_ctr(plaintext, *plaintext_len, plaintext, key, keysize, temp_iv); - - // Setting mac_auth to NULL disables the authentication check. - if (mac_auth != NULL) { - // Decrypt the MAC with CTR mode with a counter starting at 0. - aes_decrypt_ctr(mac, mac_len, mac, key, keysize, counter); - - // Format the first block of the formatted data. - plaintext_len_store_size = AES_BLOCK_SIZE - 1 - nonce_len; - ccm_prepare_first_format_blk(buf, assoc_len, *plaintext_len, plaintext_len_store_size, mac_len, nonce, nonce_len); - end_of_buf = AES_BLOCK_SIZE; - - // Format the Associated Data into the authentication buffer. - ccm_format_assoc_data(buf, &end_of_buf, assoc, assoc_len); - - // Format the Payload into the authentication buffer. - ccm_format_payload_data(buf, &end_of_buf, plaintext, *plaintext_len); - - // Perform the CBC operation with an IV of zeros on the formatted buffer to calculate the MAC. - memset(temp_iv, 0, AES_BLOCK_SIZE); - aes_encrypt_cbc_mac(buf, end_of_buf, mac_buf, key, keysize, temp_iv); - - // Compare the calculated MAC against the MAC embedded in the ciphertext to see if they are the same. - if (!memcmp(mac, mac_buf, mac_len)) { - *mac_auth = TRUE; - } - else { - *mac_auth = FALSE; - memset(plaintext, 0, *plaintext_len); - } - } - - free(buf); - - return(TRUE); -} - -// Creates the first counter block. First byte is flags, then the nonce, then the incremented part. -void ccm_prepare_first_ctr_blk(BYTE counter[], const BYTE nonce[], int nonce_len, int payload_len_store_size) -{ - memset(counter, 0, AES_BLOCK_SIZE); - counter[0] = (payload_len_store_size - 1) & 0x07; - memcpy(&counter[1], nonce, nonce_len); -} - -void ccm_prepare_first_format_blk(BYTE buf[], int assoc_len, int payload_len, int payload_len_store_size, int mac_len, const BYTE nonce[], int nonce_len) -{ - // Set the flags for the first byte of the first block. - buf[0] = ((((mac_len - 2) / 2) & 0x07) << 3) | ((payload_len_store_size - 1) & 0x07); - if (assoc_len > 0) - buf[0] += 0x40; - // Format the rest of the first block, storing the nonce and the size of the payload. - memcpy(&buf[1], nonce, nonce_len); - memset(&buf[1 + nonce_len], 0, AES_BLOCK_SIZE - 1 - nonce_len); - buf[15] = payload_len & 0x000000FF; - buf[14] = (payload_len >> 8) & 0x000000FF; -} - -void ccm_format_assoc_data(BYTE buf[], int *end_of_buf, const BYTE assoc[], int assoc_len) -{ - int pad; - - buf[*end_of_buf + 1] = assoc_len & 0x00FF; - buf[*end_of_buf] = (assoc_len >> 8) & 0x00FF; - *end_of_buf += 2; - memcpy(&buf[*end_of_buf], assoc, assoc_len); - *end_of_buf += assoc_len; - pad = AES_BLOCK_SIZE - (*end_of_buf % AES_BLOCK_SIZE);/*BUG?*/ - memset(&buf[*end_of_buf], 0, pad); - *end_of_buf += pad; -} - -void ccm_format_payload_data(BYTE buf[], int *end_of_buf, const BYTE payload[], int payload_len) -{ - int pad; - - memcpy(&buf[*end_of_buf], payload, payload_len); - *end_of_buf += payload_len; - pad = *end_of_buf % AES_BLOCK_SIZE; - if (pad != 0) - pad = AES_BLOCK_SIZE - pad; - memset(&buf[*end_of_buf], 0, pad); - *end_of_buf += pad; -} - -/******************* -* AES -*******************/ -///////////////// -// KEY EXPANSION -///////////////// - -// Substitutes a word using the AES S-Box. -WORD SubWord(WORD word) -{ - unsigned int result; - - result = (int)aes_sbox[(word >> 4) & 0x0000000F][word & 0x0000000F]; - result += (int)aes_sbox[(word >> 12) & 0x0000000F][(word >> 8) & 0x0000000F] << 8; - result += (int)aes_sbox[(word >> 20) & 0x0000000F][(word >> 16) & 0x0000000F] << 16; - result += (int)aes_sbox[(word >> 28) & 0x0000000F][(word >> 24) & 0x0000000F] << 24; - return(result); -} - -// Performs the action of generating the keys that will be used in every round of -// encryption. "key" is the user-supplied input key, "w" is the output key schedule, -// "keysize" is the length in bits of "key", must be 128, 192, or 256. -void aes_key_setup(const BYTE key[], WORD w[], int keysize) -{ - int Nb = 4,Nr,Nk,idx; - WORD temp,Rcon[] = {0x01000000,0x02000000,0x04000000,0x08000000,0x10000000,0x20000000, - 0x40000000,0x80000000,0x1b000000,0x36000000,0x6c000000,0xd8000000, - 0xab000000,0x4d000000,0x9a000000}; - - switch (keysize) { - case 128: Nr = 10; Nk = 4; break; - case 192: Nr = 12; Nk = 6; break; - case 256: Nr = 14; Nk = 8; break; - default: return; - } - - for (idx = 0; idx < Nk; ++idx) { - w[idx] = ((key[4 * idx]) << 24) | ((key[4 * idx + 1]) << 16) | - ((key[4 * idx + 2]) << 8) | ((key[4 * idx + 3])); - } - - for (idx = Nk; idx < Nb * (Nr + 1); ++idx) { - temp = w[idx - 1]; - if ((idx % Nk) == 0) - temp = SubWord(KE_ROTWORD(temp)) ^ Rcon[(idx - 1) / Nk]; - else if (Nk > 6 && (idx % Nk) == 4) - temp = SubWord(temp); - w[idx] = w[idx - Nk] ^ temp; - } -} - -///////////////// -// ADD ROUND KEY -///////////////// - -// Performs the AddRoundKey step. Each round has its own pre-generated 16-byte key in the -// form of 4 integers (the "w" array). Each integer is XOR'd by one column of the state. -// Also performs the job of InvAddRoundKey(); since the function is a simple XOR process, -// it is its own inverse. -void AddRoundKey(BYTE state[][4], const WORD w[]) -{ - BYTE subkey[4]; - - // memcpy(subkey,&w[idx],4); // Not accurate for big endian machines - // Subkey 1 - subkey[0] = w[0] >> 24; - subkey[1] = w[0] >> 16; - subkey[2] = w[0] >> 8; - subkey[3] = w[0]; - state[0][0] ^= subkey[0]; - state[1][0] ^= subkey[1]; - state[2][0] ^= subkey[2]; - state[3][0] ^= subkey[3]; - // Subkey 2 - subkey[0] = w[1] >> 24; - subkey[1] = w[1] >> 16; - subkey[2] = w[1] >> 8; - subkey[3] = w[1]; - state[0][1] ^= subkey[0]; - state[1][1] ^= subkey[1]; - state[2][1] ^= subkey[2]; - state[3][1] ^= subkey[3]; - // Subkey 3 - subkey[0] = w[2] >> 24; - subkey[1] = w[2] >> 16; - subkey[2] = w[2] >> 8; - subkey[3] = w[2]; - state[0][2] ^= subkey[0]; - state[1][2] ^= subkey[1]; - state[2][2] ^= subkey[2]; - state[3][2] ^= subkey[3]; - // Subkey 4 - subkey[0] = w[3] >> 24; - subkey[1] = w[3] >> 16; - subkey[2] = w[3] >> 8; - subkey[3] = w[3]; - state[0][3] ^= subkey[0]; - state[1][3] ^= subkey[1]; - state[2][3] ^= subkey[2]; - state[3][3] ^= subkey[3]; -} - -///////////////// -// (Inv)SubBytes -///////////////// - -// Performs the SubBytes step. All bytes in the state are substituted with a -// pre-calculated value from a lookup table. -void SubBytes(BYTE state[][4]) -{ - state[0][0] = aes_sbox[state[0][0] >> 4][state[0][0] & 0x0F]; - state[0][1] = aes_sbox[state[0][1] >> 4][state[0][1] & 0x0F]; - state[0][2] = aes_sbox[state[0][2] >> 4][state[0][2] & 0x0F]; - state[0][3] = aes_sbox[state[0][3] >> 4][state[0][3] & 0x0F]; - state[1][0] = aes_sbox[state[1][0] >> 4][state[1][0] & 0x0F]; - state[1][1] = aes_sbox[state[1][1] >> 4][state[1][1] & 0x0F]; - state[1][2] = aes_sbox[state[1][2] >> 4][state[1][2] & 0x0F]; - state[1][3] = aes_sbox[state[1][3] >> 4][state[1][3] & 0x0F]; - state[2][0] = aes_sbox[state[2][0] >> 4][state[2][0] & 0x0F]; - state[2][1] = aes_sbox[state[2][1] >> 4][state[2][1] & 0x0F]; - state[2][2] = aes_sbox[state[2][2] >> 4][state[2][2] & 0x0F]; - state[2][3] = aes_sbox[state[2][3] >> 4][state[2][3] & 0x0F]; - state[3][0] = aes_sbox[state[3][0] >> 4][state[3][0] & 0x0F]; - state[3][1] = aes_sbox[state[3][1] >> 4][state[3][1] & 0x0F]; - state[3][2] = aes_sbox[state[3][2] >> 4][state[3][2] & 0x0F]; - state[3][3] = aes_sbox[state[3][3] >> 4][state[3][3] & 0x0F]; -} - -void InvSubBytes(BYTE state[][4]) -{ - state[0][0] = aes_invsbox[state[0][0] >> 4][state[0][0] & 0x0F]; - state[0][1] = aes_invsbox[state[0][1] >> 4][state[0][1] & 0x0F]; - state[0][2] = aes_invsbox[state[0][2] >> 4][state[0][2] & 0x0F]; - state[0][3] = aes_invsbox[state[0][3] >> 4][state[0][3] & 0x0F]; - state[1][0] = aes_invsbox[state[1][0] >> 4][state[1][0] & 0x0F]; - state[1][1] = aes_invsbox[state[1][1] >> 4][state[1][1] & 0x0F]; - state[1][2] = aes_invsbox[state[1][2] >> 4][state[1][2] & 0x0F]; - state[1][3] = aes_invsbox[state[1][3] >> 4][state[1][3] & 0x0F]; - state[2][0] = aes_invsbox[state[2][0] >> 4][state[2][0] & 0x0F]; - state[2][1] = aes_invsbox[state[2][1] >> 4][state[2][1] & 0x0F]; - state[2][2] = aes_invsbox[state[2][2] >> 4][state[2][2] & 0x0F]; - state[2][3] = aes_invsbox[state[2][3] >> 4][state[2][3] & 0x0F]; - state[3][0] = aes_invsbox[state[3][0] >> 4][state[3][0] & 0x0F]; - state[3][1] = aes_invsbox[state[3][1] >> 4][state[3][1] & 0x0F]; - state[3][2] = aes_invsbox[state[3][2] >> 4][state[3][2] & 0x0F]; - state[3][3] = aes_invsbox[state[3][3] >> 4][state[3][3] & 0x0F]; -} - -///////////////// -// (Inv)ShiftRows -///////////////// - -// Performs the ShiftRows step. All rows are shifted cylindrically to the left. -void ShiftRows(BYTE state[][4]) -{ - int t; - - // Shift left by 1 - t = state[1][0]; - state[1][0] = state[1][1]; - state[1][1] = state[1][2]; - state[1][2] = state[1][3]; - state[1][3] = t; - // Shift left by 2 - t = state[2][0]; - state[2][0] = state[2][2]; - state[2][2] = t; - t = state[2][1]; - state[2][1] = state[2][3]; - state[2][3] = t; - // Shift left by 3 - t = state[3][0]; - state[3][0] = state[3][3]; - state[3][3] = state[3][2]; - state[3][2] = state[3][1]; - state[3][1] = t; -} - -// All rows are shifted cylindrically to the right. -void InvShiftRows(BYTE state[][4]) -{ - int t; - - // Shift right by 1 - t = state[1][3]; - state[1][3] = state[1][2]; - state[1][2] = state[1][1]; - state[1][1] = state[1][0]; - state[1][0] = t; - // Shift right by 2 - t = state[2][3]; - state[2][3] = state[2][1]; - state[2][1] = t; - t = state[2][2]; - state[2][2] = state[2][0]; - state[2][0] = t; - // Shift right by 3 - t = state[3][3]; - state[3][3] = state[3][0]; - state[3][0] = state[3][1]; - state[3][1] = state[3][2]; - state[3][2] = t; -} - -///////////////// -// (Inv)MixColumns -///////////////// - -// Performs the MixColums step. The state is multiplied by itself using matrix -// multiplication in a Galios Field 2^8. All multiplication is pre-computed in a table. -// Addition is equivilent to XOR. (Must always make a copy of the column as the original -// values will be destoyed.) -void MixColumns(BYTE state[][4]) -{ - BYTE col[4]; - - // Column 1 - col[0] = state[0][0]; - col[1] = state[1][0]; - col[2] = state[2][0]; - col[3] = state[3][0]; - state[0][0] = gf_mul[col[0]][0]; - state[0][0] ^= gf_mul[col[1]][1]; - state[0][0] ^= col[2]; - state[0][0] ^= col[3]; - state[1][0] = col[0]; - state[1][0] ^= gf_mul[col[1]][0]; - state[1][0] ^= gf_mul[col[2]][1]; - state[1][0] ^= col[3]; - state[2][0] = col[0]; - state[2][0] ^= col[1]; - state[2][0] ^= gf_mul[col[2]][0]; - state[2][0] ^= gf_mul[col[3]][1]; - state[3][0] = gf_mul[col[0]][1]; - state[3][0] ^= col[1]; - state[3][0] ^= col[2]; - state[3][0] ^= gf_mul[col[3]][0]; - // Column 2 - col[0] = state[0][1]; - col[1] = state[1][1]; - col[2] = state[2][1]; - col[3] = state[3][1]; - state[0][1] = gf_mul[col[0]][0]; - state[0][1] ^= gf_mul[col[1]][1]; - state[0][1] ^= col[2]; - state[0][1] ^= col[3]; - state[1][1] = col[0]; - state[1][1] ^= gf_mul[col[1]][0]; - state[1][1] ^= gf_mul[col[2]][1]; - state[1][1] ^= col[3]; - state[2][1] = col[0]; - state[2][1] ^= col[1]; - state[2][1] ^= gf_mul[col[2]][0]; - state[2][1] ^= gf_mul[col[3]][1]; - state[3][1] = gf_mul[col[0]][1]; - state[3][1] ^= col[1]; - state[3][1] ^= col[2]; - state[3][1] ^= gf_mul[col[3]][0]; - // Column 3 - col[0] = state[0][2]; - col[1] = state[1][2]; - col[2] = state[2][2]; - col[3] = state[3][2]; - state[0][2] = gf_mul[col[0]][0]; - state[0][2] ^= gf_mul[col[1]][1]; - state[0][2] ^= col[2]; - state[0][2] ^= col[3]; - state[1][2] = col[0]; - state[1][2] ^= gf_mul[col[1]][0]; - state[1][2] ^= gf_mul[col[2]][1]; - state[1][2] ^= col[3]; - state[2][2] = col[0]; - state[2][2] ^= col[1]; - state[2][2] ^= gf_mul[col[2]][0]; - state[2][2] ^= gf_mul[col[3]][1]; - state[3][2] = gf_mul[col[0]][1]; - state[3][2] ^= col[1]; - state[3][2] ^= col[2]; - state[3][2] ^= gf_mul[col[3]][0]; - // Column 4 - col[0] = state[0][3]; - col[1] = state[1][3]; - col[2] = state[2][3]; - col[3] = state[3][3]; - state[0][3] = gf_mul[col[0]][0]; - state[0][3] ^= gf_mul[col[1]][1]; - state[0][3] ^= col[2]; - state[0][3] ^= col[3]; - state[1][3] = col[0]; - state[1][3] ^= gf_mul[col[1]][0]; - state[1][3] ^= gf_mul[col[2]][1]; - state[1][3] ^= col[3]; - state[2][3] = col[0]; - state[2][3] ^= col[1]; - state[2][3] ^= gf_mul[col[2]][0]; - state[2][3] ^= gf_mul[col[3]][1]; - state[3][3] = gf_mul[col[0]][1]; - state[3][3] ^= col[1]; - state[3][3] ^= col[2]; - state[3][3] ^= gf_mul[col[3]][0]; -} - -void InvMixColumns(BYTE state[][4]) -{ - BYTE col[4]; - - // Column 1 - col[0] = state[0][0]; - col[1] = state[1][0]; - col[2] = state[2][0]; - col[3] = state[3][0]; - state[0][0] = gf_mul[col[0]][5]; - state[0][0] ^= gf_mul[col[1]][3]; - state[0][0] ^= gf_mul[col[2]][4]; - state[0][0] ^= gf_mul[col[3]][2]; - state[1][0] = gf_mul[col[0]][2]; - state[1][0] ^= gf_mul[col[1]][5]; - state[1][0] ^= gf_mul[col[2]][3]; - state[1][0] ^= gf_mul[col[3]][4]; - state[2][0] = gf_mul[col[0]][4]; - state[2][0] ^= gf_mul[col[1]][2]; - state[2][0] ^= gf_mul[col[2]][5]; - state[2][0] ^= gf_mul[col[3]][3]; - state[3][0] = gf_mul[col[0]][3]; - state[3][0] ^= gf_mul[col[1]][4]; - state[3][0] ^= gf_mul[col[2]][2]; - state[3][0] ^= gf_mul[col[3]][5]; - // Column 2 - col[0] = state[0][1]; - col[1] = state[1][1]; - col[2] = state[2][1]; - col[3] = state[3][1]; - state[0][1] = gf_mul[col[0]][5]; - state[0][1] ^= gf_mul[col[1]][3]; - state[0][1] ^= gf_mul[col[2]][4]; - state[0][1] ^= gf_mul[col[3]][2]; - state[1][1] = gf_mul[col[0]][2]; - state[1][1] ^= gf_mul[col[1]][5]; - state[1][1] ^= gf_mul[col[2]][3]; - state[1][1] ^= gf_mul[col[3]][4]; - state[2][1] = gf_mul[col[0]][4]; - state[2][1] ^= gf_mul[col[1]][2]; - state[2][1] ^= gf_mul[col[2]][5]; - state[2][1] ^= gf_mul[col[3]][3]; - state[3][1] = gf_mul[col[0]][3]; - state[3][1] ^= gf_mul[col[1]][4]; - state[3][1] ^= gf_mul[col[2]][2]; - state[3][1] ^= gf_mul[col[3]][5]; - // Column 3 - col[0] = state[0][2]; - col[1] = state[1][2]; - col[2] = state[2][2]; - col[3] = state[3][2]; - state[0][2] = gf_mul[col[0]][5]; - state[0][2] ^= gf_mul[col[1]][3]; - state[0][2] ^= gf_mul[col[2]][4]; - state[0][2] ^= gf_mul[col[3]][2]; - state[1][2] = gf_mul[col[0]][2]; - state[1][2] ^= gf_mul[col[1]][5]; - state[1][2] ^= gf_mul[col[2]][3]; - state[1][2] ^= gf_mul[col[3]][4]; - state[2][2] = gf_mul[col[0]][4]; - state[2][2] ^= gf_mul[col[1]][2]; - state[2][2] ^= gf_mul[col[2]][5]; - state[2][2] ^= gf_mul[col[3]][3]; - state[3][2] = gf_mul[col[0]][3]; - state[3][2] ^= gf_mul[col[1]][4]; - state[3][2] ^= gf_mul[col[2]][2]; - state[3][2] ^= gf_mul[col[3]][5]; - // Column 4 - col[0] = state[0][3]; - col[1] = state[1][3]; - col[2] = state[2][3]; - col[3] = state[3][3]; - state[0][3] = gf_mul[col[0]][5]; - state[0][3] ^= gf_mul[col[1]][3]; - state[0][3] ^= gf_mul[col[2]][4]; - state[0][3] ^= gf_mul[col[3]][2]; - state[1][3] = gf_mul[col[0]][2]; - state[1][3] ^= gf_mul[col[1]][5]; - state[1][3] ^= gf_mul[col[2]][3]; - state[1][3] ^= gf_mul[col[3]][4]; - state[2][3] = gf_mul[col[0]][4]; - state[2][3] ^= gf_mul[col[1]][2]; - state[2][3] ^= gf_mul[col[2]][5]; - state[2][3] ^= gf_mul[col[3]][3]; - state[3][3] = gf_mul[col[0]][3]; - state[3][3] ^= gf_mul[col[1]][4]; - state[3][3] ^= gf_mul[col[2]][2]; - state[3][3] ^= gf_mul[col[3]][5]; -} - -///////////////// -// (En/De)Crypt -///////////////// - -void aes_encrypt(const BYTE in[], BYTE out[], const WORD key[], int keysize) -{ - BYTE state[4][4]; - - // Copy input array (should be 16 bytes long) to a matrix (sequential bytes are ordered - // by row, not col) called "state" for processing. - // *** Implementation note: The official AES documentation references the state by - // column, then row. Accessing an element in C requires row then column. Thus, all state - // references in AES must have the column and row indexes reversed for C implementation. - state[0][0] = in[0]; - state[1][0] = in[1]; - state[2][0] = in[2]; - state[3][0] = in[3]; - state[0][1] = in[4]; - state[1][1] = in[5]; - state[2][1] = in[6]; - state[3][1] = in[7]; - state[0][2] = in[8]; - state[1][2] = in[9]; - state[2][2] = in[10]; - state[3][2] = in[11]; - state[0][3] = in[12]; - state[1][3] = in[13]; - state[2][3] = in[14]; - state[3][3] = in[15]; - - // Perform the necessary number of rounds. The round key is added first. - // The last round does not perform the MixColumns step. - AddRoundKey(state,&key[0]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[4]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[8]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[12]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[16]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[20]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[24]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[28]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[32]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[36]); - if (keysize != 128) { - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[40]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[44]); - if (keysize != 192) { - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[48]); - SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[52]); - SubBytes(state); ShiftRows(state); AddRoundKey(state,&key[56]); - } - else { - SubBytes(state); ShiftRows(state); AddRoundKey(state,&key[48]); - } - } - else { - SubBytes(state); ShiftRows(state); AddRoundKey(state,&key[40]); - } - - // Copy the state to the output array. - out[0] = state[0][0]; - out[1] = state[1][0]; - out[2] = state[2][0]; - out[3] = state[3][0]; - out[4] = state[0][1]; - out[5] = state[1][1]; - out[6] = state[2][1]; - out[7] = state[3][1]; - out[8] = state[0][2]; - out[9] = state[1][2]; - out[10] = state[2][2]; - out[11] = state[3][2]; - out[12] = state[0][3]; - out[13] = state[1][3]; - out[14] = state[2][3]; - out[15] = state[3][3]; -} - -void aes_decrypt(const BYTE in[], BYTE out[], const WORD key[], int keysize) -{ - BYTE state[4][4]; - - // Copy the input to the state. - state[0][0] = in[0]; - state[1][0] = in[1]; - state[2][0] = in[2]; - state[3][0] = in[3]; - state[0][1] = in[4]; - state[1][1] = in[5]; - state[2][1] = in[6]; - state[3][1] = in[7]; - state[0][2] = in[8]; - state[1][2] = in[9]; - state[2][2] = in[10]; - state[3][2] = in[11]; - state[0][3] = in[12]; - state[1][3] = in[13]; - state[2][3] = in[14]; - state[3][3] = in[15]; - - // Perform the necessary number of rounds. The round key is added first. - // The last round does not perform the MixColumns step. - if (keysize > 128) { - if (keysize > 192) { - AddRoundKey(state,&key[56]); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[52]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[48]);InvMixColumns(state); - } - else { - AddRoundKey(state,&key[48]); - } - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[44]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[40]);InvMixColumns(state); - } - else { - AddRoundKey(state,&key[40]); - } - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[36]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[32]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[28]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[24]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[20]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[16]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[12]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[8]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[4]);InvMixColumns(state); - InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[0]); - - // Copy the state to the output array. - out[0] = state[0][0]; - out[1] = state[1][0]; - out[2] = state[2][0]; - out[3] = state[3][0]; - out[4] = state[0][1]; - out[5] = state[1][1]; - out[6] = state[2][1]; - out[7] = state[3][1]; - out[8] = state[0][2]; - out[9] = state[1][2]; - out[10] = state[2][2]; - out[11] = state[3][2]; - out[12] = state[0][3]; - out[13] = state[1][3]; - out[14] = state[2][3]; - out[15] = state[3][3]; -} - -/******************* -** AES DEBUGGING FUNCTIONS -*******************/ -/* - // This prints the "state" grid as a linear hex string. - void print_state(BYTE state[][4]) - { - int idx,idx2; - - for (idx=0; idx < 4; idx++) - for (idx2=0; idx2 < 4; idx2++) - printf("%02x",state[idx2][idx]); - printf("\n"); - } - - // This prints the key (4 consecutive ints) used for a given round as a linear hex string. - void print_rnd_key(WORD key[]) - { - int idx; - - for (idx=0; idx < 4; idx++) - printf("%08x",key[idx]); - printf("\n"); - } - */ diff --git a/version2/src/C/aes.cpp b/version2/src/C/aes.cpp new file mode 100644 index 0000000..a917c7f --- /dev/null +++ b/version2/src/C/aes.cpp @@ -0,0 +1,1095 @@ +/********************************************************************* +* Filename: aes.c +* Author: Brad Conte (brad AT bradconte.com) +* Copyright: +* Disclaimer: This code is presented "as is" without any guarantees. +* Details: This code is the implementation of the AES algorithm and + the CTR, CBC, and CCM modes of operation it can be used in. + AES is, specified by the NIST in in publication FIPS PUB 197, + availible at: +* http://csrc.nist.gov/publications/fips/fips197/fips-197.pdf . + The CBC and CTR modes of operation are specified by + NIST SP 800-38 A, available at: +* http://csrc.nist.gov/publications/nistpubs/800-38a/sp800-38a.pdf . + The CCM mode of operation is specified by NIST SP80-38 C, available at: +* http://csrc.nist.gov/publications/nistpubs/800-38C/SP800-38C_updated-July20_2007.pdf +*********************************************************************/ + +/*************************** HEADER FILES ***************************/ +#include +#include +#include "aes.h" + +#include + +/****************************** MACROS ******************************/ +// The least significant byte of the word is rotated to the end. +#define KE_ROTWORD(x) (((x) << 8) | ((x) >> 24)) + +#define TRUE 1 +#define FALSE 0 + +/**************************** DATA TYPES ****************************/ +#define AES_128_ROUNDS 10 +#define AES_192_ROUNDS 12 +#define AES_256_ROUNDS 14 + +/*********************** FUNCTION DECLARATIONS **********************/ +void ccm_prepare_first_ctr_blk(BYTE counter[], const BYTE nonce[], int nonce_len, int payload_len_store_size); +void ccm_prepare_first_format_blk(BYTE buf[], int assoc_len, int payload_len, int payload_len_store_size, int mac_len, const BYTE nonce[], int nonce_len); +void ccm_format_assoc_data(BYTE buf[], int *end_of_buf, const BYTE assoc[], int assoc_len); +void ccm_format_payload_data(BYTE buf[], int *end_of_buf, const BYTE payload[], int payload_len); + +/**************************** VARIABLES *****************************/ +// This is the specified AES SBox. To look up a substitution value, put the first +// nibble in the first index (row) and the second nibble in the second index (column). +static const BYTE aes_sbox[16][16] = { + {0x63,0x7C,0x77,0x7B,0xF2,0x6B,0x6F,0xC5,0x30,0x01,0x67,0x2B,0xFE,0xD7,0xAB,0x76}, + {0xCA,0x82,0xC9,0x7D,0xFA,0x59,0x47,0xF0,0xAD,0xD4,0xA2,0xAF,0x9C,0xA4,0x72,0xC0}, + {0xB7,0xFD,0x93,0x26,0x36,0x3F,0xF7,0xCC,0x34,0xA5,0xE5,0xF1,0x71,0xD8,0x31,0x15}, + {0x04,0xC7,0x23,0xC3,0x18,0x96,0x05,0x9A,0x07,0x12,0x80,0xE2,0xEB,0x27,0xB2,0x75}, + {0x09,0x83,0x2C,0x1A,0x1B,0x6E,0x5A,0xA0,0x52,0x3B,0xD6,0xB3,0x29,0xE3,0x2F,0x84}, + {0x53,0xD1,0x00,0xED,0x20,0xFC,0xB1,0x5B,0x6A,0xCB,0xBE,0x39,0x4A,0x4C,0x58,0xCF}, + {0xD0,0xEF,0xAA,0xFB,0x43,0x4D,0x33,0x85,0x45,0xF9,0x02,0x7F,0x50,0x3C,0x9F,0xA8}, + {0x51,0xA3,0x40,0x8F,0x92,0x9D,0x38,0xF5,0xBC,0xB6,0xDA,0x21,0x10,0xFF,0xF3,0xD2}, + {0xCD,0x0C,0x13,0xEC,0x5F,0x97,0x44,0x17,0xC4,0xA7,0x7E,0x3D,0x64,0x5D,0x19,0x73}, + {0x60,0x81,0x4F,0xDC,0x22,0x2A,0x90,0x88,0x46,0xEE,0xB8,0x14,0xDE,0x5E,0x0B,0xDB}, + {0xE0,0x32,0x3A,0x0A,0x49,0x06,0x24,0x5C,0xC2,0xD3,0xAC,0x62,0x91,0x95,0xE4,0x79}, + {0xE7,0xC8,0x37,0x6D,0x8D,0xD5,0x4E,0xA9,0x6C,0x56,0xF4,0xEA,0x65,0x7A,0xAE,0x08}, + {0xBA,0x78,0x25,0x2E,0x1C,0xA6,0xB4,0xC6,0xE8,0xDD,0x74,0x1F,0x4B,0xBD,0x8B,0x8A}, + {0x70,0x3E,0xB5,0x66,0x48,0x03,0xF6,0x0E,0x61,0x35,0x57,0xB9,0x86,0xC1,0x1D,0x9E}, + {0xE1,0xF8,0x98,0x11,0x69,0xD9,0x8E,0x94,0x9B,0x1E,0x87,0xE9,0xCE,0x55,0x28,0xDF}, + {0x8C,0xA1,0x89,0x0D,0xBF,0xE6,0x42,0x68,0x41,0x99,0x2D,0x0F,0xB0,0x54,0xBB,0x16} +}; + +static const BYTE aes_invsbox[16][16] = { + {0x52,0x09,0x6A,0xD5,0x30,0x36,0xA5,0x38,0xBF,0x40,0xA3,0x9E,0x81,0xF3,0xD7,0xFB}, + {0x7C,0xE3,0x39,0x82,0x9B,0x2F,0xFF,0x87,0x34,0x8E,0x43,0x44,0xC4,0xDE,0xE9,0xCB}, + {0x54,0x7B,0x94,0x32,0xA6,0xC2,0x23,0x3D,0xEE,0x4C,0x95,0x0B,0x42,0xFA,0xC3,0x4E}, + {0x08,0x2E,0xA1,0x66,0x28,0xD9,0x24,0xB2,0x76,0x5B,0xA2,0x49,0x6D,0x8B,0xD1,0x25}, + {0x72,0xF8,0xF6,0x64,0x86,0x68,0x98,0x16,0xD4,0xA4,0x5C,0xCC,0x5D,0x65,0xB6,0x92}, + {0x6C,0x70,0x48,0x50,0xFD,0xED,0xB9,0xDA,0x5E,0x15,0x46,0x57,0xA7,0x8D,0x9D,0x84}, + {0x90,0xD8,0xAB,0x00,0x8C,0xBC,0xD3,0x0A,0xF7,0xE4,0x58,0x05,0xB8,0xB3,0x45,0x06}, + {0xD0,0x2C,0x1E,0x8F,0xCA,0x3F,0x0F,0x02,0xC1,0xAF,0xBD,0x03,0x01,0x13,0x8A,0x6B}, + {0x3A,0x91,0x11,0x41,0x4F,0x67,0xDC,0xEA,0x97,0xF2,0xCF,0xCE,0xF0,0xB4,0xE6,0x73}, + {0x96,0xAC,0x74,0x22,0xE7,0xAD,0x35,0x85,0xE2,0xF9,0x37,0xE8,0x1C,0x75,0xDF,0x6E}, + {0x47,0xF1,0x1A,0x71,0x1D,0x29,0xC5,0x89,0x6F,0xB7,0x62,0x0E,0xAA,0x18,0xBE,0x1B}, + {0xFC,0x56,0x3E,0x4B,0xC6,0xD2,0x79,0x20,0x9A,0xDB,0xC0,0xFE,0x78,0xCD,0x5A,0xF4}, + {0x1F,0xDD,0xA8,0x33,0x88,0x07,0xC7,0x31,0xB1,0x12,0x10,0x59,0x27,0x80,0xEC,0x5F}, + {0x60,0x51,0x7F,0xA9,0x19,0xB5,0x4A,0x0D,0x2D,0xE5,0x7A,0x9F,0x93,0xC9,0x9C,0xEF}, + {0xA0,0xE0,0x3B,0x4D,0xAE,0x2A,0xF5,0xB0,0xC8,0xEB,0xBB,0x3C,0x83,0x53,0x99,0x61}, + {0x17,0x2B,0x04,0x7E,0xBA,0x77,0xD6,0x26,0xE1,0x69,0x14,0x63,0x55,0x21,0x0C,0x7D} +}; + +// This table stores pre-calculated values for all possible GF(2^8) calculations.This +// table is only used by the (Inv)MixColumns steps. +// USAGE: The second index (column) is the coefficient of multiplication. Only 7 different +// coefficients are used: 0x01, 0x02, 0x03, 0x09, 0x0b, 0x0d, 0x0e, but multiplication by +// 1 is negligible leaving only 6 coefficients. Each column of the table is devoted to one +// of these coefficients, in the ascending order of value, from values 0x00 to 0xFF. +static const BYTE gf_mul[256][6] = { + {0x00,0x00,0x00,0x00,0x00,0x00},{0x02,0x03,0x09,0x0b,0x0d,0x0e}, + {0x04,0x06,0x12,0x16,0x1a,0x1c},{0x06,0x05,0x1b,0x1d,0x17,0x12}, + {0x08,0x0c,0x24,0x2c,0x34,0x38},{0x0a,0x0f,0x2d,0x27,0x39,0x36}, + {0x0c,0x0a,0x36,0x3a,0x2e,0x24},{0x0e,0x09,0x3f,0x31,0x23,0x2a}, + {0x10,0x18,0x48,0x58,0x68,0x70},{0x12,0x1b,0x41,0x53,0x65,0x7e}, + {0x14,0x1e,0x5a,0x4e,0x72,0x6c},{0x16,0x1d,0x53,0x45,0x7f,0x62}, + {0x18,0x14,0x6c,0x74,0x5c,0x48},{0x1a,0x17,0x65,0x7f,0x51,0x46}, + {0x1c,0x12,0x7e,0x62,0x46,0x54},{0x1e,0x11,0x77,0x69,0x4b,0x5a}, + {0x20,0x30,0x90,0xb0,0xd0,0xe0},{0x22,0x33,0x99,0xbb,0xdd,0xee}, + {0x24,0x36,0x82,0xa6,0xca,0xfc},{0x26,0x35,0x8b,0xad,0xc7,0xf2}, + {0x28,0x3c,0xb4,0x9c,0xe4,0xd8},{0x2a,0x3f,0xbd,0x97,0xe9,0xd6}, + {0x2c,0x3a,0xa6,0x8a,0xfe,0xc4},{0x2e,0x39,0xaf,0x81,0xf3,0xca}, + {0x30,0x28,0xd8,0xe8,0xb8,0x90},{0x32,0x2b,0xd1,0xe3,0xb5,0x9e}, + {0x34,0x2e,0xca,0xfe,0xa2,0x8c},{0x36,0x2d,0xc3,0xf5,0xaf,0x82}, + {0x38,0x24,0xfc,0xc4,0x8c,0xa8},{0x3a,0x27,0xf5,0xcf,0x81,0xa6}, + {0x3c,0x22,0xee,0xd2,0x96,0xb4},{0x3e,0x21,0xe7,0xd9,0x9b,0xba}, + {0x40,0x60,0x3b,0x7b,0xbb,0xdb},{0x42,0x63,0x32,0x70,0xb6,0xd5}, + {0x44,0x66,0x29,0x6d,0xa1,0xc7},{0x46,0x65,0x20,0x66,0xac,0xc9}, + {0x48,0x6c,0x1f,0x57,0x8f,0xe3},{0x4a,0x6f,0x16,0x5c,0x82,0xed}, + {0x4c,0x6a,0x0d,0x41,0x95,0xff},{0x4e,0x69,0x04,0x4a,0x98,0xf1}, + {0x50,0x78,0x73,0x23,0xd3,0xab},{0x52,0x7b,0x7a,0x28,0xde,0xa5}, + {0x54,0x7e,0x61,0x35,0xc9,0xb7},{0x56,0x7d,0x68,0x3e,0xc4,0xb9}, + {0x58,0x74,0x57,0x0f,0xe7,0x93},{0x5a,0x77,0x5e,0x04,0xea,0x9d}, + {0x5c,0x72,0x45,0x19,0xfd,0x8f},{0x5e,0x71,0x4c,0x12,0xf0,0x81}, + {0x60,0x50,0xab,0xcb,0x6b,0x3b},{0x62,0x53,0xa2,0xc0,0x66,0x35}, + {0x64,0x56,0xb9,0xdd,0x71,0x27},{0x66,0x55,0xb0,0xd6,0x7c,0x29}, + {0x68,0x5c,0x8f,0xe7,0x5f,0x03},{0x6a,0x5f,0x86,0xec,0x52,0x0d}, + {0x6c,0x5a,0x9d,0xf1,0x45,0x1f},{0x6e,0x59,0x94,0xfa,0x48,0x11}, + {0x70,0x48,0xe3,0x93,0x03,0x4b},{0x72,0x4b,0xea,0x98,0x0e,0x45}, + {0x74,0x4e,0xf1,0x85,0x19,0x57},{0x76,0x4d,0xf8,0x8e,0x14,0x59}, + {0x78,0x44,0xc7,0xbf,0x37,0x73},{0x7a,0x47,0xce,0xb4,0x3a,0x7d}, + {0x7c,0x42,0xd5,0xa9,0x2d,0x6f},{0x7e,0x41,0xdc,0xa2,0x20,0x61}, + {0x80,0xc0,0x76,0xf6,0x6d,0xad},{0x82,0xc3,0x7f,0xfd,0x60,0xa3}, + {0x84,0xc6,0x64,0xe0,0x77,0xb1},{0x86,0xc5,0x6d,0xeb,0x7a,0xbf}, + {0x88,0xcc,0x52,0xda,0x59,0x95},{0x8a,0xcf,0x5b,0xd1,0x54,0x9b}, + {0x8c,0xca,0x40,0xcc,0x43,0x89},{0x8e,0xc9,0x49,0xc7,0x4e,0x87}, + {0x90,0xd8,0x3e,0xae,0x05,0xdd},{0x92,0xdb,0x37,0xa5,0x08,0xd3}, + {0x94,0xde,0x2c,0xb8,0x1f,0xc1},{0x96,0xdd,0x25,0xb3,0x12,0xcf}, + {0x98,0xd4,0x1a,0x82,0x31,0xe5},{0x9a,0xd7,0x13,0x89,0x3c,0xeb}, + {0x9c,0xd2,0x08,0x94,0x2b,0xf9},{0x9e,0xd1,0x01,0x9f,0x26,0xf7}, + {0xa0,0xf0,0xe6,0x46,0xbd,0x4d},{0xa2,0xf3,0xef,0x4d,0xb0,0x43}, + {0xa4,0xf6,0xf4,0x50,0xa7,0x51},{0xa6,0xf5,0xfd,0x5b,0xaa,0x5f}, + {0xa8,0xfc,0xc2,0x6a,0x89,0x75},{0xaa,0xff,0xcb,0x61,0x84,0x7b}, + {0xac,0xfa,0xd0,0x7c,0x93,0x69},{0xae,0xf9,0xd9,0x77,0x9e,0x67}, + {0xb0,0xe8,0xae,0x1e,0xd5,0x3d},{0xb2,0xeb,0xa7,0x15,0xd8,0x33}, + {0xb4,0xee,0xbc,0x08,0xcf,0x21},{0xb6,0xed,0xb5,0x03,0xc2,0x2f}, + {0xb8,0xe4,0x8a,0x32,0xe1,0x05},{0xba,0xe7,0x83,0x39,0xec,0x0b}, + {0xbc,0xe2,0x98,0x24,0xfb,0x19},{0xbe,0xe1,0x91,0x2f,0xf6,0x17}, + {0xc0,0xa0,0x4d,0x8d,0xd6,0x76},{0xc2,0xa3,0x44,0x86,0xdb,0x78}, + {0xc4,0xa6,0x5f,0x9b,0xcc,0x6a},{0xc6,0xa5,0x56,0x90,0xc1,0x64}, + {0xc8,0xac,0x69,0xa1,0xe2,0x4e},{0xca,0xaf,0x60,0xaa,0xef,0x40}, + {0xcc,0xaa,0x7b,0xb7,0xf8,0x52},{0xce,0xa9,0x72,0xbc,0xf5,0x5c}, + {0xd0,0xb8,0x05,0xd5,0xbe,0x06},{0xd2,0xbb,0x0c,0xde,0xb3,0x08}, + {0xd4,0xbe,0x17,0xc3,0xa4,0x1a},{0xd6,0xbd,0x1e,0xc8,0xa9,0x14}, + {0xd8,0xb4,0x21,0xf9,0x8a,0x3e},{0xda,0xb7,0x28,0xf2,0x87,0x30}, + {0xdc,0xb2,0x33,0xef,0x90,0x22},{0xde,0xb1,0x3a,0xe4,0x9d,0x2c}, + {0xe0,0x90,0xdd,0x3d,0x06,0x96},{0xe2,0x93,0xd4,0x36,0x0b,0x98}, + {0xe4,0x96,0xcf,0x2b,0x1c,0x8a},{0xe6,0x95,0xc6,0x20,0x11,0x84}, + {0xe8,0x9c,0xf9,0x11,0x32,0xae},{0xea,0x9f,0xf0,0x1a,0x3f,0xa0}, + {0xec,0x9a,0xeb,0x07,0x28,0xb2},{0xee,0x99,0xe2,0x0c,0x25,0xbc}, + {0xf0,0x88,0x95,0x65,0x6e,0xe6},{0xf2,0x8b,0x9c,0x6e,0x63,0xe8}, + {0xf4,0x8e,0x87,0x73,0x74,0xfa},{0xf6,0x8d,0x8e,0x78,0x79,0xf4}, + {0xf8,0x84,0xb1,0x49,0x5a,0xde},{0xfa,0x87,0xb8,0x42,0x57,0xd0}, + {0xfc,0x82,0xa3,0x5f,0x40,0xc2},{0xfe,0x81,0xaa,0x54,0x4d,0xcc}, + {0x1b,0x9b,0xec,0xf7,0xda,0x41},{0x19,0x98,0xe5,0xfc,0xd7,0x4f}, + {0x1f,0x9d,0xfe,0xe1,0xc0,0x5d},{0x1d,0x9e,0xf7,0xea,0xcd,0x53}, + {0x13,0x97,0xc8,0xdb,0xee,0x79},{0x11,0x94,0xc1,0xd0,0xe3,0x77}, + {0x17,0x91,0xda,0xcd,0xf4,0x65},{0x15,0x92,0xd3,0xc6,0xf9,0x6b}, + {0x0b,0x83,0xa4,0xaf,0xb2,0x31},{0x09,0x80,0xad,0xa4,0xbf,0x3f}, + {0x0f,0x85,0xb6,0xb9,0xa8,0x2d},{0x0d,0x86,0xbf,0xb2,0xa5,0x23}, + {0x03,0x8f,0x80,0x83,0x86,0x09},{0x01,0x8c,0x89,0x88,0x8b,0x07}, + {0x07,0x89,0x92,0x95,0x9c,0x15},{0x05,0x8a,0x9b,0x9e,0x91,0x1b}, + {0x3b,0xab,0x7c,0x47,0x0a,0xa1},{0x39,0xa8,0x75,0x4c,0x07,0xaf}, + {0x3f,0xad,0x6e,0x51,0x10,0xbd},{0x3d,0xae,0x67,0x5a,0x1d,0xb3}, + {0x33,0xa7,0x58,0x6b,0x3e,0x99},{0x31,0xa4,0x51,0x60,0x33,0x97}, + {0x37,0xa1,0x4a,0x7d,0x24,0x85},{0x35,0xa2,0x43,0x76,0x29,0x8b}, + {0x2b,0xb3,0x34,0x1f,0x62,0xd1},{0x29,0xb0,0x3d,0x14,0x6f,0xdf}, + {0x2f,0xb5,0x26,0x09,0x78,0xcd},{0x2d,0xb6,0x2f,0x02,0x75,0xc3}, + {0x23,0xbf,0x10,0x33,0x56,0xe9},{0x21,0xbc,0x19,0x38,0x5b,0xe7}, + {0x27,0xb9,0x02,0x25,0x4c,0xf5},{0x25,0xba,0x0b,0x2e,0x41,0xfb}, + {0x5b,0xfb,0xd7,0x8c,0x61,0x9a},{0x59,0xf8,0xde,0x87,0x6c,0x94}, + {0x5f,0xfd,0xc5,0x9a,0x7b,0x86},{0x5d,0xfe,0xcc,0x91,0x76,0x88}, + {0x53,0xf7,0xf3,0xa0,0x55,0xa2},{0x51,0xf4,0xfa,0xab,0x58,0xac}, + {0x57,0xf1,0xe1,0xb6,0x4f,0xbe},{0x55,0xf2,0xe8,0xbd,0x42,0xb0}, + {0x4b,0xe3,0x9f,0xd4,0x09,0xea},{0x49,0xe0,0x96,0xdf,0x04,0xe4}, + {0x4f,0xe5,0x8d,0xc2,0x13,0xf6},{0x4d,0xe6,0x84,0xc9,0x1e,0xf8}, + {0x43,0xef,0xbb,0xf8,0x3d,0xd2},{0x41,0xec,0xb2,0xf3,0x30,0xdc}, + {0x47,0xe9,0xa9,0xee,0x27,0xce},{0x45,0xea,0xa0,0xe5,0x2a,0xc0}, + {0x7b,0xcb,0x47,0x3c,0xb1,0x7a},{0x79,0xc8,0x4e,0x37,0xbc,0x74}, + {0x7f,0xcd,0x55,0x2a,0xab,0x66},{0x7d,0xce,0x5c,0x21,0xa6,0x68}, + {0x73,0xc7,0x63,0x10,0x85,0x42},{0x71,0xc4,0x6a,0x1b,0x88,0x4c}, + {0x77,0xc1,0x71,0x06,0x9f,0x5e},{0x75,0xc2,0x78,0x0d,0x92,0x50}, + {0x6b,0xd3,0x0f,0x64,0xd9,0x0a},{0x69,0xd0,0x06,0x6f,0xd4,0x04}, + {0x6f,0xd5,0x1d,0x72,0xc3,0x16},{0x6d,0xd6,0x14,0x79,0xce,0x18}, + {0x63,0xdf,0x2b,0x48,0xed,0x32},{0x61,0xdc,0x22,0x43,0xe0,0x3c}, + {0x67,0xd9,0x39,0x5e,0xf7,0x2e},{0x65,0xda,0x30,0x55,0xfa,0x20}, + {0x9b,0x5b,0x9a,0x01,0xb7,0xec},{0x99,0x58,0x93,0x0a,0xba,0xe2}, + {0x9f,0x5d,0x88,0x17,0xad,0xf0},{0x9d,0x5e,0x81,0x1c,0xa0,0xfe}, + {0x93,0x57,0xbe,0x2d,0x83,0xd4},{0x91,0x54,0xb7,0x26,0x8e,0xda}, + {0x97,0x51,0xac,0x3b,0x99,0xc8},{0x95,0x52,0xa5,0x30,0x94,0xc6}, + {0x8b,0x43,0xd2,0x59,0xdf,0x9c},{0x89,0x40,0xdb,0x52,0xd2,0x92}, + {0x8f,0x45,0xc0,0x4f,0xc5,0x80},{0x8d,0x46,0xc9,0x44,0xc8,0x8e}, + {0x83,0x4f,0xf6,0x75,0xeb,0xa4},{0x81,0x4c,0xff,0x7e,0xe6,0xaa}, + {0x87,0x49,0xe4,0x63,0xf1,0xb8},{0x85,0x4a,0xed,0x68,0xfc,0xb6}, + {0xbb,0x6b,0x0a,0xb1,0x67,0x0c},{0xb9,0x68,0x03,0xba,0x6a,0x02}, + {0xbf,0x6d,0x18,0xa7,0x7d,0x10},{0xbd,0x6e,0x11,0xac,0x70,0x1e}, + {0xb3,0x67,0x2e,0x9d,0x53,0x34},{0xb1,0x64,0x27,0x96,0x5e,0x3a}, + {0xb7,0x61,0x3c,0x8b,0x49,0x28},{0xb5,0x62,0x35,0x80,0x44,0x26}, + {0xab,0x73,0x42,0xe9,0x0f,0x7c},{0xa9,0x70,0x4b,0xe2,0x02,0x72}, + {0xaf,0x75,0x50,0xff,0x15,0x60},{0xad,0x76,0x59,0xf4,0x18,0x6e}, + {0xa3,0x7f,0x66,0xc5,0x3b,0x44},{0xa1,0x7c,0x6f,0xce,0x36,0x4a}, + {0xa7,0x79,0x74,0xd3,0x21,0x58},{0xa5,0x7a,0x7d,0xd8,0x2c,0x56}, + {0xdb,0x3b,0xa1,0x7a,0x0c,0x37},{0xd9,0x38,0xa8,0x71,0x01,0x39}, + {0xdf,0x3d,0xb3,0x6c,0x16,0x2b},{0xdd,0x3e,0xba,0x67,0x1b,0x25}, + {0xd3,0x37,0x85,0x56,0x38,0x0f},{0xd1,0x34,0x8c,0x5d,0x35,0x01}, + {0xd7,0x31,0x97,0x40,0x22,0x13},{0xd5,0x32,0x9e,0x4b,0x2f,0x1d}, + {0xcb,0x23,0xe9,0x22,0x64,0x47},{0xc9,0x20,0xe0,0x29,0x69,0x49}, + {0xcf,0x25,0xfb,0x34,0x7e,0x5b},{0xcd,0x26,0xf2,0x3f,0x73,0x55}, + {0xc3,0x2f,0xcd,0x0e,0x50,0x7f},{0xc1,0x2c,0xc4,0x05,0x5d,0x71}, + {0xc7,0x29,0xdf,0x18,0x4a,0x63},{0xc5,0x2a,0xd6,0x13,0x47,0x6d}, + {0xfb,0x0b,0x31,0xca,0xdc,0xd7},{0xf9,0x08,0x38,0xc1,0xd1,0xd9}, + {0xff,0x0d,0x23,0xdc,0xc6,0xcb},{0xfd,0x0e,0x2a,0xd7,0xcb,0xc5}, + {0xf3,0x07,0x15,0xe6,0xe8,0xef},{0xf1,0x04,0x1c,0xed,0xe5,0xe1}, + {0xf7,0x01,0x07,0xf0,0xf2,0xf3},{0xf5,0x02,0x0e,0xfb,0xff,0xfd}, + {0xeb,0x13,0x79,0x92,0xb4,0xa7},{0xe9,0x10,0x70,0x99,0xb9,0xa9}, + {0xef,0x15,0x6b,0x84,0xae,0xbb},{0xed,0x16,0x62,0x8f,0xa3,0xb5}, + {0xe3,0x1f,0x5d,0xbe,0x80,0x9f},{0xe1,0x1c,0x54,0xb5,0x8d,0x91}, + {0xe7,0x19,0x4f,0xa8,0x9a,0x83},{0xe5,0x1a,0x46,0xa3,0x97,0x8d} +}; + +/*********************** FUNCTION DEFINITIONS ***********************/ +// XORs the in and out buffers, storing the result in out. Length is in bytes. +void xor_buf(const BYTE in[], BYTE out[], size_t len) +{ + size_t idx; + + for (idx = 0; idx < len; idx++) + out[idx] ^= in[idx]; +} + +/******************* +* AES - CBC +*******************/ +int aes_encrypt_cbc(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) +{ + BYTE buf_in[AES_BLOCK_SIZE], buf_out[AES_BLOCK_SIZE], iv_buf[AES_BLOCK_SIZE]; + int blocks, idx; + + if (in_len % AES_BLOCK_SIZE != 0) + return(FALSE); + + blocks = in_len / AES_BLOCK_SIZE; + + memcpy(iv_buf, iv, AES_BLOCK_SIZE); + + for (idx = 0; idx < blocks; idx++) { + memcpy(buf_in, &in[idx * AES_BLOCK_SIZE], AES_BLOCK_SIZE); + xor_buf(iv_buf, buf_in, AES_BLOCK_SIZE); + aes_encrypt(buf_in, buf_out, key, keysize); + memcpy(&out[idx * AES_BLOCK_SIZE], buf_out, AES_BLOCK_SIZE); + memcpy(iv_buf, buf_out, AES_BLOCK_SIZE); + } + + return(TRUE); +} + +int aes_encrypt_cbc_mac(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) +{ + BYTE buf_in[AES_BLOCK_SIZE], buf_out[AES_BLOCK_SIZE], iv_buf[AES_BLOCK_SIZE]; + int blocks, idx; + + if (in_len % AES_BLOCK_SIZE != 0) + return(FALSE); + + blocks = in_len / AES_BLOCK_SIZE; + + memcpy(iv_buf, iv, AES_BLOCK_SIZE); + + for (idx = 0; idx < blocks; idx++) { + memcpy(buf_in, &in[idx * AES_BLOCK_SIZE], AES_BLOCK_SIZE); + xor_buf(iv_buf, buf_in, AES_BLOCK_SIZE); + aes_encrypt(buf_in, buf_out, key, keysize); + memcpy(iv_buf, buf_out, AES_BLOCK_SIZE); + // Do not output all encrypted blocks. + } + + memcpy(out, buf_out, AES_BLOCK_SIZE); // Only output the last block. + + return(TRUE); +} + +int aes_decrypt_cbc(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) +{ + BYTE buf_in[AES_BLOCK_SIZE], buf_out[AES_BLOCK_SIZE], iv_buf[AES_BLOCK_SIZE]; + int blocks, idx; + + if (in_len % AES_BLOCK_SIZE != 0) + return(FALSE); + + blocks = in_len / AES_BLOCK_SIZE; + + memcpy(iv_buf, iv, AES_BLOCK_SIZE); + + for (idx = 0; idx < blocks; idx++) { + memcpy(buf_in, &in[idx * AES_BLOCK_SIZE], AES_BLOCK_SIZE); + aes_decrypt(buf_in, buf_out, key, keysize); + xor_buf(iv_buf, buf_out, AES_BLOCK_SIZE); + memcpy(&out[idx * AES_BLOCK_SIZE], buf_out, AES_BLOCK_SIZE); + memcpy(iv_buf, buf_in, AES_BLOCK_SIZE); + } + + return(TRUE); +} + +/******************* +* AES - CTR +*******************/ +void increment_iv(BYTE iv[], int counter_size) +{ + int idx; + + // Use counter_size bytes at the end of the IV as the big-endian integer to increment. + for (idx = AES_BLOCK_SIZE - 1; idx >= AES_BLOCK_SIZE - counter_size; idx--) { + iv[idx]++; + if (iv[idx] != 0 || idx == AES_BLOCK_SIZE - counter_size) + break; + } +} + +// Performs the encryption in-place, the input and output buffers may be the same. +// Input may be an arbitrary length (in bytes). +void aes_encrypt_ctr(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) +{ + size_t idx = 0, last_block_length; + BYTE iv_buf[AES_BLOCK_SIZE], out_buf[AES_BLOCK_SIZE]; + + if (in != out) + memcpy(out, in, in_len); + + memcpy(iv_buf, iv, AES_BLOCK_SIZE); + last_block_length = in_len - AES_BLOCK_SIZE; + + if (in_len > AES_BLOCK_SIZE) { + for (idx = 0; idx < last_block_length; idx += AES_BLOCK_SIZE) { + aes_encrypt(iv_buf, out_buf, key, keysize); + xor_buf(out_buf, &out[idx], AES_BLOCK_SIZE); + increment_iv(iv_buf, AES_BLOCK_SIZE); + } + } + + aes_encrypt(iv_buf, out_buf, key, keysize); + xor_buf(out_buf, &out[idx], in_len - idx); // Use the Most Significant bytes. +} + +void aes_decrypt_ctr(const BYTE in[], size_t in_len, BYTE out[], const WORD key[], int keysize, const BYTE iv[]) +{ + // CTR encryption is its own inverse function. + aes_encrypt_ctr(in, in_len, out, key, keysize, iv); +} + +/******************* +* AES - CCM +*******************/ +// out_len = payload_len + assoc_len +int aes_encrypt_ccm(const BYTE payload[], WORD payload_len, const BYTE assoc[], unsigned short assoc_len, + const BYTE nonce[], unsigned short nonce_len, BYTE out[], WORD *out_len, + WORD mac_len, const BYTE key_str[], int keysize) +{ + BYTE temp_iv[AES_BLOCK_SIZE], counter[AES_BLOCK_SIZE], mac[16], *buf; + int end_of_buf, payload_len_store_size; + WORD key[60]; + + if (mac_len != 4 && mac_len != 6 && mac_len != 8 && mac_len != 10 && + mac_len != 12 && mac_len != 14 && mac_len != 16) + return(FALSE); + + if (nonce_len < 7 || nonce_len > 13) + return(FALSE); + + if (assoc_len > 32768 /* = 2^15 */) + return(FALSE); + + buf = (BYTE *)malloc(payload_len + assoc_len + 48 /*Round both payload and associated data up a block size and add an extra block.*/); + if (!buf) + return(FALSE); + + // Prepare the key for usage. + aes_key_setup(key_str, key, keysize); + + // Format the first block of the formatted data. + payload_len_store_size = AES_BLOCK_SIZE - 1 - nonce_len; + ccm_prepare_first_format_blk(buf, assoc_len, payload_len, payload_len_store_size, mac_len, nonce, nonce_len); + end_of_buf = AES_BLOCK_SIZE; + + // Format the Associated Data, aka, assoc[]. + ccm_format_assoc_data(buf, &end_of_buf, assoc, assoc_len); + + // Format the Payload, aka payload[]. + ccm_format_payload_data(buf, &end_of_buf, payload, payload_len); + + // Create the first counter block. + ccm_prepare_first_ctr_blk(counter, nonce, nonce_len, payload_len_store_size); + + // Perform the CBC operation with an IV of zeros on the formatted buffer to calculate the MAC. + memset(temp_iv, 0, AES_BLOCK_SIZE); + aes_encrypt_cbc_mac(buf, end_of_buf, mac, key, keysize, temp_iv); + + // Copy the Payload and MAC to the output buffer. + memcpy(out, payload, payload_len); + memcpy(&out[payload_len], mac, mac_len); + + // Encrypt the Payload with CTR mode with a counter starting at 1. + memcpy(temp_iv, counter, AES_BLOCK_SIZE); + increment_iv(temp_iv, AES_BLOCK_SIZE - 1 - mac_len); // Last argument is the byte size of the counting portion of the counter block. /*BUG?*/ + aes_encrypt_ctr(out, payload_len, out, key, keysize, temp_iv); + + // Encrypt the MAC with CTR mode with a counter starting at 0. + aes_encrypt_ctr(&out[payload_len], mac_len, &out[payload_len], key, keysize, counter); + + free(buf); + *out_len = payload_len + mac_len; + + return(TRUE); +} + +// plaintext_len = ciphertext_len - mac_len +// Needs a flag for whether the MAC matches. +int aes_decrypt_ccm(const BYTE ciphertext[], WORD ciphertext_len, const BYTE assoc[], unsigned short assoc_len, + const BYTE nonce[], unsigned short nonce_len, BYTE plaintext[], WORD *plaintext_len, + WORD mac_len, int *mac_auth, const BYTE key_str[], int keysize) +{ + BYTE temp_iv[AES_BLOCK_SIZE], counter[AES_BLOCK_SIZE], mac[16], mac_buf[16], *buf; + int end_of_buf, plaintext_len_store_size; + WORD key[60]; + + if (ciphertext_len <= mac_len) + return(FALSE); + + buf = (BYTE *)malloc(assoc_len + ciphertext_len /*ciphertext_len = plaintext_len + mac_len*/ + 48); + if (!buf) + return(FALSE); + + // Prepare the key for usage. + aes_key_setup(key_str, key, keysize); + + // Copy the plaintext and MAC to the output buffers. + *plaintext_len = ciphertext_len - mac_len; + plaintext_len_store_size = AES_BLOCK_SIZE - 1 - nonce_len; + memcpy(plaintext, ciphertext, *plaintext_len); + memcpy(mac, &ciphertext[*plaintext_len], mac_len); + + // Prepare the first counter block for use in decryption. + ccm_prepare_first_ctr_blk(counter, nonce, nonce_len, plaintext_len_store_size); + + // Decrypt the Payload with CTR mode with a counter starting at 1. + memcpy(temp_iv, counter, AES_BLOCK_SIZE); + increment_iv(temp_iv, AES_BLOCK_SIZE - 1 - mac_len); // (AES_BLOCK_SIZE - 1 - mac_len) is the byte size of the counting portion of the counter block. + aes_decrypt_ctr(plaintext, *plaintext_len, plaintext, key, keysize, temp_iv); + + // Setting mac_auth to NULL disables the authentication check. + if (mac_auth != NULL) { + // Decrypt the MAC with CTR mode with a counter starting at 0. + aes_decrypt_ctr(mac, mac_len, mac, key, keysize, counter); + + // Format the first block of the formatted data. + plaintext_len_store_size = AES_BLOCK_SIZE - 1 - nonce_len; + ccm_prepare_first_format_blk(buf, assoc_len, *plaintext_len, plaintext_len_store_size, mac_len, nonce, nonce_len); + end_of_buf = AES_BLOCK_SIZE; + + // Format the Associated Data into the authentication buffer. + ccm_format_assoc_data(buf, &end_of_buf, assoc, assoc_len); + + // Format the Payload into the authentication buffer. + ccm_format_payload_data(buf, &end_of_buf, plaintext, *plaintext_len); + + // Perform the CBC operation with an IV of zeros on the formatted buffer to calculate the MAC. + memset(temp_iv, 0, AES_BLOCK_SIZE); + aes_encrypt_cbc_mac(buf, end_of_buf, mac_buf, key, keysize, temp_iv); + + // Compare the calculated MAC against the MAC embedded in the ciphertext to see if they are the same. + if (!memcmp(mac, mac_buf, mac_len)) { + *mac_auth = TRUE; + } + else { + *mac_auth = FALSE; + memset(plaintext, 0, *plaintext_len); + } + } + + free(buf); + + return(TRUE); +} + +// Creates the first counter block. First byte is flags, then the nonce, then the incremented part. +void ccm_prepare_first_ctr_blk(BYTE counter[], const BYTE nonce[], int nonce_len, int payload_len_store_size) +{ + memset(counter, 0, AES_BLOCK_SIZE); + counter[0] = (payload_len_store_size - 1) & 0x07; + memcpy(&counter[1], nonce, nonce_len); +} + +void ccm_prepare_first_format_blk(BYTE buf[], int assoc_len, int payload_len, int payload_len_store_size, int mac_len, const BYTE nonce[], int nonce_len) +{ + // Set the flags for the first byte of the first block. + buf[0] = ((((mac_len - 2) / 2) & 0x07) << 3) | ((payload_len_store_size - 1) & 0x07); + if (assoc_len > 0) + buf[0] += 0x40; + // Format the rest of the first block, storing the nonce and the size of the payload. + memcpy(&buf[1], nonce, nonce_len); + memset(&buf[1 + nonce_len], 0, AES_BLOCK_SIZE - 1 - nonce_len); + buf[15] = payload_len & 0x000000FF; + buf[14] = (payload_len >> 8) & 0x000000FF; +} + +void ccm_format_assoc_data(BYTE buf[], int *end_of_buf, const BYTE assoc[], int assoc_len) +{ + int pad; + + buf[*end_of_buf + 1] = assoc_len & 0x00FF; + buf[*end_of_buf] = (assoc_len >> 8) & 0x00FF; + *end_of_buf += 2; + memcpy(&buf[*end_of_buf], assoc, assoc_len); + *end_of_buf += assoc_len; + pad = AES_BLOCK_SIZE - (*end_of_buf % AES_BLOCK_SIZE);/*BUG?*/ + memset(&buf[*end_of_buf], 0, pad); + *end_of_buf += pad; +} + +void ccm_format_payload_data(BYTE buf[], int *end_of_buf, const BYTE payload[], int payload_len) +{ + int pad; + + memcpy(&buf[*end_of_buf], payload, payload_len); + *end_of_buf += payload_len; + pad = *end_of_buf % AES_BLOCK_SIZE; + if (pad != 0) + pad = AES_BLOCK_SIZE - pad; + memset(&buf[*end_of_buf], 0, pad); + *end_of_buf += pad; +} + +/******************* +* AES +*******************/ +///////////////// +// KEY EXPANSION +///////////////// + +// Substitutes a word using the AES S-Box. +WORD SubWord(WORD word) +{ + unsigned int result; + + result = (int)aes_sbox[(word >> 4) & 0x0000000F][word & 0x0000000F]; + result += (int)aes_sbox[(word >> 12) & 0x0000000F][(word >> 8) & 0x0000000F] << 8; + result += (int)aes_sbox[(word >> 20) & 0x0000000F][(word >> 16) & 0x0000000F] << 16; + result += (int)aes_sbox[(word >> 28) & 0x0000000F][(word >> 24) & 0x0000000F] << 24; + return(result); +} + +// Performs the action of generating the keys that will be used in every round of +// encryption. "key" is the user-supplied input key, "w" is the output key schedule, +// "keysize" is the length in bits of "key", must be 128, 192, or 256. +void aes_key_setup(const BYTE key[], WORD w[], int keysize) +{ + int Nb = 4,Nr,Nk,idx; + WORD temp,Rcon[] = {0x01000000,0x02000000,0x04000000,0x08000000,0x10000000,0x20000000, + 0x40000000,0x80000000,0x1b000000,0x36000000,0x6c000000,0xd8000000, + 0xab000000,0x4d000000,0x9a000000}; + + switch (keysize) { + case 128: Nr = 10; Nk = 4; break; + case 192: Nr = 12; Nk = 6; break; + case 256: Nr = 14; Nk = 8; break; + default: return; + } + + for (idx = 0; idx < Nk; ++idx) { + w[idx] = ((key[4 * idx]) << 24) | ((key[4 * idx + 1]) << 16) | + ((key[4 * idx + 2]) << 8) | ((key[4 * idx + 3])); + } + + for (idx = Nk; idx < Nb * (Nr + 1); ++idx) { + temp = w[idx - 1]; + if ((idx % Nk) == 0) + temp = SubWord(KE_ROTWORD(temp)) ^ Rcon[(idx - 1) / Nk]; + else if (Nk > 6 && (idx % Nk) == 4) + temp = SubWord(temp); + w[idx] = w[idx - Nk] ^ temp; + } +} + +///////////////// +// ADD ROUND KEY +///////////////// + +// Performs the AddRoundKey step. Each round has its own pre-generated 16-byte key in the +// form of 4 integers (the "w" array). Each integer is XOR'd by one column of the state. +// Also performs the job of InvAddRoundKey(); since the function is a simple XOR process, +// it is its own inverse. +void AddRoundKey(BYTE state[][4], const WORD w[]) +{ + BYTE subkey[4]; + + // memcpy(subkey,&w[idx],4); // Not accurate for big endian machines + // Subkey 1 + subkey[0] = w[0] >> 24; + subkey[1] = w[0] >> 16; + subkey[2] = w[0] >> 8; + subkey[3] = w[0]; + state[0][0] ^= subkey[0]; + state[1][0] ^= subkey[1]; + state[2][0] ^= subkey[2]; + state[3][0] ^= subkey[3]; + // Subkey 2 + subkey[0] = w[1] >> 24; + subkey[1] = w[1] >> 16; + subkey[2] = w[1] >> 8; + subkey[3] = w[1]; + state[0][1] ^= subkey[0]; + state[1][1] ^= subkey[1]; + state[2][1] ^= subkey[2]; + state[3][1] ^= subkey[3]; + // Subkey 3 + subkey[0] = w[2] >> 24; + subkey[1] = w[2] >> 16; + subkey[2] = w[2] >> 8; + subkey[3] = w[2]; + state[0][2] ^= subkey[0]; + state[1][2] ^= subkey[1]; + state[2][2] ^= subkey[2]; + state[3][2] ^= subkey[3]; + // Subkey 4 + subkey[0] = w[3] >> 24; + subkey[1] = w[3] >> 16; + subkey[2] = w[3] >> 8; + subkey[3] = w[3]; + state[0][3] ^= subkey[0]; + state[1][3] ^= subkey[1]; + state[2][3] ^= subkey[2]; + state[3][3] ^= subkey[3]; +} + +///////////////// +// (Inv)SubBytes +///////////////// + +// Performs the SubBytes step. All bytes in the state are substituted with a +// pre-calculated value from a lookup table. +void SubBytes(BYTE state[][4]) +{ + state[0][0] = aes_sbox[state[0][0] >> 4][state[0][0] & 0x0F]; + state[0][1] = aes_sbox[state[0][1] >> 4][state[0][1] & 0x0F]; + state[0][2] = aes_sbox[state[0][2] >> 4][state[0][2] & 0x0F]; + state[0][3] = aes_sbox[state[0][3] >> 4][state[0][3] & 0x0F]; + state[1][0] = aes_sbox[state[1][0] >> 4][state[1][0] & 0x0F]; + state[1][1] = aes_sbox[state[1][1] >> 4][state[1][1] & 0x0F]; + state[1][2] = aes_sbox[state[1][2] >> 4][state[1][2] & 0x0F]; + state[1][3] = aes_sbox[state[1][3] >> 4][state[1][3] & 0x0F]; + state[2][0] = aes_sbox[state[2][0] >> 4][state[2][0] & 0x0F]; + state[2][1] = aes_sbox[state[2][1] >> 4][state[2][1] & 0x0F]; + state[2][2] = aes_sbox[state[2][2] >> 4][state[2][2] & 0x0F]; + state[2][3] = aes_sbox[state[2][3] >> 4][state[2][3] & 0x0F]; + state[3][0] = aes_sbox[state[3][0] >> 4][state[3][0] & 0x0F]; + state[3][1] = aes_sbox[state[3][1] >> 4][state[3][1] & 0x0F]; + state[3][2] = aes_sbox[state[3][2] >> 4][state[3][2] & 0x0F]; + state[3][3] = aes_sbox[state[3][3] >> 4][state[3][3] & 0x0F]; +} + +void InvSubBytes(BYTE state[][4]) +{ + state[0][0] = aes_invsbox[state[0][0] >> 4][state[0][0] & 0x0F]; + state[0][1] = aes_invsbox[state[0][1] >> 4][state[0][1] & 0x0F]; + state[0][2] = aes_invsbox[state[0][2] >> 4][state[0][2] & 0x0F]; + state[0][3] = aes_invsbox[state[0][3] >> 4][state[0][3] & 0x0F]; + state[1][0] = aes_invsbox[state[1][0] >> 4][state[1][0] & 0x0F]; + state[1][1] = aes_invsbox[state[1][1] >> 4][state[1][1] & 0x0F]; + state[1][2] = aes_invsbox[state[1][2] >> 4][state[1][2] & 0x0F]; + state[1][3] = aes_invsbox[state[1][3] >> 4][state[1][3] & 0x0F]; + state[2][0] = aes_invsbox[state[2][0] >> 4][state[2][0] & 0x0F]; + state[2][1] = aes_invsbox[state[2][1] >> 4][state[2][1] & 0x0F]; + state[2][2] = aes_invsbox[state[2][2] >> 4][state[2][2] & 0x0F]; + state[2][3] = aes_invsbox[state[2][3] >> 4][state[2][3] & 0x0F]; + state[3][0] = aes_invsbox[state[3][0] >> 4][state[3][0] & 0x0F]; + state[3][1] = aes_invsbox[state[3][1] >> 4][state[3][1] & 0x0F]; + state[3][2] = aes_invsbox[state[3][2] >> 4][state[3][2] & 0x0F]; + state[3][3] = aes_invsbox[state[3][3] >> 4][state[3][3] & 0x0F]; +} + +///////////////// +// (Inv)ShiftRows +///////////////// + +// Performs the ShiftRows step. All rows are shifted cylindrically to the left. +void ShiftRows(BYTE state[][4]) +{ + int t; + + // Shift left by 1 + t = state[1][0]; + state[1][0] = state[1][1]; + state[1][1] = state[1][2]; + state[1][2] = state[1][3]; + state[1][3] = t; + // Shift left by 2 + t = state[2][0]; + state[2][0] = state[2][2]; + state[2][2] = t; + t = state[2][1]; + state[2][1] = state[2][3]; + state[2][3] = t; + // Shift left by 3 + t = state[3][0]; + state[3][0] = state[3][3]; + state[3][3] = state[3][2]; + state[3][2] = state[3][1]; + state[3][1] = t; +} + +// All rows are shifted cylindrically to the right. +void InvShiftRows(BYTE state[][4]) +{ + int t; + + // Shift right by 1 + t = state[1][3]; + state[1][3] = state[1][2]; + state[1][2] = state[1][1]; + state[1][1] = state[1][0]; + state[1][0] = t; + // Shift right by 2 + t = state[2][3]; + state[2][3] = state[2][1]; + state[2][1] = t; + t = state[2][2]; + state[2][2] = state[2][0]; + state[2][0] = t; + // Shift right by 3 + t = state[3][3]; + state[3][3] = state[3][0]; + state[3][0] = state[3][1]; + state[3][1] = state[3][2]; + state[3][2] = t; +} + +///////////////// +// (Inv)MixColumns +///////////////// + +// Performs the MixColums step. The state is multiplied by itself using matrix +// multiplication in a Galios Field 2^8. All multiplication is pre-computed in a table. +// Addition is equivilent to XOR. (Must always make a copy of the column as the original +// values will be destoyed.) +void MixColumns(BYTE state[][4]) +{ + BYTE col[4]; + + // Column 1 + col[0] = state[0][0]; + col[1] = state[1][0]; + col[2] = state[2][0]; + col[3] = state[3][0]; + state[0][0] = gf_mul[col[0]][0]; + state[0][0] ^= gf_mul[col[1]][1]; + state[0][0] ^= col[2]; + state[0][0] ^= col[3]; + state[1][0] = col[0]; + state[1][0] ^= gf_mul[col[1]][0]; + state[1][0] ^= gf_mul[col[2]][1]; + state[1][0] ^= col[3]; + state[2][0] = col[0]; + state[2][0] ^= col[1]; + state[2][0] ^= gf_mul[col[2]][0]; + state[2][0] ^= gf_mul[col[3]][1]; + state[3][0] = gf_mul[col[0]][1]; + state[3][0] ^= col[1]; + state[3][0] ^= col[2]; + state[3][0] ^= gf_mul[col[3]][0]; + // Column 2 + col[0] = state[0][1]; + col[1] = state[1][1]; + col[2] = state[2][1]; + col[3] = state[3][1]; + state[0][1] = gf_mul[col[0]][0]; + state[0][1] ^= gf_mul[col[1]][1]; + state[0][1] ^= col[2]; + state[0][1] ^= col[3]; + state[1][1] = col[0]; + state[1][1] ^= gf_mul[col[1]][0]; + state[1][1] ^= gf_mul[col[2]][1]; + state[1][1] ^= col[3]; + state[2][1] = col[0]; + state[2][1] ^= col[1]; + state[2][1] ^= gf_mul[col[2]][0]; + state[2][1] ^= gf_mul[col[3]][1]; + state[3][1] = gf_mul[col[0]][1]; + state[3][1] ^= col[1]; + state[3][1] ^= col[2]; + state[3][1] ^= gf_mul[col[3]][0]; + // Column 3 + col[0] = state[0][2]; + col[1] = state[1][2]; + col[2] = state[2][2]; + col[3] = state[3][2]; + state[0][2] = gf_mul[col[0]][0]; + state[0][2] ^= gf_mul[col[1]][1]; + state[0][2] ^= col[2]; + state[0][2] ^= col[3]; + state[1][2] = col[0]; + state[1][2] ^= gf_mul[col[1]][0]; + state[1][2] ^= gf_mul[col[2]][1]; + state[1][2] ^= col[3]; + state[2][2] = col[0]; + state[2][2] ^= col[1]; + state[2][2] ^= gf_mul[col[2]][0]; + state[2][2] ^= gf_mul[col[3]][1]; + state[3][2] = gf_mul[col[0]][1]; + state[3][2] ^= col[1]; + state[3][2] ^= col[2]; + state[3][2] ^= gf_mul[col[3]][0]; + // Column 4 + col[0] = state[0][3]; + col[1] = state[1][3]; + col[2] = state[2][3]; + col[3] = state[3][3]; + state[0][3] = gf_mul[col[0]][0]; + state[0][3] ^= gf_mul[col[1]][1]; + state[0][3] ^= col[2]; + state[0][3] ^= col[3]; + state[1][3] = col[0]; + state[1][3] ^= gf_mul[col[1]][0]; + state[1][3] ^= gf_mul[col[2]][1]; + state[1][3] ^= col[3]; + state[2][3] = col[0]; + state[2][3] ^= col[1]; + state[2][3] ^= gf_mul[col[2]][0]; + state[2][3] ^= gf_mul[col[3]][1]; + state[3][3] = gf_mul[col[0]][1]; + state[3][3] ^= col[1]; + state[3][3] ^= col[2]; + state[3][3] ^= gf_mul[col[3]][0]; +} + +void InvMixColumns(BYTE state[][4]) +{ + BYTE col[4]; + + // Column 1 + col[0] = state[0][0]; + col[1] = state[1][0]; + col[2] = state[2][0]; + col[3] = state[3][0]; + state[0][0] = gf_mul[col[0]][5]; + state[0][0] ^= gf_mul[col[1]][3]; + state[0][0] ^= gf_mul[col[2]][4]; + state[0][0] ^= gf_mul[col[3]][2]; + state[1][0] = gf_mul[col[0]][2]; + state[1][0] ^= gf_mul[col[1]][5]; + state[1][0] ^= gf_mul[col[2]][3]; + state[1][0] ^= gf_mul[col[3]][4]; + state[2][0] = gf_mul[col[0]][4]; + state[2][0] ^= gf_mul[col[1]][2]; + state[2][0] ^= gf_mul[col[2]][5]; + state[2][0] ^= gf_mul[col[3]][3]; + state[3][0] = gf_mul[col[0]][3]; + state[3][0] ^= gf_mul[col[1]][4]; + state[3][0] ^= gf_mul[col[2]][2]; + state[3][0] ^= gf_mul[col[3]][5]; + // Column 2 + col[0] = state[0][1]; + col[1] = state[1][1]; + col[2] = state[2][1]; + col[3] = state[3][1]; + state[0][1] = gf_mul[col[0]][5]; + state[0][1] ^= gf_mul[col[1]][3]; + state[0][1] ^= gf_mul[col[2]][4]; + state[0][1] ^= gf_mul[col[3]][2]; + state[1][1] = gf_mul[col[0]][2]; + state[1][1] ^= gf_mul[col[1]][5]; + state[1][1] ^= gf_mul[col[2]][3]; + state[1][1] ^= gf_mul[col[3]][4]; + state[2][1] = gf_mul[col[0]][4]; + state[2][1] ^= gf_mul[col[1]][2]; + state[2][1] ^= gf_mul[col[2]][5]; + state[2][1] ^= gf_mul[col[3]][3]; + state[3][1] = gf_mul[col[0]][3]; + state[3][1] ^= gf_mul[col[1]][4]; + state[3][1] ^= gf_mul[col[2]][2]; + state[3][1] ^= gf_mul[col[3]][5]; + // Column 3 + col[0] = state[0][2]; + col[1] = state[1][2]; + col[2] = state[2][2]; + col[3] = state[3][2]; + state[0][2] = gf_mul[col[0]][5]; + state[0][2] ^= gf_mul[col[1]][3]; + state[0][2] ^= gf_mul[col[2]][4]; + state[0][2] ^= gf_mul[col[3]][2]; + state[1][2] = gf_mul[col[0]][2]; + state[1][2] ^= gf_mul[col[1]][5]; + state[1][2] ^= gf_mul[col[2]][3]; + state[1][2] ^= gf_mul[col[3]][4]; + state[2][2] = gf_mul[col[0]][4]; + state[2][2] ^= gf_mul[col[1]][2]; + state[2][2] ^= gf_mul[col[2]][5]; + state[2][2] ^= gf_mul[col[3]][3]; + state[3][2] = gf_mul[col[0]][3]; + state[3][2] ^= gf_mul[col[1]][4]; + state[3][2] ^= gf_mul[col[2]][2]; + state[3][2] ^= gf_mul[col[3]][5]; + // Column 4 + col[0] = state[0][3]; + col[1] = state[1][3]; + col[2] = state[2][3]; + col[3] = state[3][3]; + state[0][3] = gf_mul[col[0]][5]; + state[0][3] ^= gf_mul[col[1]][3]; + state[0][3] ^= gf_mul[col[2]][4]; + state[0][3] ^= gf_mul[col[3]][2]; + state[1][3] = gf_mul[col[0]][2]; + state[1][3] ^= gf_mul[col[1]][5]; + state[1][3] ^= gf_mul[col[2]][3]; + state[1][3] ^= gf_mul[col[3]][4]; + state[2][3] = gf_mul[col[0]][4]; + state[2][3] ^= gf_mul[col[1]][2]; + state[2][3] ^= gf_mul[col[2]][5]; + state[2][3] ^= gf_mul[col[3]][3]; + state[3][3] = gf_mul[col[0]][3]; + state[3][3] ^= gf_mul[col[1]][4]; + state[3][3] ^= gf_mul[col[2]][2]; + state[3][3] ^= gf_mul[col[3]][5]; +} + +///////////////// +// (En/De)Crypt +///////////////// + +void aes_encrypt(const BYTE in[], BYTE out[], const WORD key[], int keysize) +{ + BYTE state[4][4]; + + // Copy input array (should be 16 bytes long) to a matrix (sequential bytes are ordered + // by row, not col) called "state" for processing. + // *** Implementation note: The official AES documentation references the state by + // column, then row. Accessing an element in C requires row then column. Thus, all state + // references in AES must have the column and row indexes reversed for C implementation. + state[0][0] = in[0]; + state[1][0] = in[1]; + state[2][0] = in[2]; + state[3][0] = in[3]; + state[0][1] = in[4]; + state[1][1] = in[5]; + state[2][1] = in[6]; + state[3][1] = in[7]; + state[0][2] = in[8]; + state[1][2] = in[9]; + state[2][2] = in[10]; + state[3][2] = in[11]; + state[0][3] = in[12]; + state[1][3] = in[13]; + state[2][3] = in[14]; + state[3][3] = in[15]; + + // Perform the necessary number of rounds. The round key is added first. + // The last round does not perform the MixColumns step. + AddRoundKey(state,&key[0]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[4]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[8]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[12]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[16]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[20]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[24]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[28]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[32]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[36]); + if (keysize != 128) { + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[40]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[44]); + if (keysize != 192) { + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[48]); + SubBytes(state); ShiftRows(state); MixColumns(state); AddRoundKey(state,&key[52]); + SubBytes(state); ShiftRows(state); AddRoundKey(state,&key[56]); + } + else { + SubBytes(state); ShiftRows(state); AddRoundKey(state,&key[48]); + } + } + else { + SubBytes(state); ShiftRows(state); AddRoundKey(state,&key[40]); + } + + // Copy the state to the output array. + out[0] = state[0][0]; + out[1] = state[1][0]; + out[2] = state[2][0]; + out[3] = state[3][0]; + out[4] = state[0][1]; + out[5] = state[1][1]; + out[6] = state[2][1]; + out[7] = state[3][1]; + out[8] = state[0][2]; + out[9] = state[1][2]; + out[10] = state[2][2]; + out[11] = state[3][2]; + out[12] = state[0][3]; + out[13] = state[1][3]; + out[14] = state[2][3]; + out[15] = state[3][3]; +} + +void aes_decrypt(const BYTE in[], BYTE out[], const WORD key[], int keysize) +{ + BYTE state[4][4]; + + // Copy the input to the state. + state[0][0] = in[0]; + state[1][0] = in[1]; + state[2][0] = in[2]; + state[3][0] = in[3]; + state[0][1] = in[4]; + state[1][1] = in[5]; + state[2][1] = in[6]; + state[3][1] = in[7]; + state[0][2] = in[8]; + state[1][2] = in[9]; + state[2][2] = in[10]; + state[3][2] = in[11]; + state[0][3] = in[12]; + state[1][3] = in[13]; + state[2][3] = in[14]; + state[3][3] = in[15]; + + // Perform the necessary number of rounds. The round key is added first. + // The last round does not perform the MixColumns step. + if (keysize > 128) { + if (keysize > 192) { + AddRoundKey(state,&key[56]); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[52]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[48]);InvMixColumns(state); + } + else { + AddRoundKey(state,&key[48]); + } + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[44]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[40]);InvMixColumns(state); + } + else { + AddRoundKey(state,&key[40]); + } + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[36]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[32]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[28]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[24]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[20]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[16]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[12]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[8]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[4]);InvMixColumns(state); + InvShiftRows(state);InvSubBytes(state);AddRoundKey(state,&key[0]); + + // Copy the state to the output array. + out[0] = state[0][0]; + out[1] = state[1][0]; + out[2] = state[2][0]; + out[3] = state[3][0]; + out[4] = state[0][1]; + out[5] = state[1][1]; + out[6] = state[2][1]; + out[7] = state[3][1]; + out[8] = state[0][2]; + out[9] = state[1][2]; + out[10] = state[2][2]; + out[11] = state[3][2]; + out[12] = state[0][3]; + out[13] = state[1][3]; + out[14] = state[2][3]; + out[15] = state[3][3]; +} + +/******************* +** AES DEBUGGING FUNCTIONS +*******************/ +/* + // This prints the "state" grid as a linear hex string. + void print_state(BYTE state[][4]) + { + int idx,idx2; + + for (idx=0; idx < 4; idx++) + for (idx2=0; idx2 < 4; idx2++) + printf("%02x",state[idx2][idx]); + printf("\n"); + } + + // This prints the key (4 consecutive ints) used for a given round as a linear hex string. + void print_rnd_key(WORD key[]) + { + int idx; + + for (idx=0; idx < 4; idx++) + printf("%08x",key[idx]); + printf("\n"); + } + */ diff --git a/version2/src/C/pbkdf2-sha256.cc b/version2/src/C/pbkdf2-sha256.cc deleted file mode 100644 index 2a78fa0..0000000 --- a/version2/src/C/pbkdf2-sha256.cc +++ /dev/null @@ -1,859 +0,0 @@ -/* - * FIPS-180-2 compliant SHA-256 implementation - * - * Copyright (C) 2006-2010, Brainspark B.V. - * - * This file is part of PolarSSL (http://www.polarssl.org) - * Lead Maintainer: Paul Bakker - * - * All rights reserved. - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation; either version 2 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License along - * with this program; if not, write to the Free Software Foundation, Inc., - * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. - */ -/* - * The SHA-256 Secure Hash Standard was published by NIST in 2002. - * - * http://csrc.nist.gov/publications/fips/fips180-2/fips180-2.pdf - */ - -#include -#include -#include -#include "pbkdf2-sha256.h" - -/* - * 32-bit integer manipulation macros (big endian) - */ -#ifndef GET_ULONG_BE -#define GET_ULONG_BE(n,b,i) \ - { \ - (n) = ( (unsigned long) (b)[(i) ] << 24) \ - | ( (unsigned long) (b)[(i) + 1] << 16) \ - | ( (unsigned long) (b)[(i) + 2] << 8) \ - | ( (unsigned long) (b)[(i) + 3]); \ - } -#endif - -#ifndef PUT_ULONG_BE -#define PUT_ULONG_BE(n,b,i) \ - { \ - (b)[(i) ] = (unsigned char) ( (n) >> 24); \ - (b)[(i) + 1] = (unsigned char) ( (n) >> 16); \ - (b)[(i) + 2] = (unsigned char) ( (n) >> 8); \ - (b)[(i) + 3] = (unsigned char) ( (n) ); \ - } -#endif - -/* - * SHA-256 context setup - */ -void sha2_starts( sha2_context *ctx, int is224 ) -{ - ctx->total[0] = 0; - ctx->total[1] = 0; - - if ( is224 == 0 ) - { - /* SHA-256 */ - ctx->state[0] = 0x6A09E667; - ctx->state[1] = 0xBB67AE85; - ctx->state[2] = 0x3C6EF372; - ctx->state[3] = 0xA54FF53A; - ctx->state[4] = 0x510E527F; - ctx->state[5] = 0x9B05688C; - ctx->state[6] = 0x1F83D9AB; - ctx->state[7] = 0x5BE0CD19; - } - else - { - /* SHA-224 */ - ctx->state[0] = 0xC1059ED8; - ctx->state[1] = 0x367CD507; - ctx->state[2] = 0x3070DD17; - ctx->state[3] = 0xF70E5939; - ctx->state[4] = 0xFFC00B31; - ctx->state[5] = 0x68581511; - ctx->state[6] = 0x64F98FA7; - ctx->state[7] = 0xBEFA4FA4; - } - - ctx->is224 = is224; -} - -static void sha2_process( sha2_context *ctx, const unsigned char data[64] ) -{ - unsigned long temp1, temp2, W[64]; - unsigned long A, B, C, D, E, F, G, H; - - GET_ULONG_BE( W[ 0], data, 0 ); - GET_ULONG_BE( W[ 1], data, 4 ); - GET_ULONG_BE( W[ 2], data, 8 ); - GET_ULONG_BE( W[ 3], data, 12 ); - GET_ULONG_BE( W[ 4], data, 16 ); - GET_ULONG_BE( W[ 5], data, 20 ); - GET_ULONG_BE( W[ 6], data, 24 ); - GET_ULONG_BE( W[ 7], data, 28 ); - GET_ULONG_BE( W[ 8], data, 32 ); - GET_ULONG_BE( W[ 9], data, 36 ); - GET_ULONG_BE( W[10], data, 40 ); - GET_ULONG_BE( W[11], data, 44 ); - GET_ULONG_BE( W[12], data, 48 ); - GET_ULONG_BE( W[13], data, 52 ); - GET_ULONG_BE( W[14], data, 56 ); - GET_ULONG_BE( W[15], data, 60 ); - -#define SHR(x,n) ((x & 0xFFFFFFFF) >> n) -#define ROTR(x,n) (SHR(x,n) | (x << (32 - n))) - -#define S0(x) (ROTR(x, 7) ^ ROTR(x,18) ^ SHR(x, 3)) -#define S1(x) (ROTR(x,17) ^ ROTR(x,19) ^ SHR(x,10)) - -#define S2(x) (ROTR(x, 2) ^ ROTR(x,13) ^ ROTR(x,22)) -#define S3(x) (ROTR(x, 6) ^ ROTR(x,11) ^ ROTR(x,25)) - -#define F0(x,y,z) ((x & y) | (z & (x | y))) -#define F1(x,y,z) (z ^ (x & (y ^ z))) - -#define R(t) \ - ( \ - W[t] = S1(W[t - 2]) + W[t - 7] + \ - S0(W[t - 15]) + W[t - 16] \ - ) - -#define P(a,b,c,d,e,f,g,h,x,K) \ - { \ - temp1 = h + S3(e) + F1(e,f,g) + K + x; \ - temp2 = S2(a) + F0(a,b,c); \ - d += temp1; h = temp1 + temp2; \ - } - - A = ctx->state[0]; - B = ctx->state[1]; - C = ctx->state[2]; - D = ctx->state[3]; - E = ctx->state[4]; - F = ctx->state[5]; - G = ctx->state[6]; - H = ctx->state[7]; - - P( A, B, C, D, E, F, G, H, W[ 0], 0x428A2F98 ); - P( H, A, B, C, D, E, F, G, W[ 1], 0x71374491 ); - P( G, H, A, B, C, D, E, F, W[ 2], 0xB5C0FBCF ); - P( F, G, H, A, B, C, D, E, W[ 3], 0xE9B5DBA5 ); - P( E, F, G, H, A, B, C, D, W[ 4], 0x3956C25B ); - P( D, E, F, G, H, A, B, C, W[ 5], 0x59F111F1 ); - P( C, D, E, F, G, H, A, B, W[ 6], 0x923F82A4 ); - P( B, C, D, E, F, G, H, A, W[ 7], 0xAB1C5ED5 ); - P( A, B, C, D, E, F, G, H, W[ 8], 0xD807AA98 ); - P( H, A, B, C, D, E, F, G, W[ 9], 0x12835B01 ); - P( G, H, A, B, C, D, E, F, W[10], 0x243185BE ); - P( F, G, H, A, B, C, D, E, W[11], 0x550C7DC3 ); - P( E, F, G, H, A, B, C, D, W[12], 0x72BE5D74 ); - P( D, E, F, G, H, A, B, C, W[13], 0x80DEB1FE ); - P( C, D, E, F, G, H, A, B, W[14], 0x9BDC06A7 ); - P( B, C, D, E, F, G, H, A, W[15], 0xC19BF174 ); - P( A, B, C, D, E, F, G, H, R(16), 0xE49B69C1 ); - P( H, A, B, C, D, E, F, G, R(17), 0xEFBE4786 ); - P( G, H, A, B, C, D, E, F, R(18), 0x0FC19DC6 ); - P( F, G, H, A, B, C, D, E, R(19), 0x240CA1CC ); - P( E, F, G, H, A, B, C, D, R(20), 0x2DE92C6F ); - P( D, E, F, G, H, A, B, C, R(21), 0x4A7484AA ); - P( C, D, E, F, G, H, A, B, R(22), 0x5CB0A9DC ); - P( B, C, D, E, F, G, H, A, R(23), 0x76F988DA ); - P( A, B, C, D, E, F, G, H, R(24), 0x983E5152 ); - P( H, A, B, C, D, E, F, G, R(25), 0xA831C66D ); - P( G, H, A, B, C, D, E, F, R(26), 0xB00327C8 ); - P( F, G, H, A, B, C, D, E, R(27), 0xBF597FC7 ); - P( E, F, G, H, A, B, C, D, R(28), 0xC6E00BF3 ); - P( D, E, F, G, H, A, B, C, R(29), 0xD5A79147 ); - P( C, D, E, F, G, H, A, B, R(30), 0x06CA6351 ); - P( B, C, D, E, F, G, H, A, R(31), 0x14292967 ); - P( A, B, C, D, E, F, G, H, R(32), 0x27B70A85 ); - P( H, A, B, C, D, E, F, G, R(33), 0x2E1B2138 ); - P( G, H, A, B, C, D, E, F, R(34), 0x4D2C6DFC ); - P( F, G, H, A, B, C, D, E, R(35), 0x53380D13 ); - P( E, F, G, H, A, B, C, D, R(36), 0x650A7354 ); - P( D, E, F, G, H, A, B, C, R(37), 0x766A0ABB ); - P( C, D, E, F, G, H, A, B, R(38), 0x81C2C92E ); - P( B, C, D, E, F, G, H, A, R(39), 0x92722C85 ); - P( A, B, C, D, E, F, G, H, R(40), 0xA2BFE8A1 ); - P( H, A, B, C, D, E, F, G, R(41), 0xA81A664B ); - P( G, H, A, B, C, D, E, F, R(42), 0xC24B8B70 ); - P( F, G, H, A, B, C, D, E, R(43), 0xC76C51A3 ); - P( E, F, G, H, A, B, C, D, R(44), 0xD192E819 ); - P( D, E, F, G, H, A, B, C, R(45), 0xD6990624 ); - P( C, D, E, F, G, H, A, B, R(46), 0xF40E3585 ); - P( B, C, D, E, F, G, H, A, R(47), 0x106AA070 ); - P( A, B, C, D, E, F, G, H, R(48), 0x19A4C116 ); - P( H, A, B, C, D, E, F, G, R(49), 0x1E376C08 ); - P( G, H, A, B, C, D, E, F, R(50), 0x2748774C ); - P( F, G, H, A, B, C, D, E, R(51), 0x34B0BCB5 ); - P( E, F, G, H, A, B, C, D, R(52), 0x391C0CB3 ); - P( D, E, F, G, H, A, B, C, R(53), 0x4ED8AA4A ); - P( C, D, E, F, G, H, A, B, R(54), 0x5B9CCA4F ); - P( B, C, D, E, F, G, H, A, R(55), 0x682E6FF3 ); - P( A, B, C, D, E, F, G, H, R(56), 0x748F82EE ); - P( H, A, B, C, D, E, F, G, R(57), 0x78A5636F ); - P( G, H, A, B, C, D, E, F, R(58), 0x84C87814 ); - P( F, G, H, A, B, C, D, E, R(59), 0x8CC70208 ); - P( E, F, G, H, A, B, C, D, R(60), 0x90BEFFFA ); - P( D, E, F, G, H, A, B, C, R(61), 0xA4506CEB ); - P( C, D, E, F, G, H, A, B, R(62), 0xBEF9A3F7 ); - P( B, C, D, E, F, G, H, A, R(63), 0xC67178F2 ); - - ctx->state[0] += A; - ctx->state[1] += B; - ctx->state[2] += C; - ctx->state[3] += D; - ctx->state[4] += E; - ctx->state[5] += F; - ctx->state[6] += G; - ctx->state[7] += H; -} - -/* - * SHA-256 process buffer - */ -void sha2_update( sha2_context *ctx, const unsigned char *input, size_t ilen ) -{ - size_t fill; - unsigned long left; - - if ( ilen <= 0 ) - return; - - left = ctx->total[0] & 0x3F; - fill = 64 - left; - - ctx->total[0] += (unsigned long) ilen; - ctx->total[0] &= 0xFFFFFFFF; - - if ( ctx->total[0] < (unsigned long) ilen ) - ctx->total[1]++; - - if ( left && ilen >= fill ) - { - memcpy( (void *) (ctx->buffer + left), - (void *) input, fill ); - sha2_process( ctx, ctx->buffer ); - input += fill; - ilen -= fill; - left = 0; - } - - while ( ilen >= 64 ) - { - sha2_process( ctx, input ); - input += 64; - ilen -= 64; - } - - if ( ilen > 0 ) - { - memcpy( (void *) (ctx->buffer + left), - (void *) input, ilen ); - } -} - -static const unsigned char sha2_padding[64] = -{ - 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 -}; - -/* - * SHA-256 final digest - */ -void sha2_finish( sha2_context *ctx, unsigned char output[32] ) -{ - unsigned long last, padn; - unsigned long high, low; - unsigned char msglen[8]; - - high = (ctx->total[0] >> 29) - | (ctx->total[1] << 3); - low = (ctx->total[0] << 3); - - PUT_ULONG_BE( high, msglen, 0 ); - PUT_ULONG_BE( low, msglen, 4 ); - - last = ctx->total[0] & 0x3F; - padn = (last < 56) ? (56 - last) : (120 - last); - - sha2_update( ctx, (unsigned char *) sha2_padding, padn ); - sha2_update( ctx, msglen, 8 ); - - PUT_ULONG_BE( ctx->state[0], output, 0 ); - PUT_ULONG_BE( ctx->state[1], output, 4 ); - PUT_ULONG_BE( ctx->state[2], output, 8 ); - PUT_ULONG_BE( ctx->state[3], output, 12 ); - PUT_ULONG_BE( ctx->state[4], output, 16 ); - PUT_ULONG_BE( ctx->state[5], output, 20 ); - PUT_ULONG_BE( ctx->state[6], output, 24 ); - - if ( ctx->is224 == 0 ) - PUT_ULONG_BE( ctx->state[7], output, 28 ); -} - -/* - * output = SHA-256( input buffer ) - */ -void sha2( const unsigned char *input, size_t ilen, - unsigned char output[32], int is224 ) -{ - sha2_context ctx; - - sha2_starts( &ctx, is224 ); - sha2_update( &ctx, input, ilen ); - sha2_finish( &ctx, output ); - - memset( &ctx, 0, sizeof(sha2_context) ); -} - -/* - * SHA-256 HMAC context setup - */ -void sha2_hmac_starts( sha2_context *ctx, const unsigned char *key, size_t keylen, - int is224 ) -{ - size_t i; - unsigned char sum[32]; - - if ( keylen > 64 ) - { - sha2( key, keylen, sum, is224 ); - keylen = (is224) ? 28 : 32; - key = sum; - } - - memset( ctx->ipad, 0x36, 64 ); - memset( ctx->opad, 0x5C, 64 ); - - for ( i = 0; i < keylen; i++ ) - { - ctx->ipad[i] = (unsigned char)(ctx->ipad[i] ^ key[i]); - ctx->opad[i] = (unsigned char)(ctx->opad[i] ^ key[i]); - } - - sha2_starts( ctx, is224 ); - sha2_update( ctx, ctx->ipad, 64 ); - - memset( sum, 0, sizeof(sum) ); -} - -/* - * SHA-256 HMAC process buffer - */ -void sha2_hmac_update( sha2_context *ctx, const unsigned char *input, size_t ilen ) -{ - sha2_update( ctx, input, ilen ); -} - -/* - * SHA-256 HMAC final digest - */ -void sha2_hmac_finish( sha2_context *ctx, unsigned char output[32] ) -{ - int is224, hlen; - unsigned char tmpbuf[32]; - - is224 = ctx->is224; - hlen = (is224 == 0) ? 32 : 28; - - sha2_finish( ctx, tmpbuf ); - sha2_starts( ctx, is224 ); - sha2_update( ctx, ctx->opad, 64 ); - sha2_update( ctx, tmpbuf, hlen ); - sha2_finish( ctx, output ); - - memset( tmpbuf, 0, sizeof(tmpbuf) ); -} - -/* - * SHA-256 HMAC context reset - */ -void sha2_hmac_reset( sha2_context *ctx ) -{ - sha2_starts( ctx, ctx->is224 ); - sha2_update( ctx, ctx->ipad, 64 ); -} - -/* - * output = HMAC-SHA-256( hmac key, input buffer ) - */ -void sha2_hmac( const unsigned char *key, size_t keylen, - const unsigned char *input, size_t ilen, - unsigned char output[32], int is224 ) -{ - sha2_context ctx; - - sha2_hmac_starts( &ctx, key, keylen, is224 ); - sha2_hmac_update( &ctx, input, ilen ); - sha2_hmac_finish( &ctx, output ); - - memset( &ctx, 0, sizeof(sha2_context) ); -} - - - - - -#ifndef min -#define min( a, b ) ( ((a) < (b)) ? (a) : (b) ) -#endif - -void PKCS5_PBKDF2_HMAC(unsigned char *password, size_t plen, - unsigned char *salt, size_t slen, - const unsigned long iteration_count, const unsigned long key_length, - unsigned char *output) -{ - sha2_context ctx; - sha2_starts(&ctx, 0); - - // Size of the generated digest - unsigned char md_size = 32; - unsigned char md1[32]; - unsigned char work[32]; - - unsigned long counter = 1; - unsigned long generated_key_length = 0; - while (generated_key_length < key_length) { - // U1 ends up in md1 and work - unsigned char c[4]; - c[0] = (counter >> 24) & 0xff; - c[1] = (counter >> 16) & 0xff; - c[2] = (counter >> 8) & 0xff; - c[3] = (counter >> 0) & 0xff; - - sha2_hmac_starts(&ctx, password, plen, 0); - sha2_hmac_update(&ctx, salt, slen); - sha2_hmac_update(&ctx, c, 4); - sha2_hmac_finish(&ctx, md1); - memcpy(work, md1, md_size); - - unsigned long ic = 1; - for (ic = 1; ic < iteration_count; ic++) { - // U2 ends up in md1 - sha2_hmac_starts(&ctx, password, plen, 0); - sha2_hmac_update(&ctx, md1, md_size); - sha2_hmac_finish(&ctx, md1); - // U1 xor U2 - unsigned long i = 0; - for (i = 0; i < md_size; i++) { - work[i] ^= md1[i]; - } - // and so on until iteration_count - } - - // Copy the generated bytes to the key - unsigned long bytes_to_write = - min((key_length - generated_key_length), md_size); - memcpy(output + generated_key_length, work, bytes_to_write); - generated_key_length += bytes_to_write; - ++counter; - } -} - - - - - - - - - -#ifdef TEST -/* - * FIPS-180-2 test vectors - */ -static unsigned char sha2_test_buf[3][57] = -{ - { "abc" }, - { "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" }, - { "" } -}; - -static const int sha2_test_buflen[3] = -{ - 3, 56, 1000 -}; - -static const unsigned char sha2_test_sum[6][32] = -{ - /* - * SHA-224 test vectors - */ - { 0x23, 0x09, 0x7D, 0x22, 0x34, 0x05, 0xD8, 0x22, - 0x86, 0x42, 0xA4, 0x77, 0xBD, 0xA2, 0x55, 0xB3, - 0x2A, 0xAD, 0xBC, 0xE4, 0xBD, 0xA0, 0xB3, 0xF7, - 0xE3, 0x6C, 0x9D, 0xA7 }, - { 0x75, 0x38, 0x8B, 0x16, 0x51, 0x27, 0x76, 0xCC, - 0x5D, 0xBA, 0x5D, 0xA1, 0xFD, 0x89, 0x01, 0x50, - 0xB0, 0xC6, 0x45, 0x5C, 0xB4, 0xF5, 0x8B, 0x19, - 0x52, 0x52, 0x25, 0x25 }, - { 0x20, 0x79, 0x46, 0x55, 0x98, 0x0C, 0x91, 0xD8, - 0xBB, 0xB4, 0xC1, 0xEA, 0x97, 0x61, 0x8A, 0x4B, - 0xF0, 0x3F, 0x42, 0x58, 0x19, 0x48, 0xB2, 0xEE, - 0x4E, 0xE7, 0xAD, 0x67 }, - - /* - * SHA-256 test vectors - */ - { 0xBA, 0x78, 0x16, 0xBF, 0x8F, 0x01, 0xCF, 0xEA, - 0x41, 0x41, 0x40, 0xDE, 0x5D, 0xAE, 0x22, 0x23, - 0xB0, 0x03, 0x61, 0xA3, 0x96, 0x17, 0x7A, 0x9C, - 0xB4, 0x10, 0xFF, 0x61, 0xF2, 0x00, 0x15, 0xAD }, - { 0x24, 0x8D, 0x6A, 0x61, 0xD2, 0x06, 0x38, 0xB8, - 0xE5, 0xC0, 0x26, 0x93, 0x0C, 0x3E, 0x60, 0x39, - 0xA3, 0x3C, 0xE4, 0x59, 0x64, 0xFF, 0x21, 0x67, - 0xF6, 0xEC, 0xED, 0xD4, 0x19, 0xDB, 0x06, 0xC1 }, - { 0xCD, 0xC7, 0x6E, 0x5C, 0x99, 0x14, 0xFB, 0x92, - 0x81, 0xA1, 0xC7, 0xE2, 0x84, 0xD7, 0x3E, 0x67, - 0xF1, 0x80, 0x9A, 0x48, 0xA4, 0x97, 0x20, 0x0E, - 0x04, 0x6D, 0x39, 0xCC, 0xC7, 0x11, 0x2C, 0xD0 } -}; - -/* - * RFC 4231 test vectors - */ -static unsigned char sha2_hmac_test_key[7][26] = { - {"\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B" - "\x0B\x0B\x0B\x0B"}, - {"Jefe"}, - {"\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA" - "\xAA\xAA\xAA\xAA"}, - {"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F\x10" - "\x11\x12\x13\x14\x15\x16\x17\x18\x19"}, - {"\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C" - "\x0C\x0C\x0C\x0C"}, - {""}, /* 0xAA 131 times */ - {""} -}; - -static const int sha2_hmac_test_keylen[7] = { - 20, 4, 20, 25, 20, 131, 131 -}; - -static unsigned char sha2_hmac_test_buf[7][153] = -{ - { "Hi There" }, - { "what do ya want for nothing?" }, - { "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" - "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" - "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" - "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" - "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" }, - { "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" - "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" - "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" - "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" - "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" }, - { "Test With Truncation" }, - { "Test Using Larger Than Block-Size Key - Hash Key First" }, - { "This is a test using a larger than block-size key " - "and a larger than block-size data. The key needs to " - "be hashed before being used by the HMAC algorithm." } -}; - -static const int sha2_hmac_test_buflen[7] = -{ - 8, 28, 50, 50, 20, 54, 152 -}; - -static const unsigned char sha2_hmac_test_sum[14][32] = -{ - /* - * HMAC-SHA-224 test vectors - */ - { 0x89, 0x6F, 0xB1, 0x12, 0x8A, 0xBB, 0xDF, 0x19, - 0x68, 0x32, 0x10, 0x7C, 0xD4, 0x9D, 0xF3, 0x3F, - 0x47, 0xB4, 0xB1, 0x16, 0x99, 0x12, 0xBA, 0x4F, - 0x53, 0x68, 0x4B, 0x22 }, - { 0xA3, 0x0E, 0x01, 0x09, 0x8B, 0xC6, 0xDB, 0xBF, - 0x45, 0x69, 0x0F, 0x3A, 0x7E, 0x9E, 0x6D, 0x0F, - 0x8B, 0xBE, 0xA2, 0xA3, 0x9E, 0x61, 0x48, 0x00, - 0x8F, 0xD0, 0x5E, 0x44 }, - { 0x7F, 0xB3, 0xCB, 0x35, 0x88, 0xC6, 0xC1, 0xF6, - 0xFF, 0xA9, 0x69, 0x4D, 0x7D, 0x6A, 0xD2, 0x64, - 0x93, 0x65, 0xB0, 0xC1, 0xF6, 0x5D, 0x69, 0xD1, - 0xEC, 0x83, 0x33, 0xEA }, - { 0x6C, 0x11, 0x50, 0x68, 0x74, 0x01, 0x3C, 0xAC, - 0x6A, 0x2A, 0xBC, 0x1B, 0xB3, 0x82, 0x62, 0x7C, - 0xEC, 0x6A, 0x90, 0xD8, 0x6E, 0xFC, 0x01, 0x2D, - 0xE7, 0xAF, 0xEC, 0x5A }, - { 0x0E, 0x2A, 0xEA, 0x68, 0xA9, 0x0C, 0x8D, 0x37, - 0xC9, 0x88, 0xBC, 0xDB, 0x9F, 0xCA, 0x6F, 0xA8 }, - { 0x95, 0xE9, 0xA0, 0xDB, 0x96, 0x20, 0x95, 0xAD, - 0xAE, 0xBE, 0x9B, 0x2D, 0x6F, 0x0D, 0xBC, 0xE2, - 0xD4, 0x99, 0xF1, 0x12, 0xF2, 0xD2, 0xB7, 0x27, - 0x3F, 0xA6, 0x87, 0x0E }, - { 0x3A, 0x85, 0x41, 0x66, 0xAC, 0x5D, 0x9F, 0x02, - 0x3F, 0x54, 0xD5, 0x17, 0xD0, 0xB3, 0x9D, 0xBD, - 0x94, 0x67, 0x70, 0xDB, 0x9C, 0x2B, 0x95, 0xC9, - 0xF6, 0xF5, 0x65, 0xD1 }, - - /* - * HMAC-SHA-256 test vectors - */ - { 0xB0, 0x34, 0x4C, 0x61, 0xD8, 0xDB, 0x38, 0x53, - 0x5C, 0xA8, 0xAF, 0xCE, 0xAF, 0x0B, 0xF1, 0x2B, - 0x88, 0x1D, 0xC2, 0x00, 0xC9, 0x83, 0x3D, 0xA7, - 0x26, 0xE9, 0x37, 0x6C, 0x2E, 0x32, 0xCF, 0xF7 }, - { 0x5B, 0xDC, 0xC1, 0x46, 0xBF, 0x60, 0x75, 0x4E, - 0x6A, 0x04, 0x24, 0x26, 0x08, 0x95, 0x75, 0xC7, - 0x5A, 0x00, 0x3F, 0x08, 0x9D, 0x27, 0x39, 0x83, - 0x9D, 0xEC, 0x58, 0xB9, 0x64, 0xEC, 0x38, 0x43 }, - { 0x77, 0x3E, 0xA9, 0x1E, 0x36, 0x80, 0x0E, 0x46, - 0x85, 0x4D, 0xB8, 0xEB, 0xD0, 0x91, 0x81, 0xA7, - 0x29, 0x59, 0x09, 0x8B, 0x3E, 0xF8, 0xC1, 0x22, - 0xD9, 0x63, 0x55, 0x14, 0xCE, 0xD5, 0x65, 0xFE }, - { 0x82, 0x55, 0x8A, 0x38, 0x9A, 0x44, 0x3C, 0x0E, - 0xA4, 0xCC, 0x81, 0x98, 0x99, 0xF2, 0x08, 0x3A, - 0x85, 0xF0, 0xFA, 0xA3, 0xE5, 0x78, 0xF8, 0x07, - 0x7A, 0x2E, 0x3F, 0xF4, 0x67, 0x29, 0x66, 0x5B }, - { 0xA3, 0xB6, 0x16, 0x74, 0x73, 0x10, 0x0E, 0xE0, - 0x6E, 0x0C, 0x79, 0x6C, 0x29, 0x55, 0x55, 0x2B }, - { 0x60, 0xE4, 0x31, 0x59, 0x1E, 0xE0, 0xB6, 0x7F, - 0x0D, 0x8A, 0x26, 0xAA, 0xCB, 0xF5, 0xB7, 0x7F, - 0x8E, 0x0B, 0xC6, 0x21, 0x37, 0x28, 0xC5, 0x14, - 0x05, 0x46, 0x04, 0x0F, 0x0E, 0xE3, 0x7F, 0x54 }, - { 0x9B, 0x09, 0xFF, 0xA7, 0x1B, 0x94, 0x2F, 0xCB, - 0x27, 0x63, 0x5F, 0xBC, 0xD5, 0xB0, 0xE9, 0x44, - 0xBF, 0xDC, 0x63, 0x64, 0x4F, 0x07, 0x13, 0x93, - 0x8A, 0x7F, 0x51, 0x53, 0x5C, 0x3A, 0x35, 0xE2 } -}; -typedef struct { - char *t; - char *p; - int plen; - char *s; - int slen; - int c; - int dkLen; - char dk[1024]; // Remember to set this to max dkLen -} testvector; - -int do_test(testvector *tv) -{ - printf("Started %s\n", tv->t); - fflush(stdout); - char *key = malloc(tv->dkLen); - if (key == 0) { - return -1; - } - - PKCS5_PBKDF2_HMAC((unsigned char *)tv->p, tv->plen, - (unsigned char *)tv->s, tv->slen, tv->c, - tv->dkLen, (unsigned char *)key); - - if (memcmp(tv->dk, key, tv->dkLen) != 0) { - // Failed - return -1; - } - - return 0; -} - -/* - * Checkup routine - */ -int main() -{ - int verbose = 1; - int i, j, k, buflen; - unsigned char buf[1024]; - unsigned char sha2sum[32]; - sha2_context ctx; - - for (i = 0; i < 6; i++) { - j = i % 3; - k = i < 3; - - if (verbose != 0) - printf(" SHA-%d test #%d: ", 256 - k * 32, j + 1); - - sha2_starts(&ctx, k); - - if (j == 2) { - memset(buf, 'a', buflen = 1000); - - for (j = 0; j < 1000; j++) - sha2_update(&ctx, buf, buflen); - } else - sha2_update(&ctx, sha2_test_buf[j], - sha2_test_buflen[j]); - - sha2_finish(&ctx, sha2sum); - - if (memcmp(sha2sum, sha2_test_sum[i], 32 - k * 4) != 0) { - if (verbose != 0) - printf("failed\n"); - - return (1); - } - - if (verbose != 0) - printf("passed\n"); - } - - if (verbose != 0) - printf("\n"); - - for (i = 0; i < 14; i++) { - j = i % 7; - k = i < 7; - - if (verbose != 0) - printf(" HMAC-SHA-%d test #%d: ", 256 - k * 32, - j + 1); - - if (j == 5 || j == 6) { - memset(buf, '\xAA', buflen = 131); - sha2_hmac_starts(&ctx, buf, buflen, k); - } else - sha2_hmac_starts(&ctx, sha2_hmac_test_key[j], - sha2_hmac_test_keylen[j], k); - - sha2_hmac_update(&ctx, sha2_hmac_test_buf[j], - sha2_hmac_test_buflen[j]); - - sha2_hmac_finish(&ctx, sha2sum); - - buflen = (j == 4) ? 16 : 32 - k * 4; - - if (memcmp(sha2sum, sha2_hmac_test_sum[i], buflen) != 0) { - if (verbose != 0) - printf("failed\n"); - - return (1); - } - - if (verbose != 0) - printf("passed\n"); - } - - if (verbose != 0) - printf("\n"); - - testvector *tv = 0; - int res = 0; - - testvector t1 = { - "Test 1", - "password", 8, "salt", 4, 1, 32, - .dk = { 0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c, - 0x43, 0xe7, 0x22, 0x52, 0x56, 0xc4, 0xf8, 0x37, - 0xa8, 0x65, 0x48, 0xc9, 0x2c, 0xcc, 0x35, 0x48, - 0x08, 0x05, 0x98, 0x7c, 0xb7, 0x0b, 0xe1, 0x7b } - }; - - tv = &t1; - res = do_test(tv); - if (res != 0) { - printf("%s failed\n", tv->t); - return res; - } - - testvector t2 = { - "Test 2", - "password", 8, "salt", 4, 2, 32, { - 0xae, 0x4d, 0x0c, 0x95, 0xaf, 0x6b, 0x46, 0xd3, - 0x2d, 0x0a, 0xdf, 0xf9, 0x28, 0xf0, 0x6d, 0xd0, - 0x2a, 0x30, 0x3f, 0x8e, 0xf3, 0xc2, 0x51, 0xdf, - 0xd6, 0xe2, 0xd8, 0x5a, 0x95, 0x47, 0x4c, 0x43 - } - }; - - tv = &t2; - res = do_test(tv); - if (res != 0) { - printf("%s failed\n", tv->t); - return res; - } - - testvector t3 = { - "Test 3", - "password", 8, "salt", 4, 4096, 32, { - 0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41, - 0xaa, 0x53, 0x0d, 0xb6, 0x84, 0x5c, 0x4c, 0x8d, - 0x96, 0x28, 0x93, 0xa0, 0x01, 0xce, 0x4e, 0x11, - 0xa4, 0x96, 0x38, 0x73, 0xaa, 0x98, 0x13, 0x4a - } - }; - - tv = &t3; - res = do_test(tv); - if (res != 0) { - printf("%s failed\n", tv->t); - return res; - } - - testvector t4 = { - "Test 4", - "password", 8, "salt", 4, 16777216, 32, { - 0xcf, 0x81, 0xc6, 0x6f, 0xe8, 0xcf, 0xc0, 0x4d, - 0x1f, 0x31, 0xec, 0xb6, 0x5d, 0xab, 0x40, 0x89, - 0xf7, 0xf1, 0x79, 0xe8, 0x9b, 0x3b, 0x0b, 0xcb, - 0x17, 0xad, 0x10, 0xe3, 0xac, 0x6e, 0xba, 0x46 - } - }; - - tv = &t4; - // res = do_test(tv); - if (res != 0) { - printf("%s failed\n", tv->t); - return res; - } - - testvector t5 = { - "Test 5", - "passwordPASSWORDpassword", 24, - "saltSALTsaltSALTsaltSALTsaltSALTsalt", 36, 4096, 40, { - 0x34, 0x8c, 0x89, 0xdb, 0xcb, 0xd3, 0x2b, 0x2f, - 0x32, 0xd8, 0x14, 0xb8, 0x11, 0x6e, 0x84, 0xcf, - 0x2b, 0x17, 0x34, 0x7e, 0xbc, 0x18, 0x00, 0x18, - 0x1c, 0x4e, 0x2a, 0x1f, 0xb8, 0xdd, 0x53, 0xe1, - 0xc6, 0x35, 0x51, 0x8c, 0x7d, 0xac, 0x47, 0xe9 - } - }; - - tv = &t5; - res = do_test(tv); - if (res != 0) { - printf("%s failed\n", tv->t); - return res; - } - - testvector t6 = { - "Test 6", - "pass\0word", 9, "sa\0lt", 5, 4096, 16, { - 0x89, 0xb6, 0x9d, 0x05, 0x16, 0xf8, 0x29, 0x89, - 0x3c, 0x69, 0x62, 0x26, 0x65, 0x0a, 0x86, 0x87 - } - }; - - tv = &t6; - res = do_test(tv); - if (res != 0) { - printf("%s failed\n", tv->t); - return res; - } - - return (0); -} - -#endif diff --git a/version2/src/C/pbkdf2-sha256.cpp b/version2/src/C/pbkdf2-sha256.cpp new file mode 100644 index 0000000..2a78fa0 --- /dev/null +++ b/version2/src/C/pbkdf2-sha256.cpp @@ -0,0 +1,859 @@ +/* + * FIPS-180-2 compliant SHA-256 implementation + * + * Copyright (C) 2006-2010, Brainspark B.V. + * + * This file is part of PolarSSL (http://www.polarssl.org) + * Lead Maintainer: Paul Bakker + * + * All rights reserved. + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License along + * with this program; if not, write to the Free Software Foundation, Inc., + * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + */ +/* + * The SHA-256 Secure Hash Standard was published by NIST in 2002. + * + * http://csrc.nist.gov/publications/fips/fips180-2/fips180-2.pdf + */ + +#include +#include +#include +#include "pbkdf2-sha256.h" + +/* + * 32-bit integer manipulation macros (big endian) + */ +#ifndef GET_ULONG_BE +#define GET_ULONG_BE(n,b,i) \ + { \ + (n) = ( (unsigned long) (b)[(i) ] << 24) \ + | ( (unsigned long) (b)[(i) + 1] << 16) \ + | ( (unsigned long) (b)[(i) + 2] << 8) \ + | ( (unsigned long) (b)[(i) + 3]); \ + } +#endif + +#ifndef PUT_ULONG_BE +#define PUT_ULONG_BE(n,b,i) \ + { \ + (b)[(i) ] = (unsigned char) ( (n) >> 24); \ + (b)[(i) + 1] = (unsigned char) ( (n) >> 16); \ + (b)[(i) + 2] = (unsigned char) ( (n) >> 8); \ + (b)[(i) + 3] = (unsigned char) ( (n) ); \ + } +#endif + +/* + * SHA-256 context setup + */ +void sha2_starts( sha2_context *ctx, int is224 ) +{ + ctx->total[0] = 0; + ctx->total[1] = 0; + + if ( is224 == 0 ) + { + /* SHA-256 */ + ctx->state[0] = 0x6A09E667; + ctx->state[1] = 0xBB67AE85; + ctx->state[2] = 0x3C6EF372; + ctx->state[3] = 0xA54FF53A; + ctx->state[4] = 0x510E527F; + ctx->state[5] = 0x9B05688C; + ctx->state[6] = 0x1F83D9AB; + ctx->state[7] = 0x5BE0CD19; + } + else + { + /* SHA-224 */ + ctx->state[0] = 0xC1059ED8; + ctx->state[1] = 0x367CD507; + ctx->state[2] = 0x3070DD17; + ctx->state[3] = 0xF70E5939; + ctx->state[4] = 0xFFC00B31; + ctx->state[5] = 0x68581511; + ctx->state[6] = 0x64F98FA7; + ctx->state[7] = 0xBEFA4FA4; + } + + ctx->is224 = is224; +} + +static void sha2_process( sha2_context *ctx, const unsigned char data[64] ) +{ + unsigned long temp1, temp2, W[64]; + unsigned long A, B, C, D, E, F, G, H; + + GET_ULONG_BE( W[ 0], data, 0 ); + GET_ULONG_BE( W[ 1], data, 4 ); + GET_ULONG_BE( W[ 2], data, 8 ); + GET_ULONG_BE( W[ 3], data, 12 ); + GET_ULONG_BE( W[ 4], data, 16 ); + GET_ULONG_BE( W[ 5], data, 20 ); + GET_ULONG_BE( W[ 6], data, 24 ); + GET_ULONG_BE( W[ 7], data, 28 ); + GET_ULONG_BE( W[ 8], data, 32 ); + GET_ULONG_BE( W[ 9], data, 36 ); + GET_ULONG_BE( W[10], data, 40 ); + GET_ULONG_BE( W[11], data, 44 ); + GET_ULONG_BE( W[12], data, 48 ); + GET_ULONG_BE( W[13], data, 52 ); + GET_ULONG_BE( W[14], data, 56 ); + GET_ULONG_BE( W[15], data, 60 ); + +#define SHR(x,n) ((x & 0xFFFFFFFF) >> n) +#define ROTR(x,n) (SHR(x,n) | (x << (32 - n))) + +#define S0(x) (ROTR(x, 7) ^ ROTR(x,18) ^ SHR(x, 3)) +#define S1(x) (ROTR(x,17) ^ ROTR(x,19) ^ SHR(x,10)) + +#define S2(x) (ROTR(x, 2) ^ ROTR(x,13) ^ ROTR(x,22)) +#define S3(x) (ROTR(x, 6) ^ ROTR(x,11) ^ ROTR(x,25)) + +#define F0(x,y,z) ((x & y) | (z & (x | y))) +#define F1(x,y,z) (z ^ (x & (y ^ z))) + +#define R(t) \ + ( \ + W[t] = S1(W[t - 2]) + W[t - 7] + \ + S0(W[t - 15]) + W[t - 16] \ + ) + +#define P(a,b,c,d,e,f,g,h,x,K) \ + { \ + temp1 = h + S3(e) + F1(e,f,g) + K + x; \ + temp2 = S2(a) + F0(a,b,c); \ + d += temp1; h = temp1 + temp2; \ + } + + A = ctx->state[0]; + B = ctx->state[1]; + C = ctx->state[2]; + D = ctx->state[3]; + E = ctx->state[4]; + F = ctx->state[5]; + G = ctx->state[6]; + H = ctx->state[7]; + + P( A, B, C, D, E, F, G, H, W[ 0], 0x428A2F98 ); + P( H, A, B, C, D, E, F, G, W[ 1], 0x71374491 ); + P( G, H, A, B, C, D, E, F, W[ 2], 0xB5C0FBCF ); + P( F, G, H, A, B, C, D, E, W[ 3], 0xE9B5DBA5 ); + P( E, F, G, H, A, B, C, D, W[ 4], 0x3956C25B ); + P( D, E, F, G, H, A, B, C, W[ 5], 0x59F111F1 ); + P( C, D, E, F, G, H, A, B, W[ 6], 0x923F82A4 ); + P( B, C, D, E, F, G, H, A, W[ 7], 0xAB1C5ED5 ); + P( A, B, C, D, E, F, G, H, W[ 8], 0xD807AA98 ); + P( H, A, B, C, D, E, F, G, W[ 9], 0x12835B01 ); + P( G, H, A, B, C, D, E, F, W[10], 0x243185BE ); + P( F, G, H, A, B, C, D, E, W[11], 0x550C7DC3 ); + P( E, F, G, H, A, B, C, D, W[12], 0x72BE5D74 ); + P( D, E, F, G, H, A, B, C, W[13], 0x80DEB1FE ); + P( C, D, E, F, G, H, A, B, W[14], 0x9BDC06A7 ); + P( B, C, D, E, F, G, H, A, W[15], 0xC19BF174 ); + P( A, B, C, D, E, F, G, H, R(16), 0xE49B69C1 ); + P( H, A, B, C, D, E, F, G, R(17), 0xEFBE4786 ); + P( G, H, A, B, C, D, E, F, R(18), 0x0FC19DC6 ); + P( F, G, H, A, B, C, D, E, R(19), 0x240CA1CC ); + P( E, F, G, H, A, B, C, D, R(20), 0x2DE92C6F ); + P( D, E, F, G, H, A, B, C, R(21), 0x4A7484AA ); + P( C, D, E, F, G, H, A, B, R(22), 0x5CB0A9DC ); + P( B, C, D, E, F, G, H, A, R(23), 0x76F988DA ); + P( A, B, C, D, E, F, G, H, R(24), 0x983E5152 ); + P( H, A, B, C, D, E, F, G, R(25), 0xA831C66D ); + P( G, H, A, B, C, D, E, F, R(26), 0xB00327C8 ); + P( F, G, H, A, B, C, D, E, R(27), 0xBF597FC7 ); + P( E, F, G, H, A, B, C, D, R(28), 0xC6E00BF3 ); + P( D, E, F, G, H, A, B, C, R(29), 0xD5A79147 ); + P( C, D, E, F, G, H, A, B, R(30), 0x06CA6351 ); + P( B, C, D, E, F, G, H, A, R(31), 0x14292967 ); + P( A, B, C, D, E, F, G, H, R(32), 0x27B70A85 ); + P( H, A, B, C, D, E, F, G, R(33), 0x2E1B2138 ); + P( G, H, A, B, C, D, E, F, R(34), 0x4D2C6DFC ); + P( F, G, H, A, B, C, D, E, R(35), 0x53380D13 ); + P( E, F, G, H, A, B, C, D, R(36), 0x650A7354 ); + P( D, E, F, G, H, A, B, C, R(37), 0x766A0ABB ); + P( C, D, E, F, G, H, A, B, R(38), 0x81C2C92E ); + P( B, C, D, E, F, G, H, A, R(39), 0x92722C85 ); + P( A, B, C, D, E, F, G, H, R(40), 0xA2BFE8A1 ); + P( H, A, B, C, D, E, F, G, R(41), 0xA81A664B ); + P( G, H, A, B, C, D, E, F, R(42), 0xC24B8B70 ); + P( F, G, H, A, B, C, D, E, R(43), 0xC76C51A3 ); + P( E, F, G, H, A, B, C, D, R(44), 0xD192E819 ); + P( D, E, F, G, H, A, B, C, R(45), 0xD6990624 ); + P( C, D, E, F, G, H, A, B, R(46), 0xF40E3585 ); + P( B, C, D, E, F, G, H, A, R(47), 0x106AA070 ); + P( A, B, C, D, E, F, G, H, R(48), 0x19A4C116 ); + P( H, A, B, C, D, E, F, G, R(49), 0x1E376C08 ); + P( G, H, A, B, C, D, E, F, R(50), 0x2748774C ); + P( F, G, H, A, B, C, D, E, R(51), 0x34B0BCB5 ); + P( E, F, G, H, A, B, C, D, R(52), 0x391C0CB3 ); + P( D, E, F, G, H, A, B, C, R(53), 0x4ED8AA4A ); + P( C, D, E, F, G, H, A, B, R(54), 0x5B9CCA4F ); + P( B, C, D, E, F, G, H, A, R(55), 0x682E6FF3 ); + P( A, B, C, D, E, F, G, H, R(56), 0x748F82EE ); + P( H, A, B, C, D, E, F, G, R(57), 0x78A5636F ); + P( G, H, A, B, C, D, E, F, R(58), 0x84C87814 ); + P( F, G, H, A, B, C, D, E, R(59), 0x8CC70208 ); + P( E, F, G, H, A, B, C, D, R(60), 0x90BEFFFA ); + P( D, E, F, G, H, A, B, C, R(61), 0xA4506CEB ); + P( C, D, E, F, G, H, A, B, R(62), 0xBEF9A3F7 ); + P( B, C, D, E, F, G, H, A, R(63), 0xC67178F2 ); + + ctx->state[0] += A; + ctx->state[1] += B; + ctx->state[2] += C; + ctx->state[3] += D; + ctx->state[4] += E; + ctx->state[5] += F; + ctx->state[6] += G; + ctx->state[7] += H; +} + +/* + * SHA-256 process buffer + */ +void sha2_update( sha2_context *ctx, const unsigned char *input, size_t ilen ) +{ + size_t fill; + unsigned long left; + + if ( ilen <= 0 ) + return; + + left = ctx->total[0] & 0x3F; + fill = 64 - left; + + ctx->total[0] += (unsigned long) ilen; + ctx->total[0] &= 0xFFFFFFFF; + + if ( ctx->total[0] < (unsigned long) ilen ) + ctx->total[1]++; + + if ( left && ilen >= fill ) + { + memcpy( (void *) (ctx->buffer + left), + (void *) input, fill ); + sha2_process( ctx, ctx->buffer ); + input += fill; + ilen -= fill; + left = 0; + } + + while ( ilen >= 64 ) + { + sha2_process( ctx, input ); + input += 64; + ilen -= 64; + } + + if ( ilen > 0 ) + { + memcpy( (void *) (ctx->buffer + left), + (void *) input, ilen ); + } +} + +static const unsigned char sha2_padding[64] = +{ + 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 +}; + +/* + * SHA-256 final digest + */ +void sha2_finish( sha2_context *ctx, unsigned char output[32] ) +{ + unsigned long last, padn; + unsigned long high, low; + unsigned char msglen[8]; + + high = (ctx->total[0] >> 29) + | (ctx->total[1] << 3); + low = (ctx->total[0] << 3); + + PUT_ULONG_BE( high, msglen, 0 ); + PUT_ULONG_BE( low, msglen, 4 ); + + last = ctx->total[0] & 0x3F; + padn = (last < 56) ? (56 - last) : (120 - last); + + sha2_update( ctx, (unsigned char *) sha2_padding, padn ); + sha2_update( ctx, msglen, 8 ); + + PUT_ULONG_BE( ctx->state[0], output, 0 ); + PUT_ULONG_BE( ctx->state[1], output, 4 ); + PUT_ULONG_BE( ctx->state[2], output, 8 ); + PUT_ULONG_BE( ctx->state[3], output, 12 ); + PUT_ULONG_BE( ctx->state[4], output, 16 ); + PUT_ULONG_BE( ctx->state[5], output, 20 ); + PUT_ULONG_BE( ctx->state[6], output, 24 ); + + if ( ctx->is224 == 0 ) + PUT_ULONG_BE( ctx->state[7], output, 28 ); +} + +/* + * output = SHA-256( input buffer ) + */ +void sha2( const unsigned char *input, size_t ilen, + unsigned char output[32], int is224 ) +{ + sha2_context ctx; + + sha2_starts( &ctx, is224 ); + sha2_update( &ctx, input, ilen ); + sha2_finish( &ctx, output ); + + memset( &ctx, 0, sizeof(sha2_context) ); +} + +/* + * SHA-256 HMAC context setup + */ +void sha2_hmac_starts( sha2_context *ctx, const unsigned char *key, size_t keylen, + int is224 ) +{ + size_t i; + unsigned char sum[32]; + + if ( keylen > 64 ) + { + sha2( key, keylen, sum, is224 ); + keylen = (is224) ? 28 : 32; + key = sum; + } + + memset( ctx->ipad, 0x36, 64 ); + memset( ctx->opad, 0x5C, 64 ); + + for ( i = 0; i < keylen; i++ ) + { + ctx->ipad[i] = (unsigned char)(ctx->ipad[i] ^ key[i]); + ctx->opad[i] = (unsigned char)(ctx->opad[i] ^ key[i]); + } + + sha2_starts( ctx, is224 ); + sha2_update( ctx, ctx->ipad, 64 ); + + memset( sum, 0, sizeof(sum) ); +} + +/* + * SHA-256 HMAC process buffer + */ +void sha2_hmac_update( sha2_context *ctx, const unsigned char *input, size_t ilen ) +{ + sha2_update( ctx, input, ilen ); +} + +/* + * SHA-256 HMAC final digest + */ +void sha2_hmac_finish( sha2_context *ctx, unsigned char output[32] ) +{ + int is224, hlen; + unsigned char tmpbuf[32]; + + is224 = ctx->is224; + hlen = (is224 == 0) ? 32 : 28; + + sha2_finish( ctx, tmpbuf ); + sha2_starts( ctx, is224 ); + sha2_update( ctx, ctx->opad, 64 ); + sha2_update( ctx, tmpbuf, hlen ); + sha2_finish( ctx, output ); + + memset( tmpbuf, 0, sizeof(tmpbuf) ); +} + +/* + * SHA-256 HMAC context reset + */ +void sha2_hmac_reset( sha2_context *ctx ) +{ + sha2_starts( ctx, ctx->is224 ); + sha2_update( ctx, ctx->ipad, 64 ); +} + +/* + * output = HMAC-SHA-256( hmac key, input buffer ) + */ +void sha2_hmac( const unsigned char *key, size_t keylen, + const unsigned char *input, size_t ilen, + unsigned char output[32], int is224 ) +{ + sha2_context ctx; + + sha2_hmac_starts( &ctx, key, keylen, is224 ); + sha2_hmac_update( &ctx, input, ilen ); + sha2_hmac_finish( &ctx, output ); + + memset( &ctx, 0, sizeof(sha2_context) ); +} + + + + + +#ifndef min +#define min( a, b ) ( ((a) < (b)) ? (a) : (b) ) +#endif + +void PKCS5_PBKDF2_HMAC(unsigned char *password, size_t plen, + unsigned char *salt, size_t slen, + const unsigned long iteration_count, const unsigned long key_length, + unsigned char *output) +{ + sha2_context ctx; + sha2_starts(&ctx, 0); + + // Size of the generated digest + unsigned char md_size = 32; + unsigned char md1[32]; + unsigned char work[32]; + + unsigned long counter = 1; + unsigned long generated_key_length = 0; + while (generated_key_length < key_length) { + // U1 ends up in md1 and work + unsigned char c[4]; + c[0] = (counter >> 24) & 0xff; + c[1] = (counter >> 16) & 0xff; + c[2] = (counter >> 8) & 0xff; + c[3] = (counter >> 0) & 0xff; + + sha2_hmac_starts(&ctx, password, plen, 0); + sha2_hmac_update(&ctx, salt, slen); + sha2_hmac_update(&ctx, c, 4); + sha2_hmac_finish(&ctx, md1); + memcpy(work, md1, md_size); + + unsigned long ic = 1; + for (ic = 1; ic < iteration_count; ic++) { + // U2 ends up in md1 + sha2_hmac_starts(&ctx, password, plen, 0); + sha2_hmac_update(&ctx, md1, md_size); + sha2_hmac_finish(&ctx, md1); + // U1 xor U2 + unsigned long i = 0; + for (i = 0; i < md_size; i++) { + work[i] ^= md1[i]; + } + // and so on until iteration_count + } + + // Copy the generated bytes to the key + unsigned long bytes_to_write = + min((key_length - generated_key_length), md_size); + memcpy(output + generated_key_length, work, bytes_to_write); + generated_key_length += bytes_to_write; + ++counter; + } +} + + + + + + + + + +#ifdef TEST +/* + * FIPS-180-2 test vectors + */ +static unsigned char sha2_test_buf[3][57] = +{ + { "abc" }, + { "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" }, + { "" } +}; + +static const int sha2_test_buflen[3] = +{ + 3, 56, 1000 +}; + +static const unsigned char sha2_test_sum[6][32] = +{ + /* + * SHA-224 test vectors + */ + { 0x23, 0x09, 0x7D, 0x22, 0x34, 0x05, 0xD8, 0x22, + 0x86, 0x42, 0xA4, 0x77, 0xBD, 0xA2, 0x55, 0xB3, + 0x2A, 0xAD, 0xBC, 0xE4, 0xBD, 0xA0, 0xB3, 0xF7, + 0xE3, 0x6C, 0x9D, 0xA7 }, + { 0x75, 0x38, 0x8B, 0x16, 0x51, 0x27, 0x76, 0xCC, + 0x5D, 0xBA, 0x5D, 0xA1, 0xFD, 0x89, 0x01, 0x50, + 0xB0, 0xC6, 0x45, 0x5C, 0xB4, 0xF5, 0x8B, 0x19, + 0x52, 0x52, 0x25, 0x25 }, + { 0x20, 0x79, 0x46, 0x55, 0x98, 0x0C, 0x91, 0xD8, + 0xBB, 0xB4, 0xC1, 0xEA, 0x97, 0x61, 0x8A, 0x4B, + 0xF0, 0x3F, 0x42, 0x58, 0x19, 0x48, 0xB2, 0xEE, + 0x4E, 0xE7, 0xAD, 0x67 }, + + /* + * SHA-256 test vectors + */ + { 0xBA, 0x78, 0x16, 0xBF, 0x8F, 0x01, 0xCF, 0xEA, + 0x41, 0x41, 0x40, 0xDE, 0x5D, 0xAE, 0x22, 0x23, + 0xB0, 0x03, 0x61, 0xA3, 0x96, 0x17, 0x7A, 0x9C, + 0xB4, 0x10, 0xFF, 0x61, 0xF2, 0x00, 0x15, 0xAD }, + { 0x24, 0x8D, 0x6A, 0x61, 0xD2, 0x06, 0x38, 0xB8, + 0xE5, 0xC0, 0x26, 0x93, 0x0C, 0x3E, 0x60, 0x39, + 0xA3, 0x3C, 0xE4, 0x59, 0x64, 0xFF, 0x21, 0x67, + 0xF6, 0xEC, 0xED, 0xD4, 0x19, 0xDB, 0x06, 0xC1 }, + { 0xCD, 0xC7, 0x6E, 0x5C, 0x99, 0x14, 0xFB, 0x92, + 0x81, 0xA1, 0xC7, 0xE2, 0x84, 0xD7, 0x3E, 0x67, + 0xF1, 0x80, 0x9A, 0x48, 0xA4, 0x97, 0x20, 0x0E, + 0x04, 0x6D, 0x39, 0xCC, 0xC7, 0x11, 0x2C, 0xD0 } +}; + +/* + * RFC 4231 test vectors + */ +static unsigned char sha2_hmac_test_key[7][26] = { + {"\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B\x0B" + "\x0B\x0B\x0B\x0B"}, + {"Jefe"}, + {"\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA" + "\xAA\xAA\xAA\xAA"}, + {"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x0C\x0D\x0E\x0F\x10" + "\x11\x12\x13\x14\x15\x16\x17\x18\x19"}, + {"\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C\x0C" + "\x0C\x0C\x0C\x0C"}, + {""}, /* 0xAA 131 times */ + {""} +}; + +static const int sha2_hmac_test_keylen[7] = { + 20, 4, 20, 25, 20, 131, 131 +}; + +static unsigned char sha2_hmac_test_buf[7][153] = +{ + { "Hi There" }, + { "what do ya want for nothing?" }, + { "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" + "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" + "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" + "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" + "\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD\xDD" }, + { "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" + "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" + "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" + "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" + "\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD\xCD" }, + { "Test With Truncation" }, + { "Test Using Larger Than Block-Size Key - Hash Key First" }, + { "This is a test using a larger than block-size key " + "and a larger than block-size data. The key needs to " + "be hashed before being used by the HMAC algorithm." } +}; + +static const int sha2_hmac_test_buflen[7] = +{ + 8, 28, 50, 50, 20, 54, 152 +}; + +static const unsigned char sha2_hmac_test_sum[14][32] = +{ + /* + * HMAC-SHA-224 test vectors + */ + { 0x89, 0x6F, 0xB1, 0x12, 0x8A, 0xBB, 0xDF, 0x19, + 0x68, 0x32, 0x10, 0x7C, 0xD4, 0x9D, 0xF3, 0x3F, + 0x47, 0xB4, 0xB1, 0x16, 0x99, 0x12, 0xBA, 0x4F, + 0x53, 0x68, 0x4B, 0x22 }, + { 0xA3, 0x0E, 0x01, 0x09, 0x8B, 0xC6, 0xDB, 0xBF, + 0x45, 0x69, 0x0F, 0x3A, 0x7E, 0x9E, 0x6D, 0x0F, + 0x8B, 0xBE, 0xA2, 0xA3, 0x9E, 0x61, 0x48, 0x00, + 0x8F, 0xD0, 0x5E, 0x44 }, + { 0x7F, 0xB3, 0xCB, 0x35, 0x88, 0xC6, 0xC1, 0xF6, + 0xFF, 0xA9, 0x69, 0x4D, 0x7D, 0x6A, 0xD2, 0x64, + 0x93, 0x65, 0xB0, 0xC1, 0xF6, 0x5D, 0x69, 0xD1, + 0xEC, 0x83, 0x33, 0xEA }, + { 0x6C, 0x11, 0x50, 0x68, 0x74, 0x01, 0x3C, 0xAC, + 0x6A, 0x2A, 0xBC, 0x1B, 0xB3, 0x82, 0x62, 0x7C, + 0xEC, 0x6A, 0x90, 0xD8, 0x6E, 0xFC, 0x01, 0x2D, + 0xE7, 0xAF, 0xEC, 0x5A }, + { 0x0E, 0x2A, 0xEA, 0x68, 0xA9, 0x0C, 0x8D, 0x37, + 0xC9, 0x88, 0xBC, 0xDB, 0x9F, 0xCA, 0x6F, 0xA8 }, + { 0x95, 0xE9, 0xA0, 0xDB, 0x96, 0x20, 0x95, 0xAD, + 0xAE, 0xBE, 0x9B, 0x2D, 0x6F, 0x0D, 0xBC, 0xE2, + 0xD4, 0x99, 0xF1, 0x12, 0xF2, 0xD2, 0xB7, 0x27, + 0x3F, 0xA6, 0x87, 0x0E }, + { 0x3A, 0x85, 0x41, 0x66, 0xAC, 0x5D, 0x9F, 0x02, + 0x3F, 0x54, 0xD5, 0x17, 0xD0, 0xB3, 0x9D, 0xBD, + 0x94, 0x67, 0x70, 0xDB, 0x9C, 0x2B, 0x95, 0xC9, + 0xF6, 0xF5, 0x65, 0xD1 }, + + /* + * HMAC-SHA-256 test vectors + */ + { 0xB0, 0x34, 0x4C, 0x61, 0xD8, 0xDB, 0x38, 0x53, + 0x5C, 0xA8, 0xAF, 0xCE, 0xAF, 0x0B, 0xF1, 0x2B, + 0x88, 0x1D, 0xC2, 0x00, 0xC9, 0x83, 0x3D, 0xA7, + 0x26, 0xE9, 0x37, 0x6C, 0x2E, 0x32, 0xCF, 0xF7 }, + { 0x5B, 0xDC, 0xC1, 0x46, 0xBF, 0x60, 0x75, 0x4E, + 0x6A, 0x04, 0x24, 0x26, 0x08, 0x95, 0x75, 0xC7, + 0x5A, 0x00, 0x3F, 0x08, 0x9D, 0x27, 0x39, 0x83, + 0x9D, 0xEC, 0x58, 0xB9, 0x64, 0xEC, 0x38, 0x43 }, + { 0x77, 0x3E, 0xA9, 0x1E, 0x36, 0x80, 0x0E, 0x46, + 0x85, 0x4D, 0xB8, 0xEB, 0xD0, 0x91, 0x81, 0xA7, + 0x29, 0x59, 0x09, 0x8B, 0x3E, 0xF8, 0xC1, 0x22, + 0xD9, 0x63, 0x55, 0x14, 0xCE, 0xD5, 0x65, 0xFE }, + { 0x82, 0x55, 0x8A, 0x38, 0x9A, 0x44, 0x3C, 0x0E, + 0xA4, 0xCC, 0x81, 0x98, 0x99, 0xF2, 0x08, 0x3A, + 0x85, 0xF0, 0xFA, 0xA3, 0xE5, 0x78, 0xF8, 0x07, + 0x7A, 0x2E, 0x3F, 0xF4, 0x67, 0x29, 0x66, 0x5B }, + { 0xA3, 0xB6, 0x16, 0x74, 0x73, 0x10, 0x0E, 0xE0, + 0x6E, 0x0C, 0x79, 0x6C, 0x29, 0x55, 0x55, 0x2B }, + { 0x60, 0xE4, 0x31, 0x59, 0x1E, 0xE0, 0xB6, 0x7F, + 0x0D, 0x8A, 0x26, 0xAA, 0xCB, 0xF5, 0xB7, 0x7F, + 0x8E, 0x0B, 0xC6, 0x21, 0x37, 0x28, 0xC5, 0x14, + 0x05, 0x46, 0x04, 0x0F, 0x0E, 0xE3, 0x7F, 0x54 }, + { 0x9B, 0x09, 0xFF, 0xA7, 0x1B, 0x94, 0x2F, 0xCB, + 0x27, 0x63, 0x5F, 0xBC, 0xD5, 0xB0, 0xE9, 0x44, + 0xBF, 0xDC, 0x63, 0x64, 0x4F, 0x07, 0x13, 0x93, + 0x8A, 0x7F, 0x51, 0x53, 0x5C, 0x3A, 0x35, 0xE2 } +}; +typedef struct { + char *t; + char *p; + int plen; + char *s; + int slen; + int c; + int dkLen; + char dk[1024]; // Remember to set this to max dkLen +} testvector; + +int do_test(testvector *tv) +{ + printf("Started %s\n", tv->t); + fflush(stdout); + char *key = malloc(tv->dkLen); + if (key == 0) { + return -1; + } + + PKCS5_PBKDF2_HMAC((unsigned char *)tv->p, tv->plen, + (unsigned char *)tv->s, tv->slen, tv->c, + tv->dkLen, (unsigned char *)key); + + if (memcmp(tv->dk, key, tv->dkLen) != 0) { + // Failed + return -1; + } + + return 0; +} + +/* + * Checkup routine + */ +int main() +{ + int verbose = 1; + int i, j, k, buflen; + unsigned char buf[1024]; + unsigned char sha2sum[32]; + sha2_context ctx; + + for (i = 0; i < 6; i++) { + j = i % 3; + k = i < 3; + + if (verbose != 0) + printf(" SHA-%d test #%d: ", 256 - k * 32, j + 1); + + sha2_starts(&ctx, k); + + if (j == 2) { + memset(buf, 'a', buflen = 1000); + + for (j = 0; j < 1000; j++) + sha2_update(&ctx, buf, buflen); + } else + sha2_update(&ctx, sha2_test_buf[j], + sha2_test_buflen[j]); + + sha2_finish(&ctx, sha2sum); + + if (memcmp(sha2sum, sha2_test_sum[i], 32 - k * 4) != 0) { + if (verbose != 0) + printf("failed\n"); + + return (1); + } + + if (verbose != 0) + printf("passed\n"); + } + + if (verbose != 0) + printf("\n"); + + for (i = 0; i < 14; i++) { + j = i % 7; + k = i < 7; + + if (verbose != 0) + printf(" HMAC-SHA-%d test #%d: ", 256 - k * 32, + j + 1); + + if (j == 5 || j == 6) { + memset(buf, '\xAA', buflen = 131); + sha2_hmac_starts(&ctx, buf, buflen, k); + } else + sha2_hmac_starts(&ctx, sha2_hmac_test_key[j], + sha2_hmac_test_keylen[j], k); + + sha2_hmac_update(&ctx, sha2_hmac_test_buf[j], + sha2_hmac_test_buflen[j]); + + sha2_hmac_finish(&ctx, sha2sum); + + buflen = (j == 4) ? 16 : 32 - k * 4; + + if (memcmp(sha2sum, sha2_hmac_test_sum[i], buflen) != 0) { + if (verbose != 0) + printf("failed\n"); + + return (1); + } + + if (verbose != 0) + printf("passed\n"); + } + + if (verbose != 0) + printf("\n"); + + testvector *tv = 0; + int res = 0; + + testvector t1 = { + "Test 1", + "password", 8, "salt", 4, 1, 32, + .dk = { 0x12, 0x0f, 0xb6, 0xcf, 0xfc, 0xf8, 0xb3, 0x2c, + 0x43, 0xe7, 0x22, 0x52, 0x56, 0xc4, 0xf8, 0x37, + 0xa8, 0x65, 0x48, 0xc9, 0x2c, 0xcc, 0x35, 0x48, + 0x08, 0x05, 0x98, 0x7c, 0xb7, 0x0b, 0xe1, 0x7b } + }; + + tv = &t1; + res = do_test(tv); + if (res != 0) { + printf("%s failed\n", tv->t); + return res; + } + + testvector t2 = { + "Test 2", + "password", 8, "salt", 4, 2, 32, { + 0xae, 0x4d, 0x0c, 0x95, 0xaf, 0x6b, 0x46, 0xd3, + 0x2d, 0x0a, 0xdf, 0xf9, 0x28, 0xf0, 0x6d, 0xd0, + 0x2a, 0x30, 0x3f, 0x8e, 0xf3, 0xc2, 0x51, 0xdf, + 0xd6, 0xe2, 0xd8, 0x5a, 0x95, 0x47, 0x4c, 0x43 + } + }; + + tv = &t2; + res = do_test(tv); + if (res != 0) { + printf("%s failed\n", tv->t); + return res; + } + + testvector t3 = { + "Test 3", + "password", 8, "salt", 4, 4096, 32, { + 0xc5, 0xe4, 0x78, 0xd5, 0x92, 0x88, 0xc8, 0x41, + 0xaa, 0x53, 0x0d, 0xb6, 0x84, 0x5c, 0x4c, 0x8d, + 0x96, 0x28, 0x93, 0xa0, 0x01, 0xce, 0x4e, 0x11, + 0xa4, 0x96, 0x38, 0x73, 0xaa, 0x98, 0x13, 0x4a + } + }; + + tv = &t3; + res = do_test(tv); + if (res != 0) { + printf("%s failed\n", tv->t); + return res; + } + + testvector t4 = { + "Test 4", + "password", 8, "salt", 4, 16777216, 32, { + 0xcf, 0x81, 0xc6, 0x6f, 0xe8, 0xcf, 0xc0, 0x4d, + 0x1f, 0x31, 0xec, 0xb6, 0x5d, 0xab, 0x40, 0x89, + 0xf7, 0xf1, 0x79, 0xe8, 0x9b, 0x3b, 0x0b, 0xcb, + 0x17, 0xad, 0x10, 0xe3, 0xac, 0x6e, 0xba, 0x46 + } + }; + + tv = &t4; + // res = do_test(tv); + if (res != 0) { + printf("%s failed\n", tv->t); + return res; + } + + testvector t5 = { + "Test 5", + "passwordPASSWORDpassword", 24, + "saltSALTsaltSALTsaltSALTsaltSALTsalt", 36, 4096, 40, { + 0x34, 0x8c, 0x89, 0xdb, 0xcb, 0xd3, 0x2b, 0x2f, + 0x32, 0xd8, 0x14, 0xb8, 0x11, 0x6e, 0x84, 0xcf, + 0x2b, 0x17, 0x34, 0x7e, 0xbc, 0x18, 0x00, 0x18, + 0x1c, 0x4e, 0x2a, 0x1f, 0xb8, 0xdd, 0x53, 0xe1, + 0xc6, 0x35, 0x51, 0x8c, 0x7d, 0xac, 0x47, 0xe9 + } + }; + + tv = &t5; + res = do_test(tv); + if (res != 0) { + printf("%s failed\n", tv->t); + return res; + } + + testvector t6 = { + "Test 6", + "pass\0word", 9, "sa\0lt", 5, 4096, 16, { + 0x89, 0xb6, 0x9d, 0x05, 0x16, 0xf8, 0x29, 0x89, + 0x3c, 0x69, 0x62, 0x26, 0x65, 0x0a, 0x86, 0x87 + } + }; + + tv = &t6; + res = do_test(tv); + if (res != 0) { + printf("%s failed\n", tv->t); + return res; + } + + return (0); +} + +#endif