@@ -349,50 +349,6 @@ namespace xt
349349 typename T::temporary_type>::type;
350350 };
351351
352- template <class Ed , class Ei >
353- inline void argsort_over_leading_axis (const Ed& data, Ei& inds)
354- {
355- std::size_t n_iters = 1 ;
356- std::ptrdiff_t data_secondary_stride, inds_secondary_stride;
357-
358- if (data.layout () == layout_type::row_major)
359- {
360- n_iters = std::accumulate (
361- data.shape ().begin (),
362- data.shape ().end () - 1 ,
363- std::size_t (1 ),
364- std::multiplies<>()
365- );
366- data_secondary_stride = static_cast <std::ptrdiff_t >(data.shape (data.dimension () - 1 ));
367- inds_secondary_stride = static_cast <std::ptrdiff_t >(inds.shape (inds.dimension () - 1 ));
368- }
369- else
370- {
371- n_iters = std::accumulate (
372- data.shape ().begin () + 1 ,
373- data.shape ().end (),
374- std::size_t (1 ),
375- std::multiplies<>()
376- );
377- data_secondary_stride = static_cast <std::ptrdiff_t >(data.shape (0 ));
378- inds_secondary_stride = static_cast <std::ptrdiff_t >(inds.shape (0 ));
379- }
380-
381- auto ptr = data.data ();
382- auto indices_ptr = inds.data ();
383-
384- for (std::size_t i = 0 ; i < n_iters;
385- ++i, ptr += data_secondary_stride, indices_ptr += inds_secondary_stride)
386- {
387- auto comp = [&ptr](std::size_t x, std::size_t y)
388- {
389- return *(ptr + x) < *(ptr + y);
390- };
391- std::iota (indices_ptr, indices_ptr + inds_secondary_stride, 0 );
392- std::sort (indices_ptr, indices_ptr + inds_secondary_stride, comp);
393- }
394- }
395-
396352 template <class E , class R = typename detail::linear_argsort_result_type<E>::type>
397353 inline auto flatten_argsort_impl (const xexpression<E>& e)
398354 {
@@ -449,23 +405,33 @@ namespace xt
449405 return detail::flatten_argsort_impl<E, result_type>(e);
450406 }
451407
452- if (ax != detail::leading_axis (de) )
408+ const auto argsort = []( auto res_begin, auto res_end, auto ev_begin, auto /* ev_end */ )
453409 {
454- dynamic_shape<std::size_t > permutation, reverse_permutation;
455- std::tie (permutation, reverse_permutation) = detail::get_permutations (de.dimension (), ax, de.layout ());
410+ std::iota (res_begin, res_end, 0 );
411+ std::sort (
412+ res_begin,
413+ res_end,
414+ [&ev_begin](auto const & i, auto const & j)
415+ {
416+ return *(ev_begin + i) < *(ev_begin + j);
417+ }
418+ );
419+ };
456420
457- eval_type ev = transpose (de, permutation);
458- result_type res = result_type::from_shape (ev.shape ());
459- detail::argsort_over_leading_axis (ev, res);
460- res = transpose (res, reverse_permutation);
461- return res;
462- }
463- else
421+ if (ax == detail::leading_axis (de))
464422 {
465423 result_type res = result_type::from_shape (de.shape ());
466- detail::argsort_over_leading_axis ( de, res );
424+ detail::call_over_leading_axis (res, de, argsort );
467425 return res;
468426 }
427+
428+ dynamic_shape<std::size_t > permutation, reverse_permutation;
429+ std::tie (permutation, reverse_permutation) = detail::get_permutations (de.dimension (), ax, de.layout ());
430+ eval_type ev = transpose (de, permutation);
431+ result_type res = result_type::from_shape (ev.shape ());
432+ detail::call_over_leading_axis (res, ev, argsort);
433+ res = transpose (res, reverse_permutation);
434+ return res;
469435 }
470436
471437 /* ***********************************************
0 commit comments