Skip to content

Commit 6c8ab10

Browse files
committed
Added template parameter for initial value type in accumulators
1 parent 11a4425 commit 6c8ab10

5 files changed

Lines changed: 42 additions & 43 deletions

File tree

docs/source/operator.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,21 @@ function. For example, the implementation of cumsum is as follows:
284284
arr,
285285
1);
286286
287+
Like reducers, accumulators accept a template parameter to specify the ``value_type``
288+
of the initial value of the accumulation. The ``value_type`` of the result is computed
289+
with the same rules as those for reducers:
290+
291+
.. code::
292+
293+
#include "xtensor/xarray.hpp"
294+
#include "xtensor/xaccumulator.hpp"
295+
296+
xt::xarray<int> arr = some_init_function({5, 5, 5});
297+
auto r1 = xt::cumsum<short>(a, 1);
298+
// r1 holds int values
299+
auto r2 = xt::cumsum<long int>(a, 1);
300+
// r2 hols long int values
301+
287302
Evaluation strategy
288303
-------------------
289304

docs/source/quickref/reducer.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ Sum
3434
// r4 holds long int values
3535
3636
auto r5 = xt::sum<short>(a, {1});
37-
// r5 hols int values
37+
// r5 holds int values
3838
39-
auto r6 = xt::sum<xt::big_promote_value_type<decltype(a)>>(a, {1});
39+
auto r6 = xt::sum<xt::big_promote_value_type_t<decltype(a)>>(a, {1});
4040
// r6 holds long long int values
4141
4242
Prod
@@ -50,7 +50,7 @@ Prod
5050
int r2 = xt::prod(a)();
5151
auto r3 = xt::prod(a, {0});
5252
auro r4 = xt::prod<long int>(a, {0});
53-
auto r5 = xt::prod<xt::big_promote_value_type<decltype(a)>>(a, {1});
53+
auto r5 = xt::prod<xt::big_promote_value_type_t<decltype(a)>>(a, {1});
5454
5555
Mean
5656
----

