Skip to content

Commit 080c7d8

Browse files
Move ti.concat() to dpctl_ext.tensor and reuse it
1 parent bd265da commit 080c7d8

3 files changed

Lines changed: 182 additions & 1 deletion

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
from dpctl_ext.tensor._manipulation_functions import (
6666
broadcast_arrays,
6767
broadcast_to,
68+
concat,
6869
repeat,
6970
roll,
7071
)
@@ -81,6 +82,7 @@
8182
"broadcast_arrays",
8283
"broadcast_to",
8384
"can_cast",
85+
"concat",
8486
"copy",
8587
"clip",
8688
"empty",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,54 @@
4040
import dpctl_ext.tensor._tensor_impl as ti
4141

4242
from ._numpy_helper import normalize_axis_index, normalize_axis_tuple
43+
from ._type_utils import _supported_dtype, _to_device_supported_dtype
4344

4445
__doc__ = (
4546
"Implementation module for array manipulation "
4647
"functions in :module:`dpctl.tensor`"
4748
)
4849

4950

51+
def _arrays_validation(arrays, check_ndim=True):
52+
n = len(arrays)
53+
if n == 0:
54+
raise TypeError("Missing 1 required positional argument: 'arrays'.")
55+
56+
if not isinstance(arrays, (list, tuple)):
57+
raise TypeError(f"Expected tuple or list type, got {type(arrays)}.")
58+
59+
for X in arrays:
60+
if not isinstance(X, dpt.usm_ndarray):
61+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
62+
63+
exec_q = dputils.get_execution_queue([X.sycl_queue for X in arrays])
64+
if exec_q is None:
65+
raise ValueError("All the input arrays must have same sycl queue.")
66+
67+
res_usm_type = dputils.get_coerced_usm_type([X.usm_type for X in arrays])
68+
if res_usm_type is None:
69+
raise ValueError("All the input arrays must have usm_type.")
70+
71+
X0 = arrays[0]
72+
_supported_dtype(Xi.dtype for Xi in arrays)
73+
74+
res_dtype = X0.dtype
75+
dev = exec_q.sycl_device
76+
for i in range(1, n):
77+
res_dtype = np.promote_types(res_dtype, arrays[i])
78+
res_dtype = _to_device_supported_dtype(res_dtype, dev)
79+
80+
if check_ndim:
81+
for i in range(1, n):
82+
if X0.ndim != arrays[i].ndim:
83+
raise ValueError(
84+
"All the input arrays must have same number of dimensions, "
85+
f"but the array at index 0 has {X0.ndim} dimension(s) and "
86+
f"the array at index {i} has {arrays[i].ndim} dimension(s)."
87+
)
88+
return res_dtype, res_usm_type, exec_q
89+
90+
5091
def _broadcast_shapes(*args):
5192
"""
5293
Broadcast the input shapes into a single shape;
@@ -112,6 +153,74 @@ def _broadcast_strides(X_shape, X_strides, res_ndim):
112153
return tuple(out_strides)
113154

114155

156+
def _check_same_shapes(X0_shape, axis, n, arrays):
157+
for i in range(1, n):
158+
Xi_shape = arrays[i].shape
159+
for j, X0j in enumerate(X0_shape):
160+
if X0j != Xi_shape[j] and j != axis:
161+
raise ValueError(
162+
"All the input array dimensions for the concatenation "
163+
f"axis must match exactly, but along dimension {j}, the "
164+
f"array at index 0 has size {X0j} and the array "
165+
f"at index {i} has size {Xi_shape[j]}."
166+
)
167+
168+
169+
def _concat_axis_None(arrays):
170+
"""Implementation of concat(arrays, axis=None)."""
171+
res_dtype, res_usm_type, exec_q = _arrays_validation(
172+
arrays, check_ndim=False
173+
)
174+
res_shape = 0
175+
for array in arrays:
176+
res_shape += array.size
177+
res = dpt_ext.empty(
178+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
179+
)
180+
181+
fill_start = 0
182+
_manager = dputils.SequentialOrderManager[exec_q]
183+
deps = _manager.submitted_events
184+
for array in arrays:
185+
fill_end = fill_start + array.size
186+
if array.flags.c_contiguous:
187+
hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
188+
src=dpt_ext.reshape(array, -1),
189+
dst=res[fill_start:fill_end],
190+
sycl_queue=exec_q,
191+
depends=deps,
192+
)
193+
_manager.add_event_pair(hev, cpy_ev)
194+
else:
195+
src_ = array
196+
# _copy_usm_ndarray_for_reshape requires src and dst to have
197+
# the same data type
198+
if not array.dtype == res_dtype:
199+
src2_ = dpt_ext.empty_like(src_, dtype=res_dtype)
200+
ht_copy_ev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
201+
src=src_, dst=src2_, sycl_queue=exec_q, depends=deps
202+
)
203+
_manager.add_event_pair(ht_copy_ev, cpy_ev)
204+
hev, reshape_copy_ev = ti._copy_usm_ndarray_for_reshape(
205+
src=src2_,
206+
dst=res[fill_start:fill_end],
207+
sycl_queue=exec_q,
208+
depends=[cpy_ev],
209+
)
210+
_manager.add_event_pair(hev, reshape_copy_ev)
211+
else:
212+
hev, cpy_ev = ti._copy_usm_ndarray_for_reshape(
213+
src=src_,
214+
dst=res[fill_start:fill_end],
215+
sycl_queue=exec_q,
216+
depends=deps,
217+
)
218+
_manager.add_event_pair(hev, cpy_ev)
219+
fill_start = fill_end
220+
221+
return res
222+
223+
115224
def broadcast_arrays(*args):
116225
"""broadcast_arrays(*arrays)
117226
@@ -180,6 +289,76 @@ def broadcast_to(X, /, shape):
180289
)
181290

