@@ -46,35 +46,39 @@ namespace xt
4646 return stride != 0 ? stride : static_cast <std::ptrdiff_t >(shape);
4747 }
4848
49- template <class E , class F >
50- inline void call_over_leading_axis ( E& ev, F&& fct )
49+ template <class E >
50+ inline std:: ptrdiff_t get_secondary_stride ( const E& ev)
5151 {
52- std::size_t n_iters = 1 ;
53- std::ptrdiff_t secondary_stride;
52+ if (ev.layout () == layout_type::row_major)
53+ {
54+ return adjust_secondary_stride (ev.strides ()[ev.dimension () - 2 ], *(ev.shape ().end () - 1 ));
55+ }
56+
57+ return adjust_secondary_stride (ev.strides ()[1 ], *(ev.shape ().begin ()));
58+ }
5459
60+ template <class E >
61+ inline std::size_t leading_axis_n_iters (const E& ev)
62+ {
5563 if (ev.layout () == layout_type::row_major)
5664 {
57- n_iters = std::accumulate (
65+ return std::accumulate (
5866 ev.shape ().begin (),
5967 ev.shape ().end () - 1 ,
6068 std::size_t (1 ),
6169 std::multiplies<>()
6270 );
63- secondary_stride = adjust_secondary_stride (
64- ev.strides ()[ev.dimension () - 2 ],
65- *(ev.shape ().end () - 1 )
66- );
67- }
68- else
69- {
70- n_iters = std::accumulate (
71- ev.shape ().begin () + 1 ,
72- ev.shape ().end (),
73- std::size_t (1 ),
74- std::multiplies<>()
75- );
76- secondary_stride = adjust_secondary_stride (ev.strides ()[1 ], *(ev.shape ().begin ()));
7771 }
72+ return std::accumulate (ev.shape ().begin () + 1 , ev.shape ().end (), std::size_t (1 ), std::multiplies<>());
73+ }
74+
75+ template <class E , class F >
76+ inline void call_over_leading_axis (E& ev, F&& fct)
77+ {
78+ XTENSOR_ASSERT (ev.dimension () >= 2 );
79+
80+ const std::size_t n_iters = leading_axis_n_iters (ev);
81+ const std::ptrdiff_t secondary_stride = get_secondary_stride (ev);
7882
7983 const auto begin = ev.data ();
8084 const auto end = begin + n_iters * secondary_stride;
@@ -87,37 +91,13 @@ namespace xt
8791 template <class E1 , class E2 , class F >
8892 inline void call_over_leading_axis (E1 & e1 , E2 & e2 , F&& fct)
8993 {
90- std:: size_t n_iters = 1 ;
91- std:: ptrdiff_t secondary_stride1, secondary_stride2 ;
94+ XTENSOR_ASSERT ( e1 . dimension () >= 2 ) ;
95+ XTENSOR_ASSERT ( e1 . dimension () == e2 . dimension ()) ;
9296
93- if (e1 .layout () == layout_type::row_major)
94- {
95- n_iters = std::accumulate (
96- e1 .shape ().begin (),
97- e1 .shape ().end () - 1 ,
98- std::size_t (1 ),
99- std::multiplies<>()
100- );
101- secondary_stride1 = adjust_secondary_stride (
102- e1 .strides ()[e1 .dimension () - 2 ],
103- *(e1 .shape ().end () - 1 )
104- );
105- secondary_stride2 = adjust_secondary_stride (
106- e2 .strides ()[e2 .dimension () - 2 ],
107- *(e2 .shape ().end () - 2 )
108- );
109- }
110- else
111- {
112- n_iters = std::accumulate (
113- e1 .shape ().begin () + 1 ,
114- e1 .shape ().end (),
115- std::size_t (1 ),
116- std::multiplies<>()
117- );
118- secondary_stride1 = adjust_secondary_stride (e1 .strides ()[1 ], *(e1 .shape ().begin ()));
119- secondary_stride2 = adjust_secondary_stride (e2 .strides ()[1 ], *(e2 .shape ().begin ()));
120- }
97+ const std::size_t n_iters = leading_axis_n_iters (e1 );
98+ std::ptrdiff_t const secondary_stride1 = get_secondary_stride (e1 );
99+ std::ptrdiff_t const secondary_stride2 = get_secondary_stride (e2 );
100+ XTENSOR_ASSERT (secondary_stride1 == secondary_stride2);
121101
122102 const auto begin1 = e1 .data ();
123103 const auto end1 = begin1 + n_iters * secondary_stride1;
@@ -192,7 +172,7 @@ namespace xt
192172 }
193173
194174 dynamic_shape<std::size_t > permutation, reverse_permutation;
195- std::tie (permutation, reverse_permutation) = get_permutations (e.dimension (), axis , e.layout ());
175+ std::tie (permutation, reverse_permutation) = get_permutations (e.dimension (), ax , e.layout ());
196176 R res = transpose (e, permutation);
197177 detail::call_over_leading_axis (res, std::forward<F>(lambda));
198178 res = transpose (res, reverse_permutation);
0 commit comments