Skip to content

Commit 8ef5aca

Browse files
committed
Add weighted choice without replacement
1 parent 40bf942 commit 8ef5aca

2 files changed

Lines changed: 118 additions & 0 deletions

File tree

include/xtensor/xrandom.hpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414
#ifndef XTENSOR_RANDOM_HPP
1515
#define XTENSOR_RANDOM_HPP
1616

17+
#include <algorithm>
1718
#include <functional>
1819
#include <random>
1920
#include <utility>
21+
#include <type_traits>
22+
23+
#include <xtl/xspan.hpp>
2024

2125
#include "xbuilder.hpp"
2226
#include "xgenerator.hpp"
27+
#include "xindex_view.hpp"
2328
#include "xtensor.hpp"
2429
#include "xtensor_config.hpp"
2530
#include "xview.hpp"
@@ -180,6 +185,12 @@ namespace xt
180185
template <class T, class E = random::default_engine_type>
181186
xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n, bool replace = true,
182187
E& engine = random::get_default_random_engine());
188+
189+
template <class T, class W, class E = random::default_engine_type>
190+
xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n,
191+
const xexpression<W>& weights,
192+
bool replace = true,
193+
E& engine = random::get_default_random_engine());
183194
}
184195

185196
namespace detail
@@ -792,6 +803,83 @@ namespace xt
792803
}
793804
return result;
794805
}
806+
807+
/**
808+
* Randomly select n unique elements from xexpression e using the weights distribution w.
809+
*
810+
* Note: this function makes a copy of your data, and only 1D data is accepted.
811+
*
812+
* @param e expression to sample from
813+
* @param n number of elements to sample
814+
* @param e expression for the weight distribution.
815+
* Weights must be positive and real-valued but need not sum to 1.
816+
* @param replace whether to sample with or without replacement
817+
* @param engine random number engine
818+
*
819+
* @return xtensor containing 1D container of sampled elements
820+
*/
821+
template <class T, class W, class E>
822+
xtensor<typename T::value_type, 1>
823+
choice(const xexpression<T>& e, std::size_t n, const xexpression<W>& weights, bool replace, E& engine)
824+
{
825+
const auto& de = e.derived_cast();
826+
const auto& dweights = weights.derived_cast();
827+
if (de.dimension() != 1)
828+
{
829+
XTENSOR_THROW(std::runtime_error, "Sample and weight expression must be 1 dimensional");
830+
}
831+
if (de.size() < n && !replace)
832+
{
833+
XTENSOR_THROW(std::runtime_error, "If replace is false, then the sample expression's size must be > n");
834+
}
835+
if (de.size() != dweights.size() || de.dimension() != dweights.dimension())
836+
{
837+
XTENSOR_THROW(std::runtime_error, "Sample and weight expression must have the same size");
838+
}
839+
static_assert(std::is_floating_point<typename W::value_type>::value,
840+
"Weight expression must be of floating point type");
841+
using result_type = xtensor<typename T::value_type, 1>;
842+
using size_type = typename result_type::size_type;
843+
using weight_type = typename W::value_type;
844+
result_type result;
845+
result.resize({n});
846+
847+
if (replace)
848+
{
849+
XTENSOR_THROW(std::runtime_error, "Not implemented");
850+
}
851+
else
852+
{
853+
// Algorithm from
854+
// Efraimidis PS, Spirakis PG (2006). "Weighted random sampling with a reservoir."
855+
// Information Processing Letters, 97 (5), 181-185. ISSN 0020-0190.
856+
// doi:10.1016/j.ipl.2005.11.003.
857+
// https://www.sciencedirect.com/science/article/pii/S002001900500298X
858+
//
859+
// The keys computed are replaced with weight/randexp(1) instead rand()^(1/weight) as done in wrlmlR:
860+
// https://web.archive.org/web/20201021162211/https://krlmlr.github.io/wrswoR/
861+
// https://web.archive.org/web/20201021162520/https://github.com/krlmlr/wrswoR/blob/master/src/sample_int_crank.cpp
862+
// As well as in JuliaStats:
863+
// https://web.archive.org/web/20201021162949/https://github.com/JuliaStats/StatsBase.jl/blob/master/src/sampling.jl
864+
865+
// Compute (modified) keys as weight/randexp(1).
866+
xtensor<weight_type, 1> keys;
867+
keys.resize({dweights.size()});
868+
std::exponential_distribution<weight_type> randexp{1.};
869+
std::transform(dweights.storage().begin(), dweights.storage().end(), keys.begin(),
870+
[&randexp, &engine](auto w){ return w / randexp(engine); });
871+
872+
// Find indexes for the n biggest key
873+
xtensor<size_type, 1> indices = arange<size_type>(0, dweights.size());
874+
std::partial_sort(indices.storage().begin(), indices.storage().begin() + n, indices.storage().end(),
875+
[&keys](auto i, auto j) { return keys[i] > keys[j]; });
876+
877+
// Return samples with the n biggest keys
878+
result = index_view(de, xtl::span<size_type>{indices.data(), n});
879+
}
880+
return result;
881+
}
882+
795883
}
796884
}
797885

test/test_xrandom.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
* The full license is in the file LICENSE, distributed with this software. *
88
****************************************************************************/
99

10+
#include <type_traits>
11+
1012
#include "gtest/gtest.h"
1113
#include "test_common_macros.hpp"
1214
#if (defined(__GNUC__) && !defined(__clang__))
@@ -19,6 +21,7 @@
1921
#endif
2022
#include "xtensor/xarray.hpp"
2123
#include "xtensor/xview.hpp"
24+
#include "xtensor/xset_operation.hpp"
2225

2326
namespace xt
2427
{
@@ -167,6 +170,33 @@ namespace xt
167170
XT_ASSERT_THROW(xt::random::choice(multidim_input, 5, true), std::runtime_error);
168171
}
169172

173+
TEST(xrandom, weighted_choice)
174+
{
175+
xarray<int> a = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
176+
xarray<double> w = {1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0};
177+
xt::random::seed(42);
178+
auto ac1 = xt::random::choice(a, 6, w, false);
179+
auto ac2 = xt::random::choice(a, 6, w, false);
180+
xt::random::seed(42);
181+
auto ac3 = xt::random::choice(a, 6, w, false);
182+
183+
static_assert(std::is_same<decltype(a)::value_type, decltype(ac1)::value_type>::value, "Elements must be same type");
184+
ASSERT_EQ(ac1, ac3);
185+
ASSERT_NE(ac1, ac2);
186+
ASSERT_TRUE(all(isin(ac1, a)));
187+
ASSERT_TRUE(all(equal(ac1 % 2, 1)));
188+
189+
xarray<double> b = {-1, 1};
190+
xarray<double> v = {1, 1};
191+
xt::random::seed(42);
192+
XT_ASSERT_THROW(xt::random::choice(b, 5, v, false), std::runtime_error);
193+
XT_ASSERT_NO_THROW(xt::random::choice(b, 5, v, true));
194+
xarray<double> multidim_input = { {1,2,3}, {3,4,5} };
195+
XT_ASSERT_THROW(xt::random::choice(multidim_input, 5, v, true), std::runtime_error);
196+
xarray<double> bad_count_weights = {1, 1, 4};
197+
XT_ASSERT_THROW(xt::random::choice(b, 5, bad_count_weights, true), std::runtime_error);
198+
}
199+
170200
TEST(xrandom, shuffle)
171201
{
172202
xarray<double> a = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};

0 commit comments

Comments
 (0)