2929import warnings
3030from functools import wraps
3131
32- import dpctl .tensor as dpt
33- import dpctl .tensor ._type_utils as dtu
3432import dpctl .utils as dpu
3533import numpy
36- from dpctl .tensor ._elementwise_common import (
37- BinaryElementwiseFunc ,
38- UnaryElementwiseFunc ,
39- )
40- from dpctl .tensor ._scalar_utils import (
41- _get_dtype ,
42- _get_shape ,
43- _validate_dtype ,
44- )
4534
4635# pylint: disable=no-name-in-module
4736# TODO: revert to `import dpctl.tensor...`
4837# when dpnp fully migrates dpctl/tensor
49- import dpctl_ext .tensor as dpt_ext
38+ import dpctl_ext .tensor as dpt
5039import dpctl_ext .tensor ._copy_utils as dtc
5140import dpctl_ext .tensor ._tensor_impl as dti
41+ import dpctl_ext .tensor ._type_utils as dtu
5242import dpnp
5343import dpnp .backend .extensions .vm ._vm_impl as vmi
44+ from dpctl_ext .tensor ._elementwise_common import (
45+ BinaryElementwiseFunc ,
46+ UnaryElementwiseFunc ,
47+ )
48+ from dpctl_ext .tensor ._scalar_utils import (
49+ _get_dtype ,
50+ _get_shape ,
51+ _validate_dtype ,
52+ )
5453from dpnp .dpnp_array import dpnp_array
5554from dpnp .dpnp_utils import get_usm_allocations
5655from dpnp .dpnp_utils .dpnp_utils_common import (
@@ -213,7 +212,7 @@ def __call__(
213212
214213 x_usm = dpnp .get_usm_ndarray (x )
215214 if dtype is not None :
216- x_usm = dpt_ext .astype (x_usm , dtype , copy = False )
215+ x_usm = dpt .astype (x_usm , dtype , copy = False )
217216
218217 out = self ._unpack_out_kw (out )
219218 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
@@ -467,7 +466,7 @@ def __call__(
467466 )
468467
469468 # Allocate a temporary buffer with the required dtype
470- out [i ] = dpt_ext .empty_like (res , dtype = res_dt )
469+ out [i ] = dpt .empty_like (res , dtype = res_dt )
471470 elif (
472471 buf_dt is None
473472 and dti ._array_overlap (x , res )
@@ -476,7 +475,7 @@ def __call__(
476475 # Allocate a temporary buffer to avoid memory overlapping.
477476 # Note if `buf_dt` is not None, a temporary copy of `x` will be
478477 # created, so the array overlap check isn't needed.
479- out [i ] = dpt_ext .empty_like (res )
478+ out [i ] = dpt .empty_like (res )
480479
481480 _manager = dpu .SequentialOrderManager [exec_q ]
482481 dep_evs = _manager .submitted_events
@@ -486,7 +485,7 @@ def __call__(
486485 if order == "K" :
487486 buf = dtc ._empty_like_orderK (x , buf_dt )
488487 else :
489- buf = dpt_ext .empty_like (x , dtype = buf_dt , order = order )
488+ buf = dpt .empty_like (x , dtype = buf_dt , order = order )
490489
491490 ht_copy_ev , copy_ev = dti ._copy_usm_ndarray_into_usm_ndarray (
492491 src = x , dst = buf , sycl_queue = exec_q , depends = dep_evs
@@ -503,7 +502,7 @@ def __call__(
503502 if order == "K" :
504503 out [i ] = dtc ._empty_like_orderK (x , res_dt )
505504 else :
506- out [i ] = dpt_ext .empty_like (x , dtype = res_dt , order = order )
505+ out [i ] = dpt .empty_like (x , dtype = res_dt , order = order )
507506
508507 # Call the unary function with input and output arrays
509508 ht_unary_ev , unary_ev = self .get_implementation_function ()(
@@ -713,24 +712,24 @@ def __call__(
713712
714713 if dtype is not None :
715714 if dpnp .isscalar (x1 ):
716- x1_usm = dpt_ext .asarray (
715+ x1_usm = dpt .asarray (
717716 x1 ,
718717 dtype = dtype ,
719718 sycl_queue = x2 .sycl_queue ,
720719 usm_type = x2 .usm_type ,
721720 )
722- x2_usm = dpt_ext .astype (x2_usm , dtype , copy = False )
721+ x2_usm = dpt .astype (x2_usm , dtype , copy = False )
723722 elif dpnp .isscalar (x2 ):
724- x1_usm = dpt_ext .astype (x1_usm , dtype , copy = False )
725- x2_usm = dpt_ext .asarray (
723+ x1_usm = dpt .astype (x1_usm , dtype , copy = False )
724+ x2_usm = dpt .asarray (
726725 x2 ,
727726 dtype = dtype ,
728727 sycl_queue = x1 .sycl_queue ,
729728 usm_type = x1 .usm_type ,
730729 )
731730 else :
732- x1_usm = dpt_ext .astype (x1_usm , dtype , copy = False )
733- x2_usm = dpt_ext .astype (x2_usm , dtype , copy = False )
731+ x1_usm = dpt .astype (x1_usm , dtype , copy = False )
732+ x2_usm = dpt .astype (x2_usm , dtype , copy = False )
734733
735734 res_usm = super ().__call__ (x1_usm , x2_usm , out = out_usm , order = order )
736735
@@ -1078,7 +1077,7 @@ def __call__(
10781077 )
10791078
10801079 # Allocate a temporary buffer with the required dtype
1081- out [i ] = dpt_ext .empty_like (res , dtype = res_dt )
1080+ out [i ] = dpt .empty_like (res , dtype = res_dt )
10821081 else :
10831082 # If `dt` is not None, a temporary copy of `x` will be created,
10841083 # so the array overlap check isn't needed.
@@ -1094,7 +1093,7 @@ def __call__(
10941093 for x in x_to_check
10951094 ):
10961095 # allocate a temporary buffer to avoid memory overlapping
1097- out [i ] = dpt_ext .empty_like (res )
1096+ out [i ] = dpt .empty_like (res )
10981097
10991098 x1 = dpnp .as_usm_ndarray (x1 , dtype = x1_dt , sycl_queue = exec_q )
11001099 x2 = dpnp .as_usm_ndarray (x2 , dtype = x2_dt , sycl_queue = exec_q )
@@ -1127,7 +1126,7 @@ def __call__(
11271126 if order == "K" :
11281127 buf = dtc ._empty_like_orderK (x , buf_dt )
11291128 else :
1130- buf = dpt_ext .empty_like (x , dtype = buf_dt , order = order )
1129+ buf = dpt .empty_like (x , dtype = buf_dt , order = order )
11311130
11321131 ht_copy_ev , copy_ev = dti ._copy_usm_ndarray_into_usm_ndarray (
11331132 src = x , dst = buf , sycl_queue = exec_q , depends = dep_evs
@@ -1146,7 +1145,7 @@ def __call__(
11461145 x1 , x2 , res_dt , res_shape , res_usm_type , exec_q
11471146 )
11481147 else :
1149- out [i ] = dpt_ext .empty (
1148+ out [i ] = dpt .empty (
11501149 res_shape ,
11511150 dtype = res_dt ,
11521151 order = order ,
@@ -1156,9 +1155,9 @@ def __call__(
11561155
11571156 # Broadcast shapes of input arrays
11581157 if x1 .shape != res_shape :
1159- x1 = dpt_ext .broadcast_to (x1 , res_shape )
1158+ x1 = dpt .broadcast_to (x1 , res_shape )
11601159 if x2 .shape != res_shape :
1161- x2 = dpt_ext .broadcast_to (x2 , res_shape )
1160+ x2 = dpt .broadcast_to (x2 , res_shape )
11621161
11631162 # Call the binary function with input and output arrays
11641163 ht_binary_ev , binary_ev = self .get_implementation_function ()(
@@ -1326,7 +1325,7 @@ def __call__(self, x, /, decimals=0, out=None, *, dtype=None):
13261325 res_usm = dpt .divide (x_usm , 10 ** decimals , out = out_usm )
13271326
13281327 if dtype is not None :
1329- res_usm = dpt_ext .astype (res_usm , dtype , copy = False )
1328+ res_usm = dpt .astype (res_usm , dtype , copy = False )
13301329
13311330 if out is not None and isinstance (out , dpnp_array ):
13321331 return out
0 commit comments