@@ -9,6 +9,7 @@ using namespace seal;
99namespace py = pybind11;
1010
1111PYBIND11_MAKE_OPAQUE (std::vector<double >);
12+ PYBIND11_MAKE_OPAQUE (std::vector<std::complex <double >>);
1213PYBIND11_MAKE_OPAQUE (std::vector<std::uint64_t >);
1314PYBIND11_MAKE_OPAQUE (std::vector<std::int64_t >);
1415
@@ -18,6 +19,7 @@ PYBIND11_MODULE(seal, m)
1819 m.attr (" __version__" ) = " 4.1.2" ;
1920
2021 py::bind_vector<std::vector<double >>(m, " VectorDouble" , py::buffer_protocol ());
22+ py::bind_vector<std::vector<std::complex <double >>>(m, " VectorComplex" , py::buffer_protocol ());
2123 py::bind_vector<std::vector<std::uint64_t >>(m, " VectorUInt" , py::buffer_protocol ());
2224 py::bind_vector<std::vector<std::int64_t >>(m, " VectorInt" , py::buffer_protocol ());
2325
@@ -706,9 +708,53 @@ PYBIND11_MODULE(seal, m)
706708 py::class_<CKKSEncoder>(m, " CKKSEncoder" )
707709 .def (py::init<const SEALContext &>())
708710 .def (" slot_count" , &CKKSEncoder::slot_count)
711+ .def (" encode_complex" , [](CKKSEncoder &encoder, const std::vector<std::complex <double >> &values, double scale, Plaintext &destination){
712+ encoder.encode (values, scale, destination);
713+ })
709714 .def (" encode" , [](CKKSEncoder &encoder, const std::vector<double > &values, double scale, Plaintext &destination){
710715 encoder.encode (values, scale, destination);
711716 })
717+ .def (" encode_complex" , [](CKKSEncoder &encoder, py::array_t <std::complex <double >> values, double scale){
718+ py::buffer_info buf = values.request ();
719+ if (buf.ndim == 0 )
720+ {
721+ auto *ptr = static_cast <std::complex <double > *>(buf.ptr );
722+ Plaintext pt;
723+ encoder.encode (ptr[0 ], scale, pt);
724+ return pt;
725+ }
726+ if (buf.ndim != 1 )
727+ throw std::runtime_error (" E101: Number of dimensions must be one" );
728+
729+ auto *ptr = static_cast <std::complex <double > *>(buf.ptr );
730+ std::vector<std::complex <double >> vec (static_cast <std::size_t >(buf.shape [0 ]));
731+
732+ for (py::ssize_t i = 0 ; i < buf.shape [0 ]; i++)
733+ vec[static_cast <std::size_t >(i)] = ptr[i];
734+
735+ Plaintext pt;
736+ encoder.encode (vec, scale, pt);
737+ return pt;
738+ })
739+ .def (" encode_complex" , [](CKKSEncoder &encoder, py::array_t <std::complex <double >> values, double scale, Plaintext &destination){
740+ py::buffer_info buf = values.request ();
741+ if (buf.ndim == 0 )
742+ {
743+ auto *ptr = static_cast <std::complex <double > *>(buf.ptr );
744+ encoder.encode (ptr[0 ], scale, destination);
745+ return ;
746+ }
747+ if (buf.ndim != 1 )
748+ throw std::runtime_error (" E101: Number of dimensions must be one" );
749+
750+ auto *ptr = static_cast <std::complex <double > *>(buf.ptr );
751+ std::vector<std::complex <double >> vec (static_cast <std::size_t >(buf.shape [0 ]));
752+
753+ for (py::ssize_t i = 0 ; i < buf.shape [0 ]; i++)
754+ vec[static_cast <std::size_t >(i)] = ptr[i];
755+
756+ encoder.encode (vec, scale, destination);
757+ })
712758 .def (" encode" , [](CKKSEncoder &encoder, py::array_t <double > values, double scale){
713759 py::buffer_info buf = values.request ();
714760 if (buf.ndim != 1 )
@@ -724,6 +770,23 @@ PYBIND11_MODULE(seal, m)
724770 encoder.encode (vec, scale, pt);
725771 return pt;
726772 })
773+ .def (" encode_complex" , [](CKKSEncoder &encoder, py::iterable values, double scale){
774+ std::vector<std::complex <double >> vec;
775+ vec.reserve (py::len (values));
776+ for (const auto &value : values)
777+ vec.push_back (py::cast<std::complex <double >>(value));
778+
779+ Plaintext pt;
780+ encoder.encode (vec, scale, pt);
781+ return pt;
782+ })
783+ .def (" encode_complex" , [](CKKSEncoder &encoder, py::iterable values, double scale, Plaintext &destination){
784+ std::vector<std::complex <double >> vec;
785+ vec.reserve (py::len (values));
786+ for (const auto &value : values)
787+ vec.push_back (py::cast<std::complex <double >>(value));
788+ encoder.encode (vec, scale, destination);
789+ })
727790 .def (" encode" , [](CKKSEncoder &encoder, py::iterable values, double scale){
728791 std::vector<double > vec;
729792 vec.reserve (py::len (values));
@@ -749,6 +812,14 @@ PYBIND11_MODULE(seal, m)
749812 .def (" encode" , [](CKKSEncoder &encoder, double value, double scale, Plaintext &destination){
750813 encoder.encode (value, scale, destination);
751814 })
815+ .def (" encode_complex" , [](CKKSEncoder &encoder, std::complex <double > value, double scale){
816+ Plaintext pt;
817+ encoder.encode (value, scale, pt);
818+ return pt;
819+ })
820+ .def (" encode_complex" , [](CKKSEncoder &encoder, std::complex <double > value, double scale, Plaintext &destination){
821+ encoder.encode (value, scale, destination);
822+ })
752823 .def (" encode" , [](CKKSEncoder &encoder, std::int64_t value){
753824 Plaintext pt;
754825 encoder.encode (value, pt);
@@ -768,6 +839,19 @@ PYBIND11_MODULE(seal, m)
768839 for (auto i = 0 ; i < buf.shape [0 ]; i++)
769840 ptr[i] = destination[i];
770841
842+ return values;
843+ })
844+ .def (" decode_complex" , [](CKKSEncoder &encoder, const Plaintext &plain){
845+ std::vector<std::complex <double >> destination;
846+ encoder.decode (plain, destination);
847+
848+ py::array_t <std::complex <double >> values (destination.size ());
849+ py::buffer_info buf = values.request ();
850+ auto *ptr = static_cast <std::complex <double > *>(buf.ptr );
851+
852+ for (py::ssize_t i = 0 ; i < buf.shape [0 ]; i++)
853+ ptr[i] = destination[static_cast <std::size_t >(i)];
854+
771855 return values;
772856 });
773857
0 commit comments