Skip to content

Commit 23d2229

Browse files
Move ti.ones_like() to dpctl_ext/tensor
1 parent 8c15ddb commit 23d2229

2 files changed

Lines changed: 83 additions & 0 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
linspace,
4848
meshgrid,
4949
ones,
50+
ones_like,
5051
tril,
5152
triu,
5253
)
@@ -90,6 +91,7 @@
9091
"meshgrid",
9192
"nonzero",
9293
"ones",
94+
"ones_like",
9395
"place",
9496
"put",
9597
"put_along_axis",

dpctl_ext/tensor/_ctors.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,87 @@ def ones(
15861586
return res
15871587

15881588

1589+
def ones_like(
1590+
x, /, *, dtype=None, order="K", device=None, usm_type=None, sycl_queue=None
1591+
):
1592+
"""
1593+
Returns a new :class:`dpctl.tensor.usm_ndarray` filled with ones and
1594+
having the same `shape` as the input array `x`.
1595+
1596+
Args:
1597+
x (usm_ndarray):
1598+
Input array from which to derive the output array shape
1599+
dtype (optional):
1600+
data type of the array. Can be typestring,
1601+
a :class:`numpy.dtype` object, :mod:`numpy` char string,
1602+
or a NumPy scalar type. Default: `None`
1603+
order ("C", "F", "A", or "K"):
1604+
memory layout for the array. Default: ``"C"``
1605+
device (optional):
1606+
array API concept of device where the output array
1607+
is created. ``device`` can be ``None``, a oneAPI filter selector
1608+
string, an instance of :class:`dpctl.SyclDevice` corresponding to
1609+
a non-partitioned SYCL device, an instance of
1610+
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device` object
1611+
returned by :attr:`dpctl.tensor.usm_ndarray.device`.
1612+
Default: ``None``
1613+
usm_type (``"device"``, ``"shared"``, ``"host"``, optional):
1614+
The type of SYCL USM allocation for the output array.
1615+
Default: ``"device"``
1616+
sycl_queue (:class:`dpctl.SyclQueue`, optional):
1617+
The SYCL queue to use
1618+
for output array allocation and copying. ``sycl_queue`` and
1619+
``device`` are complementary arguments, i.e. use one or another.
1620+
If both are specified, a :exc:`TypeError` is raised unless both
1621+
imply the same underlying SYCL queue to be used. If both are
1622+
``None``, a cached queue targeting default-selected device is
1623+
used for allocation and population. Default: ``None``
1624+
1625+
Returns:
1626+
usm_ndarray:
1627+
New array initialized with ones.
1628+
"""
1629+
if not isinstance(x, dpt.usm_ndarray):
1630+
raise TypeError(f"Expected instance of dpt.usm_ndarray, got {type(x)}.")
1631+
if (
1632+
not isinstance(order, str)
1633+
or len(order) == 0
1634+
or order[0] not in "CcFfAaKk"
1635+
):
1636+
raise ValueError(
1637+
"Unrecognized order keyword value, expecting 'C', 'F', 'A', or 'K'."
1638+
)
1639+
order = order[0].upper()
1640+
if dtype is None:
1641+
dtype = x.dtype
1642+
if usm_type is None:
1643+
usm_type = x.usm_type
1644+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1645+
if device is None and sycl_queue is None:
1646+
device = x.device
1647+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1648+
dtype = dpt.dtype(dtype)
1649+
order = _normalize_order(order, x)
1650+
if order == "K":
1651+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
1652+
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)
1653+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
1654+
# populating new allocation, no dependent events
1655+
hev, full_ev = ti._full_usm_ndarray(1, res, sycl_queue)
1656+
_manager.add_event_pair(hev, full_ev)
1657+
return res
1658+
else:
1659+
sh = x.shape
1660+
return ones(
1661+
sh,
1662+
dtype=dtype,
1663+
order=order,
1664+
device=device,
1665+
usm_type=usm_type,
1666+
sycl_queue=sycl_queue,
1667+
)
1668+
1669+
15891670
def tril(x, /, *, k=0):
15901671
"""
15911672
Returns the lower triangular part of a matrix (or a stack of matrices)

0 commit comments

Comments
 (0)