Skip to content

Commit 9ed0ad9

Browse files
add as_strided free function
add helpers for readable sfinae use aliases to make sfinae work with all compilers add docstrings add tests
1 parent 1c9b74a commit 9ed0ad9

2 files changed

Lines changed: 360 additions & 0 deletions

File tree

include/xtensor/xeval.hpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,122 @@ namespace xt
4747
{
4848
return std::forward<T>(t);
4949
}
50+
51+
namespace detail
52+
{
53+
/************************************
54+
* layout_remove_any implementation *
55+
************************************/
56+
57+
constexpr layout_type layout_remove_any(const layout_type layout)
58+
{
59+
return layout == layout_type::any ? XTENSOR_DEFAULT_LAYOUT : layout;
60+
}
61+
62+
/**********************************
63+
* has_same_layout implementation *
64+
**********************************/
65+
66+
template <layout_type L = layout_type::any, class E>
67+
constexpr bool has_same_layout()
68+
{
69+
return (std::decay_t<E>::static_layout == L) || (L == layout_type::any);
70+
}
71+
72+
template <layout_type L = layout_type::any, class E>
73+
constexpr bool has_same_layout(E&&)
74+
{
75+
return has_same_layout<L, E>();
76+
}
77+
78+
template <class E1, class E2>
79+
constexpr bool has_same_layout(E1&&, E2&&)
80+
{
81+
return has_same_layout<std::decay_t<E1>::static_layout, E2>();
82+
}
83+
84+
/*********************************
85+
* has_fixed_dims implementation *
86+
*********************************/
87+
88+
template <class E>
89+
constexpr bool has_fixed_dims()
90+
{
91+
return detail::is_array<typename std::decay_t<E>::shape_type>::value;
92+
}
93+
94+
template <class E>
95+
constexpr bool has_fixed_dims(E&&)
96+
{
97+
return has_fixed_dims<E>();
98+
}
99+
100+
/****************************************
101+
* as_xarray_container_t implementation *
102+
****************************************/
103+
104+
template <class E, layout_type L>
105+
using as_xarray_container_t = xarray<typename std::decay_t<E>::value_type, detail::layout_remove_any(L)>;
106+
107+
/*****************************************
108+
* as_xtensor_container_t implementation *
109+
*****************************************/
110+
111+
template <class E, layout_type L>
112+
using as_xtensor_container_t = xtensor<typename std::decay_t<E>::value_type,
113+
std::tuple_size<typename std::decay_t<E>::shape_type>::value,
114+
detail::layout_remove_any(L)>;
115+
}
116+
117+
/**
118+
* Force evaluation of xexpression not providing a data interface
119+
* and convert to the required layout.
120+
*
121+
* @warning This function should be used in a local context only.
122+
* Returning the value returned by this function could lead to a dangling reference.
123+
*
124+
* @return The expression when it already provides a data interface with the correct layout,
125+
* an evaluated xarray or xtensor depending on shape type otherwise.
126+
*
127+
* \code{.cpp}
128+
* xarray<double, layout_type::row_major> a = {1,2,3,4};
129+
* auto&& b = xt::as_strided(a); // b is a reference to a, no copy!
130+
* auto&& c = xt::as_strided<layout_type::column_major>(a); // b is xarray<double> with the required layout
131+
* auto&& a_cast = xt::cast<int>(a); // a_cast is an xexpression
132+
* auto&& d = xt::as_strided(a_cast); // d is xarray<int>, not an xexpression
133+
* auto&& e = xt::as_strided<layout_type::column_major>(a_cast); // d is xarray<int> with the required layout
134+
* \endcode
135+
*/
136+
template <layout_type L = layout_type::any, class E>
137+
inline auto as_strided(E&& e)
138+
-> std::enable_if_t<has_data_interface<std::decay_t<E>>::value
139+
&& detail::has_same_layout<L, E>(),
140+
E&&>
141+
{
142+
return std::forward<E>(e);
143+
}
144+
145+
/// @cond DOXYGEN_INCLUDE_SFINAE
146+
template <layout_type L = layout_type::any, class E>
147+
inline auto as_strided(E&& e)
148+
-> std::enable_if_t<(!(has_data_interface<std::decay_t<E>>::value
149+
&& detail::has_same_layout<L, E>()))
150+
&& detail::has_fixed_dims<E>(),
151+
detail::as_xtensor_container_t<E, L>>
152+
{
153+
return e;
154+
}
155+
156+
/// @cond DOXYGEN_INCLUDE_SFINAE
157+
template <layout_type L = layout_type::any, class E>
158+
inline auto as_strided(E&& e)
159+
-> std::enable_if_t<(!(has_data_interface<std::decay_t<E>>::value
160+
&& detail::has_same_layout<L, E>()))
161+
&& (!detail::has_fixed_dims<E>()),
162+
detail::as_xarray_container_t<E, L>>
163+
{
164+
return e;
165+
}
50166
}
51167

