|
47 | 47 | ) |
48 | 48 |
|
49 | 49 |
|
| 50 | +def _broadcast_shapes(*args): |
| 51 | + """ |
| 52 | + Broadcast the input shapes into a single shape; |
| 53 | + returns tuple broadcasted shape. |
| 54 | + """ |
| 55 | + array_shapes = [array.shape for array in args] |
| 56 | + return _broadcast_shape_impl(array_shapes) |
| 57 | + |
| 58 | + |
50 | 59 | def _broadcast_shape_impl(shapes): |
51 | 60 | if len(set(shapes)) == 1: |
52 | 61 | return shapes[0] |
@@ -103,6 +112,37 @@ def _broadcast_strides(X_shape, X_strides, res_ndim): |
103 | 112 | return tuple(out_strides) |
104 | 113 |
|
105 | 114 |
|
| 115 | +def broadcast_arrays(*args): |
| 116 | + """broadcast_arrays(*arrays) |
| 117 | +
|
| 118 | + Broadcasts one or more :class:`dpctl.tensor.usm_ndarrays` against |
| 119 | + one another. |
| 120 | +
|
| 121 | + Args: |
| 122 | + arrays (usm_ndarray): an arbitrary number of arrays to be |
| 123 | + broadcasted. |
| 124 | +
|
| 125 | + Returns: |
| 126 | + List[usm_ndarray]: |
| 127 | + A list of broadcasted arrays. Each array |
| 128 | + must have the same shape. Each array must have the same `dtype`, |
| 129 | + `device` and `usm_type` attributes as its corresponding input |
| 130 | + array. |
| 131 | + """ |
| 132 | + if len(args) == 0: |
| 133 | + raise ValueError("`broadcast_arrays` requires at least one argument") |
| 134 | + for X in args: |
| 135 | + if not isinstance(X, dpt.usm_ndarray): |
| 136 | + raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") |
| 137 | + |
| 138 | + shape = _broadcast_shapes(*args) |
| 139 | + |
| 140 | + if all(X.shape == shape for X in args): |
| 141 | + return args |
| 142 | + |
| 143 | + return [broadcast_to(X, shape) for X in args] |
| 144 | + |
| 145 | + |
106 | 146 | def broadcast_to(X, /, shape): |
107 | 147 | """broadcast_to(x, shape) |
108 | 148 |
|
|
0 commit comments