Skip to content

Commit dba06f8

Browse files
authored
Merge pull request #2277 from adriendelsalle/remove-reducer-default-promoted-type
Remove reducer big_promote_type
2 parents 2e3f1e6 + 0063981 commit dba06f8

6 files changed

Lines changed: 162 additions & 90 deletions

File tree

include/xtensor/xexpression_traits.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,19 @@ namespace xt
180180

181181
template <class... C>
182182
using common_tensor_type_t = typename common_tensor_type<C...>::type;
183+
184+
/**************************
185+
* big_promote_value_type *
186+
**************************/
187+
188+
template <class E>
189+
struct big_promote_value_type
190+
{
191+
using type = xtl::big_promote_type_t<typename std::decay_t<E>::value_type>;
192+
};
193+
194+
template <class E>
195+
using big_promote_value_type_t = typename big_promote_value_type<E>::type;
183196
}
184197

185198
#endif

include/xtensor/xmath.hpp

Lines changed: 122 additions & 77 deletions
Large diffs are not rendered by default.

include/xtensor/xoperation.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ constexpr auto operator OP(const std::complex<T1>& arg1, const std::complex<T2>&
4949
{ \
5050
using result_type = typename xtl::promote_type_t<std::complex<T1>, std::complex<T2>>; \
5151
return (result_type(arg1) OP result_type(arg2)); \
52+
} \
53+
\
54+
template <class T1, class T2, XTL_REQUIRES(xtl::negation<std::is_same<T1, T2>>)> \
55+
constexpr auto operator OP(const T1& arg1, const std::complex<T2>& arg2) \
56+
{ \
57+
using result_type = typename xtl::promote_type_t<T1, std::complex<T2>>; \
58+
return (result_type(arg1) OP result_type(arg2)); \
59+
} \
60+
\
61+
template <class T1, class T2, XTL_REQUIRES(xtl::negation<std::is_same<T1, T2>>)> \
62+
constexpr auto operator OP(const std::complex<T1>& arg1, const T2& arg2) \
63+
{ \
64+
using result_type = typename xtl::promote_type_t<std::complex<T1>, T2>; \
65+
return (result_type(arg1) OP result_type(arg2)); \
5266
}
5367

5468
#define BINARY_OPERATOR_FUNCTOR(NAME, OP) \

test/test_xaccumulator.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ namespace xt
2222
TEST(xaccumulator, one_d)
2323
{
2424
xt::xarray<short> a = { short(1), short(2), short(3), short(4)};
25-
xt::xarray<long> expected = { 1, 3, 6, 10};
25+
xt::xarray<int> expected = { 1, 3, 6, 10};
2626
auto no_axis = cumsum(a);
2727
auto with_axis = cumsum(a, 0);
28-
bool promotion_works = std::is_same<decltype(no_axis)::value_type, long long>::value;
28+
bool promotion_works = std::is_same<decltype(no_axis)::value_type, short>::value;
2929
EXPECT_TRUE(promotion_works);
3030
EXPECT_TRUE(all(equal(no_axis, expected)));
3131

@@ -208,7 +208,7 @@ namespace xt
208208
TEST(xaccumulator, xfixed)
209209
{
210210
xtensor_fixed<float, xshape<2, 4, 3>> a = xt::random::rand<float>({2, 4, 3});
211-
auto res = cumsum(a, 1);
211+
auto res = cumsum<double>(a, 1);
212212

213213
bool truth = std::is_same<decltype(res), xtensor_fixed<double, xshape<2, 4, 3>>>::value;
214214
EXPECT_TRUE(truth);

test/test_xmath_result_type.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ void check_promoted_types(E&& e)
125125
CHECK_RESULT_TYPE(2.0 * auchar, double);
126126
CHECK_RESULT_TYPE(sqrt(auchar), double);
127127
CHECK_RESULT_TYPE(abs(auchar), unsigned char);
128-
CHECK_RESULT_TYPE(sum(auchar), unsigned long long);
128+
CHECK_RESULT_TYPE(sum(auchar), int);
129129
CHECK_RESULT_TYPE(mean(auchar), double);
130130
CHECK_RESULT_TYPE(minmax(auchar), ARRAY_TYPE(unsigned char));
131131
CHECK_TEMPLATED_RESULT_TYPE_FOR_ALL(auchar);
@@ -141,7 +141,7 @@ void check_promoted_types(E&& e)
141141
CHECK_RESULT_TYPE(2.0 * ashort, double);
142142
CHECK_RESULT_TYPE(sqrt(ashort), double);
143143
CHECK_RESULT_TYPE(abs(ashort), decltype(std::abs(short{})));
144-
CHECK_RESULT_TYPE(sum(ashort), long long);
144+
CHECK_RESULT_TYPE(sum(ashort), int);
145145
CHECK_RESULT_TYPE(mean(ashort), double);
146146
CHECK_RESULT_TYPE(minmax(ashort), ARRAY_TYPE(short));
147147
CHECK_TEMPLATED_RESULT_TYPE_FOR_ALL(ashort);
@@ -157,7 +157,7 @@ void check_promoted_types(E&& e)
157157
CHECK_RESULT_TYPE(2.0 * aushort, double);
158158
CHECK_RESULT_TYPE(sqrt(aushort), double);
159159
CHECK_RESULT_TYPE(abs(aushort), unsigned short);
160-
CHECK_RESULT_TYPE(sum(aushort), unsigned long long);
160+
CHECK_RESULT_TYPE(sum(aushort), int);
161161
CHECK_RESULT_TYPE(mean(aushort), double);
162162
CHECK_RESULT_TYPE(minmax(aushort), ARRAY_TYPE(unsigned short));
163163
CHECK_TEMPLATED_RESULT_TYPE_FOR_ALL(aushort);
@@ -173,7 +173,7 @@ void check_promoted_types(E&& e)
173173
CHECK_RESULT_TYPE(2.0 * aint, double);
174174
CHECK_RESULT_TYPE(sqrt(aint), double);
175175
CHECK_RESULT_TYPE(abs(aint), int);
176-
CHECK_RESULT_TYPE(sum(aint), long long);
176+
CHECK_RESULT_TYPE(sum(aint), int);
177177
CHECK_RESULT_TYPE(mean(aint), double);
178178
CHECK_RESULT_TYPE(minmax(aint), ARRAY_TYPE(int));
179179
CHECK_TEMPLATED_RESULT_TYPE_FOR_ALL(aint);
@@ -189,7 +189,7 @@ void check_promoted_types(E&& e)
189189
CHECK_RESULT_TYPE(2.0 * auint, double);
190190
CHECK_RESULT_TYPE(sqrt(auint), double);
191191
CHECK_RESULT_TYPE(abs(auint), unsigned int);
192-
CHECK_RESULT_TYPE(sum(auint), unsigned long long);
192+
CHECK_RESULT_TYPE(sum(auint), unsigned int);
193193
CHECK_RESULT_TYPE(mean(auint), double);
194194
CHECK_RESULT_TYPE(minmax(auint), ARRAY_TYPE(unsigned int));
195195
CHECK_TEMPLATED_RESULT_TYPE_FOR_ALL(auint);
@@ -237,7 +237,7 @@ void check_promoted_types(E&& e)
237237
CHECK_RESULT_TYPE(2.0 * afloat, double);
238238
CHECK_RESULT_TYPE(sqrt(afloat), float);
239239
CHECK_RESULT_TYPE(abs(afloat), float);
240-
CHECK_RESULT_TYPE(sum(afloat), double);
240+
CHECK_RESULT_TYPE(sum(afloat), float);
241241
CHECK_RESULT_TYPE(mean(afloat), double);
242242
CHECK_RESULT_TYPE(minmax(afloat), ARRAY_TYPE(float));
243243
CHECK_TEMPLATED_RESULT_TYPE_FOR_ALL(afloat);
@@ -268,7 +268,7 @@ void check_promoted_types(E&& e)
268268
CHECK_RESULT_TYPE(2.0f * afcomplex, std::complex<float>);
269269
CHECK_RESULT_TYPE(sqrt(afcomplex), std::complex<float>);
270270
CHECK_RESULT_TYPE(abs(afcomplex), float);
271-
CHECK_RESULT_TYPE(sum(afcomplex), std::complex<double>);
271+
CHECK_RESULT_TYPE(sum(afcomplex), std::complex<float>);
272272
CHECK_RESULT_TYPE(mean(afcomplex), std::complex<double>);
273273
}
274274

test/test_xreducer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ namespace xt
498498
{
499499
// check that there is no overflow
500500
xarray<uint8_t> c = 2 * ones<uint8_t>({34});
501-
EXPECT_EQ(1ULL << 34, prod(c)());
501+
EXPECT_EQ(1ULL << 34, prod<long long>(c)());
502502
}
503503

504504
#define TEST_OPT_PROD(INPUT) \
@@ -868,9 +868,9 @@ namespace xt
868868
EXPECT_TRUE(b_fx_2 == sum(c, {0, 1}));
869869
EXPECT_EQ(b_fx_3, sum(c, {0, 1, 2}));
870870

871-
truth = std::is_same<std::decay_t<decltype(b_fx_1)>, xtensor_fixed<long long, xshape<5>>>::value;
871+
truth = std::is_same<std::decay_t<decltype(b_fx_1)>, xtensor_fixed<int, xshape<5>>>::value;
872872
EXPECT_TRUE(truth);
873-
truth = std::is_same<std::decay_t<decltype(b_fx_3)>, xtensor_fixed<long long, xshape<>>>::value;
873+
truth = std::is_same<std::decay_t<decltype(b_fx_3)>, xtensor_fixed<int, xshape<>>>::value;
874874
EXPECT_TRUE(truth);
875875

876876
truth = std::is_same<xshape<1, 3>, typename fixed_xreducer_shape_type<xshape<1, 5, 3>, xshape<1>>::type>();

0 commit comments

Comments
 (0)