include/xtensor/xaccumulator.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,12 @@ namespace xt
199199
{
200200
using init_type = typename F::init_value_type;
201201
using init_functor_type = typename F::init_functor_type;
202-
using return_type = std::conditional_t<std::is_same<init_type, void>::value, typename std::decay_t<E>::value_type, init_type>;
202+
using accumulate_functor_type = typename F::accumulate_functor_type;
203+
using expr_value_type = typename std::decay_t<E>::value_type;
204+
//using return_type = std::conditional_t<std::is_same<init_type, void>::value, typename std::decay_t<E>::value_type, init_type>;
205+
206+
using return_type = std::decay_t<decltype(std::declval<accumulate_functor_type>()(std::declval<init_type>(),
207+
std::declval<expr_value_type>()))>;
203208
using result_type = xaccumulator_return_type_t<std::decay_t<E>, return_type>;
204209

205210
if (axis >= e.dimension())
@@ -267,7 +272,11 @@ namespace xt
267272
inline auto accumulator_impl(F&& f, E&& e, evaluation_strategy::immediate_type)
268273
{
269274
using init_type = typename F::init_value_type;
270-
using return_type = std::conditional_t<std::is_same<init_type, void>::value, typename std::decay_t<E>::value_type, init_type>;
275+
using expr_value_type = typename std::decay_t<E>::value_type;
276+
using accumulate_functor_type = typename F::accumulate_functor_type;
277+
using return_type = std::decay_t<decltype(std::declval<accumulate_functor_type>()(std::declval<init_type>(),
278+
std::declval<expr_value_type>()))>;
279+
//using return_type = std::conditional_t<std::is_same<init_type, void>::value, typename std::decay_t<E>::value_type, init_type>;
271280
using result_type = xaccumulator_return_type_t<std::decay_t<E>, return_type>;
272281

273282
std::size_t sz = e.size();

include/xtensor/xmath.hpp

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,37 +2357,6 @@ namespace detail {
23572357
return detail::make_xfunction<detail::nan_to_num_functor>(std::forward<E>(e));
23582358
}
23592359

2360-
#define XTENSOR_NAN_REDUCER_FUNCTION(NAME, FUNCTOR, RESULT_TYPE, NAN) \
2361-
template <class T = void, class E, class X, class EVS = DEFAULT_STRATEGY_REDUCERS, \
2362-
XTL_REQUIRES(xtl::negation<is_reducer_options<X>>)> \
2363-
inline auto NAME(E&& e, X&& axes, EVS es = EVS()) \
2364-
{ \
2365-
using result_type = std::conditional_t<std::is_same<T, void>::value, RESULT_TYPE, T>; \
2366-
using functor_type = FUNCTOR<result_type>; \
2367-
using init_functor_type = detail::nan_init<result_type, NAN>; \
2368-
return xt::reduce(make_xreducer_functor(functor_type(), init_functor_type()), std::forward<E>(e), \
2369-
std::forward<X>(axes), es); \
2370-
} \
2371-
\
2372-
template <class T = void, class E, class EVS = DEFAULT_STRATEGY_REDUCERS, \
2373-
XTL_REQUIRES(is_reducer_options<EVS>)> \
2374-
inline auto NAME(E&& e, EVS es = EVS()) \
2375-
{ \
2376-
using result_type = std::conditional_t<std::is_same<T, void>::value, RESULT_TYPE, T>; \
2377-
using functor_type = FUNCTOR<result_type>; \
2378-
using init_functor_type = detail::nan_init<result_type, NAN>; \
2379-
return xt::reduce(make_xreducer_functor(functor_type(), init_functor_type()), std::forward<E>(e), es); \
2380-
} \
2381-
\
2382-
template <class T = void, class E, class I, std::size_t N, class EVS = DEFAULT_STRATEGY_REDUCERS> \
2383-
inline auto NAME(E&& e, const I (&axes)[N], EVS es = EVS()) \
2384-
{ \
2385-
using result_type = std::conditional_t<std::is_same<T, void>::value, RESULT_TYPE, T>; \
2386-
using functor_type = FUNCTOR<result_type>; \
2387-
using init_functor_type = detail::nan_init<result_type, NAN>; \
2388-
return xt::reduce(make_xreducer_functor(functor_type(), init_functor_type()), std::forward<E>(e), axes, es); \
2389-
}
2390-
23912360
/**
23922361
* @ingroup nan_functions
23932362
* @brief Sum of elements over given axes, replacing nan with 0.
@@ -2397,7 +2366,9 @@ namespace detail {
23972366
* @param e an \ref xexpression
23982367
* @param axes the axes along which the sum is performed (optional)
23992368
* @param es evaluation strategy of the reducer (optional)
2400-
* @tparam T the result type. The default is `E::value_type`.
2369+
* @tparam T the value type used for internal computation. The default is
2370+
* `E::value_type`. `T` is also used for determining the value type
2371+
* of the result, which is the type of `T() + E::value_type()`.
24012372
* You can pass `big_promote_value_type_t<E>` to avoid overflow in computation.
24022373
* @return an \ref xreducer
24032374
*/
@@ -2412,14 +2383,14 @@ namespace detail {
24122383
* @param e an \ref xexpression
24132384
* @param axes the axes along which the sum is performed (optional)
24142385
* @param es evaluation strategy of the reducer (optional)
2415-
* @tparam T the result type. The default is `E::value_type`.
2386+
* @tparam T the value type used for internal computation. The default is
2387+
* `E::value_type`. `T` is also used for determining the value type
2388+
* of the result, which is the type of `T() * E::value_type()`.
24162389
* You can pass `big_promote_value_type_t<E>` to avoid overflow in computation.
24172390
* @return an \ref xreducer
24182391
*/
24192392
XTENSOR_REDUCER_FUNCTION(nanprod, detail::nan_multiplies, typename std::decay_t<E>::value_type, 1)
24202393

2421-
#undef XTENSOR_NAN_REDUCER_FUNCTION
2422-
24232394
#define COUNT_NON_ZEROS_CONTENT \
24242395
using value_type = typename std::decay_t<E>::value_type; \
24252396
using result_type = xt::detail::xreducer_size_type_t<value_type>; \
@@ -2507,7 +2478,9 @@ namespace detail {
25072478
* \em axis, replacing nan with 0.
25082479
* @param e an \ref xexpression
25092480
* @param axis the axis along which the elements are accumulated (optional)
2510-
* @tparam T the result type. The default is `E::value_type`.
2481+
* @tparam T the value type used for internal computation. The default is
2482+
* `E::value_type`. `T` is also used for determining the value type
2483+
* of the result, which is the type of `T() + E::value_type()`.
25112484
* You can pass `big_promote_value_type_t<E>` to avoid overflow in computation.
25122485
* @return an xaccumulator
25132486
*/
@@ -2533,7 +2506,9 @@ namespace detail {
25332506
* \em axis, replacing nan with 1.
25342507
* @param e an \ref xexpression
25352508
* @param axis the axis along which the elements are accumulated (optional)
2536-
* @tparam T the result type. The default is `E::value_type`.
2509+
* @tparam T the value type used for internal computation. The default is
2510+
* `E::value_type`. `T` is also used for determining the value type
2511+
* of the result, which is the type of `T() * E::value_type()`.
25372512
* You can pass `big_promote_value_type_t<E>` to avoid overflow in computation.
25382513
* @return an xaccumulator
25392514
*/

test/test_xaccumulator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace xt
2525
xt::xarray<int> expected = { 1, 3, 6, 10};
2626
auto no_axis = cumsum(a);
2727
auto with_axis = cumsum(a, 0);
28-
bool promotion_works = std::is_same<decltype(no_axis)::value_type, short>::value;
28+
bool promotion_works = std::is_same<decltype(no_axis)::value_type, int>::value;
2929
EXPECT_TRUE(promotion_works);
3030
EXPECT_TRUE(all(equal(no_axis, expected)));
3131

0 commit comments

Comments
 (0)