Skip to content

Commit e823ea6

Browse files
committed
Added dimension deduction for -1 shape parameters
1 parent 4beafb6 commit e823ea6

6 files changed

Lines changed: 94 additions & 19 deletions

File tree

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ jobs:
170170
ninja
171171
172172
- name: Configure using CMake
173-
run: cmake -Bbuild -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DBUILD_TESTS=ON -G Ninja
173+
run: cmake -Bbuild -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DBUILD_TESTS=ON -G Ninja
174174

175175
- name: Install
176176
working-directory: build

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# The full license is in the file LICENSE, distributed with this software. #
88
############################################################################
99

10-
cmake_minimum_required(VERSION 3.1)
10+
cmake_minimum_required(VERSION 3.5)
1111
project(xtensor CXX)
1212

1313
set(XTENSOR_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include)

benchmark/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# The full license is in the file LICENSE, distributed with this software. #
77
############################################################################
88

9-
cmake_minimum_required(VERSION 3.1)
9+
cmake_minimum_required(VERSION 3.5)
1010

1111
if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
1212
project(xtensor-benchmark)

include/xtensor/xstrided_view.hpp

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,66 @@ namespace xt
807807
);
808808
}
809809

810+
namespace detail
811+
{
812+
// if shape is signed do the check
813+
template <class S, class R>
814+
using do_shape_recalculation = std::
815+
enable_if_t<std::is_signed<get_value_type_t<typename std::decay<S>::type>>::value, R>;
816+
817+
// if shape is unsigned pass through
818+
template <class S, class R>
819+
using no_shape_recalculation = std::
820+
enable_if_t<!std::is_signed<get_value_type_t<typename std::decay<S>::type>>::value, R>;
821+
822+
template <typename T>
823+
inline no_shape_recalculation<T, T> make_unsigned_shape(T shape)
824+
{
825+
return shape;
826+
}
827+
828+
template <typename S, typename Enable = void>
829+
struct rebind_shape;
830+
831+
template <class S>
832+
struct rebind_shape<S, std::enable_if_t<!std::is_signed<get_value_type_t<typename std::decay_t<S>>>::value>>
833+
{
834+
using Shape = S;
835+
};
836+
837+
template <class S>
838+
struct rebind_shape<S, std::enable_if_t<std::is_signed<get_value_type_t<typename std::decay_t<S>>>::value>>
839+
{
840+
using Shape = rebind_container_t<size_t, S>;
841+
};
842+
843+
template <class S, do_shape_recalculation<S, bool> = true>
844+
inline auto recalculate_shape_impl(S& shape, size_t size)
845+
{
846+
using value_type = get_value_type_t<typename std::decay_t<S>>;
847+
auto iter = std::find(shape.begin(), shape.end(), -1);
848+
if (iter != std::end(shape))
849+
{
850+
const auto total = std::accumulate(shape.cbegin(), shape.cend(), -1, std::multiplies<int>{});
851+
const auto missing_dimension = size / total;
852+
(*iter) = static_cast<value_type>(missing_dimension);
853+
}
854+
return shape;
855+
}
856+
857+
template <class S, no_shape_recalculation<S, bool> = true>
858+
inline auto recalculate_shape_impl(S& shape, size_t)
859+
{
860+
return shape;
861+
}
862+
863+
template <class S>
864+
inline auto recalculate_shape(S&& shape, size_t size)
865+
{
866+
return recalculate_shape_impl(shape, size);
867+
}
868+
}
869+
810870
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E, class S>
811871
inline auto reshape_view(E&& e, S&& shape)
812872
{
@@ -815,18 +875,26 @@ namespace xt
815875
"traversal has to be row or column major"
816876
);
817877

818-
using shape_type = std::decay_t<S>;
819-
get_strides_t<shape_type> strides;
878+
using shape_type = std::decay_t<decltype(shape)>;
879+
using unsigned_shape_type = typename detail::rebind_shape<shape_type>::Shape;
880+
get_strides_t<unsigned_shape_type> strides;
820881

