Skip to content

Commit 23e9434

Browse files
committed
add complex support
1 parent 662c1df commit 23e9434

1 file changed

Lines changed: 84 additions & 0 deletions

File tree

src/wrapper.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using namespace seal;
99
namespace py = pybind11;
1010

1111
PYBIND11_MAKE_OPAQUE(std::vector<double>);
12+
PYBIND11_MAKE_OPAQUE(std::vector<std::complex<double>>);
1213
PYBIND11_MAKE_OPAQUE(std::vector<std::uint64_t>);
1314
PYBIND11_MAKE_OPAQUE(std::vector<std::int64_t>);
1415

@@ -18,6 +19,7 @@ PYBIND11_MODULE(seal, m)
1819
m.attr("__version__") = "4.1.2";
1920

2021
py::bind_vector<std::vector<double>>(m, "VectorDouble", py::buffer_protocol());
22+
py::bind_vector<std::vector<std::complex<double>>>(m, "VectorComplex", py::buffer_protocol());
2123
py::bind_vector<std::vector<std::uint64_t>>(m, "VectorUInt", py::buffer_protocol());
2224
py::bind_vector<std::vector<std::int64_t>>(m, "VectorInt", py::buffer_protocol());
2325

@@ -706,9 +708,53 @@ PYBIND11_MODULE(seal, m)
706708
py::class_<CKKSEncoder>(m, "CKKSEncoder")
707709
.def(py::init<const SEALContext &>())
708710
.def("slot_count", &CKKSEncoder::slot_count)
711+
.def("encode_complex", [](CKKSEncoder &encoder, const std::vector<std::complex<double>> &values, double scale, Plaintext &destination){
712+
encoder.encode(values, scale, destination);
713+
})
709714
.def("encode", [](CKKSEncoder &encoder, const std::vector<double> &values, double scale, Plaintext &destination){
710715
encoder.encode(values, scale, destination);
711716
})
717+
.def("encode_complex", [](CKKSEncoder &encoder, py::array_t<std::complex<double>> values, double scale){
718+
py::buffer_info buf = values.request();
719+
if (buf.ndim == 0)
720+
{
721+
auto *ptr = static_cast<std::complex<double> *>(buf.ptr);
722+
Plaintext pt;
723+
encoder.encode(ptr[0], scale, pt);
724+
return pt;
725+
}
726+
if (buf.ndim != 1)
727+
throw std::runtime_error("E101: Number of dimensions must be one");
728+
729+
auto *ptr = static_cast<std::complex<double> *>(buf.ptr);
730+
std::vector<std::complex<double>> vec(static_cast<std::size_t>(buf.shape[0]));
731+
732+
for (py::ssize_t i = 0; i < buf.shape[0]; i++)
733+
vec[static_cast<std::size_t>(i)] = ptr[i];
734+
735+
Plaintext pt;
736+
encoder.encode(vec, scale, pt);
737+
return pt;
738+
})
739+
.def("encode_complex", [](CKKSEncoder &encoder, py::array_t<std::complex<double>> values, double scale, Plaintext &destination){
740+
py::buffer_info buf = values.request();
741+
if (buf.ndim == 0)
742+
{
743+
auto *ptr = static_cast<std::complex<double> *>(buf.ptr);
744+
encoder.encode(ptr[0], scale, destination);
745+
return;
746+
}
747+
if (buf.ndim != 1)
748+
throw std::runtime_error("E101: Number of dimensions must be one");
749+
750+
auto *ptr = static_cast<std::complex<double> *>(buf.ptr);
751+
std::vector<std::complex<double>> vec(static_cast<std::size_t>(buf.shape[0]));
752+
753+
for (py::ssize_t i = 0; i < buf.shape[0]; i++)
754+
vec[static_cast<std::size_t>(i)] = ptr[i];
755+
756+
encoder.encode(vec, scale, destination);
757+
})
712758
.def("encode", [](CKKSEncoder &encoder, py::array_t<double> values, double scale){
713759
py::buffer_info buf = values.request();
714760
if (buf.ndim != 1)
@@ -724,6 +770,23 @@ PYBIND11_MODULE(seal, m)
724770
encoder.encode(vec, scale, pt);
725771
return pt;
726772
})
773+
.def("encode_complex", [](CKKSEncoder &encoder, py::iterable values, double scale){
774+
std::vector<std::complex<double>> vec;
775+
vec.reserve(py::len(values));
776+
for (const auto &value : values)
777+
vec.push_back(py::cast<std::complex<double>>(value));
778+
779+
Plaintext pt;
780+
encoder.encode(vec, scale, pt);
781+
return pt;
782+
})
783+
.def("encode_complex", [](CKKSEncoder &encoder, py::iterable values, double scale, Plaintext &destination){
784+
std::vector<std::complex<double>> vec;
785+
vec.reserve(py::len(values));
786+
for (const auto &value : values)
787+
vec.push_back(py::cast<std::complex<double>>(value));
788+
encoder.encode(vec, scale, destination);
789+
})
727790
.def("encode", [](CKKSEncoder &encoder, py::iterable values, double scale){
728791
std::vector<double> vec;
729792
vec.reserve(py::len(values));
@@ -749,6 +812,14 @@ PYBIND11_MODULE(seal, m)
749812
.def("encode", [](CKKSEncoder &encoder, double value, double scale, Plaintext &destination){
750813
encoder.encode(value, scale, destination);
751814
})
815+
.def("encode_complex", [](CKKSEncoder &encoder, std::complex<double> value, double scale){
816+
Plaintext pt;
817+
encoder.encode(value, scale, pt);
818+
return pt;
819+
})
820+
.def("encode_complex", [](CKKSEncoder &encoder, std::complex<double> value, double scale, Plaintext &destination){
821+
encoder.encode(value, scale, destination);
822+
})
752823
.def("encode", [](CKKSEncoder &encoder, std::int64_t value){
753824
Plaintext pt;
754825
encoder.encode(value, pt);
@@ -768,6 +839,19 @@ PYBIND11_MODULE(seal, m)
768839
for (auto i = 0; i < buf.shape[0]; i++)
769840
ptr[i] = destination[i];
770841

842+
return values;
843+
})
844+
.def("decode_complex", [](CKKSEncoder &encoder, const Plaintext &plain){
845+
std::vector<std::complex<double>> destination;
846+
encoder.decode(plain, destination);
847+
848+
py::array_t<std::complex<double>> values(destination.size());
849+
py::buffer_info buf = values.request();
850+
auto *ptr = static_cast<std::complex<double> *>(buf.ptr);
851+
852+
for (py::ssize_t i = 0; i < buf.shape[0]; i++)
853+
ptr[i] = destination[static_cast<std::size_t>(i)];
854+
771855
return values;
772856
});
773857

0 commit comments

Comments
 (0)