diff --git a/.github/workflows/test-aby3.yml b/.github/workflows/test-aby3.yml new file mode 100644 index 0000000..fd140ea --- /dev/null +++ b/.github/workflows/test-aby3.yml @@ -0,0 +1,28 @@ +name: Build and test the Docker container for ABY3 + +on: + push: + branches: [ master ] + paths: + - aby3/* + pull_request: + branches: [ master ] + paths: + - aby3/* + +jobs: + build: + name: Build container for ABY3 + runs-on: ubuntu-latest + steps: + - name: Check out the repo + uses: actions/checkout@v3 + - uses: docker/setup-buildx-action@v2 + - uses: docker/build-push-action@v4 + with: + context: aby3 + tags: aby3 + load: true + cache-from: type=gha + cache-to: type=gha,mode=max + push: false diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..0d39b48 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,50 @@ +{ + "files.associations": { + "array": "cpp", + "atomic": "cpp", + "bit": "cpp", + "*.tcc": "cpp", + "cctype": "cpp", + "clocale": "cpp", + "cmath": "cpp", + "compare": "cpp", + "concepts": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "deque": "cpp", + "string": "cpp", + "unordered_map": "cpp", + "vector": "cpp", + "exception": "cpp", + "algorithm": "cpp", + "functional": "cpp", + "iterator": "cpp", + "memory": "cpp", + "memory_resource": "cpp", + "numeric": "cpp", + "random": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "utility": "cpp", + "initializer_list": "cpp", + "iosfwd": "cpp", + "iostream": "cpp", + "istream": "cpp", + "limits": "cpp", + "new": "cpp", + "numbers": "cpp", + "ostream": "cpp", + "sstream": "cpp", + "stdexcept": "cpp", + "streambuf": "cpp", + "cinttypes": "cpp", + "typeinfo": "cpp" + } +} \ No newline at end of file diff --git a/aby3/Dockerfile b/aby3/Dockerfile new file mode 100644 index 0000000..79a5d1d --- /dev/null +++ b/aby3/Dockerfile @@ -0,0 +1,18 @@ +FROM ubuntu:22.04 + +WORKDIR /root + +RUN apt update \ + && apt install -y vim git build-essential python3 cmake + +ADD install.sh . +RUN ./install.sh + +WORKDIR /root/aby3 +ADD source ./samples +# build our examples +RUN python3 build.py + +ADD run_example.sh . + +ENTRYPOINT ./run_example.sh xtabs diff --git a/aby3/README.md b/aby3/README.md new file mode 100644 index 0000000..6a05b6e --- /dev/null +++ b/aby3/README.md @@ -0,0 +1,7 @@ +ABY^3 was developed by Peter Rindal. + +Currently fails to build with error +``` +error: no matching function for call to 'osuCrypto::Channel::asyncSend(unsigned char*&, long unsigned int&, osuCrypto::OblvPermutation::send(osuCrypto::Channel&, osuCrypto::Channel&, osuCrypto::Matrix, std::__cxx11::string)::)' + recvrChl.asyncSend(data, size, [a = std::move(src)](){}); +``` diff --git a/aby3/install.sh b/aby3/install.sh new file mode 100755 index 0000000..96403c5 --- /dev/null +++ b/aby3/install.sh @@ -0,0 +1,12 @@ +#!/bin/sh + +set -ex + +git clone https://github.com/ladnir/aby3.git +cd aby3/ +git checkout 854ca1b +python3 build.py --setup +python3 build.py + +mkdir samples +printf "add_subdirectory(samples)\n" >> CMakeLists.txt diff --git a/aby3/run_example.sh b/aby3/run_example.sh new file mode 100755 index 0000000..ad24f5c --- /dev/null +++ b/aby3/run_example.sh @@ -0,0 +1,16 @@ +#!/bin/sh + +set -ex + +SAMPLE=$1 + +BIN=out/build/linux/samples/samples.exe + +if [ $SAMPLE = mult3 -o $SAMPLE = innerprod -o $SAMPLE = xtabs ]; then + $BIN -p 0 -u $SAMPLE & + $BIN -p 1 -u $SAMPLE & + $BIN -p 2 -u $SAMPLE +else + echo "Bad sample name. Supported samples: mult3, innerprod, xtabs" + exit 1 +fi diff --git a/aby3/source/CMakeLists.txt b/aby3/source/CMakeLists.txt new file mode 100644 index 0000000..7dede90 --- /dev/null +++ b/aby3/source/CMakeLists.txt @@ -0,0 +1,17 @@ + +#project(samples) + +############################################# +# Build samples.exe # +############################################# + +file(GLOB_RECURSE SRC_SAMPLES ${CMAKE_SOURCE_DIR}/samples/*.cpp) +include_directories(${CMAKE_SOURCE_DIR}/samples/) + +add_executable(samples.exe ${SRC_SAMPLES}) + +target_link_libraries(samples.exe com-psi) +target_link_libraries(samples.exe aby3-ML) +target_link_libraries(samples.exe com-psi_Tests) +target_link_libraries(samples.exe aby3_Tests) +target_link_libraries(samples.exe oc::tests_cryptoTools) diff --git a/aby3/source/innerprod.cpp b/aby3/source/innerprod.cpp new file mode 100644 index 0000000..2c279cb --- /dev/null +++ b/aby3/source/innerprod.cpp @@ -0,0 +1,68 @@ +#include +#include + +#include "aby3/sh3/Sh3Runtime.h" +#include "aby3/sh3/Sh3Encryptor.h" +#include "aby3/sh3/Sh3Evaluator.h" + +#include "innerprod.h" + +using namespace oc; +using namespace aby3; + +void innerprod_test(oc::u64 partyIdx, std::vector values) +{ + if (partyIdx == 0) + std::cout << "testing innerprod..." << std::endl; + + IOService ios; + Sh3Encryptor enc; + Sh3Evaluator eval; + Sh3Runtime runtime; + setup_samples(partyIdx, ios, enc, eval, runtime); + + // encrypt (only parties 0,1 provide input) + u64 rows = values.size(); + std::vector A(rows); + std::vector B(rows); + + // note: this fails if you try to make multiple local/remote calls in the same if/else statement + // maybe could fix this with tasks? + for (int i = 0; i < 10; i++) + { + if (partyIdx == 0) + { + enc.localInt(runtime, values[i], A[i]).get(); + } + else + { + enc.remoteInt(runtime, A[i]).get(); + } + + if (partyIdx == 1) + { + enc.localInt(runtime, values[i], B[i]).get(); + } + else + { + enc.remoteInt(runtime, B[i]).get(); + } + } + + // parallel multiplications + std::vector prods(rows); + Sh3Task task = runtime.noDependencies(); + for (u64 i = 0; i < rows; ++i) + task = eval.asyncMul(task, B[i], A[i], prods[i]); + task.get(); + + // addition + si64 sum = prods[0]; + for (u64 i = 1; i < rows; ++i) + sum = sum + (si64)prods[i]; + + // reveal result + i64 result; + enc.revealAll(runtime, sum, result).get(); + std::cout << "result: " << result << std::endl; +} diff --git a/aby3/source/innerprod.h b/aby3/source/innerprod.h new file mode 100644 index 0000000..a312a01 --- /dev/null +++ b/aby3/source/innerprod.h @@ -0,0 +1,14 @@ +#pragma once +#include +void innerprod_test(oc::u64 partyIdx, std::vector values); + +#include +#include "aby3/sh3/Sh3Runtime.h" +#include "aby3/sh3/Sh3Encryptor.h" +#include "aby3/sh3/Sh3Evaluator.h" +void setup_samples( + aby3::u64 partyIdx, + oc::IOService &ios, + aby3::Sh3Encryptor &enc, + aby3::Sh3Evaluator &eval, + aby3::Sh3Runtime &runtime); diff --git a/aby3/source/main.cpp b/aby3/source/main.cpp new file mode 100644 index 0000000..ff675d4 --- /dev/null +++ b/aby3/source/main.cpp @@ -0,0 +1,141 @@ +#include "mult3.h" +#include "innerprod.h" +#include "xtabs.h" + +// testing and command line parsing +#include +#include +std::vector unitTestTag{"u", "unitTest"}; +std::vector playerTag{"p", "player"}; + +// convenience function +#include +#include "aby3/sh3/Sh3Runtime.h" +#include "aby3/sh3/Sh3Encryptor.h" +#include "aby3/sh3/Sh3Evaluator.h" + +// This function sets up the basic classes that we will +// use to perform some computation. This mostly consists +// of creating Channels (network sockets) to the other +// parties and then establishing some shared randomness. +void setup_samples( + aby3::u64 partyIdx, + oc::IOService &ios, + aby3::Sh3Encryptor &enc, + aby3::Sh3Evaluator &eval, + aby3::Sh3Runtime &runtime); + +void help() +{ + std::cout << "-u ~~ to run all tests" << std::endl; + std::cout << "-u n1 [n2 ...] ~~ to run test n1, n2, ..." << std::endl; + std::cout << "-u -list ~~ to list all tests" << std::endl; + + std::cout << "-p party ~~ indicate which party you are" << std::endl; +} + +int main(int argc, char **argv) +{ + try + { + + oc::CLP cmd(argc, argv); + + oc::u64 player; + if (cmd.isSet(playerTag)) + { + player = cmd.getOr(playerTag, -1); + } + else + { + std::cout << "You need to specify which party you are" << std::endl; + return 0; + } + + // this calls the appropriate test function based on command-line arg + // functions are defined in their respective .h files + if (cmd.isSet(unitTestTag)) + { + std::string none = ""; + if (cmd.getOr(unitTestTag, none).compare("mult3") == 0) + { + // hardcoded input value. expected result: 3*4*5 = 60 + int value = 3 + player; + mult3_test(player, value); + } + if (cmd.getOr(unitTestTag, none).compare("innerprod") == 0) + { + std::vector values(10); + // expected result: (0*0) + (1*2) + (2*4) + ... = 570 + for (uint i = 0; i < values.size(); i++) + { + values[i] = (player + 1) * i; + } + innerprod_test(player, values); + } + if (cmd.getOr(unitTestTag, none).compare("xtabs") == 0) + { + std::vector ids{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::vector values{0, 2, 2, 3, 4, 5, 6, 7, 8, 9}; + xtabs_test(player, ids, values, 10); + } + + return 0; + } + + help(); + } + catch (std::exception &e) + { + std::cout << e.what() << std::endl; + } + + return 0; +} + +using namespace aby3; +using namespace oc; +using namespace aby3; +using namespace oc; +void setup_samples( + aby3::u64 partyIdx, + oc::IOService &ios, + aby3::Sh3Encryptor &enc, + aby3::Sh3Evaluator &eval, + aby3::Sh3Runtime &runtime) +{ + // A CommPkg is a pair of Channels (network sockets) to the other parties. + // See cryptoTools\frontend_cryptoTools\Tutorials\Network.cpp + // for details. + // since we're running them all locally, they're sitting on 3 different ports. + aby3::CommPkg comm; + switch (partyIdx) + { + case 0: + comm.mNext = oc::Session(ios, "127.0.0.1:1313", oc::SessionMode::Server, "01").addChannel(); + comm.mPrev = oc::Session(ios, "127.0.0.1:1314", oc::SessionMode::Server, "02").addChannel(); + break; + case 1: + comm.mNext = oc::Session(ios, "127.0.0.1:1315", oc::SessionMode::Server, "12").addChannel(); + comm.mPrev = oc::Session(ios, "127.0.0.1:1313", oc::SessionMode::Client, "01").addChannel(); + break; + default: + comm.mNext = oc::Session(ios, "127.0.0.1:1314", oc::SessionMode::Client, "02").addChannel(); + comm.mPrev = oc::Session(ios, "127.0.0.1:1315", oc::SessionMode::Client, "12").addChannel(); + break; + } + + // in a real work example, where parties + // have different IPs, you have to give the + // Clients the IP of the server and you give + // the servers their own IP (to listen to). + + // Establishes some shared randomness needed for the later protocols + enc.init(partyIdx, comm, sysRandomSeed()); + + // Establishes some shared randomness needed for the later protocols + eval.init(partyIdx, comm, sysRandomSeed()); + + // Copies the Channels and will use them for later protcols. + runtime.init(partyIdx, comm); +} diff --git a/aby3/source/mult3.cpp b/aby3/source/mult3.cpp new file mode 100644 index 0000000..cad7052 --- /dev/null +++ b/aby3/source/mult3.cpp @@ -0,0 +1,63 @@ +#include +#include + +#include "aby3/sh3/Sh3Runtime.h" +#include "aby3/sh3/Sh3Encryptor.h" +#include "aby3/sh3/Sh3Evaluator.h" + +#include "mult3.h" + +using namespace oc; +using namespace aby3; +/* + * mult3 + * + * multiplies three numbers together + * lots of this code is ported from Peter Rindal's aby3Tutorial.cpp + * + */ + +void mult3_test(u64 partyIdx, int value) +{ + if (partyIdx == 0) + std::cout << "testing mult3..." << std::endl; + + IOService ios; + + // Sh3Encryptor allows us to generate and reconstruct secret shared values. + Sh3Encryptor enc; + // Sh3Evaluator will allow us to perform some of the + // most common interactive protocols, e.g. multiplication. + Sh3Evaluator eval; + // Sh3Runtime does networking and helps schedule operations in parallel + Sh3Runtime runtime; + setup_samples(partyIdx, ios, enc, eval, runtime); + + std::vector sharedVec(3); + + /* Convert clear values from each player to secure types. + * tutorial suggests doing this asynchronously by task &='ing each call + * instead of calling .get() immediately + * This didn't compile for me; the runtime destructor was getting called way too early. + */ + for (u64 i = 0; i < sharedVec.size(); ++i) + { + if (i % 3 == partyIdx) + enc.localInt(runtime, value, sharedVec[i]).get(); + else + enc.remoteInt(runtime, sharedVec[i]).get(); + } + + /* multiply them together */ + si64 prod = sharedVec[0]; + Sh3Task task = runtime.noDependencies(); + for (u64 i = 1; i < sharedVec.size(); ++i) + task = eval.asyncMul(task, prod, sharedVec[i], prod); + + task.get(); + + /* reveal result */ + i64 result; + enc.revealAll(runtime, prod, result).get(); + std::cout << "product: " << result << std::endl; +} diff --git a/aby3/source/mult3.h b/aby3/source/mult3.h new file mode 100644 index 0000000..e08ed7d --- /dev/null +++ b/aby3/source/mult3.h @@ -0,0 +1,14 @@ +#pragma once +#include +void mult3_test(oc::u64 partyIdx, int value); + +#include +#include "aby3/sh3/Sh3Runtime.h" +#include "aby3/sh3/Sh3Encryptor.h" +#include "aby3/sh3/Sh3Evaluator.h" +void setup_samples( + aby3::u64 partyIdx, + oc::IOService &ios, + aby3::Sh3Encryptor &enc, + aby3::Sh3Evaluator &eval, + aby3::Sh3Runtime &runtime); diff --git a/aby3/source/xtabs.cpp b/aby3/source/xtabs.cpp new file mode 100644 index 0000000..d259ed2 --- /dev/null +++ b/aby3/source/xtabs.cpp @@ -0,0 +1,108 @@ +#include +#include +#include +#include "xtabs.h" + +#include "aby3-DB/DBServer.h" +#include +#include "aby3/sh3/Sh3Runtime.h" +#include "aby3/sh3/Sh3Encryptor.h" +#include "aby3/sh3/Sh3Evaluator.h" + +using namespace oc; +using namespace aby3; +using namespace std; + +void xtabs_test(u64 partyIdx, std::vector ids, std::vector values, int nCats) +{ + std::cout << "testing xtabs..." << std::endl; + + IOService ios; + + DBServer srv; + + PRNG prng(ZeroBlock); + + if (partyIdx == 0) + { + Session s01(ios, "127.0.0.1:3030", SessionMode::Server, "01"); + Session s02(ios, "127.0.0.1:3031", SessionMode::Server, "02"); + srv.init(0, s02, s01, prng); + } + else if (partyIdx == 1) + { + Session s10(ios, "127.0.0.1:3030", SessionMode::Client, "01"); + Session s12(ios, "127.0.0.1:3032", SessionMode::Server, "12"); + srv.init(1, s10, s12, prng); + } + else + { + Session s20(ios, "127.0.0.1:3031", SessionMode::Client, "02"); + Session s21(ios, "127.0.0.1:3032", SessionMode::Client, "12"); + srv.init(2, s21, s20, prng); + } + + auto keyBitCount = srv.mKeyBitCount; + std::vector + catCols = {ColumnInfo{"key", TypeID::IntID, keyBitCount}, + ColumnInfo{"cat", TypeID::IntID, keyBitCount}}, + valCols = {ColumnInfo{"key", TypeID::IntID, keyBitCount}, + ColumnInfo{"val", TypeID::IntID, keyBitCount}}; + + u64 rows = ids.size(); + assert(ids.size() == rows); + assert(values.size() == rows); + + Table catData(rows, catCols), valData(rows, valCols); + + // initializes data into Table (still in the clear) + for (u64 i = 0; i < rows; ++i) + { + if (partyIdx == 0) + { + catData.mColumns[0].mData(i, 0) = ids[i]; + catData.mColumns[0].mData(i, 1) = values[i]; + } + else if (partyIdx == 1) + { + valData.mColumns[0].mData(i, 0) = ids[i]; + valData.mColumns[0].mData(i, 1) = values[i]; + } + } + + SharedTable catTable, valTable; + catTable = (partyIdx == 0) ? srv.localInput(catData) : srv.remoteInput(0); + valTable = (partyIdx == 1) ? srv.localInput(valData) : srv.remoteInput(1); + + // i64Matrix keys(catTable.mColumns[0].rows(), catTable.mColumns[0].i64Cols()); + // srv.mEnc.revealAll(srv.mRt.mComm, catTable.mColumns[0], keys); + // + // if (partyIdx == 0) + //{ + // std::cout << keys << std::endl; + //} + + auto res = srv.join(catTable["key"], valTable["key"], {catTable["cat"], valTable["val"]}); + cout << "not reached " << endl; + // for (auto c : res.mColumns) { + // cout << c.mName << "xXx"; + // } + + aby3::i64Matrix cats(res.mColumns[0].rows(), res.mColumns[0].i64Cols()); + aby3::i64Matrix vals(res.mColumns[1].rows(), res.mColumns[1].i64Cols()); + + srv.mEnc.revealAll(srv.mRt.mComm, res.mColumns[0], cats); //debug + srv.mEnc.revealAll(srv.mRt.mComm, res.mColumns[1], vals); + + if (partyIdx == 0) + { + std::cout << cats << std::endl; + } + + std::vector sums(nCats); + for (int i = 0; i < res.mColumns[0].rows(); i++) { + for(int c = 0; c < nCats; c++) { + // if res(0, i) == c then + } + } +} diff --git a/aby3/source/xtabs.h b/aby3/source/xtabs.h new file mode 100644 index 0000000..3f6cef0 --- /dev/null +++ b/aby3/source/xtabs.h @@ -0,0 +1,14 @@ +#pragma once +#include +void xtabs_test(oc::u64 partyIdx, std::vector ids, std::vector values, int nCats); + +#include +#include "aby3/sh3/Sh3Runtime.h" +#include "aby3/sh3/Sh3Encryptor.h" +#include "aby3/sh3/Sh3Evaluator.h" +void setup_samples( + aby3::u64 partyIdx, + oc::IOService &ios, + aby3::Sh3Encryptor &enc, + aby3::Sh3Evaluator &eval, + aby3::Sh3Runtime &runtime);