Skip to content

Commit 893cdc3

Browse files
Move _radix_argsort_ascending/descending to _tensor_sorting_impl
1 parent 82d202c commit 893cdc3

5 files changed

Lines changed: 428 additions & 6 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ set(_accumulator_sources
7272
set(_sorting_sources
7373
#{CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp
7474
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
75-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
75+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
7676
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
77-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp
77+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_argsort.cpp
7878
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/searchsorted.cpp
7979
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/topk.cpp
8080
)
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
33+
/// extension.
34+
//===----------------------------------------------------------------------===//
35+
36+
#pragma once
37+
38+
#include <cassert>
39+
#include <cstddef>
40+
#include <utility>
41+
#include <vector>
42+
43+
#include <sycl/sycl.hpp>
44+
45+
#include "dpnp4pybind11.hpp"
46+
#include <pybind11/pybind11.h>
47+
#include <pybind11/stl.h>
48+
49+
#include "utils/memory_overlap.hpp"
50+
#include "utils/output_validation.hpp"
51+
#include "utils/type_dispatch.hpp"
52+
53+
namespace td_ns = dpctl::tensor::type_dispatch;
54+
55+
namespace dpctl::tensor::py_internal
56+
{
57+
58+
template <typename sorting_contig_impl_fnT>
59+
std::pair<sycl::event, sycl::event>
60+
py_argsort(const dpctl::tensor::usm_ndarray &src,
61+
const int trailing_dims_to_sort,
62+
const dpctl::tensor::usm_ndarray &dst,
63+
sycl::queue &exec_q,
64+
const std::vector<sycl::event> &depends,
65+
const sorting_contig_impl_fnT &sort_contig_fns)
66+
{
67+
int src_nd = src.get_ndim();
68+
int dst_nd = dst.get_ndim();
69+
if (src_nd != dst_nd) {
70+
throw py::value_error("The input and output arrays must have "
71+
"the same array ranks");
72+
}
73+
int iteration_nd = src_nd - trailing_dims_to_sort;
74+
if (trailing_dims_to_sort <= 0 || iteration_nd < 0) {
75+
throw py::value_error("Trailing_dim_to_sort must be positive, but no "
76+
"greater than rank of the array being sorted");
77+
}
78+
79+
const py::ssize_t *src_shape_ptr = src.get_shape_raw();
80+
const py::ssize_t *dst_shape_ptr = dst.get_shape_raw();
81+
82+
bool same_shapes = true;
83+
std::size_t iter_nelems(1);
84+
85+
for (int i = 0; same_shapes && (i < iteration_nd); ++i) {
86+
auto src_shape_i = src_shape_ptr[i];
87+
same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]);
88+
iter_nelems *= static_cast<std::size_t>(src_shape_i);
89+
}
90+
91+
std::size_t sort_nelems(1);
92+
for (int i = iteration_nd; same_shapes && (i < src_nd); ++i) {
93+
auto src_shape_i = src_shape_ptr[i];
94+
same_shapes = same_shapes && (src_shape_i == dst_shape_ptr[i]);
95+
sort_nelems *= static_cast<std::size_t>(src_shape_i);
96+
}
97+
98+
if (!same_shapes) {
99+
throw py::value_error(
100+
"Destination shape does not match the input shape");
101+
}
102+
103+
if (!dpctl::utils::queues_are_compatible(exec_q, {src, dst})) {
104+
throw py::value_error(
105+
"Execution queue is not compatible with allocation queues");
106+
}
107+
108+
dpctl::tensor::validation::CheckWritable::throw_if_not_writable(dst);
109+
110+
if ((iter_nelems == 0) || (sort_nelems == 0)) {
111+
// Nothing to do
112+
return std::make_pair(sycl::event(), sycl::event());
113+
}
114+
115+
// check that dst and src do not overlap
116+
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap();
117+
if (overlap(src, dst)) {
118+
throw py::value_error("Arrays index overlapping segments of memory");
119+
}
120+
121+
dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(
122+
dst, sort_nelems * iter_nelems);
123+
124+
int src_typenum = src.get_typenum();
125+
int dst_typenum = dst.get_typenum();
126+
127+
const auto &array_types = td_ns::usm_ndarray_types();
128+
int src_typeid = array_types.typenum_to_lookup_id(src_typenum);
129+
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);
130+
131+
if ((dst_typeid != static_cast<int>(td_ns::typenum_t::INT64)) &&
132+
(dst_typeid != static_cast<int>(td_ns::typenum_t::INT32)))
133+
{
134+
throw py::value_error(
135+
"Output index array must have data type int32 or int64");
136+
}
137+
138+
bool is_src_c_contig = src.is_c_contiguous();
139+
bool is_dst_c_contig = dst.is_c_contiguous();
140+
141+
if (is_src_c_contig && is_dst_c_contig) {
142+
if (sort_nelems > 1) {
143+
static constexpr py::ssize_t zero_offset = py::ssize_t(0);
144+
145+
auto fn = sort_contig_fns[src_typeid][dst_typeid];
146+
147+
if (fn == nullptr) {
148+
throw py::value_error(
149+
"Not implemented for dtypes of input arrays");
150+
}
151+
152+
sycl::event comp_ev =
153+
fn(exec_q, iter_nelems, sort_nelems, src.get_data(),
154+
dst.get_data(), zero_offset, zero_offset, zero_offset,
155+
zero_offset, depends);
156+
157+
sycl::event keep_args_alive_ev =
158+
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {comp_ev});
159+
160+
return std::make_pair(keep_args_alive_ev, comp_ev);
161+
}
162+
else {
163+
assert(dst.get_size() == iter_nelems);
164+
int dst_elemsize = dst.get_elemsize();
165+
static constexpr int memset_val(0);
166+
167+
sycl::event fill_ev = exec_q.submit([&](sycl::handler &cgh) {
168+
cgh.depends_on(depends);
169+
170+
cgh.memset(reinterpret_cast<void *>(dst.get_data()), memset_val,
171+
iter_nelems * dst_elemsize);
172+
});
173+
174+
sycl::event keep_args_alive_ev =
175+
dpctl::utils::keep_args_alive(exec_q, {src, dst}, {fill_ev});
176+
177+
return std::make_pair(keep_args_alive_ev, fill_ev);
178+
}
179+
}
180+
181+
throw py::value_error(
182+
"Both source and destination arrays must be C-contiguous");
183+
}
184+
185+
} // namespace dpctl::tensor::py_internal
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
33+
/// extension.
34+
//===----------------------------------------------------------------------===//
35+
36+
#include <cstddef>
37+
#include <cstdint>
38+
#include <type_traits>
39+
#include <utility>
40+
#include <vector>
41+
42+
#include <sycl/sycl.hpp>
43+
44+
#include "dpnp4pybind11.hpp"
45+
#include <pybind11/pybind11.h>
46+
#include <pybind11/stl.h>
47+
48+
#include "utils/type_dispatch.hpp"
49+
50+
#include "kernels/dpctl_tensor_types.hpp"
51+
#include "kernels/sorting/radix_sort.hpp"
52+
#include "kernels/sorting/sort_impl_fn_ptr_t.hpp"
53+
54+
#include "py_argsort_common.hpp"
55+
#include "radix_argsort.hpp"
56+
#include "radix_sort_support.hpp"
57+
58+
namespace dpctl::tensor::py_internal
59+
{
60+
61+
namespace py = pybind11;
62+
namespace td_ns = dpctl::tensor::type_dispatch;
63+
namespace impl_ns = dpctl::tensor::kernels::radix_sort_details;
64+
65+
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
66+
67+
static sort_contig_fn_ptr_t
68+
ascending_radix_argsort_contig_dispatch_table[td_ns::num_types]
69+
[td_ns::num_types];
70+
static sort_contig_fn_ptr_t
71+
descending_radix_argsort_contig_dispatch_table[td_ns::num_types]
72+
[td_ns::num_types];
73+
74+
namespace
75+
{
76+
77+
template <bool is_ascending, typename T, typename I>
78+
sycl::event argsort_axis1_contig_caller(sycl::queue &q,
79+
std::size_t iter_nelems,
80+
std::size_t sort_nelems,
81+
const char *arg_cp,
82+
char *res_cp,
83+
ssize_t iter_arg_offset,
84+
ssize_t iter_res_offset,
85+
ssize_t sort_arg_offset,
86+
ssize_t sort_res_offset,
87+
const std::vector<sycl::event> &depends)
88+
{
89+
using dpctl::tensor::kernels::radix_argsort_axis1_contig_impl;
90+
91+
return radix_argsort_axis1_contig_impl<T, I>(
92+
q, is_ascending, iter_nelems, sort_nelems, arg_cp, res_cp,
93+
iter_arg_offset, iter_res_offset, sort_arg_offset, sort_res_offset,
94+
depends);
95+
}
96+
97+
} // end of anonymous namespace
98+
99+
template <typename fnT, typename argTy, typename IndexTy>
100+
struct AscendingRadixArgSortContigFactory
101+
{
102+
fnT get()
103+
{
104+
if constexpr (RadixSortSupportVector<argTy>::is_defined &&
105+
(std::is_same_v<IndexTy, std::int64_t> ||
106+
std::is_same_v<IndexTy, std::int32_t>))
107+
{
108+
return argsort_axis1_contig_caller<
109+
/*ascending*/ true, argTy, IndexTy>;
110+
}
111+
else {
112+
return nullptr;
113+
}
114+
}
115+
};
116+
117+
template <typename fnT, typename argTy, typename IndexTy>
118+
struct DescendingRadixArgSortContigFactory
119+
{
120+
fnT get()
121+
{
122+
if constexpr (RadixSortSupportVector<argTy>::is_defined &&
123+
(std::is_same_v<IndexTy, std::int64_t> ||
124+
std::is_same_v<IndexTy, std::int32_t>))
125+
{
126+
return argsort_axis1_contig_caller<
127+
/*ascending*/ false, argTy, IndexTy>;
128+
}
129+
else {
130+
return nullptr;
131+
}
132+
}
133+
};
134+
135+
void init_radix_argsort_dispatch_tables(void)
136+
{
137+
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
138+
139+
td_ns::DispatchTableBuilder<sort_contig_fn_ptr_t,
140+
AscendingRadixArgSortContigFactory,
141+
td_ns::num_types>
142+
dtb1;
143+
dtb1.populate_dispatch_table(ascending_radix_argsort_contig_dispatch_table);
144+
145+
td_ns::DispatchTableBuilder<sort_contig_fn_ptr_t,
146+
DescendingRadixArgSortContigFactory,
147+
td_ns::num_types>
148+
dtb2;
149+
dtb2.populate_dispatch_table(
150+
descending_radix_argsort_contig_dispatch_table);
151+
}
152+
153+
void init_radix_argsort_functions(py::module_ m)
154+
{
155+
dpctl::tensor::py_internal::init_radix_argsort_dispatch_tables();
156+
157+
auto py_radix_argsort_ascending =
158+
[](const dpctl::tensor::usm_ndarray &src,
159+
const int trailing_dims_to_sort,
160+
const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q,
161+
const std::vector<sycl::event> &depends)
162+
-> std::pair<sycl::event, sycl::event> {
163+
return dpctl::tensor::py_internal::py_argsort(
164+
src, trailing_dims_to_sort, dst, exec_q, depends,
165+
dpctl::tensor::py_internal::
166+
ascending_radix_argsort_contig_dispatch_table);
167+
};
168+
m.def("_radix_argsort_ascending", py_radix_argsort_ascending,
169+
py::arg("src"), py::arg("trailing_dims_to_sort"), py::arg("dst"),
170+
py::arg("sycl_queue"), py::arg("depends") = py::list());
171+
172+
auto py_radix_argsort_descending =
173+
[](const dpctl::tensor::usm_ndarray &src,
174+
const int trailing_dims_to_sort,
175+
const dpctl::tensor::usm_ndarray &dst, sycl::queue &exec_q,
176+
const std::vector<sycl::event> &depends)
177+
-> std::pair<sycl::event, sycl::event> {
178+
return dpctl::tensor::py_internal::py_argsort(
179+
src, trailing_dims_to_sort, dst, exec_q, depends,
180+
dpctl::tensor::py_internal::
181+
descending_radix_argsort_contig_dispatch_table);
182+
};
183+
m.def("_radix_argsort_descending", py_radix_argsort_descending,
184+
py::arg("src"), py::arg("trailing_dims_to_sort"), py::arg("dst"),
185+
py::arg("sycl_queue"), py::arg("depends") = py::list());
186+
187+
return;
188+
}
189+
190+
} // namespace dpctl::tensor::py_internal

0 commit comments

Comments
 (0)