From e3894cbabf9afaff8d615dbc74e149bd233ed30f Mon Sep 17 00:00:00 2001 From: Dmitriy Suponitskiy Date: Tue, 26 Nov 2024 06:30:48 -0500 Subject: [PATCH 1/2] Added overloaded PlaintextImpl::Decode and removed a duplicate binding of PlaintextImpl::GetSchemeID --- src/lib/bindings.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lib/bindings.cpp b/src/lib/bindings.cpp index 6464425..d5fcee1 100644 --- a/src/lib/bindings.cpp +++ b/src/lib/bindings.cpp @@ -1050,7 +1050,7 @@ void bind_keys(py::module &m) .def_readwrite("secretKey", &KeyPair::secretKey) .def("good", &KeyPair::good,kp_good_docs); py::class_, std::shared_ptr>>(m, "EvalKey") - .def(py::init<>()) + .def(py::init<>()) .def("GetKeyTag", &EvalKeyImpl::GetKeyTag) .def("SetKeyTag", &EvalKeyImpl::SetKeyTag); py::class_>, std::shared_ptr>>>(m, "EvalKeyMap") @@ -1069,8 +1069,6 @@ void bind_encodings(py::module &m) ptx_GetSchemeID_docs) .def("GetLength", &PlaintextImpl::GetLength, ptx_GetLength_docs) - .def("GetSchemeID", &PlaintextImpl::GetSchemeID, - ptx_GetSchemeID_docs) .def("SetLength", &PlaintextImpl::SetLength, ptx_SetLength_docs, py::arg("newSize")) @@ -1080,7 +1078,9 @@ void bind_encodings(py::module &m) ptx_GetLogPrecision_docs) .def("Encode", &PlaintextImpl::Encode, ptx_Encode_docs) - .def("Decode", &PlaintextImpl::Decode, + .def("Decode", py::overload_cast<>(&PlaintextImpl::Decode), + ptx_Decode_docs) + .def("Decode", py::overload_cast(&PlaintextImpl::Decode), ptx_Decode_docs) .def("LowBound", &PlaintextImpl::LowBound, ptx_LowBound_docs) From 5c3b5487bc1ccba5907133b3cb1a03a27f6d4f9d Mon Sep 17 00:00:00 2001 From: Dmitriy Suponitskiy Date: Wed, 26 Mar 2025 14:58:25 -0400 Subject: [PATCH 2/2] Added a trampoline class PlaintextImpl_helper to override PlaintextImpl's virtual functions --- src/lib/bindings.cpp | 81 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/src/lib/bindings.cpp b/src/lib/bindings.cpp index d5fcee1..0949217 100644 --- a/src/lib/bindings.cpp +++ b/src/lib/bindings.cpp @@ -1057,9 +1057,88 @@ void bind_keys(py::module &m) .def(py::init<>()); } +// PlaintextImpl is an abstract class, so we should use a helper (trampoline) class +class PlaintextImpl_helper : public PlaintextImpl +{ +public: + using PlaintextImpl::PlaintextImpl; // inherited constructors + + // the PlaintextImpl virtual functions' overrides + bool Encode() override { + PYBIND11_OVERRIDE_PURE( + bool, // return type + PlaintextImpl, // parent class + Encode // function name + // no arguments + ); + } + bool Decode() override { + PYBIND11_OVERRIDE_PURE( + bool, // return type + PlaintextImpl, // parent class + Decode // function name + // no arguments + ); + } + bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override { + PYBIND11_OVERRIDE( + bool, // return type + PlaintextImpl, // parent class + Decode, // function name + depth, scalingFactor, scalTech, executionMode // arguments + ); + } + size_t GetLength() const override { + PYBIND11_OVERRIDE_PURE( + size_t, // return type + PlaintextImpl, // parent class + GetLength // function name + // no arguments + ); + } + void SetLength(size_t newSize) override { + PYBIND11_OVERRIDE( + void, // return type + PlaintextImpl, // parent class + SetLength, // function name + newSize // arguments + ); + } + double GetLogError() const override { + PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogError); + } + double GetLogPrecision() const override { + PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogPrecision); + } + const std::string& GetStringValue() const override { + PYBIND11_OVERRIDE(const std::string&, PlaintextImpl, GetStringValue); + } + const std::vector& GetCoefPackedValue() const override { + PYBIND11_OVERRIDE(const std::vector&, PlaintextImpl, GetCoefPackedValue); + } + const std::vector& GetPackedValue() const override { + PYBIND11_OVERRIDE(const std::vector&, PlaintextImpl, GetPackedValue); + } + const std::vector>& GetCKKSPackedValue() const override { + PYBIND11_OVERRIDE(const std::vector>&, PlaintextImpl, GetCKKSPackedValue); + } + std::vector GetRealPackedValue() const override { + PYBIND11_OVERRIDE(std::vector, PlaintextImpl, GetRealPackedValue); + } + void SetStringValue(const std::string& str) override { + PYBIND11_OVERRIDE(void, PlaintextImpl, SetStringValue, str); + } + void SetIntVectorValue(const std::vector& vec) override { + PYBIND11_OVERRIDE(void, PlaintextImpl, SetIntVectorValue, vec); + } + std::string GetFormattedValues(int64_t precision) const override { + PYBIND11_OVERRIDE(std::string, PlaintextImpl, GetFormattedValues, precision); + } +}; + void bind_encodings(py::module &m) { - py::class_>(m, "Plaintext") + py::class_, PlaintextImpl_helper>(m, "Plaintext") .def("GetScalingFactor", &PlaintextImpl::GetScalingFactor, ptx_GetScalingFactor_docs) .def("SetScalingFactor", &PlaintextImpl::SetScalingFactor,