@@ -137,72 +137,6 @@ def _comparison_over_axis(x, axis, keepdims, out, _reduction_fn):
137137 return out
138138
139139
140- def max (x , / , * , axis = None , keepdims = False , out = None ):
141- """
142- Calculates the maximum value of the input array ``x``.
143-
144- Args:
145- x (usm_ndarray):
146- input array.
147- axis (Optional[int, Tuple[int, ...]]):
148- axis or axes along which maxima must be computed. If a tuple
149- of unique integers, the maxima are computed over multiple axes.
150- If ``None``, the max is computed over the entire array.
151- Default: ``None``.
152- keepdims (Optional[bool]):
153- if ``True``, the reduced axes (dimensions) are included in the
154- result as singleton dimensions, so that the returned array remains
155- compatible with the input arrays according to Array Broadcasting
156- rules. Otherwise, if ``False``, the reduced axes are not included
157- in the returned array. Default: ``False``.
158- out (Optional[usm_ndarray]):
159- the array into which the result is written.
160- The data type of ``out`` must match the expected shape and the
161- expected data type of the result.
162- If ``None`` then a new array is returned. Default: ``None``.
163-
164- Returns:
165- usm_ndarray:
166- an array containing the maxima. If the max was computed over the
167- entire array, a zero-dimensional array is returned. The returned
168- array has the same data type as ``x``.
169- """
170- return _comparison_over_axis (x , axis , keepdims , out , tri ._max_over_axis )
171-
172-
173- def min (x , / , * , axis = None , keepdims = False , out = None ):
174- """
175- Calculates the minimum value of the input array ``x``.
176-
177- Args:
178- x (usm_ndarray):
179- input array.
180- axis (Optional[int, Tuple[int, ...]]):
181- axis or axes along which minima must be computed. If a tuple
182- of unique integers, the minima are computed over multiple axes.
183- If ``None``, the min is computed over the entire array.
184- Default: ``None``.
185- keepdims (Optional[bool]):
186- if ``True``, the reduced axes (dimensions) are included in the
187- result as singleton dimensions, so that the returned array remains
188- compatible with the input arrays according to Array Broadcasting
189- rules. Otherwise, if ``False``, the reduced axes are not included
190- in the returned array. Default: ``False``.
191- out (Optional[usm_ndarray]):
192- the array into which the result is written.
193- The data type of ``out`` must match the expected shape and the
194- expected data type of the result.
195- If ``None`` then a new array is returned. Default: ``None``.
196-
197- Returns:
198- usm_ndarray:
199- an array containing the minima. If the min was computed over the
200- entire array, a zero-dimensional array is returned. The returned
201- array has the same data type as ``x``.
202- """
203- return _comparison_over_axis (x , axis , keepdims , out , tri ._min_over_axis )
204-
205-
206140def _search_over_axis (x , axis , keepdims , out , _reduction_fn ):
207141 if not isinstance (x , dpt .usm_ndarray ):
208142 raise TypeError (f"Expected dpctl.tensor.usm_ndarray, got { type (x )} " )
@@ -376,18 +310,17 @@ def argmin(x, /, *, axis=None, keepdims=False, out=None):
376310 return _search_over_axis (x , axis , keepdims , out , tri ._argmin_over_axis )
377311
378312
379- def count_nonzero (x , / , * , axis = None , keepdims = False , out = None ):
313+ def max (x , / , * , axis = None , keepdims = False , out = None ):
380314 """
381- Counts the number of elements in the input array ``x`` which are non-zero .
315+ Calculates the maximum value of the input array ``x``.
382316
383317 Args:
384318 x (usm_ndarray):
385319 input array.
386320 axis (Optional[int, Tuple[int, ...]]):
387- axis or axes along which to count. If a tuple of unique integers,
388- the number of non-zero values are computed over multiple axes.
389- If ``None``, the number of non-zero values is computed over the
390- entire array.
321+ axis or axes along which maxima must be computed. If a tuple
322+ of unique integers, the maxima are computed over multiple axes.
323+ If ``None``, the max is computed over the entire array.
391324 Default: ``None``.
392325 keepdims (Optional[bool]):
393326 if ``True``, the reduced axes (dimensions) are included in the
@@ -397,23 +330,47 @@ def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
397330 in the returned array. Default: ``False``.
398331 out (Optional[usm_ndarray]):
399332 the array into which the result is written.
400- The data type of ``out`` must match the expected shape and data
401- type.
333+ The data type of ``out`` must match the expected shape and the
334+ expected data type of the result .
402335 If ``None`` then a new array is returned. Default: ``None``.
403336
404337 Returns:
405338 usm_ndarray:
406- an array containing the count of non-zero values. If the sum was
407- computed over the entire array, a zero-dimensional array is
408- returned. The returned array will have the default array index data
409- type.
339+ an array containing the maxima. If the max was computed over the
340+ entire array, a zero-dimensional array is returned. The returned
341+ array has the same data type as ``x``.
410342 """
411- if x .dtype != dpt .bool :
412- x = dpt_ext .astype (x , dpt .bool , copy = False )
413- return sum (
414- x ,
415- axis = axis ,
416- dtype = ti .default_device_index_type (x .sycl_device ),
417- keepdims = keepdims ,
418- out = out ,
419- )
343+ return _comparison_over_axis (x , axis , keepdims , out , tri ._max_over_axis )
344+
345+
346+ def min (x , / , * , axis = None , keepdims = False , out = None ):
347+ """
348+ Calculates the minimum value of the input array ``x``.
349+
350+ Args:
351+ x (usm_ndarray):
352+ input array.
353+ axis (Optional[int, Tuple[int, ...]]):
354+ axis or axes along which minima must be computed. If a tuple
355+ of unique integers, the minima are computed over multiple axes.
356+ If ``None``, the min is computed over the entire array.
357+ Default: ``None``.
358+ keepdims (Optional[bool]):
359+ if ``True``, the reduced axes (dimensions) are included in the
360+ result as singleton dimensions, so that the returned array remains
361+ compatible with the input arrays according to Array Broadcasting
362+ rules. Otherwise, if ``False``, the reduced axes are not included
363+ in the returned array. Default: ``False``.
364+ out (Optional[usm_ndarray]):
365+ the array into which the result is written.
366+ The data type of ``out`` must match the expected shape and the
367+ expected data type of the result.
368+ If ``None`` then a new array is returned. Default: ``None``.
369+
370+ Returns:
371+ usm_ndarray:
372+ an array containing the minima. If the min was computed over the
373+ entire array, a zero-dimensional array is returned. The returned
374+ array has the same data type as ``x``.
375+ """
376+ return _comparison_over_axis (x , axis , keepdims , out , tri ._min_over_axis )
0 commit comments