Skip to content

Commit bb16c19

Browse files
Move ti.stack() to dpctl_ext.tensor and reuse it
1 parent ba51636 commit bb16c19

3 files changed

Lines changed: 64 additions & 1 deletion

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
repeat,
7474
roll,
7575
squeeze,
76+
stack,
7677
)
7778
from dpctl_ext.tensor._reshape import reshape
7879

@@ -117,6 +118,7 @@
117118
"result_type",
118119
"roll",
119120
"squeeze",
121+
"stack",
120122
"take",
121123
"take_along_axis",
122124
"to_numpy",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,3 +873,64 @@ def squeeze(X, /, axis=None):
873873
return X
874874
else:
875875
return dpt_ext.reshape(X, new_shape)
876+
877+
878+
def stack(arrays, /, *, axis=0):
879+
"""
880+
stack(arrays, axis)
881+
882+
Joins a sequence of arrays along a new axis.
883+
884+
Args:
885+
arrays (Union[List[usm_ndarray], Tuple[usm_ndarray,...]]):
886+
input arrays to join. Each array must have the same shape.
887+
axis (int): axis along which the arrays will be joined. Providing
888+
an `axis` specified the index of the new axis in the dimensions
889+
of the output array. A valid axis must be on the interval
890+
`[-N, N)`, where `N` is the rank (number of dimensions) of `x`.
891+
Default: `0`.
892+
893+
Returns:
894+
usm_ndarray:
895+
An output array having rank `N+1`, where `N` is
896+
the rank (number of dimensions) of `x`. If the input arrays have
897+
different data types, array API Type Promotion Rules apply.
898+
899+
Raises:
900+
ValueError: if not all input arrays have the same shape
901+
IndexError: if provided an `axis` outside of the required interval.
902+
"""
903+
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
904+
905+
n = len(arrays)
906+
X0 = arrays[0]
907+
res_ndim = X0.ndim + 1
908+
axis = normalize_axis_index(axis, res_ndim)
909+
X0_shape = X0.shape
910+
911+
for i in range(1, n):
912+
if X0_shape != arrays[i].shape:
913+
raise ValueError("All input arrays must have the same shape")
914+
915+
res_shape = tuple(
916+
X0_shape[i - 1 * (i >= axis)] if i != axis else n
917+
for i in range(res_ndim)
918+
)
919+
920+
res = dpt_ext.empty(
921+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
922+
)
923+
924+
_manager = dputils.SequentialOrderManager[exec_q]
925+
dep_evs = _manager.submitted_events
926+
for i in range(n):
927+
c_shapes_copy = tuple(
928+
i if j == axis else np.s_[:] for j in range(res_ndim)
929+
)
930+
_dst = res[c_shapes_copy]
931+
hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
932+
src=arrays[i], dst=_dst, sycl_queue=exec_q, depends=dep_evs
933+
)
934+
_manager.add_event_pair(hev, cpy_ev)
935+
936+
return res

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3751,7 +3751,7 @@ def stack(arrays, /, *, axis=0, out=None, dtype=None, casting="same_kind"):
37513751
)
37523752

37533753
usm_arrays = [dpnp.get_usm_ndarray(x) for x in arrays]
3754-
usm_res = dpt.stack(usm_arrays, axis=axis)
3754+
usm_res = dpt_ext.stack(usm_arrays, axis=axis)
37553755

37563756
res = dpnp_array._create_from_usm_ndarray(usm_res)
37573757
if dtype is not None:

0 commit comments

Comments
 (0)