Skip to content

Commit 97dc7e1

Browse files
Move ti.zeros() to dpctl_ext/tensor and reuse it in dpnp
1 parent 23d2229 commit 97dc7e1

3 files changed

Lines changed: 85 additions & 15 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
ones_like,
5151
tril,
5252
triu,
53+
zeros,
5354
)
5455
from dpctl_ext.tensor._indexing_functions import (
5556
extract,
@@ -105,4 +106,5 @@
105106
"tril",
106107
"triu",
107108
"where",
109+
"zeros",
108110
]

dpctl_ext/tensor/_ctors.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,7 @@ def eye(
10911091
n_cols = n_rows if n_cols is None else operator.index(n_cols)
10921092
k = operator.index(k)
10931093
if k >= n_cols or -k >= n_rows:
1094-
return dpt.zeros(
1094+
return dpt_ext.zeros(
10951095
(n_rows, n_cols),
10961096
dtype=dtype,
10971097
order=order,
@@ -1720,7 +1720,7 @@ def tril(x, /, *, k=0):
17201720
)
17211721
_manager.add_event_pair(hev, cpy_ev)
17221722
elif k < -shape[nd - 2]:
1723-
res = dpt.zeros(
1723+
res = dpt_ext.zeros(
17241724
x.shape,
17251725
dtype=x.dtype,
17261726
order=order,
@@ -1784,7 +1784,7 @@ def triu(x, /, *, k=0):
17841784

17851785
q = x.sycl_queue
17861786
if k > shape[nd - 1]:
1787-
res = dpt.zeros(
1787+
res = dpt_ext.zeros(
17881788
x.shape,
17891789
dtype=x.dtype,
17901790
order=order,
@@ -1821,3 +1821,72 @@ def triu(x, /, *, k=0):
18211821
_manager.add_event_pair(hev, triu_ev)
18221822

18231823
return res
1824+
1825+
1826+
def zeros(
1827+
shape,
1828+
*,
1829+
dtype=None,
1830+
order="C",
1831+
device=None,
1832+
usm_type="device",
1833+
sycl_queue=None,
1834+
):
1835+
"""
1836+
Returns a new :class:`dpctl.tensor.usm_ndarray` having a specified
1837+
shape and filled with zeros.
1838+
1839+
Args:
1840+
shape (Tuple[int], int):
1841+
Dimensions of the array to be created.
1842+
dtype (optional):
1843+
data type of the array. Can be typestring,
1844+
a :class:`numpy.dtype` object, :mod:`numpy` char string,
1845+
or a NumPy scalar type. Default: ``None``
1846+
order ("C", or "F"):
1847+
memory layout for the array. Default: ``"C"``
1848+
device (optional): array API concept of device where the output array
1849+
is created. ``device`` can be ``None``, a oneAPI filter selector
1850+
string, an instance of :class:`dpctl.SyclDevice` corresponding to
1851+
a non-partitioned SYCL device, an instance of
1852+
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device` object
1853+
returned by :attr:`dpctl.tensor.usm_ndarray.device`.
1854+
Default: ``None``
1855+
usm_type (``"device"``, ``"shared"``, ``"host"``, optional):
1856+
The type of SYCL USM allocation for the output array.
1857+
Default: ``"device"``
1858+
sycl_queue (:class:`dpctl.SyclQueue`, optional):
1859+
The SYCL queue to use
1860+
for output array allocation and copying. ``sycl_queue`` and
1861+
``device`` are complementary arguments, i.e. use one or another.
1862+
If both are specified, a :exc:`TypeError` is raised unless both
1863+
imply the same underlying SYCL queue to be used. If both are
1864+
``None``, a cached queue targeting default-selected device is
1865+
used for allocation and population. Default: ``None``
1866+
1867+
Returns:
1868+
usm_ndarray:
1869+
Constructed array initialized with zeros.
1870+
"""
1871+
if not isinstance(order, str) or len(order) == 0 or order[0] not in "CcFf":
1872+
raise ValueError(
1873+
"Unrecognized order keyword value, expecting 'F' or 'C'."
1874+
)
1875+
order = order[0].upper()
1876+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1877+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1878+
dtype = _get_dtype(dtype, sycl_queue)
1879+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
1880+
res = dpt.usm_ndarray(
1881+
shape,
1882+
dtype=dtype,
1883+
buffer=usm_type,
1884+
order=order,
1885+
buffer_ctor_kwargs={"queue": sycl_queue},
1886+
)
1887+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
1888+
# populating new allocation, no dependent events
1889+
hev, zeros_ev = ti._zeros_usm_ndarray(res, sycl_queue)
1890+
_manager.add_event_pair(hev, zeros_ev)
1891+
1892+
return res

dpnp/dpnp_container.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,11 @@
3535
3636
"""
3737

38-
import dpctl.tensor as dpt
3938
import dpctl.utils as dpu
4039

4140
# TODO: revert to `import dpctl.tensor...`
4241
# when dpnp fully migrates dpctl/tensor
43-
import dpctl_ext.tensor as dpt_ext
42+
import dpctl_ext.tensor as dpt
4443
import dpnp
4544
from dpnp.dpnp_array import dpnp_array
4645

@@ -75,7 +74,7 @@ def arange(
7574
sycl_queue=sycl_queue, device=device
7675
)
7776

78-
array_obj = dpt_ext.arange(
77+
array_obj = dpt.arange(
7978
start,
8079
stop=stop,
8180
step=step,
@@ -103,7 +102,7 @@ def asarray(
103102

104103
"""Converts incoming 'x1' object to 'dpnp_array'."""
105104
if isinstance(x1, (list, tuple, range)):
106-
array_obj = dpt_ext.asarray(
105+
array_obj = dpt.asarray(
107106
x1,
108107
dtype=dtype,
109108
copy=copy,
@@ -122,7 +121,7 @@ def asarray(
122121
x1_obj, device=device, sycl_queue=sycl_queue
123122
)
124123

125-
array_obj = dpt_ext.asarray(
124+
array_obj = dpt.asarray(
126125
x1_obj,
127126
dtype=dtype,
128127
copy=copy,
@@ -143,7 +142,7 @@ def copy(x1, /, *, order="K"):
143142
if order is None:
144143
order = "K"
145144

146-
array_obj = dpt_ext.copy(dpnp.get_usm_ndarray(x1), order=order)
145+
array_obj = dpt.copy(dpnp.get_usm_ndarray(x1), order=order)
147146
return dpnp_array._create_from_usm_ndarray(array_obj)
148147

149148

@@ -165,7 +164,7 @@ def empty(
165164
order = "C"
166165

167166
"""Creates `dpnp_array` from uninitialized USM allocation."""
168-
array_obj = dpt_ext.empty(
167+
array_obj = dpt.empty(
169168
shape,
170169
dtype=dtype,
171170
order=order,
@@ -196,7 +195,7 @@ def eye(
196195
order = "C"
197196

198197
"""Creates `dpnp_array` with ones on the `k`th diagonal."""
199-
array_obj = dpt_ext.eye(
198+
array_obj = dpt.eye(
200199
N,
201200
M,
202201
k=k,
@@ -231,7 +230,7 @@ def full(
231230
fill_value = fill_value.get_array()
232231

233232
"""Creates `dpnp_array` having a specified shape, filled with fill_value."""
234-
array_obj = dpt_ext.full(
233+
array_obj = dpt.full(
235234
shape,
236235
fill_value,
237236
dtype=dtype,
@@ -260,7 +259,7 @@ def ones(
260259
order = "C"
261260

262261
"""Creates `dpnp_array` of ones with the given shape, dtype, and order."""
263-
array_obj = dpt_ext.ones(
262+
array_obj = dpt.ones(
264263
shape,
265264
dtype=dtype,
266265
order=order,
@@ -272,13 +271,13 @@ def ones(
272271

273272
def tril(x1, /, *, k=0):
274273
"""Creates `dpnp_array` as lower triangular part of an input array."""
275-
array_obj = dpt_ext.tril(dpnp.get_usm_ndarray(x1), k=k)
274+
array_obj = dpt.tril(dpnp.get_usm_ndarray(x1), k=k)
276275
return dpnp_array._create_from_usm_ndarray(array_obj)
277276

278277

279278
def triu(x1, /, *, k=0):
280279
"""Creates `dpnp_array` as upper triangular part of an input array."""
281-
array_obj = dpt_ext.triu(dpnp.get_usm_ndarray(x1), k=k)
280+
array_obj = dpt.triu(dpnp.get_usm_ndarray(x1), k=k)
282281
return dpnp_array._create_from_usm_ndarray(array_obj)
283282

284283

0 commit comments

Comments
 (0)