@@ -494,9 +494,10 @@ namespace xt
494494 using size_type = std::size_t ;
495495 using value_type = xtl::promote_type_t <typename std::decay_t <CT>::value_type...>;
496496
497- template <class S >
498- inline value_type access (const tuple_type& t, size_type axis, S index ) const
497+ template <class It >
498+ inline value_type access (const tuple_type& t, size_type axis, It first, It last ) const
499499 {
500+ xindex index (first, last);
500501 auto match = [&index, axis](auto & arr)
501502 {
502503 if (index[axis] >= arr.shape ()[axis])
@@ -533,48 +534,59 @@ namespace xt
533534 using size_type = std::size_t ;
534535 using value_type = xtl::promote_type_t <typename std::decay_t <CT>::value_type...>;
535536
536- template <class S >
537- inline value_type access (const tuple_type& t, size_type axis, S index ) const
537+ template <class It >
538+ inline value_type access (const tuple_type& t, size_type axis, It first, It last ) const
538539 {
539- auto get_item = [&index ](auto & arr)
540+ auto get_item = [&](auto & arr)
540541 {
541- return arr[index];
542+ size_t offset = 0 ;
543+ const size_t end = arr.dimension ();
544+ bool after_axis = false ;
545+ for (size_t i = 0 ; i < end; i++)
546+ {
547+ if (i == axis)
548+ {
549+ after_axis = true ;
550+ }
551+ const auto & stride = arr.strides ()[i];
552+ const auto len = (*(first + i + after_axis));
553+ offset += len * stride;
554+ }
555+ const auto element = arr.begin () + offset;
556+ return *element;
542557 };
543- size_type i = index[axis];
544- index.erase (index.begin () + std::ptrdiff_t (axis));
558+ size_type i = *(first + axis);
545559 return apply<value_type>(i, get_item, t);
546560 }
547561 };
548562
549563 template <class ... CT>
550- class vstack_access : private concatenate_access <CT...>,
551- private stack_access<CT...>
564+ class vstack_access
552565 {
553566 public:
554-
555567 using tuple_type = std::tuple<CT...>;
556568 using size_type = std::size_t ;
557569 using value_type = xtl::promote_type_t <typename std::decay_t <CT>::value_type...>;
558570
559- using concatenate_base = concatenate_access<CT...>;
560- using stack_base = stack_access<CT...>;
561-
562- template <class S >
563- inline value_type access (const tuple_type& t, size_type axis, S index) const
571+ template <class It >
572+ inline value_type access (const tuple_type& t, size_type axis, It first, It last) const
564573 {
565574 if (std::get<0 >(t).dimension () == 1 )
566575 {
567- return stack_base:: access (t, axis, index );
576+ return stack. access (t, axis, first, last );
568577 }
569578 else
570579 {
571- return concatenate_base:: access (t, axis, index );
580+ return concatonate. access (t, axis, first, last );
572581 }
573582 }
583+ private:
584+ concatenate_access<CT...> concatonate;
585+ stack_access<CT...> stack;
574586 };
575587
576588 template <template <class ...> class F , class ... CT>
577- class concatenate_invoker : private F <CT...>
589+ class concatenate_invoker
578590 {
579591 public:
580592
@@ -592,18 +604,18 @@ namespace xt
592604 inline value_type operator ()(Args... args) const
593605 {
594606 // TODO: avoid memory allocation
595- return this ->access (m_t , m_axis, xindex ({static_cast <size_type>(args)...}));
607+ xindex index ({static_cast <size_type>(args)...});
608+ return access_method.access (m_t , m_axis, index.begin (), index.end ());
596609 }
597610
598611 template <class It >
599612 inline value_type element (It first, It last) const
600613 {
601- // TODO: avoid memory allocation
602- return this ->access (m_t , m_axis, xindex (first, last));
614+ return access_method.access (m_t , m_axis, first, last);
603615 }
604616
605617 private:
606-
618+ F<CT...> access_method;
607619 tuple_type m_t ;
608620 size_type m_axis;
609621 };
0 commit comments