Skip to content

Commit 6b81e7a

Browse files
Move ti.cumulative_sum() and reuse it in dpnp
1 parent 4552e78 commit 6b81e7a

3 files changed

Lines changed: 315 additions & 2 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
)
8181
from dpctl_ext.tensor._reshape import reshape
8282

83+
from ._accumulation import cumulative_sum
8384
from ._clip import clip
8485
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
8586

@@ -94,6 +95,7 @@
9495
"concat",
9596
"copy",
9697
"clip",
98+
"cumulative_sum",
9799
"empty",
98100
"empty_like",
99101
"extract",

dpctl_ext/tensor/_accumulation.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
import dpctl
30+
import dpctl.tensor as dpt
31+
from dpctl.tensor._type_utils import ( # _default_accumulation_dtype_fp_types,
32+
_default_accumulation_dtype,
33+
_to_device_supported_dtype,
34+
)
35+
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
36+
37+
# TODO: revert to `import dpctl.tensor...`
38+
# when dpnp fully migrates dpctl/tensor
39+
import dpctl_ext.tensor as dpt_ext
40+
import dpctl_ext.tensor._tensor_accumulation_impl as tai
41+
import dpctl_ext.tensor._tensor_impl as ti
42+
43+
from ._numpy_helper import normalize_axis_index
44+
45+
46+
def _accumulate_common(
47+
x,
48+
axis,
49+
dtype,
50+
include_initial,
51+
out,
52+
_accumulate_fn,
53+
_accumulate_include_initial_fn,
54+
_dtype_supported,
55+
_default_accumulation_type_fn,
56+
):
57+
if not isinstance(x, dpt.usm_ndarray):
58+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
59+
appended_axis = False
60+
if x.ndim == 0:
61+
x = x[dpt.newaxis]
62+
appended_axis = True
63+
nd = x.ndim
64+
if axis is None:
65+
if nd > 1:
66+
raise ValueError(
67+
"`axis` cannot be `None` for array of dimension `{}`".format(nd)
68+
)
69+
axis = 0
70+
else:
71+
axis = normalize_axis_index(axis, nd, "axis")
72+
sh = x.shape
73+
res_sh = (
74+
sh[:axis] + (sh[axis] + 1,) + sh[axis + 1 :] if include_initial else sh
75+
)
76+
a1 = axis + 1
77+
if a1 == nd:
78+
perm = list(range(nd))
79+
arr = x
80+
else:
81+
perm = [i for i in range(nd) if i != axis] + [
82+
axis,
83+
]
84+
arr = dpt_ext.permute_dims(x, perm)
85+
q = x.sycl_queue
86+
inp_dt = x.dtype
87+
res_usm_type = x.usm_type
88+
if dtype is None:
89+
res_dt = _default_accumulation_type_fn(inp_dt, q)
90+
else:
91+
res_dt = dpt.dtype(dtype)
92+
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
93+
94+
# checking now avoids unnecessary allocations
95+
implemented_types = _dtype_supported(inp_dt, res_dt)
96+
if dtype is None and not implemented_types:
97+
raise RuntimeError(
98+
"Automatically determined accumulation data type does not "
99+
"have direct implementation"
100+
)
101+
orig_out = out
102+
if out is not None:
103+
if not isinstance(out, dpt.usm_ndarray):
104+
raise TypeError(
105+
f"output array must be of usm_ndarray type, got {type(out)}"
106+
)
107+
if not out.flags.writable:
108+
raise ValueError("provided `out` array is read-only")
109+
out_sh = out.shape
110+
# append an axis to `out` if scalar
111+
if appended_axis and not include_initial:
112+
out = out[dpt.newaxis, ...]
113+
orig_out = out
114+
final_res_sh = res_sh[1:]
115+
else:
116+
final_res_sh = res_sh
117+
if not out_sh == final_res_sh:
118+
raise ValueError(
119+
"The shape of input and output arrays are inconsistent. "
120+
f"Expected output shape is {final_res_sh}, got {out_sh}"
121+
)
122+
if res_dt != out.dtype:
123+
raise ValueError(
124+
f"Output array of type {res_dt} is needed, " f"got {out.dtype}"
125+
)
126+
if dpctl.utils.get_execution_queue((q, out.sycl_queue)) is None:
127+
raise ExecutionPlacementError(
128+
"Input and output allocation queues are not compatible"
129+
)
130+
# permute out array dims if necessary
131+
if a1 != nd:
132+
out = dpt_ext.permute_dims(out, perm)
133+
orig_out = out
134+
if ti._array_overlap(x, out) and implemented_types:
135+
out = dpt_ext.empty_like(out)
136+
else:
137+
out = dpt_ext.empty(
138+
res_sh, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
139+
)
140+
if a1 != nd:
141+
out = dpt_ext.permute_dims(out, perm)
142+
143+
_manager = SequentialOrderManager[q]
144+
depends = _manager.submitted_events
145+
if implemented_types:
146+
if not include_initial:
147+
ht_e, acc_ev = _accumulate_fn(
148+
src=arr,
149+
trailing_dims_to_accumulate=1,
150+
dst=out,
151+
sycl_queue=q,
152+
depends=depends,
153+
)
154+
else:
155+
ht_e, acc_ev = _accumulate_include_initial_fn(
156+
src=arr, dst=out, sycl_queue=q, depends=depends
157+
)
158+
_manager.add_event_pair(ht_e, acc_ev)
159+
if not (orig_out is None or out is orig_out):
160+
# Copy the out data from temporary buffer to original memory
161+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
162+
src=out, dst=orig_out, sycl_queue=q, depends=[acc_ev]
163+
)
164+
_manager.add_event_pair(ht_e_cpy, cpy_e)
165+
out = orig_out
166+
else:
167+
if _dtype_supported(res_dt, res_dt):
168+
tmp = dpt_ext.empty(
169+
arr.shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=q
170+
)
171+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
172+
src=arr, dst=tmp, sycl_queue=q, depends=depends
173+
)
174+
_manager.add_event_pair(ht_e_cpy, cpy_e)
175+
if not include_initial:
176+
ht_e, acc_ev = _accumulate_fn(
177+
src=tmp,
178+
trailing_dims_to_accumulate=1,
179+
dst=out,
180+
sycl_queue=q,
181+
depends=[cpy_e],
182+
)
183+
else:
184+
ht_e, acc_ev = _accumulate_include_initial_fn(
185+
src=tmp,
186+
dst=out,
187+
sycl_queue=q,
188+
depends=[cpy_e],
189+
)
190+
_manager.add_event_pair(ht_e, acc_ev)
191+
else:
192+
buf_dt = _default_accumulation_type_fn(inp_dt, q)
193+
tmp = dpt_ext.empty(
194+
arr.shape, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
195+
)
196+
ht_e_cpy, cpy_e = ti._copy_usm_ndarray_into_usm_ndarray(
197+
src=arr, dst=tmp, sycl_queue=q, depends=depends
198+
)
199+
_manager.add_event_pair(ht_e_cpy, cpy_e)
200+
tmp_res = dpt_ext.empty(
201+
res_sh, dtype=buf_dt, usm_type=res_usm_type, sycl_queue=q
202+
)
203+
if a1 != nd:
204+
tmp_res = dpt_ext.permute_dims(tmp_res, perm)
205+
if not include_initial:
206+
ht_e, acc_ev = _accumulate_fn(
207+
src=tmp,
208+
trailing_dims_to_accumulate=1,
209+
dst=tmp_res,
210+
sycl_queue=q,
211+
depends=[cpy_e],
212+
)
213+
else:
214+
ht_e, acc_ev = _accumulate_include_initial_fn(
215+
src=tmp,
216+
dst=tmp_res,
217+
sycl_queue=q,
218+
depends=[cpy_e],
219+
)
220+
_manager.add_event_pair(ht_e, acc_ev)
221+
ht_e_cpy2, cpy_e2 = ti._copy_usm_ndarray_into_usm_ndarray(
222+
src=tmp_res, dst=out, sycl_queue=q, depends=[acc_ev]
223+
)
224+
_manager.add_event_pair(ht_e_cpy2, cpy_e2)
225+
226+
if appended_axis:
227+
out = dpt_ext.squeeze(out)
228+
if a1 != nd:
229+
inv_perm = sorted(range(nd), key=lambda d: perm[d])
230+
out = dpt_ext.permute_dims(out, inv_perm)
231+
232+
return out
233+
234+
235+
def cumulative_sum(
236+
x, /, *, axis=None, dtype=None, include_initial=False, out=None
237+
):
238+
"""
239+
cumulative_sum(x, /, *, axis=None, dtype=None, include_initial=False,
240+
out=None)
241+
242+
Calculates the cumulative sum of elements in the input array `x`.
243+
244+
Args:
245+
x (usm_ndarray):
246+
input array.
247+
axis (Optional[int]):
248+
axis along which cumulative sum must be computed.
249+
If `None`, the sum is computed over the entire array.
250+
If `x` is a one-dimensional array, providing an `axis` is optional;
251+
however, if `x` has more than one dimension, providing an `axis`
252+
is required.
253+
Default: `None`.
254+
dtype (Optional[dtype]):
255+
data type of the returned array. If `None`, the default data
256+
type is inferred from the "kind" of the input array data type.
257+
258+
* If `x` has a real- or complex-valued floating-point data
259+
type, the returned array will have the same data type as
260+
`x`.
261+
* If `x` has signed integral data type, the returned array
262+
will have the default signed integral type for the device
263+
where input array `x` is allocated.
264+
* If `x` has unsigned integral data type, the returned array
265+
will have the default unsigned integral type for the device
266+
where input array `x` is allocated.
267+
* If `x` has a boolean data type, the returned array will
268+
have the default signed integral type for the device
269+
where input array `x` is allocated.
270+
271+
If the data type (either specified or resolved) differs from the
272+
data type of `x`, the input array elements are cast to the
273+
specified data type before computing the cumulative sum.
274+
Default: `None`.
275+
include_initial (bool):
276+
boolean indicating whether to include the initial value (i.e., the
277+
additive identity, zero) as the first value along the provided axis
278+
in the output. Default: `False`.
279+
out (Optional[usm_ndarray]):
280+
the array into which the result is written.
281+
The data type of `out` must match the expected shape and the
282+
expected data type of the result or (if provided) `dtype`.
283+
If `None` then a new array is returned. Default: `None`.
284+
285+
Returns:
286+
usm_ndarray:
287+
an array containing cumulative sums. The returned array has the data
288+
type as described in the `dtype` parameter description above.
289+
290+
The returned array shape is determined as follows:
291+
292+
* If `include_initial` is `False`, the returned array will
293+
have the same shape as `x`
294+
* If `include_initial` is `True`, the returned array will
295+
have the same shape as `x` except the axis along which the
296+
cumulative sum is calculated, which will have size `N+1`
297+
298+
where `N` is the size of the axis the cumulative sums are computed
299+
along.
300+
"""
301+
return _accumulate_common(
302+
x,
303+
axis,
304+
dtype,
305+
include_initial,
306+
out,
307+
tai._cumsum_over_axis,
308+
tai._cumsum_final_axis_include_initial,
309+
tai._cumsum_dtype_supported,
310+
_default_accumulation_dtype,
311+
)

dpnp/dpnp_iface_mathematical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,7 +1218,7 @@ def cumsum(a, axis=None, dtype=None, out=None):
12181218
return dpnp_wrap_reduction_call(
12191219
usm_a,
12201220
out,
1221-
dpt.cumulative_sum,
1221+
dpt_ext.cumulative_sum,
12221222
_get_reduction_res_dt(a, dtype),
12231223
axis=axis,
12241224
dtype=dtype,
@@ -1403,7 +1403,7 @@ def cumulative_sum(
14031403
return dpnp_wrap_reduction_call(
14041404
dpnp.get_usm_ndarray(x),
14051405
out,
1406-
dpt.cumulative_sum,
1406+
dpt_ext.cumulative_sum,
14071407
_get_reduction_res_dt(x, dtype),
14081408
axis=axis,
14091409
dtype=dtype,

0 commit comments

Comments
 (0)