Skip to content

Commit 662c1df

Browse files
committed
forward functions
1 parent 4824c2e commit 662c1df

1 file changed

Lines changed: 178 additions & 8 deletions

File tree

src/wrapper.cpp

Lines changed: 178 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@ using namespace seal;
99
namespace py = pybind11;
1010

1111
PYBIND11_MAKE_OPAQUE(std::vector<double>);
12+
PYBIND11_MAKE_OPAQUE(std::vector<std::uint64_t>);
1213
PYBIND11_MAKE_OPAQUE(std::vector<std::int64_t>);
1314

1415
PYBIND11_MODULE(seal, m)
1516
{
1617
m.doc() = "Microsoft SEAL for Python, from https://github.com/Huelse/SEAL-Python";
17-
m.attr("__version__") = "4.0.0";
18+
m.attr("__version__") = "4.1.2";
1819

1920
py::bind_vector<std::vector<double>>(m, "VectorDouble", py::buffer_protocol());
21+
py::bind_vector<std::vector<std::uint64_t>>(m, "VectorUInt", py::buffer_protocol());
2022
py::bind_vector<std::vector<std::int64_t>>(m, "VectorInt", py::buffer_protocol());
2123

2224
// encryptionparams.h
@@ -26,6 +28,37 @@ PYBIND11_MODULE(seal, m)
2628
.value("ckks", scheme_type::ckks)
2729
.value("bgv", scheme_type::bgv);
2830

31+
// serialization.h
32+
py::enum_<compr_mode_type>(m, "compr_mode_type")
33+
.value("none", compr_mode_type::none)
34+
#ifdef SEAL_USE_ZLIB
35+
.value("zlib", compr_mode_type::zlib)
36+
#endif
37+
#ifdef SEAL_USE_ZSTD
38+
.value("zstd", compr_mode_type::zstd)
39+
#endif
40+
;
41+
42+
// memorymanager.h
43+
py::class_<MemoryPoolHandle>(m, "MemoryPoolHandle")
44+
.def(py::init<>())
45+
.def_static("Global", &MemoryPoolHandle::Global)
46+
#ifndef _M_CEE
47+
.def_static("ThreadLocal", &MemoryPoolHandle::ThreadLocal)
48+
#endif
49+
.def_static("New", &MemoryPoolHandle::New, py::arg("clear_on_destruction") = false)
50+
.def("pool_count", &MemoryPoolHandle::pool_count)
51+
.def("alloc_byte_count", &MemoryPoolHandle::alloc_byte_count)
52+
.def("use_count", &MemoryPoolHandle::use_count)
53+
.def("is_initialized", [](const MemoryPoolHandle &pool){
54+
return static_cast<bool>(pool);
55+
});
56+
57+
py::class_<MemoryManager>(m, "MemoryManager")
58+
.def_static("GetPool", [](){
59+
return MemoryManager::GetPool();
60+
});
61+
2962
// encryptionparams.h
3063
py::class_<EncryptionParameters>(m, "EncryptionParameters")
3164
.def(py::init<scheme_type>())
@@ -43,11 +76,27 @@ PYBIND11_MODULE(seal, m)
4376
parms.save(out);
4477
out.close();
4578
})
79+
.def("save", [](const EncryptionParameters &parms, std::string &path, compr_mode_type compr_mode){
80+
std::ofstream out(path, std::ios::binary);
81+
parms.save(out, compr_mode);
82+
out.close();
83+
})
4684
.def("load", [](EncryptionParameters &parms, std::string &path){
4785
std::ifstream in(path, std::ios::binary);
4886
parms.load(in);
4987
in.close();
5088
})
89+
.def("load_bytes", [](EncryptionParameters &parms, py::bytes data){
90+
std::string raw = data;
91+
parms.load(reinterpret_cast<const seal_byte *>(raw.data()), raw.size());
92+
})
93+
.def("save_size", py::overload_cast<compr_mode_type>(&EncryptionParameters::save_size, py::const_),
94+
py::arg("compr_mode")=Serialization::compr_mode_default)
95+
.def("to_bytes", [](const EncryptionParameters &parms, compr_mode_type compr_mode){
96+
std::stringstream out(std::ios::binary | std::ios::out);
97+
parms.save(out, compr_mode);
98+
return py::bytes(out.str());
99+
}, py::arg("compr_mode")=Serialization::compr_mode_default)
51100
.def(py::pickle(
52101
[](const EncryptionParameters &parms){
53102
std::stringstream out(std::ios::binary | std::ios::out);
@@ -59,6 +108,7 @@ PYBIND11_MODULE(seal, m)
59108
throw std::runtime_error("(Pickle) Invalid input tuple!");
60109
std::string str = t[0].cast<std::string>();
61110
std::stringstream in(std::ios::binary | std::ios::in);
111+
in.str(str);
62112
EncryptionParameters parms;
63113
parms.load(in);
64114
return parms;
@@ -94,6 +144,9 @@ PYBIND11_MODULE(seal, m)
94144
// context.h
95145
py::class_<EncryptionParameterQualifiers, std::unique_ptr<EncryptionParameterQualifiers, py::nodelete>>(m, "EncryptionParameterQualifiers")
96146
.def("parameters_set", &EncryptionParameterQualifiers::parameters_set)
147+
.def_readwrite("parameter_error", &EncryptionParameterQualifiers::parameter_error)
148+
.def("parameter_error_name", &EncryptionParameterQualifiers::parameter_error_name)
149+
.def("parameter_error_message", &EncryptionParameterQualifiers::parameter_error_message)
97150
.def_readwrite("using_fft", &EncryptionParameterQualifiers::using_fft)
98151
.def_readwrite("using_ntt", &EncryptionParameterQualifiers::using_ntt)
99152
.def_readwrite("using_batching", &EncryptionParameterQualifiers::using_batching)
@@ -119,6 +172,8 @@ PYBIND11_MODULE(seal, m)
119172
.def("first_context_data", &SEALContext::first_context_data)
120173
.def("last_context_data", &SEALContext::last_context_data)
121174
.def("parameters_set", &SEALContext::parameters_set)
175+
.def("parameter_error_name", &SEALContext::parameter_error_name)
176+
.def("parameter_error_message", &SEALContext::parameter_error_message)
122177
.def("first_parms_id", &SEALContext::first_parms_id)
123178
.def("last_parms_id", &SEALContext::last_parms_id)
124179
.def("using_keyswitching", &SEALContext::using_keyswitching)
@@ -171,7 +226,8 @@ PYBIND11_MODULE(seal, m)
171226
.def("bit_count", &Modulus::bit_count)
172227
.def("value", &Modulus::value)
173228
.def("is_zero", &Modulus::is_zero)
174-
.def("is_prime", &Modulus::is_prime);
229+
.def("is_prime", &Modulus::is_prime)
230+
.def("reduce", &Modulus::reduce);
175231
//save & load
176232

177233
// modulus.h
@@ -213,19 +269,30 @@ PYBIND11_MODULE(seal, m)
213269
plain.save(out);
214270
out.close();
215271
})
272+
.def("save", [](const Plaintext &plain, const std::string &path, compr_mode_type compr_mode){
273+
std::ofstream out(path, std::ios::binary);
274+
plain.save(out, compr_mode);
275+
out.close();
276+
})
216277
.def("load", [](Plaintext &plain, const SEALContext &context, const std::string &path){
217278
std::ifstream in(path, std::ios::binary);
218279
plain.load(context, in);
219280
in.close();
220281
})
282+
.def("load_bytes", [](Plaintext &plain, const SEALContext &context, py::bytes data){
283+
std::string raw = data;
284+
plain.load(context, reinterpret_cast<const seal_byte *>(raw.data()), raw.size());
285+
})
221286
.def("save_size", [](const Plaintext &plain){
222287
return plain.save_size();
223288
})
224-
.def("to_string", [](const Plaintext &plain){
289+
.def("save_size", py::overload_cast<compr_mode_type>(&Plaintext::save_size, py::const_),
290+
py::arg("compr_mode")=Serialization::compr_mode_default)
291+
.def("to_bytes", [](const Plaintext &plain, compr_mode_type compr_mode){
225292
std::stringstream out(std::ios::binary | std::ios::out);
226-
plain.save(out);
293+
plain.save(out, compr_mode);
227294
return py::bytes(out.str());
228-
});
295+
}, py::arg("compr_mode")=Serialization::compr_mode_default);
229296

230297
// ciphertext.h
231298
py::class_<Ciphertext>(m, "Ciphertext")
@@ -250,19 +317,30 @@ PYBIND11_MODULE(seal, m)
250317
cipher.save(out);
251318
out.close();
252319
})
320+
.def("save", [](const Ciphertext &cipher, const std::string &path, compr_mode_type compr_mode){
321+
std::ofstream out(path, std::ios::binary);
322+
cipher.save(out, compr_mode);
323+
out.close();
324+
})
253325
.def("load", [](Ciphertext &cipher, const SEALContext &context, const std::string &path){
254326
std::ifstream in(path, std::ios::binary);
255327
cipher.load(context, in);
256328
in.close();
257329
})
330+
.def("load_bytes", [](Ciphertext &cipher, const SEALContext &context, py::bytes data){
331+
std::string raw = data;
332+
cipher.load(context, reinterpret_cast<const seal_byte *>(raw.data()), raw.size());
333+
})
258334
.def("save_size", [](const Ciphertext &cipher){
259335
return cipher.save_size();
260336
})
261-
.def("to_string", [](const Ciphertext &cipher){
337+
.def("save_size", py::overload_cast<compr_mode_type>(&Ciphertext::save_size, py::const_),
338+
py::arg("compr_mode")=Serialization::compr_mode_default)
339+
.def("to_string", [](const Ciphertext &cipher, compr_mode_type compr_mode){
262340
std::stringstream out(std::ios::binary | std::ios::out);
263-
cipher.save(out);
341+
cipher.save(out, compr_mode);
264342
return py::bytes(out.str());
265-
});
343+
}, py::arg("compr_mode")=Serialization::compr_mode_default);
266344

