3232// / This file defines functions of dpctl.tensor._tensor_impl extensions
3333// ===--------------------------------------------------------------------===//
3434
35+ #include < cassert>
3536#include < cstddef>
3637#include < exception>
3738#include < stdexcept>
5758#include " utils/output_validation.hpp"
5859#include " utils/sycl_alloc_utils.hpp"
5960
60- namespace dpctl
61- {
62- namespace tensor
63- {
64- namespace py_internal
61+ namespace dpctl ::tensor::py_internal
6562{
6663
6764namespace td_ns = dpctl::tensor::type_dispatch;
@@ -112,77 +109,64 @@ static gemm_batch_contig_impl_fn_ptr_t
112109
113110void init_dot_dispatch_tables (void )
114111{
115- using dpctl::tensor::py_internal::DotTypeMapFactory;
116112 td_ns::DispatchTableBuilder<int , DotTypeMapFactory, td_ns::num_types> dtb1;
117113 dtb1.populate_dispatch_table (dot_output_id_table);
118114
119- using dpctl::tensor::py_internal::GemmBatchAtomicFactory;
120115 td_ns::DispatchTableBuilder<gemm_batch_impl_fn_ptr_t ,
121116 GemmBatchAtomicFactory, td_ns::num_types>
122117 dtb2;
123118 dtb2.populate_dispatch_table (gemm_batch_atomic_dispatch_table);
124119
125- using dpctl::tensor::py_internal::GemmBatchContigAtomicFactory;
126120 td_ns::DispatchTableBuilder<gemm_batch_contig_impl_fn_ptr_t ,
127121 GemmBatchContigAtomicFactory, td_ns::num_types>
128122 dtb3;
129123 dtb3.populate_dispatch_table (gemm_batch_contig_atomic_dispatch_table);
130124
131- using dpctl::tensor::py_internal::GemmAtomicFactory;
132125 td_ns::DispatchTableBuilder<gemm_impl_fn_ptr_t , GemmAtomicFactory,
133126 td_ns::num_types>
134127 dtb4;
135128 dtb4.populate_dispatch_table (gemm_atomic_dispatch_table);
136129
137- using dpctl::tensor::py_internal::GemmContigAtomicFactory;
138130 td_ns::DispatchTableBuilder<gemm_contig_impl_fn_ptr_t ,
139131 GemmContigAtomicFactory, td_ns::num_types>
140132 dtb5;
141133 dtb5.populate_dispatch_table (gemm_contig_atomic_dispatch_table);
142134
143- using dpctl::tensor::py_internal::GemmBatchTempsFactory;
144135 td_ns::DispatchTableBuilder<gemm_batch_impl_fn_ptr_t , GemmBatchTempsFactory,
145136 td_ns::num_types>
146137 dtb6;
147138 dtb6.populate_dispatch_table (gemm_batch_temps_dispatch_table);
148139
149- using dpctl::tensor::py_internal::GemmBatchContigTempsFactory;
150140 td_ns::DispatchTableBuilder<gemm_batch_contig_impl_fn_ptr_t ,
151141 GemmBatchContigTempsFactory, td_ns::num_types>
152142 dtb7;
153143 dtb7.populate_dispatch_table (gemm_batch_contig_temps_dispatch_table);
154144
155- using dpctl::tensor::py_internal::GemmTempsFactory;
156145 td_ns::DispatchTableBuilder<gemm_impl_fn_ptr_t , GemmTempsFactory,
157146 td_ns::num_types>
158147 dtb8;
159148 dtb8.populate_dispatch_table (gemm_temps_dispatch_table);
160149
161- using dpctl::tensor::py_internal::GemmContigTempsFactory;
162150 td_ns::DispatchTableBuilder<gemm_contig_impl_fn_ptr_t ,
163151 GemmContigTempsFactory, td_ns::num_types>
164152 dtb9;
165153 dtb9.populate_dispatch_table (gemm_contig_temps_dispatch_table);
166154
167- using dpctl::tensor::py_internal::DotProductAtomicFactory;
168155 td_ns::DispatchTableBuilder<dot_product_impl_fn_ptr_t ,
169156 DotProductAtomicFactory, td_ns::num_types>
170157 dtb10;
171158 dtb10.populate_dispatch_table (dot_product_dispatch_table);
172159
173- using dpctl::tensor::py_internal::DotProductNoAtomicFactory;
174160 td_ns::DispatchTableBuilder<dot_product_impl_fn_ptr_t ,
175161 DotProductNoAtomicFactory, td_ns::num_types>
176162 dtb11;
177163 dtb11.populate_dispatch_table (dot_product_temps_dispatch_table);
178164
179- using dpctl::tensor::py_internal::DotProductContigAtomicFactory;
180165 td_ns::DispatchTableBuilder<dot_product_contig_impl_fn_ptr_t ,
181166 DotProductContigAtomicFactory, td_ns::num_types>
182167 dtb12;
183168 dtb12.populate_dispatch_table (dot_product_contig_dispatch_table);
184169
185- using dpctl::tensor::py_internal::DotProductContigNoAtomicFactory;
186170 td_ns::DispatchTableBuilder<dot_product_contig_impl_fn_ptr_t ,
187171 DotProductContigNoAtomicFactory,
188172 td_ns::num_types>
@@ -368,9 +352,6 @@ std::pair<sycl::event, sycl::event>
368352 dot_ev);
369353 }
370354 }
371- using dpctl::tensor::py_internal::simplify_iteration_space;
372- using dpctl::tensor::py_internal::simplify_iteration_space_3;
373-
374355 int inner_nd = inner_dims;
375356 const py::ssize_t *inner_shape_ptr = x1_shape_ptr + batch_dims;
376357 using shT = std::vector<py::ssize_t >;
@@ -628,7 +609,7 @@ std::pair<sycl::event, sycl::event>
628609 shT outer_inner_x1_shape;
629610 shT batch_x1_strides;
630611 shT outer_inner_x1_strides;
631- dpctl::tensor::py_internal:: split_iteration_space (
612+ split_iteration_space (
632613 x1_shape_vec, x1_strides_vec, batch_dims,
633614 batch_dims + x1_outer_inner_dims,
634615 // 4 vectors modified
@@ -639,7 +620,7 @@ std::pair<sycl::event, sycl::event>
639620 shT outer_inner_x2_shape;
640621 shT batch_x2_strides;
641622 shT outer_inner_x2_strides;
642- dpctl::tensor::py_internal:: split_iteration_space (
623+ split_iteration_space (
643624 x2_shape_vec, x2_strides_vec, batch_dims,
644625 batch_dims + x2_outer_inner_dims,
645626 // 4 vectors modified
@@ -650,7 +631,7 @@ std::pair<sycl::event, sycl::event>
650631 shT outer_inner_dst_shape;
651632 shT batch_dst_strides;
652633 shT outer_inner_dst_strides;
653- dpctl::tensor::py_internal:: split_iteration_space (
634+ split_iteration_space (
654635 dst_shape_vec, dst_strides_vec, batch_dims,
655636 batch_dims + dst_outer_inner_dims,
656637 // 4 vectors modified
@@ -668,7 +649,6 @@ std::pair<sycl::event, sycl::event>
668649
669650 const py::ssize_t *shape = x1_shape_ptr;
670651
671- using dpctl::tensor::py_internal::simplify_iteration_space_3;
672652 simplify_iteration_space_3 (
673653 batch_dims, shape, batch_x1_strides, batch_x2_strides,
674654 batch_dst_strides,
@@ -830,37 +810,28 @@ py::object py_dot_result_type(const py::dtype &input1_dtype,
830810 return py::cast<py::object>(res);
831811 }
832812 else {
833- using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum;
834-
835813 auto dst_typenum_t = static_cast <td_ns::typenum_t >(dst_typeid);
836- auto dt = _dtype_from_typenum (dst_typenum_t );
814+ auto dt = type_utils:: _dtype_from_typenum (dst_typenum_t );
837815
838816 return py::cast<py::object>(dt);
839817 }
840818}
841819
842820void init_dot (py::module_ m)
843821{
844- using dpctl::tensor::py_internal::init_dot_atomic_support_vector;
845822 init_dot_atomic_support_vector ();
846- using dpctl::tensor::py_internal::init_dot_dispatch_tables;
847823 init_dot_dispatch_tables ();
848824
849- using dpctl::tensor::py_internal::py_dot;
850825 m.def (" _dot" , &py_dot, " " , py::arg (" x1" ), py::arg (" x2" ),
851826 py::arg (" batch_dims" ), py::arg (" x1_outer_dims" ),
852827 py::arg (" x2_outer_dims" ), py::arg (" inner_dims" ), py::arg (" dst" ),
853828 py::arg (" sycl_queue" ), py::arg (" depends" ) = py::list ());
854829
855- using dpctl::tensor::py_internal::dot_output_id_table;
856830 auto dot_result_type_pyapi = [&](const py::dtype &dtype1,
857831 const py::dtype &dtype2) {
858- using dpctl::tensor::py_internal::py_dot_result_type;
859832 return py_dot_result_type (dtype1, dtype2, dot_output_id_table);
860833 };
861834 m.def (" _dot_result_type" , dot_result_type_pyapi, " " );
862835}
863836
864- } // namespace py_internal
865- } // namespace tensor
866- } // namespace dpctl
837+ } // namespace dpctl::tensor::py_internal
0 commit comments