Skip to content

Commit 67ec364

Browse files
committed
Fix argpartition
1 parent 5615c33 commit 67ec364

2 files changed

Lines changed: 132 additions & 114 deletions

File tree

include/xtensor/xsort.hpp

Lines changed: 84 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,53 @@ namespace xt
8484
}
8585
}
8686

87+
template <class E1, class E2, class F>
88+
inline void call_over_leading_axis(E1& e1, E2& e2, F&& fct)
89+
{
90+
std::size_t n_iters = 1;
91+
std::ptrdiff_t secondary_stride1, secondary_stride2;
92+
93+
if (e1.layout() == layout_type::row_major)
94+
{
95+
n_iters = std::accumulate(
96+
e1.shape().begin(),
97+
e1.shape().end() - 1,
98+
std::size_t(1),
99+
std::multiplies<>()
100+
);
101+
secondary_stride1 = adjust_secondary_stride(
102+
e1.strides()[e1.dimension() - 2],
103+
*(e1.shape().end() - 1)
104+
);
105+
secondary_stride2 = adjust_secondary_stride(
106+
e2.strides()[e2.dimension() - 2],
107+
*(e2.shape().end() - 2)
108+
);
109+
}
110+
else
111+
{
112+
n_iters = std::accumulate(
113+
e1.shape().begin() + 1,
114+
e1.shape().end(),
115+
std::size_t(1),
116+
std::multiplies<>()
117+
);
118+
secondary_stride1 = adjust_secondary_stride(e1.strides()[1], *(e1.shape().begin()));
119+
secondary_stride2 = adjust_secondary_stride(e2.strides()[1], *(e2.shape().begin()));
120+
}
121+
122+
const auto begin1 = e1.data();
123+
const auto end1 = begin1 + n_iters * secondary_stride1;
124+
const auto begin2 = e2.data();
125+
const auto end2 = begin2 + n_iters * secondary_stride2;
126+
auto iter1 = begin1;
127+
auto iter2 = begin2;
128+
for (; (iter1 < end1) && (iter2 < end2); iter1 += secondary_stride1, iter2 += secondary_stride2)
129+
{
130+
fct(iter1, iter1 + secondary_stride1, iter2, iter2 + secondary_stride2);
131+
}
132+
}
133+
87134
template <class E>
88135
inline std::size_t leading_axis(const E& e)
89136
{
@@ -630,30 +677,28 @@ namespace xt
630677

631678
const auto& de = e.derived_cast();
632679

633-
result_type ev = result_type::from_shape({de.size()});
680+
result_type res = result_type::from_shape({de.size()});
634681

635682
C kth_copy = kth_container;
636683
if (kth_copy.size() > 1)
637684
{
638685
std::sort(kth_copy.begin(), kth_copy.end());
639686
}
640687

641-
auto arg_lambda = [&de](std::size_t a, std::size_t b)
642-
{
643-
return de[a] < de[b];
644-
};
645-
646-
std::iota(ev.linear_begin(), ev.linear_end(), 0);
647-
std::size_t k_last = kth_copy.back();
648-
std::nth_element(ev.linear_begin(), ev.linear_begin() + k_last, ev.linear_end(), arg_lambda);
688+
std::iota(res.linear_begin(), res.linear_end(), 0);
649689

650-
for (auto it = (kth_copy.rbegin() + 1); it != kth_copy.rend(); ++it)
651-
{
652-
std::nth_element(ev.linear_begin(), ev.linear_begin() + *it, ev.linear_begin() + k_last, arg_lambda);
653-
k_last = *it;
654-
}
690+
detail::partition_iter(
691+
res.linear_begin(),
692+
res.linear_end(),
693+
kth_copy.rbegin(),
694+
kth_copy.rend(),
695+
[&de](std::size_t a, std::size_t b)
696+
{
697+
return de[a] < de[b];
698+
}
699+
);
655700

656-
return ev;
701+
return res;
657702
}
658703

659704
template <class E, class I, std::size_t N>
@@ -672,65 +717,6 @@ namespace xt
672717
return argpartition(e, std::array<std::size_t, 1>({kth}), tag);
673718
}
674719

