@@ -807,6 +807,66 @@ namespace xt
807807 );
808808 }
809809
810+ namespace detail
811+ {
812+ // if shape is signed do the check
813+ template <class S , class R >
814+ using do_shape_recalculation = std::
815+ enable_if_t <std::is_signed<get_value_type_t <typename std::decay<S>::type>>::value, R>;
816+
817+ // if shape is unsigned pass through
818+ template <class S , class R >
819+ using no_shape_recalculation = std::
820+ enable_if_t <!std::is_signed<get_value_type_t <typename std::decay<S>::type>>::value, R>;
821+
822+ template <typename T>
823+ inline no_shape_recalculation<T, T> make_unsigned_shape (T shape)
824+ {
825+ return shape;
826+ }
827+
828+ template <typename S, typename Enable = void >
829+ struct rebind_shape ;
830+
831+ template <class S >
832+ struct rebind_shape <S, std::enable_if_t <!std::is_signed<get_value_type_t <typename std::decay_t <S>>>::value>>
833+ {
834+ using Shape = S;
835+ };
836+
837+ template <class S >
838+ struct rebind_shape <S, std::enable_if_t <std::is_signed<get_value_type_t <typename std::decay_t <S>>>::value>>
839+ {
840+ using Shape = rebind_container_t <size_t , S>;
841+ };
842+
843+ template <class S , do_shape_recalculation<S, bool > = true >
844+ inline auto recalculate_shape_impl (S& shape, size_t size)
845+ {
846+ using value_type = get_value_type_t <typename std::decay_t <S>>;
847+ auto iter = std::find (shape.begin (), shape.end (), -1 );
848+ if (iter != std::end (shape))
849+ {
850+ const auto total = std::accumulate (shape.cbegin (), shape.cend (), -1 , std::multiplies<int >{});
851+ const auto missing_dimension = size / total;
852+ (*iter) = static_cast <value_type>(missing_dimension);
853+ }
854+ return shape;
855+ }
856+
857+ template <class S , no_shape_recalculation<S, bool > = true >
858+ inline auto recalculate_shape_impl (S& shape, size_t )
859+ {
860+ return shape;
861+ }
862+
863+ template <class S >
864+ inline auto recalculate_shape (S&& shape, size_t size)
865+ {
866+ return recalculate_shape_impl (shape, size);
867+ }
868+ }
869+
810870 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E , class S >
811871 inline auto reshape_view (E&& e, S&& shape)
812872 {
@@ -815,18 +875,26 @@ namespace xt
815875 " traversal has to be row or column major"
816876 );
817877
818- using shape_type = std::decay_t <S>;
819- get_strides_t <shape_type> strides;
878+ using shape_type = std::decay_t <decltype (shape)>;
879+ using unsigned_shape_type = typename detail::rebind_shape<shape_type>::Shape;
880+ get_strides_t <unsigned_shape_type> strides;
820881
882+ detail::recalculate_shape (shape, e.size ());
821883 xt::resize_container (strides, shape.size ());
822884 compute_strides (shape, L, strides);
823885 constexpr auto computed_layout = std::decay_t <E>::static_layout == L ? L : layout_type::dynamic;
824886 using view_type = xstrided_view<
825887 xclosure_t <E>,
826- shape_type ,
888+ unsigned_shape_type ,
827889 computed_layout,
828890 detail::flat_adaptor_getter<xclosure_t <E>, L>>;
829- return view_type (std::forward<E>(e), std::forward<S>(shape), std::move (strides), 0 , e.layout ());
891+ return view_type (
892+ std::forward<E>(e),
893+ xtl::forward_sequence<unsigned_shape_type, S>(shape),
894+ std::move (strides),
895+ 0 ,
896+ e.layout ()
897+ );
830898 }
831899
832900 /* *
@@ -858,7 +926,7 @@ namespace xt
858926 template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E , class I , std::size_t N>
859927 inline auto reshape_view (E&& e, const I (&shape)[N])
860928 {
861- using shape_type = std::array<std:: size_t , N>;
929+ using shape_type = std::array<I , N>;
862930 return reshape_view<L>(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype (shape)>(shape));
863931 }
864932}
0 commit comments