@@ -359,6 +359,52 @@ def concat(arrays, /, *, axis=0):
359359 return res
360360
361361
362+ def expand_dims (X , / , * , axis = 0 ):
363+ """expand_dims(x, axis)
364+
365+ Expands the shape of an array by inserting a new axis (dimension)
366+ of size one at the position specified by axis.
367+
368+ Args:
369+ x (usm_ndarray):
370+ input array
371+ axis (Union[int, Tuple[int]]):
372+ axis position in the expanded axes (zero-based). If `x` has rank
373+ (i.e, number of dimensions) `N`, a valid `axis` must reside
374+ in the closed-interval `[-N-1, N]`. If provided a negative
375+ `axis`, the `axis` position at which to insert a singleton
376+ dimension is computed as `N + axis + 1`. Hence, if
377+ provided `-1`, the resolved axis position is `N` (i.e.,
378+ a singleton dimension must be appended to the input array `x`).
379+ If provided `-N-1`, the resolved axis position is `0` (i.e., a
380+ singleton dimension is prepended to the input array `x`).
381+
382+ Returns:
383+ usm_ndarray:
384+ Returns a view, if possible, and a copy otherwise with the number
385+ of dimensions increased.
386+ The expanded array has the same data type as the input array `x`.
387+ The expanded array is located on the same device as the input
388+ array, and has the same USM allocation type.
389+
390+ Raises:
391+ IndexError: if `axis` value is invalid.
392+ """
393+ if not isinstance (X , dpt .usm_ndarray ):
394+ raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
395+
396+ if type (axis ) not in (tuple , list ):
397+ axis = (axis ,)
398+
399+ out_ndim = len (axis ) + X .ndim
400+ axis = normalize_axis_tuple (axis , out_ndim )
401+
402+ shape_it = iter (X .shape )
403+ shape = tuple (1 if ax in axis else next (shape_it ) for ax in range (out_ndim ))
404+
405+ return dpt_ext .reshape (X , shape )
406+
407+
362408def repeat (x , repeats , / , * , axis = None ):
363409 """repeat(x, repeats, axis=None)
364410
0 commit comments