675-
namespace detail
676-
{
677-
template <class Ed, class Ei>
678-
inline void
679-
argpartition_over_leading_axis(const Ed& data, Ei& inds, std::size_t kth, std::ptrdiff_t last)
680-
{
681-
std::size_t n_iters = 1;
682-
std::ptrdiff_t data_secondary_stride, inds_secondary_stride;
683-
684-
if (data.layout() == layout_type::row_major)
685-
{
686-
n_iters = std::accumulate(
687-
data.shape().begin(),
688-
data.shape().end() - 1,
689-
std::size_t(1),
690-
std::multiplies<>()
691-
);
692-
data_secondary_stride = data.strides()[data.dimension() - 2];
693-
inds_secondary_stride = inds.strides()[inds.dimension() - 2];
694-
}
695-
else
696-
{
697-
n_iters = std::accumulate(
698-
data.shape().begin() + 1,
699-
data.shape().end(),
700-
std::size_t(1),
701-
std::multiplies<>()
702-
);
703-
data_secondary_stride = data.strides()[1];
704-
inds_secondary_stride = inds.strides()[1];
705-
}
706-
707-
auto ptr = data.data();
708-
auto indices_ptr = inds.data();
709-
auto comp = [&ptr](std::size_t x, std::size_t y)
710-
{
711-
return *(ptr + x) < *(ptr + y);
712-
};
713-
714-
if (last == -1) // initialize
715-
{
716-
for (std::size_t i = 0; i < n_iters;
717-
++i, ptr += data_secondary_stride, indices_ptr += inds_secondary_stride)
718-
{
719-
std::iota(indices_ptr, indices_ptr + inds_secondary_stride, 0);
720-
std::nth_element(indices_ptr, indices_ptr + kth, indices_ptr + inds_secondary_stride, comp);
721-
}
722-
}
723-
else
724-
{
725-
for (std::size_t i = 0; i < n_iters;
726-
++i, ptr += data_secondary_stride, indices_ptr += inds_secondary_stride)
727-
{
728-
std::nth_element(indices_ptr, indices_ptr + kth, indices_ptr + last, comp);
729-
}
730-
}
731-
}
732-
}
733-
734720
template <class E, class C, class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
735721
inline auto argpartition(const xexpression<E>& e, const C& kth_container, std::ptrdiff_t axis = -1)
736722
{
@@ -739,8 +725,6 @@ namespace xt
739725

740726
const auto& de = e.derived_cast();
741727

742-
std::size_t ax = normalize_axis(de.dimension(), axis);
743-
744728
if (de.dimension() == 1)
745729
{
746730
return argpartition<E, C, result_type>(e, kth_container, xnone());
@@ -751,38 +735,35 @@ namespace xt
751735
{
752736
std::sort(kth_copy.begin(), kth_copy.end());
753737
}
754-
755-
eval_type ev;
756-
result_type res;
757-
758-
dynamic_shape<std::size_t> permutation, reverse_permutation;
759-
bool is_leading_axis = (ax == detail::leading_axis(de));
760-
761-
if (!is_leading_axis)
762-
{
763-
std::tie(permutation, reverse_permutation) = detail::get_permutations(de.dimension(), ax, de.layout());
764-
ev = transpose(de, permutation);
765-
}
766-
else
767-
{
768-
ev = de;
769-
}
770-
res.resize(ev.shape());
771-
772-
std::size_t kth = kth_copy.back();
773-
detail::argpartition_over_leading_axis(ev, res, kth, -1);
774-
775-
for (auto it = (kth_copy.rbegin() + 1); it != kth_copy.rend(); ++it)
738+
const auto argpartition_w_kth = [&kth_copy](auto res_begin, auto res_end, auto ev_begin, auto /*ev_end*/)
776739
{
777-
detail::argpartition_over_leading_axis(ev, res, *it, static_cast<std::ptrdiff_t>(kth));
778-
kth = *it;
779-
}
740+
std::iota(res_begin, res_end, 0);
741+
detail::partition_iter(
742+
res_begin,
743+
res_end,
744+
kth_copy.rbegin(),
745+
kth_copy.rend(),
746+
[&ev_begin](auto const& i, auto const& j)
747+
{
748+
return *(ev_begin + i) < *(ev_begin + j);
749+
}
750+
);
751+
};
780752

781-
if (!is_leading_axis)
753+
std::size_t const ax = normalize_axis(de.dimension(), axis);
754+
if (ax == detail::leading_axis(de))
782755
{
783-
res = transpose(res, reverse_permutation);
756+
result_type res = result_type::from_shape(de.shape());
757+
detail::call_over_leading_axis(res, de, argpartition_w_kth);
758+
return res;
784759
}
785760

761+
dynamic_shape<std::size_t> permutation, reverse_permutation;
762+
std::tie(permutation, reverse_permutation) = detail::get_permutations(de.dimension(), ax, de.layout());
763+
eval_type ev = transpose(de, permutation);
764+
result_type res = result_type::from_shape(ev.shape());
765+
detail::call_over_leading_axis(res, ev, argpartition_w_kth);
766+
res = transpose(res, reverse_permutation);
786767
return res;
787768
}
788769

