Skip to content

Commit f9f547a

Browse files
Move ti.full_like() to dpctl_ext/tensor
1 parent effcbe8 commit f9f547a

2 files changed

Lines changed: 119 additions & 0 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
empty_like,
4444
eye,
4545
full,
46+
full_like,
4647
linspace,
4748
tril,
4849
triu,
@@ -80,6 +81,7 @@
8081
"finfo",
8182
"from_numpy",
8283
"full",
84+
"full_like",
8385
"iinfo",
8486
"isdtype",
8587
"linspace",

dpctl_ext/tensor/_ctors.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,123 @@ def full(
12231223
return res
12241224

12251225

1226+
def full_like(
1227+
x,
1228+
/,
1229+
fill_value,
1230+
*,
1231+
dtype=None,
1232+
order="K",
1233+
device=None,
1234+
usm_type=None,
1235+
sycl_queue=None,
1236+
):
1237+
"""full_like(x, fill_value, dtype=None, order="K", \
1238+
device=None, usm_type=None, sycl_queue=None)
1239+
1240+
Returns a new :class:`dpctl.tensor.usm_ndarray` filled with `fill_value`
1241+
and having the same `shape` as the input array `x`.
1242+
1243+
Args:
1244+
x (usm_ndarray): Input array from which to derive the output array
1245+
shape.
1246+
fill_value: the value to fill output array with
1247+
dtype (optional):
1248+
data type of the array. Can be typestring,
1249+
a :class:`numpy.dtype` object, :mod:`numpy` char string, or a
1250+
NumPy scalar type. If ``dtype`` is ``None``, the output array data
1251+
type is inferred from ``x``. Default: ``None``
1252+
order ("C", "F", "A", or "K"):
1253+
memory layout for the array. Default: ``"K"``
1254+
device (optional):
1255+
array API concept of device where the output array
1256+
is created. ``device`` can be ``None``, a oneAPI filter selector
1257+
string, an instance of :class:`dpctl.SyclDevice` corresponding to
1258+
a non-partitioned SYCL device, an instance of
1259+
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device` object
1260+
returned by :attr:`dpctl.tensor.usm_ndarray.device`.
1261+
Default: ``None``
1262+
usm_type (``"device"``, ``"shared"``, ``"host"``, optional):
1263+
The type of SYCL USM allocation for the output array.
1264+
Default: ``"device"``
1265+
sycl_queue (:class:`dpctl.SyclQueue`, optional):
1266+
The SYCL queue to use
1267+
for output array allocation and copying. ``sycl_queue`` and
1268+
``device`` are complementary arguments, i.e. use one or another.
1269+
If both are specified, a :exc:`TypeError` is raised unless both
1270+
imply the same underlying SYCL queue to be used. If both are
1271+
``None``, a cached queue targeting default-selected device is
1272+
used for allocation and population. Default: ``None``
1273+
1274+
Returns:
1275+
usm_ndarray:
1276+
New array initialized with given value.
1277+
"""
1278+
if not isinstance(x, dpt.usm_ndarray):
1279+
raise TypeError(f"Expected instance of dpt.usm_ndarray, got {type(x)}.")
1280+
if (
1281+
not isinstance(order, str)
1282+
or len(order) == 0
1283+
or order[0] not in "CcFfAaKk"
1284+
):
1285+
raise ValueError(
1286+
"Unrecognized order keyword value, expecting 'C', 'F', 'A', or 'K'."
1287+
)
1288+
order = order[0].upper()
1289+
if dtype is None:
1290+
dtype = x.dtype
1291+
if usm_type is None:
1292+
usm_type = x.usm_type
1293+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
1294+
if device is None and sycl_queue is None:
1295+
device = x.device
1296+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
1297+
sh = x.shape
1298+
dtype = dpt.dtype(dtype)
1299+
order = _normalize_order(order, x)
1300+
if order == "K":
1301+
_ensure_native_dtype_device_support(dtype, sycl_queue.sycl_device)
1302+
if isinstance(fill_value, (dpt.usm_ndarray, np.ndarray, tuple, list)):
1303+
X = dpt_ext.asarray(
1304+
fill_value,
1305+
dtype=dtype,
1306+
order=order,
1307+
usm_type=usm_type,
1308+
sycl_queue=sycl_queue,
1309+
)
1310+
X = dpt.broadcast_to(X, sh)
1311+
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)
1312+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
1313+
# order copy after tasks populating X
1314+
dep_evs = _manager.submitted_events
1315+
hev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
1316+
src=X, dst=res, sycl_queue=sycl_queue, depends=dep_evs
1317+
)
1318+
_manager.add_event_pair(hev, copy_ev)
1319+
return res
1320+
else:
1321+
_validate_fill_value(fill_value)
1322+
1323+
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
1324+
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)
1325+
fill_value = _cast_fill_val(fill_value, dtype)
1326+
_manager = dpctl.utils.SequentialOrderManager[sycl_queue]
1327+
# populating new allocation, no dependent events
1328+
hev, full_ev = ti._full_usm_ndarray(fill_value, res, sycl_queue)
1329+
_manager.add_event_pair(hev, full_ev)
1330+
return res
1331+
else:
1332+
return full(
1333+
sh,
1334+
fill_value,
1335+
dtype=dtype,
1336+
order=order,
1337+
device=device,
1338+
usm_type=usm_type,
1339+
sycl_queue=sycl_queue,
1340+
)
1341+
1342+
12261343
def linspace(
12271344
start,
12281345
stop,

0 commit comments

Comments
 (0)