267345
// secretkey.h
268346
py::class_<SecretKey>(m, "SecretKey")
@@ -408,15 +486,32 @@ PYBIND11_MODULE(seal, m)
408486
encryptor.encrypt_zero(encrypted);
409487
return encrypted;
410488
})
489+
.def("encrypt_zero", [](const Encryptor &encryptor, Ciphertext &destination){
490+
encryptor.encrypt_zero(destination);
491+
})
492+
.def("encrypt_zero", [](const Encryptor &encryptor, parms_id_type parms_id){
493+
Ciphertext encrypted;
494+
encryptor.encrypt_zero(parms_id, encrypted);
495+
return encrypted;
496+
})
497+
.def("encrypt_zero", [](const Encryptor &encryptor, parms_id_type parms_id, Ciphertext &destination){
498+
encryptor.encrypt_zero(parms_id, destination);
499+
})
411500
.def("encrypt", [](const Encryptor &encryptor, const Plaintext &plain){
412501
Ciphertext encrypted;
413502
encryptor.encrypt(plain, encrypted);
414503
return encrypted;
415504
})
505+
.def("encrypt", [](const Encryptor &encryptor, const Plaintext &plain, Ciphertext &destination){
506+
encryptor.encrypt(plain, destination);
507+
})
416508
.def("encrypt_symmetric", [](const Encryptor &encryptor, const Plaintext &plain){
417509
Ciphertext encrypted;
418510
encryptor.encrypt_symmetric(plain, encrypted);
419511
return encrypted;
512+
})
513+
.def("encrypt_symmetric", [](const Encryptor &encryptor, const Plaintext &plain, Ciphertext &destination){
514+
encryptor.encrypt_symmetric(plain, destination);
420515
});
421516

