@@ -873,3 +873,64 @@ def squeeze(X, /, axis=None):
873873 return X
874874 else :
875875 return dpt_ext .reshape (X , new_shape )
876+
877+
878+ def stack (arrays , / , * , axis = 0 ):
879+ """
880+ stack(arrays, axis)
881+
882+ Joins a sequence of arrays along a new axis.
883+
884+ Args:
885+ arrays (Union[List[usm_ndarray], Tuple[usm_ndarray,...]]):
886+ input arrays to join. Each array must have the same shape.
887+ axis (int): axis along which the arrays will be joined. Providing
888+ an `axis` specified the index of the new axis in the dimensions
889+ of the output array. A valid axis must be on the interval
890+ `[-N, N)`, where `N` is the rank (number of dimensions) of `x`.
891+ Default: `0`.
892+
893+ Returns:
894+ usm_ndarray:
895+ An output array having rank `N+1`, where `N` is
896+ the rank (number of dimensions) of `x`. If the input arrays have
897+ different data types, array API Type Promotion Rules apply.
898+
899+ Raises:
900+ ValueError: if not all input arrays have the same shape
901+ IndexError: if provided an `axis` outside of the required interval.
902+ """
903+ res_dtype , res_usm_type , exec_q = _arrays_validation (arrays )
904+
905+ n = len (arrays )
906+ X0 = arrays [0 ]
907+ res_ndim = X0 .ndim + 1
908+ axis = normalize_axis_index (axis , res_ndim )
909+ X0_shape = X0 .shape
910+
911+ for i in range (1 , n ):
912+ if X0_shape != arrays [i ].shape :
913+ raise ValueError ("All input arrays must have the same shape" )
914+
915+ res_shape = tuple (
916+ X0_shape [i - 1 * (i >= axis )] if i != axis else n
917+ for i in range (res_ndim )
918+ )
919+
920+ res = dpt_ext .empty (
921+ res_shape , dtype = res_dtype , usm_type = res_usm_type , sycl_queue = exec_q
922+ )
923+
924+ _manager = dputils .SequentialOrderManager [exec_q ]
925+ dep_evs = _manager .submitted_events
926+ for i in range (n ):
927+ c_shapes_copy = tuple (
928+ i if j == axis else np .s_ [:] for j in range (res_ndim )
929+ )
930+ _dst = res [c_shapes_copy ]
931+ hev , cpy_ev = ti ._copy_usm_ndarray_into_usm_ndarray (
932+ src = arrays [i ], dst = _dst , sycl_queue = exec_q , depends = dep_evs
933+ )
934+ _manager .add_event_pair (hev , cpy_ev )
935+
936+ return res
0 commit comments