Skip to content

Commit a6c397e

Browse files
Move ti.zeros_like() to dpctl_ext/tensor
1 parent 97dc7e1 commit a6c397e

2 files changed

Lines changed: 86 additions & 0 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
tril,
5252
triu,
5353
zeros,
54+
zeros_like,
5455
)
5556
from dpctl_ext.tensor._indexing_functions import (
5657
extract,
@@ -107,4 +108,5 @@
107108
"triu",
108109
"where",
109110
"zeros",
111+
"zeros_like",
110112
]

dpctl_ext/tensor/_ctors.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1890,3 +1890,87 @@ def zeros(
18901890
_manager.add_event_pair(hev, zeros_ev)
18911891

18921892
return res
1893+
1894+
1895+
def zeros_like(
1896+
x, /, *, dtype=None, order="K", device=None, usm_type=None, sycl_queue=None
1897+
):
1898+
"""
1899+
Creates :class:`dpctl.tensor.usm_ndarray` from USM allocation
1900+
initialized with zeros.
1901+
1902+
Args:
1903+
x (usm_ndarray):
1904+
Input array from which to derive the shape of the
1905+
output array.
1906+
dtype (optional):
1907+
data type of the array. Can be typestring,
1908+
a :class:`numpy.dtype` object, :mod:`numpy` char string, or a
1909+
NumPy scalar type. If `None`, output array has the same data
1910+
type as the input array. Default: ``None``
1911+
order ("C", or "F"):
1912+
memory layout for the array. Default: ``"C"``
1913+
device (optional):
1914+
array API concept of device where the output array
1915+
is created. ``device`` can be ``None``, a oneAPI filter selector
1916+
string, an instance of :class:`dpctl.SyclDevice` corresponding to
1917+
a non-partitioned SYCL device, an instance of
1918+
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device` object
1919+
returned by :attr:`dpctl.tensor.usm_ndarray.device`.
1920+
Default: ``None``
1921+
usm_type (``"device"``, ``"shared"``, ``"host"``, optional):
1922+
The type of SYCL USM allocation for the output array.
1923+
Default: ``"device"``
1924+
sycl_queue (:class:`dpctl.SyclQueue`, optional):
1925+
The SYCL queue to use
1926+
for output array allocation and copying. ``sycl_queue`` and
1927+
``device`` are complementary arguments, i.e. use one or another.
1928+
If both are specified, a :exc:`TypeError` is raised unless both
1929+
imply the same underlying SYCL queue to be used. If both are
1930+
``None``, a cached queue targeting default-selected device is
1931+
used for allocation and population. Default: ``None``
1932+
1933+
Returns:
1934+
usm_ndarray:
1935+
New array initialized with zeros.
1936+
"""
1937+
if not isinstance(x, dpt.usm_ndarray):
1938+
raise TypeError(f"Expected instance of dpt.usm_ndarray, got {type(x)}.")
1939+
if (
1940+
not isinstance(order, str)
1941+
or len(order) == 0
1942+
or order[0] not in "CcFfAaKk"
1943+
):
1944+
raise ValueError(
1945+
"Unrecognized order keyword value, expecting 'C', 'F', 'A', or 'K'."
1946+
)
1947+
order = order[0].upper()
1948+
if dtype is None:
1949+
dtype = x.dtype
1950+
if usm_type is None:
1951+
usm_type = x.usm_type
1952+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1953+
if device is None and sycl_queue is None:
1954+
device = x.device
1955+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1956+
dtype = dpt.dtype(dtype)
1957+
order = _normalize_order(order, x)
1958+
if order == "K":
1959+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
1960+
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)
1961+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
1962+
# populating new allocation, no dependent events
1963+
hev, full_ev = ti._full_usm_ndarray(0, res, sycl_queue)
1964+
_manager.add_event_pair(hev, full_ev)
1965+
return res
1966+
else:
1967+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
1968+
sh = x.shape
1969+
return zeros(
1970+
sh,
1971+
dtype=dtype,
1972+
order=order,
1973+
device=device,
1974+
usm_type=usm_type,
1975+
sycl_queue=sycl_queue,
1976+
)

0 commit comments

Comments
 (0)