@@ -437,6 +437,43 @@ def flip(X, /, *, axis=None):
437437 return X [indexer ]
438438
439439
440+ def permute_dims (X , / , axes ):
441+ """permute_dims(x, axes)
442+
443+ Permute the axes (dimensions) of an array; returns the permuted
444+ array as a view.
445+
446+ Args:
447+ x (usm_ndarray): input array.
448+ axes (Tuple[int, ...]): tuple containing permutation of
449+ `(0,1,...,N-1)` where `N` is the number of axes (dimensions)
450+ of `x`.
451+ Returns:
452+ usm_ndarray:
453+ An array with permuted axes.
454+ The returned array must has the same data type as `x`,
455+ is created on the same device as `x` and has the same USM allocation
456+ type as `x`.
457+ """
458+ if not isinstance (X , dpt .usm_ndarray ):
459+ raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
460+ axes = normalize_axis_tuple (axes , X .ndim , "axes" )
461+ if not X .ndim == len (axes ):
462+ raise ValueError (
463+ "The length of the passed axes does not match "
464+ "to the number of usm_ndarray dimensions."
465+ )
466+ newstrides = tuple (X .strides [i ] for i in axes )
467+ newshape = tuple (X .shape [i ] for i in axes )
468+ return dpt .usm_ndarray (
469+ shape = newshape ,
470+ dtype = X .dtype ,
471+ buffer = X ,
472+ strides = newstrides ,
473+ offset = X ._element_offset ,
474+ )
475+
476+
440477def repeat (x , repeats , / , * , axis = None ):
441478 """repeat(x, repeats, axis=None)
442479
0 commit comments