Skip to content

Commit bd3add0

Browse files
Move ti.all() to dpctl_ext.tensor and reuse it in dpnp
1 parent c4f2496 commit bd3add0

4 files changed

Lines changed: 141 additions & 3 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
unstack,
8080
)
8181
from dpctl_ext.tensor._reshape import reshape
82+
from dpctl_ext.tensor._utility_functions import all
8283

8384
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
8485
from ._clip import clip
@@ -94,6 +95,7 @@
9495
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
9596

9697
__all__ = [
98+
"all",
9799
"arange",
98100
"argsort",
99101
"asarray",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def repeat(x, repeats, /, *, axis=None):
624624
"'repeats' array must be broadcastable to the size of "
625625
"the repeated axis"
626626
)
627-
if not dpt.all(repeats >= 0):
627+
if not dpt_ext.all(repeats >= 0):
628628
raise ValueError("'repeats' elements must be positive")
629629

630630
elif isinstance(repeats, (tuple, list, range)):
@@ -646,7 +646,7 @@ def repeat(x, repeats, /, *, axis=None):
646646
repeats = dpt_ext.asarray(
647647
repeats, dtype=dpt.int64, usm_type=usm_type, sycl_queue=exec_q
648648
)
649-
if not dpt.all(repeats >= 0):
649+
if not dpt_ext.all(repeats >= 0):
650650
raise ValueError("`repeats` elements must be positive")
651651
else:
652652
raise TypeError(
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
30+
import dpctl.tensor as dpt
31+
import dpctl.utils as du
32+
33+
# TODO: revert to `import dpctl.tensor...`
34+
# when dpnp fully migrates dpctl/tensor
35+
import dpctl_ext.tensor as dpt_ext
36+
import dpctl_ext.tensor._tensor_impl as ti
37+
import dpctl_ext.tensor._tensor_reductions_impl as tri
38+
39+
from ._numpy_helper import normalize_axis_tuple
40+
41+
42+
def _boolean_reduction(x, axis, keepdims, func):
43+
if not isinstance(x, dpt.usm_ndarray):
44+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
45+
46+
nd = x.ndim
47+
if axis is None:
48+
red_nd = nd
49+
# case of a scalar
50+
if red_nd == 0:
51+
return dpt_ext.astype(x, dpt.bool)
52+
x_tmp = x
53+
res_shape = ()
54+
perm = list(range(nd))
55+
else:
56+
if not isinstance(axis, (tuple, list)):
57+
axis = (axis,)
58+
axis = normalize_axis_tuple(axis, nd, "axis")
59+
60+
red_nd = len(axis)
61+
# check for axis=()
62+
if red_nd == 0:
63+
return dpt_ext.astype(x, dpt.bool)
64+
perm = [i for i in range(nd) if i not in axis] + list(axis)
65+
x_tmp = dpt_ext.permute_dims(x, perm)
66+
res_shape = x_tmp.shape[: nd - red_nd]
67+
68+
exec_q = x.sycl_queue
69+
res_usm_type = x.usm_type
70+
71+
_manager = du.SequentialOrderManager[exec_q]
72+
dep_evs = _manager.submitted_events
73+
# always allocate the temporary as
74+
# int32 and usm-device to ensure that atomic updates
75+
# are supported
76+
res_tmp = dpt_ext.empty(
77+
res_shape,
78+
dtype=dpt.int32,
79+
usm_type="device",
80+
sycl_queue=exec_q,
81+
)
82+
hev0, ev0 = func(
83+
src=x_tmp,
84+
trailing_dims_to_reduce=red_nd,
85+
dst=res_tmp,
86+
sycl_queue=exec_q,
87+
depends=dep_evs,
88+
)
89+
_manager.add_event_pair(hev0, ev0)
90+
91+
# copy to boolean result array
92+
res = dpt_ext.empty(
93+
res_shape,
94+
dtype=dpt.bool,
95+
usm_type=res_usm_type,
96+
sycl_queue=exec_q,
97+
)
98+
hev1, ev1 = ti._copy_usm_ndarray_into_usm_ndarray(
99+
src=res_tmp, dst=res, sycl_queue=exec_q, depends=[ev0]
100+
)
101+
_manager.add_event_pair(hev1, ev1)
102+
103+
if keepdims:
104+
res_shape = res_shape + (1,) * red_nd
105+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
106+
res = dpt_ext.permute_dims(dpt_ext.reshape(res, res_shape), inv_perm)
107+
return res
108+
109+
110+
def all(x, /, *, axis=None, keepdims=False):
111+
"""
112+
all(x, axis=None, keepdims=False)
113+
114+
Tests whether all input array elements evaluate to True along a given axis.
115+
116+
Args:
117+
x (usm_ndarray): Input array.
118+
axis (Optional[Union[int, Tuple[int,...]]]): Axis (or axes)
119+
along which to perform a logical AND reduction.
120+
When `axis` is `None`, a logical AND reduction
121+
is performed over all dimensions of `x`.
122+
If `axis` is negative, the axis is counted from
123+
the last dimension to the first.
124+
Default: `None`.
125+
keepdims (bool, optional): If `True`, the reduced axes are included
126+
in the result as singleton dimensions, and the result is
127+
broadcastable to the input array shape.
128+
If `False`, the reduced axes are not included in the result.
129+
Default: `False`.
130+
131+
Returns:
132+
usm_ndarray:
133+
An array with a data type of `bool`
134+
containing the results of the logical AND reduction.
135+
"""
136+
return _boolean_reduction(x, axis, keepdims, tri._all)

dpnp/dpnp_iface_logic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def all(a, /, axis=None, out=None, keepdims=False, *, where=True):
214214
dpnp.check_limitations(where=where)
215215

216216
usm_a = dpnp.get_usm_ndarray(a)
217-
usm_res = dpt.all(usm_a, axis=axis, keepdims=keepdims)
217+
usm_res = dpt_ext.all(usm_a, axis=axis, keepdims=keepdims)
218218

219219
# TODO: temporary solution until dpt.all supports out parameter
220220
return dpnp.get_result_array(usm_res, out)

0 commit comments

Comments
 (0)