3131#include < pybind11/stl.h>
3232
3333#include " dot.hpp"
34+ #include " dot_common.hpp"
35+ #include " dotc.hpp"
36+ #include " dotu.hpp"
3437#include " gemm.hpp"
3538
3639namespace blas_ext = dpnp::backend::ext::blas;
3740namespace py = pybind11;
41+ namespace dot_ext = blas_ext::dot;
42+ using dot_ext::dot_impl_fn_ptr_t ;
3843
3944// populate dispatch tables
4045void init_dispatch_tables (void )
4146{
42- blas_ext::init_dot_dispatch_table ();
43- blas_ext::init_dotc_dispatch_table ();
44- blas_ext::init_dotu_dispatch_table ();
4547 blas_ext::init_gemm_batch_dispatch_table ();
4648 blas_ext::init_gemm_dispatch_table ();
4749}
4850
51+ static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
52+ static dot_impl_fn_ptr_t dotc_dispatch_vector[dpctl_td_ns::num_types];
53+ static dot_impl_fn_ptr_t dotu_dispatch_vector[dpctl_td_ns::num_types];
54+
4955PYBIND11_MODULE (_blas_impl, m)
5056{
5157 init_dispatch_tables ();
5258
59+ using arrayT = dpctl::tensor::usm_ndarray;
60+ using event_vecT = std::vector<sycl::event>;
61+
5362 {
54- m.def (" _dot" , &blas_ext::dot,
63+ dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t ,
64+ blas_ext::DotContigFactory>(
65+ dot_dispatch_vector);
66+
67+ auto dot_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
68+ arrayT dst, const event_vecT &depends = {}) {
69+ return dot_ext::dot_func (exec_q, src1, src2, dst, depends,
70+ dot_dispatch_vector);
71+ };
72+
73+ m.def (" _dot" , dot_pypi,
5574 " Call `dot` from OneMKL BLAS library to return "
5675 " the dot product of two real-valued vectors." ,
5776 py::arg (" sycl_queue" ), py::arg (" vectorA" ), py::arg (" vectorB" ),
5877 py::arg (" result" ), py::arg (" depends" ) = py::list ());
5978 }
6079
6180 {
62- m.def (" _dotc" , &blas_ext::dotc,
81+ dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t ,
82+ blas_ext::DotcContigFactory>(
83+ dotc_dispatch_vector);
84+
85+ auto dotc_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
86+ arrayT dst, const event_vecT &depends = {}) {
87+ return dot_ext::dot_func (exec_q, src1, src2, dst, depends,
88+ dotc_dispatch_vector);
89+ };
90+
91+ m.def (" _dotc" , dotc_pypi,
6392 " Call `dotc` from OneMKL BLAS library to return "
6493 " the dot product of two complex vectors, "
6594 " conjugating the first vector." ,
@@ -68,7 +97,17 @@ PYBIND11_MODULE(_blas_impl, m)
6897 }
6998
7099 {
71- m.def (" _dotu" , &blas_ext::dotu,
100+ dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t ,
101+ blas_ext::DotuContigFactory>(
102+ dotu_dispatch_vector);
103+
104+ auto dotu_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
105+ arrayT dst, const event_vecT &depends = {}) {
106+ return dot_ext::dot_func (exec_q, src1, src2, dst, depends,
107+ dotu_dispatch_vector);
108+ };
109+
110+ m.def (" _dotu" , dotu_pypi,
72111 " Call `dotu` from OneMKL BLAS library to return "
73112 " the dot product of two complex vectors." ,
74113 py::arg (" sycl_queue" ), py::arg (" vectorA" ), py::arg (" vectorB" ),
0 commit comments