Skip to content

Commit ba51636

Browse files
Move ti.squeeze() to dpctl_ext.tensor and reuse it
1 parent 9c88edb commit ba51636

3 files changed

Lines changed: 48 additions & 1 deletion

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
permute_dims,
7373
repeat,
7474
roll,
75+
squeeze,
7576
)
7677
from dpctl_ext.tensor._reshape import reshape
7778

@@ -115,6 +116,7 @@
115116
"reshape",
116117
"result_type",
117118
"roll",
119+
"squeeze",
118120
"take",
119121
"take_along_axis",
120122
"to_numpy",

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,3 +828,48 @@ def roll(x, /, shift, *, axis=None):
828828
)
829829
_manager.add_event_pair(ht_e, roll_ev)
830830
return res
831+
832+
833+
def squeeze(X, /, axis=None):
834+
"""squeeze(x, axis)
835+
836+
Removes singleton dimensions (axes) from array `x`.
837+
838+
Args:
839+
x (usm_ndarray): input array
840+
axis (Union[int, Tuple[int,...]]): axis (or axes) to squeeze.
841+
842+
Returns:
843+
usm_ndarray:
844+
Output array is a view, if possible,
845+
and a copy otherwise, but with all or a subset of the
846+
dimensions of length 1 removed. Output has the same data
847+
type as the input, is allocated on the same device as the
848+
input and has the same USM allocation type as the input
849+
array `x`.
850+
851+
Raises:
852+
ValueError: if the specified axis has a size greater than one.
853+
"""
854+
if not isinstance(X, dpt.usm_ndarray):
855+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
856+
X_shape = X.shape
857+
if axis is not None:
858+
axis = normalize_axis_tuple(axis, X.ndim if X.ndim != 0 else X.ndim + 1)
859+
new_shape = []
860+
for i, x in enumerate(X_shape):
861+
if i not in axis:
862+
new_shape.append(x)
863+
else:
864+
if x != 1:
865+
raise ValueError(
866+
"Cannot select an axis to squeeze out "
867+
"which has size not equal to one."
868+
)
869+
new_shape = tuple(new_shape)
870+
else:
871+
new_shape = tuple(axis for axis in X_shape if axis != 1)
872+
if new_shape == X.shape:
873+
return X
874+
else:
875+
return dpt_ext.reshape(X, new_shape)

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3663,7 +3663,7 @@ def squeeze(a, /, axis=None):
36633663
"""
36643664

36653665
usm_a = dpnp.get_usm_ndarray(a)
3666-
usm_res = dpt.squeeze(usm_a, axis=axis)
3666+
usm_res = dpt_ext.squeeze(usm_a, axis=axis)
36673667
return dpnp_array._create_from_usm_ndarray(usm_res)
36683668

36693669

0 commit comments

Comments
 (0)