Skip to content

Commit b8c5390

Browse files
Move ti.broadcast_to() to dpctl_ext/tensor and reuse it in dpctl_ext/tensor
1 parent a6c397e commit b8c5390

7 files changed

Lines changed: 79 additions & 23 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
take_along_axis,
6464
)
6565
from dpctl_ext.tensor._manipulation_functions import (
66+
broadcast_to,
6667
repeat,
6768
roll,
6869
)
@@ -76,6 +77,7 @@
7677
"asarray",
7778
"asnumpy",
7879
"astype",
80+
"broadcast_to",
7981
"can_cast",
8082
"copy",
8183
"clip",

dpctl_ext/tensor/_clip.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,9 @@ def _clip_none(x, val, out, order, _binary_fn):
205205
order=order,
206206
)
207207
if x_shape != res_shape:
208-
x = dpt.broadcast_to(x, res_shape)
208+
x = dpt_ext.broadcast_to(x, res_shape)
209209
if val_ary.shape != res_shape:
210-
val_ary = dpt.broadcast_to(val_ary, res_shape)
210+
val_ary = dpt_ext.broadcast_to(val_ary, res_shape)
211211
_manager = SequentialOrderManager[exec_q]
212212
dep_evs = _manager.submitted_events
213213
ht_binary_ev, binary_ev = _binary_fn(
@@ -251,8 +251,8 @@ def _clip_none(x, val, out, order, _binary_fn):
251251
)
252252

253253
if x_shape != res_shape:
254-
x = dpt.broadcast_to(x, res_shape)
255-
buf = dpt.broadcast_to(buf, res_shape)
254+
x = dpt_ext.broadcast_to(x, res_shape)
255+
buf = dpt_ext.broadcast_to(buf, res_shape)
256256
ht_binary_ev, binary_ev = _binary_fn(
257257
src1=x,
258258
src2=buf,
@@ -580,11 +580,11 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
580580
order=order,
581581
)
582582
if x_shape != res_shape:
583-
x = dpt.broadcast_to(x, res_shape)
583+
x = dpt_ext.broadcast_to(x, res_shape)
584584
if a_min.shape != res_shape:
585-
a_min = dpt.broadcast_to(a_min, res_shape)
585+
a_min = dpt_ext.broadcast_to(a_min, res_shape)
586586
if a_max.shape != res_shape:
587-
a_max = dpt.broadcast_to(a_max, res_shape)
587+
a_max = dpt_ext.broadcast_to(a_max, res_shape)
588588
_manager = SequentialOrderManager[exec_q]
589589
dep_ev = _manager.submitted_events
590590
ht_binary_ev, binary_ev = ti._clip(
@@ -639,10 +639,10 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
639639
order=order,
640640
)
641641

642-
x = dpt.broadcast_to(x, res_shape)
642+
x = dpt_ext.broadcast_to(x, res_shape)
643643
if a_min.shape != res_shape:
644-
a_min = dpt.broadcast_to(a_min, res_shape)
645-
buf2 = dpt.broadcast_to(buf2, res_shape)
644+
a_min = dpt_ext.broadcast_to(a_min, res_shape)
645+
buf2 = dpt_ext.broadcast_to(buf2, res_shape)
646646
ht_binary_ev, binary_ev = ti._clip(
647647
src=x,
648648
min=a_min,
@@ -695,10 +695,10 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
695695
order=order,
696696
)
697697

698-
x = dpt.broadcast_to(x, res_shape)
699-
buf1 = dpt.broadcast_to(buf1, res_shape)
698+
x = dpt_ext.broadcast_to(x, res_shape)
699+
buf1 = dpt_ext.broadcast_to(buf1, res_shape)
700700
if a_max.shape != res_shape:
701-
a_max = dpt.broadcast_to(a_max, res_shape)
701+
a_max = dpt_ext.broadcast_to(a_max, res_shape)
702702
ht_binary_ev, binary_ev = ti._clip(
703703
src=x,
704704
min=buf1,
@@ -766,9 +766,9 @@ def clip(x, /, min=None, max=None, out=None, order="K"):
766766
order=order,
767767
)
768768

