Skip to content

Commit 8b010b0

Browse files
Move ti.searchsorted() and reuse it in dpnp
1 parent 148b6e5 commit 8b010b0

4 files changed

Lines changed: 196 additions & 3 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282

8383
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
8484
from ._clip import clip
85+
from ._searchsorted import searchsorted
8586
from ._set_functions import ( # isin,; unique_all,; unique_inverse,
8687
unique_counts,
8788
unique_values,
@@ -130,6 +131,7 @@
130131
"reshape",
131132
"result_type",
132133
"roll",
134+
"searchsorted",
133135
"sort",
134136
"squeeze",
135137
"stack",

dpctl_ext/tensor/_searchsorted.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
30+
from typing import Literal, Union
31+
32+
import dpctl
33+
import dpctl.utils as du
34+
35+
# TODO: revert to `from ._usmarray import...`
36+
# when dpnp fully migrates dpctl/tensor
37+
from dpctl.tensor._usmarray import usm_ndarray
38+
39+
from ._copy_utils import _empty_like_orderK
40+
from ._ctors import empty
41+
from ._tensor_impl import _copy_usm_ndarray_into_usm_ndarray as ti_copy
42+
from ._tensor_impl import _take as ti_take
43+
from ._tensor_impl import (
44+
default_device_index_type as ti_default_device_index_type,
45+
)
46+
from ._tensor_sorting_impl import _searchsorted_left, _searchsorted_right
47+
from ._type_utils import isdtype, result_type
48+
49+
50+
def searchsorted(
51+
x1: usm_ndarray,
52+
x2: usm_ndarray,
53+
/,
54+
*,
55+
side: Literal["left", "right"] = "left",
56+
sorter: Union[usm_ndarray, None] = None,
57+
) -> usm_ndarray:
58+
"""searchsorted(x1, x2, side='left', sorter=None)
59+
60+
Finds the indices into `x1` such that, if the corresponding elements
61+
in `x2` were inserted before the indices, the order of `x1`, when sorted
62+
in ascending order, would be preserved.
63+
64+
Args:
65+
x1 (usm_ndarray):
66+
input array. Must be a one-dimensional array. If `sorter` is
67+
`None`, must be sorted in ascending order; otherwise, `sorter` must
68+
be an array of indices that sort `x1` in ascending order.
69+
x2 (usm_ndarray):
70+
array containing search values.
71+
side (Literal["left", "right]):
72+
argument controlling which index is returned if a value lands
73+
exactly on an edge. If `x2` is an array of rank `N` where
74+
`v = x2[n, m, ..., j]`, the element `ret[n, m, ..., j]` in the
75+
return array `ret` contains the position `i` such that
76+
if `side="left"`, it is the first index such that
77+
`x1[i-1] < v <= x1[i]`, `0` if `v <= x1[0]`, and `x1.size`
78+
if `v > x1[-1]`;
79+
and if `side="right"`, it is the first position `i` such that
80+
`x1[i-1] <= v < x1[i]`, `0` if `v < x1[0]`, and `x1.size`
81+
if `v >= x1[-1]`. Default: `"left"`.
82+
sorter (Optional[usm_ndarray]):
83+
array of indices that sort `x1` in ascending order. The array must
84+
have the same shape as `x1` and have an integral data type.
85+
Out of bound index values of `sorter` array are treated using
86+
`"wrap"` mode documented in :py:func:`dpctl.tensor.take`.
87+
Default: `None`.
88+
"""
89+
if not isinstance(x1, usm_ndarray):
90+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x1)}")
91+
if not isinstance(x2, usm_ndarray):
92+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x2)}")
93+
if sorter is not None and not isinstance(sorter, usm_ndarray):
94+
raise TypeError(
95+
f"Expected dpctl.tensor.usm_ndarray, got {type(sorter)}"
96+
)
97+
98+
if side not in ["left", "right"]:
99+
raise ValueError(
100+
"Unrecognized value of 'side' keyword argument. "
101+
"Expected either 'left' or 'right'"
102+
)
103+
104+
if sorter is None:
105+
q = du.get_execution_queue([x1.sycl_queue, x2.sycl_queue])
106+
else:
107+
q = du.get_execution_queue(
108+
[x1.sycl_queue, x2.sycl_queue, sorter.sycl_queue]
109+
)
110+
if q is None:
111+
raise du.ExecutionPlacementError(
112+
"Execution placement can not be unambiguously "
113+
"inferred from input arguments."
114+
)
115+
116+
if x1.ndim != 1:
117+
raise ValueError("First argument array must be one-dimensional")
118+
119+
x1_dt = x1.dtype
120+
x2_dt = x2.dtype
121+
122+
_manager = du.SequentialOrderManager[q]
123+
dep_evs = _manager.submitted_events
124+
ev = dpctl.SyclEvent()
125+
if sorter is not None:
126+
if not isdtype(sorter.dtype, "integral"):
127+
raise ValueError(
128+
f"Sorter array must have integral data type, got {sorter.dtype}"
129+
)
130+
if x1.shape != sorter.shape:
131+
raise ValueError(
132+
"Sorter array must be one-dimension with the same "
133+
"shape as the first argument array"
134+
)
135+
res = empty(x1.shape, dtype=x1_dt, usm_type=x1.usm_type, sycl_queue=q)
136+
ind = (sorter,)
137+
axis = 0
138+
wrap_out_of_bound_indices_mode = 0
139+
ht_ev, ev = ti_take(
140+
x1,
141+
ind,
142+
res,
143+
axis,
144+
wrap_out_of_bound_indices_mode,
145+
sycl_queue=q,
146+
depends=dep_evs,
147+
)
148+
x1 = res
149+
_manager.add_event_pair(ht_ev, ev)
150+
151+
if x1_dt != x2_dt:
152+
dt = result_type(x1, x2)
153+
if x1_dt != dt:
154+
x1_buf = _empty_like_orderK(x1, dt)
155+
dep_evs = _manager.submitted_events
156+
ht_ev, ev = ti_copy(
157+
src=x1, dst=x1_buf, sycl_queue=q, depends=dep_evs
158+
)
159+
_manager.add_event_pair(ht_ev, ev)
160+
x1 = x1_buf
161+
if x2_dt != dt:
162+
x2_buf = _empty_like_orderK(x2, dt)
163+
dep_evs = _manager.submitted_events
164+
ht_ev, ev = ti_copy(
165+
src=x2, dst=x2_buf, sycl_queue=q, depends=dep_evs
166+
)
167+
_manager.add_event_pair(ht_ev, ev)
168+
x2 = x2_buf
169+
170+
dst_usm_type = du.get_coerced_usm_type([x1.usm_type, x2.usm_type])
171+
index_dt = ti_default_device_index_type(q)
172+
173+
dst = _empty_like_orderK(x2, index_dt, usm_type=dst_usm_type)
174+
175+
dep_evs = _manager.submitted_events
176+
if side == "left":
177+
ht_ev, s_ev = _searchsorted_left(
178+
hay=x1,
179+
needles=x2,
180+
positions=dst,
181+
sycl_queue=q,
182+
depends=dep_evs,
183+
)
184+
else:
185+
ht_ev, s_ev = _searchsorted_right(
186+
hay=x1, needles=x2, positions=dst, sycl_queue=q, depends=dep_evs
187+
)
188+
_manager.add_event_pair(ht_ev, s_ev)
189+
return dst

dpnp/dpnp_iface_manipulation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,10 @@ def _get_first_nan_index(usm_a):
378378
true_val = dpt_ext.asarray(
379379
True, sycl_queue=usm_a.sycl_queue, usm_type=usm_a.usm_type
380380
)
381-
return dpt.searchsorted(dpt.isnan(usm_a), true_val, side="left")
382-
return dpt.searchsorted(usm_a, usm_a[-1], side="left")
381+
return dpt_ext.searchsorted(
382+
dpt.isnan(usm_a), true_val, side="left"
383+
)
384+
return dpt_ext.searchsorted(usm_a, usm_a[-1], side="left")
383385
return None
384386

385387
usm_ar = dpnp.get_usm_ndarray(ar)

dpnp/dpnp_iface_searching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def searchsorted(a, v, side="left", sorter=None):
382382

383383
usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter)
384384
return dpnp_array._create_from_usm_ndarray(
385-
dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter)
385+
dpt_ext.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter)
386386
)
387387

388388

0 commit comments

Comments
 (0)