Skip to content

Commit 8ae933d

Browse files
Move ti.argmax()/argmin() to dpctl_ext.tensor and reuse them in dpnp
1 parent c6d600a commit 8ae933d

3 files changed

Lines changed: 52 additions & 95 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
8585
from ._clip import clip
8686
from ._reduction import (
87+
argmax,
88+
argmin,
8789
max,
8890
min,
8991
)
@@ -102,6 +104,8 @@
102104
"all",
103105
"any",
104106
"arange",
107+
"argmax",
108+
"argmin",
105109
"argsort",
106110
"asarray",
107111
"asnumpy",

dpctl_ext/tensor/_reduction.py

Lines changed: 44 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -137,72 +137,6 @@ def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
137137
return out
138138

139139

140-
def max(x, /, *, axis=None, keepdims=False, out=None):
141-
"""
142-
Calculates the maximum value of the input array ``x``.
143-
144-
Args:
145-
x (usm_ndarray):
146-
input array.
147-
axis (Optional[int, Tuple[int, ...]]):
148-
axis or axes along which maxima must be computed. If a tuple
149-
of unique integers, the maxima are computed over multiple axes.
150-
If ``None``, the max is computed over the entire array.
151-
Default: ``None``.
152-
keepdims (Optional[bool]):
153-
if ``True``, the reduced axes (dimensions) are included in the
154-
result as singleton dimensions, so that the returned array remains
155-
compatible with the input arrays according to Array Broadcasting
156-
rules. Otherwise, if ``False``, the reduced axes are not included
157-
in the returned array. Default: ``False``.
158-
out (Optional[usm_ndarray]):
159-
the array into which the result is written.
160-
The data type of ``out`` must match the expected shape and the
161-
expected data type of the result.
162-
If ``None`` then a new array is returned. Default: ``None``.
163-
164-
Returns:
165-
usm_ndarray:
166-
an array containing the maxima. If the max was computed over the
167-
entire array, a zero-dimensional array is returned. The returned
168-
array has the same data type as ``x``.
169-
"""
170-
return _comparison_over_axis(x, axis, keepdims, out, tri._max_over_axis)
171-
172-
173-
def min(x, /, *, axis=None, keepdims=False, out=None):
174-
"""
175-
Calculates the minimum value of the input array ``x``.
176-
177-
Args:
178-
x (usm_ndarray):
179-
input array.
180-
axis (Optional[int, Tuple[int, ...]]):
181-
axis or axes along which minima must be computed. If a tuple
182-
of unique integers, the minima are computed over multiple axes.
183-
If ``None``, the min is computed over the entire array.
184-
Default: ``None``.
185-
keepdims (Optional[bool]):
186-
if ``True``, the reduced axes (dimensions) are included in the
187-
result as singleton dimensions, so that the returned array remains
188-
compatible with the input arrays according to Array Broadcasting
189-
rules. Otherwise, if ``False``, the reduced axes are not included
190-
in the returned array. Default: ``False``.
191-
out (Optional[usm_ndarray]):
192-
the array into which the result is written.
193-
The data type of ``out`` must match the expected shape and the
194-
expected data type of the result.
195-
If ``None`` then a new array is returned. Default: ``None``.
196-
197-
Returns:
198-
usm_ndarray:
199-
an array containing the minima. If the min was computed over the
200-
entire array, a zero-dimensional array is returned. The returned
201-
array has the same data type as ``x``.
202-
"""
203-
return _comparison_over_axis(x, axis, keepdims, out, tri._min_over_axis)
204-
205-
206140
def _search_over_axis(x, axis, keepdims, out, _reduction_fn):
207141
if not isinstance(x, dpt.usm_ndarray):
208142
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
@@ -376,18 +310,17 @@ def argmin(x, /, *, axis=None, keepdims=False, out=None):
376310
return _search_over_axis(x, axis, keepdims, out, tri._argmin_over_axis)
377311

378312

