Skip to content

Commit ea7ea3c

Browse files
Move ti.expand_dims() to dpctl_ext.tensor and reuse it
1 parent 080c7d8 commit ea7ea3c

3 files changed

Lines changed: 49 additions & 1 deletion

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
broadcast_arrays,
6767
broadcast_to,
6868
concat,
69+
expand_dims,
6970
repeat,
7071
roll,
7172
)
@@ -88,6 +89,7 @@
8889
"empty",
8990
"empty_like",
9091
"extract",
92+
"expand_dims",
9193
"eye",
9294
"finfo",
9395
"from_numpy",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,52 @@ def concat(arrays, /, *, axis=0):
359359
return res
360360

361361

362+
def expand_dims(X, /, *, axis=0):
363+
"""expand_dims(x, axis)
364+
365+
Expands the shape of an array by inserting a new axis (dimension)
366+
of size one at the position specified by axis.
367+
368+
Args:
369+
x (usm_ndarray):
370+
input array
371+
axis (Union[int, Tuple[int]]):
372+
axis position in the expanded axes (zero-based). If `x` has rank
373+
(i.e, number of dimensions) `N`, a valid `axis` must reside
374+
in the closed-interval `[-N-1, N]`. If provided a negative
375+
`axis`, the `axis` position at which to insert a singleton
376+
dimension is computed as `N + axis + 1`. Hence, if
377+
provided `-1`, the resolved axis position is `N` (i.e.,
378+
a singleton dimension must be appended to the input array `x`).
379+
If provided `-N-1`, the resolved axis position is `0` (i.e., a
380+
singleton dimension is prepended to the input array `x`).
381+
382+
Returns:
383+
usm_ndarray:
384+
Returns a view, if possible, and a copy otherwise with the number
385+
of dimensions increased.
386+
The expanded array has the same data type as the input array `x`.
387+
The expanded array is located on the same device as the input
388+
array, and has the same USM allocation type.
389+
390+
Raises:
391+
IndexError: if `axis` value is invalid.
392+
"""
393+
if not isinstance(X, dpt.usm_ndarray):
394+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
395+
396+
if type(axis) not in (tuple, list):
397+
axis = (axis,)
398+
399+
out_ndim = len(axis) + X.ndim
400+
axis = normalize_axis_tuple(axis, out_ndim)
401+
402+
shape_it = iter(X.shape)
403+
shape = tuple(1 if ax in axis else next(shape_it) for ax in range(out_ndim))
404+
405+
return dpt_ext.reshape(X, shape)
406+
407+
362408
def repeat(x, repeats, /, *, axis=None):
363409
"""repeat(x, repeats, axis=None)
364410

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1849,7 +1849,7 @@ def expand_dims(a, axis):
18491849
"""
18501850

18511851
usm_a = dpnp.get_usm_ndarray(a)
1852-
usm_res = dpt.expand_dims(usm_a, axis=axis)
1852+
usm_res = dpt_ext.expand_dims(usm_a, axis=axis)
18531853
return dpnp_array._create_from_usm_ndarray(usm_res)
18541854

18551855

0 commit comments

Comments
 (0)