Skip to content

Commit d4bedf5

Browse files
committed
Replace XTENSOR_THROW with XTENSOR_ASSERT in xt::random::choice
1 parent b5b7371 commit d4bedf5

2 files changed

Lines changed: 6 additions & 37 deletions

File tree

include/xtensor/xrandom.hpp

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -767,14 +767,8 @@ namespace xt
767767
xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n, bool replace, E& engine)
768768
{
769769
const auto& de = e.derived_cast();
770-
if (de.dimension() != 1)
771-
{
772-
XTENSOR_THROW(std::runtime_error, "Sample expression must be 1 dimensional");
773-
}
774-
if (de.size() < n && !replace)
775-
{
776-
XTENSOR_THROW(std::runtime_error, "If replace is false, then the sample expression's size must be > n");
777-
}
770+
XTENSOR_ASSERT((de.dimension() == 1));
771+
XTENSOR_ASSERT((replace || n <= de.size()));
778772
using result_type = xtensor<typename T::value_type, 1>;
779773
using size_type = typename result_type::size_type;
780774
result_type result;
@@ -837,18 +831,10 @@ namespace xt
837831
{
838832
const auto& de = e.derived_cast();
839833
const auto& dweights = weights.derived_cast();
840-
if (de.dimension() != 1)
841-
{
842-
XTENSOR_THROW(std::runtime_error, "Sample and weight expression must be 1 dimensional");
843-
}
844-
if (de.size() < n && !replace)
845-
{
846-
XTENSOR_THROW(std::runtime_error, "If replace is false, then the sample expression's size must be > n");
847-
}
848-
if (de.size() != dweights.size() || de.dimension() != dweights.dimension())
849-
{
850-
XTENSOR_THROW(std::runtime_error, "Sample and weight expression must have the same size");
851-
}
834+
XTENSOR_ASSERT((de.dimension() == 1));
835+
XTENSOR_ASSERT((replace || n <= de.size()));
836+
XTENSOR_ASSERT((de.size() == dweights.size()));
837+
XTENSOR_ASSERT((de.dimension() == dweights.dimension()));
852838
XTENSOR_ASSERT(xt::all(dweights >= 0));
853839
static_assert(std::is_floating_point<typename W::value_type>::value,
854840
"Weight expression must be of floating point type");

test/test_xrandom.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,6 @@ namespace xt
161161
auto acr3 = xt::random::choice(a, 5, true);
162162
ASSERT_EQ(acr1, acr3);
163163
ASSERT_NE(acr1, acr2);
164-
165-
xarray<double> b = {-1, 1};
166-
xt::random::seed(42);
167-
XT_ASSERT_THROW(xt::random::choice(b, 5, false), std::runtime_error);
168-
XT_ASSERT_NO_THROW(xt::random::choice(b, 5, true));
169-
xarray<double> multidim_input = { {1,2,3}, {3,4,5} };
170-
XT_ASSERT_THROW(xt::random::choice(multidim_input, 5, true), std::runtime_error);
171164
}
172165

173166
TEST(xrandom, weighted_choice)
@@ -188,16 +181,6 @@ namespace xt
188181
ASSERT_TRUE(all(isin(ac1, a)));
189182
ASSERT_TRUE(all(equal(ac1 % 2, 1)));
190183
}
191-
192-
xarray<double> b = {-1, 1};
193-
xarray<double> v = {1, 1};
194-
xt::random::seed(42);
195-
XT_ASSERT_THROW(xt::random::choice(b, 5, v, false), std::runtime_error);
196-
XT_ASSERT_NO_THROW(xt::random::choice(b, 5, v, true));
197-
xarray<double> multidim_input = { {1,2,3}, {3,4,5} };
198-
XT_ASSERT_THROW(xt::random::choice(multidim_input, 5, v, true), std::runtime_error);
199-
xarray<double> bad_count_weights = {1, 1, 4};
200-
XT_ASSERT_THROW(xt::random::choice(b, 5, bad_count_weights, true), std::runtime_error);
201184
}
202185

203186
TEST(xrandom, shuffle)

0 commit comments

Comments
 (0)