Skip to content

Commit 13f3221

Browse files
committed
Refactor call_lambda_over_axis
1 parent a0ac70d commit 13f3221

1 file changed

Lines changed: 31 additions & 57 deletions

File tree

include/xtensor/xsort.hpp

Lines changed: 31 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -173,23 +173,30 @@ namespace xt
173173
return std::make_pair(std::move(permutation), std::move(reverse_permutation));
174174
}
175175

176-
template <class E, class R, class F>
177-
inline auto run_lambda_over_axis(const E& e, R& res, std::size_t axis, F&& lambda)
176+
template <class R, class E, class F>
177+
inline R map_axis(const E& e, std::ptrdiff_t axis, F&& lambda)
178178
{
179-
if (axis != detail::leading_axis(e))
179+
if (e.dimension() == 1)
180180
{
181-
dynamic_shape<std::size_t> permutation, reverse_permutation;
182-
std::tie(permutation, reverse_permutation) = get_permutations(e.dimension(), axis, e.layout());
183-
184-
res = transpose(e, permutation);
185-
detail::call_over_leading_axis(res, std::forward<F>(lambda));
186-
res = transpose(res, reverse_permutation);
181+
R res = e;
182+
lambda(res.begin(), res.end());
183+
return res;
187184
}
188-
else
185+
186+
std::size_t const ax = normalize_axis(e.dimension(), axis);
187+
if (ax == detail::leading_axis(e))
189188
{
190-
res = e;
189+
R res = e;
191190
detail::call_over_leading_axis(res, std::forward<F>(lambda));
191+
return res;
192192
}
193+
194+
dynamic_shape<std::size_t> permutation, reverse_permutation;
195+
std::tie(permutation, reverse_permutation) = get_permutations(e.dimension(), axis, e.layout());
196+
R res = transpose(e, permutation);
197+
detail::call_over_leading_axis(res, std::forward<F>(lambda));
198+
res = transpose(res, reverse_permutation);
199+
return res;
193200
}
194201

195202
template <class VT>
@@ -269,26 +276,14 @@ namespace xt
269276
{
270277
using eval_type = typename detail::sort_eval_type<E>::type;
271278

272-
const auto& de = e.derived_cast();
273-
274-
if (de.dimension() == 1)
275-
{
276-
return detail::flat_sort_impl<std::decay_t<decltype(de)>, eval_type>(de);
277-
}
278-
279-
std::size_t ax = normalize_axis(de.dimension(), axis);
280-
281-
eval_type res;
282-
detail::run_lambda_over_axis(
283-
de,
284-
res,
285-
ax,
279+
return detail::map_axis<eval_type>(
280+
e.derived_cast(),
281+
axis,
286282
[](auto begin, auto end)
287283
{
288284
std::sort(begin, end);
289285
}
290286
);
291-
return res;
292287
}
293288

294289
namespace detail
@@ -551,41 +546,20 @@ namespace xt
551546
}
552547

553548
template <class E, class C, class = std::enable_if_t<!xtl::is_integral<C>::value, int>>
554-
inline auto partition(const xexpression<E>& e, const C& kth_container, std::ptrdiff_t axis = -1)
549+
inline auto partition(const xexpression<E>& e, C kth_container, std::ptrdiff_t axis = -1)
555550
{
556551
using eval_type = typename detail::sort_eval_type<E>::type;
557552

558-
const auto& de = e.derived_cast();
553+
std::sort(kth_container.begin(), kth_container.end());
559554

560-
if (de.dimension() == 1)
561-
{
562-
return partition<E, C, eval_type>(de, kth_container, xnone());
563-
}
564-
565-
C kth_copy = kth_container;
566-
if (kth_copy.size() > 1)
567-
{
568-
std::sort(kth_copy.begin(), kth_copy.end());
569-
}
570-
const auto partition_w_kth = [&kth_copy](auto begin, auto end)
571-
{
572-
detail::partition_iter(begin, end, kth_copy.rbegin(), kth_copy.rend());
573-
};
574-
575-
std::size_t const ax = normalize_axis(de.dimension(), axis);
576-
if (ax == detail::leading_axis(de))
577-
{
578-
eval_type res = de;
579-
detail::call_over_leading_axis(res, partition_w_kth);
580-
return res;
581-
}
582-
583-
dynamic_shape<std::size_t> permutation, reverse_permutation;
584-
std::tie(permutation, reverse_permutation) = detail::get_permutations(de.dimension(), ax, de.layout());
585-
eval_type res = transpose(de, permutation);
586-
detail::call_over_leading_axis(res, partition_w_kth);
587-
res = transpose(res, reverse_permutation);
588-
return res;
555+
return detail::map_axis<eval_type>(
556+
e.derived_cast(),
557+
axis,
558+
[&kth_container](auto begin, auto end)
559+
{
560+
detail::partition_iter(begin, end, kth_container.rbegin(), kth_container.rend());
561+
}
562+
);
589563
}
590564

591565
template <class E, class T, std::size_t N>

0 commit comments

Comments
 (0)