Skip to content

Commit b5b7371

Browse files
committed
Improve weighted random sampling implementation
- Add positivej weights assertion for xt::random::choice - Change std::all_of for xt::all - Explicit construct randexp param - Fix test and weighted sampling with replacement error - Improve documentation of weighted random sampling
1 parent 5eeb50e commit b5b7371

3 files changed

Lines changed: 39 additions & 42 deletions

File tree

docs/source/api/xrandom.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ Defined in ``xtensor/xrandom.hpp``
8080
:project: xtensor
8181

8282
.. _random-choice-function-reference:
83-
.. doxygenfunction:: xt::random::choice
83+
.. doxygenfunction:: xt::random::choice(const xexpression<T>&, std::size_t, bool, E&)
84+
:project: xtensor
85+
.. doxygenfunction:: xt::random::choice(const xexpression<T>&, std::size_t, const xexpression<W>&, bool, E&)
8486
:project: xtensor
8587

8688
.. _random-shuffle-function-reference:

include/xtensor/xrandom.hpp

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -806,15 +806,27 @@ namespace xt
806806
}
807807

808808
/**
809-
* Randomly select n unique elements from xexpression e using the weights distribution w.
809+
* Weighted random sampling.
810+
*
811+
* Randomly sample n unique elements from xexpression ``e`` using the discrete distribution parametrized by
812+
* the weights ``w``.
813+
* When sampling with replacement, this means that the probability to sample element ``e[i]`` is defined as
814+
* ``w[i] / sum(w)``.
815+
* Without replacement, this only describes the probability of the first sample element.
816+
* In successive samples, the weight of items already sampled is assumed to be zero.
817+
*
818+
* For weighted random sampling with replacement, binary search with cumulative weights alogrithm is used.
819+
* For weighted random sampling without replacement, the algorithm used is the exponential sort from
820+
* [Efraimidis and Spirakis](https://doi.org/10.1016/j.ipl.2005.11.003) (2006) with the ``weight / randexp(1)``
821+
* [trick](https://web.archive.org/web/20201021162211/https://krlmlr.github.io/wrswoR/) from Kirill Müller.
810822
*
811823
* Note: this function makes a copy of your data, and only 1D data is accepted.
812824
*
813825
* @param e expression to sample from
814826
* @param n number of elements to sample
815-
* @param e expression for the weight distribution.
827+
* @param w expression for the weight distribution.
816828
* Weights must be positive and real-valued but need not sum to 1.
817-
* @param replace whether to sample with or without replacement
829+
* @param replace set true to sample with replacement
818830
* @param engine random number engine
819831
*
820832
* @return xtensor containing 1D container of sampled elements
@@ -837,6 +849,7 @@ namespace xt
837849
{
838850
XTENSOR_THROW(std::runtime_error, "Sample and weight expression must have the same size");
839851
}
852+
XTENSOR_ASSERT(xt::all(dweights >= 0));
840853
static_assert(std::is_floating_point<typename W::value_type>::value,
841854
"Weight expression must be of floating point type");
842855
using result_type = xtensor<typename T::value_type, 1>;
@@ -848,37 +861,26 @@ namespace xt
848861
if (replace)
849862
{
850863
// Sample u uniformly in the range [0, sum(weights)[
851-
// The index idx of the sampled element in e is the largest idx such that weight_cumul[idx] < u (given by std::upper_bound - 1).
852-
const xtensor<weight_type, 1> weight_cumul = cumsum(dweights); // 0 included as first elem
864+
// The index idx of the sampled element is such that weight_cumul[idx - 1] <= u < weight_cumul[idx].
865+
// Where weight_cumul[-1] is implicitly 0, as the empty sum.
866+
const auto weight_cumul = eval(cumsum(dweights));
853867
const auto weight_cumul_begin = weight_cumul.storage().begin();
854868
std::uniform_real_distribution<weight_type> weight_dist{0, weight_cumul[weight_cumul.size() - 1]};
855869
for(auto& x : result)
856870
{
857871
const auto u = weight_dist(engine);
858-
const auto idx_iter = std::upper_bound(weight_cumul_begin, weight_cumul.storage().end(), u) - 1;
872+
const auto idx_iter = std::upper_bound(weight_cumul_begin, weight_cumul.storage().end(), u);
859873
const auto idx = static_cast<size_type>(idx_iter - weight_cumul_begin);
860874
x = de.storage()[idx];
861875
}
862876

863877
}
864878
else
865879
{
866-
// Algorithm from
867-
// Efraimidis PS, Spirakis PG (2006). "Weighted random sampling with a reservoir."
868-
// Information Processing Letters, 97 (5), 181-185. ISSN 0020-0190.
869-
// doi:10.1016/j.ipl.2005.11.003.
870-
// https://www.sciencedirect.com/science/article/pii/S002001900500298X
871-
//
872-
// The keys computed are replaced with weight/randexp(1) instead rand()^(1/weight) as done in wrlmlR:
873-
// https://web.archive.org/web/20201021162211/https://krlmlr.github.io/wrswoR/
874-
// https://web.archive.org/web/20201021162520/https://github.com/krlmlr/wrswoR/blob/master/src/sample_int_crank.cpp
875-
// As well as in JuliaStats:
876-
// https://web.archive.org/web/20201021162949/https://github.com/JuliaStats/StatsBase.jl/blob/master/src/sampling.jl
877-
878880
// Compute (modified) keys as weight/randexp(1).
879881
xtensor<weight_type, 1> keys;
880882
keys.resize({dweights.size()});
881-
std::exponential_distribution<weight_type> randexp{1.};
883+
std::exponential_distribution<weight_type> randexp{weight_type(1)};
882884
std::transform(dweights.storage().begin(), dweights.storage().end(), keys.begin(),
883885
[&randexp, &engine](auto w){ return w / randexp(engine); });
884886

test/test_xrandom.cpp

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -173,28 +173,21 @@ namespace xt
173173
TEST(xrandom, weighted_choice)
174174
{
175175
xarray<int> a = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
176-
xarray<double> w = {1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0};
177-
xt::random::seed(42);
178-
auto ac1 = xt::random::choice(a, 6, w, false);
179-
auto ac2 = xt::random::choice(a, 6, w, false);
180-
xt::random::seed(42);
181-
auto ac3 = xt::random::choice(a, 6, w, false);
182-
183-
static_assert(std::is_same<decltype(a)::value_type, decltype(ac1)::value_type>::value, "Elements must be same type");
184-
ASSERT_EQ(ac1, ac3);
185-
ASSERT_NE(ac1, ac2);
186-
ASSERT_TRUE(all(isin(ac1, a)));
187-
ASSERT_TRUE(all(equal(ac1 % 2, 1)));
188-
189-
xt::random::seed(42);
190-
auto acr1 = xt::random::choice(a, 6, w, false);
191-
auto acr2 = xt::random::choice(a, 6, w, false);
192-
xt::random::seed(42);
193-
auto acr3 = xt::random::choice(a, 6, w, false);
194-
ASSERT_EQ(acr1, acr3);
195-
ASSERT_NE(acr1, acr2);
196-
ASSERT_TRUE(all(isin(acr1, a)));
197-
ASSERT_TRUE(all(equal(acr1 % 2, 1)));
176+
xarray<double> w = {1, 0, 2, 0, 1, 0, 1, 0, 2, 0, 1, 0};
177+
178+
for(bool replace : {true, false}) {
179+
xt::random::seed(42);
180+
auto ac1 = xt::random::choice(a, 6, w, replace);
181+
auto ac2 = xt::random::choice(a, 6, w, replace);
182+
xt::random::seed(42);
183+
auto ac3 = xt::random::choice(a, 6, w, replace);
184+
static_assert(std::is_same<decltype(a)::value_type, decltype(ac1)::value_type>::value,
185+
"Elements must be same type");
186+
ASSERT_EQ(ac1, ac3);
187+
ASSERT_NE(ac1, ac2);
188+
ASSERT_TRUE(all(isin(ac1, a)));
189+
ASSERT_TRUE(all(equal(ac1 % 2, 1)));
190+
}
198191

199192
xarray<double> b = {-1, 1};
200193
xarray<double> v = {1, 1};

0 commit comments

Comments
 (0)