Skip to content

Commit a7cbfdc

Browse files
Reuse dpctl_ext.tensor.broadcast_to() in dpnp
1 parent b8c5390 commit a7cbfdc

3 files changed

Lines changed: 5 additions & 6 deletions

File tree

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,9 +1156,9 @@ def __call__(
11561156

11571157
# Broadcast shapes of input arrays
11581158
if x1.shape != res_shape:
1159-
x1 = dpt.broadcast_to(x1, res_shape)
1159+
x1 = dpt_ext.broadcast_to(x1, res_shape)
11601160
if x2.shape != res_shape:
1161-
x2 = dpt.broadcast_to(x2, res_shape)
1161+
x2 = dpt_ext.broadcast_to(x2, res_shape)
11621162

11631163
# Call the binary function with input and output arrays
11641164
ht_binary_ev, binary_ev = self.get_implementation_function()(

dpnp/dpnp_algo/dpnp_fill.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,12 @@
2828

2929
from numbers import Number
3030

31-
import dpctl.tensor as dpt
3231
import dpctl.utils as dpu
3332
from dpctl.tensor._ctors import _cast_fill_val
3433

3534
# TODO: revert to `from dpctl.tensor...`
3635
# when dpnp fully migrates dpctl/tensor
37-
import dpctl_ext.tensor as dpt_ext
36+
import dpctl_ext.tensor as dpt
3837
import dpnp
3938
from dpctl_ext.tensor._tensor_impl import (
4039
_copy_usm_ndarray_into_usm_ndarray,
@@ -56,7 +55,7 @@ def dpnp_fill(arr, val):
5655
raise dpu.ExecutionPlacementError(
5756
"Input arrays have incompatible queues."
5857
)
59-
a_val = dpt_ext.astype(val, arr.dtype)
58+
a_val = dpt.astype(val, arr.dtype)
6059
a_val = dpt.broadcast_to(a_val, arr.shape)
6160
_manager = dpu.SequentialOrderManager[exec_q]
6261
dep_evs = _manager.submitted_events

dpnp/dpnp_iface_manipulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1178,7 +1178,7 @@ def broadcast_to(array, /, shape, subok=False):
11781178
raise NotImplementedError(f"subok={subok} is currently not supported")
11791179

11801180
usm_array = dpnp.get_usm_ndarray(array)
1181-
new_array = dpt.broadcast_to(usm_array, shape)
1181+
new_array = dpt_ext.broadcast_to(usm_array, shape)
11821182
return dpnp_array._create_from_usm_ndarray(new_array)
11831183

11841184

0 commit comments

Comments
 (0)