Skip to content

Commit 7cd2d4a

Browse files
committed
Factorize implementation of argsort
1 parent 7aa868e commit 7cd2d4a

1 file changed

Lines changed: 47 additions & 16 deletions

File tree

include/xtensor/xsort.hpp

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,49 @@ namespace xt
268268

269269
namespace detail
270270
{
271+
template <class ConstRandomIt, class RandomIt, class Compare>
272+
inline void argsort_iter(
273+
ConstRandomIt data_begin,
274+
ConstRandomIt data_end,
275+
RandomIt idx_begin,
276+
[[maybe_unused]] RandomIt idx_end,
277+
Compare comp
278+
)
279+
{
280+
XTENSOR_ASSERT(std::distance(data_begin, data_end) >= 0);
281+
XTENSOR_ASSERT(std::distance(idx_begin, idx_end) == std::distance(data_begin, data_end));
282+
283+
std::iota(idx_begin, idx_end, 0);
284+
std::sort(
285+
idx_begin,
286+
idx_end,
287+
[&](const auto i, const auto j)
288+
{
289+
return comp(*(data_begin + i), *(data_begin + j));
290+
}
291+
);
292+
}
293+
294+
template <class ConstRandomIt, class RandomIt>
295+
inline void argsort_iter(
296+
ConstRandomIt data_begin,
297+
ConstRandomIt data_end,
298+
RandomIt idx_begin,
299+
[[maybe_unused]] RandomIt idx_end
300+
)
301+
{
302+
return argsort_iter(
303+
std::move(data_begin),
304+
std::move(data_end),
305+
std::move(idx_begin),
306+
std::move(idx_end),
307+
[](const auto& x, const auto& y) -> bool
308+
{
309+
return x < y;
310+
}
311+
);
312+
}
313+
271314
template <class VT, class T>
272315
struct rebind_value_type
273316
{
@@ -336,12 +379,8 @@ namespace xt
336379
using result_type = R;
337380
result_type result;
338381
result.resize({de.size()});
339-
auto comp = [&ad](std::size_t x, std::size_t y)
340-
{
341-
return ad[x] < ad[y];
342-
};
343-
std::iota(result.begin(), result.end(), 0);
344-
std::sort(result.begin(), result.end(), comp);
382+
383+
detail::argsort_iter(de.cbegin(), de.cend(), result.begin(), result.end());
345384

346385
return result;
347386
}
@@ -380,17 +419,9 @@ namespace xt
380419
return detail::flatten_argsort_impl<E, result_type>(e);
381420
}
382421

383-
const auto argsort = [](auto res_begin, auto res_end, auto ev_begin, auto /*ev_end*/)
422+
const auto argsort = [](auto res_begin, auto res_end, auto ev_begin, auto ev_end)
384423
{
385-
std::iota(res_begin, res_end, 0);
386-
std::sort(
387-
res_begin,
388-
res_end,
389-
[&ev_begin](auto const& i, auto const& j)
390-
{
391-
return *(ev_begin + i) < *(ev_begin + j);
392-
}
393-
);
424+
detail::argsort_iter(ev_begin, ev_end, res_begin, res_end);
394425
};
395426

396427
if (ax == detail::leading_axis(de))

0 commit comments

Comments
 (0)