Skip to content

Commit a6100b8

Browse files
Move _cos/_cosh to _tensor_elementwise_impl
1 parent e24b129 commit a6100b8

8 files changed

Lines changed: 960 additions & 6 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ set(_elementwise_sources
9292
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/ceil.cpp
9393
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/conj.cpp
9494
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/copysign.cpp
95-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cos.cpp
96-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cosh.cpp
95+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cos.cpp
96+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/cosh.cpp
9797
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/equal.cpp
9898
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp.cpp
9999
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/exp2.cpp
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines kernels for elementwise evaluation of COS(x) function.
33+
//===---------------------------------------------------------------------===//
34+
35+
#pragma once
36+
#include <cmath>
37+
#include <complex>
38+
#include <cstddef>
39+
#include <cstdint>
40+
#include <limits>
41+
#include <type_traits>
42+
#include <vector>
43+
44+
#include <sycl/sycl.hpp>
45+
46+
#include "sycl_complex.hpp"
47+
#include "vec_size_util.hpp"
48+
49+
#include "kernels/dpctl_tensor_types.hpp"
50+
#include "kernels/elementwise_functions/common.hpp"
51+
52+
#include "utils/offset_utils.hpp"
53+
#include "utils/type_dispatch_building.hpp"
54+
#include "utils/type_utils.hpp"
55+
56+
namespace dpctl::tensor::kernels::cos
57+
{
58+
59+
using dpctl::tensor::ssize_t;
60+
namespace td_ns = dpctl::tensor::type_dispatch;
61+
62+
using dpctl::tensor::type_utils::is_complex;
63+
64+
template <typename argT, typename resT>
65+
struct CosFunctor
66+
{
67+
68+
// is function constant for given argT
69+
using is_constant = typename std::false_type;
70+
// constant value, if constant
71+
// constexpr resT constant_value = resT{};
72+
// is function defined for sycl::vec
73+
using supports_vec = typename std::false_type;
74+
// do both argTy and resTy support sugroup store/load operation
75+
using supports_sg_loadstore = typename std::negation<
76+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
77+
78+
resT operator()(const argT &in) const
79+
{
80+
if constexpr (is_complex<argT>::value) {
81+
using realT = typename argT::value_type;
82+
83+
static constexpr realT q_nan =
84+
std::numeric_limits<realT>::quiet_NaN();
85+
86+
realT const &in_re = std::real(in);
87+
realT const &in_im = std::imag(in);
88+
89+
const bool in_re_finite = std::isfinite(in_re);
90+
const bool in_im_finite = std::isfinite(in_im);
91+
92+
/*
93+
* Handle the nearly-non-exceptional cases where
94+
* real and imaginary parts of input are finite.
95+
*/
96+
if (in_re_finite && in_im_finite) {
97+
return exprm_ns::cos(exprm_ns::complex<realT>(in)); // cos(in);
98+
}
99+
100+
/*
101+
* since cos(in) = cosh(I * in), for special cases,
102+
* we return cosh(I * in).
103+
*/
104+
const realT x = -in_im;
105+
const realT y = in_re;
106+
107+
const bool xfinite = in_im_finite;
108+
const bool yfinite = in_re_finite;
109+
/*
110+
* cosh(+-0 +- I Inf) = dNaN + I sign(d(+-0, dNaN))0.
111+
* The sign of 0 in the result is unspecified. Choice = normally
112+
* the same as dNaN.
113+
*
114+
* cosh(+-0 +- I NaN) = d(NaN) + I sign(d(+-0, NaN))0.
115+
* The sign of 0 in the result is unspecified. Choice = normally
116+
* the same as d(NaN).
117+
*/
118+
if (x == realT(0) && !yfinite) {
119+
const realT y_m_y = (y - y);
120+
const realT res_im = sycl::copysign(realT(0), x * y_m_y);
121+
return resT{y_m_y, res_im};
122+
}
123+
124+
/*
125+
* cosh(+-Inf +- I 0) = +Inf + I (+-)(+-)0.
126+
*
127+
* cosh(NaN +- I 0) = d(NaN) + I sign(d(NaN, +-0))0.
128+
* The sign of 0 in the result is unspecified.
129+
*/
130+
if (y == realT(0) && !xfinite) {
131+
const realT res_im = sycl::copysign(realT(0), x) * y;
132+
return resT{x * x, res_im};
133+
}
134+
135+
/*
136+
* cosh(x +- I Inf) = dNaN + I dNaN.
137+
*
138+
* cosh(x + I NaN) = d(NaN) + I d(NaN).
139+
*/
140+
if (xfinite && !yfinite) {
141+
const realT y_m_y = (y - y);
142+
return resT{y_m_y, x * y_m_y};
143+
}
144+
145+
/*
146+
* cosh(+-Inf + I NaN) = +Inf + I d(NaN).
147+
*
148+
* cosh(+-Inf +- I Inf) = +Inf + I dNaN.
149+
* The sign of Inf in the result is unspecified. Choice = always +.
150+
*
151+
* cosh(+-Inf + I y) = +Inf cos(y) +- I Inf sin(y)
152+
*/
153+
if (std::isinf(x)) {
154+
if (!yfinite) {
155+
return resT{x * x, sycl::copysign(q_nan, x)};
156+
}
157+
return resT{(x * x) * sycl::cos(y), x * sycl::sin(y)};
158+
}
159+
160+
/*
161+
* cosh(NaN + I NaN) = d(NaN) + I d(NaN).
162+
*
163+
* cosh(NaN +- I Inf) = d(NaN) + I d(NaN).
164+
*
165+
* cosh(NaN + I y) = d(NaN) + I d(NaN).
166+
*/
167+
return resT{(x * x) * q_nan, (x + x) * q_nan};
168+
}
169+
else {
170+
static_assert(std::is_floating_point_v<argT> ||
171+
std::is_same_v<argT, sycl::half>);
172+
return sycl::cos(in);
173+
}
174+
}
175+
};
176+
177+
template <typename argTy,
178+
typename resTy = argTy,
179+
std::uint8_t vec_sz = 4u,
180+
std::uint8_t n_vecs = 2u,
181+
bool enable_sg_loadstore = true>
182+
using CosContigFunctor =
183+
elementwise_common::UnaryContigFunctor<argTy,
184+
resTy,
185+
CosFunctor<argTy, resTy>,
186+
vec_sz,
187+
n_vecs,
188+
enable_sg_loadstore>;
189+
190+
template <typename argTy, typename resTy, typename IndexerT>
191+
using CosStridedFunctor = elementwise_common::
192+
UnaryStridedFunctor<argTy, resTy, IndexerT, CosFunctor<argTy, resTy>>;
193+
194+
template <typename T>
195+
struct CosOutputType
196+
{
197+
using value_type = typename std::disjunction<
198+
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
199+
td_ns::TypeMapResultEntry<T, float, float>,
200+
td_ns::TypeMapResultEntry<T, double, double>,
201+
td_ns::TypeMapResultEntry<T, std::complex<float>, std::complex<float>>,
202+
td_ns::
203+
TypeMapResultEntry<T, std::complex<double>, std::complex<double>>,
204+
td_ns::DefaultResultEntry<void>>::result_type;
205+
206+
static constexpr bool is_defined = !std::is_same_v<value_type, void>;
207+
};
208+
209+
namespace hyperparam_detail
210+
{
211+
212+
namespace vsu_ns = dpctl::tensor::kernels::vec_size_utils;
213+
214+
using vsu_ns::ContigHyperparameterSetDefault;
215+
using vsu_ns::UnaryContigHyperparameterSetEntry;
216+
217+
template <typename argTy>
218+
struct CosContigHyperparameterSet
219+
{
220+
using value_type =
221+
typename std::disjunction<ContigHyperparameterSetDefault<4u, 2u>>;
222+
223+
constexpr static auto vec_sz = value_type::vec_sz;
224+
constexpr static auto n_vecs = value_type::n_vecs;
225+
};
226+
227+
} // end of namespace hyperparam_detail
228+
229+
template <typename T1, typename T2, std::uint8_t vec_sz, std::uint8_t n_vecs>
230+
class cos_contig_kernel;
231+
232+
template <typename argTy>
233+
sycl::event cos_contig_impl(sycl::queue &exec_q,
234+
std::size_t nelems,
235+
const char *arg_p,
236+
char *res_p,
237+
const std::vector<sycl::event> &depends = {})
238+
{
239+
using CosHS = hyperparam_detail::CosContigHyperparameterSet<argTy>;
240+
static constexpr std::uint8_t vec_sz = CosHS::vec_sz;
241+
static constexpr std::uint8_t n_vecs = CosHS::n_vecs;
242+
243+
return elementwise_common::unary_contig_impl<
244+
argTy, CosOutputType, CosContigFunctor, cos_contig_kernel, vec_sz,
245+
n_vecs>(exec_q, nelems, arg_p, res_p, depends);
246+
}
247+
248+
template <typename fnT, typename T>
249+
struct CosContigFactory
250+
{
251+
fnT get()
252+
{
253+
if constexpr (!CosOutputType<T>::is_defined) {
254+
fnT fn = nullptr;
255+
return fn;
256+
}
257+
else {
258+
fnT fn = cos_contig_impl<T>;
259+
return fn;
260+
}
261+
}
262+
};
263+
264+
template <typename fnT, typename T>
265+
struct CosTypeMapFactory
266+
{
267+
/*! @brief get typeid for output type of sycl::cos(T x) */
268+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
269+
{
270+
using rT = typename CosOutputType<T>::value_type;
271+
return td_ns::GetTypeid<rT>{}.get();
272+
}
273+
};
274+
275+
template <typename T1, typename T2, typename T3>
276+
class cos_strided_kernel;
277+
278+
template <typename argTy>
279+
sycl::event cos_strided_impl(sycl::queue &exec_q,
280+
std::size_t nelems,
281+
int nd,
282+
const ssize_t *shape_and_strides,
283+
const char *arg_p,
284+
ssize_t arg_offset,
285+
char *res_p,
286+
ssize_t res_offset,
287+
const std::vector<sycl::event> &depends,
288+
const std::vector<sycl::event> &additional_depends)
289+
{
290+
return elementwise_common::unary_strided_impl<
291+
argTy, CosOutputType, CosStridedFunctor, cos_strided_kernel>(
292+
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
293+
res_offset, depends, additional_depends);
294+
}
295+
296+
template <typename fnT, typename T>
297+
struct CosStridedFactory
298+
{
299+
fnT get()
300+
{
301+
if constexpr (!CosOutputType<T>::is_defined) {
302+
fnT fn = nullptr;
303+
return fn;
304+
}
305+
else {
306+
fnT fn = cos_strided_impl<T>;
307+
return fn;
308+
}
309+
}
310+
};
311+
312+
} // namespace dpctl::tensor::kernels::cos

0 commit comments

Comments
 (0)