Skip to content

Commit 8c15ddb

Browse files
Move ti.ones() to dpctl_ext/tensor and reuse it in dpnp
1 parent 0d84d7b commit 8c15ddb

7 files changed

Lines changed: 82 additions & 5 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
full_like,
4747
linspace,
4848
meshgrid,
49+
ones,
4950
tril,
5051
triu,
5152
)
@@ -88,6 +89,7 @@
8889
"linspace",
8990
"meshgrid",
9091
"nonzero",
92+
"ones",
9193
"place",
9294
"put",
9395
"put_along_axis",

dpctl_ext/tensor/_ctors.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,74 @@ def meshgrid(*arrays, indexing="xy"):
15181518
return output
15191519

15201520

1521+
def ones(
1522+
shape,
1523+
*,
1524+
dtype=None,
1525+
order="C",
1526+
device=None,
1527+
usm_type="device",
1528+
sycl_queue=None,
1529+
):
1530+
"""ones(shape, dtype=None, order="C", \
1531+
device=None, usm_type="device", sycl_queue=None)
1532+
1533+
Returns a new :class:`dpctl.tensor.usm_ndarray` having a specified
1534+
shape and filled with ones.
1535+
1536+
Args:
1537+
shape (Tuple[int], int):
1538+
Dimensions of the array to be created.
1539+
dtype (optional):
1540+
data type of the array. Can be typestring,
1541+
a :class:`numpy.dtype` object, :mod:`numpy` char string,
1542+
or a NumPy scalar type. Default: ``None``
1543+
order ("C", or "F"): memory layout for the array. Default: ``"C"``
1544+
device (optional): array API concept of device where the output array
1545+
is created. ``device`` can be ``None``, a oneAPI filter selector
1546+
string, an instance of :class:`dpctl.SyclDevice` corresponding to
1547+
a non-partitioned SYCL device, an instance of
1548+
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device` object
1549+
returned by :attr:`dpctl.tensor.usm_ndarray.device`.
1550+
Default: ``None``
1551+
usm_type (``"device"``, ``"shared"``, ``"host"``, optional):
1552+
The type of SYCL USM allocation for the output array.
1553+
Default: ``"device"``
1554+
sycl_queue (:class:`dpctl.SyclQueue`, optional):
1555+
The SYCL queue to use
1556+
for output array allocation and copying. ``sycl_queue`` and
1557+
``device`` are complementary arguments, i.e. use one or another.
1558+
If both are specified, a :exc:`TypeError` is raised unless both
1559+
imply the same underlying SYCL queue to be used. If both are
1560+
``None``, a cached queue targeting default-selected device is
1561+
used for allocation and population. Default: ``None``
1562+
1563+
Returns:
1564+
usm_ndarray:
1565+
Created array initialized with ones.
1566+
"""
1567+
if not isinstance(order, str) or len(order) == 0 or order[0] not in "CcFf":
1568+
raise ValueError(
1569+
"Unrecognized order keyword value, expecting 'F' or 'C'."
1570+
)
1571+
order = order[0].upper()
1572+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1573+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1574+
dtype = _get_dtype(dtype, sycl_queue)
1575+
res = dpt.usm_ndarray(
1576+
shape,
1577+
dtype=dtype,
1578+
buffer=usm_type,
1579+
order=order,
1580+
buffer_ctor_kwargs={"queue": sycl_queue},
1581+
)
1582+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
1583+
# populating new allocation, no dependent events
1584+
hev, full_ev = ti._full_usm_ndarray(1, res, sycl_queue)
1585+
_manager.add_event_pair(hev, full_ev)
1586+
return res
1587+
1588+
15211589
def tril(x, /, *, k=0):
15221590
"""
15231591
Returns the lower triangular part of a matrix (or a stack of matrices)

dpnp/dpnp_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def ones(
260260
order = "C"
261261

262262
"""Creates `dpnp_array` of ones with the given shape, dtype, and order."""
263-
array_obj = dpt.ones(
263+
array_obj = dpt_ext.ones(
264264
shape,
265265
dtype=dtype,
266266
order=order,

dpnp/dpnp_iface_arraycreation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3696,7 +3696,7 @@ def tri(
36963696
if usm_type is None:
36973697
usm_type = "device"
36983698

3699-
m = dpt.ones(
3699+
m = dpt_ext.ones(
37003700
(N, M),
37013701
dtype=_dtype,
37023702
device=device,

dpnp/tests/test_memory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import numpy
33
import pytest
44

5+
# TODO: revert to `import dpctl.tensor...`
6+
# when dpnp fully migrates dpctl/tensor
7+
import dpctl_ext.tensor as dpt_ext
58
import dpnp
69
import dpnp.memory as dpm
710

@@ -21,7 +24,7 @@ def test_wrong_input_type(self, x):
2124
dpm.create_data(x)
2225

2326
def test_wrong_usm_data(self):
24-
a = dpt.ones(10)
27+
a = dpt_ext.ones(10)
2528
d = IntUsmData(a.shape, buffer=a)
2629

2730
with pytest.raises(TypeError):

dpnp/tests/test_sycl_queue.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
import tempfile
33

44
import dpctl
5-
import dpctl.tensor as dpt
65
import numpy
76
import pytest
87
from dpctl.utils import ExecutionPlacementError
98
from numpy.testing import assert_array_equal, assert_raises
109

10+
# TODO: revert to `import dpctl.tensor...`
11+
# when dpnp fully migrates dpctl/tensor
12+
import dpctl_ext.tensor as dpt
1113
import dpnp
1214
import dpnp.linalg
1315
from dpnp.dpnp_array import dpnp_array

dpnp/tests/test_usm_type.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import tempfile
33
from math import prod
44

5-
import dpctl.tensor as dpt
65
import dpctl.utils as du
76
import numpy
87
import pytest
98

9+
# TODO: revert to `import dpctl.tensor...`
10+
# when dpnp fully migrates dpctl/tensor
11+
import dpctl_ext.tensor as dpt
1012
import dpnp
1113
from dpnp.dpnp_utils import get_usm_allocations
1214

0 commit comments

Comments
 (0)