@@ -1443,6 +1443,81 @@ def linspace(
14431443 return res if int_dt is None else dpt .astype (res , int_dt )
14441444
14451445
1446+ def meshgrid (* arrays , indexing = "xy" ):
1447+ """
1448+ Creates list of :class:`dpctl.tensor.usm_ndarray` coordinate matrices
1449+ from vectors.
1450+
1451+ Args:
1452+ arrays (usm_ndarray):
1453+ an arbitrary number of one-dimensional arrays
1454+ representing grid coordinates. Each array should have the same
1455+ numeric data type.
1456+ indexing (``"xy"``, or ``"ij"``):
1457+ Cartesian (``"xy"``) or matrix (``"ij"``) indexing of output.
1458+ If provided zero or one one-dimensional vector(s) (i.e., the
1459+ zero- and one-dimensional cases, respectively), the ``indexing``
1460+ keyword has no effect and should be ignored. Default: ``"xy"``
1461+
1462+ Returns:
1463+ List[array]:
1464+ list of ``N`` arrays, where ``N`` is the number of
1465+ provided one-dimensional input arrays. Each returned array must
1466+ have rank ``N``.
1467+ For a set of ``n`` vectors with lengths ``N0``, ``N1``, ``N2``, ...
1468+ The cartesian indexing results in arrays of shape
1469+ ``(N1, N0, N2, ...)``, while the
1470+ matrix indexing results in arrays of shape
1471+ ``(N0, N1, N2, ...)``.
1472+ Default: ``"xy"``.
1473+
1474+ Raises:
1475+ ValueError: If vectors are not of the same data type, or are not
1476+ one-dimensional.
1477+
1478+ """
1479+ ref_dt = None
1480+ ref_unset = True
1481+ for array in arrays :
1482+ if not isinstance (array , dpt .usm_ndarray ):
1483+ raise TypeError (
1484+ f"Expected instance of dpt.usm_ndarray, got { type (array )} ."
1485+ )
1486+ if array .ndim != 1 :
1487+ raise ValueError ("All arrays must be one-dimensional." )
1488+ if ref_unset :
1489+ ref_unset = False
1490+ ref_dt = array .dtype
1491+ else :
1492+ if not ref_dt == array .dtype :
1493+ raise ValueError (
1494+ "All arrays must be of the same numeric data type."
1495+ )
1496+ if indexing not in ["xy" , "ij" ]:
1497+ raise ValueError (
1498+ "Unrecognized indexing keyword value, expecting 'xy' or 'ij.'"
1499+ )
1500+ n = len (arrays )
1501+ if n == 0 :
1502+ return []
1503+
1504+ sh = (- 1 ,) + (1 ,) * (n - 1 )
1505+
1506+ res = []
1507+ if n > 1 and indexing == "xy" :
1508+ res .append (dpt_ext .reshape (arrays [0 ], (1 , - 1 ) + sh [2 :], copy = True ))
1509+ res .append (dpt_ext .reshape (arrays [1 ], sh , copy = True ))
1510+ arrays , sh = arrays [2 :], sh [- 2 :] + sh [:- 2 ]
1511+
1512+ for array in arrays :
1513+ res .append (dpt_ext .reshape (array , sh , copy = True ))
1514+ sh = sh [- 1 :] + sh [:- 1 ]
1515+
1516+ output = dpt .broadcast_arrays (* res )
1517+
1518+ return output
1519+
1520+
14461521def tril (x , / , * , k = 0 ):
14471522 """
14481523 Returns the lower triangular part of a matrix (or a stack of matrices)
0 commit comments