Skip to content

Commit 2982403

Browse files
committed
Use universal reference for partition container
1 parent 36f146e commit 2982403

1 file changed

Lines changed: 26 additions & 29 deletions

File tree

include/xtensor/xsort.hpp

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,14 @@ namespace xt
480480
}
481481
);
482482
}
483+
484+
template <class C>
485+
inline auto sorted(C&& container)
486+
{
487+
auto container_copy = std::forward<C>(container);
488+
std::sort(container_copy.begin(), container_copy.end());
489+
return std::move(container_copy);
490+
}
483491
}
484492

485493
/**
@@ -514,20 +522,16 @@ namespace xt
514522
class C,
515523
class R = detail::flatten_sort_result_type_t<E>,
516524
class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
517-
inline R partition(const xexpression<E>& e, const C& kth_container, placeholders::xtuph /*ax*/)
525+
inline R partition(const xexpression<E>& e, C&& kth_container, placeholders::xtuph /*ax*/)
518526
{
519527
const auto& de = e.derived_cast();
520528

521529
R ev = R::from_shape({de.size()});
522-
C kth_copy = kth_container;
523-
if (kth_copy.size() > 1)
524-
{
525-
std::sort(kth_copy.begin(), kth_copy.end());
526-
}
530+
auto kth_sorted = detail::sorted(std::forward<C>(kth_container));
527531

528532
std::copy(de.linear_cbegin(), de.linear_cend(), ev.linear_begin()); // flatten
529533

530-
detail::partition_iter(ev.linear_begin(), ev.linear_end(), kth_copy.rbegin(), kth_copy.rend());
534+
detail::partition_iter(ev.linear_begin(), ev.linear_end(), kth_sorted.rbegin(), kth_sorted.rend());
531535

532536
return ev;
533537
}
@@ -549,18 +553,18 @@ namespace xt
549553
}
550554

551555
template <class E, class C, class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
552-
inline auto partition(const xexpression<E>& e, C kth_container, std::ptrdiff_t axis = -1)
556+
inline auto partition(const xexpression<E>& e, C&& kth_container, std::ptrdiff_t axis = -1)
553557
{
554558
using eval_type = typename detail::sort_eval_type<E>::type;
555559

556-
std::sort(kth_container.begin(), kth_container.end());
560+
auto kth_sorted = detail::sorted(std::forward<C>(kth_container));
557561

558562
return detail::map_axis<eval_type>(
559563
e.derived_cast(),
560564
axis,
561-
[&kth_container](auto begin, auto end)
565+
[&kth_sorted](auto begin, auto end)
562566
{
563-
detail::partition_iter(begin, end, kth_container.rbegin(), kth_container.rend());
567+
detail::partition_iter(begin, end, kth_sorted.rbegin(), kth_sorted.rend());
564568
}
565569
);
566570
}
@@ -613,7 +617,7 @@ namespace xt
613617
class C,
614618
class R = typename detail::linear_argsort_result_type<typename detail::sort_eval_type<E>::type>::type,
615619
class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
616-
inline R argpartition(const xexpression<E>& e, const C& kth_container, placeholders::xtuph)
620+
inline R argpartition(const xexpression<E>& e, C&& kth_container, placeholders::xtuph)
617621
{
618622
using eval_type = typename detail::sort_eval_type<E>::type;
619623
using result_type = typename detail::linear_argsort_result_type<eval_type>::type;
@@ -622,19 +626,15 @@ namespace xt
622626

623627
result_type res = result_type::from_shape({de.size()});
624628

625-
C kth_copy = kth_container;
626-
if (kth_copy.size() > 1)
627-
{
628-
std::sort(kth_copy.begin(), kth_copy.end());
629-
}
629+
auto kth_sorted = detail::sorted(std::forward<C>(kth_container));
630630

631631
std::iota(res.linear_begin(), res.linear_end(), 0);
632632

633633
detail::partition_iter(
634634
res.linear_begin(),
635635
res.linear_end(),
636-
kth_copy.rbegin(),
637-
kth_copy.rend(),
636+
kth_sorted.rbegin(),
637+
kth_sorted.rend(),
638638
[&de](std::size_t a, std::size_t b)
639639
{
640640
return de[a] < de[b];
@@ -661,7 +661,7 @@ namespace xt
661661
}
662662

663663
template <class E, class C, class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
664-
inline auto argpartition(const xexpression<E>& e, const C& kth_container, std::ptrdiff_t axis = -1)
664+
inline auto argpartition(const xexpression<E>& e, C&& kth_container, std::ptrdiff_t axis = -1)
665665
{
666666
using eval_type = typename detail::sort_eval_type<E>::type;
667667
using result_type = typename detail::argsort_result_type<eval_type>::type;
@@ -670,22 +670,19 @@ namespace xt
670670

671671
if (de.dimension() == 1)
672672
{
673-
return argpartition<E, C, result_type>(e, kth_container, xnone());
673+
return argpartition<E, C, result_type>(e, std::forward<C>(kth_container), xnone());
674674
}
675675

676-
C kth_copy = kth_container;
677-
if (kth_copy.size() > 1)
678-
{
679-
std::sort(kth_copy.begin(), kth_copy.end());
680-
}
681-
const auto argpartition_w_kth = [&kth_copy](auto res_begin, auto res_end, auto ev_begin, auto /*ev_end*/)
676+
auto kth_sorted = detail::sorted(std::forward<C>(kth_container));
677+
const auto argpartition_w_kth =
678+
[&kth_sorted](auto res_begin, auto res_end, auto ev_begin, auto /*ev_end*/)
682679
{
683680
std::iota(res_begin, res_end, 0);
684681
detail::partition_iter(
685682
res_begin,
686683
res_end,
687-
kth_copy.rbegin(),
688-
kth_copy.rend(),
684+
kth_sorted.rbegin(),
685+
kth_sorted.rend(),
689686
[&ev_begin](auto const& i, auto const& j)
690687
{
691688
return *(ev_begin + i) < *(ev_begin + j);

0 commit comments

Comments
 (0)