Skip to content

Commit c6d600a

Browse files
Move _argmax/argmin_over_axis to _tensor_reductions_impl
1 parent aa313ff commit c6d600a

6 files changed

Lines changed: 658 additions & 6 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ set(_reduction_sources
7373
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp
7474
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/all.cpp
7575
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/any.cpp
76-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp
77-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp
76+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp
77+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp
7878
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp
7979
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/max.cpp
8080
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/min.cpp
Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_reductions_impl
33+
/// extension.
34+
//===---------------------------------------------------------------------===//
35+
36+
#include <complex>
37+
#include <cstdint>
38+
#include <type_traits>
39+
#include <vector>
40+
41+
#include <sycl/sycl.hpp>
42+
43+
#include "dpnp4pybind11.hpp"
44+
#include <pybind11/pybind11.h>
45+
#include <pybind11/stl.h>
46+
47+
#include "kernels/reductions.hpp"
48+
#include "reduction_over_axis.hpp"
49+
#include "utils/sycl_utils.hpp"
50+
#include "utils/type_dispatch_building.hpp"
51+
52+
namespace dpctl::tensor::py_internal
53+
{
54+
55+
namespace py = pybind11;
56+
namespace td_ns = dpctl::tensor::type_dispatch;
57+
namespace su_ns = dpctl::tensor::sycl_utils;
58+
59+
namespace impl
60+
{
61+
62+
using dpctl::tensor::kernels::search_strided_impl_fn_ptr;
63+
static search_strided_impl_fn_ptr
64+
argmax_over_axis_strided_temps_dispatch_table[td_ns::num_types]
65+
[td_ns::num_types];
66+
67+
using dpctl::tensor::kernels::search_contig_impl_fn_ptr;
68+
static search_contig_impl_fn_ptr
69+
argmax_over_axis1_contig_temps_dispatch_table[td_ns::num_types]
70+
[td_ns::num_types];
71+
using dpctl::tensor::kernels::search_contig_impl_fn_ptr;
72+
static search_contig_impl_fn_ptr
73+
argmax_over_axis0_contig_temps_dispatch_table[td_ns::num_types]
74+
[td_ns::num_types];
75+
76+
template <typename argTy, typename outTy>
77+
struct TypePairSupportForArgmaxReductionTemps
78+
{
79+
80+
static constexpr bool is_defined = std::disjunction<
81+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
82+
// input int8_t
83+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
84+
85+
// input uint8_t
86+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::int64_t>,
87+
88+
// input int16_t
89+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
90+
91+
// input uint16_t
92+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::int64_t>,
93+
94+
// input int32_t
95+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
96+
// input uint32_t
97+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::int64_t>,
98+
99+
// input int64_t
100+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
101+
102+
// input uint32_t
103+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::int64_t>,
104+
105+
// input half
106+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, std::int64_t>,
107+
108+
// input float
109+
td_ns::TypePairDefinedEntry<argTy, float, outTy, std::int64_t>,
110+
111+
// input double
112+
td_ns::TypePairDefinedEntry<argTy, double, outTy, std::int64_t>,
113+
114+
// input std::complex
115+
td_ns::TypePairDefinedEntry<argTy,
116+
std::complex<float>,
117+
outTy,
118+
std::int64_t>,
119+
120+
td_ns::TypePairDefinedEntry<argTy,
121+
std::complex<double>,
122+
outTy,
123+
std::int64_t>,
124+
125+
// fall-through
126+
td_ns::NotDefinedEntry>::is_defined;
127+
};
128+
129+
template <typename fnT, typename srcTy, typename dstTy>
130+
struct ArgmaxOverAxisTempsStridedFactory
131+
{
132+
fnT get() const
133+
{
134+
if constexpr (TypePairSupportForArgmaxReductionTemps<srcTy,
135+
dstTy>::is_defined)
136+
{
137+
if constexpr (std::is_integral_v<srcTy> &&
138+
!std::is_same_v<srcTy, bool>) {
139+
// op for values
140+
using ReductionOpT = sycl::maximum<srcTy>;
141+
// op for indices
142+
using IndexOpT = sycl::minimum<dstTy>;
143+
return dpctl::tensor::kernels::
144+
search_over_group_temps_strided_impl<
145+
srcTy, dstTy, ReductionOpT, IndexOpT>;
146+
}
147+
else {
148+
// op for values
149+
using ReductionOpT = su_ns::Maximum<srcTy>;
150+
// op for indices
151+
using IndexOpT = sycl::minimum<dstTy>;
152+
return dpctl::tensor::kernels::
153+
search_over_group_temps_strided_impl<
154+
srcTy, dstTy, ReductionOpT, IndexOpT>;
155+
}
156+
}
157+
else {
158+
return nullptr;
159+
}
160+
}
161+
};
162+
163+
template <typename fnT, typename srcTy, typename dstTy>
164+
struct ArgmaxOverAxis1TempsContigFactory
165+
{
166+
fnT get() const
167+
{
168+
if constexpr (TypePairSupportForArgmaxReductionTemps<srcTy,
169+
dstTy>::is_defined)
170+
{
171+
if constexpr (std::is_integral_v<srcTy> &&
172+
!std::is_same_v<srcTy, bool>) {
173+
// op for values
174+
using ReductionOpT = sycl::maximum<srcTy>;
175+
// op for indices
176+
using IndexOpT = sycl::minimum<dstTy>;
177+
return dpctl::tensor::kernels::
178+
search_axis1_over_group_temps_contig_impl<
179+
srcTy, dstTy, ReductionOpT, IndexOpT>;
180+
}
181+
else {
182+
// op for values
183+
using ReductionOpT = su_ns::Maximum<srcTy>;
184+
// op for indices
185+
using IndexOpT = sycl::minimum<dstTy>;
186+
return dpctl::tensor::kernels::
187+
search_axis1_over_group_temps_contig_impl<
188+
srcTy, dstTy, ReductionOpT, IndexOpT>;
189+
}
190+
}
191+
else {
192+
return nullptr;
193+
}
194+
}
195+
};
196+
197+
template <typename fnT, typename srcTy, typename dstTy>
198+
struct ArgmaxOverAxis0TempsContigFactory
199+
{
200+
fnT get() const
201+
{
202+
if constexpr (TypePairSupportForArgmaxReductionTemps<srcTy,
203+
dstTy>::is_defined)
204+
{
205+
if constexpr (std::is_integral_v<srcTy> &&
206+
!std::is_same_v<srcTy, bool>) {
207+
// op for values
208+
using ReductionOpT = sycl::maximum<srcTy>;
209+
// op for indices
210+
using IndexOpT = sycl::minimum<dstTy>;
211+
return dpctl::tensor::kernels::
212+
search_axis0_over_group_temps_contig_impl<
213+
srcTy, dstTy, ReductionOpT, IndexOpT>;
214+
}
215+
else {
216+
// op for values
217+
using ReductionOpT = su_ns::Maximum<srcTy>;
218+
// op for indices
219+
using IndexOpT = sycl::minimum<dstTy>;
220+
return dpctl::tensor::kernels::
221+
search_axis0_over_group_temps_contig_impl<
222+
srcTy, dstTy, ReductionOpT, IndexOpT>;
223+
}
224+
}
225+
else {
226+
return nullptr;
227+
}
228+
}
229+
};
230+
231+
void populate_argmax_over_axis_dispatch_tables(void)
232+
{
233+
using td_ns::DispatchTableBuilder;
234+
235+
DispatchTableBuilder<search_strided_impl_fn_ptr,
236+
ArgmaxOverAxisTempsStridedFactory, td_ns::num_types>
237+
dtb1;
238+
dtb1.populate_dispatch_table(argmax_over_axis_strided_temps_dispatch_table);
239+
240+
DispatchTableBuilder<search_contig_impl_fn_ptr,
241+
ArgmaxOverAxis1TempsContigFactory, td_ns::num_types>
242+
dtb2;
243+
dtb2.populate_dispatch_table(argmax_over_axis1_contig_temps_dispatch_table);
244+
245+
DispatchTableBuilder<search_contig_impl_fn_ptr,
246+
ArgmaxOverAxis0TempsContigFactory, td_ns::num_types>
247+
dtb3;
248+
dtb3.populate_dispatch_table(argmax_over_axis0_contig_temps_dispatch_table);
249+
}
250+
251+
} // namespace impl
252+
253+
void init_argmax(py::module_ m)
254+
{
255+
using arrayT = dpctl::tensor::usm_ndarray;
256+
using event_vecT = std::vector<sycl::event>;
257+
{
258+
using impl::populate_argmax_over_axis_dispatch_tables;
259+
populate_argmax_over_axis_dispatch_tables();
260+
using impl::argmax_over_axis0_contig_temps_dispatch_table;
261+
using impl::argmax_over_axis1_contig_temps_dispatch_table;
262+
using impl::argmax_over_axis_strided_temps_dispatch_table;
263+
264+
auto argmax_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce,
265+
const arrayT &dst, sycl::queue &exec_q,
266+
const event_vecT &depends = {}) {
267+
using dpctl::tensor::py_internal::py_search_over_axis;
268+
return py_search_over_axis(
269+
src, trailing_dims_to_reduce, dst, exec_q, depends,
270+
argmax_over_axis_strided_temps_dispatch_table,
271+
argmax_over_axis0_contig_temps_dispatch_table,
272+
argmax_over_axis1_contig_temps_dispatch_table);
273+
};
274+
m.def("_argmax_over_axis", argmax_pyapi, "", py::arg("src"),
275+
py::arg("trailing_dims_to_reduce"), py::arg("dst"),
276+
py::arg("sycl_queue"), py::arg("depends") = py::list());
277+
}
278+
}
279+
280+
} // namespace dpctl::tensor::py_internal
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_reductions_impl
33+
/// extension.
34+
//===---------------------------------------------------------------------===//
35+
36+
#pragma once
37+
#include <pybind11/pybind11.h>
38+
39+
namespace py = pybind11;
40+
41+
namespace dpctl::tensor::py_internal
42+
{
43+
44+
extern void init_argmax(py::module_ m);
45+
46+
} // namespace dpctl::tensor::py_internal

0 commit comments

Comments
 (0)