Skip to content

Commit b893245

Browse files
authored
Update BLAS backend (#1790)
* rename varibales in dot/dotc/dotu * remove additional typename for dotc and dotu * remove additional typename for dot * update gemm routine * add using throw_if_not_writable and throw_if_not_ample * remove duplicate * using DispatchVectorBuilder * reduce duplication for dot functions * address comments
1 parent 7584c86 commit b893245

12 files changed

Lines changed: 622 additions & 816 deletions

File tree

dpnp/backend/extensions/blas/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
set(python_module_name _blas_impl)
2828
set(_module_src
2929
${CMAKE_CURRENT_SOURCE_DIR}/blas_py.cpp
30-
${CMAKE_CURRENT_SOURCE_DIR}/dot.cpp
31-
${CMAKE_CURRENT_SOURCE_DIR}/dotc.cpp
32-
${CMAKE_CURRENT_SOURCE_DIR}/dotu.cpp
3330
${CMAKE_CURRENT_SOURCE_DIR}/gemm.cpp
3431
${CMAKE_CURRENT_SOURCE_DIR}/gemm_batch.cpp
3532
)

dpnp/backend/extensions/blas/blas_py.cpp

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,35 +31,64 @@
3131
#include <pybind11/stl.h>
3232

3333
#include "dot.hpp"
34+
#include "dot_common.hpp"
35+
#include "dotc.hpp"
36+
#include "dotu.hpp"
3437
#include "gemm.hpp"
3538

3639
namespace blas_ext = dpnp::backend::ext::blas;
3740
namespace py = pybind11;
41+
namespace dot_ext = blas_ext::dot;
42+
using dot_ext::dot_impl_fn_ptr_t;
3843

3944
// populate dispatch tables
4045
void init_dispatch_tables(void)
4146
{
42-
blas_ext::init_dot_dispatch_table();
43-
blas_ext::init_dotc_dispatch_table();
44-
blas_ext::init_dotu_dispatch_table();
4547
blas_ext::init_gemm_batch_dispatch_table();
4648
blas_ext::init_gemm_dispatch_table();
4749
}
4850

51+
static dot_impl_fn_ptr_t dot_dispatch_vector[dpctl_td_ns::num_types];
52+
static dot_impl_fn_ptr_t dotc_dispatch_vector[dpctl_td_ns::num_types];
53+
static dot_impl_fn_ptr_t dotu_dispatch_vector[dpctl_td_ns::num_types];
54+
4955
PYBIND11_MODULE(_blas_impl, m)
5056
{
5157
init_dispatch_tables();
5258

59+
using arrayT = dpctl::tensor::usm_ndarray;
60+
using event_vecT = std::vector<sycl::event>;
61+
5362
{
54-
m.def("_dot", &blas_ext::dot,
63+
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
64+
blas_ext::DotContigFactory>(
65+
dot_dispatch_vector);
66+
67+
auto dot_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
68+
arrayT dst, const event_vecT &depends = {}) {
69+
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
70+
dot_dispatch_vector);
71+
};
72+
73+
m.def("_dot", dot_pypi,
5574
"Call `dot` from OneMKL BLAS library to return "
5675
"the dot product of two real-valued vectors.",
5776
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),
5877
py::arg("result"), py::arg("depends") = py::list());
5978
}
6079

6180
{
62-
m.def("_dotc", &blas_ext::dotc,
81+
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
82+
blas_ext::DotcContigFactory>(
83+
dotc_dispatch_vector);
84+
85+
auto dotc_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
86+
arrayT dst, const event_vecT &depends = {}) {
87+
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
88+
dotc_dispatch_vector);
89+
};
90+
91+
m.def("_dotc", dotc_pypi,
6392
"Call `dotc` from OneMKL BLAS library to return "
6493
"the dot product of two complex vectors, "
6594
"conjugating the first vector.",
@@ -68,7 +97,17 @@ PYBIND11_MODULE(_blas_impl, m)
6897
}
6998

7099
{
71-
m.def("_dotu", &blas_ext::dotu,
100+
dot_ext::init_dot_dispatch_vector<dot_impl_fn_ptr_t,
101+
blas_ext::DotuContigFactory>(
102+
dotu_dispatch_vector);
103+
104+
auto dotu_pypi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
105+
arrayT dst, const event_vecT &depends = {}) {
106+
return dot_ext::dot_func(exec_q, src1, src2, dst, depends,
107+
dotu_dispatch_vector);
108+
};
109+
110+
m.def("_dotu", dotu_pypi,
72111
"Call `dotu` from OneMKL BLAS library to return "
73112
"the dot product of two complex vectors.",
74113
py::arg("sycl_queue"), py::arg("vectorA"), py::arg("vectorB"),

dpnp/backend/extensions/blas/dot.cpp

Lines changed: 0 additions & 238 deletions
This file was deleted.

0 commit comments

Comments
 (0)