Skip to content

Commit d4cda7c

Browse files
Move ti.abs() and reuse _abs in dpnp
1 parent 58fdef3 commit d4cda7c

4 files changed

Lines changed: 355 additions & 2 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@
8383

8484
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
8585
from ._clip import clip
86+
from ._elementwise_funcs import (
87+
abs,
88+
)
8689
from ._reduction import (
8790
argmax,
8891
argmin,
@@ -106,6 +109,7 @@
106109
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
107110

108111
__all__ = [
112+
"abs",
109113
"all",
110114
"any",
111115
"arange",
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
import dpctl
30+
import dpctl.tensor as dpt
31+
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
32+
33+
# TODO: revert to `import dpctl.tensor...`
34+
# when dpnp fully migrates dpctl/tensor
35+
import dpctl_ext.tensor as dpt_ext
36+
import dpctl_ext.tensor._tensor_impl as ti
37+
38+
from ._copy_utils import _empty_like_orderK
39+
from ._type_utils import (
40+
_acceptance_fn_default_unary,
41+
_all_data_types,
42+
_find_buf_dtype,
43+
)
44+
45+
46+
class UnaryElementwiseFunc:
47+
"""
48+
Class that implements unary element-wise functions.
49+
50+
Args:
51+
name (str):
52+
Name of the unary function
53+
result_type_resovler_fn (callable):
54+
Function that takes dtype of the input and
55+
returns the dtype of the result if the
56+
implementation functions supports it, or
57+
returns `None` otherwise.
58+
unary_dp_impl_fn (callable):
59+
Data-parallel implementation function with signature
60+
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
61+
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
62+
where the `src` is the argument array, `dst` is the
63+
array to be populated with function values, effectively
64+
evaluating `dst = func(src)`.
65+
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
66+
The first event corresponds to data-management host tasks,
67+
including lifetime management of argument Python objects to ensure
68+
that their associated USM allocation is not freed before offloaded
69+
computational tasks complete execution, while the second event
70+
corresponds to computational tasks associated with function
71+
evaluation.
72+
acceptance_fn (callable, optional):
73+
Function to influence type promotion behavior of this unary
74+
function. The function takes 4 arguments:
75+
arg_dtype - Data type of the first argument
76+
buf_dtype - Data type the argument would be cast to
77+
res_dtype - Data type of the output array with function values
78+
sycl_dev - The :class:`dpctl.SyclDevice` where the function
79+
evaluation is carried out.
80+
The function is invoked when the argument of the unary function
81+
requires casting, e.g. the argument of `dpctl.tensor.log` is an
82+
array with integral data type.
83+
docs (str):
84+
Documentation string for the unary function.
85+
"""
86+
87+
def __init__(
88+
self,
89+
name,
90+
result_type_resolver_fn,
91+
unary_dp_impl_fn,
92+
docs,
93+
acceptance_fn=None,
94+
):
95+
self.__name__ = "UnaryElementwiseFunc"
96+
self.name_ = name
97+
self.result_type_resolver_fn_ = result_type_resolver_fn
98+
self.types_ = None
99+
self.unary_fn_ = unary_dp_impl_fn
100+
self.__doc__ = docs
101+
if callable(acceptance_fn):
102+
self.acceptance_fn_ = acceptance_fn
103+
else:
104+
self.acceptance_fn_ = _acceptance_fn_default_unary
105+
106+
def __str__(self):
107+
return f"<{self.__name__} '{self.name_}'>"
108+
109+
def __repr__(self):
110+
return f"<{self.__name__} '{self.name_}'>"
111+
112+
def get_implementation_function(self):
113+
"""Returns the implementation function for
114+
this elementwise unary function.
115+
116+
"""
117+
return self.unary_fn_
118+
119+
def get_type_result_resolver_function(self):
120+
"""Returns the type resolver function for this
121+
elementwise unary function.
122+
"""
123+
return self.result_type_resolver_fn_
124+
125+
def get_type_promotion_path_acceptance_function(self):
126+
"""Returns the acceptance function for this
127+
elementwise binary function.
128+
129+
Acceptance function influences the type promotion
130+
behavior of this unary function.
131+
The function takes 4 arguments:
132+
arg_dtype - Data type of the first argument
133+
buf_dtype - Data type the argument would be cast to
134+
res_dtype - Data type of the output array with function values
135+
sycl_dev - The :class:`dpctl.SyclDevice` where the function
136+
evaluation is carried out.
137+
The function is invoked when the argument of the unary function
138+
requires casting, e.g. the argument of `dpctl.tensor.log` is an
139+
array with integral data type.
140+
"""
141+
return self.acceptance_fn_
142+
143+
@property
144+
def nin(self):
145+
"""Returns the number of arguments treated as inputs."""
146+
return 1
147+
148+
@property
149+
def nout(self):
150+
"""Returns the number of arguments treated as outputs."""
151+
return 1
152+
153+
@property
154+
def types(self):
155+
"""Returns information about types supported by
156+
implementation function, using NumPy's character
157+
encoding for data types, e.g.
158+
159+
:Example:
160+
.. code-block:: python
161+
162+
dpctl.tensor.sin.types
163+
# Outputs: ['e->e', 'f->f', 'd->d', 'F->F', 'D->D']
164+
"""
165+
types = self.types_
166+
if not types:
167+
types = []
168+
for dt1 in _all_data_types(True, True):
169+
dt2 = self.result_type_resolver_fn_(dt1)
170+
if dt2:
171+
types.append(f"{dt1.char}->{dt2.char}")
172+
self.types_ = types
173+
return types
174+
175+
def __call__(self, x, /, *, out=None, order="K"):
176+
if not isinstance(x, dpt.usm_ndarray):
177+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
178+
179+
if order not in ["C", "F", "K", "A"]:
180+
order = "K"
181+
buf_dt, res_dt = _find_buf_dtype(
182+
x.dtype,
183+
self.result_type_resolver_fn_,
184+
x.sycl_device,
185+
acceptance_fn=self.acceptance_fn_,
186+
)
187+
if res_dt is None:
188+
raise ValueError(
189+
f"function '{self.name_}' does not support input type "
190+
f"({x.dtype}), "
191+
"and the input could not be safely coerced to any "
192+
"supported types according to the casting rule ''safe''."
193+
)
194+
195+
orig_out = out
196+
if out is not None:
197+
if not isinstance(out, dpt.usm_ndarray):
198+
raise TypeError(
199+
f"output array must be of usm_ndarray type, got {type(out)}"
200+
)
201+
202+
if not out.flags.writable:
203+
raise ValueError("provided `out` array is read-only")
204+
205+
if out.shape != x.shape:
206+
raise ValueError(
207+
"The shape of input and output arrays are inconsistent. "
208+
f"Expected output shape is {x.shape}, got {out.shape}"
209+
)
210+
211+
if res_dt != out.dtype:
212+
raise ValueError(
213+
f"Output array of type {res_dt} is needed, "
214+
f"got {out.dtype}"
215+
)
216+
217+
if (
218+
buf_dt is None
219+
and ti._array_overlap(x, out)
220+
and not ti._same_logical_tensors(x, out)
221+
):
222+
# Allocate a temporary buffer to avoid memory overlapping.
223+
# Note if `buf_dt` is not None, a temporary copy of `x` will be
224+
# created, so the array overlap check isn't needed.
225+
out = dpt_ext.empty_like(out)
226+
227+
if (
228+
dpctl.utils.get_execution_queue((x.sycl_queue, out.sycl_queue))
229+
is None
230+
):
231+
raise ExecutionPlacementError(
232+
"Input and output allocation queues are not compatible"
233+
)
234+
235+
exec_q = x.sycl_queue
236+
_manager = SequentialOrderManager[exec_q]
237+
if buf_dt is None:
238+
if out is None:
239+
if order == "K":
240+
out = _empty_like_orderK(x, res_dt)
241+
else:
242+
if order == "A":
243+
order = "F" if x.flags.f_contiguous else "C"
244+
out = dpt_ext.empty_like(x, dtype=res_dt, order=order)
245+
246+
dep_evs = _manager.submitted_events
247+
ht_unary_ev, unary_ev = self.unary_fn_(
248+
x, out, sycl_queue=exec_q, depends=dep_evs
249+
)
250+
_manager.add_event_pair(ht_unary_ev, unary_ev)
251+
252+
if not (orig_out is None or orig_out is out):
253+
# Copy the out data from temporary buffer to original memory
254+
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
255+
src=out, dst=orig_out, sycl_queue=exec_q, depends=[unary_ev]
256+
)
257+
_manager.add_event_pair(ht_copy_ev, cpy_ev)
258+
out = orig_out
259+
260+
return out
261+
262+
if order == "K":
263+
buf = _empty_like_orderK(x, buf_dt)
264+
else:
265+
if order == "A":
266+
order = "F" if x.flags.f_contiguous else "C"
267+
buf = dpt_ext.empty_like(x, dtype=buf_dt, order=order)
268+
269+
dep_evs = _manager.submitted_events
270+
ht_copy_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
271+
src=x, dst=buf, sycl_queue=exec_q, depends=dep_evs
272+
)
273+
_manager.add_event_pair(ht_copy_ev, copy_ev)
274+
if out is None:
275+
if order == "K":
276+
out = _empty_like_orderK(buf, res_dt)
277+
else:
278+
out = dpt_ext.empty_like(buf, dtype=res_dt, order=order)
279+
280+
ht, uf_ev = self.unary_fn_(
281+
buf, out, sycl_queue=exec_q, depends=[copy_ev]
282+
)
283+
_manager.add_event_pair(ht, uf_ev)
284+
285+
return out
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
# TODO: revert to `import dpctl.tensor...`
30+
# when dpnp fully migrates dpctl/tensor
31+
import dpctl_ext.tensor._tensor_elementwise_impl as ti
32+
33+
from ._elementwise_common import UnaryElementwiseFunc
34+
35+
# U01: ==== ABS (x)
36+
_abs_docstring_ = r"""
37+
abs(x, /, \*, out=None, order='K')
38+
39+
Calculates the absolute value for each element `x_i` of input array `x`.
40+
41+
Args:
42+
x (usm_ndarray):
43+
Input array. May have any data type.
44+
out (Union[usm_ndarray, None], optional):
45+
Output array to populate.
46+
Array must have the correct shape and the expected data type.
47+
order ("C","F","A","K", optional):
48+
Memory layout of the new output array,
49+
if parameter `out` is ``None``.
50+
Default: `"K"`.
51+
52+
Returns:
53+
usm_ndarray:
54+
An array containing the element-wise absolute values.
55+
For complex input, the absolute value is its magnitude.
56+
If `x` has a real-valued data type, the returned array has the
57+
same data type as `x`. If `x` has a complex floating-point data type,
58+
the returned array has a real-valued floating-point data type whose
59+
precision matches the precision of `x`.
60+
"""
61+
62+
abs = UnaryElementwiseFunc("abs", ti._abs_result_type, ti._abs, _abs_docstring_)
63+
del _abs_docstring_

dpnp/dpnp_iface_mathematical.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
# TODO: revert to `import dpctl.tensor...`
5959
# when dpnp fully migrates dpctl/tensor
6060
import dpctl_ext.tensor as dpt
61+
import dpctl_ext.tensor._tensor_elementwise_impl as ti_ext
6162
import dpctl_ext.tensor._type_utils as dtu
6263
import dpnp
6364
import dpnp.backend.extensions.ufunc._ufunc_impl as ufi
@@ -384,8 +385,8 @@ def _validate_interp_param(param, name, exec_q, usm_type, dtype=None):
384385

385386
abs = DPNPUnaryFunc(
386387
"abs",
387-
ti._abs_result_type,
388-
ti._abs,
388+
ti_ext._abs_result_type,
389+
ti_ext._abs,
389390
_ABS_DOCSTRING,
390391
mkl_fn_to_call="_mkl_abs_to_call",
391392
mkl_impl_fn="_abs",

0 commit comments

Comments
 (0)