379-
def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
313+
def max(x, /, *, axis=None, keepdims=False, out=None):
380314
"""
381-
Counts the number of elements in the input array ``x`` which are non-zero.
315+
Calculates the maximum value of the input array ``x``.
382316
383317
Args:
384318
x (usm_ndarray):
385319
input array.
386320
axis (Optional[int, Tuple[int, ...]]):
387-
axis or axes along which to count. If a tuple of unique integers,
388-
the number of non-zero values are computed over multiple axes.
389-
If ``None``, the number of non-zero values is computed over the
390-
entire array.
321+
axis or axes along which maxima must be computed. If a tuple
322+
of unique integers, the maxima are computed over multiple axes.
323+
If ``None``, the max is computed over the entire array.
391324
Default: ``None``.
392325
keepdims (Optional[bool]):
393326
if ``True``, the reduced axes (dimensions) are included in the
@@ -397,23 +330,47 @@ def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
397330
in the returned array. Default: ``False``.
398331
out (Optional[usm_ndarray]):
399332
the array into which the result is written.
400-
The data type of ``out`` must match the expected shape and data
401-
type.
333+
The data type of ``out`` must match the expected shape and the
334+
expected data type of the result.
402335
If ``None`` then a new array is returned. Default: ``None``.
403336
404337
Returns:
405338
usm_ndarray:
406-
an array containing the count of non-zero values. If the sum was
407-
computed over the entire array, a zero-dimensional array is
408-
returned. The returned array will have the default array index data
409-
type.
339+
an array containing the maxima. If the max was computed over the
340+
entire array, a zero-dimensional array is returned. The returned
341+
array has the same data type as ``x``.
410342
"""
411-
if x.dtype != dpt.bool:
412-
x = dpt_ext.astype(x, dpt.bool, copy=False)
413-
return sum(
414-
x,
415-
axis=axis,
416-
dtype=ti.default_device_index_type(x.sycl_device),
417-
keepdims=keepdims,
418-
out=out,
419-
)
343+
return _comparison_over_axis(x, axis, keepdims, out, tri._max_over_axis)
344+
345+
346+
def min(x, /, *, axis=None, keepdims=False, out=None):
347+
"""
348+
Calculates the minimum value of the input array ``x``.
349+
350+
Args:
351+
x (usm_ndarray):
352+
input array.
353+
axis (Optional[int, Tuple[int, ...]]):
354+
axis or axes along which minima must be computed. If a tuple
355+
of unique integers, the minima are computed over multiple axes.
356+
If ``None``, the min is computed over the entire array.
357+
Default: ``None``.
358+
keepdims (Optional[bool]):
359+
if ``True``, the reduced axes (dimensions) are included in the
360+
result as singleton dimensions, so that the returned array remains
361+
compatible with the input arrays according to Array Broadcasting
362+
rules. Otherwise, if ``False``, the reduced axes are not included
363+
in the returned array. Default: ``False``.
364+
out (Optional[usm_ndarray]):
365+
the array into which the result is written.
366+
The data type of ``out`` must match the expected shape and the
367+
expected data type of the result.
368+
If ``None`` then a new array is returned. Default: ``None``.
369+
370+
Returns:
371+
usm_ndarray:
372+
an array containing the minima. If the min was computed over the
373+
entire array, a zero-dimensional array is returned. The returned
374+
array has the same data type as ``x``.
375+
"""
376+
return _comparison_over_axis(x, axis, keepdims, out, tri._min_over_axis)

dpnp/dpnp_iface_searching.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,10 @@
3939
4040
"""
4141

42-
import dpctl.tensor as dpt
43-
4442
# pylint: disable=no-name-in-module
4543
# TODO: revert to `import dpctl.tensor...`
4644
# when dpnp fully migrates dpctl/tensor
47-
import dpctl_ext.tensor as dpt_ext
45+
import dpctl_ext.tensor as dpt
4846
import dpctl_ext.tensor._tensor_impl as dti
4947
import dpnp
5048

@@ -376,13 +374,13 @@ def searchsorted(a, v, side="left", sorter=None):
376374

377375
usm_a = dpnp.get_usm_ndarray(a)
378376
if dpnp.isscalar(v):
379-
usm_v = dpt_ext.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
377+
usm_v = dpt.asarray(v, sycl_queue=a.sycl_queue, usm_type=a.usm_type)
380378
else:
381379
usm_v = dpnp.get_usm_ndarray(v)
382380

383381
usm_sorter = None if sorter is None else dpnp.get_usm_ndarray(sorter)
384382
return dpnp_array._create_from_usm_ndarray(
385-
dpt_ext.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter)
383+
dpt.searchsorted(usm_a, usm_v, side=side, sorter=usm_sorter)
386384
)
387385

388386

@@ -474,7 +472,5 @@ def where(condition, x=None, y=None, /, *, order="K", out=None):
474472
usm_condition = dpnp.get_usm_ndarray(condition)
475473

476474
usm_out = None if out is None else dpnp.get_usm_ndarray(out)
477-
usm_res = dpt_ext.where(
478-
usm_condition, usm_x, usm_y, order=order, out=usm_out
479-
)
475+
usm_res = dpt.where(usm_condition, usm_x, usm_y, order=order, out=usm_out)
480476
return dpnp.get_result_array(usm_res, out)

0 commit comments

Comments
 (0)