Skip to content

Commit 99c75cb

Browse files
Use function from dpctl_ext.tensor in tensor python files
1 parent cab0b36 commit 99c75cb

4 files changed

Lines changed: 5 additions & 5 deletions

File tree

dpctl_ext/tensor/_ctors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _copy_through_host_walker(seq_o, usm_res):
360360
)
361361
is None
362362
):
363-
usm_res[...] = dpt.asnumpy(seq_o).copy()
363+
usm_res[...] = dpt_ext.asnumpy(seq_o).copy()
364364
return
365365
else:
366366
usm_res[...] = seq_o
@@ -1440,7 +1440,7 @@ def linspace(
14401440
)
14411441
_manager.add_event_pair(hev, la_ev)
14421442

1443-
return res if int_dt is None else dpt.astype(res, int_dt)
1443+
return res if int_dt is None else dpt_ext.astype(res, int_dt)
14441444

14451445

14461446
def meshgrid(*arrays, indexing="xy"):

dpctl_ext/tensor/_indexing_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def place(arr, mask, vals):
190190
if vals.dtype == arr.dtype:
191191
rhs = vals
192192
else:
193-
rhs = dpt.astype(vals, arr.dtype)
193+
rhs = dpt_ext.astype(vals, arr.dtype)
194194
hev, pl_ev = ti._place(
195195
dst=arr,
196196
cumsum=cumsum,

dpctl_ext/tensor/_reduction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
506506
type.
507507
"""
508508
if x.dtype != dpt.bool:
509-
x = dpt.astype(x, dpt.bool, copy=False)
509+
x = dpt_ext.astype(x, dpt.bool, copy=False)
510510
return sum(
511511
x,
512512
axis=axis,

dpctl_ext/tensor/_utility_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
489489
slice(None) if i != axis else slice(None, -1) for i in range(x_nd)
490490
)
491491

492-
diff_op = dpt.not_equal if x.dtype == dpt.bool else dpt.subtract
492+
diff_op = dpt_ext.not_equal if x.dtype == dpt.bool else dpt_ext.subtract
493493
if n > 1:
494494
arr_tmp0 = diff_op(arr[sl0], arr[sl1])
495495
arr_tmp1 = diff_op(arr_tmp0[sl0], arr_tmp0[sl1])

0 commit comments

Comments
 (0)