@@ -84,6 +84,53 @@ namespace xt
8484 }
8585 }
8686
87+ template <class E1 , class E2 , class F >
88+ inline void call_over_leading_axis (E1 & e1 , E2 & e2 , F&& fct)
89+ {
90+ std::size_t n_iters = 1 ;
91+ std::ptrdiff_t secondary_stride1, secondary_stride2;
92+
93+ if (e1 .layout () == layout_type::row_major)
94+ {
95+ n_iters = std::accumulate (
96+ e1 .shape ().begin (),
97+ e1 .shape ().end () - 1 ,
98+ std::size_t (1 ),
99+ std::multiplies<>()
100+ );
101+ secondary_stride1 = adjust_secondary_stride (
102+ e1 .strides ()[e1 .dimension () - 2 ],
103+ *(e1 .shape ().end () - 1 )
104+ );
105+ secondary_stride2 = adjust_secondary_stride (
106+ e2 .strides ()[e2 .dimension () - 2 ],
107+ *(e2 .shape ().end () - 2 )
108+ );
109+ }
110+ else
111+ {
112+ n_iters = std::accumulate (
113+ e1 .shape ().begin () + 1 ,
114+ e1 .shape ().end (),
115+ std::size_t (1 ),
116+ std::multiplies<>()
117+ );
118+ secondary_stride1 = adjust_secondary_stride (e1 .strides ()[1 ], *(e1 .shape ().begin ()));
119+ secondary_stride2 = adjust_secondary_stride (e2 .strides ()[1 ], *(e2 .shape ().begin ()));
120+ }
121+
122+ const auto begin1 = e1 .data ();
123+ const auto end1 = begin1 + n_iters * secondary_stride1;
124+ const auto begin2 = e2 .data ();
125+ const auto end2 = begin2 + n_iters * secondary_stride2;
126+ auto iter1 = begin1;
127+ auto iter2 = begin2;
128+ for (; (iter1 < end1) && (iter2 < end2); iter1 += secondary_stride1, iter2 += secondary_stride2)
129+ {
130+ fct (iter1, iter1 + secondary_stride1, iter2, iter2 + secondary_stride2);
131+ }
132+ }
133+
87134 template <class E >
88135 inline std::size_t leading_axis (const E& e)
89136 {
@@ -630,30 +677,28 @@ namespace xt
630677
631678 const auto & de = e.derived_cast ();
632679
633- result_type ev = result_type::from_shape ({de.size ()});
680+ result_type res = result_type::from_shape ({de.size ()});
634681
635682 C kth_copy = kth_container;
636683 if (kth_copy.size () > 1 )
637684 {
638685 std::sort (kth_copy.begin (), kth_copy.end ());
639686 }
640687
641- auto arg_lambda = [&de](std::size_t a, std::size_t b)
642- {
643- return de[a] < de[b];
644- };
645-
646- std::iota (ev.linear_begin (), ev.linear_end (), 0 );
647- std::size_t k_last = kth_copy.back ();
648- std::nth_element (ev.linear_begin (), ev.linear_begin () + k_last, ev.linear_end (), arg_lambda);
688+ std::iota (res.linear_begin (), res.linear_end (), 0 );
649689
650- for (auto it = (kth_copy.rbegin () + 1 ); it != kth_copy.rend (); ++it)
651- {
652- std::nth_element (ev.linear_begin (), ev.linear_begin () + *it, ev.linear_begin () + k_last, arg_lambda);
653- k_last = *it;
654- }
690+ detail::partition_iter (
691+ res.linear_begin (),
692+ res.linear_end (),
693+ kth_copy.rbegin (),
694+ kth_copy.rend (),
695+ [&de](std::size_t a, std::size_t b)
696+ {
697+ return de[a] < de[b];
698+ }
699+ );
655700
656- return ev ;
701+ return res ;
657702 }
658703
659704 template <class E , class I , std::size_t N>
@@ -672,65 +717,6 @@ namespace xt
672717 return argpartition (e, std::array<std::size_t , 1 >({kth}), tag);
673718 }
674719
675- namespace detail
676- {
677- template <class Ed , class Ei >
678- inline void
679- argpartition_over_leading_axis (const Ed& data, Ei& inds, std::size_t kth, std::ptrdiff_t last)
680- {
681- std::size_t n_iters = 1 ;
682- std::ptrdiff_t data_secondary_stride, inds_secondary_stride;
683-
684- if (data.layout () == layout_type::row_major)
685- {
686- n_iters = std::accumulate (
687- data.shape ().begin (),
688- data.shape ().end () - 1 ,
689- std::size_t (1 ),
690- std::multiplies<>()
691- );
692- data_secondary_stride = data.strides ()[data.dimension () - 2 ];
693- inds_secondary_stride = inds.strides ()[inds.dimension () - 2 ];
694- }
695- else
696- {
697- n_iters = std::accumulate (
698- data.shape ().begin () + 1 ,
699- data.shape ().end (),
700- std::size_t (1 ),
701- std::multiplies<>()
702- );
703- data_secondary_stride = data.strides ()[1 ];
704- inds_secondary_stride = inds.strides ()[1 ];
705- }
706-
707- auto ptr = data.data ();
708- auto indices_ptr = inds.data ();
709- auto comp = [&ptr](std::size_t x, std::size_t y)
710- {
711- return *(ptr + x) < *(ptr + y);
712- };
713-
714- if (last == -1 ) // initialize
715- {
716- for (std::size_t i = 0 ; i < n_iters;
717- ++i, ptr += data_secondary_stride, indices_ptr += inds_secondary_stride)
718- {
719- std::iota (indices_ptr, indices_ptr + inds_secondary_stride, 0 );
720- std::nth_element (indices_ptr, indices_ptr + kth, indices_ptr + inds_secondary_stride, comp);
721- }
722- }
723- else
724- {
725- for (std::size_t i = 0 ; i < n_iters;
726- ++i, ptr += data_secondary_stride, indices_ptr += inds_secondary_stride)
727- {
728- std::nth_element (indices_ptr, indices_ptr + kth, indices_ptr + last, comp);
729- }
730- }
731- }
732- }
733-
734720 template <class E , class C , class = std::enable_if_t <!xtl::is_integral<C>::value, int >>
735721 inline auto argpartition (const xexpression<E>& e, const C& kth_container, std::ptrdiff_t axis = -1 )
736722 {
@@ -739,8 +725,6 @@ namespace xt
739725
740726 const auto & de = e.derived_cast ();
741727
742- std::size_t ax = normalize_axis (de.dimension (), axis);
743-
744728 if (de.dimension () == 1 )
745729 {
746730 return argpartition<E, C, result_type>(e, kth_container, xnone ());
@@ -751,38 +735,35 @@ namespace xt
751735 {
752736 std::sort (kth_copy.begin (), kth_copy.end ());
753737 }
754-
755- eval_type ev;
756- result_type res;
757-
758- dynamic_shape<std::size_t > permutation, reverse_permutation;
759- bool is_leading_axis = (ax == detail::leading_axis (de));
760-
761- if (!is_leading_axis)
762- {
763- std::tie (permutation, reverse_permutation) = detail::get_permutations (de.dimension (), ax, de.layout ());
764- ev = transpose (de, permutation);
765- }
766- else
767- {
768- ev = de;
769- }
770- res.resize (ev.shape ());
771-
772- std::size_t kth = kth_copy.back ();
773- detail::argpartition_over_leading_axis (ev, res, kth, -1 );
774-
775- for (auto it = (kth_copy.rbegin () + 1 ); it != kth_copy.rend (); ++it)
738+ const auto argpartition_w_kth = [&kth_copy](auto res_begin, auto res_end, auto ev_begin, auto /* ev_end*/ )
776739 {
777- detail::argpartition_over_leading_axis (ev, res, *it, static_cast <std::ptrdiff_t >(kth));
778- kth = *it;
779- }
740+ std::iota (res_begin, res_end, 0 );
741+ detail::partition_iter (
742+ res_begin,
743+ res_end,
744+ kth_copy.rbegin (),
745+ kth_copy.rend (),
746+ [&ev_begin](auto const & i, auto const & j)
747+ {
748+ return *(ev_begin + i) < *(ev_begin + j);
749+ }
750+ );
751+ };
780752
781- if (!is_leading_axis)
753+ std::size_t const ax = normalize_axis (de.dimension (), axis);
754+ if (ax == detail::leading_axis (de))
782755 {
783- res = transpose (res, reverse_permutation);
756+ result_type res = result_type::from_shape (de.shape ());
757+ detail::call_over_leading_axis (res, de, argpartition_w_kth);
758+ return res;
784759 }
785760
761+ dynamic_shape<std::size_t > permutation, reverse_permutation;
762+ std::tie (permutation, reverse_permutation) = detail::get_permutations (de.dimension (), ax, de.layout ());
763+ eval_type ev = transpose (de, permutation);
764+ result_type res = result_type::from_shape (ev.shape ());
765+ detail::call_over_leading_axis (res, ev, argpartition_w_kth);
766+ res = transpose (res, reverse_permutation);
786767 return res;
787768 }
788769
0 commit comments