Skip to content

Commit 0d84d7b

Browse files
Move ti.meshgrid() to dpctl_ext/tensor
1 parent f9f547a commit 0d84d7b

2 files changed

Lines changed: 77 additions & 0 deletions

File tree

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
full,
4646
full_like,
4747
linspace,
48+
meshgrid,
4849
tril,
4950
triu,
5051
)
@@ -85,6 +86,7 @@
8586
"iinfo",
8687
"isdtype",
8788
"linspace",
89+
"meshgrid",
8890
"nonzero",
8991
"place",
9092
"put",

dpctl_ext/tensor/_ctors.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,81 @@ def linspace(
14431443
return res if int_dt is None else dpt.astype(res, int_dt)
14441444

14451445

1446+
def meshgrid(*arrays, indexing="xy"):
1447+
"""
1448+
Creates list of :class:`dpctl.tensor.usm_ndarray` coordinate matrices
1449+
from vectors.
1450+
1451+
Args:
1452+
arrays (usm_ndarray):
1453+
an arbitrary number of one-dimensional arrays
1454+
representing grid coordinates. Each array should have the same
1455+
numeric data type.
1456+
indexing (``"xy"``, or ``"ij"``):
1457+
Cartesian (``"xy"``) or matrix (``"ij"``) indexing of output.
1458+
If provided zero or one one-dimensional vector(s) (i.e., the
1459+
zero- and one-dimensional cases, respectively), the ``indexing``
1460+
keyword has no effect and should be ignored. Default: ``"xy"``
1461+
1462+
Returns:
1463+
List[array]:
1464+
list of ``N`` arrays, where ``N`` is the number of
1465+
provided one-dimensional input arrays. Each returned array must
1466+
have rank ``N``.
1467+
For a set of ``n`` vectors with lengths ``N0``, ``N1``, ``N2``, ...
1468+
The cartesian indexing results in arrays of shape
1469+
``(N1, N0, N2, ...)``, while the
1470+
matrix indexing results in arrays of shape
1471+
``(N0, N1, N2, ...)``.
1472+
Default: ``"xy"``.
1473+
1474+
Raises:
1475+
ValueError: If vectors are not of the same data type, or are not
1476+
one-dimensional.
1477+
1478+
"""
1479+
ref_dt = None
1480+
ref_unset = True
1481+
for array in arrays:
1482+
if not isinstance(array, dpt.usm_ndarray):
1483+
raise TypeError(
1484+
f"Expected instance of dpt.usm_ndarray, got {type(array)}."
1485+
)
1486+
if array.ndim != 1:
1487+
raise ValueError("All arrays must be one-dimensional.")
1488+
if ref_unset:
1489+
ref_unset = False
1490+
ref_dt = array.dtype
1491+
else:
1492+
if not ref_dt == array.dtype:
1493+
raise ValueError(
1494+
"All arrays must be of the same numeric data type."
1495+
)
1496+
if indexing not in ["xy", "ij"]:
1497+
raise ValueError(
1498+
"Unrecognized indexing keyword value, expecting 'xy' or 'ij.'"
1499+
)
1500+
n = len(arrays)
1501+
if n == 0:
1502+
return []
1503+
1504+
sh = (-1,) + (1,) * (n - 1)
1505+
1506+
res = []
1507+
if n > 1 and indexing == "xy":
1508+
res.append(dpt_ext.reshape(arrays[0], (1, -1) + sh[2:], copy=True))
1509+
res.append(dpt_ext.reshape(arrays[1], sh, copy=True))
1510+
arrays, sh = arrays[2:], sh[-2:] + sh[:-2]
1511+
1512+
for array in arrays:
1513+
res.append(dpt_ext.reshape(array, sh, copy=True))
1514+
sh = sh[-1:] + sh[:-1]
1515+
1516+
output = dpt.broadcast_arrays(*res)
1517+
1518+
return output
1519+
1520+
14461521
def tril(x, /, *, k=0):
14471522
"""
14481523
Returns the lower triangular part of a matrix (or a stack of matrices)

0 commit comments

Comments
 (0)