Skip to content

Commit 87f5529

Browse files
Move ti.maximum()/minimum()/multiply() and reuse them
1 parent 1fb889d commit 87f5529

14 files changed

Lines changed: 1447 additions & 16 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ set(_elementwise_sources
118118
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_not.cpp
119119
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_or.cpp
120120
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_xor.cpp
121-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/maximum.cpp
122-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/minimum.cpp
123-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/multiply.cpp
121+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/maximum.cpp
122+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/minimum.cpp
123+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/multiply.cpp
124124
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/negative.cpp
125125
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/nextafter.cpp
126126
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/not_equal.cpp

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@
135135
logical_not,
136136
logical_or,
137137
logical_xor,
138+
maximum,
139+
minimum,
140+
multiply,
138141
negative,
139142
positive,
140143
proj,
@@ -256,9 +259,12 @@
256259
"log2",
257260
"log10",
258261
"max",
262+
"maximum",
259263
"meshgrid",
260264
"min",
265+
"minimum",
261266
"moveaxis",
267+
"multiply",
262268
"permute_dims",
263269
"matmul",
264270
"matrix_transpose",

dpctl_ext/tensor/_elementwise_funcs.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,106 @@
13681368
)
13691369
del _logical_not_docstring
13701370

1371+
# B26: ==== MAXIMUM (x1, x2)
1372+
_maximum_docstring_ = r"""
1373+
maximum(x1, x2, /, \*, out=None, order='K')
1374+
1375+
Compares two input arrays `x1` and `x2` and returns a new array containing the
1376+
element-wise maxima.
1377+
1378+
Args:
1379+
x1 (usm_ndarray):
1380+
First input array. May have any data type.
1381+
x2 (usm_ndarray):
1382+
Second input array. May have any data type.
1383+
out (Union[usm_ndarray, None], optional):
1384+
Output array to populate.
1385+
Array must have the correct shape and the expected data type.
1386+
order ("C","F","A","K", optional):
1387+
Memory layout of the new output array, if parameter
1388+
`out` is ``None``.
1389+
Default: "K".
1390+
1391+
Returns:
1392+
usm_ndarray:
1393+
An array containing the element-wise maxima. The data type of
1394+
the returned array is determined by the Type Promotion Rules.
1395+
"""
1396+
maximum = BinaryElementwiseFunc(
1397+
"maximum",
1398+
ti._maximum_result_type,
1399+
ti._maximum,
1400+
_maximum_docstring_,
1401+
)
1402+
del _maximum_docstring_
1403+
1404+
# B27: ==== MINIMUM (x1, x2)
1405+
_minimum_docstring_ = r"""
1406+
minimum(x1, x2, /, \*, out=None, order='K')
1407+
1408+
Compares two input arrays `x1` and `x2` and returns a new array containing the
1409+
element-wise minima.
1410+
1411+
Args:
1412+
x1 (usm_ndarray):
1413+
First input array. May have any data type.
1414+
x2 (usm_ndarray):
1415+
Second input array. May have any data type.
1416+
out (Union[usm_ndarray, None], optional):
1417+
Output array to populate.
1418+
Array must have the correct shape and the expected data type.
1419+
order ("C","F","A","K", optional):
1420+
Memory layout of the new output array, if parameter
1421+
`out` is ``None``.
1422+
Default: "K".
1423+
1424+
Returns:
1425+
usm_ndarray:
1426+
An array containing the element-wise minima. The data type of
1427+
the returned array is determined by the Type Promotion Rules.
1428+
"""
1429+
minimum = BinaryElementwiseFunc(
1430+
"minimum",
1431+
ti._minimum_result_type,
1432+
ti._minimum,
1433+
_minimum_docstring_,
1434+
)
1435+
del _minimum_docstring_
1436+
1437+
# B19: ==== MULTIPLY (x1, x2)
1438+
_multiply_docstring_ = r"""
1439+
multiply(x1, x2, /, \*, out=None, order='K')
1440+
1441+
Calculates the product for each element `x1_i` of the input array `x1` with the
1442+
respective element `x2_i` of the input array `x2`.
1443+
1444+
Args:
1445+
x1 (usm_ndarray):
1446+
First input array. May have any data type.
1447+
x2 (usm_ndarray):
1448+
Second input array. May have any data type.
1449+
out (Union[usm_ndarray, None], optional):
1450+
Output array to populate.
1451+
Array must have the correct shape and the expected data type.
1452+
order ("C","F","A","K", optional):
1453+
Memory layout of the new output array, if parameter
1454+
`out` is ``None``.
1455+
Default: "K".
1456+
1457+
Returns:
1458+
usm_ndarray:
1459+
An array containing the element-wise products. The data type of
1460+
the returned array is determined by the Type Promotion Rules.
1461+
"""
1462+
multiply = BinaryElementwiseFunc(
1463+
"multiply",
1464+
ti._multiply_result_type,
1465+
ti._multiply,
1466+
_multiply_docstring_,
1467+
binary_inplace_fn=ti._multiply_inplace,
1468+
)
1469+
del _multiply_docstring_
1470+
13711471
# U25: ==== NEGATIVE (x)
13721472
_negative_docstring_ = r"""
13731473
negative(x, /, \*, out=None, order='K')

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/maximum.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
2626
// THE POSSIBILITY OF SUCH DAMAGE.
2727
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
2830
///
2931
/// \file
3032
/// This file defines kernels for elementwise evaluation of MAXIMUM(x1, x2)
@@ -53,6 +55,7 @@
5355

5456
namespace dpctl::tensor::kernels::maximum
5557
{
58+
5659
using dpctl::tensor::ssize_t;
5760
namespace td_ns = dpctl::tensor::type_dispatch;
5861
namespace tu_ns = dpctl::tensor::type_utils;

dpctl_ext/tensor/libtensor/include/kernels/elementwise_functions/minimum.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
2626
// THE POSSIBILITY OF SUCH DAMAGE.
2727
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
2830
///
2931
/// \file
3032
/// This file defines kernels for elementwise evaluation of MINIMUM(x1, x2)
@@ -52,6 +54,7 @@
5254

5355
namespace dpctl::tensor::kernels::minimum
5456
{
57+
5558
using dpctl::tensor::ssize_t;
5659
namespace td_ns = dpctl::tensor::type_dispatch;
5760
namespace tu_ns = dpctl::tensor::type_utils;

0 commit comments

Comments
 (0)