Skip to content

Commit a0ac70d

Browse files
committed
Refactor argsort to use call_over_leading_axis
1 parent 67ec364 commit a0ac70d

1 file changed

Lines changed: 21 additions & 55 deletions

File tree

include/xtensor/xsort.hpp

Lines changed: 21 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -349,50 +349,6 @@ namespace xt
349349
typename T::temporary_type>::type;
350350
};
351351

352-
template <class Ed, class Ei>
353-
inline void argsort_over_leading_axis(const Ed& data, Ei& inds)
354-
{
355-
std::size_t n_iters = 1;
356-
std::ptrdiff_t data_secondary_stride, inds_secondary_stride;
357-
358-
if (data.layout() == layout_type::row_major)
359-
{
360-
n_iters = std::accumulate(
361-
data.shape().begin(),
362-
data.shape().end() - 1,
363-
std::size_t(1),
364-
std::multiplies<>()
365-
);
366-
data_secondary_stride = static_cast<std::ptrdiff_t>(data.shape(data.dimension() - 1));
367-
inds_secondary_stride = static_cast<std::ptrdiff_t>(inds.shape(inds.dimension() - 1));
368-
}
369-
else
370-
{
371-
n_iters = std::accumulate(
372-
data.shape().begin() + 1,
373-
data.shape().end(),
374-
std::size_t(1),
375-
std::multiplies<>()
376-
);
377-
data_secondary_stride = static_cast<std::ptrdiff_t>(data.shape(0));
378-
inds_secondary_stride = static_cast<std::ptrdiff_t>(inds.shape(0));
379-
}
380-
381-
auto ptr = data.data();
382-
auto indices_ptr = inds.data();
383-
384-
for (std::size_t i = 0; i < n_iters;
385-
++i, ptr += data_secondary_stride, indices_ptr += inds_secondary_stride)
386-
{
387-
auto comp = [&ptr](std::size_t x, std::size_t y)
388-
{
389-
return *(ptr + x) < *(ptr + y);
390-
};
391-
std::iota(indices_ptr, indices_ptr + inds_secondary_stride, 0);
392-
std::sort(indices_ptr, indices_ptr + inds_secondary_stride, comp);
393-
}
394-
}
395-
396352
template <class E, class R = typename detail::linear_argsort_result_type<E>::type>
397353
inline auto flatten_argsort_impl(const xexpression<E>& e)
398354
{
@@ -449,23 +405,33 @@ namespace xt
449405
return detail::flatten_argsort_impl<E, result_type>(e);
450406
}
451407

452-
if (ax != detail::leading_axis(de))
408+
const auto argsort = [](auto res_begin, auto res_end, auto ev_begin, auto /*ev_end*/)
453409
{
454-
dynamic_shape<std::size_t> permutation, reverse_permutation;
455-
std::tie(permutation, reverse_permutation) = detail::get_permutations(de.dimension(), ax, de.layout());
410+
std::iota(res_begin, res_end, 0);
411+
std::sort(
412+
res_begin,
413+
res_end,
414+
[&ev_begin](auto const& i, auto const& j)
415+
{
416+
return *(ev_begin + i) < *(ev_begin + j);
417+
}
418+
);
419+
};
456420

457-
eval_type ev = transpose(de, permutation);
458-
result_type res = result_type::from_shape(ev.shape());
459-
detail::argsort_over_leading_axis(ev, res);
460-
res = transpose(res, reverse_permutation);
461-
return res;
462-
}
463-
else
421+
if (ax == detail::leading_axis(de))
464422
{
465423
result_type res = result_type::from_shape(de.shape());
466-
detail::argsort_over_leading_axis(de, res);
424+
detail::call_over_leading_axis(res, de, argsort);
467425
return res;
468426
}
427+
428+
dynamic_shape<std::size_t> permutation, reverse_permutation;
429+
std::tie(permutation, reverse_permutation) = detail::get_permutations(de.dimension(), ax, de.layout());
430+
eval_type ev = transpose(de, permutation);
431+
result_type res = result_type::from_shape(ev.shape());
432+
detail::call_over_leading_axis(res, ev, argsort);
433+
res = transpose(res, reverse_permutation);
434+
return res;
469435
}
470436

471437
/************************************************

0 commit comments

Comments
 (0)