Skip to content

Commit 3883a1c

Browse files
Switch fully to dpctl_ext.tensor in dpnp
1 parent 3c428a6 commit 3883a1c

17 files changed

Lines changed: 152 additions & 167 deletions

dpnp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
# Borrowed from DPCTL
6565
with warnings.catch_warnings():
6666
warnings.simplefilter("ignore", DeprecationWarning)
67-
from dpctl.tensor import __array_api_version__, DLDeviceType
67+
from dpctl_ext.tensor import __array_api_version__, DLDeviceType
6868

6969
from .dpnp_array import dpnp_array as ndarray
7070
from .dpnp_array_api_info import __array_namespace_info__

dpnp/dpnp_algo/dpnp_arraycreation.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,12 @@
2929
import math
3030
import operator
3131

32-
import dpctl.tensor as dpt
3332
import dpctl.utils as dpu
3433
import numpy
3534

3635
# TODO: revert to `import dpctl.tensor...`
3736
# when dpnp fully migrates dpctl/tensor
38-
import dpctl_ext.tensor as dpt_ext
37+
import dpctl_ext.tensor as dpt
3938
import dpnp
4039
from dpnp.dpnp_array import dpnp_array
4140
from dpnp.dpnp_utils import get_usm_allocations, map_dtype_to_device
@@ -53,7 +52,7 @@ def _as_usm_ndarray(a, usm_type, sycl_queue):
5352

5453
if isinstance(a, dpnp_array):
5554
a = a.get_array()
56-
return dpt_ext.asarray(a, usm_type=usm_type, sycl_queue=sycl_queue)
55+
return dpt.asarray(a, usm_type=usm_type, sycl_queue=sycl_queue)
5756

5857

