Skip to content

Commit 5af94c8

Browse files
Move _cumprod_over_axis to dpctl_ext.tensor._tensor_accumulation_impl
1 parent 6b81e7a commit 5af94c8

4 files changed

Lines changed: 408 additions & 3 deletions

File tree

dpctl_ext/tensor/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ set(_tensor_impl_sources
6666
set(_accumulator_sources
6767
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/accumulators_common.cpp
6868
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_logsumexp.cpp
69-
#${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
69+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_prod.cpp
7070
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/accumulators/cumulative_sum.cpp
7171
)
7272
set(_tensor_accumulation_impl_sources

dpctl_ext/tensor/libtensor/source/accumulators/accumulators_common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
#include <pybind11/pybind11.h>
3737

3838
// #include "cumulative_logsumexp.hpp"
39-
// #include "cumulative_prod.hpp"
39+
#include "cumulative_prod.hpp"
4040
#include "cumulative_sum.hpp"
4141

4242
namespace py = pybind11;
@@ -48,7 +48,7 @@ namespace dpctl::tensor::py_internal
4848
void init_accumulator_functions(py::module_ m)
4949
{
5050
// init_cumulative_logsumexp(m);
51-
// init_cumulative_prod(m);
51+
init_cumulative_prod(m);
5252
init_cumulative_sum(m);
5353
}
5454

Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2026, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
//
29+
//===----------------------------------------------------------------------===//
30+
///
31+
/// \file
32+
/// This file defines functions of dpctl.tensor._tensor_accumulation_impl
33+
// extensions
34+
//===----------------------------------------------------------------------===//
35+
36+
#include <complex>
37+
#include <cstdint>
38+
#include <type_traits>
39+
#include <vector>
40+
41+
#include <sycl/sycl.hpp>
42+
43+
#include "dpnp4pybind11.hpp"
44+
#include <pybind11/numpy.h>
45+
#include <pybind11/pybind11.h>
46+
47+
#include "accumulate_over_axis.hpp"
48+
#include "kernels/accumulators.hpp"
49+
#include "utils/type_dispatch_building.hpp"
50+
51+
namespace py = pybind11;
52+
53+
namespace dpctl::tensor::py_internal
54+
{
55+
56+
namespace td_ns = dpctl::tensor::type_dispatch;
57+
58+
namespace impl
59+
{
60+
61+
using dpctl::tensor::kernels::accumulators::accumulate_1d_contig_impl_fn_ptr_t;
62+
static accumulate_1d_contig_impl_fn_ptr_t
63+
cumprod_1d_contig_dispatch_table[td_ns::num_types][td_ns::num_types];
64+
65+
using dpctl::tensor::kernels::accumulators::accumulate_strided_impl_fn_ptr_t;
66+
static accumulate_strided_impl_fn_ptr_t
67+
cumprod_strided_dispatch_table[td_ns::num_types][td_ns::num_types];
68+
69+
static accumulate_1d_contig_impl_fn_ptr_t
70+
cumprod_1d_include_initial_contig_dispatch_table[td_ns::num_types]
71+
[td_ns::num_types];
72+
73+
static accumulate_strided_impl_fn_ptr_t
74+
cumprod_include_initial_strided_dispatch_table[td_ns::num_types]
75+
[td_ns::num_types];
76+
77+
template <typename argTy, typename outTy>
78+
struct TypePairSupportDataForProdAccumulation
79+
{
80+
static constexpr bool is_defined = std::disjunction<
81+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
82+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int32_t>,
83+
td_ns::TypePairDefinedEntry<argTy, bool, outTy, std::int64_t>,
84+
85+
// input int8_t
86+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,
87+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int32_t>,
88+
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int64_t>,
89+
90+
// input uint8_t
91+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint8_t>,
92+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint32_t>,
93+
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint64_t>,
94+
95+
// input int16_t
96+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int16_t>,
97+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int32_t>,
98+
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int64_t>,
99+
100+
// input uint16_t
101+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint16_t>,
102+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint32_t>,
103+
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint64_t>,
104+
105+
// input int32_t
106+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
107+
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int64_t>,
108+
109+
// input uint32_t
110+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint32_t>,
111+
td_ns::TypePairDefinedEntry<argTy, std::uint32_t, outTy, std::uint64_t>,
112+
113+
// input int64_t
114+
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
115+
116+
// input uint64_t
117+
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
118+
119+
// input half
120+
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
121+
122+
// input float
123+
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
124+
125+
// input double
126+
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
127+
128+
// input std::complex
129+
td_ns::TypePairDefinedEntry<argTy,
130+
std::complex<float>,
131+
outTy,
132+
std::complex<float>>,
133+
134+
td_ns::TypePairDefinedEntry<argTy,
135+
std::complex<double>,
136+
outTy,
137+
std::complex<double>>,
138+
139+
// fall-through
140+
td_ns::NotDefinedEntry>::is_defined;
141+
};
142+
143+
template <typename T>
144+
using CumProdScanOpT = std::conditional_t<std::is_same_v<T, bool>,
145+
sycl::logical_and<T>,
146+
sycl::multiplies<T>>;
147+
148+
template <typename fnT, typename srcTy, typename dstTy>
149+
struct CumProd1DContigFactory
150+
{
151+
fnT get()
152+
{
153+
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
154+
dstTy>::is_defined)
155+
{
156+
using ScanOpT = CumProdScanOpT<dstTy>;
157+
static constexpr bool include_initial = false;
158+
if constexpr (std::is_same_v<srcTy, dstTy>) {
159+
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
160+
fnT fn = dpctl::tensor::kernels::accumulators::
161+
accumulate_1d_contig_impl<srcTy, dstTy,
162+
NoOpTransformer<dstTy>, ScanOpT,
163+
include_initial>;
164+
return fn;
165+
}
166+
else {
167+
using dpctl::tensor::kernels::accumulators::CastTransformer;
168+
fnT fn = dpctl::tensor::kernels::accumulators::
169+
accumulate_1d_contig_impl<srcTy, dstTy,
170+
CastTransformer<srcTy, dstTy>,
171+
ScanOpT, include_initial>;
172+
return fn;
173+
}
174+
}
175+
else {
176+
return nullptr;
177+
}
178+
}
179+
};
180+
181+
template <typename fnT, typename srcTy, typename dstTy>
182+
struct CumProd1DIncludeInitialContigFactory
183+
{
184+
fnT get()
185+
{
186+
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
187+
dstTy>::is_defined)
188+
{
189+
using ScanOpT = CumProdScanOpT<dstTy>;
190+
static constexpr bool include_initial = true;
191+
if constexpr (std::is_same_v<srcTy, dstTy>) {
192+
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
193+
fnT fn = dpctl::tensor::kernels::accumulators::
194+
accumulate_1d_contig_impl<srcTy, dstTy,
195+
NoOpTransformer<dstTy>, ScanOpT,
196+
include_initial>;
197+
return fn;
198+
}
199+
else {
200+
using dpctl::tensor::kernels::accumulators::CastTransformer;
201+
fnT fn = dpctl::tensor::kernels::accumulators::
202+
accumulate_1d_contig_impl<srcTy, dstTy,
203+
CastTransformer<srcTy, dstTy>,
204+
ScanOpT, include_initial>;
205+
return fn;
206+
}
207+
}
208+
else {
209+
return nullptr;
210+
}
211+
}
212+
};
213+
214+
template <typename fnT, typename srcTy, typename dstTy>
215+
struct CumProdStridedFactory
216+
{
217+
fnT get()
218+
{
219+
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
220+
dstTy>::is_defined)
221+
{
222+
using ScanOpT = CumProdScanOpT<dstTy>;
223+
static constexpr bool include_initial = false;
224+
if constexpr (std::is_same_v<srcTy, dstTy>) {
225+
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
226+
fnT fn = dpctl::tensor::kernels::accumulators::
227+
accumulate_strided_impl<srcTy, dstTy,
228+
NoOpTransformer<dstTy>, ScanOpT,
229+
include_initial>;
230+
return fn;
231+
}
232+
else {
233+
using dpctl::tensor::kernels::accumulators::CastTransformer;
234+
fnT fn = dpctl::tensor::kernels::accumulators::
235+
accumulate_strided_impl<srcTy, dstTy,
236+
CastTransformer<srcTy, dstTy>,
237+
ScanOpT, include_initial>;
238+
return fn;
239+
}
240+
}
241+
else {
242+
return nullptr;
243+
}
244+
}
245+
};
246+
247+
template <typename fnT, typename srcTy, typename dstTy>
248+
struct CumProdIncludeInitialStridedFactory
249+
{
250+
fnT get()
251+
{
252+
if constexpr (TypePairSupportDataForProdAccumulation<srcTy,
253+
dstTy>::is_defined)
254+
{
255+
using ScanOpT = CumProdScanOpT<dstTy>;
256+
static constexpr bool include_initial = true;
257+
if constexpr (std::is_same_v<srcTy, dstTy>) {
258+
using dpctl::tensor::kernels::accumulators::NoOpTransformer;
259+
fnT fn = dpctl::tensor::kernels::accumulators::
260+
accumulate_strided_impl<srcTy, dstTy,
261+
NoOpTransformer<dstTy>, ScanOpT,
262+
include_initial>;
263+
return fn;
264+
}
265+
else {
266+
using dpctl::tensor::kernels::accumulators::CastTransformer;
267+
fnT fn = dpctl::tensor::kernels::accumulators::
268+
accumulate_strided_impl<srcTy, dstTy,
269+
CastTransformer<srcTy, dstTy>,
270+
ScanOpT, include_initial>;
271+
return fn;
272+
}
273+
}
274+
else {
275+
return nullptr;
276+
}
277+
}
278+
};
279+
280+
void populate_cumprod_dispatch_tables(void)
281+
{
282+
td_ns::DispatchTableBuilder<accumulate_1d_contig_impl_fn_ptr_t,
283+
CumProd1DContigFactory, td_ns::num_types>
284+
dtb1;
285+
dtb1.populate_dispatch_table(cumprod_1d_contig_dispatch_table);
286+
287+
td_ns::DispatchTableBuilder<accumulate_strided_impl_fn_ptr_t,
288+
CumProdStridedFactory, td_ns::num_types>
289+
dtb2;
290+
dtb2.populate_dispatch_table(cumprod_strided_dispatch_table);
291+
292+
td_ns::DispatchTableBuilder<accumulate_1d_contig_impl_fn_ptr_t,
293+
CumProd1DIncludeInitialContigFactory,
294+
td_ns::num_types>
295+
dtb3;
296+
dtb3.populate_dispatch_table(
297+
cumprod_1d_include_initial_contig_dispatch_table);
298+
299+
td_ns::DispatchTableBuilder<accumulate_strided_impl_fn_ptr_t,
300+
CumProdIncludeInitialStridedFactory,
301+
td_ns::num_types>
302+
dtb4;
303+
dtb4.populate_dispatch_table(
304+
cumprod_include_initial_strided_dispatch_table);
305+
306+
return;
307+
}
308+
309+
} // namespace impl
310+
311+
void init_cumulative_prod(py::module_ m)
312+
{
313+
using arrayT = dpctl::tensor::usm_ndarray;
314+
using event_vecT = std::vector<sycl::event>;
315+
316+
using impl::populate_cumprod_dispatch_tables;
317+
populate_cumprod_dispatch_tables();
318+
319+
using impl::cumprod_1d_contig_dispatch_table;
320+
using impl::cumprod_strided_dispatch_table;
321+
auto cumprod_pyapi = [&](const arrayT &src, int trailing_dims_to_accumulate,
322+
const arrayT &dst, sycl::queue &exec_q,
323+
const event_vecT &depends = {}) {
324+
using dpctl::tensor::py_internal::py_accumulate_over_axis;
325+
return py_accumulate_over_axis(
326+
src, trailing_dims_to_accumulate, dst, exec_q, depends,
327+
cumprod_strided_dispatch_table, cumprod_1d_contig_dispatch_table);
328+
};
329+
m.def("_cumprod_over_axis", cumprod_pyapi, "", py::arg("src"),
330+
py::arg("trailing_dims_to_accumulate"), py::arg("dst"),
331+
py::arg("sycl_queue"), py::arg("depends") = py::list());
332+
333+
using impl::cumprod_1d_include_initial_contig_dispatch_table;
334+
using impl::cumprod_include_initial_strided_dispatch_table;
335+
auto cumprod_include_initial_pyapi =
336+
[&](const arrayT &src, const arrayT &dst, sycl::queue &exec_q,
337+
const event_vecT &depends = {}) {
338+
using dpctl::tensor::py_internal::
339+
py_accumulate_final_axis_include_initial;
340+
return py_accumulate_final_axis_include_initial(
341+
src, dst, exec_q, depends,
342+
cumprod_include_initial_strided_dispatch_table,
343+
cumprod_1d_include_initial_contig_dispatch_table);
344+
};
345+
m.def("_cumprod_final_axis_include_initial", cumprod_include_initial_pyapi,
346+
"", py::arg("src"), py::arg("dst"), py::arg("sycl_queue"),
347+
py::arg("depends") = py::list());
348+
349+
auto cumprod_dtype_supported = [&](const py::dtype &input_dtype,
350+
const py::dtype &output_dtype) {
351+
using dpctl::tensor::py_internal::py_accumulate_dtype_supported;
352+
return py_accumulate_dtype_supported(input_dtype, output_dtype,
353+
cumprod_strided_dispatch_table);
354+
};
355+
m.def("_cumprod_dtype_supported", cumprod_dtype_supported, "",
356+
py::arg("arg_dtype"), py::arg("out_dtype"));
357+
}
358+
359+
} // namespace dpctl::tensor::py_internal

0 commit comments

Comments
 (0)