Skip to content

Commit eee65dd

Browse files
authored
Merge pull request #199 from openfheorg/187-fixes-for-openfhe
Updates required for the new OpenFHE release v1.3.0
2 parents f154bde + 5c3b548 commit eee65dd

1 file changed

Lines changed: 84 additions & 5 deletions

File tree

src/lib/bindings.cpp

Lines changed: 84 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,16 +1050,95 @@ void bind_keys(py::module &m)
10501050
.def_readwrite("secretKey", &KeyPair<DCRTPoly>::secretKey)
10511051
.def("good", &KeyPair<DCRTPoly>::good,kp_good_docs);
10521052
py::class_<EvalKeyImpl<DCRTPoly>, std::shared_ptr<EvalKeyImpl<DCRTPoly>>>(m, "EvalKey")
1053-
.def(py::init<>())
1053+
.def(py::init<>())
10541054
.def("GetKeyTag", &EvalKeyImpl<DCRTPoly>::GetKeyTag)
10551055
.def("SetKeyTag", &EvalKeyImpl<DCRTPoly>::SetKeyTag);
10561056
py::class_<std::map<usint, EvalKey<DCRTPoly>>, std::shared_ptr<std::map<usint, EvalKey<DCRTPoly>>>>(m, "EvalKeyMap")
10571057
.def(py::init<>());
10581058
}
10591059

1060+
// PlaintextImpl is an abstract class, so we should use a helper (trampoline) class
1061+
class PlaintextImpl_helper : public PlaintextImpl
1062+
{
1063+
public:
1064+
using PlaintextImpl::PlaintextImpl; // inherited constructors
1065+
1066+
// the PlaintextImpl virtual functions' overrides
1067+
bool Encode() override {
1068+
PYBIND11_OVERRIDE_PURE(
1069+
bool, // return type
1070+
PlaintextImpl, // parent class
1071+
Encode // function name
1072+
// no arguments
1073+
);
1074+
}
1075+
bool Decode() override {
1076+
PYBIND11_OVERRIDE_PURE(
1077+
bool, // return type
1078+
PlaintextImpl, // parent class
1079+
Decode // function name
1080+
// no arguments
1081+
);
1082+
}
1083+
bool Decode(size_t depth, double scalingFactor, ScalingTechnique scalTech, ExecutionMode executionMode) override {
1084+
PYBIND11_OVERRIDE(
1085+
bool, // return type
1086+
PlaintextImpl, // parent class
1087+
Decode, // function name
1088+
depth, scalingFactor, scalTech, executionMode // arguments
1089+
);
1090+
}
1091+
size_t GetLength() const override {
1092+
PYBIND11_OVERRIDE_PURE(
1093+
size_t, // return type
1094+
PlaintextImpl, // parent class
1095+
GetLength // function name
1096+
// no arguments
1097+
);
1098+
}
1099+
void SetLength(size_t newSize) override {
1100+
PYBIND11_OVERRIDE(
1101+
void, // return type
1102+
PlaintextImpl, // parent class
1103+
SetLength, // function name
1104+
newSize // arguments
1105+
);
1106+
}
1107+
double GetLogError() const override {
1108+
PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogError);
1109+
}
1110+
double GetLogPrecision() const override {
1111+
PYBIND11_OVERRIDE(double, PlaintextImpl, GetLogPrecision);
1112+
}
1113+
const std::string& GetStringValue() const override {
1114+
PYBIND11_OVERRIDE(const std::string&, PlaintextImpl, GetStringValue);
1115+
}
1116+
const std::vector<int64_t>& GetCoefPackedValue() const override {
1117+
PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetCoefPackedValue);
1118+
}
1119+
const std::vector<int64_t>& GetPackedValue() const override {
1120+
PYBIND11_OVERRIDE(const std::vector<int64_t>&, PlaintextImpl, GetPackedValue);
1121+
}
1122+
const std::vector<std::complex<double>>& GetCKKSPackedValue() const override {
1123+
PYBIND11_OVERRIDE(const std::vector<std::complex<double>>&, PlaintextImpl, GetCKKSPackedValue);
1124+
}
1125+
std::vector<double> GetRealPackedValue() const override {
1126+
PYBIND11_OVERRIDE(std::vector<double>, PlaintextImpl, GetRealPackedValue);
1127+
}
1128+
void SetStringValue(const std::string& str) override {
1129+
PYBIND11_OVERRIDE(void, PlaintextImpl, SetStringValue, str);
1130+
}
1131+
void SetIntVectorValue(const std::vector<int64_t>& vec) override {
1132+
PYBIND11_OVERRIDE(void, PlaintextImpl, SetIntVectorValue, vec);
1133+
}
1134+
std::string GetFormattedValues(int64_t precision) const override {
1135+
PYBIND11_OVERRIDE(std::string, PlaintextImpl, GetFormattedValues, precision);
1136+
}
1137+
};
1138+
10601139
void bind_encodings(py::module &m)
10611140
{
1062-
py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>>(m, "Plaintext")
1141+
py::class_<PlaintextImpl, std::shared_ptr<PlaintextImpl>, PlaintextImpl_helper>(m, "Plaintext")
10631142
.def("GetScalingFactor", &PlaintextImpl::GetScalingFactor,
10641143
ptx_GetScalingFactor_docs)
10651144
.def("SetScalingFactor", &PlaintextImpl::SetScalingFactor,
@@ -1069,8 +1148,6 @@ void bind_encodings(py::module &m)
10691148
ptx_GetSchemeID_docs)
10701149
.def("GetLength", &PlaintextImpl::GetLength,
10711150
ptx_GetLength_docs)
1072-
.def("GetSchemeID", &PlaintextImpl::GetSchemeID,
1073-
ptx_GetSchemeID_docs)
10741151
.def("SetLength", &PlaintextImpl::SetLength,
10751152
ptx_SetLength_docs,
10761153
py::arg("newSize"))
@@ -1080,7 +1157,9 @@ void bind_encodings(py::module &m)
10801157
ptx_GetLogPrecision_docs)
10811158
.def("Encode", &PlaintextImpl::Encode,
10821159
ptx_Encode_docs)
1083-
.def("Decode", &PlaintextImpl::Decode,
1160+
.def("Decode", py::overload_cast<>(&PlaintextImpl::Decode),
1161+
ptx_Decode_docs)
1162+
.def("Decode", py::overload_cast<size_t, double, ScalingTechnique, ExecutionMode>(&PlaintextImpl::Decode),
10841163
ptx_Decode_docs)
10851164
.def("LowBound", &PlaintextImpl::LowBound,
10861165
ptx_LowBound_docs)

0 commit comments

Comments
 (0)