Skip to content

Commit 4e63cca

Browse files
Move ti.permute_dims() to dpctl_ext.tensor and reuse it
1 parent d2e9279 commit 4e63cca

5 files changed

Lines changed: 48 additions & 5 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
concat,
6969
expand_dims,
7070
flip,
71+
permute_dims,
7172
repeat,
7273
roll,
7374
)
@@ -101,6 +102,7 @@
101102
"isdtype",
102103
"linspace",
103104
"meshgrid",
105+
"permute_dims",
104106
"nonzero",
105107
"ones",
106108
"ones_like",

dpctl_ext/tensor/_copy_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ def _make_empty_like_orderK(x, dt, usm_type, dev):
695695
for i in range(x.ndim)
696696
)
697697
R = R[sl]
698-
return dpt.permute_dims(R, inv_perm)
698+
return dpt_ext.permute_dims(R, inv_perm)
699699

700700

701701
def _empty_like_orderK(x, dt, usm_type=None, dev=None):
@@ -800,7 +800,7 @@ def _empty_like_pair_orderK(X1, X2, dt, res_shape, usm_type, dev):
800800
for i in range(nd1)
801801
)
802802
R = R[sl]
803-
return dpt.permute_dims(R, inv_perm)
803+
return dpt_ext.permute_dims(R, inv_perm)
804804

805805

806806
def _empty_like_triple_orderK(X1, X2, X3, dt, res_shape, usm_type, dev):
@@ -876,7 +876,7 @@ def _empty_like_triple_orderK(X1, X2, X3, dt, res_shape, usm_type, dev):
876876
for i in range(nd1)
877877
)
878878
R = R[sl]
879-
return dpt.permute_dims(R, inv_perm)
879+
return dpt_ext.permute_dims(R, inv_perm)
880880

881881

882882
def copy(usm_ary, /, *, order="K"):

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,43 @@ def flip(X, /, *, axis=None):
437437
return X[indexer]
438438

439439

440+
def permute_dims(X, /, axes):
441+
"""permute_dims(x, axes)
442+
443+
Permute the axes (dimensions) of an array; returns the permuted
444+
array as a view.
445+
446+
Args:
447+
x (usm_ndarray): input array.
448+
axes (Tuple[int, ...]): tuple containing permutation of
449+
`(0,1,...,N-1)` where `N` is the number of axes (dimensions)
450+
of `x`.
451+
Returns:
452+
usm_ndarray:
453+
An array with permuted axes.
454+
The returned array must has the same data type as `x`,
455+
is created on the same device as `x` and has the same USM allocation
456+
type as `x`.
457+
"""
458+
if not isinstance(X, dpt.usm_ndarray):
459+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
460+
axes = normalize_axis_tuple(axes, X.ndim, "axes")
461+
if not X.ndim == len(axes):
462+
raise ValueError(
463+
"The length of the passed axes does not match "
464+
"to the number of usm_ndarray dimensions."
465+
)
466+
newstrides = tuple(X.strides[i] for i in axes)
467+
newshape = tuple(X.shape[i] for i in axes)
468+
return dpt.usm_ndarray(
469+
shape=newshape,
470+
dtype=X.dtype,
471+
buffer=X,
472+
strides=newstrides,
473+
offset=X._element_offset,
474+
)
475+
476+
440477
def repeat(x, repeats, /, *, axis=None):
441478
"""repeat(x, repeats, axis=None)
442479

dpctl_ext/tensor/_reshape.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
_unravel_index,
3838
)
3939

40+
# TODO: revert to `import dpctl.tensor...`
41+
# when dpnp fully migrates dpctl/tensor
42+
import dpctl_ext.tensor as dpt_ext
43+
4044
__doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."
4145

4246

@@ -184,7 +188,7 @@ def reshape(X, /, shape, *, order="C", copy=None):
184188
src=X, dst=flat_res, sycl_queue=copy_q, depends=dep_evs
185189
)
186190
else:
187-
X_t = dpt.permute_dims(X, range(X.ndim - 1, -1, -1))
191+
X_t = dpt_ext.permute_dims(X, range(X.ndim - 1, -1, -1))
188192
hev, r_e = _copy_usm_ndarray_for_reshape(
189193
src=X_t, dst=flat_res, sycl_queue=copy_q, depends=dep_evs
190194
)

dpnp/dpnp_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2283,7 +2283,7 @@ def transpose(self, *axes):
22832283
# self.transpose(None).shape == self.shape[::-1]
22842284
axes = tuple((ndim - x - 1) for x in range(ndim))
22852285

2286-
usm_res = dpt.permute_dims(self._array_obj, axes)
2286+
usm_res = dpt_ext.permute_dims(self._array_obj, axes)
22872287
return dpnp_array._create_from_usm_ndarray(usm_res)
22882288

22892289
def var(

0 commit comments

Comments
 (0)