Skip to content

Commit d2e9279

Browse files
Move ti.flip() to dpctl_ext.tensor and reuse it
1 parent ea7ea3c commit d2e9279

3 files changed

Lines changed: 35 additions & 1 deletion

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
broadcast_to,
6868
concat,
6969
expand_dims,
70+
flip,
7071
repeat,
7172
roll,
7273
)
@@ -92,6 +93,7 @@
9293
"expand_dims",
9394
"eye",
9495
"finfo",
96+
"flip",
9597
"from_numpy",
9698
"full",
9799
"full_like",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,38 @@ def expand_dims(X, /, *, axis=0):
405405
return dpt_ext.reshape(X, shape)
406406

407407

408+
def flip(X, /, *, axis=None):
409+
"""flip(x, axis)
410+
411+
Reverses the order of elements in an array `x` along the given `axis`.
412+
The shape of the array is preserved, but the elements are reordered.
413+
414+
Args:
415+
x (usm_ndarray): input array.
416+
axis (Optional[Union[int, Tuple[int,...]]]): axis (or axes) along
417+
which to flip.
418+
If `axis` is `None`, all input array axes are flipped.
419+
If `axis` is negative, the flipped axis is counted from the
420+
last dimension. If provided more than one axis, only the specified
421+
axes are flipped. Default: `None`.
422+
423+
Returns:
424+
usm_ndarray:
425+
A view of `x` with the entries of `axis` reversed.
426+
"""
427+
if not isinstance(X, dpt.usm_ndarray):
428+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
429+
X_ndim = X.ndim
430+
if axis is None:
431+
indexer = (np.s_[::-1],) * X_ndim
432+
else:
433+
axis = normalize_axis_tuple(axis, X_ndim)
434+
indexer = tuple(
435+
np.s_[::-1] if i in axis else np.s_[:] for i in range(X.ndim)
436+
)
437+
return X[indexer]
438+
439+
408440
def repeat(x, repeats, /, *, axis=None):
409441
"""repeat(x, repeats, axis=None)
410442

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,7 @@ def flip(m, axis=None):
19201920
"""
19211921

19221922
m_usm = dpnp.get_usm_ndarray(m)
1923-
return dpnp_array._create_from_usm_ndarray(dpt.flip(m_usm, axis=axis))
1923+
return dpnp_array._create_from_usm_ndarray(dpt_ext.flip(m_usm, axis=axis))
19241924

19251925

19261926
def fliplr(m):

0 commit comments

Comments
 (0)