Skip to content

Commit 40c2b84

Browse files
Move _argsort_ascending/descending to _tensor_sorting_impl
1 parent 893cdc3 commit 40c2b84

2 files changed

Lines changed: 206 additions & 0 deletions

File tree

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
33+
/// extension.
34+
//===----------------------------------------------------------------------===//
35+
36+
#include <cstdint>
37+
#include <type_traits>
38+
39+
#include <sycl/sycl.hpp>
40+
41+
#include "dpnp4pybind11.hpp"
42+
#include <pybind11/pybind11.h>
43+
#include <pybind11/stl.h>
44+
45+
#include "utils/rich_comparisons.hpp"
46+
#include "utils/type_dispatch.hpp"
47+
48+
#include "kernels/sorting/merge_sort.hpp"
49+
#include "kernels/sorting/sort_impl_fn_ptr_t.hpp"
50+
51+
#include "merge_argsort.hpp"
52+
#include "py_argsort_common.hpp"
53+
54+
namespace dpctl::tensor::py_internal
55+
{
56+
57+
namespace py = pybind11;
58+
namespace td_ns = dpctl::tensor::type_dispatch;
59+
60+
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
61+
static sort_contig_fn_ptr_t
62+
ascending_argsort_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
63+
static sort_contig_fn_ptr_t
64+
descending_argsort_contig_dispatch_table[td_ns::num_types]
65+
[td_ns::num_types];
66+
67+
template <typename fnT, typename argTy, typename IndexTy>
68+
struct AscendingArgSortContigFactory
69+
{
70+
fnT get()
71+
{
72+
if constexpr (std::is_same_v<IndexTy, std::int64_t> ||
73+
std::is_same_v<IndexTy, std::int32_t>)
74+
{
75+
using dpctl::tensor::rich_comparisons::AscendingSorter;
76+
using Comp = typename AscendingSorter<argTy>::type;
77+
78+
using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl;
79+
return stable_argsort_axis1_contig_impl<argTy, IndexTy, Comp>;
80+
}
81+
else {
82+
return nullptr;
83+
}
84+
}
85+
};
86+
87+
template <typename fnT, typename argTy, typename IndexTy>
88+
struct DescendingArgSortContigFactory
89+
{
90+
fnT get()
91+
{
92+
if constexpr (std::is_same_v<IndexTy, std::int64_t> ||
93+
std::is_same_v<IndexTy, std::int32_t>)
94+
{
95+
using dpctl::tensor::rich_comparisons::DescendingSorter;
96+
using Comp = typename DescendingSorter<argTy>::type;
97+
98+
using dpctl::tensor::kernels::stable_argsort_axis1_contig_impl;
99+
return stable_argsort_axis1_contig_impl<argTy, IndexTy, Comp>;
100+
}
101+
else {
102+
return nullptr;
103+
}
104+
}
105+
};
106+
107+
void init_merge_argsort_dispatch_tables(void)
108+
{
109+
using dpctl::tensor::kernels::sort_contig_fn_ptr_t;
110+
111+
td_ns::DispatchTableBuilder<sort_contig_fn_ptr_t,
112+
AscendingArgSortContigFactory, td_ns::num_types>
113+
dtb1;
114+
dtb1.populate_dispatch_table(ascending_argsort_contig_dispatch_table);
115+
116+
td_ns::DispatchTableBuilder<
117+
sort_contig_fn_ptr_t, DescendingArgSortContigFactory, td_ns::num_types>
118+
dtb2;
119+
dtb2.populate_dispatch_table(descending_argsort_contig_dispatch_table);
120+
}
121+
122+
void init_merge_argsort_functions(py::module_ m)
123+
{
124+
dpctl::tensor::py_internal::init_merge_argsort_dispatch_tables();
125+
126+
auto py_argsort_ascending = [](const dpctl::tensor::usm_ndarray &src,
127+
const int trailing_dims_to_sort,
128+
const dpctl::tensor::usm_ndarray &dst,
129+
sycl::queue &exec_q,
130+
const std::vector<sycl::event> &depends)
131+
-> std::pair<sycl::event, sycl::event> {
132+
return dpctl::tensor::py_internal::py_argsort(
133+
src, trailing_dims_to_sort, dst, exec_q, depends,
134+
dpctl::tensor::py_internal::
135+
ascending_argsort_contig_dispatch_table);
136+
};
137+
m.def("_argsort_ascending", py_argsort_ascending, py::arg("src"),
138+
py::arg("trailing_dims_to_sort"), py::arg("dst"),
139+
py::arg("sycl_queue"), py::arg("depends") = py::list());
140+
141+
auto py_argsort_descending = [](const dpctl::tensor::usm_ndarray &src,
142+
const int trailing_dims_to_sort,
143+
const dpctl::tensor::usm_ndarray &dst,
144+
sycl::queue &exec_q,
145+
const std::vector<sycl::event> &depends)
146+
-> std::pair<sycl::event, sycl::event> {
147+
return dpctl::tensor::py_internal::py_argsort(
148+
src, trailing_dims_to_sort, dst, exec_q, depends,
149+
dpctl::tensor::py_internal::
150+
descending_argsort_contig_dispatch_table);
151+
};
152+
m.def("_argsort_descending", py_argsort_descending, py::arg("src"),
153+
py::arg("trailing_dims_to_sort"), py::arg("dst"),
154+
py::arg("sycl_queue"), py::arg("depends") = py::list());
155+
156+
return;
157+
}
158+
159+
} // namespace dpctl::tensor::py_internal
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_sorting_impl
33+
/// extension.
34+
//===----------------------------------------------------------------------===//
35+
36+
#pragma once
37+
38+
#include <pybind11/pybind11.h>
39+
40+
namespace py = pybind11;
41+
42+
namespace dpctl::tensor::py_internal
43+
{
44+
45+
extern void init_merge_argsort_functions(py::module_);
46+
47+
} // namespace dpctl::tensor::py_internal

0 commit comments

Comments
 (0)