Skip to content

Commit 58fdef3

Browse files
Initialize _tensor_elementwise_impl extension and move _abs
1 parent 4872bb6 commit 58fdef3

11 files changed

Lines changed: 1315 additions & 1 deletion

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,81 @@ set(_accumulator_sources
6969
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
7070
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
7171
)
72+
set(_elementwise_sources
73+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_common.cpp
74+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/elementwise_functions_type_utils.cpp
75+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/abs.cpp
76+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acos.cpp
77+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/acosh.cpp
78+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/add.cpp
79+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/angle.cpp
80+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asin.cpp
81+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/asinh.cpp
82+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan.cpp
83+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atan2.cpp
84+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/atanh.cpp
85+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_and.cpp
86+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_invert.cpp
87+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_left_shift.cpp
88+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_or.cpp
89+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_right_shift.cpp
90+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/bitwise_xor.cpp
91+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cbrt.cpp
92+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/ceil.cpp
93+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/conj.cpp
94+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/copysign.cpp
95+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cos.cpp
96+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cosh.cpp
97+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/equal.cpp
98+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp.cpp
99+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp2.cpp
100+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/expm1.cpp
101+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor_divide.cpp
102+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/floor.cpp
103+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/greater_equal.cpp
104+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/greater.cpp
105+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/hypot.cpp
106+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/imag.cpp
107+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isfinite.cpp
108+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isinf.cpp
109+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/isnan.cpp
110+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less_equal.cpp
111+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/less.cpp
112+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log.cpp
113+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log1p.cpp
114+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log2.cpp
115+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/log10.cpp
116+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logaddexp.cpp
117+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_and.cpp
118+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_not.cpp
119+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_or.cpp
120+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/logical_xor.cpp
121+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/maximum.cpp
122+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/minimum.cpp
123+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/multiply.cpp
124+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/negative.cpp
125+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/nextafter.cpp
126+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/not_equal.cpp
127+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/positive.cpp
128+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp
129+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/proj.cpp
130+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/real.cpp
131+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/reciprocal.cpp
132+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/remainder.cpp
133+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/round.cpp
134+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/rsqrt.cpp
135+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sign.cpp
136+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/signbit.cpp
137+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sin.cpp
138+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sinh.cpp
139+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/sqrt.cpp
140+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/square.cpp
141+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/subtract.cpp
142+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tan.cpp
143+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/tanh.cpp
144+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/true_divide.cpp
145+
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/trunc.cpp
146+
)
72147
set(_reduction_sources
73148
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp
74149
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/all.cpp
@@ -95,6 +170,10 @@ set(_tensor_accumulation_impl_sources
95170
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_accumulation.cpp
96171
${_accumulator_sources}
97172
)
173+
set(_tensor_elementwise_impl_sources
174+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_elementwise.cpp
175+
${_elementwise_sources}
176+
)
98177
set(_tensor_reductions_impl_sources
99178
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_reductions.cpp
100179
${_reduction_sources}
@@ -131,6 +210,12 @@ add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_accumulation_i
131210
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
132211
list(APPEND _py_trgts ${python_module_name})
133212

213+
set(python_module_name _tensor_elementwise_impl)
214+
pybind11_add_module(${python_module_name} MODULE ${_tensor_elementwise_impl_sources})
215+
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_elementwise_impl_sources})
216+
target_link_libraries(${python_module_name} PRIVATE ${_static_lib_trgt})
217+
list(APPEND _py_trgts ${python_module_name})
218+
134219
set(python_module_name _tensor_reductions_impl)
135220
pybind11_add_module(${python_module_name} MODULE ${_tensor_reductions_impl_sources})
136221
add_sycl_to_target(TARGET ${python_module_name} SOURCES ${_tensor_reductions_impl_sources})
@@ -157,7 +242,7 @@ set(_no_fast_math_sources
157242
)
158243
list(
159244
APPEND _no_fast_math_sources
160-
# ${_elementwise_sources}
245+
${_elementwise_sources}
161246
${_reduction_sources}
162247
${_sorting_sources}
163248
# ${_linalg_sources}
@@ -175,6 +260,19 @@ endforeach()
175260

176261
set(_compiler_definitions "")
177262

263+
foreach(_src_fn ${_elementwise_sources})
264+
get_source_file_property(_cmpl_options_defs ${_src_fn} COMPILE_DEFINITIONS)
265+
if(${_cmpl_options_defs})
266+
set(_combined_options_defs ${_cmpl_options_defs} "${_compiler_definitions}")
267+
else()
268+
set(_combined_options_defs "${_compiler_definitions}")
269+
endif()
270+
set_source_files_properties(
271+
${_src_fn}
272+
PROPERTIES COMPILE_DEFINITIONS "${_combined_options_defs}"
273+
)
274+
endforeach()
275+
178276
set(_linker_options "LINKER:${DPNP_LDFLAGS}")
179277
foreach(python_module_name ${_py_trgts})
180278
target_compile_options(
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of ABS(x) function.
33+
//===---------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#include <complex>
37+
#include <cstddef>
38+
#include <cstdint>
39+
#include <type_traits>
40+
#include <vector>
41+
42+
#include <sycl/sycl.hpp>
43+
44+
#include "cabs_impl.hpp"
45+
#include "vec_size_util.hpp"
46+
47+
#include "kernels/dpctl_tensor_types.hpp"
48+
#include "kernels/elementwise_functions/common.hpp"
49+
50+
#include "utils/offset_utils.hpp"
51+
#include "utils/type_dispatch_building.hpp"
52+
#include "utils/type_utils.hpp"
53+
54+
namespace dpctl::tensor::kernels::abs
55+
{
56+
57+
namespace td_ns = dpctl::tensor::type_dispatch;
58+
59+
using dpctl::tensor::ssize_t;
60+
using dpctl::tensor::type_utils::is_complex;
61+
62+
template <typename argT, typename resT>
63+
struct AbsFunctor
64+
{
65+
66+
using is_constant = typename std::false_type;
67+
// constexpr resT constant_value = resT{};
68+
using supports_vec = typename std::false_type;
69+
using supports_sg_loadstore = typename std::negation<
70+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
71+
72+
resT operator()(const argT &x) const
73+
{
74+
75+
if constexpr (std::is_same_v<argT, bool> ||
76+
(std::is_integral<argT>::value &&
77+
std::is_unsigned<argT>::value))
78+
{
79+
static_assert(std::is_same_v<resT, argT>);
80+
return x;
81+
}
82+
else {
83+
if constexpr (is_complex<argT>::value) {
84+
return detail::cabs(x);
85+
}
86+
else if constexpr (std::is_same_v<argT, sycl::half> ||
87+
std::is_floating_point_v<argT>)
88+
{
89+
return (sycl::signbit(x) ? -x : x);
90+
}
91+
else {
92+
return sycl::abs(x);
93+
}
94+
}
95+
}
96+
};
97+
98+
template <typename argT,
99+
typename resT = argT,
100+
std::uint8_t vec_sz = 4u,
101+
std::uint8_t n_vecs = 2u,
102+
bool enable_sg_loadstore = true>
103+
using AbsContigFunctor =
104+
elementwise_common::UnaryContigFunctor<argT,
105+
resT,
106+
AbsFunctor<argT, resT>,
107+
vec_sz,
108+
n_vecs,
109+
enable_sg_loadstore>;
110+
111+
template <typename T>
112+
struct AbsOutputType
113+
{
114+
using value_type = typename std::disjunction<
115+
td_ns::TypeMapResultEntry<T, bool>,
116+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
117+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
118+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
119+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
120+
td_ns::TypeMapResultEntry<T, std::int8_t>,
121+
td_ns::TypeMapResultEntry<T, std::int16_t>,
122+
td_ns::TypeMapResultEntry<T, std::int32_t>,
123+
td_ns::TypeMapResultEntry<T, std::int64_t>,
124+
td_ns::TypeMapResultEntry<T, sycl::half>,
125+
td_ns::TypeMapResultEntry<T, float>,
126+
td_ns::TypeMapResultEntry<T, double>,
127+
td_ns::TypeMapResultEntry<T, std::complex<float>, float>,
128+
td_ns::TypeMapResultEntry<T, std::complex<double>, double>,
129+
td_ns::DefaultResultEntry<void>>::result_type;
130+
131+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
132+
};
133+
134+
namespace hyperparam_detail
135+
{
136+
137+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
138+
139+
using vsu_ns::ContigHyperparameterSetDefault;
140+
141+
template <typename argTy>
142+
struct AbsContigHyperparameterSet
143+
{
144+
using value_type =
145+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
146+
147+
constexpr static auto vec_sz = value_type::vec_sz;
148+
constexpr static auto n_vecs = value_type::n_vecs;
149+
};
150+
151+
} // namespace hyperparam_detail
152+
153+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
154+
class abs_contig_kernel;
155+
156+
template <typename argTy>
157+
sycl::event abs_contig_impl(sycl::queue &exec_q,
158+
std::size_t nelems,
159+
const char *arg_p,
160+
char *res_p,
161+
const std::vector<sycl::event> &depends = {})
162+
{
163+
using AbsHS = hyperparam_detail::AbsContigHyperparameterSet<argTy>;
164+
static constexpr std::uint8_t vec_sz = AbsHS::vec_sz;
165+
static constexpr std::uint8_t n_vec = AbsHS::n_vecs;
166+
167+
return elementwise_common::unary_contig_impl<
168+
argTy, AbsOutputType, AbsContigFunctor, abs_contig_kernel, vec_sz,
169+
n_vec>(exec_q, nelems, arg_p, res_p, depends);
170+
}
171+
172+
template <typename fnT, typename T>
173+
struct AbsContigFactory
174+
{
175+
fnT get()
176+
{
177+
if constexpr (!AbsOutputType<T>::is_defined) {
178+
fnT fn = nullptr;
179+
return fn;
180+
}
181+
else {
182+
fnT fn = abs_contig_impl<T>;
183+
return fn;
184+
}
185+
}
186+
};
187+
188+
template <typename fnT, typename T>
189+
struct AbsTypeMapFactory
190+
{
191+
/*! @brief get typeid for output type of abs(T x) */
192+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
193+
{
194+
using rT = typename AbsOutputType<T>::value_type;
195+
return td_ns::GetTypeid<rT>{}.get();
196+
}
197+
};
198+
199+
template <typename argTy, typename resTy, typename IndexerT>
200+
using AbsStridedFunctor = elementwise_common::
201+
UnaryStridedFunctor<argTy, resTy, IndexerT, AbsFunctor<argTy, resTy>>;
202+
203+
template <typename T1, typename T2, typename T3>
204+
class abs_strided_kernel;
205+
206+
template <typename argTy>
207+
sycl::event abs_strided_impl(sycl::queue &exec_q,
208+
std::size_t nelems,
209+
int nd,
210+
const ssize_t *shape_and_strides,
211+
const char *arg_p,
212+
ssize_t arg_offset,
213+
char *res_p,
214+
ssize_t res_offset,
215+
const std::vector<sycl::event> &depends,
216+
const std::vector<sycl::event> &additional_depends)
217+
{
218+
return elementwise_common::unary_strided_impl<
219+
argTy, AbsOutputType, AbsStridedFunctor, abs_strided_kernel>(
220+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
221+
res_offset, depends, additional_depends);
222+
}
223+
224+
template <typename fnT, typename T>
225+
struct AbsStridedFactory
226+
{
227+
fnT get()
228+
{
229+
if constexpr (!AbsOutputType<T>::is_defined) {
230+
fnT fn = nullptr;
231+
return fn;
232+
}
233+
else {
234+
fnT fn = abs_strided_impl<T>;
235+
return fn;
236+
}
237+
}
238+
};
239+
240+
} // namespace dpctl::tensor::kernels::abs

0 commit comments

Comments
 (0)