Skip to content

Commit 5e7123d

Browse files
Update _usmarray.pyx to use dpctl_ext.tensor
1 parent 7f14dfc commit 5e7123d

1 file changed

Lines changed: 62 additions & 51 deletions

File tree

dpctl_ext/tensor/_usmarray.pyx

Lines changed: 62 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ import numpy as np
3737
from dpctl._backend cimport DPCTLSyclUSMRef
3838
from dpctl._sycl_device_factory cimport _cached_default_device
3939

40+
# TODO: remote it when dpnp fully migrates dpctl/tensor
41+
import dpctl_ext
42+
4043
from ._data_types import bool as dpt_bool
4144
from ._device import Device
4245
from ._print import usm_ndarray_repr, usm_ndarray_str
@@ -1143,7 +1146,9 @@ cdef class usm_ndarray:
11431146
return (
11441147
self.array_namespace_
11451148
if self.array_namespace_ is not None
1146-
else dpctl.tensor
1149+
# TODO: revert to `else dpctl.tensor`
1150+
# when dpnp fully migrates dpctl/tensor
1151+
else dpctl_ext.tensor
11471152
)
11481153

11491154
def __bool__(self):
@@ -1199,17 +1204,19 @@ cdef class usm_ndarray:
11991204
raise IndexError("only integer arrays are valid indices")
12001205

12011206
def __abs__(self):
1202-
return dpctl.tensor.abs(self)
1207+
# TODO: revert to `return dpctl.tensor...`
1208+
# when dpnp fully migrates dpctl/tensor
1209+
return dpctl_ext.tensor.abs(self)
12031210

12041211
def __add__(self, other):
12051212
"""
12061213
Implementation for operator.add
12071214
"""
1208-
return dpctl.tensor.add(self, other)
1215+
return dpctl_ext.tensor.add(self, other)
12091216

12101217
def __and__(self, other):
12111218
"Implementation for operator.and"
1212-
return dpctl.tensor.bitwise_and(self, other)
1219+
return dpctl_ext.tensor.bitwise_and(self, other)
12131220

12141221
def __dlpack__(
12151222
self, *, stream=None, max_version=None, dl_device=None, copy=None
@@ -1368,22 +1375,24 @@ cdef class usm_ndarray:
13681375
)
13691376

13701377
def __eq__(self, other):
1371-
return dpctl.tensor.equal(self, other)
1378+
# TODO: revert to `return dpctl.tensor...`
1379+
# when dpnp fully migrates dpctl/tensor
1380+
return dpctl_ext.tensor.equal(self, other)
13721381

13731382
def __floordiv__(self, other):
1374-
return dpctl.tensor.floor_divide(self, other)
1383+
return dpctl_ext.tensor.floor_divide(self, other)
13751384

13761385
def __ge__(self, other):
1377-
return dpctl.tensor.greater_equal(self, other)
1386+
return dpctl_ext.tensor.greater_equal(self, other)
13781387

13791388
def __gt__(self, other):
1380-
return dpctl.tensor.greater(self, other)
1389+
return dpctl_ext.tensor.greater(self, other)
13811390

13821391
def __invert__(self):
1383-
return dpctl.tensor.bitwise_invert(self)
1392+
return dpctl_ext.tensor.bitwise_invert(self)
13841393

13851394
def __le__(self, other):
1386-
return dpctl.tensor.less_equal(self, other)
1395+
return dpctl_ext.tensor.less_equal(self, other)
13871396

13881397
def __len__(self):
13891398
if (self.nd_):
@@ -1392,37 +1401,37 @@ cdef class usm_ndarray:
13921401
raise TypeError("len() of unsized object")
13931402

13941403
def __lshift__(self, other):
1395-
return dpctl.tensor.bitwise_left_shift(self, other)
1404+
return dpctl_ext.tensor.bitwise_left_shift(self, other)
13961405

13971406
def __lt__(self, other):
1398-
return dpctl.tensor.less(self, other)
1407+
return dpctl_ext.tensor.less(self, other)
13991408

