Skip to content

Commit 6912311

Browse files
Move ti.argsort() to dpctl_ext.tensor and reuse it in dpnp
1 parent 40c2b84 commit 6912311

3 files changed

Lines changed: 131 additions & 9 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,12 @@
8686
unique_counts,
8787
unique_values,
8888
)
89-
from ._sorting import sort
89+
from ._sorting import argsort, sort
9090
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
9191

9292
__all__ = [
9393
"arange",
94+
"argsort",
9495
"asarray",
9596
"asnumpy",
9697
"astype",

dpctl_ext/tensor/_sorting.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,19 @@
3838
import dpctl_ext.tensor._tensor_impl as ti
3939

4040
from ._numpy_helper import normalize_axis_index
41-
from ._tensor_sorting_impl import ( # _argsort_ascending,; _argsort_descending,; _radix_argsort_ascending,; _radix_argsort_descending,; _topk,
41+
from ._tensor_sorting_impl import (
42+
_argsort_ascending,
43+
_argsort_descending,
44+
_radix_argsort_ascending,
45+
_radix_argsort_descending,
4246
_radix_sort_ascending,
4347
_radix_sort_descending,
4448
_radix_sort_dtype_supported,
4549
_sort_ascending,
4650
_sort_descending,
4751
)
4852

49-
__all__ = ["sort"]
53+
__all__ = ["sort", "argsort"]
5054

5155

5256
def _get_mergesort_impl_fn(descending):
@@ -161,3 +165,122 @@ def sort(x, /, *, axis=-1, descending=False, stable=True, kind=None):
161165
inv_perm = sorted(range(nd), key=lambda d: perm[d])
162166
res = dpt_ext.permute_dims(res, inv_perm)
163167
return res
168+
169+
170+
def _get_mergeargsort_impl_fn(descending):
171+
return _argsort_descending if descending else _argsort_ascending
172+
173+
174+
def _get_radixargsort_impl_fn(descending):
175+
return _radix_argsort_descending if descending else _radix_argsort_ascending
176+
177+
178+
def argsort(x, axis=-1, descending=False, stable=True, kind=None):
179+
"""argsort(x, axis=-1, descending=False, stable=True)
180+
181+
Returns the indices that sort an array `x` along a specified axis.
182+
183+
Args:
184+
x (usm_ndarray):
185+
input array.
186+
axis (Optional[int]):
187+
axis along which to sort. If set to `-1`, the function
188+
must sort along the last axis. Default: `-1`.
189+
descending (Optional[bool]):
190+
sort order. If `True`, the array must be sorted in descending
191+
order (by value). If `False`, the array must be sorted in
192+
ascending order (by value). Default: `False`.
193+
stable (Optional[bool]):
194+
sort stability. If `True`, the returned array must maintain the
195+
relative order of `x` values which compare as equal. If `False`,
196+
the returned array may or may not maintain the relative order of
197+
`x` values which compare as equal. Default: `True`.
198+
kind (Optional[Literal["stable", "mergesort", "radixsort"]]):
199+
Sorting algorithm. The default is `"stable"`, which uses parallel
200+
merge-sort or parallel radix-sort algorithms depending on the
201+
array data type.
202+
203+
Returns:
204+
usm_ndarray:
205+
an array of indices. The returned array has the same shape as
206+
the input array `x`. The return array has default array index
207+
data type.
208+
"""
209+
if not isinstance(x, dpt.usm_ndarray):
210+
raise TypeError(
211+
f"Expected type dpctl.tensor.usm_ndarray, got {type(x)}"
212+
)
213+
nd = x.ndim
214+
if nd == 0:
215+
axis = normalize_axis_index(axis, ndim=1, msg_prefix="axis")
216+
return dpt_ext.zeros_like(
217+
x, dtype=ti.default_device_index_type(x.sycl_queue), order="C"
218+
)
219+
else:
220+
axis = normalize_axis_index(axis, ndim=nd, msg_prefix="axis")
221+
a1 = axis + 1
222+
if a1 == nd:
223+
perm = list(range(nd))
224+
arr = x
225+
else:
226+
perm = [i for i in range(nd) if i != axis] + [
227+
axis,
228+
]
229+
arr = dpt_ext.permute_dims(x, perm)
230+
if kind is None:
231+
kind = "stable"
232+
if not isinstance(kind, str) or kind not in [
233+
"stable",
234+
"radixsort",
235+
"mergesort",
236+
]:
237+
raise ValueError(
238+
"Unsupported kind value. Expected 'stable', 'mergesort', "
239+
f"or 'radixsort', but got '{kind}'"
240+
)
241+
if kind == "mergesort":
242+
impl_fn = _get_mergeargsort_impl_fn(descending)
243+
elif kind == "radixsort":
244+
if _radix_sort_dtype_supported(x.dtype.num):
245+
impl_fn = _get_radixargsort_impl_fn(descending)
246+
else:
247+
raise ValueError(f"Radix sort is not supported for {x.dtype}")
248+
else:
249+
dt = x.dtype
250+
if dt in [dpt.bool, dpt.uint8, dpt.int8, dpt.int16, dpt.uint16]:
251+
impl_fn = _get_radixargsort_impl_fn(descending)
252+
else:
253+
impl_fn = _get_mergeargsort_impl_fn(descending)
254+
exec_q = x.sycl_queue
255+
_manager = du.SequentialOrderManager[exec_q]
256+
dep_evs = _manager.submitted_events
257+
index_dt = ti.default_device_index_type(exec_q)
258+
if arr.flags.c_contiguous:
259+
res = dpt_ext.empty_like(arr, dtype=index_dt, order="C")
260+
ht_ev, impl_ev = impl_fn(
261+
src=arr,
262+
trailing_dims_to_sort=1,
263+
dst=res,
264+
sycl_queue=exec_q,
265+
depends=dep_evs,
266+
)
267+
_manager.add_event_pair(ht_ev, impl_ev)
268+
else:
269+
tmp = dpt_ext.empty_like(arr, order="C")
270+
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
271+
src=arr, dst=tmp, sycl_queue=exec_q, depends=dep_evs
272+
)
273+
_manager.add_event_pair(ht_ev, copy_ev)
274+
res = dpt_ext.empty_like(arr, dtype=index_dt, order="C")
275+
ht_ev, impl_ev = impl_fn(
276+
src=tmp,
277+
trailing_dims_to_sort=1,
278+
dst=res,
279+
sycl_queue=exec_q,
280+
depends=[copy_ev],
281+
)
282+
_manager.add_event_pair(ht_ev, impl_ev)
283+
if a1 != nd:
284+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
285+
res = dpt_ext.permute_dims(res, inv_perm)
286+
return res

dpnp/dpnp_iface_sorting.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,11 @@
4141

4242
from collections.abc import Sequence
4343

44-
import dpctl.tensor as dpt
45-
from dpctl.tensor._numpy_helper import normalize_axis_index
46-
4744
# TODO: revert to `import dpctl.tensor...`
4845
# when dpnp fully migrates dpctl/tensor
49-
import dpctl_ext.tensor as dpt_ext
46+
import dpctl_ext.tensor as dpt
5047
import dpnp
48+
from dpctl_ext.tensor._numpy_helper import normalize_axis_index
5149

5250
# pylint: disable=no-name-in-module
5351
from .dpnp_algo import (
@@ -87,7 +85,7 @@ def _wrap_sort_argsort(
8785

8886
usm_a = dpnp.get_usm_ndarray(a)
8987
if axis is None:
90-
usm_a = dpt_ext.reshape(usm_a, -1)
88+
usm_a = dpt.reshape(usm_a, -1)
9189
axis = -1
9290

9391
axis = normalize_axis_index(axis, ndim=usm_a.ndim)
@@ -404,7 +402,7 @@ def sort(a, axis=-1, kind=None, order=None, *, descending=False, stable=None):
404402

405403
return _wrap_sort_argsort(
406404
a,
407-
dpt_ext.sort,
405+
dpt.sort,
408406
axis=axis,
409407
kind=kind,
410408
order=order,

0 commit comments

Comments
 (0)