5958
def _check_has_zero_val(a):
@@ -196,7 +195,7 @@ def dpnp_linspace(
196195

197196
if dpnp.isscalar(start) and dpnp.isscalar(stop):
198197
# Call linspace() function for scalars.
199-
usm_res = dpt_ext.linspace(
198+
usm_res = dpt.linspace(
200199
start,
201200
stop,
202201
num,
@@ -213,37 +212,35 @@ def dpnp_linspace(
213212
else:
214213
step = dpnp.nan
215214
else:
216-
usm_start = dpt_ext.asarray(
215+
usm_start = dpt.asarray(
217216
start,
218217
dtype=dt,
219218
usm_type=_usm_type,
220219
sycl_queue=sycl_queue_normalized,
221220
)
222-
usm_stop = dpt_ext.asarray(
221+
usm_stop = dpt.asarray(
223222
stop, dtype=dt, usm_type=_usm_type, sycl_queue=sycl_queue_normalized
224223
)
225224

226225
delta = usm_stop - usm_start
227226

228-
usm_res = dpt_ext.arange(
227+
usm_res = dpt.arange(
229228
0,
230229
stop=num,
231230
step=1,
232231
dtype=dt,
233232
usm_type=_usm_type,
234233
sycl_queue=sycl_queue_normalized,
235234
)
236-
usm_res = dpt_ext.reshape(
237-
usm_res, (-1,) + (1,) * delta.ndim, copy=False
238-
)
235+
usm_res = dpt.reshape(usm_res, (-1,) + (1,) * delta.ndim, copy=False)
239236

240237
if step_num > 0:
241238
step = delta / step_num
242239

243240
# Needed a special handling for denormal numbers (when step == 0),
244241
# see numpy#5437 for more details.
245242
# Note, dpt.where() is used to avoid a synchronization branch.
246-
usm_res = dpt_ext.where(
243+
usm_res = dpt.where(
247244
step == 0, (usm_res / step_num) * delta, usm_res * step
248245
)
249246
else:
@@ -256,17 +253,17 @@ def dpnp_linspace(
256253
usm_res[-1, ...] = usm_stop
257254

258255
if axis != 0:
259-
usm_res = dpt_ext.moveaxis(usm_res, 0, axis)
256+
usm_res = dpt.moveaxis(usm_res, 0, axis)
260257

261258
if dpnp.issubdtype(dtype, dpnp.integer):
262259
dpt.floor(usm_res, out=usm_res)
263260

264-
res = dpt_ext.astype(usm_res, dtype, copy=False)
261+
res = dpt.astype(usm_res, dtype, copy=False)
265262
res = dpnp_array._create_from_usm_ndarray(res)
266263

267264
if retstep is True:
268265
if dpnp.isscalar(step):
269-
step = dpt_ext.asarray(
266+
step = dpt.asarray(
270267
step, usm_type=res.usm_type, sycl_queue=res.sycl_queue
271268
)
272269
return res, dpnp_array._create_from_usm_ndarray(step)

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,27 @@
2929
import warnings
3030
from functools import wraps
3131

32-
import dpctl.tensor as dpt
33-
import dpctl.tensor._type_utils as dtu
3432
import dpctl.utils as dpu
3533
import numpy
36-
from dpctl.tensor._elementwise_common import (
37-
BinaryElementwiseFunc,
38-
UnaryElementwiseFunc,
39-
)
40-
from dpctl.tensor._scalar_utils import (
41-
_get_dtype,
42-
_get_shape,
43-
_validate_dtype,
44-
)
4534

4635
# pylint: disable=no-name-in-module
4736
# TODO: revert to `import dpctl.tensor...`
4837
# when dpnp fully migrates dpctl/tensor
49-
import dpctl_ext.tensor as dpt_ext
38+
import dpctl_ext.tensor as dpt
5039
import dpctl_ext.tensor._copy_utils as dtc
5140
import dpctl_ext.tensor._tensor_impl as dti
41+
import dpctl_ext.tensor._type_utils as dtu
5242
import dpnp
5343
import dpnp.backend.extensions.vm._vm_impl as vmi
44+
from dpctl_ext.tensor._elementwise_common import (
45+
BinaryElementwiseFunc,
46+
UnaryElementwiseFunc,
47+
)
48+
from dpctl_ext.tensor._scalar_utils import (
49+
_get_dtype,
50+
_get_shape,
51+
_validate_dtype,
52+
)
5453
from dpnp.dpnp_array import dpnp_array
5554
from dpnp.dpnp_utils import get_usm_allocations
5655
from dpnp.dpnp_utils.dpnp_utils_common import (
@@ -213,7 +212,7 @@ def __call__(
213212

214213
x_usm = dpnp.get_usm_ndarray(x)
215214
if dtype is not None:
216-
x_usm = dpt_ext.astype(x_usm, dtype, copy=False)
215+
x_usm = dpt.astype(x_usm, dtype, copy=False)
217216

218217
out = self._unpack_out_kw(out)
219218
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
@@ -467,7 +466,7 @@ def __call__(
467466
)
468467

469468
# Allocate a temporary buffer with the required dtype
470-
out[i] = dpt_ext.empty_like(res, dtype=res_dt)
469+
out[i] = dpt.empty_like(res, dtype=res_dt)
471470
elif (
472471
buf_dt is None
473472
and dti._array_overlap(x, res)
@@ -476,7 +475,7 @@ def __call__(
476475
# Allocate a temporary buffer to avoid memory overlapping.
477476
# Note if `buf_dt` is not None, a temporary copy of `x` will be
478477
# created, so the array overlap check isn't needed.
479-
out[i] = dpt_ext.empty_like(res)
478+
out[i] = dpt.empty_like(res)
480479

481480
_manager = dpu.SequentialOrderManager[exec_q]
482481
dep_evs = _manager.submitted_events
@@ -486,7 +485,7 @@ def __call__(
486485
if order == "K":
487486
buf = dtc._empty_like_orderK(x, buf_dt)
488487
else:
489-
buf = dpt_ext.empty_like(x, dtype=buf_dt, order=order)
488+
buf = dpt.empty_like(x, dtype=buf_dt, order=order)
490489

491490
ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray(
492491
src=x, dst=buf, sycl_queue=exec_q, depends=dep_evs
@@ -503,7 +502,7 @@ def __call__(
503502
if order == "K":
504503
out[i] = dtc._empty_like_orderK(x, res_dt)
505504
else:
506-
out[i] = dpt_ext.empty_like(x, dtype=res_dt, order=order)
505+
out[i] = dpt.empty_like(x, dtype=res_dt, order=order)
507506

508507
# Call the unary function with input and output arrays
509508
ht_unary_ev, unary_ev = self.get_implementation_function()(
@@ -713,24 +712,24 @@ def __call__(
713712

714713
if dtype is not None:
715714
if dpnp.isscalar(x1):
716-
x1_usm = dpt_ext.asarray(
715+
x1_usm = dpt.asarray(
717716
x1,
718717
dtype=dtype,
719718
sycl_queue=x2.sycl_queue,
720719
usm_type=x2.usm_type,
721720
)
722-
x2_usm = dpt_ext.astype(x2_usm, dtype, copy=False)
721+
x2_usm = dpt.astype(x2_usm, dtype, copy=False)
723722
elif dpnp.isscalar(x2):
724-
x1_usm = dpt_ext.astype(x1_usm, dtype, copy=False)
725-
x2_usm = dpt_ext.asarray(
723+
x1_usm = dpt.astype(x1_usm, dtype, copy=False)
724+
x2_usm = dpt.asarray(
726725
x2,
727726
dtype=dtype,
728727
sycl_queue=x1.sycl_queue,
729728
usm_type=x1.usm_type,
730729
)
731730
else:
732-
x1_usm = dpt_ext.astype(x1_usm, dtype, copy=False)
733-
x2_usm = dpt_ext.astype(x2_usm, dtype, copy=False)
731+
x1_usm = dpt.astype(x1_usm, dtype, copy=False)
732+
x2_usm = dpt.astype(x2_usm, dtype, copy=False)
734733

735734
res_usm = super().__call__(x1_usm, x2_usm, out=out_usm, order=order)
736735

@@ -1078,7 +1077,7 @@ def __call__(
10781077
)
10791078

10801079
# Allocate a temporary buffer with the required dtype
1081-
out[i] = dpt_ext.empty_like(res, dtype=res_dt)
1080+
out[i] = dpt.empty_like(res, dtype=res_dt)
10821081
else:
10831082
# If `dt` is not None, a temporary copy of `x` will be created,
10841083
# so the array overlap check isn't needed.
@@ -1094,7 +1093,7 @@ def __call__(
10941093
for x in x_to_check
10951094
):
10961095
# allocate a temporary buffer to avoid memory overlapping
1097-
out[i] = dpt_ext.empty_like(res)
1096+
out[i] = dpt.empty_like(res)
10981097

10991098
x1 = dpnp.as_usm_ndarray(x1, dtype=x1_dt, sycl_queue=exec_q)
11001099
x2 = dpnp.as_usm_ndarray(x2, dtype=x2_dt, sycl_queue=exec_q)
@@ -1127,7 +1126,7 @@ def __call__(
11271126
if order == "K":
11281127
buf = dtc._empty_like_orderK(x, buf_dt)
11291128
else:
1130-
buf = dpt_ext.empty_like(x, dtype=buf_dt, order=order)
1129+
buf = dpt.empty_like(x, dtype=buf_dt, order=order)
11311130

11321131
ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray(
11331132
src=x, dst=buf, sycl_queue=exec_q, depends=dep_evs
@@ -1146,7 +1145,7 @@ def __call__(
11461145
x1, x2, res_dt, res_shape, res_usm_type, exec_q
11471146
)
11481147
else:
1149-
out[i] = dpt_ext.empty(
1148+
out[i] = dpt.empty(
11501149
res_shape,
11511150
dtype=res_dt,
11521151
order=order,
@@ -1156,9 +1155,9 @@ def __call__(
11561155

11571156
# Broadcast shapes of input arrays
11581157
if x1.shape != res_shape:
1159-
x1 = dpt_ext.broadcast_to(x1, res_shape)
1158+
x1 = dpt.broadcast_to(x1, res_shape)
11601159
if x2.shape != res_shape:
1161-
x2 = dpt_ext.broadcast_to(x2, res_shape)
1160+
x2 = dpt.broadcast_to(x2, res_shape)
11621161

11631162
# Call the binary function with input and output arrays
11641163
ht_binary_ev, binary_ev = self.get_implementation_function()(
@@ -1326,7 +1325,7 @@ def __call__(self, x, /, decimals=0, out=None, *, dtype=None):
13261325
res_usm = dpt.divide(x_usm, 10**decimals, out=out_usm)
13271326

13281327
if dtype is not None:
1329-
res_usm = dpt_ext.astype(res_usm, dtype, copy=False)
1328+
res_usm = dpt.astype(res_usm, dtype, copy=False)
13301329

13311330
if out is not None and isinstance(out, dpnp_array):
13321331
return out

dpnp/dpnp_array.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@
3737

3838
import warnings
3939

40-
import dpctl.tensor as dpt
41-
4240
# TODO: revert to `import dpctl.tensor...`
4341
# when dpnp fully migrates dpctl/tensor
44-
import dpctl_ext.tensor as dpt_ext
42+
import dpctl_ext.tensor as dpt
4543
import dpctl_ext.tensor._type_utils as dtu
4644
import dpnp
4745
from dpctl_ext.tensor._numpy_helper import AxisError
@@ -777,7 +775,7 @@ def asnumpy(self):
777775
778776
"""
779777

780-
return dpt_ext.asnumpy(self._array_obj)
778+
return dpt.asnumpy(self._array_obj)
781779

782780
def astype(
783781
self,
@@ -2283,7 +2281,7 @@ def transpose(self, *axes):
22832281
# self.transpose(None).shape == self.shape[::-1]
22842282
axes = tuple((ndim - x - 1) for x in range(ndim))
22852283

2286-
usm_res = dpt_ext.permute_dims(self._array_obj, axes)
2284+
usm_res = dpt.permute_dims(self._array_obj, axes)
22872285
return dpnp_array._create_from_usm_ndarray(usm_res)
22882286

22892287
def var(

dpnp/dpnp_array_api_info.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@
3636
3737
"""
3838

39-
import dpctl.tensor as dpt
39+
# TODO: revert to `import dpctl.tensor...`
40+
# when dpnp fully migrates dpctl/tensor
41+
import dpctl_ext.tensor as dpt
4042

4143

4244
def __array_namespace_info__():

dpnp/dpnp_iface.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,16 @@
4545
import os
4646

4747
import dpctl
48-
import dpctl.tensor as dpt
4948
import dpctl.utils as dpu
5049
import numpy
51-
from dpctl.tensor._device import normalize_queue_device
5250

5351
# pylint: disable=no-name-in-module
5452
# TODO: revert to `import dpctl.tensor...`
5553
# when dpnp fully migrates dpctl/tensor
56-
import dpctl_ext.tensor as dpt_ext
54+
import dpctl_ext.tensor as dpt
5755
import dpctl_ext.tensor._tensor_impl as ti
5856
import dpnp
57+
from dpctl_ext.tensor._device import normalize_queue_device
5958

6059
from .dpnp_array import dpnp_array
6160
from .dpnp_utils import (
@@ -137,7 +136,7 @@ def asnumpy(a, order="C"):
137136
return a.asnumpy()
138137

139138
if isinstance(a, dpt.usm_ndarray):
140-
return dpt_ext.asnumpy(a)
139+
return dpt.asnumpy(a)
141140

142141
return numpy.asarray(a, order=order)
143142

@@ -191,7 +190,7 @@ def as_usm_ndarray(a, dtype=None, device=None, usm_type=None, sycl_queue=None):
191190
if is_supported_array_type(a):
192191
return get_usm_ndarray(a)
193192

194-
return dpt_ext.asarray(
193+
return dpt.asarray(
195194
a, dtype=dtype, device=device, usm_type=usm_type, sycl_queue=sycl_queue
196195
)
197196

0 commit comments

Comments
 (0)