14001409
def __matmul__(self, other):
1401-
return dpctl.tensor.matmul(self, other)
1410+
return dpctl_ext.tensor.matmul(self, other)
14021411

14031412
def __mod__(self, other):
1404-
return dpctl.tensor.remainder(self, other)
1413+
return dpctl_ext.tensor.remainder(self, other)
14051414

14061415
def __mul__(self, other):
1407-
return dpctl.tensor.multiply(self, other)
1416+
return dpctl_ext.tensor.multiply(self, other)
14081417

14091418
def __ne__(self, other):
1410-
return dpctl.tensor.not_equal(self, other)
1419+
return dpctl_ext.tensor.not_equal(self, other)
14111420

14121421
def __neg__(self):
1413-
return dpctl.tensor.negative(self)
1422+
return dpctl_ext.tensor.negative(self)
14141423

14151424
def __or__(self, other):
1416-
return dpctl.tensor.bitwise_or(self, other)
1425+
return dpctl_ext.tensor.bitwise_or(self, other)
14171426

14181427
def __pos__(self):
1419-
return dpctl.tensor.positive(self)
1428+
return dpctl_ext.tensor.positive(self)
14201429

14211430
def __pow__(self, other):
1422-
return dpctl.tensor.pow(self, other)
1431+
return dpctl_ext.tensor.pow(self, other)
14231432

14241433
def __rshift__(self, other):
1425-
return dpctl.tensor.bitwise_right_shift(self, other)
1434+
return dpctl_ext.tensor.bitwise_right_shift(self, other)
14261435

14271436
def __setitem__(self, key, rhs):
14281437
cdef tuple _meta
@@ -1467,7 +1476,7 @@ cdef class usm_ndarray:
14671476
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs)
14681477
else:
14691478
if hasattr(rhs, "__sycl_usm_array_interface__"):
1470-
from dpctl.tensor import asarray
1479+
from dpctl_ext.tensor import asarray
14711480
try:
14721481
rhs_ar = asarray(rhs)
14731482
_copy_from_usm_ndarray_to_usm_ndarray(Xv, rhs_ar)
@@ -1515,91 +1524,93 @@ cdef class usm_ndarray:
15151524
return
15161525

15171526
def __sub__(self, other):
1518-
return dpctl.tensor.subtract(self, other)
1527+
# TODO: revert to `return dpctl.tensor...`
1528+
# when dpnp fully migrates dpctl/tensor
1529+
return dpctl_ext.tensor.subtract(self, other)
15191530

15201531
def __truediv__(self, other):
1521-
return dpctl.tensor.divide(self, other)
1532+
return dpctl_ext.tensor.divide(self, other)
15221533

15231534
def __xor__(self, other):
1524-
return dpctl.tensor.bitwise_xor(self, other)
1535+
return dpctl_ext.tensor.bitwise_xor(self, other)
15251536

15261537
def __radd__(self, other):
1527-
return dpctl.tensor.add(other, self)
1538+
return dpctl_ext.tensor.add(other, self)
15281539

15291540
def __rand__(self, other):
1530-
return dpctl.tensor.bitwise_and(other, self)
1541+
return dpctl_ext.tensor.bitwise_and(other, self)
15311542

15321543
def __rfloordiv__(self, other):
1533-
return dpctl.tensor.floor_divide(other, self)
1544+
return dpctl_ext.tensor.floor_divide(other, self)
15341545

15351546
def __rlshift__(self, other):
1536-
return dpctl.tensor.bitwise_left_shift(other, self)
1547+
return dpctl_ext.tensor.bitwise_left_shift(other, self)
15371548

15381549
def __rmatmul__(self, other):
1539-
return dpctl.tensor.matmul(other, self)
1550+
return dpctl_ext.tensor.matmul(other, self)
15401551

15411552
def __rmod__(self, other):
1542-
return dpctl.tensor.remainder(other, self)
1553+
return dpctl_ext.tensor.remainder(other, self)
15431554

15441555
def __rmul__(self, other):
1545-
return dpctl.tensor.multiply(other, self)
1556+
return dpctl_ext.tensor.multiply(other, self)
15461557

