diff --git a/CMakeLists.txt b/CMakeLists.txt index 749057f..828075b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,14 @@ set(OPENFHE_PYTHON_VERSION_PATCH 0) set(OPENFHE_PYTHON_VERSION_TWEAK 0) set(OPENFHE_PYTHON_VERSION ${OPENFHE_PYTHON_VERSION_MAJOR}.${OPENFHE_PYTHON_VERSION_MINOR}.${OPENFHE_PYTHON_VERSION_PATCH}.${OPENFHE_PYTHON_VERSION_TWEAK}) +# OpenFHE version can be specified externally (-DOPENFHE_REQUIRED_VERSION=1.3.0) +if(NOT DEFINED OPENFHE_REQUIRED_VERSION) + set(OPENFHE_REQUIRED_VERSION "1.3.0" CACHE STRING "Required OpenFHE version") +else() + # User provided OPENFHE_REQUIRED_VERSION via -D + message(STATUS "Using user-specified OpenFHE version: ${OPENFHE_REQUIRED_VERSION}") +endif() + set(CMAKE_CXX_STANDARD 17) option( BUILD_STATIC "Set to ON to include static versions of the library" OFF) @@ -15,7 +23,9 @@ if(APPLE) set(CMAKE_CXX_VISIBILITY_PRESET default) endif() -find_package(OpenFHE 1.3.0 REQUIRED) +find_package(OpenFHE ${OPENFHE_REQUIRED_VERSION} REQUIRED) +message(STATUS "Building with OpenFHE version: ${OPENFHE_REQUIRED_VERSION}") + set(PYBIND11_FINDPYTHON ON) find_package(pybind11 REQUIRED) @@ -66,20 +76,13 @@ pybind11_add_module(openfhe ### Python installation # Allow the user to specify the path to Python executable (if not provided, find it) option(PYTHON_EXECUTABLE_PATH "Path to Python executable" "") - -if(NOT PYTHON_EXECUTABLE_PATH) - # Find Python and its development components - find_package(Python REQUIRED COMPONENTS Interpreter Development) -else() - # Set Python_EXECUTABLE to the specified path +if(PYTHON_EXECUTABLE_PATH) set(Python_EXECUTABLE "${PYTHON_EXECUTABLE_PATH}") endif() - -# Find Python interpreter -find_package(PythonInterp REQUIRED) +find_package(Python REQUIRED COMPONENTS Interpreter Development) # Check Python version -if(${PYTHON_VERSION_MAJOR} EQUAL 3 AND ${PYTHON_VERSION_MINOR} GREATER_EQUAL 10) +if(${Python_VERSION_MAJOR} EQUAL 3 AND ${Python_VERSION_MINOR} GREATER_EQUAL 10) execute_process( COMMAND "${Python_EXECUTABLE}" -c "from sys import exec_prefix; print(exec_prefix)" OUTPUT_VARIABLE PYTHON_SITE_PACKAGES @@ -101,3 +104,5 @@ else() endif() message("***** INSTALL IS AT ${Python_Install_Location}; to change, run cmake with -DCMAKE_INSTALL_PREFIX=/your/path") install(TARGETS openfhe LIBRARY DESTINATION ${Python_Install_Location}) +install(FILES ${CMAKE_SOURCE_DIR}/__init__.py DESTINATION ${Python_Install_Location}) + diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..db9b485 --- /dev/null +++ b/__init__.py @@ -0,0 +1,48 @@ +import os +import ctypes + + +def load_shared_library(libname, paths): + for path in paths: + lib_path = os.path.join(path, libname) + if os.path.exists(lib_path): + return ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) + + raise FileNotFoundError( + f"Shared library {libname} not found in {paths}" + ) + +# Search LD_LIBRARY_PATH +ld_paths = os.environ.get("LD_LIBRARY_PATH", "").split(":") + +if not any(ld_paths): + # Path to the bundled `lib/` directory inside site-packages + package_dir = os.path.abspath(os.path.dirname(__file__)) + internal_lib_dir = [os.path.join(package_dir, 'lib')] + + # Shared libraries required + shared_libs = [ + 'libgomp.so', + 'libOPENFHEcore.so.1', + 'libOPENFHEbinfhe.so.1', + 'libOPENFHEpke.so.1', + ] + + for libname in shared_libs: + load_shared_library(libname, internal_lib_dir) + + from .openfhe import * + +else: + # Shared libraries required + # skip 'libgomp.so' if LD_LIBRARY_PATH is set as we should get it from the libgomp.so location + shared_libs = [ + 'libOPENFHEcore.so.1', + 'libOPENFHEbinfhe.so.1', + 'libOPENFHEpke.so.1', + ] + + for libname in shared_libs: + load_shared_library(libname, ld_paths) + + # from .openfhe import * diff --git a/openfhe/__init__.py b/openfhe/__init__.py deleted file mode 100644 index 26b7231..0000000 --- a/openfhe/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from openfhe.openfhe import * diff --git a/setup.py b/setup.py deleted file mode 100755 index af09a2c..0000000 --- a/setup.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import subprocess -import sys -from setuptools import setup, Extension -from setuptools.command.sdist import sdist as _sdist -from setuptools.command.build_ext import build_ext as _build_ext -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel -import glob -import shutil - -__version__ = '0.9.0' -OPENFHE_PATH = 'openfhe/' -OPENFHE_LIB = 'openfhe.so' - -class CMakeExtension(Extension): - def __init__(self, name, sourcedir=''): - super().__init__(name, sources=[]) - self.sourcedir = os.path.abspath(sourcedir) - -class CMakeBuild(_build_ext): - - def run(self): - for ext in self.extensions: - self.build_cmake(ext) - - def build_cmake(self, ext): - if os.path.exists(OPENFHE_PATH + OPENFHE_LIB): - return - extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) - print(extdir) - cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, - '-DPYTHON_EXECUTABLE=' + sys.executable] - - cfg = 'Debug' if self.debug else 'Release' - build_args = ['--config', cfg] - - build_temp = os.path.abspath(self.build_temp) - os.makedirs(build_temp, exist_ok=True) - - num_cores = os.cpu_count() or 1 - build_args += ['--parallel', str(num_cores)] - - subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, cwd=build_temp) - subprocess.check_call(['cmake', '--build', '.', '--target', ext.name] + build_args, cwd=build_temp) - - so_files = glob.glob(os.path.join(extdir, '*.so')) - if not so_files: - raise RuntimeError("Cannot find any built .so file in " + extdir) - - src_file = so_files[0] - dst_file = os.path.join('openfhe', OPENFHE_LIB) - shutil.move(src_file, dst_file) - -# Run build_ext before sdist -class SDist(_sdist): - def run(self): - if os.path.exists(OPENFHE_PATH + OPENFHE_LIB): - os.remove(OPENFHE_PATH + OPENFHE_LIB) - self.run_command('build_ext') - super().run() - -setup( - name='openfhe', - version=__version__, - description='Python wrapper for OpenFHE C++ library.', - author='OpenFHE Team', - author_email='contact@openfhe.org', - url='https://github.com/openfheorg/openfhe-python', - license='BSD-2-Clause', - packages=['openfhe'], - package_data={'openfhe': ['*.so', '*.pyi']}, - ext_modules=[CMakeExtension('openfhe', sourcedir='')], - cmdclass={ - 'build_ext': CMakeBuild, - 'sdist': SDist - }, - include_package_data=True, - python_requires=">=3.6", - install_requires=['pybind11', 'pybind11-global', 'pybind11-stubgen'], - tests_require = ['pytest'], -) diff --git a/src/include/pke/cryptocontext_wrapper.h b/src/include/pke/cryptocontext_wrapper.h index cf26791..8981991 100644 --- a/src/include/pke/cryptocontext_wrapper.h +++ b/src/include/pke/cryptocontext_wrapper.h @@ -54,6 +54,7 @@ Plaintext MultipartyDecryptFusionWrapper(CryptoContext& self,const std const std::shared_ptr>> GetEvalSumKeyMapWrapper(CryptoContext& self, const std::string &id); PlaintextModulus GetPlaintextModulusWrapper(CryptoContext& self); +uint32_t GetBatchSizeWrapper(CryptoContext& self); double GetModulusWrapper(CryptoContext& self); void RemoveElementWrapper(Ciphertext& self, uint32_t index); double GetScalingFactorRealWrapper(CryptoContext& self, uint32_t l); diff --git a/src/lib/bindings.cpp b/src/lib/bindings.cpp index 86ba7cf..cab098c 100644 --- a/src/lib/bindings.cpp +++ b/src/lib/bindings.cpp @@ -166,6 +166,7 @@ void bind_crypto_context(py::module &m) //.def("GetCryptoParameters", &CryptoContextImpl::GetCryptoParameters) .def("GetRingDimension", &CryptoContextImpl::GetRingDimension, cc_GetRingDimension_docs) .def("GetPlaintextModulus", &GetPlaintextModulusWrapper, cc_GetPlaintextModulus_docs) + .def("GetBatchSize", &GetBatchSizeWrapper) .def("GetModulus", &GetModulusWrapper, cc_GetModulus_docs) .def("GetModulusCKKS", &GetModulusCKKSWrapper) .def("GetScalingFactorReal", &GetScalingFactorRealWrapper, cc_GetScalingFactorReal_docs) @@ -868,101 +869,91 @@ void bind_crypto_context(py::module &m) cc_InsertEvalAutomorphismKey_docs, py::arg("evalKeyMap"), py::arg("keyTag") = "") - .def_static( - "ClearEvalAutomorphismKeys", []() - { CryptoContextImpl::ClearEvalAutomorphismKeys(); }, + .def_static("ClearEvalAutomorphismKeys", []() { + CryptoContextImpl::ClearEvalAutomorphismKeys(); + }, cc_ClearEvalAutomorphismKeys_docs) // it is safer to return by value instead of by reference (GetEvalMultKeyVector returns a const reference to std::vector) - .def_static("GetEvalMultKeyVector", - [](const std::string& keyTag) { - return CryptoContextImpl::GetEvalMultKeyVector(keyTag); + .def_static("GetEvalMultKeyVector", [](const std::string& keyTag) { + return CryptoContextImpl::GetEvalMultKeyVector(keyTag); }, cc_GetEvalMultKeyVector_docs, py::arg("keyTag") = "") .def_static("GetEvalAutomorphismKeyMap", &CryptoContextImpl::GetEvalAutomorphismKeyMapPtr, cc_GetEvalAutomorphismKeyMap_docs, py::arg("keyTag") = "") - .def_static( - "SerializeEvalMultKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string keyTag = "") - { - std::ofstream outfile(filename, std::ios::out | std::ios::binary); - bool res = CryptoContextImpl::SerializeEvalMultKey(outfile, sertype, keyTag); - outfile.close(); - return res; }, + .def_static("SerializeEvalMultKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string keyTag = "") { + std::ofstream outfile(filename, std::ios::out | std::ios::binary); + bool res = CryptoContextImpl::SerializeEvalMultKey(outfile, sertype, keyTag); + outfile.close(); + return res; + }, cc_SerializeEvalMultKey_docs, py::arg("filename"), py::arg("sertype"), py::arg("keyTag") = "") - .def_static( // SerializeEvalMultKey - JSON - "SerializeEvalMultKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string keyTag = "") - { - std::ofstream outfile(filename, std::ios::out | std::ios::binary); - bool res = CryptoContextImpl::SerializeEvalMultKey(outfile, sertype, keyTag); - outfile.close(); - return res; }, + .def_static("SerializeEvalMultKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string keyTag = "") { + std::ofstream outfile(filename, std::ios::out | std::ios::binary); + bool res = CryptoContextImpl::SerializeEvalMultKey(outfile, sertype, keyTag); + outfile.close(); + return res; + }, cc_SerializeEvalMultKey_docs, py::arg("filename"), py::arg("sertype"), py::arg("keyTag") = "") - .def_static( // SerializeEvalAutomorphismKey - Binary - "SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string keyTag = "") - { - std::ofstream outfile(filename, std::ios::out | std::ios::binary); - bool res = CryptoContextImpl::SerializeEvalAutomorphismKey(outfile, sertype, keyTag); - outfile.close(); - return res; }, + .def_static("SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERBINARY &sertype, std::string keyTag = "") { + std::ofstream outfile(filename, std::ios::out | std::ios::binary); + bool res = CryptoContextImpl::SerializeEvalAutomorphismKey(outfile, sertype, keyTag); + outfile.close(); + return res; + }, cc_SerializeEvalAutomorphismKey_docs, py::arg("filename"), py::arg("sertype"), py::arg("keyTag") = "") - .def_static( // SerializeEvalAutomorphismKey - JSON - "SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string keyTag = "") - { - std::ofstream outfile(filename, std::ios::out | std::ios::binary); - bool res = CryptoContextImpl::SerializeEvalAutomorphismKey(outfile, sertype, keyTag); - outfile.close(); - return res; }, + .def_static("SerializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERJSON &sertype, std::string keyTag = "") { + std::ofstream outfile(filename, std::ios::out | std::ios::binary); + bool res = CryptoContextImpl::SerializeEvalAutomorphismKey(outfile, sertype, keyTag); + outfile.close(); + return res; + }, cc_SerializeEvalAutomorphismKey_docs, py::arg("filename"), py::arg("sertype"), py::arg("keyTag") = "") - .def_static("DeserializeEvalMultKey", // DeserializeEvalMultKey - Binary - [](const std::string &filename, const SerType::SERBINARY &sertype) - { - std::ifstream emkeys(filename, std::ios::in | std::ios::binary); - if (!emkeys.is_open()) { - std::cerr << "I cannot read serialization from " << filename << std::endl; - } - bool res = CryptoContextImpl::DeserializeEvalMultKey(emkeys, sertype); - return res; - }, - cc_DeserializeEvalMultKey_docs, - py::arg("filename"), py::arg("sertype")) - .def_static("DeserializeEvalMultKey", // DeserializeEvalMultKey - JSON - [](const std::string &filename, const SerType::SERJSON &sertype) - { - std::ifstream emkeys(filename, std::ios::in | std::ios::binary); - if (!emkeys.is_open()) { - std::cerr << "I cannot read serialization from " << filename << std::endl; - } - bool res = CryptoContextImpl::DeserializeEvalMultKey(emkeys, sertype); - return res; }, - cc_DeserializeEvalMultKey_docs, - py::arg("filename"), py::arg("sertype")) - .def_static("DeserializeEvalAutomorphismKey", // DeserializeEvalAutomorphismKey - Binary - [](const std::string &filename, const SerType::SERBINARY &sertype) - { - std::ifstream erkeys(filename, std::ios::in | std::ios::binary); - if (!erkeys.is_open()) { - std::cerr << "I cannot read serialization from " << filename << std::endl; - } - bool res = CryptoContextImpl::DeserializeEvalAutomorphismKey(erkeys, sertype); - return res; }, - cc_DeserializeEvalAutomorphismKey_docs, - py::arg("filename"), py::arg("sertype")) - .def_static("DeserializeEvalAutomorphismKey", // DeserializeEvalAutomorphismKey - JSON - [](const std::string &filename, const SerType::SERJSON &sertype) - { - std::ifstream erkeys(filename, std::ios::in | std::ios::binary); - if (!erkeys.is_open()) { - std::cerr << "I cannot read serialization from " << filename << std::endl; - } - bool res = CryptoContextImpl::DeserializeEvalAutomorphismKey(erkeys, sertype); - return res; }, - cc_DeserializeEvalAutomorphismKey_docs, - py::arg("filename"), py::arg("sertype")); + .def_static("DeserializeEvalMultKey", [](const std::string &filename, const SerType::SERBINARY &sertype) { + std::ifstream emkeys(filename, std::ios::in | std::ios::binary); + if (!emkeys.is_open()) { + std::cerr << "I cannot read serialization from " << filename << std::endl; + } + bool res = CryptoContextImpl::DeserializeEvalMultKey(emkeys, sertype); + return res; + }, + cc_DeserializeEvalMultKey_docs, + py::arg("filename"), py::arg("sertype")) + .def_static("DeserializeEvalMultKey", [](const std::string &filename, const SerType::SERJSON &sertype) { + std::ifstream emkeys(filename, std::ios::in | std::ios::binary); + if (!emkeys.is_open()) { + std::cerr << "I cannot read serialization from " << filename << std::endl; + } + bool res = CryptoContextImpl::DeserializeEvalMultKey(emkeys, sertype); + return res; + }, + cc_DeserializeEvalMultKey_docs, + py::arg("filename"), py::arg("sertype")) + .def_static("DeserializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERBINARY &sertype) { + std::ifstream erkeys(filename, std::ios::in | std::ios::binary); + if (!erkeys.is_open()) { + std::cerr << "I cannot read serialization from " << filename << std::endl; + } + bool res = CryptoContextImpl::DeserializeEvalAutomorphismKey(erkeys, sertype); + return res; + }, + cc_DeserializeEvalAutomorphismKey_docs, + py::arg("filename"), py::arg("sertype")) + .def_static("DeserializeEvalAutomorphismKey", [](const std::string &filename, const SerType::SERJSON &sertype) { + std::ifstream erkeys(filename, std::ios::in | std::ios::binary); + if (!erkeys.is_open()) { + std::cerr << "I cannot read serialization from " << filename << std::endl; + } + bool res = CryptoContextImpl::DeserializeEvalAutomorphismKey(erkeys, sertype); + return res; + }, + cc_DeserializeEvalAutomorphismKey_docs, + py::arg("filename"), py::arg("sertype")); // Generator Functions m.def("GenCryptoContext", &GenCryptoContext, @@ -1159,6 +1150,7 @@ void bind_keys(py::module &m) .def("SetKeyTag", &PublicKeyImpl::SetKeyTag); py::class_, std::shared_ptr>>(m, "PrivateKey") .def(py::init<>()) + .def("GetCryptoContext", &PrivateKeyImpl::GetCryptoContext) .def("GetKeyTag", &PrivateKeyImpl::GetKeyTag) .def("SetKeyTag", &PrivateKeyImpl::SetKeyTag); py::class_>(m, "KeyPair") @@ -1302,61 +1294,97 @@ void bind_encodings(py::module &m) .def("SetStringValue", &PlaintextImpl::SetStringValue) .def("SetIntVectorValue", &PlaintextImpl::SetIntVectorValue) .def("GetFormattedValues", &PlaintextImpl::GetFormattedValues) - .def("__repr__", [](const PlaintextImpl &p) - { - std::stringstream ss; - ss << ""; - return ss.str(); }) - .def("__str__", [](const PlaintextImpl &p) - { - std::stringstream ss; - ss << p; - return ss.str(); }); + .def("__repr__", [](const PlaintextImpl &p) { + std::stringstream ss; + ss << "<Plaintext Object: " << p << ">"; + return ss.str(); + }) + .def("__str__", [](const PlaintextImpl &p) { + std::stringstream ss; + ss << p; + return ss.str(); + }); } -void bind_ciphertext(py::module &m) -{ - py::class_<CiphertextImpl<DCRTPoly>, - std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext") - .def(py::init<>()) - .def( - "__add__", - [](const Ciphertext<DCRTPoly> &a, const Ciphertext<DCRTPoly> &b) { - return a + b; - }, - py::is_operator(), pybind11::keep_alive<0, 1>()) - // .def(py::self + py::self); - // .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth) - // .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth) - .def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel, ctx_GetLevel_docs) - .def("SetLevel", &CiphertextImpl<DCRTPoly>::SetLevel, ctx_SetLevel_docs, - py::arg("level")) - .def("Clone", &CiphertextImpl<DCRTPoly>::Clone) - .def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs) - // .def("GetHopLevel", &CiphertextImpl<DCRTPoly>::GetHopLevel) - // .def("SetHopLevel", &CiphertextImpl<DCRTPoly>::SetHopLevel) - // .def("GetScalingFactor", &CiphertextImpl<DCRTPoly>::GetScalingFactor) - // .def("SetScalingFactor", &CiphertextImpl<DCRTPoly>::SetScalingFactor) - .def("GetSlots", &CiphertextImpl<DCRTPoly>::GetSlots) - .def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots) - .def("GetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::GetNoiseScaleDeg) - .def("SetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::SetNoiseScaleDeg) - .def("GetElements", [](const CiphertextImpl<DCRTPoly>& self) -> const std::vector<DCRTPoly> & { - return self.GetElements(); - }, - py::return_value_policy::reference_internal) - .def("GetElementsMutable", [](CiphertextImpl<DCRTPoly>& self) -> std::vector<DCRTPoly> & { - return self.GetElements(); - }, - py::return_value_policy::reference_internal) - .def("SetElements", [](CiphertextImpl<DCRTPoly>& self, const std::vector<DCRTPoly> &elems) { - self.SetElements(elems); - }) - .def("SetElementsMove", [](CiphertextImpl<DCRTPoly>& self, std::vector<DCRTPoly> &&elems) { - self.SetElements(std::move(elems)); - }); +void bind_ciphertext(py::module &m) { + py::class_<CiphertextImpl<DCRTPoly>, std::shared_ptr<CiphertextImpl<DCRTPoly>>>(m, "Ciphertext") + .def(py::init<>()) + .def("__add__", [](const Ciphertext<DCRTPoly> &a, const Ciphertext<DCRTPoly> &b) { + return a + b; + }, + py::is_operator(), pybind11::keep_alive<0, 1>()) + // .def(py::self + py::self); + // .def("GetDepth", &CiphertextImpl<DCRTPoly>::GetDepth) + // .def("SetDepth", &CiphertextImpl<DCRTPoly>::SetDepth) + .def("GetLevel", &CiphertextImpl<DCRTPoly>::GetLevel, ctx_GetLevel_docs) + .def("SetLevel", &CiphertextImpl<DCRTPoly>::SetLevel, ctx_SetLevel_docs, + py::arg("level")) + .def("Clone", &CiphertextImpl<DCRTPoly>::Clone) + .def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs) + // .def("GetHopLevel", &CiphertextImpl<DCRTPoly>::GetHopLevel) + // .def("SetHopLevel", &CiphertextImpl<DCRTPoly>::SetHopLevel) + // .def("GetScalingFactor", &CiphertextImpl<DCRTPoly>::GetScalingFactor) + // .def("SetScalingFactor", &CiphertextImpl<DCRTPoly>::SetScalingFactor) + .def("GetSlots", &CiphertextImpl<DCRTPoly>::GetSlots) + .def("SetSlots", &CiphertextImpl<DCRTPoly>::SetSlots) + .def("GetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::GetNoiseScaleDeg) + .def("SetNoiseScaleDeg", &CiphertextImpl<DCRTPoly>::SetNoiseScaleDeg) + .def("GetCryptoContext", &CiphertextImpl<DCRTPoly>::GetCryptoContext) + .def("GetEncodingType", &CiphertextImpl<DCRTPoly>::GetEncodingType) + .def("GetElements", [](const CiphertextImpl<DCRTPoly>& self) -> const std::vector<DCRTPoly>& { + return self.GetElements(); + }, + py::return_value_policy::reference_internal) + .def("GetElementsMutable", [](CiphertextImpl<DCRTPoly>& self) -> std::vector<DCRTPoly>& { + return self.GetElements(); + }, + py::return_value_policy::reference_internal) + .def("SetElements", [](CiphertextImpl<DCRTPoly>& self, const std::vector<DCRTPoly>& elems) { + self.SetElements(elems); + }) + .def("SetElementsMove", [](CiphertextImpl<DCRTPoly>& self, std::vector<DCRTPoly>&& elems) { + self.SetElements(std::move(elems)); + }); } +// void bind_ciphertext(py::module &m) { +// using CiphertextImplDCRT = CiphertextImpl<DCRTPoly>; +// using CiphertextDCRT = Ciphertext<DCRTPoly>; // shared_ptr<CiphertextImpl<DCRTPoly>> + +// // Bind CiphertextImpl<DCRTPoly> and expose it to Python as "Ciphertext" +// py::class_<CiphertextImplDCRT, std::shared_ptr<CiphertextImplDCRT>>(m, "Ciphertext") +// .def(py::init<>()) +// .def("__add__", [](const CiphertextDCRT &a, const CiphertextDCRT &b) { +// return a + b; +// }, +// py::is_operator(), pybind11::keep_alive<0, 1>()) +// .def("GetLevel", &CiphertextImplDCRT::GetLevel, ctx_GetLevel_docs) +// .def("SetLevel", &CiphertextImplDCRT::SetLevel, ctx_SetLevel_docs, py::arg("level")) +// .def("Clone", &CiphertextImplDCRT::Clone) +// .def("RemoveElement", &RemoveElementWrapper, cc_RemoveElement_docs) +// .def("GetSlots", &CiphertextImplDCRT::GetSlots) +// .def("SetSlots", &CiphertextImplDCRT::SetSlots) +// .def("GetNoiseScaleDeg", &CiphertextImplDCRT::GetNoiseScaleDeg) +// .def("SetNoiseScaleDeg", &CiphertextImplDCRT::SetNoiseScaleDeg) +// .def("GetCryptoContext", &CiphertextImplDCRT::GetCryptoContext) +// .def("GetEncodingType", &CiphertextImplDCRT::GetEncodingType) +// .def("GetElements", [](const CiphertextImplDCRT& self) -> const std::vector<DCRTPoly>& { +// return self.GetElements(); +// }, py::return_value_policy::reference_internal) +// .def("GetElementsMutable", [](CiphertextImplDCRT& self) -> std::vector<DCRTPoly>& { +// return self.GetElements(); +// }, py::return_value_policy::reference_internal) +// .def("SetElements", [](CiphertextImplDCRT& self, const std::vector<DCRTPoly>& elems) { +// self.SetElements(elems); +// }) +// .def("SetElementsMove", [](CiphertextImplDCRT& self, std::vector<DCRTPoly>&& elems) { +// self.SetElements(std::move(elems)); +// }); + +// // Bind the shared_ptr alias (Ciphertext<DCRTPoly>) so it picks up the methods above +// py::class_<CiphertextDCRT>(m, "_CiphertextAlias"); // hidden helper; not necessary for users +// } + void bind_schemes(py::module &m){ /*Bind schemes specific functionalities like bootstrapping functions and multiparty*/ py::class_<FHECKKSRNS>(m, "FHECKKSRNS") @@ -1409,14 +1437,13 @@ void bind_sch_swch_params(py::module &m) .def("SetRingDimension", &SchSwchParams::SetRingDimension) .def("SetScalingModSize", &SchSwchParams::SetScalingModSize) .def("SetBatchSize", &SchSwchParams::SetBatchSize) - .def("__str__",[](const SchSwchParams &params) { - std::stringstream stream; - stream << params; - return stream.str(); - }); + .def("__str__", [](const SchSwchParams &params) { + std::stringstream stream; + stream << params; + return stream.str(); + }); } - PYBIND11_MODULE(openfhe, m) { m.doc() = "Open-Source Fully Homomorphic Encryption Library"; diff --git a/src/lib/pke/cryptocontext_wrapper.cpp b/src/lib/pke/cryptocontext_wrapper.cpp index d8241c6..68bf965 100644 --- a/src/lib/pke/cryptocontext_wrapper.cpp +++ b/src/lib/pke/cryptocontext_wrapper.cpp @@ -34,10 +34,9 @@ Ciphertext<DCRTPoly> EvalFastRotationPrecomputeWrapper(CryptoContext<DCRTPoly> &self,ConstCiphertext<DCRTPoly> ciphertext) { std::shared_ptr<std::vector<DCRTPoly>> precomp = self->EvalFastRotationPrecompute(ciphertext); std::vector<DCRTPoly> elements = *(precomp.get()); - CiphertextImpl<DCRTPoly> cipherdigits = CiphertextImpl<DCRTPoly>(self); - std::shared_ptr<CiphertextImpl<DCRTPoly>> cipherdigitsPtr = std::make_shared<CiphertextImpl<DCRTPoly>>(cipherdigits); - cipherdigitsPtr->SetElements(elements); - return cipherdigitsPtr; + std::shared_ptr<CiphertextImpl<DCRTPoly>> cipherdigits = std::make_shared<CiphertextImpl<DCRTPoly>>(self); + cipherdigits->SetElements(std::move(elements)); + return cipherdigits; } Ciphertext<DCRTPoly> EvalFastRotationWrapper(CryptoContext<DCRTPoly>& self,ConstCiphertext<DCRTPoly> ciphertext, uint32_t index, uint32_t m,ConstCiphertext<DCRTPoly> digits) { @@ -78,6 +77,10 @@ PlaintextModulus GetPlaintextModulusWrapper(CryptoContext<DCRTPoly>& self){ return self->GetCryptoParameters()->GetPlaintextModulus(); } +uint32_t GetBatchSizeWrapper(CryptoContext<DCRTPoly>& self){ + return self->GetCryptoParameters()->GetBatchSize(); +} + double GetModulusWrapper(CryptoContext<DCRTPoly>& self){ return self->GetCryptoParameters()->GetElementParams()->GetModulus().ConvertToDouble(); }