Skip to content

Commit b1953df

Browse files
Move _any to _tensor_reductions_impl
1 parent bd3add0 commit b1953df

4 files changed

Lines changed: 213 additions & 3 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ set(_accumulator_sources
7272
set(_reduction_sources
7373
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/reduction_common.cpp
7474
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/all.cpp
75-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/any.cpp
75+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/any.cpp
7676
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmax.cpp
7777
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/argmin.cpp
7878
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/logsumexp.cpp
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_reductions_impl
33+
/// extension.
34+
//===---------------------------------------------------------------------===//
35+
36+
#include <cstdint>
37+
#include <vector>
38+
39+
#include <sycl/sycl.hpp>
40+
41+
#include "dpnp4pybind11.hpp"
42+
#include <pybind11/pybind11.h>
43+
#include <pybind11/stl.h>
44+
45+
#include "kernels/reductions.hpp"
46+
#include "reduction_atomic_support.hpp"
47+
#include "reduction_over_axis.hpp"
48+
#include "utils/type_dispatch.hpp"
49+
50+
namespace dpctl::tensor::py_internal
51+
{
52+
53+
namespace py = pybind11;
54+
namespace td_ns = dpctl::tensor::type_dispatch;
55+
56+
namespace impl
57+
{
58+
59+
using dpctl::tensor::kernels::reduction_strided_impl_fn_ptr;
60+
static reduction_strided_impl_fn_ptr
61+
any_reduction_strided_dispatch_vector[td_ns::num_types];
62+
63+
using dpctl::tensor::kernels::reduction_contig_impl_fn_ptr;
64+
static reduction_contig_impl_fn_ptr
65+
any_reduction_axis1_contig_dispatch_vector[td_ns::num_types];
66+
static reduction_contig_impl_fn_ptr
67+
any_reduction_axis0_contig_dispatch_vector[td_ns::num_types];
68+
69+
template <typename fnT, typename srcTy>
70+
struct AnyStridedFactory
71+
{
72+
fnT get() const
73+
{
74+
using dstTy = std::int32_t;
75+
using ReductionOpT = sycl::logical_or<dstTy>;
76+
return dpctl::tensor::kernels::
77+
reduction_over_group_with_atomics_strided_impl<srcTy, dstTy,
78+
ReductionOpT>;
79+
}
80+
};
81+
82+
template <typename fnT, typename srcTy>
83+
struct AnyAxis1ContigFactory
84+
{
85+
fnT get() const
86+
{
87+
using dstTy = std::int32_t;
88+
using ReductionOpT = sycl::logical_or<dstTy>;
89+
return dpctl::tensor::kernels::
90+
reduction_axis1_over_group_with_atomics_contig_impl<srcTy, dstTy,
91+
ReductionOpT>;
92+
}
93+
};
94+
95+
template <typename fnT, typename srcTy>
96+
struct AnyAxis0ContigFactory
97+
{
98+
fnT get() const
99+
{
100+
using dstTy = std::int32_t;
101+
using ReductionOpT = sycl::logical_or<dstTy>;
102+
return dpctl::tensor::kernels::
103+
reduction_axis0_over_group_with_atomics_contig_impl<srcTy, dstTy,
104+
ReductionOpT>;
105+
}
106+
};
107+
108+
void populate_any_dispatch_vectors(void)
109+
{
110+
using td_ns::DispatchVectorBuilder;
111+
112+
DispatchVectorBuilder<reduction_strided_impl_fn_ptr, AnyStridedFactory,
113+
td_ns::num_types>
114+
any_dvb1;
115+
any_dvb1.populate_dispatch_vector(any_reduction_strided_dispatch_vector);
116+
117+
DispatchVectorBuilder<reduction_contig_impl_fn_ptr, AnyAxis1ContigFactory,
118+
td_ns::num_types>
119+
any_dvb2;
120+
any_dvb2.populate_dispatch_vector(
121+
any_reduction_axis1_contig_dispatch_vector);
122+
123+
DispatchVectorBuilder<reduction_contig_impl_fn_ptr, AnyAxis0ContigFactory,
124+
td_ns::num_types>
125+
any_dvb3;
126+
any_dvb3.populate_dispatch_vector(
127+
any_reduction_axis0_contig_dispatch_vector);
128+
};
129+
130+
using atomic_support::atomic_support_fn_ptr_t;
131+
using atomic_support::check_atomic_support;
132+
static atomic_support_fn_ptr_t any_atomic_support =
133+
check_atomic_support<std::int32_t>;
134+
135+
} // namespace impl
136+
137+
void init_any(py::module_ m)
138+
{
139+
using arrayT = dpctl::tensor::usm_ndarray;
140+
using event_vecT = std::vector<sycl::event>;
141+
{
142+
impl::populate_any_dispatch_vectors();
143+
using impl::any_reduction_axis0_contig_dispatch_vector;
144+
using impl::any_reduction_axis1_contig_dispatch_vector;
145+
using impl::any_reduction_strided_dispatch_vector;
146+
147+
using impl::any_atomic_support;
148+
149+
auto any_pyapi = [&](const arrayT &src, int trailing_dims_to_reduce,
150+
const arrayT &dst, sycl::queue &exec_q,
151+
const event_vecT &depends = {}) {
152+
return py_boolean_reduction(
153+
src, trailing_dims_to_reduce, dst, exec_q, depends,
154+
any_reduction_axis1_contig_dispatch_vector,
155+
any_reduction_axis0_contig_dispatch_vector,
156+
any_reduction_strided_dispatch_vector, any_atomic_support);
157+
};
158+
m.def("_any", any_pyapi, "", py::arg("src"),
159+
py::arg("trailing_dims_to_reduce"), py::arg("dst"),
160+
py::arg("sycl_queue"), py::arg("depends") = py::list());
161+
}
162+
}
163+
164+
} // namespace dpctl::tensor::py_internal
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===---------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_reductions_impl
33+
/// extension.
34+
//===---------------------------------------------------------------------===//
35+
36+
#pragma once
37+
#include <pybind11/pybind11.h>
38+
39+
namespace py = pybind11;
40+
41+
namespace dpctl::tensor::py_internal
42+
{
43+
44+
extern void init_any(py::module_ m);
45+
46+
} // namespace dpctl::tensor::py_internal

dpctl_ext/tensor/libtensor/source/reductions/reduction_common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include <pybind11/pybind11.h>
3737

3838
#include "all.hpp"
39-
// #include "any.hpp"
39+
#include "any.hpp"
4040
// #include "argmax.hpp"
4141
// #include "argmin.hpp"
4242
// #include "logsumexp.hpp"
@@ -55,7 +55,7 @@ namespace dpctl::tensor::py_internal
5555
void init_reduction_functions(py::module_ m)
5656
{
5757
init_all(m);
58-
// init_any(m);
58+
init_any(m);
5959
// init_argmax(m);
6060
// init_argmin(m);
6161
// init_logsumexp(m);

0 commit comments

Comments
 (0)