Skip to content

Commit 9c88edb

Browse files
Move ti.moveaxis() to dpctl_ext.tensor and reuse it
1 parent 4e63cca commit 9c88edb

4 files changed

Lines changed: 55 additions & 2 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
concat,
6969
expand_dims,
7070
flip,
71+
moveaxis,
7172
permute_dims,
7273
repeat,
7374
roll,
@@ -102,6 +103,7 @@
102103
"isdtype",
103104
"linspace",
104105
"meshgrid",
106+
"moveaxis",
105107
"permute_dims",
106108
"nonzero",
107109
"ones",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,57 @@ def flip(X, /, *, axis=None):
437437
return X[indexer]
438438

439439

440+
def moveaxis(X, source, destination, /):
441+
"""moveaxis(x, source, destination)
442+
443+
Moves axes of an array to new positions.
444+
445+
Args:
446+
x (usm_ndarray): input array
447+
448+
source (int or a sequence of int):
449+
Original positions of the axes to move.
450+
These must be unique. If `x` has rank (i.e., number of
451+
dimensions) `N`, a valid `axis` must be in the
452+
half-open interval `[-N, N)`.
453+
454+
destination (int or a sequence of int):
455+
Destination positions for each of the original axes.
456+
These must also be unique. If `x` has rank
457+
(i.e., number of dimensions) `N`, a valid `axis` must be
458+
in the half-open interval `[-N, N)`.
459+
460+
Returns:
461+
usm_ndarray:
462+
Array with moved axes.
463+
The returned array must has the same data type as `x`,
464+
is created on the same device as `x` and has the same
465+
USM allocation type as `x`.
466+
467+
Raises:
468+
AxisError: if `axis` value is invalid.
469+
ValueError: if `src` and `dst` have not equal number of elements.
470+
"""
471+
if not isinstance(X, dpt.usm_ndarray):
472+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
473+
474+
source = normalize_axis_tuple(source, X.ndim, "source")
475+
destination = normalize_axis_tuple(destination, X.ndim, "destination")
476+
477+
if len(source) != len(destination):
478+
raise ValueError(
479+
"`source` and `destination` arguments must have "
480+
"the same number of elements"
481+
)
482+
483+
ind = [n for n in range(X.ndim) if n not in source]
484+
485+
for src, dst in sorted(zip(destination, source)):
486+
ind.insert(src, dst)
487+
488+
return dpt_ext.permute_dims(X, tuple(ind))
489+
490+
440491
def permute_dims(X, /, axes):
441492
"""permute_dims(x, axes)
442493

dpnp/dpnp_algo/dpnp_arraycreation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def dpnp_linspace(
256256
usm_res[-1, ...] = usm_stop
257257

258258
if axis != 0:
259-
usm_res = dpt.moveaxis(usm_res, 0, axis)
259+
usm_res = dpt_ext.moveaxis(usm_res, 0, axis)
260260

261261
if dpnp.issubdtype(dtype, dpnp.integer):
262262
dpt.floor(usm_res, out=usm_res)

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2408,7 +2408,7 @@ def moveaxis(a, source, destination):
24082408

24092409
usm_array = dpnp.get_usm_ndarray(a)
24102410
return dpnp_array._create_from_usm_ndarray(
2411-
dpt.moveaxis(usm_array, source, destination)
2411+
dpt_ext.moveaxis(usm_array, source, destination)
24122412
)
24132413

24142414

0 commit comments

Comments
 (0)