@@ -173,23 +173,30 @@ namespace xt
173173 return std::make_pair (std::move (permutation), std::move (reverse_permutation));
174174 }
175175
176- template <class E , class R , class F >
177- inline auto run_lambda_over_axis (const E& e, R& res, std::size_t axis, F&& lambda)
176+ template <class R , class E , class F >
177+ inline R map_axis (const E& e, std::ptrdiff_t axis, F&& lambda)
178178 {
179- if (axis != detail::leading_axis (e) )
179+ if (e. dimension () == 1 )
180180 {
181- dynamic_shape<std::size_t > permutation, reverse_permutation;
182- std::tie (permutation, reverse_permutation) = get_permutations (e.dimension (), axis, e.layout ());
183-
184- res = transpose (e, permutation);
185- detail::call_over_leading_axis (res, std::forward<F>(lambda));
186- res = transpose (res, reverse_permutation);
181+ R res = e;
182+ lambda (res.begin (), res.end ());
183+ return res;
187184 }
188- else
185+
186+ std::size_t const ax = normalize_axis (e.dimension (), axis);
187+ if (ax == detail::leading_axis (e))
189188 {
190- res = e;
189+ R res = e;
191190 detail::call_over_leading_axis (res, std::forward<F>(lambda));
191+ return res;
192192 }
193+
194+ dynamic_shape<std::size_t > permutation, reverse_permutation;
195+ std::tie (permutation, reverse_permutation) = get_permutations (e.dimension (), axis, e.layout ());
196+ R res = transpose (e, permutation);
197+ detail::call_over_leading_axis (res, std::forward<F>(lambda));
198+ res = transpose (res, reverse_permutation);
199+ return res;
193200 }
194201
195202 template <class VT >
@@ -269,26 +276,14 @@ namespace xt
269276 {
270277 using eval_type = typename detail::sort_eval_type<E>::type;
271278
272- const auto & de = e.derived_cast ();
273-
274- if (de.dimension () == 1 )
275- {
276- return detail::flat_sort_impl<std::decay_t <decltype (de)>, eval_type>(de);
277- }
278-
279- std::size_t ax = normalize_axis (de.dimension (), axis);
280-
281- eval_type res;
282- detail::run_lambda_over_axis (
283- de,
284- res,
285- ax,
279+ return detail::map_axis<eval_type>(
280+ e.derived_cast (),
281+ axis,
286282 [](auto begin, auto end)
287283 {
288284 std::sort (begin, end);
289285 }
290286 );
291- return res;
292287 }
293288
294289 namespace detail
@@ -551,41 +546,20 @@ namespace xt
551546 }
552547
553548 template <class E , class C , class = std::enable_if_t <!xtl::is_integral<C>::value, int >>
554- inline auto partition (const xexpression<E>& e, const C& kth_container, std::ptrdiff_t axis = -1 )
549+ inline auto partition (const xexpression<E>& e, C kth_container, std::ptrdiff_t axis = -1 )
555550 {
556551 using eval_type = typename detail::sort_eval_type<E>::type;
557552
558- const auto & de = e. derived_cast ( );
553+ std::sort (kth_container. begin (), kth_container. end () );
559554
560- if (de.dimension () == 1 )
561- {
562- return partition<E, C, eval_type>(de, kth_container, xnone ());
563- }
564-
565- C kth_copy = kth_container;
566- if (kth_copy.size () > 1 )
567- {
568- std::sort (kth_copy.begin (), kth_copy.end ());
569- }
570- const auto partition_w_kth = [&kth_copy](auto begin, auto end)
571- {
572- detail::partition_iter (begin, end, kth_copy.rbegin (), kth_copy.rend ());
573- };
574-
575- std::size_t const ax = normalize_axis (de.dimension (), axis);
576- if (ax == detail::leading_axis (de))
577- {
578- eval_type res = de;
579- detail::call_over_leading_axis (res, partition_w_kth);
580- return res;
581- }
582-
583- dynamic_shape<std::size_t > permutation, reverse_permutation;
584- std::tie (permutation, reverse_permutation) = detail::get_permutations (de.dimension (), ax, de.layout ());
585- eval_type res = transpose (de, permutation);
586- detail::call_over_leading_axis (res, partition_w_kth);
587- res = transpose (res, reverse_permutation);
588- return res;
555+ return detail::map_axis<eval_type>(
556+ e.derived_cast (),
557+ axis,
558+ [&kth_container](auto begin, auto end)
559+ {
560+ detail::partition_iter (begin, end, kth_container.rbegin (), kth_container.rend ());
561+ }
562+ );
589563 }
590564
591565 template <class E , class T , std::size_t N>
0 commit comments