test/test_xsort.cpp

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ namespace xt
396396
bool res = true;
397397
for (std::size_t i = 0; i < pos; ++i)
398398
{
399-
res = res && arr(idxs[i]) < arr(idxs[pos]);
399+
res = res && arr(idxs[i]) <= arr(idxs[pos]);
400400
}
401401
for (std::size_t i = pos; i < arr.size(); ++i)
402402
{
@@ -407,16 +407,53 @@ namespace xt
407407

408408
TEST(xsort, argpartition)
409409
{
410-
xt::xarray<int> a = {3, 4, 2, 1};
411-
auto r1 = xt::argpartition(a, 2);
412-
EXPECT_TRUE(check_argpartition(a, r1, 2));
413-
414-
std::size_t s = a.size();
415-
int* arr = a.data();
416-
dynamic_shape<std::size_t> sh = {s};
417-
auto b = xt::adapt(arr, s, xt::no_ownership(), sh);
418-
auto r2 = xt::argpartition(b, 2);
419-
EXPECT_TRUE(check_argpartition(b, r2, 2));
410+
SUBCASE("simple")
411+
{
412+
xt::xarray<int> a = {3, 4, 2, 1};
413+
auto r1 = xt::argpartition(a, 2);
414+
EXPECT_TRUE(check_argpartition(a, r1, 2));
415+
416+
SUBCASE("adapt")
417+
{
418+
std::size_t s = a.size();
419+
int* arr = a.data();
420+
dynamic_shape<std::size_t> sh = {s};
421+
auto b = xt::adapt(arr, s, xt::no_ownership(), sh);
422+
auto r2 = xt::argpartition(b, 2);
423+
EXPECT_TRUE(check_argpartition(b, r2, 2));
424+
}
425+
}
426+
427+
SUBCASE("complex")
428+
{
429+
xt::xarray<double> a = {
430+
1014., 1017., 1019., 1020., 1023., 1026., 1026., 1028., 1030., 1032., 1039., 1047., 1071.,
431+
927., 932., 935., 943., 944., 944., 945., 948., 952., 962., 968., 968., 969.,
432+
969., 974., 981., 993., 994., 994., 1003., 1007., 1008., 1008., 1012., 1013., 1014.,
433+
1080., 1085., 1088., 1111., 1112., 1117., 1119., 1128., 1130., 1209., 1309., 1426.};
434+
xt::xtensor_fixed<std::size_t, xt::xshape<4>> kth = {17, 32, 18, 33};
435+
436+
SUBCASE("1D")
437+
{
438+
const auto argpart = xt::argpartition(a, kth);
439+
for (std::size_t k : kth)
440+
{
441+
CAPTURE(k);
442+
EXPECT_TRUE(check_argpartition(a, argpart, k));
443+
}
444+
}
445+
446+
SUBCASE("2D")
447+
{
448+
a.reshape({1, a.size()});
449+
const auto argpart = xt::argpartition(a, kth, 1);
450+
for (std::size_t k : kth)
451+
{
452+
CAPTURE(k);
453+
EXPECT_TRUE(check_argpartition(a, argpart, k));
454+
}
455+
}
456+
}
420457
}
421458

422459
TEST(xsort, quantile)

0 commit comments

Comments
 (0)