769-
x = dpt.broadcast_to(x, res_shape)
770-
buf1 = dpt.broadcast_to(buf1, res_shape)
771-
buf2 = dpt.broadcast_to(buf2, res_shape)
769+
x = dpt_ext.broadcast_to(x, res_shape)
770+
buf1 = dpt_ext.broadcast_to(buf1, res_shape)
771+
buf2 = dpt_ext.broadcast_to(buf2, res_shape)
772772
ht_, clip_ev = ti._clip(
773773
src=x,
774774
min=buf1,

dpctl_ext/tensor/_copy_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
368368
rhs = vals
369369
else:
370370
rhs = dpt_ext.astype(vals, ary.dtype)
371-
rhs = dpt.broadcast_to(rhs, expected_vals_shape)
371+
rhs = dpt_ext.broadcast_to(rhs, expected_vals_shape)
372372
_manager = dpctl.utils.SequentialOrderManager[exec_q]
373373
dep_ev = _manager.submitted_events
374374
hev, put_ev = ti._put(

dpctl_ext/tensor/_ctors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,7 @@ def full(
12001200
usm_type=usm_type,
12011201
sycl_queue=sycl_queue,
12021202
)
1203-
return dpt_ext.copy(dpt.broadcast_to(X, shape), order=order)
1203+
return dpt_ext.copy(dpt_ext.broadcast_to(X, shape), order=order)
12041204
else:
12051205
_validate_fill_value(fill_value)
12061206

@@ -1307,7 +1307,7 @@ def full_like(
13071307
usm_type=usm_type,
13081308
sycl_queue=sycl_queue,
13091309
)
1310-
X = dpt.broadcast_to(X, sh)
1310+
X = dpt_ext.broadcast_to(X, sh)
13111311
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)
13121312
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
13131313
# order copy after tasks populating X

dpctl_ext/tensor/_indexing_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def put_vec_duplicates(vec, ind, vals):
341341
rhs = vals
342342
else:
343343
rhs = dpt_ext.astype(vals, x.dtype)
344-
rhs = dpt.broadcast_to(rhs, val_shape)
344+
rhs = dpt_ext.broadcast_to(rhs, val_shape)
345345

346346
_manager = dpctl.utils.SequentialOrderManager[exec_q]
347347
deps_ev = _manager.submitted_events

dpctl_ext/tensor/_manipulation_functions.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,60 @@ def _broadcast_shape_impl(shapes):
8686
return tuple(common_shape)
8787

8888

89+
def _broadcast_strides(X_shape, X_strides, res_ndim):
90+
"""
91+
Broadcasts strides to match the given dimensions;
92+
returns tuple type strides.
93+
"""
94+
out_strides = [0] * res_ndim
95+
X_shape_len = len(X_shape)
96+
str_dim = -X_shape_len
97+
for i in range(X_shape_len):
98+
shape_value = X_shape[i]
99+
if not shape_value == 1:
100+
out_strides[str_dim] = X_strides[i]
101+
str_dim += 1
102+
103+
return tuple(out_strides)
104+
105+
106+
def broadcast_to(X, /, shape):
107+
"""broadcast_to(x, shape)
108+
109+
Broadcast an array to a new `shape`; returns the broadcasted
110+
:class:`dpctl.tensor.usm_ndarray` as a view.
111+
112+
Args:
113+
x (usm_ndarray): input array
114+
shape (Tuple[int,...]): array shape. The `shape` must be
115+
compatible with `x` according to broadcasting rules.
116+
117+
Returns:
118+
usm_ndarray:
119+
An array with the specified `shape`.
120+
The output array is a view of the input array, and
121+
hence has the same data type, USM allocation type and
122+
device attributes.
123+
"""
124+
if not isinstance(X, dpt.usm_ndarray):
125+
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")
126+
127+
# Use numpy.broadcast_to to check the validity of the input
128+
# parameter 'shape'. Raise ValueError if 'X' is not compatible
129+
# with 'shape' according to NumPy's broadcasting rules.
130+
new_array = np.broadcast_to(
131+
np.broadcast_to(np.empty(tuple(), dtype="u1"), X.shape), shape
132+
)
133+
new_sts = _broadcast_strides(X.shape, X.strides, new_array.ndim)
134+
return dpt.usm_ndarray(
135+
shape=new_array.shape,
136+
dtype=X.dtype,
137+
buffer=X,
138+
strides=new_sts,
139+
offset=X._element_offset,
140+
)
141+
142+
89143
def repeat(x, repeats, /, *, axis=None):
90144
"""repeat(x, repeats, axis=None)
91145

dpctl_ext/tensor/_search_functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,11 +389,11 @@ def where(condition, x1, x2, /, *, order="K", out=None):
389389
)
390390

391391
if condition_shape != res_shape:
392-
condition = dpt.broadcast_to(condition, res_shape)
392+
condition = dpt_ext.broadcast_to(condition, res_shape)
393393
if x1_shape != res_shape:
394-
x1 = dpt.broadcast_to(x1, res_shape)
394+
x1 = dpt_ext.broadcast_to(x1, res_shape)
395395
if x2_shape != res_shape:
396-
x2 = dpt.broadcast_to(x2, res_shape)
396+
x2 = dpt_ext.broadcast_to(x2, res_shape)
397397

398398
dep_evs = _manager.submitted_events
399399
hev, where_ev = ti._where(

0 commit comments

Comments
 (0)