diff --git a/.gitmodules b/.gitmodules index d11ae3e..c186e0c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -28,3 +28,6 @@ [submodule "fmt"] path = thirdparty/fmt url = https://github.com/fmtlib/fmt.git +[submodule "thirdparty/loguru"] + path = thirdparty/loguru + url = https://github.com/emilk/loguru.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 2389283..1e944fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,8 @@ if (IYOKAN_ENABLE_CUDA) else() add_subdirectory(thirdparty/cuFHE/thirdparties/TFHEpp) endif(IYOKAN_ENABLE_CUDA) + +add_subdirectory(thirdparty/fmt) add_subdirectory(thirdparty/spdlog) set(IYOKAN_CXXFLAGS -Wall -Wextra -Wno-sign-compare) @@ -47,13 +49,13 @@ set(IYOKAN_INCLUDE_DIRS $ $ $ - $ $ + $ ) if (IYOKAN_ENABLE_CUDA) list(APPEND IYOKAN_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) endif(IYOKAN_ENABLE_CUDA) -set(IYOKAN_LIBS tfhe++ Threads::Threads OpenMP::OpenMP_CXX Backward::Backward stdc++fs) +set(IYOKAN_LIBS tfhe++ Threads::Threads OpenMP::OpenMP_CXX Backward::Backward stdc++fs fmt::fmt) if (IYOKAN_80BIT_SECURITY) # For TFHEpp headers list(APPEND IYOKAN_COMPILE_DEFINITIONS USE_80BIT_SECURITY) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e7f0d4f..8dad3fe 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,7 +31,10 @@ endif(IYOKAN_ENABLE_CUDA) ##### test0 add_executable(test0 - test0.cpp iyokan.cpp iyokan_plain.cpp iyokan_tfhepp.cpp error.cpp + test0.cpp + iyokan_nt.cpp iyokan_nt_plain.cpp iyokan_nt_tfhepp.cpp + error_nt.cpp packet_nt.cpp dataholder_nt.cpp + network_reader.cpp blueprint.cpp label.cpp allocator.cpp snapshot.cpp ${CMAKE_CURRENT_BINARY_DIR}/mux-ram-8-8-8.o ${CMAKE_CURRENT_BINARY_DIR}/mux-ram-8-16-16.o ${CMAKE_CURRENT_BINARY_DIR}/mux-ram-9-16-16.o) @@ -41,11 +44,11 @@ target_compile_options(test0 PUBLIC "$<$:${IYOKAN_CXXFLAGS_RELEA target_link_libraries(test0 ${IYOKAN_LIBS}) target_include_directories(test0 PRIVATE ${IYOKAN_INCLUDE_DIRS}) target_compile_definitions(test0 PRIVATE ${IYOKAN_COMPILE_DEFINITIONS}) -if (IYOKAN_ENABLE_CUDA) - target_sources(test0 PRIVATE iyokan_cufhe.cpp) - target_link_libraries(test0 cufhe_gpu ${CUDA_LIBRARIES}) - target_compile_definitions(test0 PRIVATE IYOKAN_CUDA_ENABLED) -endif(IYOKAN_ENABLE_CUDA) +#if (IYOKAN_ENABLE_CUDA) +# target_sources(test0 PRIVATE iyokan_cufhe.cpp) +# target_link_libraries(test0 cufhe_gpu ${CUDA_LIBRARIES}) +# target_compile_definitions(test0 PRIVATE IYOKAN_CUDA_ENABLED) +#endif(IYOKAN_ENABLE_CUDA) ##### iyokan-packet add_executable(iyokan-packet iyokan-packet.cpp error.cpp) diff --git a/src/allocator.cpp b/src/allocator.cpp new file mode 100644 index 0000000..c10fb1b --- /dev/null +++ b/src/allocator.cpp @@ -0,0 +1,120 @@ +#include "dataholder_nt.hpp" +#include "iyokan_nt.hpp" +#include "packet_nt.hpp" + +#include +#include +#include + +#include +#include +#include + +namespace { + +using namespace nt; + +template +void save(Archive& ar, const Bit& b, const std::uint32_t version) +{ + assert(version == 1); + + ar(static_cast(b)); +} + +template +void load(Archive& ar, Bit& b, const std::uint32_t version) +{ + assert(version == 1); + + bool bl; + ar(bl); + b = static_cast(b); +} + +struct Serializable { + std::variant data; + + template + void serialize(Archive& ar, const std::uint32_t version) + { + assert(version == 1); + + ar(data); + } +}; + +} // namespace + +CEREAL_CLASS_VERSION(Serializable, 1); + +namespace nt { + +/* class Allocator */ + +Allocator::Allocator() + : hasLoadedFromIStream_(false), indexToBeMade_(0), data_() +{ +} + +Allocator::Allocator(cereal::PortableBinaryInputArchive& ar) + : hasLoadedFromIStream_(true), indexToBeMade_(0), data_() +{ + // Read and de-serialize data from the snapshot file + size_t size; + Serializable buf; + + ar(size); + for (size_t i = 0; i < size; i++) { + ar(buf); + switch (buf.data.index()) { + case 0: { // Bit + Bit b = std::get<0>(buf.data); + data_.emplace_back(b); + break; + } + + case 1: { // TLWELvl0 + const TLWELvl0& tlwe = std::get<1>(buf.data); + data_.emplace_back(tlwe); + break; + } + + default: + ERR_UNREACHABLE; + } + } +} + +void Allocator::dumpAllocatedData(cereal::PortableBinaryOutputArchive& ar) const +{ + // FIXME: WE KNOW the code in this function is TREMENDOUSLY UGLY (and + // inefficient). We need to find some more sophisticated ways to do this. + + Serializable buf; + + // Serialization process for each type + std::unordered_map> + tyHandlers; + tyHandlers[typeid(Bit)] = [&](const std::any& any) { + const Bit* src = std::any_cast(&any); + buf.data = *src; + ar(buf); + }; + tyHandlers[typeid(TLWELvl0)] = [&](const std::any& any) { + const TLWELvl0* src = std::any_cast(&any); + buf.data = *src; + ar(buf); + }; + + // First serialize the size of the entries + ar(static_cast(data_.size())); + + // Dispatch + for (size_t i = 0; i < data_.size(); i++) { + const std::any& src = data_.at(i); + tyHandlers.at(src.type())(src); + } +} + +} // namespace nt diff --git a/src/blueprint.cpp b/src/blueprint.cpp new file mode 100644 index 0000000..ffaa9c7 --- /dev/null +++ b/src/blueprint.cpp @@ -0,0 +1,299 @@ +#include "blueprint.hpp" +#include "error_nt.hpp" + +#include +#include + +#include + +namespace { +std::vector regexMatch(const std::string& text, + const std::regex& re) +{ + std::vector ret; + std::smatch m; + if (!std::regex_match(text, m, re)) + return ret; + for (auto&& elm : m) + ret.push_back(elm.str()); + return ret; +} + +} // namespace + +namespace nt { + +/* class Blueprint */ + +Blueprint::Blueprint(const std::string& fileName) +{ + namespace fs = std::filesystem; + + // Read the file + std::stringstream inputStream; + { + std::ifstream ifs{fileName}; + if (!ifs) + ERR_DIE("File not found: " << fileName); + inputStream << ifs.rdbuf(); + source_ = inputStream.str(); + inputStream.seekg(std::ios::beg); + } + + // Parse config file + const auto src = toml::parse(inputStream, fileName); + + // Find working directory of config + fs::path wd = fs::absolute(fileName); + wd.remove_filename(); + + // [[file]] + { + const auto srcFiles = + toml::find_or>(src, "file", {}); + for (const auto& srcFile : srcFiles) { + std::string typeStr = toml::find(srcFile, "type"); + fs::path path = toml::find(srcFile, "path"); + std::string name = toml::find(srcFile, "name"); + + blueprint::File::TYPE type; + if (typeStr == "iyokanl1-json") + type = blueprint::File::TYPE::IYOKANL1_JSON; + else if (typeStr == "yosys-json") + type = blueprint::File::TYPE::YOSYS_JSON; + else + ERR_DIE("Invalid file type: " << typeStr); + + if (path.is_relative()) + path = wd / path; // Make path absolute + + files_.push_back(blueprint::File{type, path.string(), name}); + } + } + + // [[builtin]] + { + const auto srcBuiltins = + toml::find_or>(src, "builtin", {}); + for (const auto& srcBuiltin : srcBuiltins) { + const auto type = toml::find(srcBuiltin, "type"); + const auto name = toml::find(srcBuiltin, "name"); + + if (type == "rom" || type == "mux-rom") { + auto romType = type == "rom" + ? blueprint::BuiltinROM::TYPE::CMUX_MEMORY + : blueprint::BuiltinROM::TYPE::MUX; + const auto inAddrWidth = + toml::find(srcBuiltin, "in_addr_width"); + const auto outRdataWidth = + toml::find(srcBuiltin, "out_rdata_width"); + + builtinROMs_.push_back(blueprint::BuiltinROM{ + romType, name, inAddrWidth, outRdataWidth}); + } + else if (type == "ram" || type == "mux-ram") { + auto ramType = type == "ram" + ? blueprint::BuiltinRAM::TYPE::CMUX_MEMORY + : blueprint::BuiltinRAM::TYPE::MUX; + const auto inAddrWidth = + toml::find(srcBuiltin, "in_addr_width"); + const auto inWdataWidth = + toml::find(srcBuiltin, "in_wdata_width"); + const auto outRdataWidth = + toml::find(srcBuiltin, "out_rdata_width"); + + builtinRAMs_.push_back(blueprint::BuiltinRAM{ + ramType, name, inAddrWidth, inWdataWidth, outRdataWidth}); + } + } + } + + // [connect] + { + const auto srcConnect = toml::find_or(src, "connect", {}); + for (const auto& [srcKey, srcValue] : srcConnect) { + if (srcKey == "TOGND") { // TOGND = [@...[n:m], @...[n:m], ...] + auto ary = toml::get>(srcValue); + for (const auto& portStr : ary) { // @...[n:m] + if (portStr.empty() || portStr.at(0) != '@') + ERR_DIE("Invalid port name for TOGND: " << portStr); + auto ports = parsePortString(portStr, Label::OUTPUT); + for (auto&& port : ports) { // @...[n] + const std::string& name = port.cname.portName; + int bit = port.cname.portBit; + auto [it, inserted] = atPortWidths_.emplace(name, 0); + it->second = std::max(it->second, bit + 1); + } + } + continue; + } + + std::string srcTo = srcKey, + srcFrom = toml::get(srcValue), + errMsg = fmt::format("Invalid connect: {} = {}", srcTo, + srcFrom); + + // Check if input is correct. + if (srcTo.empty() || srcFrom.empty() || + (srcTo[0] == '@' && srcFrom[0] == '@')) + ERR_DIE(errMsg); + + // Others. + std::vector portsTo = parsePortString( + srcTo, Label::INPUT), + portsFrom = parsePortString( + srcFrom, Label::OUTPUT); + if (portsTo.size() != portsFrom.size()) + ERR_DIE(errMsg); + + for (size_t i = 0; i < portsTo.size(); i++) { + const blueprint::Port& to = portsTo[i]; + const blueprint::Port& from = portsFrom[i]; + + if (srcTo[0] == '@') { // @... = ... + if (!to.cname.nodeName.empty() || + from.cname.nodeName.empty()) + ERR_DIE(errMsg); + + const std::string& name = to.cname.portName; + int bit = to.cname.portBit; + + { + auto [it, inserted] = + atPorts_.emplace(std::make_tuple(name, bit), from); + if (!inserted) + LOG_S(WARNING) + << srcTo + << " is used multiple times. Only the first " + "one is effective."; + } + + auto [it, inserted] = atPortWidths_.emplace(name, 0); + it->second = std::max(it->second, bit + 1); + } + else if (srcFrom[0] == '@') { // ... = @... + if (!from.cname.nodeName.empty() || + to.cname.nodeName.empty()) + ERR_DIE(errMsg); + + const std::string& name = from.cname.portName; + int bit = from.cname.portBit; + + { + auto [it, inserted] = + atPorts_.emplace(std::make_tuple(name, bit), to); + if (!inserted) + LOG_S(WARNING) + << srcFrom + << " is used multiple times. Only the first " + "one is effective. (FIXME)"; + } + + auto [it, inserted] = atPortWidths_.emplace(name, 0); + it->second = std::max(it->second, bit + 1); + } + else { // ... = ... + edges_.emplace_back(from, to); + } + } + } + } +} + +std::vector Blueprint::parsePortString(const std::string& src, + const char* const kind) +{ + std::string nodeName, portName; + int portBitFrom, portBitTo; + + auto match = regexMatch( + src, + std::regex(R"(^@?(?:([^/]+)/)?([^[]+)(?:\[([0-9]+):([0-9]+)\])?$)")); + if (match.empty()) + ERR_DIE("Invalid port string: " << src); + + assert(match.size() == 1 + 4); + + nodeName = match[1]; + portName = match[2]; + + if (match[3].empty()) { // hoge/piyo + assert(match[4].empty()); + portBitFrom = 0; + portBitTo = 0; + } + else { // hoge/piyo[foo:bar] + assert(!match[4].empty()); + portBitFrom = std::stoi(match[3]); + portBitTo = std::stoi(match[4]); + } + + std::vector ret; + for (int i = portBitFrom; i < portBitTo + 1; i++) + ret.push_back(blueprint::Port{kind, {nodeName, portName, i}}); + return ret; +} + +bool Blueprint::needsCircuitKey() const +{ + for (const auto& bprom : builtinROMs_) + if (bprom.type == blueprint::BuiltinROM::TYPE::CMUX_MEMORY) + return true; + for (const auto& bpram : builtinRAMs_) + if (bpram.type == blueprint::BuiltinRAM::TYPE::CMUX_MEMORY) + return true; + return false; +} + +const std::string& Blueprint::sourceFile() const +{ + return sourceFile_; +} + +const std::string& Blueprint::source() const +{ + return source_; +} + +const std::vector& Blueprint::files() const +{ + return files_; +} + +const std::vector& Blueprint::builtinROMs() const +{ + return builtinROMs_; +} + +const std::vector& Blueprint::builtinRAMs() const +{ + return builtinRAMs_; +} + +const std::vector>& +Blueprint::edges() const +{ + return edges_; +} + +const std::map, blueprint::Port>& +Blueprint::atPorts() const +{ + return atPorts_; +} + +std::optional Blueprint::at(const std::string& portName, + int portBit) const +{ + auto it = atPorts_.find(std::make_tuple(portName, portBit)); + if (it == atPorts_.end()) + return std::nullopt; + return it->second; +} + +const std::unordered_map& Blueprint::atPortWidths() const +{ + return atPortWidths_; +} + +} // namespace nt diff --git a/src/blueprint.hpp b/src/blueprint.hpp new file mode 100644 index 0000000..b9b8559 --- /dev/null +++ b/src/blueprint.hpp @@ -0,0 +1,82 @@ +#ifndef VIRTUALSECUREPLATFORM_BLUEPRINT_HPP +#define VIRTUALSECUREPLATFORM_BLUEPRINT_HPP + +#include "label.hpp" + +#include +#include +#include +#include + +namespace nt { + +namespace blueprint { // blueprint components +struct File { + enum class TYPE { + IYOKANL1_JSON, + YOSYS_JSON, + } type; + std::string path, name; +}; + +struct BuiltinROM { + enum class TYPE { + CMUX_MEMORY, + MUX, + } type; + std::string name; + size_t inAddrWidth, outRdataWidth; +}; + +struct BuiltinRAM { + enum class TYPE { + CMUX_MEMORY, + MUX, + } type; + std::string name; + size_t inAddrWidth, inWdataWidth, outRdataWidth; +}; + +struct Port { + const char* kind; // Store a string literal + ConfigName cname; +}; +} // namespace blueprint + +class Blueprint { +private: + std::string sourceFile_, source_; + + std::vector files_; + std::vector builtinROMs_; + std::vector builtinRAMs_; + std::vector> edges_; + + std::map, blueprint::Port> atPorts_; + std::unordered_map atPortWidths_; + +private: + std::vector parsePortString(const std::string& src, + const char* const kind); + +public: + Blueprint(const std::string& fileName); + + bool needsCircuitKey() const; + const std::string& sourceFile() const; + const std::string& source() const; + const std::vector& files() const; + const std::vector& builtinROMs() const; + const std::vector& builtinRAMs() const; + const std::vector>& edges() + const; + const std::map, blueprint::Port>& atPorts() + const; + std::optional at(const std::string& portName, + int portBit = 0) const; + const std::unordered_map& atPortWidths() const; +}; + +} // namespace nt + +#endif diff --git a/src/dataholder_nt.cpp b/src/dataholder_nt.cpp new file mode 100644 index 0000000..23d5064 --- /dev/null +++ b/src/dataholder_nt.cpp @@ -0,0 +1,44 @@ +#include "dataholder_nt.hpp" + +namespace nt { + +/* class DataHolder */ +DataHolder::DataHolder() : dataBit_(nullptr), type_(TYPE::UND) +{ +} + +DataHolder::DataHolder(const Bit* const dataBit) + : dataBit_(dataBit), type_(TYPE::BIT) +{ +} + +DataHolder::DataHolder(const TLWELvl0* const dataTLWELvl0) + : dataTLWELvl0_(dataTLWELvl0), type_(TYPE::TLWE_LVL0) +{ +} + +Bit DataHolder::getBit() const +{ + assert(type_ == TYPE::BIT); + return *dataBit_; +} + +void DataHolder::setBit(const Bit* const dataBit) +{ + dataBit_ = dataBit; + type_ = TYPE::BIT; +} + +void DataHolder::getTLWELvl0(TLWELvl0& out) const +{ + assert(type_ == TYPE::TLWE_LVL0); + out = *dataTLWELvl0_; +} + +void DataHolder::setTLWELvl0(const TLWELvl0* const dataTLWELvl0) +{ + dataTLWELvl0_ = dataTLWELvl0; + type_ = TYPE::TLWE_LVL0; +} + +} // namespace nt diff --git a/src/dataholder_nt.hpp b/src/dataholder_nt.hpp new file mode 100644 index 0000000..47a7f82 --- /dev/null +++ b/src/dataholder_nt.hpp @@ -0,0 +1,46 @@ +#ifndef VIRTUALSECUREPLATFORM_DATAHOLDER_NT_HPP +#define VIRTUALSECUREPLATFORM_DATAHOLDER_NT_HPP + +#include "tfhepp_cufhe_wrapper.hpp" + +#include + +#include + +namespace nt { + +enum class Bit : bool; + +// DataHolder holds data using Task::setInput/Task::getOutput. +// FIXME: The name "DataHolder" is misleading. Maybe "IODataPointer" or +// something is more suitable, because we actually do not "hold" the data but +// "point" them. DataHolder::setBit() does not set the bit itself but set the +// pointer to a bit. +class DataHolder { +private: + union { + const Bit* dataBit_; + const TLWELvl0* dataTLWELvl0_; + }; + + enum class TYPE { + UND, + BIT, + TLWE_LVL0, + } type_; + +public: + DataHolder(); + DataHolder(const Bit* const dataBit); + DataHolder(const TLWELvl0* const dataTLWELvl0); + + Bit getBit() const; + void setBit(const Bit* const dataBit); + + void getTLWELvl0(TLWELvl0& out) const; + void setTLWELvl0(const TLWELvl0* const dataTLWELvl0); +}; + +} // namespace nt + +#endif diff --git a/src/error.hpp b/src/error.hpp index ddc4578..494963d 100644 --- a/src/error.hpp +++ b/src/error.hpp @@ -18,6 +18,31 @@ inline void initialize(const std::string& tag) spdlog::set_default_logger(spdlog::stderr_color_mt(tag)); } +template +[[noreturn]] void diefmt(Args&&... args) +{ + using namespace backward; + + // Print error message + spdlog::error(std::forward(args)...); + +#ifndef NDEBUG + { + // Print backtrace + spdlog::error("Preparing backtrace..."); + std::stringstream ss; + StackTrace st; + st.load_here(32); + Printer p; + p.print(st, ss); + spdlog::error(ss.str()); + } +#endif + + // Abort + std::exit(EXIT_FAILURE); +} + template [[noreturn]] void die(Args... args) { diff --git a/src/error_nt.cpp b/src/error_nt.cpp new file mode 100644 index 0000000..84d20e5 --- /dev/null +++ b/src/error_nt.cpp @@ -0,0 +1,40 @@ +#include "error_nt.hpp" + +#include + +#include + +namespace nt::error { +void initialize() +{ +#ifdef NDEBUG + // Release + loguru::g_stderr_verbosity = loguru::Verbosity_INFO; +#else + // Debug + loguru::g_stderr_verbosity = loguru::Verbosity_1; // Show LOG_DBG messages +#endif +} + +void abortWithBacktrace() +{ + using namespace backward; + +#ifndef NDEBUG + { + // Print backtrace + LOG_F(ERROR, "Preparing backtrace..."); + StackTrace st; + st.load_here(32); + Printer p; + p.print(st, stderr); + } +#endif + + // Abort + std::exit(EXIT_FAILURE); +} + +} // namespace nt::error + +#include diff --git a/src/error_nt.hpp b/src/error_nt.hpp new file mode 100644 index 0000000..1a5a847 --- /dev/null +++ b/src/error_nt.hpp @@ -0,0 +1,24 @@ +#ifndef VIRTUALSECUREPLATFORM_ERROR_NT_HPP +#define VIRTUALSECUREPLATFORM_ERROR_NT_HPP + +#define LOGURU_WITH_STREAMS 1 +#include + +namespace nt::error { +void initialize(); +[[noreturn]] void abortWithBacktrace(); +} // namespace nt::error + +#define LOG_DBG LOG_S(1) + +#define ERR_DIE(cont) \ + do { \ + LOG_S(ERROR) << cont; \ + nt::error::abortWithBacktrace(); \ + } while (false); + +#define ERR_UNREACHABLE ERR_DIE("Internal error: unreachable here") + +#define LOG_DBG_SCOPE(...) LOG_SCOPE_F(1, __VA_ARGS__); + +#endif diff --git a/src/iyokan-packet.cpp b/src/iyokan-packet.cpp index 2109faf..21e775a 100644 --- a/src/iyokan-packet.cpp +++ b/src/iyokan-packet.cpp @@ -3,6 +3,8 @@ #include #include +#include + namespace { enum class TYPE { @@ -336,7 +338,6 @@ int main(int argc, char** argv) toml2packet --in packet.toml --out packet.plain */ - using namespace utility; using namespace std::chrono; error::initialize("iyokan-packet"); diff --git a/src/iyokan_nt.cpp b/src/iyokan_nt.cpp new file mode 100644 index 0000000..c0fce73 --- /dev/null +++ b/src/iyokan_nt.cpp @@ -0,0 +1,945 @@ +#include "iyokan_nt.hpp" +#include "blueprint.hpp" +#include "dataholder_nt.hpp" +#include "error_nt.hpp" + +#include + +#include + +namespace { + +// Visit tasks in the network in topological order. +template +void visitTaskTopo(const nt::Network& net, F f) +{ + using namespace nt; + + std::unordered_map numReadyParents; + std::queue que; + net.eachTask([&](Task* task) { + numReadyParents.emplace(task, 0); + if (task->areAllInputsReady()) + que.push(task); + }); + while (!que.empty()) { + Task* task = que.front(); + que.pop(); + f(task); + for (Task* child : task->children()) { + if (child->areAllInputsReady()) // false parent-child relationship + continue; + numReadyParents.at(child)++; + assert(child->parents().size() >= numReadyParents.at(child)); + if (child->parents().size() == numReadyParents.at(child)) + que.push(child); + } + } +} + +// Visit tasks in the network in reversed topological order +template +void visitTaskRevTopo(const nt::Network& net, F f) +{ + using namespace nt; + + std::unordered_map + numReadyChildren; // task |-> # of ready children + std::queue que; // Initial tasks to be visited + net.eachTask([&](Task* task) { + const std::vector& children = task->children(); + + // Count the children that have no inputs to wait for + size_t n = std::count_if( + children.begin(), children.end(), + [&](Task* child) { return child->areAllInputsReady(); }); + numReadyChildren.emplace(task, n); + + // Initial nodes should be "terminals", that is, + // they have no children OR all of their children has no inputs to wait + // for. + if (children.size() == n) + que.push(task); + }); + assert(!que.empty()); + + while (!que.empty()) { + Task* task = que.front(); + que.pop(); + f(task); + if (task->areAllInputsReady()) // The end of the travel + continue; + // Push parents into the queue if all of their children are ready + for (Task* parent : task->parents()) { + numReadyChildren.at(parent)++; + assert(parent->children().size() >= numReadyChildren.at(parent)); + if (parent->children().size() == numReadyChildren.at(parent)) + que.push(parent); + } + } +} + +void prioritizeTaskByTopo(const nt::Network& net) +{ + using namespace nt; + + size_t numPrioritizedTasks = 0; + visitTaskTopo(net, [&](Task* task) { + // Calculate and set the priority for the task + int pri = -1; + if (!task->areAllInputsReady()) + for (Task* parent : task->parents()) + pri = std::max(pri, parent->priority()); + task->setPriority(pri + 1); + numPrioritizedTasks++; + }); + + if (net.size() > numPrioritizedTasks) { + LOG_DBG << "net.size() " << net.size() << " != numPrioritizedTasks " + << numPrioritizedTasks; + net.eachTask([&](Task* task) { + const Label& l = task->label(); + if (task->priority() == -1) + LOG_DBG << "\t" << l.uid << " " << l.kind << " "; + }); + ERR_DIE("Invalid network; some nodes will not be executed."); + } + assert(net.size() == numPrioritizedTasks); +} + +void prioritizeTaskByRanku(const nt::Network& net) +{ + // c.f. https://en.wikipedia.org/wiki/Heterogeneous_Earliest_Finish_Time + // FIXME: Take communication costs into account + // FIXME: Tune computation costs by dynamic measurements + + using namespace nt; + + size_t numPrioritizedTasks = 0; + visitTaskRevTopo(net, [&](Task* task) { + // Calculate and set the priority for the task + int pri = 0; + for (Task* child : task->children()) + // Only take valid children (i.e., ones that have no some inputs to + // wait for) into account + if (!child->areAllInputsReady()) + pri = std::max(pri, child->priority()); + task->setPriority(pri + task->getComputationCost()); + numPrioritizedTasks++; + }); + + if (net.size() > numPrioritizedTasks) { + LOG_DBG << "net.size() " << net.size() << " != numPrioritizedTasks " + << numPrioritizedTasks; + net.eachTask([&](Task* task) { + const Label& l = task->label(); + if (task->priority() == -1) + LOG_DBG << "\t" << l.uid << " " << l.kind << " "; + }); + ERR_DIE("Invalid network; some nodes will not be executed."); + } + assert(net.size() == numPrioritizedTasks); +} + +} // namespace + +namespace nt { + +/* class Task */ + +Task::Task(Label label) + : label_(std::move(label)), + parents_(), + children_(), + priority_(-1), + hasQueued_(false) +{ +} + +Task::~Task() +{ +} + +const Label& Task::label() const +{ + return label_; +} + +const std::vector& Task::parents() const +{ + return parents_; +} + +const std::vector& Task::children() const +{ + return children_; +} + +int Task::priority() const +{ + return priority_; +} + +bool Task::hasQueued() const +{ + return hasQueued_; +} + +void Task::addChild(Task* task) +{ + assert(task != nullptr); + children_.push_back(task); +} + +void Task::addParent(Task* task) +{ + assert(task != nullptr); + parents_.push_back(task); +} + +void Task::setPriority(int newPri) +{ + priority_ = newPri; +} + +void Task::setQueued() +{ + assert(!hasQueued_); + hasQueued_ = true; +} + +int Task::getComputationCost() const +{ + return 0; +} + +void Task::tick() +{ + hasQueued_ = false; +} + +void Task::getOutput(DataHolder&) +{ + ERR_UNREACHABLE; +} + +void Task::setInput(const DataHolder&) +{ + ERR_UNREACHABLE; +} + +bool Task::canRunPlain() const +{ + return false; +} + +bool Task::canRunTFHEpp() const +{ + return false; +} + +void Task::onAfterTick(size_t) +{ + // Do nothing by default. +} + +void Task::startAsynchronously(plain::WorkerInfo&) +{ + ERR_UNREACHABLE; +} + +void Task::startAsynchronously(tfhepp::WorkerInfo&) +{ + ERR_UNREACHABLE; +} + +/* class TaskFinder */ + +size_t TaskFinder::size() const +{ + return byUID_.size(); +} + +void TaskFinder::add(Task* task) +{ + const Label& label = task->label(); + byUID_.emplace(label.uid, task); + + if (label.cname) { + const ConfigName& cname = label.cname.value(); + auto [it, inserted] = byConfigName_.emplace( + std::make_tuple(cname.nodeName, cname.portName, cname.portBit), + task); + if (!inserted) + ERR_DIE("Same config name already exists: " + << cname.nodeName << "/" << cname.portName << "[" + << cname.portBit << "]"); + } +} + +Task* TaskFinder::findByUID(UID uid) const +{ + return byUID_.at(uid); +} + +Task* TaskFinder::findByConfigName(const ConfigName& cname) const +{ + return byConfigName_.at( + std::make_tuple(cname.nodeName, cname.portName, cname.portBit)); +} + +/* class ReadyQueue */ + +bool ReadyQueue::empty() const +{ + return queue_.empty(); +} + +void ReadyQueue::pop() +{ + queue_.pop(); +} + +Task* ReadyQueue::peek() const +{ + auto [pri, task] = queue_.top(); + return task; +} + +void ReadyQueue::push(Task* task) +{ + queue_.emplace(task->priority(), task); + task->setQueued(); +} + +/* class Network */ + +Network::Network(TaskFinder finder, std::vector> tasks) + : finder_(std::move(finder)), tasks_(std::move(tasks)) +{ +} + +size_t Network::size() const +{ + return tasks_.size(); +} + +const TaskFinder& Network::finder() const +{ + return finder_; +} + +bool Network::checkIfValid() const +{ + bool valid = true; + eachTask([&](Task* task) { + if (!task->checkIfValid()) + valid = false; + }); + + // Check if the network is weekly connected + size_t numConnectedTasks = 0; + visitTaskTopo(*this, [&](Task*) { numConnectedTasks++; }); + if (numConnectedTasks != size()) { + LOG_S(ERROR) << "The network is not weekly connected i.e., there are " + "some nodes that cannot be visited; numConnectedTasks " + << numConnectedTasks << " != size " << size(); + valid = false; + } + + return valid; +} + +/* class NetworkBuilder */ + +NetworkBuilder::NetworkBuilder(Allocator& alc) + : finder_(), tasks_(), consumed_(false), alc_(&alc) +{ +} + +NetworkBuilder::~NetworkBuilder() +{ +} + +Allocator& NetworkBuilder::currentAllocator() +{ + return *alc_; +} + +const TaskFinder& NetworkBuilder::finder() const +{ + assert(!consumed_); + return finder_; +} + +Network NetworkBuilder::createNetwork() +{ + assert(!consumed_); + consumed_ = true; + return Network{std::move(finder_), std::move(tasks_)}; +} + +/* class Worker */ + +Worker::Worker() : target_(nullptr) +{ +} + +Worker::~Worker() +{ +} + +void Worker::update(ReadyQueue& readyQueue, size_t& numFinishedTargets) +{ + if (target_ == nullptr && !readyQueue.empty()) { + // Try to find the task to tackle next + Task* cand = readyQueue.peek(); + assert(cand != nullptr); + if (canExecute(cand)) { + target_ = cand; + readyQueue.pop(); + startTask(target_); + } + } + + if (target_ != nullptr && target_->hasFinished()) { + for (Task* child : target_->children()) { + if (child->hasQueued()) + continue; + child->notifyOneInputReady(); + if (child->areAllInputsReady()) + readyQueue.push(child); + } + target_ = nullptr; + numFinishedTargets++; + } +} + +bool Worker::isWorking() const +{ + return target_ != nullptr; +} + +/* class NetworkRunner */ + +NetworkRunner::NetworkRunner(Network network, + std::vector> workers) + : network_(std::move(network)), + workers_(std::move(workers)), + readyQueue_(), + numFinishedTargets_(0) +{ + assert(workers_.size() != 0); + for (auto&& w : workers) + assert(w != nullptr); +} + +void NetworkRunner::prepareToRun() +{ + assert(readyQueue_.empty()); + + numFinishedTargets_ = 0; + + // Push ready tasks to the ready queue. + network_.eachTask([&](Task* task) { + if (task->areAllInputsReady()) + readyQueue_.push(task); + }); +} + +void NetworkRunner::update() +{ + for (auto&& w : workers_) + w->update(readyQueue_, numFinishedTargets_); +} + +const Network& NetworkRunner::network() const +{ + return network_; +} + +size_t NetworkRunner::numFinishedTargets() const +{ + return numFinishedTargets_; +} + +bool NetworkRunner::isRunning() const +{ + return std::any_of(workers_.begin(), workers_.end(), + [](auto&& w) { return w->isWorking(); }) || + !readyQueue_.empty(); +} + +void NetworkRunner::run() +{ + prepareToRun(); + while (numFinishedTargets() < network().size()) { + assert(isRunning() && "Invalid network: maybe some unreachable tasks?"); + update(); + } +} + +void NetworkRunner::tick() +{ + network_.eachTask([&](Task* task) { task->tick(); }); +} + +void NetworkRunner::onAfterTick(size_t currentCycle) +{ + network_.eachTask([&](Task* task) { task->onAfterTick(currentCycle); }); +} + +/* struct RunParameter */ + +void RunParameter::print() const +{ + LOG_S(INFO) << "Run parameters"; + LOG_S(INFO) << "\tMode: plain"; + LOG_S(INFO) << "\tBlueprint: " << blueprintFile; + LOG_S(INFO) << "\t# of CPU Workers: " << numCPUWorkers; + LOG_S(INFO) << "\t# of cycles: " << numCycles; + LOG_S(INFO) << "\tCurrent cycle #: " << currentCycle; + LOG_S(INFO) << "\tInput file (request packet): " << inputFile; + LOG_S(INFO) << "\tOutput file (result packet): " << outputFile; + LOG_S(INFO) << "\tSchedule: " << (sched == SCHED::TOPO ? "topo" : "ranku"); +} + +/* class Frontend */ + +Frontend::Frontend(const RunParameter& pr) + : pr_(pr), + network_(std::nullopt), + currentCycle_(pr.currentCycle), + bp_(std::make_unique(pr_.blueprintFile)), + alc_(std::make_shared()) +{ + assert(alc_); +} + +Frontend::Frontend(const Snapshot& ss) + : pr_(ss.getRunParam()), + network_(std::nullopt), + currentCycle_(pr_.currentCycle), + bp_(std::make_unique(pr_.blueprintFile)), + alc_(ss.getAllocator()) +{ + assert(alc_); +} + +void Frontend::buildNetwork(NetworkBuilder& nb) +{ + LOG_DBG_SCOPE("FRONTEND BUILD NETWORK"); + + const Blueprint& bp = blueprint(); + const TaskFinder& finder = nb.finder(); + + // [[file]] + LOG_DBG << "BUILD NETWORK FROM FILE"; + for (auto&& file : bp.files()) + readNetworkFromFile(file, nb); + + // [[builtin]] type = ram | type = mux-ram + LOG_DBG << "BUILD BUILTIN RAM"; + for (auto&& ram : bp.builtinRAMs()) { + // We ignore ram.type and always use mux-ram in plaintext mode. + makeMUXRAM(ram, nb); + } + + // [[builtin]] type = rom | type = mux-rom + LOG_DBG << "BUILD BUILTIN ROM"; + for (auto&& rom : bp.builtinROMs()) { + // We ignore rom.type and always use mux-rom in plaintext mode. + makeMUXROM(rom, nb); + } + + auto get = [&](const blueprint::Port& port) -> Task* { + Task* task = finder.findByConfigName(port.cname); + if (task->label().kind != port.kind) + ERR_DIE("Invalid port: " << port.cname << " is " + << task->label().kind << ", not " + << port.kind); + return task; + }; + + // [connect] + LOG_DBG << "CONNECT"; + // We need to treat "... = @..." and "@... = ..." differently from + // "..." = ...". + // First, check if ports that are connected to or from "@..." exist. + for (auto&& [key, port] : bp.atPorts()) { + get(port); // Only checks if port exists + } + // Then, connect other ports. `get` checks if they also exist. + for (auto&& [src, dst] : bp.edges()) { + assert(src.kind == Label::OUTPUT); + assert(dst.kind == Label::INPUT); + nb.connect(get(src)->label().uid, get(dst)->label().uid); + } + + // Create the network from the builder + LOG_DBG << "CREATE NETWORK FROM BUILDER"; + network_.emplace(nb.createNetwork()); + + // Check if network is valid + LOG_DBG << "CHECK IF NETWORK IS VALID"; + if (!network_->checkIfValid()) + ERR_DIE("Network is not valid"); + + // Set priority to each task + LOG_DBG << "PRIORITIZE TASKS IN NETWORK"; + switch (pr_.sched) { + case SCHED::TOPO: + prioritizeTaskByTopo(network_.value()); + break; + + case SCHED::RANKU: + prioritizeTaskByRanku(network_.value()); + break; + } +} + +Frontend::~Frontend() +{ +} + +void Frontend::run() +{ + const Blueprint& bp = blueprint(); + + DataHolder bit0, bit1; + setBit0(bit0); + setBit1(bit1); + + // Create workers + LOG_DBG << "CREATE WORKERS"; + std::vector> workers = makeWorkers(); + + // Create runner and finder for the network + LOG_DBG << "CREATE RUNNER"; + NetworkRunner runner{std::move(network_.value()), std::move(workers)}; + network_ = std::nullopt; + const TaskFinder& finder = runner.network().finder(); + + // Process reset cycle if @reset is used + // FIXME: Add support for --skip-reset flag + if (currentCycle_ == 0) { + auto reset = bp.at("reset"); + if (reset && reset->kind == Label::INPUT) { + LOG_DBG << "RESET"; + Task* t = finder.findByConfigName(reset->cname); + t->setInput(bit1); // Set reset on + runner.run(); + t->setInput(bit0); // Set reset off + } + } + + // Process normal cycles + for (int i = 0; i < pr_.numCycles; i++, currentCycle_++) { + LOG_DBG_SCOPE("Cycle #%d (i = %d)", currentCycle_, i); + + // Mount new values to DFFs + LOG_DBG << "TICK"; + runner.tick(); + + // Set new input data. If i is equal to 0, it also mounts initial data + // to RAMs. + LOG_DBG << "ON AFTER TICK"; + runner.onAfterTick(currentCycle_); + + // Go computing of each gate + LOG_DBG << "RUN"; + runner.run(); + + /* + // Debug printing of all the gates + runner.network().eachTask([&](Task* t) { + TaskCommon* p = dynamic_cast*>(t); + if (p == nullptr) + return; + const Label& l = t->label(); + if (t->label().cname) + LOG_DBG << l.kind << "\t" << *l.cname << "\t" + << p->DEBUG_output(); + else + LOG_DBG << l.kind << "\t" << p->DEBUG_output(); + }); + */ + } + + // Dump result packet + LOG_DBG << "DUMP RES PACKET"; + dumpResPacket(pr_.outputFile, finder, currentCycle_); + + // Dump snapshot + if (pr_.snapshotFile) { + LOG_DBG << "DUMP SNAPSHOT"; + Snapshot ss{pr_, allocatorPtr()}; + ss.updateCurrentCycle(currentCycle_); + ss.dump(pr_.snapshotFile.value()); + } +} + +/* makeMUXROM */ + +namespace { +void make1bitROMWithMUX(const std::string& nodeName, + const std::vector& addrInputs, + size_t outRdataWidth, size_t indexOutRdata, + NetworkBuilder& nb) +{ + /* + INPUT + addr[1] ------------------------------+ + INPUT | + addr[0] --+-----------------+ | + | | | + | ROM | | + | romdata[0] -- |\ | + | ROM | | --+ | + | romdata[1] -- |/ +-- |\ OUTPUT + | | | -- ... -- rdata[indexOutRdata] + +-----------------+ +-- |/ + | | + ROM | | + romdata[2] -- +\ | + ROM | | --+ + romdata[3] -- |/ + + ... + + ROM + addr[2^inAddrWidth-1] -- ... + */ + + const size_t inAddrWidth = addrInputs.size(); + + // Create ROMs + std::vector workingIds; + for (size_t i = 0; i < (1 << inAddrWidth); i++) { + UID id = nb.ROM(nodeName, "romdata", indexOutRdata + i * outRdataWidth); + workingIds.push_back(id); + } + + // Create MUXs + for (size_t i = 0; i < inAddrWidth; i++) { + assert(workingIds.size() > 0 && workingIds.size() % 2 == 0); + std::vector newWorkingIds; + for (size_t j = 0; j < workingIds.size(); j += 2) { + UID id = nb.MUX(); + nb.connect(workingIds.at(j), id); + nb.connect(workingIds.at(j + 1), id); + nb.connect(addrInputs.at(i), id); + newWorkingIds.push_back(id); + } + workingIds.swap(newWorkingIds); + } + assert(workingIds.size() == 1); + + // Create output + UID id = nb.OUTPUT(nodeName, "rdata", indexOutRdata); + nb.connect(workingIds.at(0), id); +} +} // namespace + +void makeMUXROM(const blueprint::BuiltinROM& rom, NetworkBuilder& nb) +{ + // Create inputs + std::vector addrInputs; + for (size_t i = 0; i < rom.inAddrWidth; i++) { + UID id = nb.INPUT(rom.name, "addr", i); + addrInputs.push_back(id); + } + + // Create 1bit ROMs + for (size_t i = 0; i < rom.outRdataWidth; i++) { + make1bitROMWithMUX(rom.name, addrInputs, rom.outRdataWidth, i, nb); + } +} + +/* makeMUXRAM */ + +namespace { + +void make1bitRAMWithMUX(const std::string& nodeName, + const std::vector& addrInputs, UID wrenInput, + size_t dataWidth, size_t indexWRdata, + NetworkBuilder& nb) +{ + /* + wdata[indexWRdata] + | + | +---------------------+ + | | | + | +--|\ | + | | |-- ramdata[.] --+------------+-|\ + +------|/ | |-- rdata[indexWRdata] + | | +---------------------+ +-|/ + | a +--|\ |---+ + | | |-- ramdata[.] --+ + +---------------|/ + | | + | b + + ... + + | + +---- ... -- ramdata[2^inAddrWidth-1] -- + + + a b + | | + ----- + addr[0] --- /0 1\ DMUX + ------- ... + | | + +-------+-------+ + | + ... ... + | | + | | + ----- + addr[inAddrWidth-1] --- /0 1\ DMUX + ------- + | + wren + + + DMUX: (in, sel) -> (out0, out1) + out0 = andnot(in, sel) + out1 = and(in, sel) + */ + + const size_t inAddrWidth = addrInputs.size(); + + // Create input "wdata[indexWRdata]" + UID wdataInput = nb.INPUT(nodeName, "wdata", indexWRdata); + + // Create DMUXs + std::vector workingIds = {wrenInput}, newWorkingIds; + for (auto it = addrInputs.rbegin(); it != addrInputs.rend(); ++it) { + UID addr = *it; + for (UID src : workingIds) { + // Create DMUX + // dst0 = andnot(src, addr) + // dst1 = and(src, addr) + UID dst0 = nb.ANDNOT(); + UID dst1 = nb.AND(); + nb.connect(src, dst0); + nb.connect(addr, dst0); + nb.connect(src, dst1); + nb.connect(addr, dst1); + + newWorkingIds.push_back(dst0); + newWorkingIds.push_back(dst1); + } + workingIds.swap(newWorkingIds); + newWorkingIds.clear(); + } + assert(workingIds.size() == (1 << inAddrWidth)); + + // Create RAMs + for (size_t addr = 0; addr < (1 << inAddrWidth); addr++) { + /* + +-------------------------+ + | | + +--|\ RAM |-- + INPUT | |-- ramdata[ ... ] --+ + wdata[indexRWdata] ----|/ + | + sel + */ + UID sel = workingIds.at(addr), mux = nb.MUX(), + ram = nb.RAM(nodeName, "ramdata", addr * dataWidth + indexWRdata); + nb.connect(ram, mux); + nb.connect(wdataInput, mux); + nb.connect(sel, mux); + nb.connect(mux, ram); + newWorkingIds.push_back(ram); + } + workingIds.swap(newWorkingIds); + newWorkingIds.clear(); + + // Create MUXs + for (size_t i = 0; i < inAddrWidth; i++) { + assert(workingIds.size() > 0 && workingIds.size() % 2 == 0); + for (size_t j = 0; j < workingIds.size(); j += 2) { + UID id = nb.MUX(); + nb.connect(workingIds.at(j), id); + nb.connect(workingIds.at(j + 1), id); + nb.connect(addrInputs.at(i), id); + newWorkingIds.push_back(id); + } + workingIds.swap(newWorkingIds); + newWorkingIds.clear(); + } + assert(workingIds.size() == 1); + + // Create output "rdata[indexWRdata]" + UID rdataOutput = nb.OUTPUT(nodeName, "rdata", indexWRdata); + nb.connect(workingIds.at(0), rdataOutput); +} + +} // namespace + +extern "C" { +// Iyokan-L1 JSON of MUX RAM pre-compiled (and optimized) by Yosys +extern char _binary_mux_ram_8_8_8_min_json_start[]; +extern char _binary_mux_ram_8_8_8_min_json_end[]; +extern char _binary_mux_ram_8_8_8_min_json_size[]; +extern char _binary_mux_ram_8_16_16_min_json_start[]; +extern char _binary_mux_ram_8_16_16_min_json_end[]; +extern char _binary_mux_ram_8_16_16_min_json_size[]; +extern char _binary_mux_ram_9_16_16_min_json_start[]; +extern char _binary_mux_ram_9_16_16_min_json_end[]; +extern char _binary_mux_ram_9_16_16_min_json_size[]; +} + +void makeMUXRAM(const blueprint::BuiltinRAM& ram, NetworkBuilder& nb) +{ + assert(ram.inWdataWidth == ram.outRdataWidth); + +#define USE_PRECOMPILED_BINARY(addrW, dataW) \ + if (ram.inAddrWidth == addrW && ram.outRdataWidth == dataW) { \ + std::stringstream ss{std::string{ \ + _binary_mux_ram_##addrW##_##dataW##_##dataW##_min_json_start, \ + _binary_mux_ram_##addrW##_##dataW##_##dataW##_min_json_end}}; \ + readPrecompiledRAMNetworkFromFile(ram.name, ss, nb, dataW); \ + return; \ + } + USE_PRECOMPILED_BINARY(8, 8); + USE_PRECOMPILED_BINARY(8, 16); + USE_PRECOMPILED_BINARY(9, 16); +#undef USE_PRECOMPILED_BINARY + + // Create inputs + std::vector addrInputs; + for (size_t i = 0; i < ram.inAddrWidth; i++) { + UID id = nb.INPUT(ram.name, "addr", i); + addrInputs.push_back(id); + } + UID wrenInput = nb.INPUT(ram.name, "wren", 0); + + // Create 1bitRAMs + for (size_t i = 0; i < ram.outRdataWidth; i++) { + make1bitRAMWithMUX(ram.name, addrInputs, wrenInput, ram.outRdataWidth, + i, nb); + } +} + +void test0() +{ + // operator< for ConfigName + { + bool res = false; + res = ConfigName{"abc", "def", 0} < ConfigName{"abc", "dfe", 0}; + assert(res); + res = ConfigName{"acc", "def", 0} < ConfigName{"abc", "dfe", 0}; + assert(!res); + res = ConfigName{"abc", "def", 0} < ConfigName{"abc", "def", 0}; + assert(!res); + res = ConfigName{"abc", "def", 0} < ConfigName{"abc", "def", 1}; + assert(res); + } +} + +} // namespace nt diff --git a/src/iyokan_nt.hpp b/src/iyokan_nt.hpp new file mode 100644 index 0000000..ec597e3 --- /dev/null +++ b/src/iyokan_nt.hpp @@ -0,0 +1,565 @@ +#ifndef VIRTUALSECUREPLATFORM_IYOKAN_NT_HPP +#define VIRTUALSECUREPLATFORM_IYOKAN_NT_HPP + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "error_nt.hpp" +#include "label.hpp" + +// Forward declarations +namespace nt { +namespace plain { +struct WorkerInfo; +} +namespace tfhepp { +struct WorkerInfo; +}; +class DataHolder; +class Blueprint; +namespace blueprint { +struct File; +struct BuiltinROM; +struct BuiltinRAM; +} // namespace blueprint +} // namespace nt +namespace cereal { +class PortableBinaryOutputArchive; +class PortableBinaryInputArchive; +} // namespace cereal + +namespace nt { + +class Allocator { +private: + // Prohibit copying Allcator. Probably moving it is not a problem. + Allocator(const Allocator&) = delete; + Allocator& operator=(const Allocator&) = delete; + +public: + using Index = size_t; + +private: + // Allocator can be constructed in 2 ways: + // - With no arguments. In this case, Allocator has empty data at first, and + // make() really makes a new object. The member variable + // hasLoadedFromIStream_ is false, and indexToBeMade_ is not used. + // - With snapshot file path. In this case, Allocator reads data from the + // snapshot file, and make() returns data_[indexToBeMade_++]. The member + // variable hasLoadedFromIStream_ is true, and indexToBeMade_ indicates + // the next index of the data returned by make(). + bool hasLoadedFromIStream_; + size_t indexToBeMade_; + + // data_ has all the allocated objects in std::any. + // We use std::deque here since push_back/emplace_back of std::deque do + // not invalidate the address of its element, while ones of std::vector + // do. + std::deque data_; + +public: + Allocator(); + Allocator(cereal::PortableBinaryInputArchive& ar); + + void dumpAllocatedData(cereal::PortableBinaryOutputArchive& ar) const; + + template + T* get(Index index) + { + assert(index < data_.size()); + assert(!hasLoadedFromIStream_ || index < indexToBeMade_); + T* ret = std::any_cast(&data_.at(index)); + assert(ret != nullptr); + return ret; + } + + template + T* make() + { + if (hasLoadedFromIStream_) { + return std::any_cast(&data_.at(indexToBeMade_++)); + } + else { + std::any& v = data_.emplace_back(); + return &v.emplace(); + } + } +}; + +class Task { +private: + Label label_; + std::vector parents_, children_; + int priority_; + bool hasQueued_; + +public: + Task(Label label); + virtual ~Task(); + + const Label& label() const; + const std::vector& parents() const; + const std::vector& children() const; + int priority() const; + bool hasQueued() const; + void addChild(Task* task); + void addParent(Task* task); + void setPriority(int newPri); + void setQueued(); + + // Get computation cost of this task. Used for scheduling of tasks. + virtual int getComputationCost() const; + + // Check if the task is valid. Returns true iff it is valid. + virtual bool checkIfValid() const = 0; + + virtual void notifyOneInputReady() = 0; + virtual bool areAllInputsReady() const = 0; + virtual bool hasFinished() const = 0; + + // tick() resets the internal state of the task for the next cycle + virtual void tick(); + + // Get output value. Only available for output and DFF gates. + virtual void getOutput(DataHolder&); + + // Set input value. Only available for input gates. + virtual void setInput(const DataHolder&); + + // onAfterTick() will be called after each tick. + virtual void onAfterTick(size_t currentCycle); + + // Return true iff this task can be run in plaintext mode. + virtual bool canRunPlain() const; + + // Return true iff this task can be run in TFHEpp mode. + virtual bool canRunTFHEpp() const; + + // Start this task asynchronously in plaintext mode. + // Only available when canRunPlain() returns true. + virtual void startAsynchronously(plain::WorkerInfo&); + + // Start this task asynchronously in TFHEpp mode. + // Only available when canRunTFHEpp() returns true. + virtual void startAsynchronously(tfhepp::WorkerInfo&); +}; + +class TaskFinder { +private: + std::unordered_map byUID_; + std::map, Task*> byConfigName_; + +public: + size_t size() const; + void add(Task* task); + Task* findByUID(UID uid) const; + Task* findByConfigName(const ConfigName& cname) const; + + template + void eachTask(F f) const + { + for (auto&& [uid, task] : byUID_) + f(uid, task); + } +}; + +// TaskCommon can be used as base class of many "common" tasks. +// "Common" here means: +// 1. # of outputs is 1. (#inputs can be >1.) +// 2. All the inputs and output have the same type. +// 3. (and so on) +template +class TaskCommon : public Task { +private: + size_t numReadyInputs_; + const size_t numMinExpectedInputs_, numMaxExpectedInputs_; + std::vector inputs_; + T* output_; + +protected: + size_t getInputSize() const + { + return inputs_.size(); + } + + const T& input(size_t i) const + { + assert(i < inputs_.size()); + return *inputs_.at(i); + } + + T& output() + { + assert(output_ != nullptr); + return *output_; + } + +public: + TaskCommon(Label label, Allocator& alc, size_t numMinExpectedInputs, + std::optional numMaxExpectedInputs = std::nullopt) + : Task(std::move(label)), + numReadyInputs_(0), + numMinExpectedInputs_(numMinExpectedInputs), + numMaxExpectedInputs_( + numMaxExpectedInputs.value_or(numMinExpectedInputs)), + inputs_(), + output_(alc.make()) + { + } + + virtual ~TaskCommon() + { + } + + virtual bool checkIfValid() const override + { + assert(output_); + + bool valid = true; + if (getInputSize() < numMinExpectedInputs_) { + LOG_S(ERROR) << "Input size < min. expected size: " + << getInputSize() << " < " << numMinExpectedInputs_; + valid = false; + } + if (getInputSize() > numMaxExpectedInputs_) { + LOG_S(ERROR) << "Input size > max. expected size: " + << getInputSize() << " > " << numMaxExpectedInputs_; + valid = false; + } + if (getInputSize() != parents().size()) { + LOG_S(ERROR) << "Input size != parents size: " << getInputSize() + << " != " << parents().size(); + valid = false; + } + + return valid; + } + + virtual void notifyOneInputReady() override + { + numReadyInputs_++; + assert(numReadyInputs_ <= inputs_.size()); + } + + virtual bool areAllInputsReady() const override + { + return numReadyInputs_ == inputs_.size(); + } + + virtual void tick() override + { + Task::tick(); + numReadyInputs_ = 0; + } + + void addInput(TaskCommon* newIn) + { + assert(newIn != nullptr); + addInput(newIn, newIn->output_); + } + + void addInput(Task* newParent, T* newIn) + { + assert(newParent != nullptr && newIn != nullptr); + assert(inputs_.size() < numMaxExpectedInputs_); + + addParent(newParent); + newParent->addChild(this); + + inputs_.push_back(newIn); + } + + // public output(). Debug purpose only. + T& DEBUG_output() + { + return output(); + } +}; + +// class TaskDFF can be used as base class of DFF tasks. +// TaskDFF inherits TaskCommon, so it has addInput member functions. +// NetworkBuilder can use it to connect common gates with DFFs. +template +class TaskDFF : public TaskCommon { +public: + TaskDFF(Label label, Allocator& alc) + : TaskCommon(std::move(label), alc, 1) + { + } + + virtual ~TaskDFF() + { + } + + void notifyOneInputReady() override + { + ERR_UNREACHABLE; + } + + bool areAllInputsReady() const override + { + // Since areAllInputsReady() is called after calling of tick(), the + // input should already be in output(). + return true; + } + + bool hasFinished() const override + { + // Since hasFinished() is called after calling of tick(), the + // input should already be in output(). + return true; + } + + void tick() override + { + TaskCommon::tick(); + this->output() = this->input(0); + } +}; + +class ReadyQueue { +private: + std::priority_queue> queue_; + +public: + bool empty() const; + void pop(); + Task* peek() const; + void push(Task* task); +}; + +class Network { +private: + TaskFinder finder_; + std::vector> tasks_; + +public: + Network(TaskFinder finder, std::vector> tasks); + + size_t size() const; + const TaskFinder& finder() const; + + // Check if the network is valid. Print error messages if necessary. Returns + // true iff it is valid. + bool checkIfValid() const; + + template + void eachTask(F f) const + { + for (auto&& task : tasks_) + f(task.get()); + } + + template + void eachTask(F f) + { + for (auto&& task : tasks_) + f(task.get()); + } +}; + +class NetworkBuilder { +private: + TaskFinder finder_; + std::vector> tasks_; + bool consumed_; + Allocator* alc_; + +protected: + // Create a new task. T must be derived from class Task. + template + T* emplaceTask(Args&&... args) + { + assert(!consumed_); + T* task = new T(std::forward(args)...); + tasks_.emplace_back(task); + finder_.add(task); + return task; + } + + Allocator& currentAllocator(); + +public: + NetworkBuilder(Allocator& alc); + virtual ~NetworkBuilder(); + + const TaskFinder& finder() const; + + Network createNetwork(); + + virtual void connect(UID from, UID to) = 0; + + // not/and/or are C++ keywords, so the member functions here are in + // capitals. + virtual UID INPUT(const std::string& nodeName, const std::string& portName, + int portBit) = 0; + virtual UID OUTPUT(const std::string& nodeName, const std::string& portName, + int portBit) = 0; + virtual UID ROM(const std::string& nodeName, const std::string& portName, + int portBit) = 0; + virtual UID RAM(const std::string& nodeName, const std::string& portName, + int portBit) = 0; + + virtual UID AND() = 0; + virtual UID ANDNOT() = 0; + virtual UID CONSTONE() = 0; + virtual UID CONSTZERO() = 0; + virtual UID DFF() = 0; + virtual UID MUX() = 0; + virtual UID NAND() = 0; + virtual UID NMUX() = 0; + virtual UID NOR() = 0; + virtual UID NOT() = 0; + virtual UID OR() = 0; + virtual UID ORNOT() = 0; + virtual UID SDFF0() = 0; + virtual UID SDFF1() = 0; + virtual UID XNOR() = 0; + virtual UID XOR() = 0; +}; + +class Worker { +private: + Task* target_; + +public: + Worker(); + virtual ~Worker(); + + void update(ReadyQueue& readyQueue, size_t& numFinishedTargets); + bool isWorking() const; + +protected: + virtual void startTask(Task* task) = 0; + virtual bool canExecute(Task* task) = 0; +}; + +class NetworkRunner { +private: + Network network_; + std::vector> workers_; + ReadyQueue readyQueue_; + size_t numFinishedTargets_; + +private: + void prepareToRun(); + void update(); + +public: + NetworkRunner(Network network, + std::vector> workers); + + const Network& network() const; + size_t numFinishedTargets() const; + bool isRunning() const; + void run(); + void tick(); + void onAfterTick(size_t currentCycle); +}; + +enum class SCHED { + TOPO, + RANKU, +}; + +struct RunParameter { + std::string blueprintFile, inputFile, outputFile, bkeyFile; + int numCPUWorkers, numCycles, currentCycle; + SCHED sched; + + std::optional snapshotFile; + + void print() const; + + // For cereal + template + void serialize(Archive& ar) + { + ar(blueprintFile, inputFile, outputFile, bkeyFile, numCPUWorkers, + numCycles, currentCycle, sched, snapshotFile); + } +}; + +class Snapshot { +private: + RunParameter pr_; + std::shared_ptr alc_; + +public: + Snapshot(const RunParameter& pr, const std::shared_ptr& alc); + Snapshot(const std::string& snapshotFile); + + const RunParameter& getRunParam() const; + const std::shared_ptr& getAllocator() const; + void updateCurrentCycle(int currentCycle); + void updateNumCycles(int numCycles); + void dump(const std::string& snapshotFile) const; +}; + +class Frontend { +private: + // Prohibit copying Frontend. + Frontend(const Frontend&) = delete; + Frontend& operator=(const Frontend&) = delete; + +private: + const RunParameter pr_; + std::optional network_; + int currentCycle_; + std::unique_ptr bp_; + std::shared_ptr alc_; + +protected: + virtual void setBit0(DataHolder& dh) = 0; + virtual void setBit1(DataHolder& dh) = 0; + virtual void dumpResPacket(const std::string& outpath, + const TaskFinder& finder, int numCycles) = 0; + virtual std::vector> makeWorkers() = 0; + + void buildNetwork(NetworkBuilder& nb); + const RunParameter& runParam() const + { + return pr_; + } + const Blueprint& blueprint() const + { + return *bp_; + } + Allocator& allocator() + { + return *alc_; + } + const std::shared_ptr& allocatorPtr() const + { + return alc_; + } + +public: + Frontend(const RunParameter& pr); + Frontend(const Snapshot& snapshot); + virtual ~Frontend(); + + void run(); +}; + +void readPrecompiledRAMNetworkFromFile(const std::string& name, + std::istream& is, nt::NetworkBuilder& nb, + int ramDataWidth); +void readNetworkFromFile(const blueprint::File& file, NetworkBuilder& nb); +void makeMUXROM(const blueprint::BuiltinROM& rom, NetworkBuilder& nb); +void makeMUXRAM(const blueprint::BuiltinRAM& ram, NetworkBuilder& nb); + +void test0(); + +} // namespace nt + +#endif diff --git a/src/iyokan_nt_plain.cpp b/src/iyokan_nt_plain.cpp new file mode 100644 index 0000000..d57ecf7 --- /dev/null +++ b/src/iyokan_nt_plain.cpp @@ -0,0 +1,982 @@ +#include "iyokan_nt_plain.hpp" +#include "blueprint.hpp" +#include "dataholder_nt.hpp" +#include "iyokan_nt.hpp" +#include "packet_nt.hpp" + +namespace nt { +namespace plain { + +struct WorkerInfo { +}; + +class Worker : public nt::Worker { +private: + WorkerInfo wi_; + +protected: + bool canExecute(Task* task) override + { + return task->canRunPlain(); + } + + void startTask(Task* task) override + { + task->startAsynchronously(wi_); + } +}; + +// struct InputSource is used by class TaskInput to set correct input value +// every cycle. +struct InputSource { + int atPortWidth, atPortBit; + std::vector* bits; +}; + +class TaskInput : public TaskCommon { +private: + std::optional source_; + +public: + TaskInput(Label label, Allocator& alc) + : TaskCommon(label, alc, 0, 1), source_(std::nullopt) + { + } + TaskInput(InputSource source, Label label, Allocator& alc) + : TaskCommon(label, alc, 0, 1), source_(source) + { + } + + void onAfterTick(size_t currentCycle) override + { + if (source_) { + // Set the output value from the source + assert(getInputSize() == 0); + InputSource& s = source_.value(); + size_t index = + (s.atPortWidth * currentCycle + s.atPortBit) % s.bits->size(); + output() = s.bits->at(index); + } + } + + void startAsynchronously(WorkerInfo&) override + { + if (getInputSize() == 1) + output() = input(0); + } + + bool hasFinished() const override + { + return true; + } + + bool canRunPlain() const override + { + return true; + } + + void setInput(const DataHolder& h) override + { + // Set the input i.e., set the output value of this gate + output() = h.getBit(); + } +}; + +class TaskOutput : public TaskCommon { +public: + TaskOutput(Label label, Allocator& alc) : TaskCommon(label, alc, 1) + { + } + + void startAsynchronously(WorkerInfo&) override + { + output() = input(0); + } + + bool hasFinished() const override + { + return true; + } + + bool canRunPlain() const override + { + return true; + } + + void getOutput(DataHolder& h) override + { + h.setBit(&output()); + } +}; + +class TaskDFF : public nt::TaskDFF { +private: + std::optional initialValue_; + +public: + TaskDFF(Label label, Allocator& alc) + : nt::TaskDFF(std::move(label), alc), initialValue_(std::nullopt) + { + } + + TaskDFF(Bit initialValue, Label label, Allocator& alc) + : nt::TaskDFF(std::move(label), alc), initialValue_(initialValue) + { + } + + void onAfterTick(size_t currentCycle) override + { + if (currentCycle == 0) + output() = initialValue_.value_or(output()); + } + + bool canRunPlain() const override + { + return true; + } + + void startAsynchronously(WorkerInfo&) override + { + // Nothing to do, because the main process is done in + // nt::TaskDFF::tick(). + } + + void getOutput(DataHolder& h) override + { + h.setBit(&output()); + } +}; + +class TaskROM : public TaskCommon { +public: + TaskROM(Bit value, Label label, Allocator& alc) + : TaskCommon(label, alc, 0) + { + output() = value; + } + + void startAsynchronously(WorkerInfo&) override + { + } + + bool hasFinished() const override + { + return true; + } + + bool canRunPlain() const override + { + return true; + } +}; + +#define DEF_COMMON_TASK_CLASS(CamelName, inputSize, expr) \ + class Task##CamelName : public TaskCommon { \ + public: \ + Task##CamelName(Label label, Allocator& alc) \ + : TaskCommon(std::move(label), alc, inputSize) \ + { \ + } \ + void startAsynchronously(WorkerInfo&) override \ + { \ + output() = (expr); \ + } \ + bool hasFinished() const override \ + { \ + return true; \ + } \ + bool canRunPlain() const override \ + { \ + return true; \ + } \ + }; +DEF_COMMON_TASK_CLASS(And, 2, (input(0) & input(1))); +DEF_COMMON_TASK_CLASS(Andnot, 2, (input(0) & !input(1))); +DEF_COMMON_TASK_CLASS(ConstOne, 0, 1_b); +DEF_COMMON_TASK_CLASS(ConstZero, 0, 0_b); +DEF_COMMON_TASK_CLASS(Mux, 3, input(2) == 0_b ? input(0) : input(1)); +DEF_COMMON_TASK_CLASS(Nand, 2, !(input(0) & input(1))); +DEF_COMMON_TASK_CLASS(Nmux, 3, input(2) == 0_b ? !input(0) : !input(1)); +DEF_COMMON_TASK_CLASS(Nor, 2, !(input(0) | input(1))); +DEF_COMMON_TASK_CLASS(Not, 1, !input(0)); +DEF_COMMON_TASK_CLASS(Or, 2, (input(0) | input(1))); +DEF_COMMON_TASK_CLASS(Ornot, 2, (input(0) | !input(1))); +DEF_COMMON_TASK_CLASS(Xnor, 2, !(input(0) ^ input(1))); +DEF_COMMON_TASK_CLASS(Xor, 2, (input(0) ^ input(1))); +#undef DEF_COMMON_TASK_CLASS + +class NetworkBuilder : public nt::NetworkBuilder { +private: + std::unordered_map*> uid2common_; + UID nextUID_; + const PlainPacket* const reqPacket_; + const std::map* const cname2isource_; + +private: + UID genUID() + { + return nextUID_++; + } + +public: + NetworkBuilder(const std::map& cname2isource, + const PlainPacket& reqPacket, Allocator& alc) + : nt::NetworkBuilder(alc), + uid2common_(), + nextUID_(0), + reqPacket_(&reqPacket), + cname2isource_(&cname2isource) + { + } + + ~NetworkBuilder() + { + } + + void connect(UID fromUID, UID toUID) override + { + auto &from = uid2common_.at(fromUID), &to = uid2common_.at(toUID); + to->addInput(from); + } + +#define DEF_COMMON_TASK(CAPName, CamelName) \ + UID CAPName() override \ + { \ + UID uid = genUID(); \ + Task##CamelName* task = nullptr; \ + task = emplaceTask( \ + Label{uid, #CAPName, std::nullopt}, currentAllocator()); \ + uid2common_.emplace(uid, task); \ + return uid; \ + } + DEF_COMMON_TASK(AND, And); + DEF_COMMON_TASK(ANDNOT, Andnot); + DEF_COMMON_TASK(CONSTONE, ConstOne); + DEF_COMMON_TASK(CONSTZERO, ConstZero); + DEF_COMMON_TASK(DFF, DFF); + DEF_COMMON_TASK(MUX, Mux); + DEF_COMMON_TASK(NAND, Nand); + DEF_COMMON_TASK(NMUX, Nmux); + DEF_COMMON_TASK(NOR, Nor); + DEF_COMMON_TASK(NOT, Not); + DEF_COMMON_TASK(OR, Or); + DEF_COMMON_TASK(ORNOT, Ornot); + DEF_COMMON_TASK(XNOR, Xnor); + DEF_COMMON_TASK(XOR, Xor); +#undef DEF_COMMON_TASK + + UID SDFF0() override + { + UID uid = genUID(); + TaskDFF* task = emplaceTask( + 0_b, Label{uid, "SDFF0", std::nullopt}, currentAllocator()); + uid2common_.emplace(uid, task); + return uid; + } + + UID SDFF1() override + { + UID uid = genUID(); + TaskDFF* task = emplaceTask( + 1_b, Label{uid, "SDFF1", std::nullopt}, currentAllocator()); + uid2common_.emplace(uid, task); + return uid; + } + + UID INPUT(const std::string& nodeName, const std::string& portName, + int portBit) override + { + Allocator& alc = currentAllocator(); + UID uid = genUID(); + ConfigName cname = ConfigName{nodeName, portName, portBit}; + Label label{uid, Label::INPUT, cname}; + TaskInput* task = nullptr; + if (auto it = cname2isource_->find(cname); it != cname2isource_->end()) + task = emplaceTask(it->second, label, alc); + else + task = emplaceTask(label, alc); + uid2common_.emplace(uid, task); + return uid; + } + + UID OUTPUT(const std::string& nodeName, const std::string& portName, + int portBit) override + { + UID uid = genUID(); + TaskOutput* task = nullptr; + task = emplaceTask( + Label{uid, Label::OUTPUT, ConfigName{nodeName, portName, portBit}}, + currentAllocator()); + uid2common_.emplace(uid, task); + return uid; + } + + UID ROM(const std::string& nodeName, const std::string& portName, + int portBit) override + { + assert(reqPacket_ != nullptr); + assert(portName == "romdata"); + + UID uid = genUID(); + TaskROM* task = emplaceTask( + reqPacket_->rom.at(nodeName).at(portBit), + Label{uid, "ROM", ConfigName{nodeName, portName, portBit}}, + currentAllocator()); + uid2common_.emplace(uid, task); + return uid; + } + + UID RAM(const std::string& nodeName, const std::string& portName, + int portBit) override + { + assert(reqPacket_ != nullptr); + assert(portName == "ramdata"); + + UID uid = genUID(); + TaskDFF* task = emplaceTask( + reqPacket_->ram.at(nodeName).at(portBit), + Label{uid, "RAM", ConfigName{nodeName, portName, portBit}}, + currentAllocator()); + uid2common_.emplace(uid, task); + + return uid; + } +}; + +class Frontend : public nt::Frontend { +private: + PlainPacket reqPacket_; + +private: + void setBit0(DataHolder& dh) override; + void setBit1(DataHolder& dh) override; + void dumpResPacket(const std::string& outpath, const TaskFinder& finder, + int numCycles) override; + std::vector> makeWorkers() override; + + // Actual constructor. + void doConstruct(); + +public: + Frontend(const RunParameter& pr); + Frontend(const Snapshot& ss); +}; + +Frontend::Frontend(const RunParameter& pr) + : nt::Frontend(pr), reqPacket_(readPlainPacket(runParam().inputFile)) +{ + doConstruct(); +} + +Frontend::Frontend(const Snapshot& ss) + : nt::Frontend(ss), reqPacket_(readPlainPacket(runParam().inputFile)) +{ + doConstruct(); +} + +void Frontend::doConstruct() +{ + const Blueprint& bp = blueprint(); + + // Create map from ConfigName to InputSource + std::map cname2isource; + for (auto&& [key, port] : bp.atPorts()) { + // Find only inputs, that is, "[connect] ... = @..." + if (port.kind != Label::INPUT) + continue; + + // Get "@atPortName[atPortBit]" + auto& [atPortName, atPortBit] = key; + + // Check if reqPacket_ contains input data for @atPortName + auto it = reqPacket_.bits.find(atPortName); + if (it == reqPacket_.bits.end()) + continue; + + // Die if users try to set the value of @reset[0] since it is set only + // by system + if (atPortName == "reset") + ERR_DIE("@reset cannot be set by user's input"); + + // Add a new entry to cname2isource + InputSource s{bp.atPortWidths().at(atPortName), atPortBit, &it->second}; + cname2isource.emplace(port.cname, s); + } + + // Build the network. The instance is in nt::Frontend + NetworkBuilder nb{cname2isource, reqPacket_, allocator()}; + buildNetwork(nb); +} + +void Frontend::setBit0(DataHolder& dh) +{ + static const Bit bit0 = 0_b; + dh.setBit(&bit0); +} + +void Frontend::setBit1(DataHolder& dh) +{ + static const Bit bit1 = 1_b; + dh.setBit(&bit1); +} + +void Frontend::dumpResPacket(const std::string& outpath, + const TaskFinder& finder, int numCycles) +{ + DataHolder dh; + PlainPacket out; + const Blueprint& bp = blueprint(); + + // Set the current number of cycles + out.numCycles = numCycles; + + // Get values of output @port + out.bits.clear(); + for (auto&& [key, port] : bp.atPorts()) { + // Find "[connect] @atPortName[atPortBit] = ..." + if (port.kind != Label::OUTPUT) + continue; + auto& [atPortName, atPortBit] = key; + + // Get the value + Task* t = finder.findByConfigName(port.cname); + t->getOutput(dh); + + // Assign the value to the corresponding bit of the response packet + auto& bits = out.bits[atPortName]; + if (bits.size() < atPortBit + 1) + bits.resize(atPortBit + 1); + bits.at(atPortBit) = dh.getBit(); + } + + // Get values of RAM + for (auto&& ram : bp.builtinRAMs()) { + std::vector& dst = out.ram[ram.name]; + dst.clear(); + for (size_t i = 0; i < (1 << ram.inAddrWidth) * ram.outRdataWidth; + i++) { + ConfigName cname{ram.name, "ramdata", static_cast(i)}; + Task* t = finder.findByConfigName(cname); + t->getOutput(dh); + dst.push_back(dh.getBit()); + } + } + + // Dump the result packet + writePlainPacket(outpath, out); +} + +std::vector> Frontend::makeWorkers() +{ + const RunParameter& pr = runParam(); + std::vector> workers; + for (size_t i = 0; i < pr.numCPUWorkers; i++) + workers.emplace_back(std::make_unique()); + return workers; +} + +/**************************************************/ +/***** TEST ***************************************/ +/**************************************************/ + +void test0() +{ + WorkerInfo wi; + DataHolder dh; + Bit bit0 = 0_b, bit1 = 1_b; + PlainPacket pkt; + std::map c2is; + + { + Allocator alc; + TaskConstOne t0{Label{1, "", std::nullopt}, alc}; + TaskOutput t1{Label{2, "", std::nullopt}, alc}; + t1.addInput(&t0); + t0.startAsynchronously(wi); + t1.startAsynchronously(wi); + assert(t0.hasFinished()); + t1.getOutput(dh); + assert(dh.getBit() == 1_b); + } + + { + Allocator alc; + TaskConstZero t0{Label{0, "", std::nullopt}, alc}; + TaskConstOne t1{Label{1, "", std::nullopt}, alc}; + TaskNand t2{Label{2, "", std::nullopt}, alc}; + TaskOutput t3{Label{3, "", std::nullopt}, alc}; + t2.addInput(&t0); + t2.addInput(&t1); + t3.addInput(&t2); + t0.startAsynchronously(wi); + t1.startAsynchronously(wi); + t2.startAsynchronously(wi); + t3.startAsynchronously(wi); + assert(t0.hasFinished() && t1.hasFinished() && t2.hasFinished() && + t3.hasFinished()); + t3.getOutput(dh); + assert(dh.getBit() == 1_b); + } + + { + Allocator alc; + NetworkBuilder nb{c2is, pkt, alc}; + UID id0 = nb.INPUT("", "A", 0), id1 = nb.INPUT("", "B", 0), + id2 = nb.NAND(), id3 = nb.OUTPUT("", "C", 0); + nb.connect(id0, id2); + nb.connect(id1, id2); + nb.connect(id2, id3); + + std::vector> workers; + workers.emplace_back(std::make_unique()); + + NetworkRunner runner{nb.createNetwork(), std::move(workers)}; + Task* t0 = runner.network().finder().findByUID(id0); + Task* t1 = runner.network().finder().findByUID(id1); + Task* t3 = runner.network().finder().findByUID(id3); + + t0->setInput(&bit1); + t1->setInput(&bit1); + runner.run(); + t3->getOutput(dh); + assert(dh.getBit() == 0_b); + } + + { + /* + B D + reset(0) >---> ANDNOT(4) >---> DFF(2) + ^ A v Q + | | + *--< NOT(3) <--*-----> OUTPUT(1) + A + */ + Allocator alc; + NetworkBuilder nb{c2is, pkt, alc}; + UID id0 = nb.INPUT("", "reset", 0), id1 = nb.OUTPUT("", "out", 0), + id2 = nb.DFF(), id3 = nb.NOT(), id4 = nb.ANDNOT(); + nb.connect(id2, id1); + nb.connect(id4, id2); + nb.connect(id2, id3); + nb.connect(id3, id4); + nb.connect(id0, id4); + + std::vector> workers; + workers.emplace_back(std::make_unique()); + + NetworkRunner runner{nb.createNetwork(), std::move(workers)}; + Task* t0 = runner.network().finder().findByUID(id0); + Task* t1 = runner.network().finder().findByUID(id1); + + t0->setInput(&bit1); + runner.run(); + t0->setInput(&bit0); + + runner.tick(); + runner.run(); + t1->getOutput(dh); + assert(dh.getBit() == 0_b); + + runner.tick(); + runner.run(); + t1->getOutput(dh); + assert(dh.getBit() == 1_b); + } + + { + Allocator alc; + NetworkBuilder nb{c2is, pkt, alc}; + + readNetworkFromFile( + blueprint::File{blueprint::File::TYPE::YOSYS_JSON, + "test/yosys-json/addr-4bit-yosys.json", "addr"}, + nb); + + std::vector> workers; + workers.emplace_back(std::make_unique()); + + NetworkRunner runner{nb.createNetwork(), std::move(workers)}; + auto&& finder = runner.network().finder(); + Task *tA0 = finder.findByConfigName({"addr", "io_inA", 0}), + *tA1 = finder.findByConfigName({"addr", "io_inA", 1}), + *tA2 = finder.findByConfigName({"addr", "io_inA", 2}), + *tA3 = finder.findByConfigName({"addr", "io_inA", 3}); + Task *tB0 = finder.findByConfigName({"addr", "io_inB", 0}), + *tB1 = finder.findByConfigName({"addr", "io_inB", 1}), + *tB2 = finder.findByConfigName({"addr", "io_inB", 2}), + *tB3 = finder.findByConfigName({"addr", "io_inB", 3}); + Task *tO0 = finder.findByConfigName({"addr", "io_out", 0}), + *tO1 = finder.findByConfigName({"addr", "io_out", 1}), + *tO2 = finder.findByConfigName({"addr", "io_out", 2}), + *tO3 = finder.findByConfigName({"addr", "io_out", 3}); + + tA0->setInput(&bit1); + tA1->setInput(&bit0); + tA2->setInput(&bit1); + tA3->setInput(&bit0); + tB0->setInput(&bit0); + tB1->setInput(&bit1); + tB2->setInput(&bit0); + tB3->setInput(&bit1); + + runner.run(); + + tO0->getOutput(dh); + assert(dh.getBit() == 1_b); + tO1->getOutput(dh); + assert(dh.getBit() == 1_b); + tO2->getOutput(dh); + assert(dh.getBit() == 1_b); + tO3->getOutput(dh); + assert(dh.getBit() == 1_b); + } + + { + Allocator alc; + NetworkBuilder nb{c2is, pkt, alc}; + + readNetworkFromFile( + blueprint::File{blueprint::File::TYPE::IYOKANL1_JSON, + "test/iyokanl1-json/addr-4bit-iyokanl1.json", + "addr"}, + nb); + + std::vector> workers; + workers.emplace_back(std::make_unique()); + + NetworkRunner runner{nb.createNetwork(), std::move(workers)}; + auto&& finder = runner.network().finder(); + Task *tA0 = finder.findByConfigName({"addr", "io_inA", 0}), + *tA1 = finder.findByConfigName({"addr", "io_inA", 1}), + *tA2 = finder.findByConfigName({"addr", "io_inA", 2}), + *tA3 = finder.findByConfigName({"addr", "io_inA", 3}); + Task *tB0 = finder.findByConfigName({"addr", "io_inB", 0}), + *tB1 = finder.findByConfigName({"addr", "io_inB", 1}), + *tB2 = finder.findByConfigName({"addr", "io_inB", 2}), + *tB3 = finder.findByConfigName({"addr", "io_inB", 3}); + Task *tO0 = finder.findByConfigName({"addr", "io_out", 0}), + *tO1 = finder.findByConfigName({"addr", "io_out", 1}), + *tO2 = finder.findByConfigName({"addr", "io_out", 2}), + *tO3 = finder.findByConfigName({"addr", "io_out", 3}); + + tA0->setInput(&bit1); + tA1->setInput(&bit0); + tA2->setInput(&bit0); + tA3->setInput(&bit0); + tB0->setInput(&bit0); + tB1->setInput(&bit1); + tB2->setInput(&bit0); + tB3->setInput(&bit1); + + runner.run(); + + tO0->getOutput(dh); + assert(dh.getBit() == 1_b); + tO1->getOutput(dh); + assert(dh.getBit() == 1_b); + tO2->getOutput(dh); + assert(dh.getBit() == 0_b); + tO3->getOutput(dh); + assert(dh.getBit() == 1_b); + } + + { + Allocator alc; + NetworkBuilder nb{c2is, pkt, alc}; + + readNetworkFromFile( + blueprint::File{blueprint::File::TYPE::YOSYS_JSON, + "test/yosys-json/counter-4bit-yosys.json", + "counter"}, + nb); + + std::vector> workers; + workers.emplace_back(std::make_unique()); + + NetworkRunner runner{nb.createNetwork(), std::move(workers)}; + auto&& finder = runner.network().finder(); + Task *tRst = finder.findByConfigName({"counter", "reset", 0}), + *tOut0 = finder.findByConfigName({"counter", "io_out", 0}), + *tOut1 = finder.findByConfigName({"counter", "io_out", 1}), + *tOut2 = finder.findByConfigName({"counter", "io_out", 2}), + *tOut3 = finder.findByConfigName({"counter", "io_out", 3}); + + tRst->setInput(&bit1); + runner.run(); + tRst->setInput(&bit0); + + // Cycle #1 + runner.tick(); + runner.run(); + // Cycle #2 + runner.tick(); + runner.run(); + // Cycle #3 + runner.tick(); + runner.run(); + + // The output is 2, that is, '0b0010' + tOut0->getOutput(dh); + assert(dh.getBit() == 0_b); + tOut1->getOutput(dh); + assert(dh.getBit() == 1_b); + tOut2->getOutput(dh); + assert(dh.getBit() == 0_b); + tOut3->getOutput(dh); + assert(dh.getBit() == 0_b); + } + + { + Allocator alc; + NetworkBuilder nb{c2is, pkt, alc}; + + readNetworkFromFile( + blueprint::File{blueprint::File::TYPE::YOSYS_JSON, + "test/yosys-json/register-init-4bit-yosys.json", + "register_init"}, + nb); + + std::vector> workers; + workers.emplace_back(std::make_unique()); + + NetworkRunner runner{nb.createNetwork(), std::move(workers)}; + auto&& finder = runner.network().finder(); + Task *tIn0 = finder.findByConfigName({"register_init", "io_in", 0}), + *tIn1 = finder.findByConfigName({"register_init", "io_in", 1}), + *tIn2 = finder.findByConfigName({"register_init", "io_in", 2}), + *tIn3 = finder.findByConfigName({"register_init", "io_in", 3}), + *tOut0 = finder.findByConfigName({"register_init", "io_out", 0}), + *tOut1 = finder.findByConfigName({"register_init", "io_out", 1}), + *tOut2 = finder.findByConfigName({"register_init", "io_out", 2}), + *tOut3 = finder.findByConfigName({"register_init", "io_out", 3}); + + // Set 0xc to input + tIn0->setInput(&bit0); + tIn1->setInput(&bit0); + tIn2->setInput(&bit1); + tIn3->setInput(&bit1); + + // Skip the reset cycle (assume --skip-reset flag). + + // Cycle #1 + runner.tick(); + runner.onAfterTick(0); + runner.run(); + + // The output is 9, that is, '0b1001' + tOut0->getOutput(dh); + assert(dh.getBit() == 1_b); + tOut1->getOutput(dh); + assert(dh.getBit() == 0_b); + tOut2->getOutput(dh); + assert(dh.getBit() == 0_b); + tOut3->getOutput(dh); + assert(dh.getBit() == 1_b); + + // Cycle #2 + runner.tick(); + runner.run(); + + // The output is 12, that is, '0b1100' + tOut0->getOutput(dh); + assert(dh.getBit() == 0_b); + tOut1->getOutput(dh); + assert(dh.getBit() == 0_b); + tOut2->getOutput(dh); + assert(dh.getBit() == 1_b); + tOut3->getOutput(dh); + assert(dh.getBit() == 1_b); + } + + { + PlainPacket pkt; + pkt.rom["rom"] = {0_b, 1_b, 0_b, 0_b, 1_b, 0_b, 1_b, 1_b}; + + Allocator alc; + NetworkBuilder nb{c2is, pkt, alc}; + makeMUXROM(blueprint::BuiltinROM{blueprint::BuiltinROM::TYPE::MUX, + "rom", 2, 2}, + nb); + + std::vector> workers; + workers.emplace_back(std::make_unique()); + + NetworkRunner runner{nb.createNetwork(), std::move(workers)}; + auto&& finder = runner.network().finder(); + Task *tAddr0 = finder.findByConfigName({"rom", "addr", 0}), + *tAddr1 = finder.findByConfigName({"rom", "addr", 1}), + *tRdata0 = finder.findByConfigName({"rom", "rdata", 0}), + *tRdata1 = finder.findByConfigName({"rom", "rdata", 1}); + + tAddr0->setInput(&bit0); + tAddr1->setInput(&bit1); + runner.run(); + + tRdata0->getOutput(dh); + assert(dh.getBit() == 1_b); + tRdata1->getOutput(dh); + assert(dh.getBit() == 0_b); + } + + { + PlainPacket pkt; + pkt.ram["ram"] = {0_b, 1_b, 0_b, 0_b, 1_b, 0_b, 1_b, 1_b}; + + Allocator alc; + NetworkBuilder nb{c2is, pkt, alc}; + makeMUXRAM(blueprint::BuiltinRAM{blueprint::BuiltinRAM::TYPE::MUX, + "ram", 2, 2, 2}, + nb); + + std::vector> workers; + workers.emplace_back(std::make_unique()); + + NetworkRunner runner{nb.createNetwork(), std::move(workers)}; + auto&& finder = runner.network().finder(); + Task *tAddr0 = finder.findByConfigName({"ram", "addr", 0}), + *tAddr1 = finder.findByConfigName({"ram", "addr", 1}), + *tWren = finder.findByConfigName({"ram", "wren", 0}), + *tRdata0 = finder.findByConfigName({"ram", "rdata", 0}), + *tRdata1 = finder.findByConfigName({"ram", "rdata", 1}), + *tWdata0 = finder.findByConfigName({"ram", "wdata", 0}), + *tWdata1 = finder.findByConfigName({"ram", "wdata", 1}); + + // Reset cycle + runner.run(); + + // Cycle #1 + runner.tick(); + runner.onAfterTick(0); + tAddr0->setInput(&bit1); + tAddr1->setInput(&bit0); + tWren->setInput(&bit0); + tWdata0->setInput(&bit1); + tWdata1->setInput(&bit1); + runner.run(); + + tRdata0->getOutput(dh); + assert(dh.getBit() == 0_b); + tRdata1->getOutput(dh); + assert(dh.getBit() == 0_b); + + // Cycle #2 + runner.tick(); + tWren->setInput(&bit1); + runner.run(); + + // Cycle #3 + runner.tick(); + runner.run(); + + tRdata0->getOutput(dh); + assert(dh.getBit() == 1_b); + tRdata1->getOutput(dh); + assert(dh.getBit() == 1_b); + } + + auto go = [&](const std::string& blueprintPath, + const std::string& inPktPath, + const std::string& expectedOutPktPath, int numCycles) { + const char* const reqPktPath = "_test_in"; + const char* const resPktPath = "_test_out"; + + auto inPkt = PlainPacket::fromTOML(inPktPath), + expectedOutPkt = PlainPacket::fromTOML(expectedOutPktPath); + writePlainPacket(reqPktPath, inPkt); + + LOG_DBG_SCOPE("go"); + Frontend frontend{RunParameter{ + blueprintPath, // blueprintFile + reqPktPath, // inputFile + resPktPath, // outputFile + "", // bkeyFile + 2, // numCPUWorkers + numCycles, // numCycles + 0, // currentCycle + SCHED::RANKU, // sched + std::nullopt, // snapshotFile + }}; + frontend.run(); + PlainPacket got = readPlainPacket(resPktPath); + assert(got == expectedOutPkt); + }; + + go("test/config-toml/const-4bit.toml", "test/in/test22.in", + "test/out/test22.out", 1); + go("test/config-toml/addr-4bit.toml", "test/in/test04.in", + "test/out/test04.out", 1); + go("test/config-toml/pass-addr-pass-4bit.toml", "test/in/test04.in", + "test/out/test04.out", 1); + go("test/config-toml/addr-register-4bit.toml", "test/in/test16.in", + "test/out/test16.out", 3); + go("test/config-toml/div-8bit.toml", "test/in/test05.in", + "test/out/test05.out", 1); + go("test/config-toml/ram-addr8bit.toml", "test/in/test06.in", + "test/out/test06.out", 16); + go("test/config-toml/ram-addr9bit.toml", "test/in/test07.in", + "test/out/test07.out", 16); + go("test/config-toml/ram-8-16-16.toml", "test/in/test08.in", + "test/out/test08.out", 8); + go("test/config-toml/rom-4-8.toml", "test/in/test15.in", + "test/out/test15.out", 1); + go("test/config-toml/counter-4bit.toml", "test/in/test13.in", + "test/out/test13.out", 3); + go("test/config-toml/cahp-ruby.toml", "test/in/test09.in", + "test/out/test09-ruby.out", 7); + + auto go_ss = [&](const std::string& blueprintPath, + const std::string& inPktPath, + const std::string& expectedOutPktPath, int numCycles) { + const char* const reqPktPath = "_test_in"; + const char* const resPktPath = "_test_out"; + const char* const snapshotPath = "_test_snapshot"; + + auto inPkt = PlainPacket::fromTOML(inPktPath), + expectedOutPkt = PlainPacket::fromTOML(expectedOutPktPath); + writePlainPacket(reqPktPath, inPkt); + + int secondNumCycles = numCycles / 2, + firstNumCycles = numCycles - secondNumCycles; + + { + LOG_DBG_SCOPE("go_ss 1st"); + Frontend frontend{RunParameter{ + blueprintPath, // blueprintFile + reqPktPath, // inputFile + resPktPath, // outputFile + "", // bkeyFile + 2, // numCPUWorkers + firstNumCycles, // numCycles + 0, // currentCycle + SCHED::RANKU, // sched + snapshotPath, // snapshotFile + }}; + frontend.run(); + } + { + LOG_DBG_SCOPE("go_ss 2nd"); + Snapshot ss{snapshotPath}; + ss.updateNumCycles(secondNumCycles); + Frontend frontend{ss}; + frontend.run(); + + PlainPacket got = readPlainPacket(resPktPath); + assert(got == expectedOutPkt); + } + }; + go_ss("test/config-toml/addr-register-4bit.toml", "test/in/test16.in", + "test/out/test16.out", 3); + go_ss("test/config-toml/ram-addr8bit.toml", "test/in/test06.in", + "test/out/test06.out", 16); + go_ss("test/config-toml/ram-addr9bit.toml", "test/in/test07.in", + "test/out/test07.out", 16); + go_ss("test/config-toml/ram-8-16-16.toml", "test/in/test08.in", + "test/out/test08.out", 8); + go_ss("test/config-toml/counter-4bit.toml", "test/in/test13.in", + "test/out/test13.out", 3); + go_ss("test/config-toml/cahp-ruby.toml", "test/in/test09.in", + "test/out/test09-ruby.out", 7); +} + +} // namespace plain +} // namespace nt diff --git a/src/iyokan_nt_plain.hpp b/src/iyokan_nt_plain.hpp new file mode 100644 index 0000000..1ebe0a1 --- /dev/null +++ b/src/iyokan_nt_plain.hpp @@ -0,0 +1,8 @@ +#ifndef VIRTUALSECUREPLATFORM_IYOKAN_NT_PLAIN_HPP +#define VIRTUALSECUREPLATFORM_IYOKAN_NT_PLAIN_HPP + +namespace nt::plain { +void test0(); +} + +#endif diff --git a/src/iyokan_nt_tfhepp.cpp b/src/iyokan_nt_tfhepp.cpp new file mode 100644 index 0000000..5e1c790 --- /dev/null +++ b/src/iyokan_nt_tfhepp.cpp @@ -0,0 +1,694 @@ +#include "iyokan_nt_tfhepp.hpp" +#include "blueprint.hpp" +#include "dataholder_nt.hpp" +#include "iyokan_nt.hpp" +#include "packet_nt.hpp" + +#include + +namespace { + +/* class Thread */ + +class Thread { +private: + // The C++17 keyword 'inline' is necessary here to avoid duplicate + // definition of 'pool_'. Thanks to: + // https://stackoverflow.com/questions/11709859/how-to-have-static-data-members-in-a-header-only-library + static inline std::shared_ptr pool_; + std::atomic_bool finished_; + +public: + Thread(); + + bool hasFinished() const; + static void setNumThreads(int newNumThreads); + + template + Thread& operator=(Func func) + { + finished_ = false; + pool_->enqueue([this, func]() { + func(); + finished_ = true; + }); + return *this; + } +}; + +Thread::Thread() : finished_(false) +{ + if (!pool_) { + const int DEFAULT_NUM_THREADS = 10; + pool_ = std::make_shared(DEFAULT_NUM_THREADS); + } +} + +bool Thread::hasFinished() const +{ + return finished_; +} + +void Thread::setNumThreads(int newNumThreads) +{ + pool_ = std::make_shared(newNumThreads); +} + +} // namespace + +namespace nt::tfhepp { + +struct WorkerInfo { + TFHEppBKey bkey; + + WorkerInfo(TFHEppBKey bkey) : bkey(std::move(bkey)) + { + } +}; + +class Worker : public nt::Worker { +private: + WorkerInfo wi_; + +protected: + bool canExecute(Task* task) override + { + return task->canRunTFHEpp(); + } + + void startTask(Task* task) override + { + task->startAsynchronously(wi_); + } + +public: + Worker(WorkerInfo wi) : wi_(std::move(wi)) + { + } +}; + +// struct InputSource is used by class TaskInput to set correct input value +// every cycle. +struct InputSource { + int atPortWidth, atPortBit; + std::vector* bits; +}; + +class TaskInput : public TaskCommon { +private: + std::optional source_; + +public: + TaskInput(Label label, Allocator& alc) + : TaskCommon(label, alc, 0, 1), source_(std::nullopt) + { + } + TaskInput(InputSource source, Label label, Allocator& alc) + : TaskCommon(label, alc, 0, 1), source_(source) + { + } + + void onAfterTick(size_t currentCycle) override + { + if (source_) { + // Set the output value from the source + assert(getInputSize() == 0); + InputSource& s = source_.value(); + size_t index = + (s.atPortWidth * currentCycle + s.atPortBit) % s.bits->size(); + output() = s.bits->at(index); + } + } + + void startAsynchronously(WorkerInfo&) override + { + if (getInputSize() == 1) + output() = input(0); + } + + bool hasFinished() const override + { + return true; + } + + bool canRunTFHEpp() const override + { + return true; + } + + void setInput(const DataHolder& h) override + { + // Set the input i.e., set the output value of this gate + h.getTLWELvl0(output()); + } +}; + +class TaskOutput : public TaskCommon { +public: + TaskOutput(Label label, Allocator& alc) + : TaskCommon(label, alc, 1) + { + } + + void startAsynchronously(WorkerInfo&) override + { + output() = input(0); + } + + bool hasFinished() const override + { + return true; + } + + bool canRunTFHEpp() const override + { + return true; + } + + void getOutput(DataHolder& h) override + { + h.setTLWELvl0(&output()); + } +}; + +class TaskDFF : public nt::TaskDFF { +private: + std::optional initialValue_; + +public: + TaskDFF(Label label, Allocator& alc) + : nt::TaskDFF(std::move(label), alc), + initialValue_(std::nullopt) + { + } + + TaskDFF(const TLWELvl0& initialValue, Label label, Allocator& alc) + : nt::TaskDFF(std::move(label), alc), + initialValue_(initialValue) + { + } + + void onAfterTick(size_t currentCycle) override + { + if (currentCycle == 0 && initialValue_) + output() = initialValue_.value(); + } + + bool canRunTFHEpp() const override + { + return true; + } + + void startAsynchronously(WorkerInfo&) override + { + // Nothing to do, because the main process is done in + // nt::TaskDFF::tick(). + } + + void getOutput(DataHolder& h) override + { + h.setTLWELvl0(&output()); + } +}; + +class TaskROM : public TaskCommon { +public: + TaskROM(const TLWELvl0& value, Label label, Allocator& alc) + : TaskCommon(label, alc, 0) + { + output() = value; + } + + void startAsynchronously(WorkerInfo&) override + { + } + + bool hasFinished() const override + { + return true; + } + + bool canRunTFHEpp() const override + { + return true; + } +}; + +#define DEF_TASK(CamelName, inputSize, compCost, expr) \ + class Task##CamelName : public TaskCommon { \ + private: \ + Thread thr_; \ + \ + public: \ + Task##CamelName(Label label, Allocator& alc) \ + : TaskCommon(std::move(label), alc, inputSize) \ + { \ + } \ + void startAsynchronously(WorkerInfo& wi) override \ + { \ + thr_ = [&]() { \ + [[maybe_unused]] const GateKeyFFT& gk = *wi.bkey.gk; \ + (expr); \ + }; \ + } \ + bool hasFinished() const override \ + { \ + return thr_.hasFinished(); \ + } \ + bool canRunTFHEpp() const override \ + { \ + return true; \ + } \ + int getComputationCost() const override \ + { \ + return compCost; \ + } \ + }; +DEF_TASK(And, 2, 10, TFHEpp::HomAND(output(), input(0), input(1), gk)); +DEF_TASK(Andnot, 2, 10, TFHEpp::HomANDYN(output(), input(0), input(1), gk)); +DEF_TASK(ConstOne, 0, 0, TFHEpp::HomCONSTANTONE(output())); +DEF_TASK(ConstZero, 0, 0, TFHEpp::HomCONSTANTZERO(output())); +DEF_TASK(Mux, 3, 20, + TFHEpp::HomMUX(output(), input(2), input(1), input(0), gk)); +DEF_TASK(Nand, 2, 10, TFHEpp::HomNAND(output(), input(0), input(1), gk)); +DEF_TASK(Nmux, 3, 20, + TFHEpp::HomNMUX(output(), input(2), input(1), input(0), gk)); +DEF_TASK(Nor, 2, 10, TFHEpp::HomNOR(output(), input(0), input(1), gk)); +DEF_TASK(Not, 1, 0, TFHEpp::HomNOT(output(), input(0))); +DEF_TASK(Or, 2, 10, TFHEpp::HomOR(output(), input(0), input(1), gk)); +DEF_TASK(Ornot, 2, 10, TFHEpp::HomORYN(output(), input(0), input(1), gk)); +DEF_TASK(Xnor, 2, 10, TFHEpp::HomXNOR(output(), input(0), input(1), gk)); +DEF_TASK(Xor, 2, 10, TFHEpp::HomXOR(output(), input(0), input(1), gk)); +#undef DEF_TASK + +class NetworkBuilder : public nt::NetworkBuilder { +private: + std::unordered_map*> uid2common_; + UID nextUID_; + const TFHEPacket* const reqPacket_; + const std::map* const cname2isource_; + +private: + UID genUID() + { + return nextUID_++; + } + +public: + NetworkBuilder(const std::map& cname2isource, + const TFHEPacket& reqPacket, Allocator& alc) + : nt::NetworkBuilder(alc), + uid2common_(), + nextUID_(0), + reqPacket_(&reqPacket), + cname2isource_(&cname2isource) + { + } + + ~NetworkBuilder() + { + } + + void connect(UID fromUID, UID toUID) override + { + auto &from = uid2common_.at(fromUID), &to = uid2common_.at(toUID); + to->addInput(from); + } + +#define DEF_COMMON_TASK(CAPName, CamelName) \ + UID CAPName() override \ + { \ + UID uid = genUID(); \ + Task##CamelName* task = nullptr; \ + task = emplaceTask( \ + Label{uid, #CAPName, std::nullopt}, currentAllocator()); \ + uid2common_.emplace(uid, task); \ + return uid; \ + } + DEF_COMMON_TASK(AND, And); + DEF_COMMON_TASK(ANDNOT, Andnot); + DEF_COMMON_TASK(CONSTONE, ConstOne); + DEF_COMMON_TASK(CONSTZERO, ConstZero); + DEF_COMMON_TASK(DFF, DFF); + DEF_COMMON_TASK(MUX, Mux); + DEF_COMMON_TASK(NAND, Nand); + DEF_COMMON_TASK(NMUX, Nmux); + DEF_COMMON_TASK(NOR, Nor); + DEF_COMMON_TASK(NOT, Not); + DEF_COMMON_TASK(OR, Or); + DEF_COMMON_TASK(ORNOT, Ornot); + DEF_COMMON_TASK(XNOR, Xnor); + DEF_COMMON_TASK(XOR, Xor); +#undef DEF_COMMON_TASK + + UID SDFF0() override + { + TLWELvl0 tlwe; + TFHEpp::HomCONSTANTZERO(tlwe); + UID uid = genUID(); + TaskDFF* task = emplaceTask( + tlwe, Label{uid, "SDFF0", std::nullopt}, currentAllocator()); + uid2common_.emplace(uid, task); + return uid; + } + + UID SDFF1() override + { + TLWELvl0 tlwe; + TFHEpp::HomCONSTANTONE(tlwe); + UID uid = genUID(); + TaskDFF* task = emplaceTask( + tlwe, Label{uid, "SDFF1", std::nullopt}, currentAllocator()); + uid2common_.emplace(uid, task); + return uid; + } + + UID INPUT(const std::string& nodeName, const std::string& portName, + int portBit) override + { + Allocator& alc = currentAllocator(); + UID uid = genUID(); + ConfigName cname = ConfigName{nodeName, portName, portBit}; + Label label{uid, Label::INPUT, cname}; + TaskInput* task = nullptr; + if (auto it = cname2isource_->find(cname); it != cname2isource_->end()) + task = emplaceTask(it->second, label, alc); + else + task = emplaceTask(label, alc); + uid2common_.emplace(uid, task); + return uid; + } + + UID OUTPUT(const std::string& nodeName, const std::string& portName, + int portBit) override + { + UID uid = genUID(); + TaskOutput* task = nullptr; + task = emplaceTask( + Label{uid, Label::OUTPUT, ConfigName{nodeName, portName, portBit}}, + currentAllocator()); + uid2common_.emplace(uid, task); + return uid; + } + + UID ROM(const std::string& nodeName, const std::string& portName, + int portBit) override + { + assert(reqPacket_ != nullptr); + assert(portName == "romdata"); + + UID uid = genUID(); + TaskROM* task = emplaceTask( + reqPacket_->romInTLWE.at(nodeName).at(portBit), + Label{uid, "ROM", ConfigName{nodeName, portName, portBit}}, + currentAllocator()); + uid2common_.emplace(uid, task); + return uid; + } + + UID RAM(const std::string& nodeName, const std::string& portName, + int portBit) override + { + assert(reqPacket_ != nullptr); + assert(portName == "ramdata"); + + UID uid = genUID(); + TaskDFF* task = emplaceTask( + reqPacket_->ramInTLWE.at(nodeName).at(portBit), + Label{uid, "RAM", ConfigName{nodeName, portName, portBit}}, + currentAllocator()); + uid2common_.emplace(uid, task); + + return uid; + } +}; + +class Frontend : public nt::Frontend { +private: + TFHEPacket reqPacket_; + TFHEppBKey bkey_; + +private: + void setBit0(DataHolder& dh) override; + void setBit1(DataHolder& dh) override; + void dumpResPacket(const std::string& outpath, const TaskFinder& finder, + int numCycles) override; + std::vector> makeWorkers() override; + + // Actual constructor. + void doConstruct(); + +public: + Frontend(const RunParameter& pr); + Frontend(const Snapshot& ss); +}; + +Frontend::Frontend(const RunParameter& pr) + : nt::Frontend(pr), reqPacket_(readTFHEPacket(runParam().inputFile)) +{ + doConstruct(); +} + +Frontend::Frontend(const Snapshot& ss) + : nt::Frontend(ss), reqPacket_(readTFHEPacket(runParam().inputFile)) +{ + doConstruct(); +} + +void Frontend::doConstruct() +{ + const Blueprint& bp = blueprint(); + + // Create map from ConfigName to InputSource + std::map cname2isource; + for (auto&& [key, port] : bp.atPorts()) { + // Find only inputs, that is, "[connect] ... = @..." + if (port.kind != Label::INPUT) + continue; + + // Get "@atPortName[atPortBit]" + auto& [atPortName, atPortBit] = key; + + // Check if reqPacket_ contains input data for @atPortName + auto it = reqPacket_.bits.find(atPortName); + if (it == reqPacket_.bits.end()) + continue; + + // Die if users try to set the value of @reset[0] since it is set only + // by system + if (atPortName == "reset") + ERR_DIE("@reset cannot be set by user's input"); + + // Add a new entry to cname2isource + InputSource s{bp.atPortWidths().at(atPortName), atPortBit, &it->second}; + cname2isource.emplace(port.cname, s); + } + + // Build the network. The instance is in nt::Frontend + NetworkBuilder nb{cname2isource, reqPacket_, allocator()}; + buildNetwork(nb); + + // Read bkey + readTFHEppBKey(bkey_, runParam().bkeyFile); +} + +void Frontend::setBit0(DataHolder& dh) +{ + static TLWELvl0 tlwe; + TFHEpp::HomCONSTANTZERO(tlwe); + dh.setTLWELvl0(&tlwe); +} + +void Frontend::setBit1(DataHolder& dh) +{ + static TLWELvl0 tlwe; + TFHEpp::HomCONSTANTONE(tlwe); + dh.setTLWELvl0(&tlwe); +} + +void Frontend::dumpResPacket(const std::string& outpath, + const TaskFinder& finder, int numCycles) +{ + DataHolder dh; + TFHEPacket out; + const Blueprint& bp = blueprint(); + + // Set the current number of cycles + out.numCycles = numCycles; + + // Get values of output @port + out.bits.clear(); + for (auto&& [key, port] : bp.atPorts()) { + // Find "[connect] @atPortName[atPortBit] = ..." + if (port.kind != Label::OUTPUT) + continue; + auto& [atPortName, atPortBit] = key; + + // Get the value + Task* t = finder.findByConfigName(port.cname); + t->getOutput(dh); + + // Assign the value to the corresponding bit of the response packet + auto& bits = out.bits[atPortName]; + if (bits.size() < atPortBit + 1) + bits.resize(atPortBit + 1); + dh.getTLWELvl0(bits.at(atPortBit)); + } + + // Get values of RAM + for (auto&& ram : bp.builtinRAMs()) { + std::vector& dst = out.ramInTLWE[ram.name]; + size_t size = (1 << ram.inAddrWidth) * ram.outRdataWidth; + dst.resize(size); + for (size_t i = 0; i < size; i++) { + ConfigName cname{ram.name, "ramdata", static_cast(i)}; + Task* t = finder.findByConfigName(cname); + t->getOutput(dh); + dh.getTLWELvl0(dst.at(i)); + } + } + + // Dump the result packet + writeTFHEPacket(outpath, out); +} + +std::vector> Frontend::makeWorkers() +{ + const RunParameter& pr = runParam(); + std::vector> workers; + for (size_t i = 0; i < pr.numCPUWorkers; i++) + workers.emplace_back(std::make_unique(bkey_)); + return workers; +} + +/**************************************************/ +/***** TEST ***************************************/ +/**************************************************/ + +void test0() +{ + // Set # of CPU cores as # of threads + int numCPUCores = std::thread::hardware_concurrency(); + Thread::setNumThreads(numCPUCores); + + // Prepare secret and bootstrapping keys + const char* const bkeyPath = "_test_bkey"; + std::optional sk; + { + LOG_DBG_SCOPE("GENERATE KEYS"); + LOG_DBG << "GENERATE SECRET KEY"; + sk.emplace(); + LOG_DBG << "GENERATE BOOTSTRAPPING KEY"; + writeTFHEppBKey(bkeyPath, TFHEppBKey{*sk}); + } + + auto go = [&](const std::string& blueprintPath, + const std::string& inPktPath, + const std::string& expectedOutPktPath, int numCycles) { + const char* const reqPktPath = "_test_in"; + const char* const resPktPath = "_test_out"; + + auto inPlainPkt = PlainPacket::fromTOML(inPktPath), + expectedOutPlainPkt = PlainPacket::fromTOML(expectedOutPktPath); + auto inPkt = inPlainPkt.encrypt(*sk); + writeTFHEPacket(reqPktPath, inPkt); + + LOG_DBG_SCOPE("go"); + Frontend frontend{RunParameter{ + blueprintPath, // blueprintFile + reqPktPath, // inputFile + resPktPath, // outputFile + bkeyPath, // bkeyFile + numCPUCores, // numCPUWorkers + numCycles, // numCycles + 0, // currentCycle + SCHED::RANKU, // sched + std::nullopt, // snapshotFile + }}; + frontend.run(); + + TFHEPacket got = readTFHEPacket(resPktPath); + PlainPacket gotPlain = got.decrypt(*sk); + assert(gotPlain == expectedOutPlainPkt); + }; + + go("test/config-toml/const-4bit.toml", "test/in/test22.in", + "test/out/test22.out", 1); + go("test/config-toml/addr-4bit.toml", "test/in/test04.in", + "test/out/test04.out", 1); + go("test/config-toml/pass-addr-pass-4bit.toml", "test/in/test04.in", + "test/out/test04.out", 1); + go("test/config-toml/addr-register-4bit.toml", "test/in/test16.in", + "test/out/test16.out", 3); + go("test/config-toml/div-8bit.toml", "test/in/test05.in", + "test/out/test05.out", 1); + go("test/config-toml/ram-addr8bit.toml", "test/in/test06.in", + "test/out/test06.out", 16); + go("test/config-toml/ram-addr9bit.toml", "test/in/test07.in", + "test/out/test07.out", 16); + go("test/config-toml/ram-8-16-16.toml", "test/in/test08.in", + "test/out/test08.out", 8); + go("test/config-toml/rom-4-8.toml", "test/in/test15.in", + "test/out/test15.out", 1); + go("test/config-toml/counter-4bit.toml", "test/in/test13.in", + "test/out/test13.out", 3); + go("test/config-toml/cahp-ruby.toml", "test/in/test09.in", + "test/out/test09-ruby.out", 7); + + auto go_ss = [&](const std::string& blueprintPath, + const std::string& inPktPath, + const std::string& expectedOutPktPath, int numCycles) { + const char* const reqPktPath = "_test_in"; + const char* const resPktPath = "_test_out"; + const char* const snapshotPath = "_test_snapshot"; + + auto inPlainPkt = PlainPacket::fromTOML(inPktPath), + expectedOutPlainPkt = PlainPacket::fromTOML(expectedOutPktPath); + auto inPkt = inPlainPkt.encrypt(*sk); + writeTFHEPacket(reqPktPath, inPkt); + + int secondNumCycles = numCycles / 2, + firstNumCycles = numCycles - secondNumCycles; + + { + LOG_DBG_SCOPE("go_ss 1st"); + Frontend frontend{RunParameter{ + blueprintPath, // blueprintFile + reqPktPath, // inputFile + resPktPath, // outputFile + bkeyPath, // bkeyFile + numCPUCores, // numCPUWorkers + firstNumCycles, // numCycles + 0, // currentCycle + SCHED::RANKU, // sched + snapshotPath, // snapshotFile + }}; + frontend.run(); + } + { + LOG_DBG_SCOPE("go_ss 2nd"); + Snapshot ss{snapshotPath}; + ss.updateNumCycles(secondNumCycles); + Frontend frontend{ss}; + frontend.run(); + + TFHEPacket got = readTFHEPacket(resPktPath); + PlainPacket gotPlain = got.decrypt(*sk); + assert(gotPlain == expectedOutPlainPkt); + } + }; + go_ss("test/config-toml/addr-register-4bit.toml", "test/in/test16.in", + "test/out/test16.out", 3); + go_ss("test/config-toml/ram-addr8bit.toml", "test/in/test06.in", + "test/out/test06.out", 16); + go_ss("test/config-toml/ram-addr9bit.toml", "test/in/test07.in", + "test/out/test07.out", 16); + go_ss("test/config-toml/ram-8-16-16.toml", "test/in/test08.in", + "test/out/test08.out", 8); + go_ss("test/config-toml/counter-4bit.toml", "test/in/test13.in", + "test/out/test13.out", 3); + go_ss("test/config-toml/cahp-ruby.toml", "test/in/test09.in", + "test/out/test09-ruby.out", 7); +} + +} // namespace nt::tfhepp diff --git a/src/iyokan_nt_tfhepp.hpp b/src/iyokan_nt_tfhepp.hpp new file mode 100644 index 0000000..0b58ddc --- /dev/null +++ b/src/iyokan_nt_tfhepp.hpp @@ -0,0 +1,8 @@ +#ifndef VIRTUALSECUREPLATFORM_IYOKAN_NT_TFHEPP_HPP +#define VIRTUALSECUREPLATFORM_IYOKAN_NT_TFHEPP_HPP + +namespace nt::tfhepp { +void test0(); +} + +#endif diff --git a/src/label.cpp b/src/label.cpp new file mode 100644 index 0000000..dea2962 --- /dev/null +++ b/src/label.cpp @@ -0,0 +1,24 @@ +#include "label.hpp" + +#include + +/* struct ConfigName */ +std::ostream& operator<<(std::ostream& os, const ConfigName& c) +{ + os << c.nodeName << "/" << c.portName << "[" << c.portBit << "]"; + return os; +} + +bool operator<(const ConfigName& lhs, const ConfigName& rhs) +{ + if (int res = lhs.nodeName.compare(rhs.nodeName); res != 0) + return res < 0; + if (int res = lhs.portName.compare(rhs.portName); res != 0) + return res < 0; + return lhs.portBit < rhs.portBit; +} + +/* struct Label */ +// Initialization of static variables. +const char* const Label::INPUT = "INPUT"; +const char* const Label::OUTPUT = "OUTPUT"; diff --git a/src/label.hpp b/src/label.hpp new file mode 100644 index 0000000..045f7de --- /dev/null +++ b/src/label.hpp @@ -0,0 +1,41 @@ +#ifndef VIRTUALSECUREPLATFORM_CONFIG_NAME_HPP +#define VIRTUALSECUREPLATFORM_CONFIG_NAME_HPP + +#include +#include + +struct ConfigName { + std::string nodeName, portName; + int portBit; + + // Although all member variables of ConfigName are public, we make + // operator<< and operator< its friend function for clarity. + friend std::ostream& operator<<(std::ostream& os, const ConfigName& c); + friend bool operator<(const ConfigName& lhs, const ConfigName& rhs); +}; + +using UID = uint64_t; + +struct Label { + // String literals for member variable `kind`. + // If label is for inputs or outputs, these member variable must be used, + // that is, kind == Label::INPUT or kind == Label::OUTPUT. + // The instances of these variables exist in iyokan_nt.cpp. + // We CANNOT use (C++17) `inline` here, because `inline` does NOT guarantee + // that these variables have the same value in different compilation units + // (FIXME: This behaviour is confirmed only on g++-10. We need to check the + // C++ standard). + static const char* const INPUT; + static const char* const OUTPUT; + + UID uid; + const char* const kind; // Stores a string literal + std::optional cname; + + Label(UID uid, const char* const kind, std::optional cname) + : uid(uid), kind(kind), cname(std::move(cname)) + { + } +}; + +#endif diff --git a/src/network_reader.cpp b/src/network_reader.cpp new file mode 100644 index 0000000..9b7818b --- /dev/null +++ b/src/network_reader.cpp @@ -0,0 +1,452 @@ +#include "blueprint.hpp" +#include "iyokan_nt.hpp" + +#include + +#include + +namespace { + +/* class YosysJSONReader */ + +class YosysJSONReader { +private: + enum class PORT { + IN, + OUT, + }; + + struct Port { + PORT type; + int id, bit; + + Port(PORT type, int id, int bit) : type(type), id(id), bit(bit) + { + } + }; + + enum class CELL { + NOT, + AND, + ANDNOT, + NAND, + OR, + XOR, + XNOR, + NOR, + ORNOT, + DFFP, + SDFFPP0, + SDFFPP1, + MUX, + }; + + struct Cell { + CELL type; + int id, bit0, bit1, bit2; + + Cell(CELL type, int id, int bit0) + : type(type), id(id), bit0(bit0), bit1(-1), bit2(-1) + { + } + Cell(CELL type, int id, int bit0, int bit1) + : type(type), id(id), bit0(bit0), bit1(bit1), bit2(-1) + { + } + Cell(CELL type, int id, int bit0, int bit1, int bit2) + : type(type), id(id), bit0(bit0), bit1(bit1), bit2(bit2) + { + } + }; + +private: + static int getConnBit(const picojson::object& conn, const std::string& key) + { + using namespace picojson; + const auto& bits = conn.at(key).get(); + if (bits.size() != 1) + ERR_DIE("Invalid JSON: wrong conn size: expected 1, got " + << bits.size()); + if (!bits.at(0).is()) + ERR_DIE( + "Connection of cells to a constant driver is not implemented."); + return bits.at(0).get(); + } + +public: + template + static void read(const std::string& nodeName, std::istream& is, + NetworkBuilder& builder) + { + // Convert Yosys JSON to gates. Thanks to: + // https://github.com/virtualsecureplatform/Iyokan-L1/blob/ef7c9a993ddbfd54ef58e66b116b681e59d90a3c/Converter/YosysConverter.cs + using namespace picojson; + + value v; + const std::string err = parse(v, is); + if (!err.empty()) + ERR_DIE("Invalid JSON of network: " << err); + + object& root = v.get(); + object& modules = root.at("modules").get(); + if (modules.size() != 1) + ERR_DIE(".modules should be an object of size 1"); + object& modul = modules.begin()->second.get(); + object& ports = modul.at("ports").get(); + object& cells = modul.at("cells").get(); + + std::unordered_map bit2id; + + // Create INPUT/OUTPUT and extract port connection info + std::vector portvec; + for (auto&& [key, valAny] : ports) { + object& val = valAny.template get(); + std::string& direction = val["direction"].get(); + array& bits = val["bits"].get(); + + if (key == "clock") + continue; + if (key == "reset" && bits.size() == 0) + continue; + if (direction != "input" && direction != "output") + ERR_DIE("Invalid direction token: " << direction); + + const bool isDirInput = direction == "input"; + const std::string& portName = key; + for (size_t i = 0; i < bits.size(); i++) { + const int portBit = i; + + if (bits.at(i).is()) { + // Yosys document + // (https://yosyshq.net/yosys/cmd_write_json.html) says: + // + // Signal bits that are connected to a constant driver + // are denoted as string "0" or "1" instead of a number. + // + // We handle this case here. + + if (isDirInput) + ERR_DIE( + "Invalid bits: INPUT that is connected to a " + "constant driver is not implemented"); + + std::string cnstStr = bits.at(i).get(); + bool cnst = cnstStr == "1"; + if (!cnst && cnstStr != "0") + LOG_S(WARNING) + << "Constant bit of '{}' is regarded as '0'." + << cnstStr; + + int id1 = builder.OUTPUT(nodeName, portName, portBit), + id0 = cnst ? builder.CONSTONE() : builder.CONSTZERO(); + builder.connect(id0, id1); + } + else { + const int bit = bits.at(i).get(); + + int id = isDirInput + ? builder.INPUT(nodeName, portName, portBit) + : builder.OUTPUT(nodeName, portName, portBit); + portvec.emplace_back(isDirInput ? PORT::IN : PORT::OUT, id, + bit); + if (isDirInput) + bit2id.emplace(bit, id); + } + } + } + + // Create gates and extract gate connection info + const std::unordered_map mapCell = { + {"$_NOT_", CELL::NOT}, + {"$_AND_", CELL::AND}, + {"$_ANDNOT_", CELL::ANDNOT}, + {"$_NAND_", CELL::NAND}, + {"$_OR_", CELL::OR}, + {"$_XOR_", CELL::XOR}, + {"$_XNOR_", CELL::XNOR}, + {"$_NOR_", CELL::NOR}, + {"$_ORNOT_", CELL::ORNOT}, + {"$_DFF_P_", CELL::DFFP}, + {"$_SDFF_PP0_", CELL::SDFFPP0}, + {"$_SDFF_PP1_", CELL::SDFFPP1}, + {"$_MUX_", CELL::MUX}, + }; + std::vector cellvec; + for (auto&& [_key, valAny] : cells) { + object& val = valAny.template get(); + const std::string& type = val.at("type").get(); + object& conn = val.at("connections").get(); + auto get = [&](const char* key) -> int { + return getConnBit(conn, key); + }; + + int bit = -1, id = -1; + switch (mapCell.at(type)) { + case CELL::AND: + id = builder.AND(); + cellvec.emplace_back(CELL::AND, id, get("A"), get("B")); + bit = get("Y"); + break; + case CELL::NAND: + id = builder.NAND(); + cellvec.emplace_back(CELL::NAND, id, get("A"), get("B")); + bit = get("Y"); + break; + case CELL::XOR: + id = builder.XOR(); + cellvec.emplace_back(CELL::XOR, id, get("A"), get("B")); + bit = get("Y"); + break; + case CELL::XNOR: + id = builder.XNOR(); + cellvec.emplace_back(CELL::XNOR, id, get("A"), get("B")); + bit = get("Y"); + break; + case CELL::NOR: + id = builder.NOR(); + cellvec.emplace_back(CELL::NOR, id, get("A"), get("B")); + bit = get("Y"); + break; + case CELL::ANDNOT: + id = builder.ANDNOT(); + cellvec.emplace_back(CELL::ANDNOT, id, get("A"), get("B")); + bit = get("Y"); + break; + case CELL::OR: + id = builder.OR(); + cellvec.emplace_back(CELL::OR, id, get("A"), get("B")); + bit = get("Y"); + break; + case CELL::ORNOT: + id = builder.ORNOT(); + cellvec.emplace_back(CELL::ORNOT, id, get("A"), get("B")); + bit = get("Y"); + break; + case CELL::DFFP: + id = builder.DFF(); + cellvec.emplace_back(CELL::DFFP, id, get("D")); + bit = get("Q"); + break; + case CELL::SDFFPP0: + id = builder.SDFF0(); + cellvec.emplace_back(CELL::DFFP, id, get("D")); + bit = get("Q"); + break; + case CELL::SDFFPP1: + id = builder.SDFF1(); + cellvec.emplace_back(CELL::DFFP, id, get("D")); + bit = get("Q"); + break; + case CELL::NOT: + id = builder.NOT(); + cellvec.emplace_back(CELL::NOT, id, get("A")); + bit = get("Y"); + break; + case CELL::MUX: + id = builder.MUX(); + cellvec.emplace_back(CELL::MUX, id, get("A"), get("B"), + get("S")); + bit = get("Y"); + break; + } + bit2id.emplace(bit, id); + } + + for (auto&& port : portvec) { + if (port.type == PORT::IN) + // Actually nothing to do! + continue; + builder.connect(bit2id.at(port.bit), port.id); + } + + for (auto&& cell : cellvec) { + switch (cell.type) { + case CELL::AND: + case CELL::NAND: + case CELL::XOR: + case CELL::XNOR: + case CELL::NOR: + case CELL::ANDNOT: + case CELL::OR: + case CELL::ORNOT: + builder.connect(bit2id.at(cell.bit0), cell.id); + builder.connect(bit2id.at(cell.bit1), cell.id); + break; + case CELL::DFFP: + case CELL::SDFFPP0: + case CELL::SDFFPP1: + case CELL::NOT: + builder.connect(bit2id.at(cell.bit0), cell.id); + break; + case CELL::MUX: + builder.connect(bit2id.at(cell.bit0), cell.id); + builder.connect(bit2id.at(cell.bit1), cell.id); + builder.connect(bit2id.at(cell.bit2), cell.id); + break; + } + } + } +}; + +/* class IyokanL1JSONReader */ + +class IyokanL1JSONReader { +public: + template + static void read(const std::string& nodeName, std::istream& is, + NetworkBuilder& builder, std::optional ramDataWidth) + { + std::unordered_map id2taskId; + auto addId = [&](int id, int taskId) { id2taskId.emplace(id, taskId); }; + auto findTaskId = [&](int id) { + auto it = id2taskId.find(id); + if (it == id2taskId.end()) + ERR_DIE("Invalid JSON"); + return it->second; + }; + auto connectIds = [&](int from, int to) { + builder.connect(findTaskId(from), findTaskId(to)); + }; + + picojson::value v; + const std::string err = picojson::parse(v, is); + if (!err.empty()) + ERR_DIE("Invalid JSON of network: " << err); + + picojson::object& obj = v.get(); + picojson::array& cells = obj["cells"].get(); + picojson::array& ports = obj["ports"].get(); + for (const auto& e : ports) { + picojson::object port = e.get(); + std::string type = port.at("type").get(); + int id = static_cast(port.at("id").get()); + std::string portName = port.at("portName").get(); + int portBit = static_cast(port.at("portBit").get()); + if (type == "input") + addId(id, builder.INPUT(nodeName, portName, portBit)); + else if (type == "output") + addId(id, builder.OUTPUT(nodeName, portName, portBit)); + } + for (const auto& e : cells) { + picojson::object cell = e.get(); + std::string type = cell.at("type").get(); + int id = static_cast(cell.at("id").get()); + if (type == "AND") + addId(id, builder.AND()); + else if (type == "NAND") + addId(id, builder.NAND()); + else if (type == "ANDNOT") + addId(id, builder.ANDNOT()); + else if (type == "XOR") + addId(id, builder.XOR()); + else if (type == "XNOR") + addId(id, builder.XNOR()); + else if (type == "DFFP") + addId(id, builder.DFF()); + else if (type == "NOT") + addId(id, builder.NOT()); + else if (type == "NOR") + addId(id, builder.NOR()); + else if (type == "OR") + addId(id, builder.OR()); + else if (type == "ORNOT") + addId(id, builder.ORNOT()); + else if (type == "MUX") + addId(id, builder.MUX()); + else { + bool valid = false; + if (type == "RAM" && ramDataWidth) { + valid = true; + int addr = cell.at("ramAddress").get(), + bit = cell.at("ramBit").get(); + addId(id, builder.RAM(nodeName, "ramdata", + addr * ramDataWidth.value() + bit)); + } + if (!valid) + ERR_DIE("Invalid JSON of network. Invalid type: " << type); + } + } + for (const auto& e : ports) { + picojson::object port = e.get(); + std::string type = port.at("type").get(); + int id = static_cast(port.at("id").get()); + picojson::array& bits = port.at("bits").get(); + if (type == "input") { + // nothing to do! + } + else if (type == "output") { + for (const auto& b : bits) { + int logic = static_cast(b.get()); + connectIds(logic, id); + } + } + } + for (const auto& e : cells) { + picojson::object cell = e.get(); + std::string type = cell.at("type").get(); + int id = static_cast(cell.at("id").get()); + picojson::object input = cell.at("input").get(); + if (type == "AND" || type == "NAND" || type == "XOR" || + type == "XNOR" || type == "NOR" || type == "ANDNOT" || + type == "OR" || type == "ORNOT") { + int A = static_cast(input.at("A").get()); + int B = static_cast(input.at("B").get()); + connectIds(A, id); + connectIds(B, id); + } + else if (type == "DFFP" || type == "RAM") { + int D = static_cast(input.at("D").get()); + connectIds(D, id); + } + else if (type == "NOT") { + int A = static_cast(input.at("A").get()); + connectIds(A, id); + } + else if (type == "MUX") { + int A = static_cast(input.at("A").get()); + int B = static_cast(input.at("B").get()); + int S = static_cast(input.at("S").get()); + connectIds(A, id); + connectIds(B, id); + connectIds(S, id); + } + else { + ERR_DIE("Invalid JSON of network. Invalid type: " << type); + } + } + } +}; + +} // namespace + +namespace nt { + +void readPrecompiledRAMNetworkFromFile(const std::string& name, + std::istream& is, nt::NetworkBuilder& nb, + int ramDataWidth) +{ + IyokanL1JSONReader::read(name, is, nb, ramDataWidth); +} + +void readNetworkFromFile(const blueprint::File& file, nt::NetworkBuilder& nb) +{ + std::ifstream ifs{file.path, std::ios::binary}; + if (!ifs) + ERR_DIE("Invalid [[file]] path: " << file.path); + + switch (file.type) { + case blueprint::File::TYPE::IYOKANL1_JSON: + LOG_S(WARNING) + << "[[file]] of type 'iyokanl1-json' is deprecated. You don't need " + "to use Iyokan-L1. Use Yosys JSON directly by specifying type " + "'yosys-json'."; + IyokanL1JSONReader::read(file.name, ifs, nb, std::nullopt); + break; + + case blueprint::File::TYPE::YOSYS_JSON: + YosysJSONReader::read(file.name, ifs, nb); + break; + } +} + +} // namespace nt diff --git a/src/packet.hpp b/src/packet.hpp index c9c9847..f828349 100644 --- a/src/packet.hpp +++ b/src/packet.hpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include @@ -17,7 +16,6 @@ #include "tfhepp_cufhe_wrapper.hpp" #include "error.hpp" -#include "utility.hpp" enum class Bit : bool {}; inline constexpr Bit operator~(Bit l) noexcept diff --git a/src/packet_nt.cpp b/src/packet_nt.cpp new file mode 100644 index 0000000..1029f59 --- /dev/null +++ b/src/packet_nt.cpp @@ -0,0 +1,394 @@ +#include "packet_nt.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +template +void readFromArchive(T& res, std::istream& is) +{ + cereal::PortableBinaryInputArchive ar{is}; + ar(res); +} + +template +void readFromArchive(T& res, const std::string& path) +{ + try { + std::ifstream ifs{path, std::ios::binary}; + if (!ifs) + ERR_DIE( + "Can't open the file to read from; Maybe not found?: " << path); + readFromArchive(res, ifs); + } + catch (std::exception& ex) { + ERR_DIE("Invalid archive: " << path); + } +} + +template +T readFromArchive(std::istream& is) +{ + T ret; + readFromArchive(ret, is); + return ret; +} + +template +T readFromArchive(const std::string& path) +{ + T ret; + readFromArchive(ret, path); + return ret; +} + +template +void writeToArchive(std::ostream& os, const T& src) +{ + cereal::PortableBinaryOutputArchive ar{os}; + ar(src); +} + +template +void writeToArchive(const std::string& path, const T& src) +{ + try { + std::ofstream ofs{path, std::ios::binary}; + if (!ofs) + ERR_DIE("Can't open the file to write in; maybe not allowed?: " + << path); + return writeToArchive(ofs, src); + } + catch (std::exception& ex) { + ERR_DIE("Unable to write into archive: " << path << ": " << ex.what()); + } +} + +template +bool isCorrectArchive(const std::string& path) +{ + try { + std::ifstream ifs{path, std::ios::binary}; + if (!ifs) + return false; + T cont; + readFromArchive(cont, ifs); + return true; + } + catch (std::exception& ex) { + return false; + } +} + +} // namespace + +namespace nt { +uint64_t bitvec2i(const std::vector& src, int start, int end) +{ + if (end == -1) + end = src.size(); + assert(end - start < 64); + uint64_t ret = 0; + for (size_t i = start; i < end; i++) + ret |= (static_cast(src.at(i)) << (i - start)); + return ret; +} + +std::vector encryptBits(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + std::vector in; + in.reserve(src.size()); + for (auto&& bit : src) + in.push_back(bit == 1_b ? 1 : 0); + return TFHEpp::bootsSymEncrypt(in, key); +} + +std::vector encryptROM(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + using P = TFHEpp::lvl1param; + std::vector ret; + + PolyLvl1 pmu = {}; + for (size_t i = 0; i < src.size(); i++) { + pmu[i % P::n] = src[i] == 1_b ? P::μ : -P::μ; + if (i % P::n == P::n - 1) + ret.push_back( + TFHEpp::trlweSymEncrypt(pmu, P::α, key.key.lvl1)); + } + if (src.size() % P::n != 0) + ret.push_back(TFHEpp::trlweSymEncrypt(pmu, P::α, key.key.lvl1)); + + return ret; +} + +std::vector encryptROMInTLWE(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + return encryptBits(key, src); +} + +std::vector encryptRAM(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + using P = TFHEpp::lvl1param; + std::vector ret; + + for (auto&& bit : src) { + PolyLvl1 pmu = {}; + pmu[0] = bit == 1_b ? P::μ : -P::μ; + ret.push_back(TFHEpp::trlweSymEncrypt(pmu, P::α, key.key.lvl1)); + } + + return ret; +} + +std::vector encryptRAMInTLWE(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + return encryptBits(key, src); +} + +std::vector decrypt(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + std::vector ret; + for (auto it = src.begin(); it != src.end();) { + uint8_t byte = 0; + for (uint32_t i = 0; i < 8; i++, ++it) { + assert(it != src.end()); + uint8_t val = TFHEpp::bootsSymDecrypt(std::vector{*it}, key).at(0); + byte |= (val & 1u) << i; + } + ret.push_back(byte); + } + return ret; +} + +std::vector decryptBits(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + auto bitvals = TFHEpp::bootsSymDecrypt(src, key); + std::vector bits; + for (auto&& bitval : bitvals) + bits.push_back(bitval != 0 ? 1_b : 0_b); + return bits; +} + +std::vector decryptRAM(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + std::vector ret; + for (auto&& encbit : src) { + uint8_t bitval = + TFHEpp::trlweSymDecrypt(encbit, key.key.lvl1).at(0); + ret.push_back(bitval != 0 ? 1_b : 0_b); + } + + return ret; +} + +std::vector decryptRAMInTLWE(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + return decryptBits(key, src); +} + +std::vector decryptROM(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + std::vector ret; + for (auto&& encblk : src) { + auto blk = TFHEpp::trlweSymDecrypt(encblk, key.key.lvl1); + for (uint8_t bitval : blk) + ret.push_back(bitval != 0 ? 1_b : 0_b); + } + + return ret; +} + +std::vector decryptROMInTLWE(const TFHEpp::SecretKey& key, + const std::vector& src) +{ + return decryptBits(key, src); +} + +bool PlainPacket::operator==(const PlainPacket& rhs) const +{ + // Check if member variables of *this and rhs are equal. + // If we used C++20, we could make C++ compilers derive this code by using + // '= default'! + return ram == rhs.ram && rom == rhs.rom && bits == rhs.bits && + numCycles == rhs.numCycles; +} + +TFHEPacket PlainPacket::encrypt(const TFHEpp::SecretKey& key) const +{ + TFHEPacket tfhe{{}, {}, {}, {}, {}, numCycles}; + + // Encrypt RAM + for (auto&& [name, src] : ram) { + if (auto [it, inserted] = tfhe.ram.emplace(name, encryptRAM(key, src)); + !inserted) + ERR_DIE("Invalid PlainPacket. Duplicate ram's key: " << name); + if (auto [it, inserted] = + tfhe.ramInTLWE.emplace(name, encryptRAMInTLWE(key, src)); + !inserted) + ERR_DIE("Invalid PlainPacket. Duplicate ram's key: " << name); + } + + // Encrypt ROM + for (auto&& [name, src] : rom) { + if (auto [it, inserted] = tfhe.rom.emplace(name, encryptROM(key, src)); + !inserted) + ERR_DIE("Invalid PlainPacket. Duplicate rom's key: " << name); + if (auto [it, inserted] = + tfhe.romInTLWE.emplace(name, encryptROMInTLWE(key, src)); + !inserted) + ERR_DIE("Invalid PlainPacket. Duplicate rom's key: " << name); + } + + // Encrypt bits + for (auto&& [name, src] : bits) { + auto [it, inserted] = tfhe.bits.emplace(name, encryptBits(key, src)); + if (!inserted) + ERR_DIE("Invalid PlainPacket. Duplicate bits's key: " << name); + } + + return tfhe; +} + +PlainPacket PlainPacket::fromTOML(const std::string& filepath) +{ + // FIXME: iyokan-packet should use this function + const auto root = toml::parse(filepath); + int numCycles = toml::find_or(root, "cycles", -1); + std::unordered_map> ram, rom, bits; + + auto parseEntries = [&root]( + std::unordered_map>& + name2bitvec, + const std::string& entryName) { + if (!root.contains(entryName)) + return; + const auto tables = + toml::find>(root, entryName); + for (const auto& table : tables) { + const auto name = toml::find(table, "name"); + const auto size = toml::find(table, "size"); + const auto bytes = + toml::find>(table, "bytes"); + + std::vector& v = name2bitvec[name]; + v.resize(size, 0_b); + auto it = v.begin(); + for (uint64_t byte : bytes) { + if (byte > 0xffu) + LOG_S(WARNING) + << "'bytes' field expects only <256 unsinged integer, " + "but got '" + << byte << "'. Only the lower 8bits is used."; + for (int i = 0; i < 8; i++) { + if (it == v.end()) + goto end; + *it++ = ((byte >> i) & 1u) != 0 ? 1_b : 0_b; + } + } + end:; // ';' is necessary since label is followed by expression. + } + }; + + parseEntries(ram, "ram"); // [[ram]] + parseEntries(rom, "rom"); // [[rom]] + parseEntries(bits, "bits"); // [[bits]] + + return PlainPacket{ram, rom, bits, numCycles}; +} + +PlainPacket TFHEPacket::decrypt(const TFHEpp::SecretKey& key) const +{ + PlainPacket plain{{}, {}, {}, numCycles}; + + // Decrypt RAM + for (auto&& [name, trlwes] : ram) + plain.ram.emplace(name, decryptRAM(key, trlwes)); + for (auto&& [name, tlwes] : ramInTLWE) + plain.ram.emplace(name, decryptRAMInTLWE(key, tlwes)); + + // Decrypt ROM + for (auto&& [name, trlwes] : rom) + plain.rom.emplace(name, decryptROM(key, trlwes)); + for (auto&& [name, tlwes] : romInTLWE) + plain.rom.emplace(name, decryptROMInTLWE(key, tlwes)); + + // Decrypt bits + for (auto&& [name, tlwes] : bits) { + auto [it, inserted] = plain.bits.emplace(name, decryptBits(key, tlwes)); + if (!inserted) + ERR_DIE("Invalid TFHEPacket. Duplicate bits's key: " << name); + } + + return plain; +} + +PlainPacket readPlainPacket(std::istream& is) +{ + return readFromArchive(is); +} + +PlainPacket readPlainPacket(const std::string& path) +{ + return readFromArchive(path); +} + +TFHEPacket readTFHEPacket(std::istream& is) +{ + return readFromArchive(is); +} + +TFHEPacket readTFHEPacket(const std::string& path) +{ + return readFromArchive(path); +} + +void readTFHEppBKey(TFHEppBKey& out, const std::string& path) +{ + out = readFromArchive(path); +} + +void writePlainPacket(std::ostream& os, const PlainPacket& pkt) +{ + writeToArchive(os, pkt); +} + +void writePlainPacket(const std::string& path, const PlainPacket& pkt) +{ + writeToArchive(path, pkt); +} + +void writeTFHEPacket(std::ostream& os, const TFHEPacket& pkt) +{ + writeToArchive(os, pkt); +} + +void writeTFHEPacket(const std::string& path, const TFHEPacket& pkt) +{ + writeToArchive(path, pkt); +} + +void writeTFHEppBKey(const std::string& path, const TFHEppBKey& bkey) +{ + writeToArchive(path, bkey); +} + +} // namespace nt diff --git a/src/packet_nt.hpp b/src/packet_nt.hpp new file mode 100644 index 0000000..8208c62 --- /dev/null +++ b/src/packet_nt.hpp @@ -0,0 +1,144 @@ +#ifndef VIRTUALSECUREPLATFORM_PACKET_NT_HPP +#define VIRTUALSECUREPLATFORM_PACKET_NT_HPP + +#include "error_nt.hpp" +#include "tfhepp_cufhe_wrapper.hpp" + +#include + +namespace nt { + +enum class Bit : bool {}; +inline constexpr Bit operator!(Bit l) noexcept +{ + return Bit(!static_cast(l)); +} +inline constexpr Bit operator|(Bit l, Bit r) noexcept +{ + return Bit(static_cast(l) | static_cast(r)); +} +inline constexpr Bit operator&(Bit l, Bit r) noexcept +{ + return Bit(static_cast(l) & static_cast(r)); +} +inline constexpr Bit operator^(Bit l, Bit r) noexcept +{ + return Bit(static_cast(l) ^ static_cast(r)); +} +inline constexpr Bit operator|=(Bit& l, Bit r) noexcept +{ + return l = l | r; +} +inline constexpr Bit operator&=(Bit& l, Bit r) noexcept +{ + return l = l & r; +} +inline constexpr Bit operator^=(Bit& l, Bit r) noexcept +{ + return l = l ^ r; +} +inline Bit operator"" _b(unsigned long long x) +{ + return Bit(x != 0); +} +inline std::ostream& operator<<(std::ostream& os, const Bit& bit) +{ + os << (bit == 1_b ? 1 : 0); + return os; +} + +struct TFHEppBKey { + std::shared_ptr gk; + std::shared_ptr ck; + + TFHEppBKey() + { + } + + TFHEppBKey(const TFHEpp::SecretKey& sk) + : gk(std::make_shared(sk)), + ck(std::make_shared(sk)) + { + } + + template + void serialize(Archive& ar) + { + ar(gk, ck); + } +}; + +struct TFHEPacket; + +struct PlainPacket { + std::unordered_map> ram; + std::unordered_map> rom; + std::unordered_map> bits; + std::optional numCycles; + + template + void serialize(Archive& ar) + { + ar(ram, rom, bits, numCycles); + } + + bool operator==(const PlainPacket& rhs) const; + TFHEPacket encrypt(const TFHEpp::SecretKey& key) const; + + static PlainPacket fromTOML(const std::string& filepath); +}; + +struct TFHEPacket { + std::unordered_map> ram; + std::unordered_map> ramInTLWE; + std::unordered_map> rom; + std::unordered_map> romInTLWE; + std::unordered_map> bits; + std::optional numCycles; + + template + void serialize(Archive& ar) + { + ar(ram, ramInTLWE, rom, romInTLWE, bits, numCycles); + } + + PlainPacket decrypt(const TFHEpp::SecretKey& key) const; +}; + +uint64_t bitvec2i(const std::vector& src, int start = 0, int end = -1); +std::vector encryptBits(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector encryptROM(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector encryptROMInTLWE(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector encryptRAM(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector encryptRAMInTLWE(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector decrypt(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector decryptBits(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector decryptRAM(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector decryptRAMInTLWE(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector decryptROM(const TFHEpp::SecretKey& key, + const std::vector& src); +std::vector decryptROMInTLWE(const TFHEpp::SecretKey& key, + const std::vector& src); +PlainPacket readPlainPacket(std::istream& is); +PlainPacket readPlainPacket(const std::string& path); +TFHEPacket readTFHEPacket(std::istream& is); +TFHEPacket readTFHEPacket(const std::string& path); +void readTFHEppBKey(TFHEppBKey& out, const std::string& path); +void writePlainPacket(std::ostream& os, const PlainPacket& pkt); +void writePlainPacket(const std::string& path, const PlainPacket& pkt); +void writeTFHEPacket(std::ostream& os, const TFHEPacket& pkt); +void writeTFHEPacket(const std::string& path, const TFHEPacket& pkt); +void writeTFHEppBKey(const std::string& path, const TFHEppBKey& bkey); + +} // namespace nt + +#endif diff --git a/src/snapshot.cpp b/src/snapshot.cpp new file mode 100644 index 0000000..366bd84 --- /dev/null +++ b/src/snapshot.cpp @@ -0,0 +1,85 @@ +#include "iyokan_nt.hpp" + +#include +#include +#include +#include + +#include + +namespace nt { + +/* class Snapshot */ + +Snapshot::Snapshot(const RunParameter& pr, + const std::shared_ptr& alc) + : pr_(pr), alc_(alc) +{ +} + +Snapshot::Snapshot(const std::string& snapshotFile) : pr_(), alc_(nullptr) +{ + LOG_DBG_SCOPE("READ SNAPSHOT"); + + LOG_DBG << "OPEN"; + std::ifstream ifs{snapshotFile}; + if (!ifs) + ERR_DIE("Can't open a snapshot file to read from: " << snapshotFile); + cereal::PortableBinaryInputArchive ar{ifs}; + + // Read header + LOG_DBG << "READ HEADER"; + std::string header; + ar(header); + if (header != "IYSS") // IYokan SnapShot + ERR_DIE( + "Can't read the snapshot file; incorrect header: " << snapshotFile); + + // Read run parameters + LOG_DBG << "READ RUN PARAMS"; + ar(pr_); + + // Read allocator + LOG_DBG << "READ ALLOCATOR"; + alc_.reset(new Allocator(ar)); +} + +const RunParameter& Snapshot::getRunParam() const +{ + return pr_; +} + +const std::shared_ptr& Snapshot::getAllocator() const +{ + return alc_; +} + +void Snapshot::dump(const std::string& snapshotFile) const +{ + std::ofstream ofs{snapshotFile}; + if (!ofs) + ERR_DIE("Can't open a snapshot file to write in: " << snapshotFile); + cereal::PortableBinaryOutputArchive ar{ofs}; + + // Write header + // FIXME: much better way to store the header? + ar(std::string{"IYSS"}); // IYokan SnapShot + + // Serialize pr_ of class RunParameter + ar(pr_); + + // Serialize alc_ of class Allocator + alc_->dumpAllocatedData(ar); +} + +void Snapshot::updateCurrentCycle(int currentCycle) +{ + pr_.currentCycle = currentCycle; +} + +void Snapshot::updateNumCycles(int numCycles) +{ + pr_.numCycles = numCycles; +} + +} // namespace nt diff --git a/src/test0.cpp b/src/test0.cpp index 5f990c3..b21d339 100644 --- a/src/test0.cpp +++ b/src/test0.cpp @@ -1,529 +1,871 @@ -#include "iyokan.hpp" -#include "tfhepp_cufhe_wrapper.hpp" - -#include -#include - -template -void assertNetValid(Network&& net) -{ - error::Stack err; - net.checkValid(err); - if (err.empty()) - return; - - std::cerr << err.str() << std::endl; - assert(0); -} - -template -auto get(TaskNetwork& net, const std::string& kind, const std::string& portName, - int portBit) -{ - return net.template get( - kind, portName, portBit); -} - -template -typename NetworkBuilder::NetworkType readNetworkFromJSON(std::istream& is) -{ - // FIXME: Assume the stream `is` emits Iyokan-L1 JSON - NetworkBuilder builder; - IyokanL1JSONReader{}.read(builder, is); - return typename NetworkBuilder::NetworkType{std::move(builder)}; -} - -// Assume variable names 'NetworkBuilder' and 'net' -#define ASSERT_OUTPUT_EQ(portName, portBit, expected) \ - assert(getOutput(get(net, "output", portName, portBit)) == \ - (expected)) -#define SET_INPUT(portName, portBit, val) \ - setInput(get(net, "input", portName, portBit), val) - -template -void testNOT() -{ - NetworkBuilder builder; - int id0 = builder.INPUT("A", 0); - int id1 = builder.NOT(); - int id2 = builder.OUTPUT("out", 0); - builder.connect(id0, id1); - builder.connect(id1, id2); - - TaskNetwork net = std::move(builder); - auto out = get(net, "output", "out", 0); - - std::array, 8> invals{{{0, 1}, {0, 1}}}; - for (int i = 0; i < 2; i++) { - // Set inputs. - SET_INPUT("A", 0, std::get<0>(invals[i])); - - processAllGates(net); - - // Check if results are okay. - assert(getOutput(out) == std::get<1>(invals[i])); - - net.tick(); - } -} - -template -void testMUX() -{ - NetworkBuilder builder; - int id0 = builder.INPUT("A", 0); - int id1 = builder.INPUT("B", 0); - int id2 = builder.INPUT("S", 0); - int id3 = builder.MUX(); - int id4 = builder.OUTPUT("out", 0); - builder.connect(id0, id3); - builder.connect(id1, id3); - builder.connect(id2, id3); - builder.connect(id3, id4); - - TaskNetwork net = std::move(builder); - - std::array, 8> invals{{/*A,B, S, O*/ - {0, 0, 0, 0}, - {0, 0, 1, 0}, - {0, 1, 0, 0}, - {0, 1, 1, 1}, - {1, 0, 0, 1}, - {1, 0, 1, 0}, - {1, 1, 0, 1}, - {1, 1, 1, 1}}}; - for (int i = 0; i < 8; i++) { - // Set inputs. - SET_INPUT("A", 0, std::get<0>(invals[i])); - SET_INPUT("B", 0, std::get<1>(invals[i])); - SET_INPUT("S", 0, std::get<2>(invals[i])); - - processAllGates(net); - - // Check if results are okay. - ASSERT_OUTPUT_EQ("out", 0, std::get<3>(invals[i])); - - net.tick(); - } -} - -template -void testBinopGates() -{ - NetworkBuilder builder; - int id0 = builder.INPUT("in0", 0); - int id1 = builder.INPUT("in1", 0); - - std::unordered_map> - id2res; - -#define DEFINE_BINOP_GATE_TEST(name, e00, e01, e10, e11) \ - do { \ - int gateId = builder.name(); \ - int outputId = builder.OUTPUT("out_" #name, 0); \ - builder.connect(id0, gateId); \ - builder.connect(id1, gateId); \ - builder.connect(gateId, outputId); \ - id2res["out_" #name] = {e00, e01, e10, e11}; \ - } while (false); - DEFINE_BINOP_GATE_TEST(AND, 0, 0, 0, 1); - DEFINE_BINOP_GATE_TEST(NAND, 1, 1, 1, 0); - DEFINE_BINOP_GATE_TEST(ANDNOT, 0, 0, 1, 0); - DEFINE_BINOP_GATE_TEST(OR, 0, 1, 1, 1); - DEFINE_BINOP_GATE_TEST(ORNOT, 1, 0, 1, 1); - DEFINE_BINOP_GATE_TEST(XOR, 0, 1, 1, 0); - DEFINE_BINOP_GATE_TEST(XNOR, 1, 0, 0, 1); -#undef DEFINE_BINOP_GATE_TEST - - TaskNetwork net = std::move(builder); - - std::array, 4> invals{{{0, 0}, {0, 1}, {1, 0}, {1, 1}}}; - for (int i = 0; i < 4; i++) { - // Set inputs. - SET_INPUT("in0", 0, invals[i].first ? 1 : 0); - SET_INPUT("in1", 0, invals[i].second ? 1 : 0); - - processAllGates(net); - - // Check if results are okay. - for (auto&& [portName, res] : id2res) - ASSERT_OUTPUT_EQ(portName, 0, res[i]); - - net.tick(); - } -} - -template -void testFromJSONtest_pass_4bit() -{ - const std::string fileName = "test/iyokanl1-json/pass-4bit-iyokanl1.json"; - std::ifstream ifs{fileName}; - assert(ifs); - - auto net = readNetworkFromJSON(ifs); - assertNetValid(net); - - SET_INPUT("io_in", 0, 0); - SET_INPUT("io_in", 1, 1); - SET_INPUT("io_in", 2, 1); - SET_INPUT("io_in", 3, 0); - - processAllGates(net); - - ASSERT_OUTPUT_EQ("io_out", 0, 0); - ASSERT_OUTPUT_EQ("io_out", 1, 1); - ASSERT_OUTPUT_EQ("io_out", 2, 1); - ASSERT_OUTPUT_EQ("io_out", 3, 0); -} - -template -void testFromJSONtest_and_4bit() -{ - const std::string fileName = "test/iyokanl1-json/and-4bit-iyokanl1.json"; - std::ifstream ifs{fileName}; - assert(ifs); - - auto net = readNetworkFromJSON(ifs); - assertNetValid(net); - - SET_INPUT("io_inA", 0, 0); - SET_INPUT("io_inA", 1, 0); - SET_INPUT("io_inA", 2, 1); - SET_INPUT("io_inA", 3, 1); - SET_INPUT("io_inB", 0, 0); - SET_INPUT("io_inB", 1, 1); - SET_INPUT("io_inB", 2, 0); - SET_INPUT("io_inB", 3, 1); - - processAllGates(net); - - ASSERT_OUTPUT_EQ("io_out", 0, 0); - ASSERT_OUTPUT_EQ("io_out", 1, 0); - ASSERT_OUTPUT_EQ("io_out", 2, 0); - ASSERT_OUTPUT_EQ("io_out", 3, 1); -} - -template -void testFromJSONtest_and_4_2bit() -{ - const std::string fileName = "test/iyokanl1-json/and-4_2bit-iyokanl1.json"; - std::ifstream ifs{fileName}; - assert(ifs); - - auto net = readNetworkFromJSON(ifs); - assertNetValid(net); - - SET_INPUT("io_inA", 0, 1); - SET_INPUT("io_inA", 1, 0); - SET_INPUT("io_inA", 2, 1); - SET_INPUT("io_inA", 3, 1); - SET_INPUT("io_inB", 0, 1); - SET_INPUT("io_inB", 1, 1); - SET_INPUT("io_inB", 2, 1); - SET_INPUT("io_inB", 3, 1); - - processAllGates(net); - - ASSERT_OUTPUT_EQ("io_out", 0, 1); - ASSERT_OUTPUT_EQ("io_out", 1, 0); -} - -template -void testFromJSONtest_mux_4bit() -{ - const std::string fileName = "test/iyokanl1-json/mux-4bit-iyokanl1.json"; - std::ifstream ifs{fileName}; - assert(ifs); - - auto net = readNetworkFromJSON(ifs); - assertNetValid(net); - - SET_INPUT("io_inA", 0, 0); - SET_INPUT("io_inA", 1, 0); - SET_INPUT("io_inA", 2, 1); - SET_INPUT("io_inA", 3, 1); - SET_INPUT("io_inB", 0, 0); - SET_INPUT("io_inB", 1, 1); - SET_INPUT("io_inB", 2, 0); - SET_INPUT("io_inB", 3, 1); - - SET_INPUT("io_sel", 0, 0); - processAllGates(net); - ASSERT_OUTPUT_EQ("io_out", 0, 0); - ASSERT_OUTPUT_EQ("io_out", 1, 0); - ASSERT_OUTPUT_EQ("io_out", 2, 1); - ASSERT_OUTPUT_EQ("io_out", 3, 1); - net.tick(); - - SET_INPUT("io_sel", 0, 1); - processAllGates(net); - ASSERT_OUTPUT_EQ("io_out", 0, 0); - ASSERT_OUTPUT_EQ("io_out", 1, 1); - ASSERT_OUTPUT_EQ("io_out", 2, 0); - ASSERT_OUTPUT_EQ("io_out", 3, 1); -} - -template -void testFromJSONtest_addr_4bit() -{ - const std::string fileName = "test/iyokanl1-json/addr-4bit-iyokanl1.json"; - std::ifstream ifs{fileName}; - assert(ifs); - - auto net = readNetworkFromJSON(ifs); - assertNetValid(net); - - SET_INPUT("io_inA", 0, 0); - SET_INPUT("io_inA", 1, 0); - SET_INPUT("io_inA", 2, 1); - SET_INPUT("io_inA", 3, 1); - SET_INPUT("io_inB", 0, 0); - SET_INPUT("io_inB", 1, 1); - SET_INPUT("io_inB", 2, 0); - SET_INPUT("io_inB", 3, 1); - - processAllGates(net); - - ASSERT_OUTPUT_EQ("io_out", 0, 0); - ASSERT_OUTPUT_EQ("io_out", 1, 1); - ASSERT_OUTPUT_EQ("io_out", 2, 1); - ASSERT_OUTPUT_EQ("io_out", 3, 0); -} - -template -void testFromJSONtest_register_4bit() -{ - const std::string fileName = - "test/iyokanl1-json/register-4bit-iyokanl1.json"; - std::ifstream ifs{fileName}; - assert(ifs); - - auto net = readNetworkFromJSON(ifs); - assertNetValid(net); - - SET_INPUT("io_in", 0, 0); - SET_INPUT("io_in", 1, 0); - SET_INPUT("io_in", 2, 1); - SET_INPUT("io_in", 3, 1); - - // 1: Reset all DFFs. - SET_INPUT("reset", 0, 1); - processAllGates(net); - net.tick(); - - // 2: Store values into DFFs. - SET_INPUT("reset", 0, 0); - processAllGates(net); - net.tick(); - - ASSERT_OUTPUT_EQ("io_out", 0, 0); - ASSERT_OUTPUT_EQ("io_out", 1, 0); - ASSERT_OUTPUT_EQ("io_out", 2, 0); - ASSERT_OUTPUT_EQ("io_out", 3, 0); - - // 3: Get outputs. - SET_INPUT("reset", 0, 0); - processAllGates(net); - net.tick(); - - ASSERT_OUTPUT_EQ("io_out", 0, 0); - ASSERT_OUTPUT_EQ("io_out", 1, 0); - ASSERT_OUTPUT_EQ("io_out", 2, 1); - ASSERT_OUTPUT_EQ("io_out", 3, 1); -} - -template -void testSequentialCircuit() -{ - /* - B D - reset(0) >---> ANDNOT(4) >---> DFF(2) - ^ A v Q - | | - *--< NOT(3) <--*-----> OUTPUT(1) - A - */ - - NetworkBuilder builder; - int id0 = builder.INPUT("reset", 0); - int id1 = builder.OUTPUT("out", 0); - int id2 = builder.DFF(); - int id3 = builder.NOT(); - int id4 = builder.ANDNOT(); - builder.connect(id2, id1); - builder.connect(id4, id2); - builder.connect(id2, id3); - builder.connect(id3, id4); - builder.connect(id0, id4); - - TaskNetwork net = std::move(builder); - assertNetValid(net); - - auto dff = - std::dynamic_pointer_cast( - net.node(id2)->task()); - auto out = get(net, "output", "out", 0); - - // 1: - SET_INPUT("reset", 0, 1); - processAllGates(net); - - // 2: - net.tick(); - assert(getOutput(dff) == 0); - SET_INPUT("reset", 0, 0); - processAllGates(net); - ASSERT_OUTPUT_EQ("out", 0, 0); - - // 3: - net.tick(); - assert(getOutput(dff) == 1); - processAllGates(net); - ASSERT_OUTPUT_EQ("out", 0, 1); - - // 4: - net.tick(); - assert(getOutput(dff) == 0); - processAllGates(net); - ASSERT_OUTPUT_EQ("out", 0, 0); -} - -template -void testFromJSONtest_counter_4bit() -{ - const std::string fileName = - "test/iyokanl1-json/counter-4bit-iyokanl1.json"; - std::ifstream ifs{fileName}; - assert(ifs); - - auto net = readNetworkFromJSON(ifs); - assertNetValid(net); - - std::vector> outvals{{{0, 0, 0, 0}, - {1, 0, 0, 0}, - {0, 1, 0, 0}, - {1, 1, 0, 0}, - {0, 0, 1, 0}, - {1, 0, 1, 0}, - {0, 1, 1, 0}, - {1, 1, 1, 0}, - {0, 0, 0, 1}, - {1, 0, 0, 1}, - {0, 1, 0, 1}, - {1, 1, 0, 1}, - {0, 0, 1, 1}, - {1, 0, 1, 1}, - {0, 1, 1, 1}, - {1, 1, 1, 1}}}; - - SET_INPUT("reset", 0, 1); - processAllGates(net); - - SET_INPUT("reset", 0, 0); - for (size_t i = 0; i < outvals.size(); i++) { - net.tick(); - processAllGates(net); - ASSERT_OUTPUT_EQ("io_out", 0, outvals[i][0]); - ASSERT_OUTPUT_EQ("io_out", 1, outvals[i][1]); - ASSERT_OUTPUT_EQ("io_out", 2, outvals[i][2]); - ASSERT_OUTPUT_EQ("io_out", 3, outvals[i][3]); - } -} - -template -void testPrioritySetVisitor() -{ - NetworkBuilder builder; - int id0 = builder.INPUT("A", 0); - int id1 = builder.NOT(); - int id2 = builder.OUTPUT("out", 0); - builder.connect(id0, id1); - builder.connect(id1, id2); - - TaskNetwork net = std::move(builder); - auto depnode = get(net, "output", "out", 0)->depnode(); - assert(depnode->priority() == -1); - - // Set priority to each DepNode - GraphVisitor grvis; - net.visit(grvis); - PrioritySetVisitor privis{graph::doTopologicalSort(grvis.getMap())}; - net.visit(privis); - - assert(depnode->priority() == 2); -} - +//#include "iyokan.hpp" +//#include "tfhepp_cufhe_wrapper.hpp" // -#include "iyokan_plain.hpp" - -void processAllGates(PlainNetwork& net, - std::shared_ptr graph = nullptr) -{ - processAllGates(net, std::thread::hardware_concurrency(), graph); -} - -void setInput(std::shared_ptr task, int val) -{ - task->set(val != 0 ? 1_b : 0_b); -} - -int getOutput(std::shared_ptr task) -{ - return task->get() == 1_b ? 1 : 0; -} - -void testProgressGraphMaker() -{ - /* - B D - reset(0) >---> ANDNOT(4) >---> DFF(2) - ^ A v Q - | | - *--< NOT(3) <--*-----> OUTPUT(1) - A - */ - - PlainNetworkBuilder builder; - int id0 = builder.INPUT("reset", 0); - int id1 = builder.OUTPUT("out", 0); - int id2 = builder.DFF(); - int id3 = builder.NOT(); - int id4 = builder.ANDNOT(); - builder.connect(id2, id1); - builder.connect(id4, id2); - builder.connect(id2, id3); - builder.connect(id3, id4); - builder.connect(id0, id4); - - PlainNetwork net = std::move(builder); - assertNetValid(net); - - auto graph = std::make_shared(); - - processAllGates(net, graph); +//#include +//#include +// +// template +// void assertNetValid(Network&& net) +//{ +// error::Stack err; +// net.checkValid(err); +// if (err.empty()) +// return; +// +// std::cerr << err.str() << std::endl; +// assert(0); +//} +// +// template +// auto get(TaskNetwork& net, const std::string& kind, const std::string& +// portName, +// int portBit) +//{ +// return net.template get( +// kind, portName, portBit); +//} +// +// template +// typename NetworkBuilder::NetworkType readNetworkFromJSON(std::istream& is) +//{ +// // FIXME: Assume the stream `is` emits Iyokan-L1 JSON +// NetworkBuilder builder; +// IyokanL1JSONReader{}.read(builder, is); +// return typename NetworkBuilder::NetworkType{std::move(builder)}; +//} +// +//// Assume variable names 'NetworkBuilder' and 'net' +//#define ASSERT_OUTPUT_EQ(portName, portBit, expected) \ +// assert(getOutput(get(net, "output", portName, portBit)) == +// \ +// (expected)) +//#define SET_INPUT(portName, portBit, val) \ +// setInput(get(net, "input", portName, portBit), val) +// +// template +// void testNOT() +//{ +// NetworkBuilder builder; +// int id0 = builder.INPUT("A", 0); +// int id1 = builder.NOT(); +// int id2 = builder.OUTPUT("out", 0); +// builder.connect(id0, id1); +// builder.connect(id1, id2); +// +// TaskNetwork net = std::move(builder); +// auto out = get(net, "output", "out", 0); +// +// std::array, 8> invals{{{0, 1}, {0, 1}}}; +// for (int i = 0; i < 2; i++) { +// // Set inputs. +// SET_INPUT("A", 0, std::get<0>(invals[i])); +// +// processAllGates(net); +// +// // Check if results are okay. +// assert(getOutput(out) == std::get<1>(invals[i])); +// +// net.tick(); +// } +//} +// +// template +// void testMUX() +//{ +// NetworkBuilder builder; +// int id0 = builder.INPUT("A", 0); +// int id1 = builder.INPUT("B", 0); +// int id2 = builder.INPUT("S", 0); +// int id3 = builder.MUX(); +// int id4 = builder.OUTPUT("out", 0); +// builder.connect(id0, id3); +// builder.connect(id1, id3); +// builder.connect(id2, id3); +// builder.connect(id3, id4); +// +// TaskNetwork net = std::move(builder); +// +// std::array, 8> invals{{/*A,B, S, O*/ +// {0, 0, 0, 0}, +// {0, 0, 1, 0}, +// {0, 1, 0, 0}, +// {0, 1, 1, 1}, +// {1, 0, 0, 1}, +// {1, 0, 1, 0}, +// {1, 1, 0, 1}, +// {1, 1, 1, 1}}}; +// for (int i = 0; i < 8; i++) { +// // Set inputs. +// SET_INPUT("A", 0, std::get<0>(invals[i])); +// SET_INPUT("B", 0, std::get<1>(invals[i])); +// SET_INPUT("S", 0, std::get<2>(invals[i])); +// +// processAllGates(net); +// +// // Check if results are okay. +// ASSERT_OUTPUT_EQ("out", 0, std::get<3>(invals[i])); +// +// net.tick(); +// } +//} +// +// template +// void testBinopGates() +//{ +// NetworkBuilder builder; +// int id0 = builder.INPUT("in0", 0); +// int id1 = builder.INPUT("in1", 0); +// +// std::unordered_map> +// id2res; +// +//#define DEFINE_BINOP_GATE_TEST(name, e00, e01, e10, e11) \ +// do { \ +// int gateId = builder.name(); \ +// int outputId = builder.OUTPUT("out_" #name, 0); \ +// builder.connect(id0, gateId); \ +// builder.connect(id1, gateId); \ +// builder.connect(gateId, outputId); \ +// id2res["out_" #name] = {e00, e01, e10, e11}; \ +// } while (false); +// DEFINE_BINOP_GATE_TEST(AND, 0, 0, 0, 1); +// DEFINE_BINOP_GATE_TEST(NAND, 1, 1, 1, 0); +// DEFINE_BINOP_GATE_TEST(ANDNOT, 0, 0, 1, 0); +// DEFINE_BINOP_GATE_TEST(OR, 0, 1, 1, 1); +// DEFINE_BINOP_GATE_TEST(ORNOT, 1, 0, 1, 1); +// DEFINE_BINOP_GATE_TEST(XOR, 0, 1, 1, 0); +// DEFINE_BINOP_GATE_TEST(XNOR, 1, 0, 0, 1); +//#undef DEFINE_BINOP_GATE_TEST +// +// TaskNetwork net = std::move(builder); +// +// std::array, 4> invals{{{0, 0}, {0, 1}, {1, 0}, {1, +// 1}}}; for (int i = 0; i < 4; i++) { +// // Set inputs. +// SET_INPUT("in0", 0, invals[i].first ? 1 : 0); +// SET_INPUT("in1", 0, invals[i].second ? 1 : 0); +// +// processAllGates(net); +// +// // Check if results are okay. +// for (auto&& [portName, res] : id2res) +// ASSERT_OUTPUT_EQ(portName, 0, res[i]); +// +// net.tick(); +// } +//} +// +// template +// void testFromJSONtest_pass_4bit() +//{ +// const std::string fileName = "test/iyokanl1-json/pass-4bit-iyokanl1.json"; +// std::ifstream ifs{fileName}; +// assert(ifs); +// +// auto net = readNetworkFromJSON(ifs); +// assertNetValid(net); +// +// SET_INPUT("io_in", 0, 0); +// SET_INPUT("io_in", 1, 1); +// SET_INPUT("io_in", 2, 1); +// SET_INPUT("io_in", 3, 0); +// +// processAllGates(net); +// +// ASSERT_OUTPUT_EQ("io_out", 0, 0); +// ASSERT_OUTPUT_EQ("io_out", 1, 1); +// ASSERT_OUTPUT_EQ("io_out", 2, 1); +// ASSERT_OUTPUT_EQ("io_out", 3, 0); +//} +// +// template +// void testFromJSONtest_and_4bit() +//{ +// const std::string fileName = "test/iyokanl1-json/and-4bit-iyokanl1.json"; +// std::ifstream ifs{fileName}; +// assert(ifs); +// +// auto net = readNetworkFromJSON(ifs); +// assertNetValid(net); +// +// SET_INPUT("io_inA", 0, 0); +// SET_INPUT("io_inA", 1, 0); +// SET_INPUT("io_inA", 2, 1); +// SET_INPUT("io_inA", 3, 1); +// SET_INPUT("io_inB", 0, 0); +// SET_INPUT("io_inB", 1, 1); +// SET_INPUT("io_inB", 2, 0); +// SET_INPUT("io_inB", 3, 1); +// +// processAllGates(net); +// +// ASSERT_OUTPUT_EQ("io_out", 0, 0); +// ASSERT_OUTPUT_EQ("io_out", 1, 0); +// ASSERT_OUTPUT_EQ("io_out", 2, 0); +// ASSERT_OUTPUT_EQ("io_out", 3, 1); +//} +// +// template +// void testFromJSONtest_and_4_2bit() +//{ +// const std::string fileName = +// "test/iyokanl1-json/and-4_2bit-iyokanl1.json"; std::ifstream +// ifs{fileName}; assert(ifs); +// +// auto net = readNetworkFromJSON(ifs); +// assertNetValid(net); +// +// SET_INPUT("io_inA", 0, 1); +// SET_INPUT("io_inA", 1, 0); +// SET_INPUT("io_inA", 2, 1); +// SET_INPUT("io_inA", 3, 1); +// SET_INPUT("io_inB", 0, 1); +// SET_INPUT("io_inB", 1, 1); +// SET_INPUT("io_inB", 2, 1); +// SET_INPUT("io_inB", 3, 1); +// +// processAllGates(net); +// +// ASSERT_OUTPUT_EQ("io_out", 0, 1); +// ASSERT_OUTPUT_EQ("io_out", 1, 0); +//} +// +// template +// void testFromJSONtest_mux_4bit() +//{ +// const std::string fileName = "test/iyokanl1-json/mux-4bit-iyokanl1.json"; +// std::ifstream ifs{fileName}; +// assert(ifs); +// +// auto net = readNetworkFromJSON(ifs); +// assertNetValid(net); +// +// SET_INPUT("io_inA", 0, 0); +// SET_INPUT("io_inA", 1, 0); +// SET_INPUT("io_inA", 2, 1); +// SET_INPUT("io_inA", 3, 1); +// SET_INPUT("io_inB", 0, 0); +// SET_INPUT("io_inB", 1, 1); +// SET_INPUT("io_inB", 2, 0); +// SET_INPUT("io_inB", 3, 1); +// +// SET_INPUT("io_sel", 0, 0); +// processAllGates(net); +// ASSERT_OUTPUT_EQ("io_out", 0, 0); +// ASSERT_OUTPUT_EQ("io_out", 1, 0); +// ASSERT_OUTPUT_EQ("io_out", 2, 1); +// ASSERT_OUTPUT_EQ("io_out", 3, 1); +// net.tick(); +// +// SET_INPUT("io_sel", 0, 1); +// processAllGates(net); +// ASSERT_OUTPUT_EQ("io_out", 0, 0); +// ASSERT_OUTPUT_EQ("io_out", 1, 1); +// ASSERT_OUTPUT_EQ("io_out", 2, 0); +// ASSERT_OUTPUT_EQ("io_out", 3, 1); +//} +// +// template +// void testFromJSONtest_addr_4bit() +//{ +// const std::string fileName = "test/iyokanl1-json/addr-4bit-iyokanl1.json"; +// std::ifstream ifs{fileName}; +// assert(ifs); +// +// auto net = readNetworkFromJSON(ifs); +// assertNetValid(net); +// +// SET_INPUT("io_inA", 0, 0); +// SET_INPUT("io_inA", 1, 0); +// SET_INPUT("io_inA", 2, 1); +// SET_INPUT("io_inA", 3, 1); +// SET_INPUT("io_inB", 0, 0); +// SET_INPUT("io_inB", 1, 1); +// SET_INPUT("io_inB", 2, 0); +// SET_INPUT("io_inB", 3, 1); +// +// processAllGates(net); +// +// ASSERT_OUTPUT_EQ("io_out", 0, 0); +// ASSERT_OUTPUT_EQ("io_out", 1, 1); +// ASSERT_OUTPUT_EQ("io_out", 2, 1); +// ASSERT_OUTPUT_EQ("io_out", 3, 0); +//} +// +// template +// void testFromJSONtest_register_4bit() +//{ +// const std::string fileName = +// "test/iyokanl1-json/register-4bit-iyokanl1.json"; +// std::ifstream ifs{fileName}; +// assert(ifs); +// +// auto net = readNetworkFromJSON(ifs); +// assertNetValid(net); +// +// SET_INPUT("io_in", 0, 0); +// SET_INPUT("io_in", 1, 0); +// SET_INPUT("io_in", 2, 1); +// SET_INPUT("io_in", 3, 1); +// +// // 1: Reset all DFFs. +// SET_INPUT("reset", 0, 1); +// processAllGates(net); +// net.tick(); +// +// // 2: Store values into DFFs. +// SET_INPUT("reset", 0, 0); +// processAllGates(net); +// net.tick(); +// +// ASSERT_OUTPUT_EQ("io_out", 0, 0); +// ASSERT_OUTPUT_EQ("io_out", 1, 0); +// ASSERT_OUTPUT_EQ("io_out", 2, 0); +// ASSERT_OUTPUT_EQ("io_out", 3, 0); +// +// // 3: Get outputs. +// SET_INPUT("reset", 0, 0); +// processAllGates(net); +// net.tick(); +// +// ASSERT_OUTPUT_EQ("io_out", 0, 0); +// ASSERT_OUTPUT_EQ("io_out", 1, 0); +// ASSERT_OUTPUT_EQ("io_out", 2, 1); +// ASSERT_OUTPUT_EQ("io_out", 3, 1); +//} +// +// template +// void testSequentialCircuit() +//{ +// /* +// B D +// reset(0) >---> ANDNOT(4) >---> DFF(2) +// ^ A v Q +// | | +// *--< NOT(3) <--*-----> OUTPUT(1) +// A +// */ +// +// NetworkBuilder builder; +// int id0 = builder.INPUT("reset", 0); +// int id1 = builder.OUTPUT("out", 0); +// int id2 = builder.DFF(); +// int id3 = builder.NOT(); +// int id4 = builder.ANDNOT(); +// builder.connect(id2, id1); +// builder.connect(id4, id2); +// builder.connect(id2, id3); +// builder.connect(id3, id4); +// builder.connect(id0, id4); +// +// TaskNetwork net = std::move(builder); +// assertNetValid(net); +// +// auto dff = +// std::dynamic_pointer_cast( +// net.node(id2)->task()); +// auto out = get(net, "output", "out", 0); +// +// // 1: +// SET_INPUT("reset", 0, 1); +// processAllGates(net); +// +// // 2: +// net.tick(); +// assert(getOutput(dff) == 0); +// SET_INPUT("reset", 0, 0); +// processAllGates(net); +// ASSERT_OUTPUT_EQ("out", 0, 0); +// +// // 3: +// net.tick(); +// assert(getOutput(dff) == 1); +// processAllGates(net); +// ASSERT_OUTPUT_EQ("out", 0, 1); +// +// // 4: +// net.tick(); +// assert(getOutput(dff) == 0); +// processAllGates(net); +// ASSERT_OUTPUT_EQ("out", 0, 0); +//} +// +// template +// void testFromJSONtest_counter_4bit() +//{ +// const std::string fileName = +// "test/iyokanl1-json/counter-4bit-iyokanl1.json"; +// std::ifstream ifs{fileName}; +// assert(ifs); +// +// auto net = readNetworkFromJSON(ifs); +// assertNetValid(net); +// +// std::vector> outvals{{{0, 0, 0, 0}, +// {1, 0, 0, 0}, +// {0, 1, 0, 0}, +// {1, 1, 0, 0}, +// {0, 0, 1, 0}, +// {1, 0, 1, 0}, +// {0, 1, 1, 0}, +// {1, 1, 1, 0}, +// {0, 0, 0, 1}, +// {1, 0, 0, 1}, +// {0, 1, 0, 1}, +// {1, 1, 0, 1}, +// {0, 0, 1, 1}, +// {1, 0, 1, 1}, +// {0, 1, 1, 1}, +// {1, 1, 1, 1}}}; +// +// SET_INPUT("reset", 0, 1); +// processAllGates(net); +// +// SET_INPUT("reset", 0, 0); +// for (size_t i = 0; i < outvals.size(); i++) { +// net.tick(); +// processAllGates(net); +// ASSERT_OUTPUT_EQ("io_out", 0, outvals[i][0]); +// ASSERT_OUTPUT_EQ("io_out", 1, outvals[i][1]); +// ASSERT_OUTPUT_EQ("io_out", 2, outvals[i][2]); +// ASSERT_OUTPUT_EQ("io_out", 3, outvals[i][3]); +// } +//} +// +// template +// void testPrioritySetVisitor() +//{ +// NetworkBuilder builder; +// int id0 = builder.INPUT("A", 0); +// int id1 = builder.NOT(); +// int id2 = builder.OUTPUT("out", 0); +// builder.connect(id0, id1); +// builder.connect(id1, id2); +// +// TaskNetwork net = std::move(builder); +// auto depnode = get(net, "output", "out", 0)->depnode(); +// assert(depnode->priority() == -1); +// +// // Set priority to each DepNode +// GraphVisitor grvis; +// net.visit(grvis); +// PrioritySetVisitor privis{graph::doTopologicalSort(grvis.getMap())}; +// net.visit(privis); +// +// assert(depnode->priority() == 2); +//} +// +//// +//#include "iyokan_plain.hpp" +// +// void processAllGates(PlainNetwork& net, +// std::shared_ptr graph = nullptr) +//{ +// processAllGates(net, std::thread::hardware_concurrency(), graph); +//} +// +// void setInput(std::shared_ptr task, int val) +//{ +// task->set(val != 0 ? 1_b : 0_b); +//} +// +// int getOutput(std::shared_ptr task) +//{ +// return task->get() == 1_b ? 1 : 0; +//} +// +// void testProgressGraphMaker() +//{ +// /* +// B D +// reset(0) >---> ANDNOT(4) >---> DFF(2) +// ^ A v Q +// | | +// *--< NOT(3) <--*-----> OUTPUT(1) +// A +// */ +// +// PlainNetworkBuilder builder; +// int id0 = builder.INPUT("reset", 0); +// int id1 = builder.OUTPUT("out", 0); +// int id2 = builder.DFF(); +// int id3 = builder.NOT(); +// int id4 = builder.ANDNOT(); +// builder.connect(id2, id1); +// builder.connect(id4, id2); +// builder.connect(id2, id3); +// builder.connect(id3, id4); +// builder.connect(id0, id4); +// +// PlainNetwork net = std::move(builder); +// assertNetValid(net); +// +// auto graph = std::make_shared(); +// +// processAllGates(net, graph); +// +// std::stringstream ss; +// graph->dumpDOT(ss); +// std::string dot = ss.str(); +// assert(dot.find(fmt::sprintf("n%d [label = \"{INPUT|reset[0]}\"]", id0)) +// != +// std::string::npos); +// assert(dot.find(fmt::sprintf("n%d [label = \"{OUTPUT|out[0]}\"]", id1)) != +// std::string::npos); +// assert(dot.find(fmt::sprintf("n%d [label = \"{DFF}\"]", id2)) != +// std::string::npos); +// assert(dot.find(fmt::sprintf("n%d [label = \"{NOT}\"]", id3)) != +// std::string::npos); +// assert(dot.find(fmt::sprintf("n%d [label = \"{ANDNOT}\"]", id4)) != +// std::string::npos); +// assert(dot.find(fmt::sprintf("n%d -> n%d", id2, id1)) != +// std::string::npos); assert(dot.find(fmt::sprintf("n%d -> n%d", id4, id2)) +// != std::string::npos); assert(dot.find(fmt::sprintf("n%d -> n%d", id2, +// id3)) != std::string::npos); assert(dot.find(fmt::sprintf("n%d -> n%d", +// id0, id4)) != std::string::npos); assert(dot.find(fmt::sprintf("n%d -> +// n%d", id3, id4)) != std::string::npos); +//} +// +//#include "iyokan_tfhepp.hpp" +// +// class TFHEppTestHelper { +// private: +// std::shared_ptr sk_; +// std::shared_ptr gk_; +// std::shared_ptr ck_; +// TLWELvl0 zero_, one_; +// +// private: +// TFHEppTestHelper() +// { +// sk_ = std::make_shared(); +// gk_ = std::make_shared(*sk_); +// zero_ = TFHEpp::bootsSymEncrypt({0}, *sk_).at(0); +// one_ = TFHEpp::bootsSymEncrypt({1}, *sk_).at(0); +// } +// +// public: +// static TFHEppTestHelper& instance() +// { +// static TFHEppTestHelper inst; +// return inst; +// } +// +// void prepareCircuitKey() +// { +// ck_ = std::make_shared(*sk_); +// } +// +// TFHEppWorkerInfo wi() const +// { +// return TFHEppWorkerInfo{gk_, ck_}; +// } +// +// const std::shared_ptr& sk() const +// { +// return sk_; +// } +// +// const std::shared_ptr& gk() const +// { +// return gk_; +// } +// +// const TLWELvl0& zero() const +// { +// return zero_; +// } +// +// const TLWELvl0& one() const +// { +// return one_; +// } +//}; +// +// void processAllGates(TFHEppNetwork& net, +// std::shared_ptr graph = nullptr) +//{ +// processAllGates(net, std::thread::hardware_concurrency(), +// TFHEppTestHelper::instance().wi(), graph); +//} +// +// void setInput(std::shared_ptr task, int val) +//{ +// auto& h = TFHEppTestHelper::instance(); +// task->set(val ? h.one() : h.zero()); +//} +// +// int getOutput(std::shared_ptr task) +//{ +// return TFHEpp::bootsSymDecrypt({task->get()}, +// *TFHEppTestHelper::instance().sk())[0]; +//} +// +// void testTFHEppSerialization() +//{ +// auto& h = TFHEppTestHelper::instance(); +// const std::shared_ptr& sk = h.sk(); +// const std::shared_ptr& gk = h.wi().gateKey; +// +// // Test for secret key +// { +// // Dump +// writeToArchive("_test_sk", *sk); +// // Load +// auto sk2 = std::make_shared(); +// readFromArchive(*sk2, "_test_sk"); +// +// auto zero = TFHEpp::bootsSymEncrypt({0}, *sk2).at(0); +// auto one = TFHEpp::bootsSymEncrypt({1}, *sk2).at(0); +// TLWELvl0 res; +// TFHEpp::HomANDNY(res, zero, one, *gk); +// assert(TFHEpp::bootsSymDecrypt({res}, *sk2).at(0) == 1); +// } +// +// // Test for gate key +// { +// std::stringstream ss{std::ios::binary | std::ios::out | std::ios::in}; +// +// // Dump +// writeToArchive(ss, *gk); +// // Load +// auto gk2 = std::make_shared(); +// readFromArchive(*gk2, ss); +// +// auto zero = TFHEpp::bootsSymEncrypt({0}, *sk).at(0); +// auto one = TFHEpp::bootsSymEncrypt({1}, *sk).at(0); +// TLWELvl0 res; +// TFHEpp::HomANDNY(res, zero, one, *gk2); +// assert(TFHEpp::bootsSymDecrypt({res}, *sk).at(0) == 1); +// } +// +// // Test for TLWE level 0 +// { +// std::stringstream ss{std::ios::binary | std::ios::out | std::ios::in}; +// +// { +// auto zero = TFHEpp::bootsSymEncrypt({0}, *sk).at(0); +// auto one = TFHEpp::bootsSymEncrypt({1}, *sk).at(0); +// writeToArchive(ss, zero); +// writeToArchive(ss, one); +// ss.seekg(0); +// } +// +// { +// TLWELvl0 res, zero, one; +// readFromArchive(zero, ss); +// readFromArchive(one, ss); +// TFHEpp::HomANDNY(res, zero, one, *gk); +// assert(TFHEpp::bootsSymDecrypt({res}, *sk).at(0) == 1); +// } +// } +//} +// +//#ifdef IYOKAN_CUDA_ENABLED +//#include "iyokan_cufhe.hpp" +// +// class CUFHETestHelper { +// private: +// std::shared_ptr gk_; +// cufhe::Ctxt zero_, one_; +// +// private: +// CUFHETestHelper() +// { +// gk_ = std::make_shared(); +// ifftGateKey(*gk_, *TFHEppTestHelper::instance().gk()); +// setCtxtZero(zero_); +// setCtxtOne(one_); +// } +// +// public: +// class CUFHEManager { +// public: +// CUFHEManager() +// { +// cufhe::Initialize(*CUFHETestHelper::instance().gk_); +// } +// +// ~CUFHEManager() +// { +// cufhe::CleanUp(); +// } +// }; +// +// public: +// static CUFHETestHelper& instance() +// { +// static CUFHETestHelper inst; +// return inst; +// } +//}; +// +// void processAllGates(CUFHENetwork& net, +// std::shared_ptr graph = nullptr) +//{ +// processAllGates(net, 240, graph); +//} +// +// void setInput(std::shared_ptr task, int val) +//{ +// TLWELvl0 c; +// if (val) +// setTLWELvl0Trivial1(c); +// else +// setTLWELvl0Trivial0(c); +// task->set(c); +//} +// +// int getOutput(std::shared_ptr task) +//{ +// return decryptTLWELvl0(task->get(), *TFHEppTestHelper::instance().sk()); +//} +// +// void testBridgeBetweenCUFHEAndTFHEpp() +//{ +// auto& ht = TFHEppTestHelper::instance(); +// +// // FIXME: The network constructed here does not have any meanings anymore, +// // but it is enough to check if bridges work correctly. +// // We may need another network here. +// NetworkBuilderBase b0; +// NetworkBuilderBase b1; +// auto t0 = b0.addINPUT("in", 0, false); +// auto t1 = std::make_shared(); +// b1.addTask(NodeLabel{"tfhepp2cufhe", ""}, t1); +// auto t2 = std::make_shared(); +// b1.addTask(NodeLabel{"cufhe2tfhepp", ""}, t2); +// auto t3 = b0.addOUTPUT("out", 0, true); +// connectTasks(t1, t2); +// +// auto net0 = std::make_shared>(std::move(b0)); +// auto net1 = +// std::make_shared>(std::move(b1)); auto +// bridge0 = connectWithBridge(t0, t1); auto bridge1 = connectWithBridge(t2, +// t3); +// +// CUFHENetworkRunner runner{1, 1, ht.wi()}; +// runner.addNetwork(net0); +// runner.addNetwork(net1); +// runner.addBridge(bridge0); +// runner.addBridge(bridge1); +// +// t0->set(ht.one()); +// runner.run(false); +// assert(t3->get() == ht.one()); +// +// net0->tick(); +// net1->tick(); +// bridge0->tick(); +// bridge1->tick(); +// +// t0->set(ht.zero()); +// runner.run(false); +// assert(t3->get() == ht.zero()); +//} +//#endif +// +// void testBlueprint() +//{ +// using namespace blueprint; +// +// NetworkBlueprint blueprint{"test/config-toml/cahp-diamond.toml"}; +// +// { +// const auto& files = blueprint.files(); +// assert(files.size() == 1); +// assert(files[0].type == File::TYPE::YOSYS_JSON); +// assert(std::filesystem::canonical(files[0].path) == +// std::filesystem::canonical( +// "test/yosys-json/cahp-diamond-core-yosys.json")); +// assert(files[0].name == "core"); +// } +// +// { +// const auto& roms = blueprint.builtinROMs(); +// assert(roms.size() == 1); +// assert(roms[0].name == "rom"); +// assert(roms[0].inAddrWidth == 7); +// assert(roms[0].outRdataWidth == 32); +// } +// +// { +// const auto& rams = blueprint.builtinRAMs(); +// assert(rams.size() == 2); +// assert((rams[0].name == "ramA" && rams[1].name == "ramB") || +// (rams[1].name == "ramA" && rams[0].name == "ramB")); +// assert(rams[0].inAddrWidth == 8); +// assert(rams[0].inWdataWidth == 8); +// assert(rams[0].outRdataWidth == 8); +// assert(rams[1].inAddrWidth == 8); +// assert(rams[1].inWdataWidth == 8); +// assert(rams[1].outRdataWidth == 8); +// } +// +// { +// const auto& edges = blueprint.edges(); +// auto assertIn = [&edges](std::string fNodeName, std::string fPortName, +// std::string tNodeName, std::string tPortName, +// int size) { +// for (int i = 0; i < size; i++) { +// auto v = +// std::make_pair(Port{fNodeName, {"output", fPortName, i}}, +// Port{tNodeName, {"input", tPortName, i}}); +// auto it = std::find(edges.begin(), edges.end(), v); +// assert(it != edges.end()); +// } +// }; +// assertIn("core", "io_romAddr", "rom", "addr", 7); +// assertIn("rom", "rdata", "core", "io_romData", 32); +// assertIn("core", "io_memA_writeEnable", "ramA", "wren", 1); +// assertIn("core", "io_memA_address", "ramA", "addr", 8); +// assertIn("core", "io_memA_in", "ramA", "wdata", 8); +// assertIn("ramA", "rdata", "core", "io_memA_out", 8); +// assertIn("core", "io_memB_writeEnable", "ramB", "wren", 1); +// assertIn("core", "io_memB_address", "ramB", "addr", 8); +// assertIn("core", "io_memB_in", "ramB", "wdata", 8); +// assertIn("ramB", "rdata", "core", "io_memB_out", 8); +// } +// +// { +// const Port port = blueprint.at("reset").value(); +// assert(port.nodeName == "core"); +// assert(port.portLabel.kind == "input"); +// assert(port.portLabel.portName == "reset"); +// assert(port.portLabel.portBit == 0); +// } +// +// { +// const Port port = blueprint.at("finflag").value(); +// assert(port.nodeName == "core"); +// assert(port.portLabel.kind == "output"); +// assert(port.portLabel.portName == "io_finishFlag"); +// assert(port.portLabel.portBit == 0); +// } +// +// for (int ireg = 0; ireg < 16; ireg++) { +// for (int ibit = 0; ibit < 16; ibit++) { +// const Port port = +// blueprint.at(utility::fok("reg_x", ireg), ibit).value(); +// assert(port.nodeName == "core"); +// assert(port.portLabel.portName == +// utility::fok("io_regOut_x", ireg)); +// assert(port.portLabel.portBit == ibit); +// } +// } +//} + +#include "iyokan_nt.hpp" +#include "iyokan_nt_plain.hpp" +#include "iyokan_nt_tfhepp.hpp" +#include "packet_nt.hpp" +#include "tfhepp_cufhe_wrapper.hpp" - std::stringstream ss; - graph->dumpDOT(ss); - std::string dot = ss.str(); - assert(dot.find(fmt::sprintf("n%d [label = \"{INPUT|reset[0]}\"]", id0)) != - std::string::npos); - assert(dot.find(fmt::sprintf("n%d [label = \"{OUTPUT|out[0]}\"]", id1)) != - std::string::npos); - assert(dot.find(fmt::sprintf("n%d [label = \"{DFF}\"]", id2)) != - std::string::npos); - assert(dot.find(fmt::sprintf("n%d [label = \"{NOT}\"]", id3)) != - std::string::npos); - assert(dot.find(fmt::sprintf("n%d [label = \"{ANDNOT}\"]", id4)) != - std::string::npos); - assert(dot.find(fmt::sprintf("n%d -> n%d", id2, id1)) != std::string::npos); - assert(dot.find(fmt::sprintf("n%d -> n%d", id4, id2)) != std::string::npos); - assert(dot.find(fmt::sprintf("n%d -> n%d", id2, id3)) != std::string::npos); - assert(dot.find(fmt::sprintf("n%d -> n%d", id0, id4)) != std::string::npos); - assert(dot.find(fmt::sprintf("n%d -> n%d", id3, id4)) != std::string::npos); -} +#include +#include -#include "iyokan_tfhepp.hpp" +#include class TFHEppTestHelper { private: @@ -553,11 +895,6 @@ class TFHEppTestHelper { ck_ = std::make_shared(*sk_); } - TFHEppWorkerInfo wi() const - { - return TFHEppWorkerInfo{gk_, ck_}; - } - const std::shared_ptr& sk() const { return sk_; @@ -579,330 +916,131 @@ class TFHEppTestHelper { } }; -void processAllGates(TFHEppNetwork& net, - std::shared_ptr graph = nullptr) -{ - processAllGates(net, std::thread::hardware_concurrency(), - TFHEppTestHelper::instance().wi(), graph); -} - -void setInput(std::shared_ptr task, int val) -{ - auto& h = TFHEppTestHelper::instance(); - task->set(val ? h.one() : h.zero()); -} - -int getOutput(std::shared_ptr task) -{ - return TFHEpp::bootsSymDecrypt({task->get()}, - *TFHEppTestHelper::instance().sk())[0]; -} +namespace nt { -void testTFHEppSerialization() +void testAllocator() { - auto& h = TFHEppTestHelper::instance(); - const std::shared_ptr& sk = h.sk(); - const std::shared_ptr& gk = h.wi().gateKey; - - // Test for secret key { - // Dump - writeToArchive("_test_sk", *sk); - // Load - auto sk2 = std::make_shared(); - readFromArchive(*sk2, "_test_sk"); - - auto zero = TFHEpp::bootsSymEncrypt({0}, *sk2).at(0); - auto one = TFHEpp::bootsSymEncrypt({1}, *sk2).at(0); - TLWELvl0 res; - TFHEpp::HomANDNY(res, zero, one, *gk); - assert(TFHEpp::bootsSymDecrypt({res}, *sk2).at(0) == 1); - } - - // Test for gate key - { - std::stringstream ss{std::ios::binary | std::ios::out | std::ios::in}; - - // Dump - writeToArchive(ss, *gk); - // Load - auto gk2 = std::make_shared(); - readFromArchive(*gk2, ss); - - auto zero = TFHEpp::bootsSymEncrypt({0}, *sk).at(0); - auto one = TFHEpp::bootsSymEncrypt({1}, *sk).at(0); - TLWELvl0 res; - TFHEpp::HomANDNY(res, zero, one, *gk2); - assert(TFHEpp::bootsSymDecrypt({res}, *sk).at(0) == 1); - } - - // Test for TLWE level 0 - { - std::stringstream ss{std::ios::binary | std::ios::out | std::ios::in}; - + Allocator alc; { - auto zero = TFHEpp::bootsSymEncrypt({0}, *sk).at(0); - auto one = TFHEpp::bootsSymEncrypt({1}, *sk).at(0); - writeToArchive(ss, zero); - writeToArchive(ss, one); - ss.seekg(0); + TLWELvl0* zero = alc.make(); + TLWELvl0* one = alc.make(); + *zero = TFHEppTestHelper::instance().zero(); + *one = TFHEppTestHelper::instance().one(); } - { - TLWELvl0 res, zero, one; - readFromArchive(zero, ss); - readFromArchive(one, ss); - TFHEpp::HomANDNY(res, zero, one, *gk); - assert(TFHEpp::bootsSymDecrypt({res}, *sk).at(0) == 1); + TLWELvl0* zero = alc.get(0); + TLWELvl0* one = alc.get(1); + assert(*zero == TFHEppTestHelper::instance().zero()); + assert(*one == TFHEppTestHelper::instance().one()); } } -} - -#ifdef IYOKAN_CUDA_ENABLED -#include "iyokan_cufhe.hpp" - -class CUFHETestHelper { -private: - std::shared_ptr gk_; - cufhe::Ctxt zero_, one_; - -private: - CUFHETestHelper() { - gk_ = std::make_shared(); - ifftGateKey(*gk_, *TFHEppTestHelper::instance().gk()); - setCtxtZero(zero_); - setCtxtOne(one_); - } - -public: - class CUFHEManager { - public: - CUFHEManager() + Allocator alc; { - cufhe::Initialize(*CUFHETestHelper::instance().gk_); + TLWELvl0* zero = alc.make(); + TLWELvl0* one = alc.make(); + *zero = TFHEppTestHelper::instance().zero(); + *one = TFHEppTestHelper::instance().one(); } - - ~CUFHEManager() { - cufhe::CleanUp(); + TLWELvl0* zero = alc.get(0); + TLWELvl0* one = alc.get(1); + assert(*zero == TFHEppTestHelper::instance().zero()); + assert(*one == TFHEppTestHelper::instance().one()); } - }; - -public: - static CUFHETestHelper& instance() - { - static CUFHETestHelper inst; - return inst; } -}; - -void processAllGates(CUFHENetwork& net, - std::shared_ptr graph = nullptr) -{ - processAllGates(net, 240, graph); -} - -void setInput(std::shared_ptr task, int val) -{ - TLWELvl0 c; - if (val) - setTLWELvl0Trivial1(c); - else - setTLWELvl0Trivial0(c); - task->set(c); -} - -int getOutput(std::shared_ptr task) -{ - return decryptTLWELvl0(task->get(), *TFHEppTestHelper::instance().sk()); -} - -void testBridgeBetweenCUFHEAndTFHEpp() -{ - auto& ht = TFHEppTestHelper::instance(); - - // FIXME: The network constructed here does not have any meanings anymore, - // but it is enough to check if bridges work correctly. - // We may need another network here. - NetworkBuilderBase b0; - NetworkBuilderBase b1; - auto t0 = b0.addINPUT("in", 0, false); - auto t1 = std::make_shared(); - b1.addTask(NodeLabel{"tfhepp2cufhe", ""}, t1); - auto t2 = std::make_shared(); - b1.addTask(NodeLabel{"cufhe2tfhepp", ""}, t2); - auto t3 = b0.addOUTPUT("out", 0, true); - connectTasks(t1, t2); - - auto net0 = std::make_shared>(std::move(b0)); - auto net1 = std::make_shared>(std::move(b1)); - auto bridge0 = connectWithBridge(t0, t1); - auto bridge1 = connectWithBridge(t2, t3); - - CUFHENetworkRunner runner{1, 1, ht.wi()}; - runner.addNetwork(net0); - runner.addNetwork(net1); - runner.addBridge(bridge0); - runner.addBridge(bridge1); - - t0->set(ht.one()); - runner.run(false); - assert(t3->get() == ht.one()); - - net0->tick(); - net1->tick(); - bridge0->tick(); - bridge1->tick(); - - t0->set(ht.zero()); - runner.run(false); - assert(t3->get() == ht.zero()); } -#endif -void testBlueprint() +void testSnapshot() { - using namespace blueprint; - - NetworkBlueprint blueprint{"test/config-toml/cahp-diamond.toml"}; - { - const auto& files = blueprint.files(); - assert(files.size() == 1); - assert(files[0].type == File::TYPE::YOSYS_JSON); - assert(std::filesystem::canonical(files[0].path) == - std::filesystem::canonical( - "test/yosys-json/cahp-diamond-core-yosys.json")); - assert(files[0].name == "core"); + Allocator alc; + Bit* b0 = alc.make(); + *b0 = 0_b; + Bit* b1 = alc.make(); + *b1 = 1_b; + + std::ofstream ofs{"_test_snapshot"}; + assert(ofs); + cereal::PortableBinaryOutputArchive ar{ofs}; + alc.dumpAllocatedData(ar); } - { - const auto& roms = blueprint.builtinROMs(); - assert(roms.size() == 1); - assert(roms[0].name == "rom"); - assert(roms[0].inAddrWidth == 7); - assert(roms[0].outRdataWidth == 32); - } - - { - const auto& rams = blueprint.builtinRAMs(); - assert(rams.size() == 2); - assert((rams[0].name == "ramA" && rams[1].name == "ramB") || - (rams[1].name == "ramA" && rams[0].name == "ramB")); - assert(rams[0].inAddrWidth == 8); - assert(rams[0].inWdataWidth == 8); - assert(rams[0].outRdataWidth == 8); - assert(rams[1].inAddrWidth == 8); - assert(rams[1].inWdataWidth == 8); - assert(rams[1].outRdataWidth == 8); - } - - { - const auto& edges = blueprint.edges(); - auto assertIn = [&edges](std::string fNodeName, std::string fPortName, - std::string tNodeName, std::string tPortName, - int size) { - for (int i = 0; i < size; i++) { - auto v = - std::make_pair(Port{fNodeName, {"output", fPortName, i}}, - Port{tNodeName, {"input", tPortName, i}}); - auto it = std::find(edges.begin(), edges.end(), v); - assert(it != edges.end()); - } - }; - assertIn("core", "io_romAddr", "rom", "addr", 7); - assertIn("rom", "rdata", "core", "io_romData", 32); - assertIn("core", "io_memA_writeEnable", "ramA", "wren", 1); - assertIn("core", "io_memA_address", "ramA", "addr", 8); - assertIn("core", "io_memA_in", "ramA", "wdata", 8); - assertIn("ramA", "rdata", "core", "io_memA_out", 8); - assertIn("core", "io_memB_writeEnable", "ramB", "wren", 1); - assertIn("core", "io_memB_address", "ramB", "addr", 8); - assertIn("core", "io_memB_in", "ramB", "wdata", 8); - assertIn("ramB", "rdata", "core", "io_memB_out", 8); - } - - { - const Port port = blueprint.at("reset").value(); - assert(port.nodeName == "core"); - assert(port.portLabel.kind == "input"); - assert(port.portLabel.portName == "reset"); - assert(port.portLabel.portBit == 0); - } - - { - const Port port = blueprint.at("finflag").value(); - assert(port.nodeName == "core"); - assert(port.portLabel.kind == "output"); - assert(port.portLabel.portName == "io_finishFlag"); - assert(port.portLabel.portBit == 0); - } - - for (int ireg = 0; ireg < 16; ireg++) { - for (int ibit = 0; ibit < 16; ibit++) { - const Port port = - blueprint.at(utility::fok("reg_x", ireg), ibit).value(); - assert(port.nodeName == "core"); - assert(port.portLabel.portName == - utility::fok("io_regOut_x", ireg)); - assert(port.portLabel.portBit == ibit); - } + std::ifstream ifs{"_test_snapshot"}; + assert(ifs); + cereal::PortableBinaryInputArchive ar{ifs}; + Allocator alc{ar}; + assert(*alc.make() == 0_b); + assert(*alc.make() == 1_b); } } +} // namespace nt + int main() { - AsyncThread::setNumThreads(std::thread::hardware_concurrency()); - - testNOT(); - testMUX(); - testBinopGates(); - testFromJSONtest_pass_4bit(); - testFromJSONtest_and_4bit(); - testFromJSONtest_and_4_2bit(); - testFromJSONtest_mux_4bit(); - testFromJSONtest_addr_4bit(); - testFromJSONtest_register_4bit(); - testSequentialCircuit(); - testFromJSONtest_counter_4bit(); - testPrioritySetVisitor(); - - testNOT(); - testMUX(); - testBinopGates(); - testFromJSONtest_pass_4bit(); - testFromJSONtest_pass_4bit(); - testFromJSONtest_and_4bit(); - testFromJSONtest_and_4_2bit(); - testFromJSONtest_mux_4bit(); - testFromJSONtest_addr_4bit(); - testFromJSONtest_register_4bit(); - testSequentialCircuit(); - testFromJSONtest_counter_4bit(); - testPrioritySetVisitor(); - testTFHEppSerialization(); - -#ifdef IYOKAN_CUDA_ENABLED - { - CUFHETestHelper::CUFHEManager man; - - testNOT(); - testMUX(); - testBinopGates(); - testFromJSONtest_pass_4bit(); - testFromJSONtest_and_4bit(); - testFromJSONtest_and_4_2bit(); - testFromJSONtest_mux_4bit(); - testFromJSONtest_addr_4bit(); - testFromJSONtest_register_4bit(); - testSequentialCircuit(); - testFromJSONtest_counter_4bit(); - testPrioritySetVisitor(); - testBridgeBetweenCUFHEAndTFHEpp(); - } -#endif - - testProgressGraphMaker(); - testBlueprint(); + // AsyncThread::setNumThreads(std::thread::hardware_concurrency()); + + // testNOT(); + // testMUX(); + // testBinopGates(); + // testFromJSONtest_pass_4bit(); + // testFromJSONtest_and_4bit(); + // testFromJSONtest_and_4_2bit(); + // testFromJSONtest_mux_4bit(); + // testFromJSONtest_addr_4bit(); + // testFromJSONtest_register_4bit(); + // testSequentialCircuit(); + // testFromJSONtest_counter_4bit(); + // testPrioritySetVisitor(); + // + // testNOT(); + // testMUX(); + // testBinopGates(); + // testFromJSONtest_pass_4bit(); + // testFromJSONtest_pass_4bit(); + // testFromJSONtest_and_4bit(); + // testFromJSONtest_and_4_2bit(); + // testFromJSONtest_mux_4bit(); + // testFromJSONtest_addr_4bit(); + // testFromJSONtest_register_4bit(); + // testSequentialCircuit(); + // testFromJSONtest_counter_4bit(); + // testPrioritySetVisitor(); + // testTFHEppSerialization(); + // + //#ifdef IYOKAN_CUDA_ENABLED + // { + // CUFHETestHelper::CUFHEManager man; + // + // testNOT(); + // testMUX(); + // testBinopGates(); + // testFromJSONtest_pass_4bit(); + // testFromJSONtest_and_4bit(); + // testFromJSONtest_and_4_2bit(); + // testFromJSONtest_mux_4bit(); + // testFromJSONtest_addr_4bit(); + // testFromJSONtest_register_4bit(); + // testSequentialCircuit(); + // testFromJSONtest_counter_4bit(); + // testPrioritySetVisitor(); + // testBridgeBetweenCUFHEAndTFHEpp(); + // } + //#endif + + // testProgressGraphMaker(); + // testBlueprint(); + + using namespace nt; + + nt::error::initialize(); + + nt::testAllocator(); + nt::test0(); + // nt::plain::test0(); + nt::tfhepp::test0(); + + return 0; } diff --git a/thirdparty/loguru b/thirdparty/loguru new file mode 160000 index 0000000..323d0eb --- /dev/null +++ b/thirdparty/loguru @@ -0,0 +1 @@ +Subproject commit 323d0eb1b7ba0bda39d9b8494aca456639bfd2d5