182291

292+
def concat(arrays, /, *, axis=0):
293+
"""concat(arrays, axis)
294+
295+
Joins a sequence of arrays along an existing axis.
296+
297+
Args:
298+
arrays (Union[List[usm_ndarray, Tuple[usm_ndarray,...]]]):
299+
input arrays to join. The arrays must have the same shape,
300+
except in the dimension specified by `axis`.
301+
axis (Optional[int]): axis along which the arrays will be joined.
302+
If `axis` is `None`, arrays must be flattened before
303+
concatenation. If `axis` is negative, it is understood as
304+
being counted from the last dimension. Default: `0`.
305+
306+
Returns:
307+
usm_ndarray:
308+
An output array containing the concatenated
309+
values. The output array data type is determined by Type
310+
Promotion Rules of array API.
311+
312+
All input arrays must have the same device attribute. The output array
313+
is allocated on that same device, and data movement operations are
314+
scheduled on a queue underlying the device. The USM allocation type
315+
of the output array is determined by USM allocation type promotion
316+
rules.
317+
"""
318+
if axis is None:
319+
return _concat_axis_None(arrays)
320+
321+
res_dtype, res_usm_type, exec_q = _arrays_validation(arrays)
322+
n = len(arrays)
323+
X0 = arrays[0]
324+
325+
axis = normalize_axis_index(axis, X0.ndim)
326+
X0_shape = X0.shape
327+
_check_same_shapes(X0_shape, axis, n, arrays)
328+
329+
res_shape_axis = 0
330+
for X in arrays:
331+
res_shape_axis = res_shape_axis + X.shape[axis]
332+
333+
res_shape = tuple(
334+
X0_shape[i] if i != axis else res_shape_axis for i in range(X0.ndim)
335+
)
336+
337+
res = dpt_ext.empty(
338+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
339+
)
340+
341+
_manager = dputils.SequentialOrderManager[exec_q]
342+
deps = _manager.submitted_events
343+
fill_start = 0
344+
for i in range(n):
345+
fill_end = fill_start + arrays[i].shape[axis]
346+
c_shapes_copy = tuple(
347+
np.s_[fill_start:fill_end] if j == axis else np.s_[:]
348+
for j in range(X0.ndim)
349+
)
350+
hev, cpy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
351+
src=arrays[i],
352+
dst=res[c_shapes_copy],
353+
sycl_queue=exec_q,
354+
depends=deps,
355+
)
356+
_manager.add_event_pair(hev, cpy_ev)
357+
fill_start = fill_end
358+
359+
return res
360+
361+
183362
def repeat(x, repeats, /, *, axis=None):
184363
"""repeat(x, repeats, axis=None)
185364

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1416,7 +1416,7 @@ def concatenate(
14161416
)
14171417

14181418
usm_arrays = [dpnp.get_usm_ndarray(x) for x in arrays]
1419-
usm_res = dpt.concat(usm_arrays, axis=axis)
1419+
usm_res = dpt_ext.concat(usm_arrays, axis=axis)
14201420

14211421
res = dpnp_array._create_from_usm_ndarray(usm_res)
14221422
if dtype is not None:

0 commit comments

Comments
 (0)