52168
#endif

test/test_xeval.cpp

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,248 @@ namespace xt
5858
bool type_eq_2 = std::is_same<decltype(i), xtensor<int, 1>&&>::value;
5959
EXPECT_TRUE(type_eq_2);
6060
}
61+
62+
63+
#define EXPECT_LAYOUT(EXPRESSION, LAYOUT) \
64+
EXPECT_TRUE((decltype(EXPRESSION)::static_layout == LAYOUT))
65+
66+
#define HAS_DATA_INTERFACE(EXPRESSION) \
67+
has_data_interface<std::decay_t<decltype(EXPRESSION)>>::value
68+
69+
#define EXPECT_XARRAY(EXPRESSION) \
70+
EXPECT_TRUE(!detail::is_array< \
71+
typename std::decay_t<decltype(EXPRESSION) \
72+
>::shape_type>::value)
73+
74+
#define EXPECT_XTENSOR(EXPRESSION) \
75+
EXPECT_TRUE(detail::is_array< \
76+
typename std::decay_t<decltype(EXPRESSION) \
77+
>::shape_type>::value == true)
78+
79+
80+
TEST(utils, has_same_layout)
81+
{
82+
xt::xtensor<double, 1, layout_type::row_major> ten1 {1., 2., 3.2};
83+
EXPECT_TRUE(detail::has_same_layout<layout_type::row_major>(ten1));
84+
EXPECT_FALSE(detail::has_same_layout<layout_type::column_major>(ten1));
85+
EXPECT_TRUE(detail::has_same_layout<layout_type::any>(ten1));
86+
87+
xt::xtensor<double, 1, layout_type::column_major> ten2 {1., 2., 3.2};
88+
EXPECT_TRUE(detail::has_same_layout<layout_type::column_major>(ten2));
89+
EXPECT_FALSE(detail::has_same_layout<layout_type::row_major>(ten2));
90+
EXPECT_TRUE(detail::has_same_layout<layout_type::any>(ten2));
91+
92+
EXPECT_FALSE((detail::has_same_layout(ten1, ten2)));
93+
EXPECT_TRUE((detail::has_same_layout(ten1, xt::xtensor<double, 1, layout_type::row_major>({1., 2., 3.2}))));
94+
EXPECT_TRUE((detail::has_same_layout(ten2, xt::xtensor<double, 1, layout_type::column_major>({1., 2., 3.2}))));
95+
}
96+
97+
TEST(utils, has_fixed_dims)
98+
{
99+
xt::xtensor<double, 1> ten {1., 2., 3.2};
100+
EXPECT_TRUE((detail::has_fixed_dims<xt::xtensor<double, 1>>()));
101+
EXPECT_TRUE(detail::has_fixed_dims(ten));
102+
103+
xt::xarray<double> arr {1., 2., 3.2};
104+
EXPECT_FALSE((detail::has_fixed_dims<xt::xarray<double>>()));
105+
EXPECT_FALSE(detail::has_fixed_dims(arr));
106+
}
107+
108+
TEST(utils, as_xarray_container_t)
109+
{
110+
using array_type = xt::xarray<double, layout_type::row_major>;
111+
112+
detail::as_xarray_container_t<array_type, layout_type::column_major> arr;
113+
EXPECT_XARRAY(arr);
114+
EXPECT_LAYOUT(arr, layout_type::column_major);
115+
}
116+
117+
TEST(utils, as_xtensor_container_t)
118+
{
119+
using tensor_type = xt::xtensor<double, 1, layout_type::row_major>;
120+
121+
detail::as_xtensor_container_t<tensor_type, layout_type::column_major> ten;
122+
EXPECT_XTENSOR(ten);
123+
EXPECT_LAYOUT(ten, layout_type::column_major);
124+
}
125+
126+
namespace testing
127+
{ // avoid collision with fixture class
128+
129+
class as_strided: public ::testing::Test
130+
{
131+
protected:
132+
133+
xt::xtensor<double, 1, layout_type::row_major> ten {1., 2., 3.2};
134+
xt::xarray<double, layout_type::row_major> arr {1., 2., 3.2};
135+
};
136+
137+
TEST_F(as_strided, array_reference)
138+
{
139+
EXPECT_LAYOUT(arr, layout_type::row_major);
140+
EXPECT_TRUE(HAS_DATA_INTERFACE(arr));
141+
EXPECT_XARRAY(arr);
142+
EXPECT_EQ(arr(2), 3.2);
143+
}
144+
145+
TEST_F(as_strided, tensor_reference)
146+
{
147+
EXPECT_LAYOUT(ten, layout_type::row_major);
148+
EXPECT_TRUE(HAS_DATA_INTERFACE(ten));
149+
EXPECT_XTENSOR(ten);
150+
EXPECT_EQ(ten(2), 3.2);
151+
}
152+
153+
TEST_F(as_strided, array_layout_unchanged)
154+
{
155+
auto res_lvalue = xt::as_strided<layout_type::row_major>(arr);
156+
EXPECT_LAYOUT(res_lvalue, layout_type::row_major);
157+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_lvalue));
158+
EXPECT_XARRAY(res_lvalue);
159+
EXPECT_EQ(res_lvalue(2), 3.2);
160+
161+
auto res_rvalue = xt::as_strided<layout_type::row_major>(
162+
xt::xarray<double, layout_type::row_major>({1., 2., 3.2})
163+
);
164+
EXPECT_LAYOUT(res_rvalue, layout_type::row_major);
165+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_rvalue));
166+
EXPECT_XARRAY(res_rvalue);
167+
EXPECT_EQ(res_rvalue(2), 3.2);
168+
}
169+
170+
TEST_F(as_strided, tensor_layout_unchanged)
171+
{
172+
auto res_lvalue = xt::as_strided<layout_type::row_major>(ten);
173+
EXPECT_LAYOUT(res_lvalue, layout_type::row_major);
174+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_lvalue));
175+
EXPECT_XTENSOR(res_lvalue);
176+
EXPECT_EQ(res_lvalue(2), 3.2);
177+
178+
auto res_rvalue = xt::as_strided<layout_type::row_major>(
179+
xt::xtensor<double, 1, layout_type::row_major>({1., 2., 3.2})
180+
);
181+
EXPECT_LAYOUT(res_rvalue, layout_type::row_major);
182+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_rvalue));
183+
EXPECT_XTENSOR(res_rvalue);
184+
EXPECT_EQ(res_rvalue(2), 3.2);
185+
}
186+
187+
TEST_F(as_strided, array_layout_change)
188+
{
189+
auto res_lvalue = xt::as_strided<layout_type::column_major>(arr);
190+
EXPECT_LAYOUT(res_lvalue, layout_type::column_major);
191+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_lvalue));
192+
EXPECT_XARRAY(res_lvalue);
193+
EXPECT_EQ(res_lvalue(2), 3.2);
194+
195+
auto res_rvalue = xt::as_strided<layout_type::column_major>(
196+
xt::xarray<double, layout_type::row_major>({1., 2., 3.2})
197+
);
198+
EXPECT_LAYOUT(res_rvalue, layout_type::column_major);
199+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_rvalue));
200+
EXPECT_XARRAY(res_rvalue);
201+
EXPECT_EQ(res_rvalue(2), 3.2);
202+
}
203+
204+
TEST_F(as_strided, tensor_layout_changed)
205+
{
206+
auto res_lvalue = xt::as_strided<layout_type::column_major>(ten);
207+
EXPECT_LAYOUT(res_lvalue, layout_type::column_major);
208+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_lvalue));
209+
EXPECT_XTENSOR(res_lvalue);
210+
EXPECT_EQ(res_lvalue(2), 3.2);
211+
212+
auto res_rvalue = xt::as_strided<layout_type::column_major>(
213+
xt::xtensor<double, 1, layout_type::row_major>({1., 2., 3.2})
214+
);
215+
EXPECT_LAYOUT(res_rvalue, layout_type::column_major);
216+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_rvalue));
217+
EXPECT_XTENSOR(res_rvalue);
218+
EXPECT_EQ(res_rvalue(2), 3.2);
219+
}
220+
221+
TEST_F(as_strided, array_no_data_interface_layout_unchanged)
222+
{
223+
auto array_cast = xt::cast<int>(arr);
224+
EXPECT_FALSE(HAS_DATA_INTERFACE(array_cast));
225+
EXPECT_XARRAY(array_cast);
226+
227+
auto res_lvalue = xt::as_strided<layout_type::row_major>(array_cast);
228+
EXPECT_LAYOUT(res_lvalue, layout_type::row_major);
229+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_lvalue));
230+
EXPECT_XARRAY(res_lvalue);
231+
EXPECT_EQ(res_lvalue(2), 3);
232+
233+
auto res_rvalue = xt::as_strided<layout_type::row_major>(
234+
xt::cast<int>(arr)
235+
);
236+
EXPECT_LAYOUT(res_rvalue, layout_type::row_major);
237+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_rvalue));
238+
EXPECT_XARRAY(res_rvalue);
239+
EXPECT_EQ(res_rvalue(2), 3);
240+
}
241+
242+
TEST_F(as_strided, tensor_no_data_interface_layout_unchanged)
243+
{
244+
auto tensor_cast = xt::cast<int>(ten);
245+
EXPECT_FALSE(HAS_DATA_INTERFACE(tensor_cast));
246+
EXPECT_XTENSOR(tensor_cast);
247+
248+
auto res_lvalue = xt::as_strided<layout_type::row_major>(tensor_cast);
249+
EXPECT_LAYOUT(res_lvalue, layout_type::row_major);
250+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_lvalue));
251+
EXPECT_XTENSOR(res_lvalue);
252+
EXPECT_EQ(res_lvalue(2), 3);
253+
254+
auto res_rvalue = xt::as_strided<layout_type::row_major>(
255+
xt::cast<int>(ten)
256+
);
257+
EXPECT_LAYOUT(res_rvalue, layout_type::row_major);
258+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_rvalue));
259+
EXPECT_XTENSOR(res_rvalue);
260+
EXPECT_EQ(res_rvalue(2), 3);
261+
}
262+
263+
TEST_F(as_strided, array_no_data_interface_layout_changed)
264+
{
265+
auto array_cast = xt::cast<int>(arr);
266+
EXPECT_FALSE(HAS_DATA_INTERFACE(array_cast));
267+
EXPECT_XARRAY(array_cast);
268+
269+
auto res_lvalue = xt::as_strided<layout_type::column_major>(array_cast);
270+
EXPECT_LAYOUT(res_lvalue, layout_type::column_major);
271+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_lvalue));
272+
EXPECT_XARRAY(res_lvalue);
273+
EXPECT_EQ(res_lvalue(2), 3);
274+
275+
auto res_rvalue = xt::as_strided<layout_type::column_major>(
276+
xt::cast<int>(arr)
277+
);
278+
EXPECT_LAYOUT(res_rvalue, layout_type::column_major);
279+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_rvalue));
280+
EXPECT_XARRAY(res_rvalue);
281+
EXPECT_EQ(res_rvalue(2), 3);
282+
}
283+
284+
TEST_F(as_strided, tensor_no_data_interface_layout_changed)
285+
{
286+
auto tensor_cast = xt::cast<int>(ten);
287+
EXPECT_FALSE(HAS_DATA_INTERFACE(tensor_cast));
288+
EXPECT_XTENSOR(tensor_cast);
289+
290+
auto res_lvalue = xt::as_strided<layout_type::column_major>(tensor_cast);
291+
EXPECT_LAYOUT(res_lvalue, layout_type::column_major);
292+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_lvalue));
293+
EXPECT_XTENSOR(res_lvalue);
294+
EXPECT_EQ(res_lvalue(2), 3);
295+
296+
auto res_rvalue = xt::as_strided<layout_type::column_major>(
297+
xt::cast<int>(ten)
298+
);
299+
EXPECT_LAYOUT(res_rvalue, layout_type::column_major);
300+
EXPECT_TRUE(HAS_DATA_INTERFACE(res_rvalue));
301+
EXPECT_XTENSOR(res_rvalue);
302+
EXPECT_EQ(res_rvalue(2), 3);
303+
}
304+
}
61305
}

0 commit comments

Comments
 (0)