422517
// evaluator.h
@@ -611,6 +706,9 @@ PYBIND11_MODULE(seal, m)
611706
py::class_<CKKSEncoder>(m, "CKKSEncoder")
612707
.def(py::init<const SEALContext &>())
613708
.def("slot_count", &CKKSEncoder::slot_count)
709+
.def("encode", [](CKKSEncoder &encoder, const std::vector<double> &values, double scale, Plaintext &destination){
710+
encoder.encode(values, scale, destination);
711+
})
614712
.def("encode", [](CKKSEncoder &encoder, py::array_t<double> values, double scale){
615713
py::buffer_info buf = values.request();
616714
if (buf.ndim != 1)
@@ -626,11 +724,39 @@ PYBIND11_MODULE(seal, m)
626724
encoder.encode(vec, scale, pt);
627725
return pt;
628726
})
727+
.def("encode", [](CKKSEncoder &encoder, py::iterable values, double scale){
728+
std::vector<double> vec;
729+
vec.reserve(py::len(values));
730+
for (const auto &value : values)
731+
vec.push_back(py::cast<double>(value));
732+
733+
Plaintext pt;
734+
encoder.encode(vec, scale, pt);
735+
return pt;
736+
})
737+
.def("encode", [](CKKSEncoder &encoder, py::iterable values, double scale, Plaintext &destination){
738+
std::vector<double> vec;
739+
vec.reserve(py::len(values));
740+
for (const auto &value : values)
741+
vec.push_back(py::cast<double>(value));
742+
encoder.encode(vec, scale, destination);
743+
})
629744
.def("encode", [](CKKSEncoder &encoder, double value, double scale){
630745
Plaintext pt;
631746
encoder.encode(value, scale, pt);
632747
return pt;
633748
})
749+
.def("encode", [](CKKSEncoder &encoder, double value, double scale, Plaintext &destination){
750+
encoder.encode(value, scale, destination);
751+
})
752+
.def("encode", [](CKKSEncoder &encoder, std::int64_t value){
753+
Plaintext pt;
754+
encoder.encode(value, pt);
755+
return pt;
756+
})
757+
.def("encode", [](CKKSEncoder &encoder, std::int64_t value, Plaintext &destination){
758+
encoder.encode(value, destination);
759+
})
634760
.def("decode", [](CKKSEncoder &encoder, const Plaintext &plain){
635761
std::vector<double> destination;
636762
encoder.decode(plain, destination);
@@ -660,6 +786,12 @@ PYBIND11_MODULE(seal, m)
660786
py::class_<BatchEncoder>(m, "BatchEncoder")
661787
.def(py::init<const SEALContext &>())
662788
.def("slot_count", &BatchEncoder::slot_count)
789+
.def("encode", [](BatchEncoder &encoder, const std::vector<std::int64_t> &values, Plaintext &destination){
790+
encoder.encode(values, destination);
791+
})
792+
.def("encode", [](BatchEncoder &encoder, const std::vector<std::uint64_t> &values, Plaintext &destination){
793+
encoder.encode(values, destination);
794+
})
663795
.def("encode", [](BatchEncoder &encoder, py::array_t<std::int64_t> values){
664796
py::buffer_info buf = values.request();
665797
if (buf.ndim != 1)
@@ -675,6 +807,44 @@ PYBIND11_MODULE(seal, m)
675807
encoder.encode(vec, pt);
676808
return pt;
677809
})
810+
.def("encode", [](BatchEncoder &encoder, py::array_t<std::uint64_t> values){
811+
py::buffer_info buf = values.request();
812+
if (buf.ndim != 1)
813+
throw std::runtime_error("E101: Number of dimensions must be one");
814+
815+
auto *ptr = static_cast<std::uint64_t *>(buf.ptr);
816+
std::vector<std::uint64_t> vec(static_cast<std::size_t>(buf.shape[0]));
817+
818+
for (py::ssize_t i = 0; i < buf.shape[0]; i++)
819+
vec[static_cast<std::size_t>(i)] = ptr[i];
820+
821+
Plaintext pt;
822+
encoder.encode(vec, pt);
823+
return pt;
824+
})
825+
.def("encode", [](BatchEncoder &encoder, py::iterable values){
826+
std::vector<std::int64_t> vec;
827+
vec.reserve(py::len(values));
828+
for (const auto &value : values)
829+
vec.push_back(py::cast<std::int64_t>(value));
830+
831+
Plaintext pt;
832+
encoder.encode(vec, pt);
833+
return pt;
834+
})
835+
.def("decode_uint64", [](BatchEncoder &encoder, const Plaintext &plain){
836+
std::vector<std::uint64_t> destination;
837+
encoder.decode(plain, destination);
838+
839+
py::array_t<std::uint64_t> values(destination.size());
840+
py::buffer_info buf = values.request();
841+
auto *ptr = static_cast<std::uint64_t *>(buf.ptr);
842+
843+
for (py::ssize_t i = 0; i < buf.shape[0]; i++)
844+
ptr[i] = destination[static_cast<std::size_t>(i)];
845+
846+
return values;
847+
})
678848
.def("decode", [](BatchEncoder &encoder, const Plaintext &plain){
679849
std::vector<std::int64_t> destination;
680850
encoder.decode(plain, destination);

0 commit comments

Comments
 (0)