Skip to content

Commit b62d836

Browse files
Move ti.sort() and reuse it in dpnp
1 parent 1f018ec commit b62d836

3 files changed

Lines changed: 166 additions & 1 deletion

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282

8383
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
8484
from ._clip import clip
85+
from ._sorting import sort
8586
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
8687

8788
__all__ = [
@@ -124,6 +125,7 @@
124125
"reshape",
125126
"result_type",
126127
"roll",
128+
"sort",
127129
"squeeze",
128130
"stack",
129131
"swapaxes",

dpctl_ext/tensor/_sorting.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
# import operator
30+
# from typing import NamedTuple
31+
32+
import dpctl.tensor as dpt
33+
import dpctl.utils as du
34+
35+
# TODO: revert to `import dpctl.tensor...`
36+
# when dpnp fully migrates dpctl/tensor
37+
import dpctl_ext.tensor as dpt_ext
38+
import dpctl_ext.tensor._tensor_impl as ti
39+
40+
from ._numpy_helper import normalize_axis_index
41+
from ._tensor_sorting_impl import ( # _argsort_ascending,; _argsort_descending,; _radix_argsort_ascending,; _radix_argsort_descending,; _topk,
42+
_radix_sort_ascending,
43+
_radix_sort_descending,
44+
_radix_sort_dtype_supported,
45+
_sort_ascending,
46+
_sort_descending,
47+
)
48+
49+
__all__ = ["sort"]
50+
51+
52+
def _get_mergesort_impl_fn(descending):
53+
return _sort_descending if descending else _sort_ascending
54+
55+
56+
def _get_radixsort_impl_fn(descending):
57+
return _radix_sort_descending if descending else _radix_sort_ascending
58+
59+
60+
def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
61+
"""sort(x, axis=-1, descending=False, stable=True)
62+
63+
Returns a sorted copy of an input array `x`.
64+
65+
Args:
66+
x (usm_ndarray):
67+
input array.
68+
axis (Optional[int]):
69+
axis along which to sort. If set to `-1`, the function
70+
must sort along the last axis. Default: `-1`.
71+
descending (Optional[bool]):
72+
sort order. If `True`, the array must be sorted in descending
73+
order (by value). If `False`, the array must be sorted in
74+
ascending order (by value). Default: `False`.
75+
stable (Optional[bool]):
76+
sort stability. If `True`, the returned array must maintain the
77+
relative order of `x` values which compare as equal. If `False`,
78+
the returned array may or may not maintain the relative order of
79+
`x` values which compare as equal. Default: `True`.
80+
kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
81+
Sorting algorithm. The default is `"stable"`, which uses parallel
82+
merge-sort or parallel radix-sort algorithms depending on the
83+
array data type.
84+
Returns:
85+
usm_ndarray:
86+
a sorted array. The returned array has the same data type and
87+
the same shape as the input array `x`.
88+
"""
89+
if not isinstance(x, dpt.usm_ndarray):
90+
raise TypeError(
91+
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
92+
)
93+
nd = x.ndim
94+
if nd == 0:
95+
axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
96+
return dpt_ext.copy(x, order="C")
97+
else:
98+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
99+
a1 = axis + 1
100+
if a1 == nd:
101+
perm = list(range(nd))
102+
arr = x
103+
else:
104+
perm = [i for i in range(nd) if i != axis] + [
105+
axis,
106+
]
107+
arr = dpt_ext.permute_dims(x, perm)
108+
if kind is None:
109+
kind = "stable"
110+
if not isinstance(kind, str) or kind not in [
111+
"stable",
112+
"radixsort",
113+
"mergesort",
114+
]:
115+
raise ValueError(
116+
"Unsupported kind value. Expected 'stable', 'mergesort', "
117+
f"or 'radixsort', but got '{kind}'"
118+
)
119+
if kind == "mergesort":
120+
impl_fn = _get_mergesort_impl_fn(descending)
121+
elif kind == "radixsort":
122+
if _radix_sort_dtype_supported(x.dtype.num):
123+
impl_fn = _get_radixsort_impl_fn(descending)
124+
else:
125+
raise ValueError(f"Radix sort is not supported for {x.dtype}")
126+
else:
127+
dt = x.dtype
128+
if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
129+
impl_fn = _get_radixsort_impl_fn(descending)
130+
else:
131+
impl_fn = _get_mergesort_impl_fn(descending)
132+
exec_q = x.sycl_queue
133+
_manager = du.SequentialOrderManager[exec_q]
134+
dep_evs = _manager.submitted_events
135+
if arr.flags.c_contiguous:
136+
res = dpt_ext.empty_like(arr, order="C")
137+
ht_ev, impl_ev = impl_fn(
138+
src=arr,
139+
trailing_dims_to_sort=1,
140+
dst=res,
141+
sycl_queue=exec_q,
142+
depends=dep_evs,
143+
)
144+
_manager.add_event_pair(ht_ev, impl_ev)
145+
else:
146+
tmp = dpt_ext.empty_like(arr, order="C")
147+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
148+
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
149+
)
150+
_manager.add_event_pair(ht_ev, copy_ev)
151+
res = dpt_ext.empty_like(arr, order="C")
152+
ht_ev, impl_ev = impl_fn(
153+
src=tmp,
154+
trailing_dims_to_sort=1,
155+
dst=res,
156+
sycl_queue=exec_q,
157+
depends=[copy_ev],
158+
)
159+
_manager.add_event_pair(ht_ev, impl_ev)
160+
if a1 != nd:
161+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
162+
res = dpt_ext.permute_dims(res, inv_perm)
163+
return res

dpnp/dpnp_iface_sorting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def sort(a, axis=-1, kind=None, order=None, *, descending=False, stable=None):
404404

405405
return _wrap_sort_argsort(
406406
a,
407-
dpt.sort,
407+
dpt_ext.sort,
408408
axis=axis,
409409
kind=kind,
410410
order=order,

0 commit comments

Comments
 (0)