882+
detail::recalculate_shape(shape, e.size());
821883
xt::resize_container(strides, shape.size());
822884
compute_strides(shape, L, strides);
823885
constexpr auto computed_layout = std::decay_t<E>::static_layout == L ? L : layout_type::dynamic;
824886
using view_type = xstrided_view<
825887
xclosure_t<E>,
826-
shape_type,
888+
unsigned_shape_type,
827889
computed_layout,
828890
detail::flat_adaptor_getter<xclosure_t<E>, L>>;
829-
return view_type(std::forward<E>(e), std::forward<S>(shape), std::move(strides), 0, e.layout());
891+
return view_type(
892+
std::forward<E>(e),
893+
xtl::forward_sequence<unsigned_shape_type, S>(shape),
894+
std::move(strides),
895+
0,
896+
e.layout()
897+
);
830898
}
831899

832900
/**
@@ -858,7 +926,7 @@ namespace xt
858926
template <layout_type L = XTENSOR_DEFAULT_TRAVERSAL, class E, class I, std::size_t N>
859927
inline auto reshape_view(E&& e, const I (&shape)[N])
860928
{
861-
using shape_type = std::array<std::size_t, N>;
929+
using shape_type = std::array<I, N>;
862930
return reshape_view<L>(std::forward<E>(e), xtl::forward_sequence<shape_type, decltype(shape)>(shape));
863931
}
864932
}

test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# The full license is in the file LICENSE, distributed with this software. #
88
############################################################################
99

10-
cmake_minimum_required(VERSION 3.1)
10+
cmake_minimum_required(VERSION 3.5)
1111

1212
if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
1313
project(xtensor-test CXX)

test/test_xstrided_view.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -696,24 +696,31 @@ namespace xt
696696
EXPECT_EQ(av, e);
697697
EXPECT_EQ(av, a);
698698

699-
bool truthy;
700-
truthy = std::is_same<
701-
typename decltype(xv)::temporary_type,
702-
xtensor_fixed<double, xshape<3, 3>, XTENSOR_DEFAULT_LAYOUT>>();
703-
EXPECT_TRUE(truthy);
704-
705-
truthy = std::is_same<typename decltype(av)::temporary_type, xtensor<double, 2, XTENSOR_DEFAULT_LAYOUT>>(
699+
static_assert(
700+
std::is_same<
701+
typename decltype(xv)::temporary_type,
702+
xtensor_fixed<double, xshape<3, 3>, XTENSOR_DEFAULT_LAYOUT>>::value,
703+
"Container types do not match"
704+
);
705+
static_assert(
706+
std::is_same<typename decltype(av)::temporary_type, xtensor<double, 2, XTENSOR_DEFAULT_LAYOUT>>::value,
707+
"Container types do not match"
708+
);
709+
static_assert(
710+
std::is_same<typename decltype(av)::shape_type, typename decltype(e)::shape_type>::value,
711+
"Shape types do not match"
706712
);
707-
EXPECT_TRUE(truthy);
708-
truthy = std::is_same<typename decltype(av)::shape_type, typename decltype(e)::shape_type>::value;
709-
EXPECT_TRUE(truthy);
710713

711714
xarray<int> xa = {{1, 2, 3}, {4, 5, 6}};
712715
std::vector<std::size_t> new_shape = {3, 2};
713716
auto xrv = reshape_view(xa, new_shape);
714717

715718
xarray<int> xres = {{1, 2}, {3, 4}, {5, 6}};
716719
EXPECT_EQ(xrv, xres);
720+
721+
auto nv = xt::reshape_view<XTENSOR_DEFAULT_LAYOUT>(a, {-1, 3});
722+
std::vector<size_t> expected_shape({3, 3});
723+
EXPECT_TRUE(std::equal(nv.shape().begin(), nv.shape().end(), expected_shape.begin()));
717724
}
718725

719726
TEST(xstrided_view, reshape_view_assign)

0 commit comments

Comments
 (0)