15471558
def __ror__(self, other):
1548-
return dpctl.tensor.bitwise_or(other, self)
1559+
return dpctl_ext.tensor.bitwise_or(other, self)
15491560

15501561
def __rpow__(self, other):
1551-
return dpctl.tensor.pow(other, self)
1562+
return dpctl_ext.tensor.pow(other, self)
15521563

15531564
def __rrshift__(self, other):
1554-
return dpctl.tensor.bitwise_right_shift(other, self)
1565+
return dpctl_ext.tensor.bitwise_right_shift(other, self)
15551566

15561567
def __rsub__(self, other):
1557-
return dpctl.tensor.subtract(other, self)
1568+
return dpctl_ext.tensor.subtract(other, self)
15581569

15591570
def __rtruediv__(self, other):
1560-
return dpctl.tensor.divide(other, self)
1571+
return dpctl_ext.tensor.divide(other, self)
15611572

15621573
def __rxor__(self, other):
1563-
return dpctl.tensor.bitwise_xor(other, self)
1574+
return dpctl_ext.tensor.bitwise_xor(other, self)
15641575

15651576
def __iadd__(self, other):
1566-
return dpctl.tensor.add._inplace_op(self, other)
1577+
return dpctl_ext.tensor.add._inplace_op(self, other)
15671578

15681579
def __iand__(self, other):
1569-
return dpctl.tensor.bitwise_and._inplace_op(self, other)
1580+
return dpctl_ext.tensor.bitwise_and._inplace_op(self, other)
15701581

15711582
def __ifloordiv__(self, other):
1572-
return dpctl.tensor.floor_divide._inplace_op(self, other)
1583+
return dpctl_ext.tensor.floor_divide._inplace_op(self, other)
15731584

15741585
def __ilshift__(self, other):
1575-
return dpctl.tensor.bitwise_left_shift._inplace_op(self, other)
1586+
return dpctl_ext.tensor.bitwise_left_shift._inplace_op(self, other)
15761587

15771588
def __imatmul__(self, other):
1578-
return dpctl.tensor.matmul(self, other, out=self, dtype=self.dtype)
1589+
return dpctl_ext.tensor.matmul(self, other, out=self, dtype=self.dtype)
15791590

15801591
def __imod__(self, other):
1581-
return dpctl.tensor.remainder._inplace_op(self, other)
1592+
return dpctl_ext.tensor.remainder._inplace_op(self, other)
15821593

15831594
def __imul__(self, other):
1584-
return dpctl.tensor.multiply._inplace_op(self, other)
1595+
return dpctl_ext.tensor.multiply._inplace_op(self, other)
15851596

15861597
def __ior__(self, other):
1587-
return dpctl.tensor.bitwise_or._inplace_op(self, other)
1598+
return dpctl_ext.tensor.bitwise_or._inplace_op(self, other)
15881599

15891600
def __ipow__(self, other):
1590-
return dpctl.tensor.pow._inplace_op(self, other)
1601+
return dpctl_ext.tensor.pow._inplace_op(self, other)
15911602

15921603
def __irshift__(self, other):
1593-
return dpctl.tensor.bitwise_right_shift._inplace_op(self, other)
1604+
return dpctl_ext.tensor.bitwise_right_shift._inplace_op(self, other)
15941605

15951606
def __isub__(self, other):
1596-
return dpctl.tensor.subtract._inplace_op(self, other)
1607+
return dpctl_ext.tensor.subtract._inplace_op(self, other)
15971608

15981609
def __itruediv__(self, other):
1599-
return dpctl.tensor.divide._inplace_op(self, other)
1610+
return dpctl_ext.tensor.divide._inplace_op(self, other)
16001611

16011612
def __ixor__(self, other):
1602-
return dpctl.tensor.bitwise_xor._inplace_op(self, other)
1613+
return dpctl_ext.tensor.bitwise_xor._inplace_op(self, other)
16031614

16041615
def __str__(self):
16051616
return usm_ndarray_str(self)

0 commit comments

Comments
 (0)