Skip to content

Commit 82d202c

Browse files
Move ti.unique_counts() and ti.unique_values() to dpctl_ext.tensor
1 parent b62d836 commit 82d202c

3 files changed

Lines changed: 282 additions & 2 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@
8282

8383
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
8484
from ._clip import clip
85+
from ._set_functions import ( # isin,; unique_all,; unique_inverse,
86+
unique_counts,
87+
unique_values,
88+
)
8589
from ._sorting import sort
8690
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
8791

@@ -135,6 +139,8 @@
135139
"to_numpy",
136140
"tril",
137141
"triu",
142+
"unique_counts",
143+
"unique_values",
138144
"unstack",
139145
"where",
140146
"zeros",

dpctl_ext/tensor/_set_functions.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
from typing import NamedTuple
30+
31+
import dpctl.tensor as dpt
32+
import dpctl.utils as du
33+
from dpctl.tensor._tensor_elementwise_impl import _not_equal, _subtract
34+
35+
# TODO: revert to `import dpctl.tensor...`
36+
# when dpnp fully migrates dpctl/tensor
37+
import dpctl_ext.tensor as dpt_ext
38+
39+
from ._tensor_impl import (
40+
_copy_usm_ndarray_into_usm_ndarray,
41+
_extract,
42+
_full_usm_ndarray,
43+
_linspace_step,
44+
default_device_index_type,
45+
mask_positions,
46+
)
47+
from ._tensor_sorting_impl import (
48+
_sort_ascending,
49+
)
50+
51+
52+
class UniqueCountsResult(NamedTuple):
53+
values: dpt.usm_ndarray
54+
counts: dpt.usm_ndarray
55+
56+
57+
def unique_values(x: dpt.usm_ndarray) -> dpt.usm_ndarray:
58+
"""unique_values(x)
59+
60+
Returns the unique elements of an input array `x`.
61+
62+
Args:
63+
x (usm_ndarray):
64+
input array. Inputs with more than one dimension are flattened.
65+
Returns:
66+
usm_ndarray
67+
an array containing the set of unique elements in `x`. The
68+
returned array has the same data type as `x`.
69+
"""
70+
if not isinstance(x, dpt.usm_ndarray):
71+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
72+
array_api_dev = x.device
73+
exec_q = array_api_dev.sycl_queue
74+
if x.ndim == 1:
75+
fx = x
76+
else:
77+
fx = dpt_ext.reshape(x, (x.size,), order="C")
78+
if fx.size == 0:
79+
return fx
80+
s = dpt_ext.empty_like(fx, order="C")
81+
_manager = du.SequentialOrderManager[exec_q]
82+
dep_evs = _manager.submitted_events
83+
if fx.flags.c_contiguous:
84+
ht_ev, sort_ev = _sort_ascending(
85+
src=fx,
86+
trailing_dims_to_sort=1,
87+
dst=s,
88+
sycl_queue=exec_q,
89+
depends=dep_evs,
90+
)
91+
_manager.add_event_pair(ht_ev, sort_ev)
92+
else:
93+
tmp = dpt_ext.empty_like(fx, order="C")
94+
ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray(
95+
src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs
96+
)
97+
_manager.add_event_pair(ht_ev, copy_ev)
98+
ht_ev, sort_ev = _sort_ascending(
99+
src=tmp,
100+
trailing_dims_to_sort=1,
101+
dst=s,
102+
sycl_queue=exec_q,
103+
depends=[copy_ev],
104+
)
105+
_manager.add_event_pair(ht_ev, sort_ev)
106+
unique_mask = dpt_ext.empty(fx.shape, dtype="?", sycl_queue=exec_q)
107+
ht_ev, uneq_ev = _not_equal(
108+
src1=s[:-1],
109+
src2=s[1:],
110+
dst=unique_mask[1:],
111+
sycl_queue=exec_q,
112+
depends=[sort_ev],
113+
)
114+
_manager.add_event_pair(ht_ev, uneq_ev)
115+
# writing into new allocation, no dependencies
116+
ht_ev, one_ev = _full_usm_ndarray(
117+
fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
118+
)
119+
_manager.add_event_pair(ht_ev, one_ev)
120+
cumsum = dpt_ext.empty(s.shape, dtype=dpt.int64, sycl_queue=exec_q)
121+
# synchronizing call
122+
n_uniques = mask_positions(
123+
unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev]
124+
)
125+
if n_uniques == fx.size:
126+
return s
127+
unique_vals = dpt_ext.empty(
128+
n_uniques, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
129+
)
130+
ht_ev, ex_e = _extract(
131+
src=s,
132+
cumsum=cumsum,
133+
axis_start=0,
134+
axis_end=1,
135+
dst=unique_vals,
136+
sycl_queue=exec_q,
137+
)
138+
_manager.add_event_pair(ht_ev, ex_e)
139+
return unique_vals
140+
141+
142+
def unique_counts(x: dpt.usm_ndarray) -> UniqueCountsResult:
143+
"""unique_counts(x)
144+
145+
Returns the unique elements of an input array `x` and the corresponding
146+
counts for each unique element in `x`.
147+
148+
Args:
149+
x (usm_ndarray):
150+
input array. Inputs with more than one dimension are flattened.
151+
Returns:
152+
tuple[usm_ndarray, usm_ndarray]
153+
a namedtuple `(values, counts)` whose
154+
155+
* first element is the field name `values` and is an array
156+
containing the unique elements of `x`. This array has the
157+
same data type as `x`.
158+
* second element has the field name `counts` and is an array
159+
containing the number of times each unique element occurs in `x`.
160+
This array has the same shape as `values` and has the default
161+
array index data type.
162+
"""
163+
if not isinstance(x, dpt.usm_ndarray):
164+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
165+
array_api_dev = x.device
166+
exec_q = array_api_dev.sycl_queue
167+
x_usm_type = x.usm_type
168+
if x.ndim == 1:
169+
fx = x
170+
else:
171+
fx = dpt_ext.reshape(x, (x.size,), order="C")
172+
ind_dt = default_device_index_type(exec_q)
173+
if fx.size == 0:
174+
return UniqueCountsResult(fx, dpt_ext.empty_like(fx, dtype=ind_dt))
175+
s = dpt_ext.empty_like(fx, order="C")
176+
177+
_manager = du.SequentialOrderManager[exec_q]
178+
dep_evs = _manager.submitted_events
179+
if fx.flags.c_contiguous:
180+
ht_ev, sort_ev = _sort_ascending(
181+
src=fx,
182+
trailing_dims_to_sort=1,
183+
dst=s,
184+
sycl_queue=exec_q,
185+
depends=dep_evs,
186+
)
187+
_manager.add_event_pair(ht_ev, sort_ev)
188+
else:
189+
tmp = dpt_ext.empty_like(fx, order="C")
190+
ht_ev, copy_ev = _copy_usm_ndarray_into_usm_ndarray(
191+
src=fx, dst=tmp, sycl_queue=exec_q, depends=dep_evs
192+
)
193+
_manager.add_event_pair(ht_ev, copy_ev)
194+
ht_ev, sort_ev = _sort_ascending(
195+
src=tmp,
196+
dst=s,
197+
trailing_dims_to_sort=1,
198+
sycl_queue=exec_q,
199+
depends=[copy_ev],
200+
)
201+
_manager.add_event_pair(ht_ev, sort_ev)
202+
unique_mask = dpt_ext.empty(s.shape, dtype="?", sycl_queue=exec_q)
203+
ht_ev, uneq_ev = _not_equal(
204+
src1=s[:-1],
205+
src2=s[1:],
206+
dst=unique_mask[1:],
207+
sycl_queue=exec_q,
208+
depends=[sort_ev],
209+
)
210+
_manager.add_event_pair(ht_ev, uneq_ev)
211+
# no dependency, since we write into new allocation
212+
ht_ev, one_ev = _full_usm_ndarray(
213+
fill_value=True, dst=unique_mask[0], sycl_queue=exec_q
214+
)
215+
_manager.add_event_pair(ht_ev, one_ev)
216+
cumsum = dpt_ext.empty(
217+
unique_mask.shape, dtype=dpt.int64, sycl_queue=exec_q
218+
)
219+
# synchronizing call
220+
n_uniques = mask_positions(
221+
unique_mask, cumsum, sycl_queue=exec_q, depends=[one_ev, uneq_ev]
222+
)
223+
if n_uniques == fx.size:
224+
return UniqueCountsResult(
225+
s,
226+
dpt_ext.ones(
227+
n_uniques, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
228+
),
229+
)
230+
unique_vals = dpt_ext.empty(
231+
n_uniques, dtype=x.dtype, usm_type=x_usm_type, sycl_queue=exec_q
232+
)
233+
# populate unique values
234+
ht_ev, ex_e = _extract(
235+
src=s,
236+
cumsum=cumsum,
237+
axis_start=0,
238+
axis_end=1,
239+
dst=unique_vals,
240+
sycl_queue=exec_q,
241+
)
242+
_manager.add_event_pair(ht_ev, ex_e)
243+
unique_counts = dpt_ext.empty(
244+
n_uniques + 1, dtype=ind_dt, usm_type=x_usm_type, sycl_queue=exec_q
245+
)
246+
idx = dpt_ext.empty(x.size, dtype=ind_dt, sycl_queue=exec_q)
247+
# writing into new allocation, no dependency
248+
ht_ev, id_ev = _linspace_step(start=0, dt=1, dst=idx, sycl_queue=exec_q)
249+
_manager.add_event_pair(ht_ev, id_ev)
250+
ht_ev, extr_ev = _extract(
251+
src=idx,
252+
cumsum=cumsum,
253+
axis_start=0,
254+
axis_end=1,
255+
dst=unique_counts[:-1],
256+
sycl_queue=exec_q,
257+
depends=[id_ev],
258+
)
259+
_manager.add_event_pair(ht_ev, extr_ev)
260+
# no dependency, writing into disjoint segmenent of new allocation
261+
ht_ev, set_ev = _full_usm_ndarray(
262+
x.size, dst=unique_counts[-1], sycl_queue=exec_q
263+
)
264+
_manager.add_event_pair(ht_ev, set_ev)
265+
_counts = dpt_ext.empty_like(unique_counts[1:])
266+
ht_ev, sub_ev = _subtract(
267+
src1=unique_counts[1:],
268+
src2=unique_counts[:-1],
269+
dst=_counts,
270+
sycl_queue=exec_q,
271+
depends=[set_ev, extr_ev],
272+
)
273+
_manager.add_event_pair(ht_ev, sub_ev)
274+
return UniqueCountsResult(unique_vals, _counts)

dpnp/dpnp_iface_manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,12 +386,12 @@ def _get_first_nan_index(usm_a):
386386

387387
num_of_flags = (return_index, return_inverse, return_counts).count(True)
388388
if num_of_flags == 0:
389-
usm_res = dpt.unique_values(usm_ar)
389+
usm_res = dpt_ext.unique_values(usm_ar)
390390
usm_res = (usm_res,) # cast to a tuple to align with other cases
391391
elif num_of_flags == 1 and return_inverse:
392392
usm_res = dpt.unique_inverse(usm_ar)
393393
elif num_of_flags == 1 and return_counts:
394-
usm_res = dpt.unique_counts(usm_ar)
394+
usm_res = dpt_ext.unique_counts(usm_ar)
395395
else:
396396
usm_res = dpt.unique_all(usm_ar)
397397

0 commit comments

Comments
 (0)