Skip to content

Commit 3bd1b0e

Browse files
committed
fix includes and namespaces in dot.cpp
1 parent fc34f30 commit 3bd1b0e

1 file changed

Lines changed: 7 additions & 36 deletions

File tree

  • dpctl_ext/tensor/libtensor/source/linalg_functions

dpctl_ext/tensor/libtensor/source/linalg_functions/dot.cpp

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
/// This file defines functions of dpctl.tensor._tensor_impl extensions
3333
//===--------------------------------------------------------------------===//
3434

35+
#include <cassert>
3536
#include <cstddef>
3637
#include <exception>
3738
#include <stdexcept>
@@ -57,11 +58,7 @@
5758
#include "utils/output_validation.hpp"
5859
#include "utils/sycl_alloc_utils.hpp"
5960

60-
namespace dpctl
61-
{
62-
namespace tensor
63-
{
64-
namespace py_internal
61+
namespace dpctl::tensor::py_internal
6562
{
6663

6764
namespace td_ns = dpctl::tensor::type_dispatch;
@@ -112,77 +109,64 @@ static gemm_batch_contig_impl_fn_ptr_t
112109

113110
void init_dot_dispatch_tables(void)
114111
{
115-
using dpctl::tensor::py_internal::DotTypeMapFactory;
116112
td_ns::DispatchTableBuilder<int, DotTypeMapFactory, td_ns::num_types> dtb1;
117113
dtb1.populate_dispatch_table(dot_output_id_table);
118114

119-
using dpctl::tensor::py_internal::GemmBatchAtomicFactory;
120115
td_ns::DispatchTableBuilder<gemm_batch_impl_fn_ptr_t,
121116
GemmBatchAtomicFactory, td_ns::num_types>
122117
dtb2;
123118
dtb2.populate_dispatch_table(gemm_batch_atomic_dispatch_table);
124119

125-
using dpctl::tensor::py_internal::GemmBatchContigAtomicFactory;
126120
td_ns::DispatchTableBuilder<gemm_batch_contig_impl_fn_ptr_t,
127121
GemmBatchContigAtomicFactory, td_ns::num_types>
128122
dtb3;
129123
dtb3.populate_dispatch_table(gemm_batch_contig_atomic_dispatch_table);
130124

131-
using dpctl::tensor::py_internal::GemmAtomicFactory;
132125
td_ns::DispatchTableBuilder<gemm_impl_fn_ptr_t, GemmAtomicFactory,
133126
td_ns::num_types>
134127
dtb4;
135128
dtb4.populate_dispatch_table(gemm_atomic_dispatch_table);
136129

137-
using dpctl::tensor::py_internal::GemmContigAtomicFactory;
138130
td_ns::DispatchTableBuilder<gemm_contig_impl_fn_ptr_t,
139131
GemmContigAtomicFactory, td_ns::num_types>
140132
dtb5;
141133
dtb5.populate_dispatch_table(gemm_contig_atomic_dispatch_table);
142134

143-
using dpctl::tensor::py_internal::GemmBatchTempsFactory;
144135
td_ns::DispatchTableBuilder<gemm_batch_impl_fn_ptr_t, GemmBatchTempsFactory,
145136
td_ns::num_types>
146137
dtb6;
147138
dtb6.populate_dispatch_table(gemm_batch_temps_dispatch_table);
148139

149-
using dpctl::tensor::py_internal::GemmBatchContigTempsFactory;
150140
td_ns::DispatchTableBuilder<gemm_batch_contig_impl_fn_ptr_t,
151141
GemmBatchContigTempsFactory, td_ns::num_types>
152142
dtb7;
153143
dtb7.populate_dispatch_table(gemm_batch_contig_temps_dispatch_table);
154144

155-
using dpctl::tensor::py_internal::GemmTempsFactory;
156145
td_ns::DispatchTableBuilder<gemm_impl_fn_ptr_t, GemmTempsFactory,
157146
td_ns::num_types>
158147
dtb8;
159148
dtb8.populate_dispatch_table(gemm_temps_dispatch_table);
160149

161-
using dpctl::tensor::py_internal::GemmContigTempsFactory;
162150
td_ns::DispatchTableBuilder<gemm_contig_impl_fn_ptr_t,
163151
GemmContigTempsFactory, td_ns::num_types>
164152
dtb9;
165153
dtb9.populate_dispatch_table(gemm_contig_temps_dispatch_table);
166154

167-
using dpctl::tensor::py_internal::DotProductAtomicFactory;
168155
td_ns::DispatchTableBuilder<dot_product_impl_fn_ptr_t,
169156
DotProductAtomicFactory, td_ns::num_types>
170157
dtb10;
171158
dtb10.populate_dispatch_table(dot_product_dispatch_table);
172159

173-
using dpctl::tensor::py_internal::DotProductNoAtomicFactory;
174160
td_ns::DispatchTableBuilder<dot_product_impl_fn_ptr_t,
175161
DotProductNoAtomicFactory, td_ns::num_types>
176162
dtb11;
177163
dtb11.populate_dispatch_table(dot_product_temps_dispatch_table);
178164

179-
using dpctl::tensor::py_internal::DotProductContigAtomicFactory;
180165
td_ns::DispatchTableBuilder<dot_product_contig_impl_fn_ptr_t,
181166
DotProductContigAtomicFactory, td_ns::num_types>
182167
dtb12;
183168
dtb12.populate_dispatch_table(dot_product_contig_dispatch_table);
184169

185-
using dpctl::tensor::py_internal::DotProductContigNoAtomicFactory;
186170
td_ns::DispatchTableBuilder<dot_product_contig_impl_fn_ptr_t,
187171
DotProductContigNoAtomicFactory,
188172
td_ns::num_types>
@@ -368,9 +352,6 @@ std::pair<sycl::event, sycl::event>
368352
dot_ev);
369353
}
370354
}
371-
using dpctl::tensor::py_internal::simplify_iteration_space;
372-
using dpctl::tensor::py_internal::simplify_iteration_space_3;
373-
374355
int inner_nd = inner_dims;
375356
const py::ssize_t *inner_shape_ptr = x1_shape_ptr + batch_dims;
376357
using shT = std::vector<py::ssize_t>;
@@ -628,7 +609,7 @@ std::pair<sycl::event, sycl::event>
628609
shT outer_inner_x1_shape;
629610
shT batch_x1_strides;
630611
shT outer_inner_x1_strides;
631-
dpctl::tensor::py_internal::split_iteration_space(
612+
split_iteration_space(
632613
x1_shape_vec, x1_strides_vec, batch_dims,
633614
batch_dims + x1_outer_inner_dims,
634615
// 4 vectors modified
@@ -639,7 +620,7 @@ std::pair<sycl::event, sycl::event>
639620
shT outer_inner_x2_shape;
640621
shT batch_x2_strides;
641622
shT outer_inner_x2_strides;
642-
dpctl::tensor::py_internal::split_iteration_space(
623+
split_iteration_space(
643624
x2_shape_vec, x2_strides_vec, batch_dims,
644625
batch_dims + x2_outer_inner_dims,
645626
// 4 vectors modified
@@ -650,7 +631,7 @@ std::pair<sycl::event, sycl::event>
650631
shT outer_inner_dst_shape;
651632
shT batch_dst_strides;
652633
shT outer_inner_dst_strides;
653-
dpctl::tensor::py_internal::split_iteration_space(
634+
split_iteration_space(
654635
dst_shape_vec, dst_strides_vec, batch_dims,
655636
batch_dims + dst_outer_inner_dims,
656637
// 4 vectors modified
@@ -668,7 +649,6 @@ std::pair<sycl::event, sycl::event>
668649

669650
const py::ssize_t *shape = x1_shape_ptr;
670651

671-
using dpctl::tensor::py_internal::simplify_iteration_space_3;
672652
simplify_iteration_space_3(
673653
batch_dims, shape, batch_x1_strides, batch_x2_strides,
674654
batch_dst_strides,
@@ -830,37 +810,28 @@ py::object py_dot_result_type(const py::dtype &input1_dtype,
830810
return py::cast<py::object>(res);
831811
}
832812
else {
833-
using dpctl::tensor::py_internal::type_utils::_dtype_from_typenum;
834-
835813
auto dst_typenum_t = static_cast<td_ns::typenum_t>(dst_typeid);
836-
auto dt = _dtype_from_typenum(dst_typenum_t);
814+
auto dt = type_utils::_dtype_from_typenum(dst_typenum_t);
837815

838816
return py::cast<py::object>(dt);
839817
}
840818
}
841819

842820
void init_dot(py::module_ m)
843821
{
844-
using dpctl::tensor::py_internal::init_dot_atomic_support_vector;
845822
init_dot_atomic_support_vector();
846-
using dpctl::tensor::py_internal::init_dot_dispatch_tables;
847823
init_dot_dispatch_tables();
848824

849-
using dpctl::tensor::py_internal::py_dot;
850825
m.def("_dot", &py_dot, "", py::arg("x1"), py::arg("x2"),
851826
py::arg("batch_dims"), py::arg("x1_outer_dims"),
852827
py::arg("x2_outer_dims"), py::arg("inner_dims"), py::arg("dst"),
853828
py::arg("sycl_queue"), py::arg("depends") = py::list());
854829

855-
using dpctl::tensor::py_internal::dot_output_id_table;
856830
auto dot_result_type_pyapi = [&](const py::dtype &dtype1,
857831
const py::dtype &dtype2) {
858-
using dpctl::tensor::py_internal::py_dot_result_type;
859832
return py_dot_result_type(dtype1, dtype2, dot_output_id_table);
860833
};
861834
m.def("_dot_result_type", dot_result_type_pyapi, "");
862835
}
863836

864-
} // namespace py_internal
865-
} // namespace tensor
866-
} // namespace dpctl
837+
} // namespace dpctl::tensor::py_internal

0 commit comments

Comments
 (0)