Skip to content

Commit bd265da

Browse files
Move ti.broadcast_arrays() to dpctl_ext.tensor and reuse it
1 parent a7cbfdc commit bd265da

6 files changed

Lines changed: 46 additions & 4 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
take_along_axis,
6464
)
6565
from dpctl_ext.tensor._manipulation_functions import (
66+
broadcast_arrays,
6667
broadcast_to,
6768
repeat,
6869
roll,
@@ -77,6 +78,7 @@
7778
"asarray",
7879
"asnumpy",
7980
"astype",
81+
"broadcast_arrays",
8082
"broadcast_to",
8183
"can_cast",
8284
"copy",

dpctl_ext/tensor/_copy_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def _prepare_indices_arrays(inds, q, usm_type):
306306
)
307307

308308
# broadcast
309-
inds = dpt.broadcast_arrays(*inds)
309+
inds = dpt_ext.broadcast_arrays(*inds)
310310

311311
return inds
312312

dpctl_ext/tensor/_ctors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1513,7 +1513,7 @@ def meshgrid(*arrays, indexing="xy"):
15131513
res.append(dpt_ext.reshape(array, sh, copy=True))
15141514
sh = sh[-1:] + sh[:-1]
15151515

1516-
output = dpt.broadcast_arrays(*res)
1516+
output = dpt_ext.broadcast_arrays(*res)
15171517

15181518
return output
15191519

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,15 @@
4747
)
4848

4949

50+
def _broadcast_shapes(*args):
51+
"""
52+
Broadcast the input shapes into a single shape;
53+
returns tuple broadcasted shape.
54+
"""
55+
array_shapes = [array.shape for array in args]
56+
return _broadcast_shape_impl(array_shapes)
57+
58+
5059
def _broadcast_shape_impl(shapes):
5160
if len(set(shapes)) == 1:
5261
return shapes[0]
@@ -103,6 +112,37 @@ def _broadcast_strides(X_shape, X_strides, res_ndim):
103112
return tuple(out_strides)
104113

105114

115+
def broadcast_arrays(*args):
116+
"""broadcast_arrays(*arrays)
117+
118+
Broadcasts one or more :class:`dpctl.tensor.usm_ndarrays` against
119+
one another.
120+
121+
Args:
122+
arrays (usm_ndarray): an arbitrary number of arrays to be
123+
broadcasted.
124+
125+
Returns:
126+
List[usm_ndarray]:
127+
A list of broadcasted arrays. Each array
128+
must have the same shape. Each array must have the same `dtype`,
129+
`device` and `usm_type` attributes as its corresponding input
130+
array.
131+
"""
132+
if len(args) == 0:
133+
raise ValueError("`broadcast_arrays` requires at least one argument")
134+
for X in args:
135+
if not isinstance(X, dpt.usm_ndarray):
136+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
137+
138+
shape = _broadcast_shapes(*args)
139+
140+
if all(X.shape == shape for X in args):
141+
return args
142+
143+
return [broadcast_to(X, shape) for X in args]
144+
145+
106146
def broadcast_to(X, /, shape):
107147
"""broadcast_to(x, shape)
108148

dpnp/dpnp_iface_arraycreation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3131,7 +3131,7 @@ def meshgrid(*xi, copy=True, sparse=False, indexing="xy"):
31313131
output[1] = dpt_ext.reshape(output[1], (-1, 1) + s0[2:])
31323132

31333133
if not sparse:
3134-
output = dpt.broadcast_arrays(*output)
3134+
output = dpt_ext.broadcast_arrays(*output)
31353135

31363136
if copy:
31373137
output = [dpt_ext.copy(x) for x in output]

dpnp/dpnp_iface_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def choose(a, choices, out=None, mode="wrap"):
260260
choices,
261261
)
262262
)
263-
arrs_broadcast = dpt.broadcast_arrays(inds, *choices)
263+
arrs_broadcast = dpt_ext.broadcast_arrays(inds, *choices)
264264
inds = arrs_broadcast[0]
265265
choices = tuple(arrs_broadcast[1:])
266266

0 commit comments

Comments
 (0)