Skip to content

Commit 7aa868e

Browse files
committed
Fix xsort calls over axis
1 parent 13f3221 commit 7aa868e

2 files changed

Lines changed: 34 additions & 52 deletions

File tree

include/xtensor/xsort.hpp

Lines changed: 30 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -46,35 +46,39 @@ namespace xt
4646
return stride != 0 ? stride : static_cast<std::ptrdiff_t>(shape);
4747
}
4848

49-
template <class E, class F>
50-
inline void call_over_leading_axis(E& ev, F&& fct)
49+
template <class E>
50+
inline std::ptrdiff_t get_secondary_stride(const E& ev)
5151
{
52-
std::size_t n_iters = 1;
53-
std::ptrdiff_t secondary_stride;
52+
if (ev.layout() == layout_type::row_major)
53+
{
54+
return adjust_secondary_stride(ev.strides()[ev.dimension() - 2], *(ev.shape().end() - 1));
55+
}
56+
57+
return adjust_secondary_stride(ev.strides()[1], *(ev.shape().begin()));
58+
}
5459

60+
template <class E>
61+
inline std::size_t leading_axis_n_iters(const E& ev)
62+
{
5563
if (ev.layout() == layout_type::row_major)
5664
{
57-
n_iters = std::accumulate(
65+
return std::accumulate(
5866
ev.shape().begin(),
5967
ev.shape().end() - 1,
6068
std::size_t(1),
6169
std::multiplies<>()
6270
);
63-
secondary_stride = adjust_secondary_stride(
64-
ev.strides()[ev.dimension() - 2],
65-
*(ev.shape().end() - 1)
66-
);
67-
}
68-
else
69-
{
70-
n_iters = std::accumulate(
71-
ev.shape().begin() + 1,
72-
ev.shape().end(),
73-
std::size_t(1),
74-
std::multiplies<>()
75-
);
76-
secondary_stride = adjust_secondary_stride(ev.strides()[1], *(ev.shape().begin()));
7771
}
72+
return std::accumulate(ev.shape().begin() + 1, ev.shape().end(), std::size_t(1), std::multiplies<>());
73+
}
74+
75+
template <class E, class F>
76+
inline void call_over_leading_axis(E& ev, F&& fct)
77+
{
78+
XTENSOR_ASSERT(ev.dimension() >= 2);
79+
80+
const std::size_t n_iters = leading_axis_n_iters(ev);
81+
const std::ptrdiff_t secondary_stride = get_secondary_stride(ev);
7882

7983
const auto begin = ev.data();
8084
const auto end = begin + n_iters * secondary_stride;
@@ -87,37 +91,13 @@ namespace xt
8791
template <class E1, class E2, class F>
8892
inline void call_over_leading_axis(E1& e1, E2& e2, F&& fct)
8993
{
90-
std::size_t n_iters = 1;
91-
std::ptrdiff_t secondary_stride1, secondary_stride2;
94+
XTENSOR_ASSERT(e1.dimension() >= 2);
95+
XTENSOR_ASSERT(e1.dimension() == e2.dimension());
9296

93-
if (e1.layout() == layout_type::row_major)
94-
{
95-
n_iters = std::accumulate(
96-
e1.shape().begin(),
97-
e1.shape().end() - 1,
98-
std::size_t(1),
99-
std::multiplies<>()
100-
);
101-
secondary_stride1 = adjust_secondary_stride(
102-
e1.strides()[e1.dimension() - 2],
103-
*(e1.shape().end() - 1)
104-
);
105-
secondary_stride2 = adjust_secondary_stride(
106-
e2.strides()[e2.dimension() - 2],
107-
*(e2.shape().end() - 2)
108-
);
109-
}
110-
else
111-
{
112-
n_iters = std::accumulate(
113-
e1.shape().begin() + 1,
114-
e1.shape().end(),
115-
std::size_t(1),
116-
std::multiplies<>()
117-
);
118-
secondary_stride1 = adjust_secondary_stride(e1.strides()[1], *(e1.shape().begin()));
119-
secondary_stride2 = adjust_secondary_stride(e2.strides()[1], *(e2.shape().begin()));
120-
}
97+
const std::size_t n_iters = leading_axis_n_iters(e1);
98+
std::ptrdiff_t const secondary_stride1 = get_secondary_stride(e1);
99+
std::ptrdiff_t const secondary_stride2 = get_secondary_stride(e2);
100+
XTENSOR_ASSERT(secondary_stride1 == secondary_stride2);
121101

122102
const auto begin1 = e1.data();
123103
const auto end1 = begin1 + n_iters * secondary_stride1;
@@ -192,7 +172,7 @@ namespace xt
192172
}
193173

194174
dynamic_shape<std::size_t> permutation, reverse_permutation;
195-
std::tie(permutation, reverse_permutation) = get_permutations(e.dimension(), axis, e.layout());
175+
std::tie(permutation, reverse_permutation) = get_permutations(e.dimension(), ax, e.layout());
196176
R res = transpose(e, permutation);
197177
detail::call_over_leading_axis(res, std::forward<F>(lambda));
198178
res = transpose(res, reverse_permutation);

test/test_xsort.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,15 +426,16 @@ namespace xt
426426

427427
SUBCASE("complex")
428428
{
429-
xt::xarray<double> a = {
429+
const xt::xarray<double> data = {
430430
1014., 1017., 1019., 1020., 1023., 1026., 1026., 1028., 1030., 1032., 1039., 1047., 1071.,
431431
927., 932., 935., 943., 944., 944., 945., 948., 952., 962., 968., 968., 969.,
432432
969., 974., 981., 993., 994., 994., 1003., 1007., 1008., 1008., 1012., 1013., 1014.,
433433
1080., 1085., 1088., 1111., 1112., 1117., 1119., 1128., 1130., 1209., 1309., 1426.};
434-
xt::xtensor_fixed<std::size_t, xt::xshape<4>> kth = {17, 32, 18, 33};
434+
const xt::xtensor_fixed<std::size_t, xt::xshape<4>> kth = {17, 32, 18, 33};
435435

436436
SUBCASE("1D")
437437
{
438+
const auto& a = data;
438439
const auto argpart = xt::argpartition(a, kth);
439440
for (std::size_t k : kth)
440441
{
@@ -445,6 +446,7 @@ namespace xt
445446

446447
SUBCASE("2D")
447448
{
449+
auto a = data;
448450
a.reshape({1, a.size()});
449451
const auto argpart = xt::argpartition(a, kth, 1);
450452
for (std::size_t k : kth)

0 commit comments

Comments
 (0)