should work
[c11concurrency-benchmarks.git] / mabain / binaries / mbc.cpp
diff --git a/mabain/binaries/mbc.cpp b/mabain/binaries/mbc.cpp
new file mode 100644 (file)
index 0000000..70099c9
--- /dev/null
@@ -0,0 +1,676 @@
+/**
+ * Copyright (C) 2017 Cisco Inc.
+ *
+ * This program is free software: you can redistribute it and/or  modify
+ * it under the terms of the GNU General Public License, version 2,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+// @author Changxue Deng <chadeng@cisco.com>
+
+// A mabain command-line client
+
+#include <stdlib.h>
+#include <string.h>
+#include <iostream>
+#include <fstream>
+#include <string> 
+#include <assert.h>
+#include <signal.h>
+#include <readline/readline.h>
+#include <readline/history.h>
+
+#include "db.h"
+#include "mb_data.h"
+#include "dict.h"
+#include "error.h"
+#include "version.h"
+
+#include "hexbin.h"
+#include "expr_parser.h"
+
+using namespace mabain;
+
+enum mbc_command {
+    COMMAND_NONE = 0,
+    COMMAND_QUIT = 1,
+    COMMAND_UNKNOWN = 2,
+    COMMAND_STATS = 3,
+    COMMAND_FIND = 4,
+    COMMAND_FIND_ALL = 5,
+    COMMAND_INSERT = 6,
+    COMMAND_REPLACE = 7,
+    COMMAND_DELETE = 8,
+    COMMAND_DELETE_ALL = 9,
+    COMMAND_HELP = 10,
+    COMMAND_RESET_N_WRITER = 11,
+    COMMAND_RESET_N_READER = 12,
+    COMMAND_FIND_LPREFIX = 13,
+    COMMAND_PRINT_HEADER = 14,
+    COMMAND_FIND_HEX = 15,
+    COMMAND_FIND_LPREFIX_HEX = 16,
+    COMMAND_RECLAIM_RESOURCES = 17,
+    COMMAND_PARSING_ERROR = 18,
+};
+
+volatile bool quit_mbc = false;
+static void HandleSignal(int sig)
+{
+    switch(sig)
+    {
+        case SIGSEGV:
+            std::cerr << "process segfault\n";
+            abort();
+        case SIGTERM:
+        case SIGINT:
+        case SIGQUIT:
+        case SIGHUP:
+        case SIGPIPE:
+            quit_mbc = true;
+        case SIGUSR1:
+        case SIGUSR2:
+            break;
+    }
+}
+
+static void usage(const char *prog)
+{
+    std::cout << "Usage: " << prog << " -d mabain-directory [-im index-memcap] [-dm data-memcap] [-w] [-e query] [-s script-file]\n";
+    std::cout <<"\t-d mabain databse directory\n";
+    std::cout <<"\t-im index memcap\n";
+    std::cout <<"\t-dm data memcap\n";
+    std::cout <<"\t-w running in writer mode\n";
+    std::cout <<"\t-e run query on command line\n";
+    std::cout <<"\t-s run queries in a file\n";
+    exit(1);
+}
+
+static void show_help()
+{
+    std::cout << "\tfind(\"key\")\t\tsearch entry by key\n";
+    std::cout << "\tfindPrefix(\"key\")\tsearch entry by key using longest prefix match\n";
+    std::cout << "\tfindAll\t\t\tlist all entries\n";
+    std::cout << "\tinsert(\"key\":\"value\")\tinsert a key-value pair\n";
+    std::cout << "\treplace(\"key\":\"value\")\treplace a key-value pair\n";
+    std::cout << "\tdelete(\"key\")\t\tdelete entry by key\n";
+    std::cout << "\tdeleteAll\t\tdelete all entries\n";
+    std::cout << "\tshow\t\t\tshow database statistics\n";
+    std::cout << "\thelp\t\t\tshow helps\n";
+    std::cout << "\tquit\t\t\tquit mabain client\n";
+    std::cout << "\tdecWriterCount\t\tClear writer count in shared memory header\n";
+    std::cout << "\tdecReaderCount\t\tdecrement reader count in shared memory header\n";
+    std::cout << "\tprintHeader\t\tPrint shared memory header\n";
+    std::cout << "\treclaimResources\tReclaim deleted resources\n";
+}
+
+static void trim_spaces(const char *cmd, std::string &cmd_trim)
+{
+    cmd_trim.clear();
+
+    int quotation = 0;
+    const char *p = cmd;
+    while(*p != '\0')
+    {
+        if(*p == '\'' || *p == '\"')
+        {
+            cmd_trim.append(1, '\'');
+            quotation ^= 1; 
+        }
+        else if(!isspace(*p) || quotation) 
+        {
+            cmd_trim.append(1, *p);
+        }
+
+        p++;
+    }
+}
+
+static bool check_hex_output(std::string &cmd)
+{
+    size_t pos = cmd.rfind(".hex()");
+    if(pos == std::string::npos)
+        return false;
+    if(pos == cmd.length()-6)
+    {
+        cmd.erase(pos);
+        return true;
+    }
+
+    return false;
+}
+
+static int parse_key_value_pair(const std::string &kv_str,
+                                std::string &key,
+                                std::string &value)
+{
+    // search for ':' that separates key and value pair
+    // currently this utility does not support quotation in
+    // quotation, e.g, "abc\"def" as key or value.
+    size_t pos = 0;
+    int quotation_cnt = 0;
+    for(size_t i = 0; i < kv_str.length(); i++)
+    {
+        if(kv_str[i] == '\'')
+        {
+            quotation_cnt++;
+        }
+        else if(kv_str[i] == ':')
+        {
+            // do not count the ':' in the key or value string.
+            if(quotation_cnt % 2 == 0)
+            {
+                pos = i;
+                break;
+            }
+        }
+    }
+
+    if(pos == 0)
+        return -1;
+
+    ExprParser expr_key(kv_str.substr(0, pos));
+    if(expr_key.Evaluate(key) < 0)
+        return -1;
+
+    ExprParser expr_value(kv_str.substr(pos+1));
+    if(expr_value.Evaluate(value) < 0)
+        return -1;
+
+    return 0;
+}
+
+static int parse_command(std::string &cmd,
+                         std::string &key,
+                         std::string &value)
+{
+    std::string yes;
+    bool hex_output = false;
+
+    key = "";
+    value = "";
+    switch(cmd[0])
+    {
+        case 'q':
+            if(cmd.compare("quit") == 0)
+                return COMMAND_QUIT;
+            break;
+        case 's':
+            if(cmd.compare("show") == 0)
+                return COMMAND_STATS;
+            break;
+        case 'f':
+            hex_output = check_hex_output(cmd);
+            if(cmd.compare(0, 5, "find(") == 0)
+            {
+                if(cmd[cmd.length()-1] != ')')
+                    return COMMAND_UNKNOWN;
+                ExprParser expr(cmd.substr(5, cmd.length()-6));
+                if(expr.Evaluate(key) < 0)
+                    return COMMAND_PARSING_ERROR;
+                if(hex_output)
+                    return COMMAND_FIND_HEX;
+                return COMMAND_FIND;
+            }
+            else if(cmd.compare(0, 11, "findPrefix(") == 0)
+            {
+                if(cmd[cmd.length()-1] != ')')
+                    return COMMAND_UNKNOWN;
+                ExprParser expr(cmd.substr(11, cmd.length()-12));
+                if(expr.Evaluate(key) < 0)
+                    return COMMAND_PARSING_ERROR;
+                if(hex_output)
+                    return COMMAND_FIND_LPREFIX_HEX;
+                return COMMAND_FIND_LPREFIX;
+            }
+            else if(cmd.compare("findAll") == 0)
+                return COMMAND_FIND_ALL;
+            break;
+        case 'd':
+            if(cmd.compare(0, 7, "delete(") == 0)
+            {
+                if(cmd[cmd.length()-1] != ')')
+                    return COMMAND_UNKNOWN;
+                ExprParser expr(cmd.substr(7, cmd.length()-8));
+                if(expr.Evaluate(key) < 0)
+                    return COMMAND_PARSING_ERROR; 
+                return COMMAND_DELETE;
+            }
+            else if(cmd.compare("deleteAll") == 0)
+            {
+                std::cout << "Do you want to delete all entries? Press \'Y\' to continue: ";
+                std::string del_all;
+                std::getline(std::cin, del_all);
+                if(del_all.length() == 0 || del_all[0] != 'Y')
+                    return COMMAND_NONE;
+                return COMMAND_DELETE_ALL;
+            }
+            else if(cmd.compare("decReaderCount") == 0)
+            {
+                std::cout << "Do you want to decrement number of reader? Press \'y\' to continue: ";
+                std::getline(std::cin, yes);   
+                if(yes.length() > 0 && yes[0] == 'y')
+                    return COMMAND_RESET_N_READER;
+                return COMMAND_NONE;
+            }
+            else if(cmd.compare("decWriterCount") == 0)
+            {
+                std::cout << "Do you want to decrement number of writer? Press \'y\' to continue: ";
+                std::getline(std::cin, yes);   
+                if(yes.length() > 0 && yes[0] == 'y')
+                    return COMMAND_RESET_N_WRITER;
+                return COMMAND_NONE;
+            }
+            break;
+        case 'i':
+            if(cmd.compare(0, 7, "insert(") == 0)
+            {
+                if(cmd[cmd.length()-1] != ')')
+                    return COMMAND_UNKNOWN;
+                if(parse_key_value_pair(cmd.substr(7, cmd.length()-8), key, value) < 0)
+                    return COMMAND_PARSING_ERROR;
+                return COMMAND_INSERT;
+            }
+            break;
+        case 'r':
+            if(cmd.compare(0, 8, "replace(") == 0)
+            {
+                if(cmd[cmd.length()-1] != ')')
+                    return COMMAND_UNKNOWN;
+                if(parse_key_value_pair(cmd.substr(8, cmd.length()-9), key, value) < 0)
+                    return COMMAND_PARSING_ERROR;
+                return COMMAND_REPLACE;
+            }
+            else if(cmd.compare("reclaimResources") == 0)
+                return COMMAND_RECLAIM_RESOURCES;
+            else
+                return COMMAND_UNKNOWN;
+            break;
+        case 'h':
+            if(cmd.compare("help") == 0)
+                return COMMAND_HELP;
+            break;
+        case 'p':
+            if(cmd.compare("printHeader") == 0)
+                return COMMAND_PRINT_HEADER;
+            break;
+        default:
+            break;
+    }
+
+    return COMMAND_UNKNOWN;
+}
+
+#define ENTRY_PER_PAGE 20
+static void display_all_kvs(DB *db)
+{
+    if(db == NULL)
+        return;
+
+    int count = 0;
+    for(DB::iterator iter = db->begin(); iter != db->end(); ++iter)
+    {
+        count++;
+        std::cout << iter.key << ": " <<
+                     std::string((char *)iter.value.buff, iter.value.data_len) << "\n";
+        if(count % ENTRY_PER_PAGE == 0)
+        {
+            std::string show_more;
+            std::cout << "Press \'y\' for displaying more: ";
+            std::getline(std::cin, show_more);
+            if(show_more.length() == 0 || show_more[0] != 'y')
+                break;
+        }
+    }
+}
+
+static void display_output(const MBData &mbd, bool hex_output, bool prefix)
+{
+    if(prefix)
+        std::cout << "key length matched: " << mbd.match_len << "\n";
+    if(hex_output)
+    {
+        char hex_buff[256];
+        int len = mbd.data_len;
+        if(256 < 2*len + 1)
+        {
+            std::cout << "display the first 127 bytes\n";
+            len = 127;
+        }
+        if(bin_2_hex((const uint8_t *)mbd.buff, len, hex_buff, 256) < 0)
+            std::cout << "failed to convert binary to hex\n";
+        else 
+            std::cout << hex_buff << "\n";
+    }
+    else
+    {
+        std::cout << std::string((char *)mbd.buff, mbd.data_len) << "\n";
+    }
+}
+
+static int RunCommand(int mode, DB *db, int cmd_id, const std::string &key, const std::string &value)
+{
+    int rval = MBError::SUCCESS;
+    bool overwrite = false;
+    bool hex_output = false;
+    MBData mbd;
+
+    switch(cmd_id)
+    {
+        case COMMAND_NONE:
+            // no opertation needed
+            break;
+        case COMMAND_QUIT:
+            std::cout << "bye\n";
+            quit_mbc = true;
+            break;
+        case COMMAND_FIND_HEX:
+            hex_output = true;
+        case COMMAND_FIND:
+            rval = db->Find(key, mbd);
+            if(rval == MBError::SUCCESS)
+                display_output(mbd, hex_output, false);
+            else
+                std::cout << MBError::get_error_str(rval) << "\n";
+            break;
+        case COMMAND_FIND_LPREFIX_HEX:
+            hex_output = true;
+        case COMMAND_FIND_LPREFIX:
+            rval = db->FindLongestPrefix(key, mbd);
+            if(rval == MBError::SUCCESS)
+                display_output(mbd, hex_output, true);
+            else
+                std::cout << MBError::get_error_str(rval) << "\n";
+            break;
+        case COMMAND_DELETE:
+            if(mode & CONSTS::ACCESS_MODE_WRITER)
+            {
+                rval = db->Remove(key);
+                std::cout << MBError::get_error_str(rval) << "\n";
+            }
+            else
+                std::cout << "permission not allowed\n";
+            break;
+        case COMMAND_REPLACE:
+            overwrite = true;
+        case COMMAND_INSERT:
+            if(mode & CONSTS::ACCESS_MODE_WRITER)
+            {
+                rval = db->Add(key, value, overwrite);
+                std::cout << MBError::get_error_str(rval) << "\n";
+            }
+            else
+                std::cout << "permission not allowed\n";
+            break;
+        case COMMAND_STATS:
+            db->PrintStats();
+            break;
+        case COMMAND_HELP:
+            show_help();
+            break;
+        case COMMAND_DELETE_ALL:
+            if(mode & CONSTS::ACCESS_MODE_WRITER)
+            {
+                rval = db->RemoveAll();
+                std::cout << MBError::get_error_str(rval) << "\n";
+            }
+            else
+                std::cout << "permission not allowed\n";
+            break;
+        case COMMAND_FIND_ALL:
+            display_all_kvs(db);
+            break;
+        case COMMAND_RESET_N_WRITER:
+            if(mode & CONSTS::ACCESS_MODE_WRITER)
+                std::cout << "writer is running, cannot reset writer counter\n";
+            else
+                db->UpdateNumHandlers(CONSTS::ACCESS_MODE_WRITER, -1);
+            break;
+        case COMMAND_RESET_N_READER:
+            db->UpdateNumHandlers(CONSTS::ACCESS_MODE_READER, -1);
+            break;
+        case COMMAND_PRINT_HEADER:
+            db->PrintHeader();
+            break;
+        case COMMAND_RECLAIM_RESOURCES:
+            if(mode & CONSTS::ACCESS_MODE_WRITER)
+                db->CollectResource(1, 1);
+            else
+                std::cout << "writer is not running, can not perform grabage collection" << std::endl;
+            break;
+        case COMMAND_PARSING_ERROR:
+            break;
+        case COMMAND_UNKNOWN:
+        default:
+            std::cout << "unknown query\n";
+            break;
+    }
+
+    return rval;
+}
+
+static void mbclient(DB *db, int mode)
+{
+    rl_bind_key('\t', rl_complete);
+
+    printf("mabain %d.%d.%d shell\n", version[0], version[1], version[2]);
+    std::cout << "database directory: " << db->GetDBDir() << "\n";
+
+    int cmd_id;
+    std::string key;
+    std::string value;
+    std::string cmd;
+
+    while(true)
+    {
+        char* line = readline(">> ");
+        if(line == NULL) break;
+        if(line[0] == '\0')
+        {
+            free(line);
+            continue;
+        }
+
+        trim_spaces(line, cmd);
+        add_history(line);
+        free(line);
+        cmd_id = parse_command(cmd, key, value);
+
+        RunCommand(mode, db, cmd_id, key, value);
+
+        if(quit_mbc) break;
+    }
+}
+
+static void run_query_command(DB *db, int mode, const std::string &command_str)
+{
+    std::string cmd;
+    int cmd_id;
+    std::string key;
+    std::string value;
+
+    trim_spaces(command_str.c_str(), cmd);
+    if(cmd.length() == 0)
+    {
+        std::cerr << command_str << " not a valid command\n";
+        return;
+    }
+
+    cmd_id = parse_command(cmd, key, value);
+    RunCommand(mode, db, cmd_id, key, value);
+}
+
+static void run_script(DB *db, int mode, const std::string &script_file)
+{
+    std::ifstream script_in(script_file);
+    if(!script_in.is_open()) {
+        std::cerr << "cannot open file " << script_file << "\n";
+        return;
+    }
+
+    std::string line;
+    int cmd_id;
+    std::string key;
+    std::string value;
+    std::string cmd;
+    
+    while(getline(script_in, line))
+    {
+        trim_spaces(line.c_str(), cmd);
+        if(cmd.length() == 0)
+        {
+            std::cerr << line << " not a valid query\n";
+            continue;
+        }
+
+        cmd_id = parse_command(cmd, key, value);
+        std::cout << cmd << ": ";
+        RunCommand(mode, db, cmd_id, key, value);
+
+        if(quit_mbc) break;
+    }
+    script_in.close();
+}
+
+int main(int argc, char *argv[])
+{
+    sigset_t mask;
+
+    signal(SIGINT, HandleSignal);
+    signal(SIGTERM, HandleSignal);
+    signal(SIGQUIT, HandleSignal);
+    signal(SIGPIPE, HandleSignal);
+    signal(SIGHUP, HandleSignal);
+    signal(SIGSEGV, HandleSignal);
+    signal(SIGUSR1, HandleSignal);
+    signal(SIGUSR2, HandleSignal);
+
+    sigemptyset(&mask);
+    sigaddset(&mask, SIGTERM);
+    sigaddset(&mask, SIGQUIT);
+    sigaddset(&mask, SIGINT);
+    sigaddset(&mask, SIGPIPE);
+    sigaddset(&mask, SIGHUP);
+    sigaddset(&mask, SIGSEGV);
+    sigaddset(&mask, SIGUSR1);
+    sigaddset(&mask, SIGUSR2);
+    pthread_sigmask(SIG_SETMASK, &mask, NULL);
+
+    sigemptyset(&mask);
+    pthread_sigmask(SIG_SETMASK, &mask, NULL);
+
+    int64_t memcap_i = 1024*1024LL;
+    int64_t memcap_d = 1024*1024LL;
+    const char *db_dir = NULL;
+    int mode = 0;
+    std::string query_cmd = "";
+    std::string script_file = "";
+    int64_t index_blk_size = 64LL*1024*1024;
+    int64_t data_blk_size = 64LL*1024*1024;
+    int64_t lru_bucket_size = 1000;
+
+    for(int i = 1; i < argc; i++)
+    {
+        if(strcmp(argv[i], "-d") == 0)
+        {
+            if(++i >= argc)
+                usage(argv[0]);
+            db_dir = argv[i];
+        }
+        else if(strcmp(argv[i], "-im") == 0)
+        {
+            if(++i >= argc)
+                usage(argv[0]);
+            memcap_i = atoi(argv[i]);
+        }
+        else if(strcmp(argv[i], "-dm") == 0)
+        {
+            if(++i >= argc)
+                usage(argv[0]);
+            memcap_d = atoi(argv[i]);
+        }
+        else if(strcmp(argv[i], "-w") == 0)
+        {
+            mode |= CONSTS::ACCESS_MODE_WRITER;
+        }
+        else if(strcmp(argv[i], "-e") == 0)
+        {
+            if(++i >= argc)
+                usage(argv[0]);
+            query_cmd = argv[i];
+        }
+        else if(strcmp(argv[i], "-s") == 0)
+        {
+            if(++i >= argc)
+                usage(argv[0]);
+            script_file = argv[i];
+        }
+        else if(strcmp(argv[i], "--lru-bucket-size") == 0)
+        {
+            if(++i >= argc)
+                usage(argv[0]);
+            lru_bucket_size = atoi(argv[i]);
+        }
+        else if(strcmp(argv[i], "--index-block-size") == 0)
+        {
+            if(++i >= argc)
+                usage(argv[0]);
+            index_blk_size = atoi(argv[i]);
+        }
+        else if(strcmp(argv[i], "--data-block-size") == 0)
+        {
+            if(++i >= argc)
+                usage(argv[0]);
+            data_blk_size = atoi(argv[i]);
+        }
+        else
+            usage(argv[0]);
+    }
+
+    if(db_dir == NULL)
+        usage(argv[0]);
+
+    MBConfig mbconf;
+    memset(&mbconf, 0, sizeof(mbconf));
+    mbconf.mbdir = db_dir;
+    mbconf.options = mode;
+    mbconf.memcap_index = memcap_i;
+    mbconf.memcap_data = memcap_d;
+    mbconf.block_size_index = index_blk_size;
+    mbconf.block_size_data = data_blk_size;
+    mbconf.num_entry_per_bucket = lru_bucket_size;
+    DB *db = new DB(mbconf);
+    if(!db->is_open())
+    {
+        std::cout << db->StatusStr() << "\n";
+        exit(1);
+    }
+
+    // DB::SetLogFile("/var/tmp/mabain.log");
+    // DB::LogDebug();
+
+    if(query_cmd.length() != 0)
+    {
+        run_query_command(db, mode, query_cmd);
+    }
+    else if(script_file.length() != 0)
+    {
+        run_script(db, mode, script_file);
+    }
+    else
+    {
+        mbclient(db, mode);
+    }
+
+    db->Close();
+    // DB::CloseLogFile();
+    delete db;
+    return 0;
+}