Skip to content

Commit 5eeb50e

Browse files
committed
Add weighted choice with replacement
1 parent 8ef5aca commit 5eeb50e

2 files changed

Lines changed: 24 additions & 1 deletion

File tree

include/xtensor/xrandom.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "xtensor.hpp"
2929
#include "xtensor_config.hpp"
3030
#include "xview.hpp"
31+
#include "xmath.hpp"
3132

3233
namespace xt
3334
{
@@ -846,7 +847,19 @@ namespace xt
846847

847848
if (replace)
848849
{
849-
XTENSOR_THROW(std::runtime_error, "Not implemented");
850+
// Sample u uniformly in the range [0, sum(weights)[
851+
// The index idx of the sampled element in e is the largest idx such that weight_cumul[idx] < u (given by std::upper_bound - 1).
852+
const xtensor<weight_type, 1> weight_cumul = cumsum(dweights); // 0 included as first elem
853+
const auto weight_cumul_begin = weight_cumul.storage().begin();
854+
std::uniform_real_distribution<weight_type> weight_dist{0, weight_cumul[weight_cumul.size() - 1]};
855+
for(auto& x : result)
856+
{
857+
const auto u = weight_dist(engine);
858+
const auto idx_iter = std::upper_bound(weight_cumul_begin, weight_cumul.storage().end(), u) - 1;
859+
const auto idx = static_cast<size_type>(idx_iter - weight_cumul_begin);
860+
x = de.storage()[idx];
861+
}
862+
850863
}
851864
else
852865
{

test/test_xrandom.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,16 @@ namespace xt
186186
ASSERT_TRUE(all(isin(ac1, a)));
187187
ASSERT_TRUE(all(equal(ac1 % 2, 1)));
188188

189+
xt::random::seed(42);
190+
auto acr1 = xt::random::choice(a, 6, w, false);
191+
auto acr2 = xt::random::choice(a, 6, w, false);
192+
xt::random::seed(42);
193+
auto acr3 = xt::random::choice(a, 6, w, false);
194+
ASSERT_EQ(acr1, acr3);
195+
ASSERT_NE(acr1, acr2);
196+
ASSERT_TRUE(all(isin(acr1, a)));
197+
ASSERT_TRUE(all(equal(acr1 % 2, 1)));
198+
189199
xarray<double> b = {-1, 1};
190200
xarray<double> v = {1, 1};
191201
xt::random::seed(42);

0 commit comments

Comments
 (0)