Skip to content

Commit b5f4a2d

Browse files
authored
Merge pull request #2241 from AntoinePrv/weighted-choice
Implement xt::random::choice with weights vector
2 parents 40bf942 + f997d26 commit b5f4a2d

3 files changed

Lines changed: 120 additions & 15 deletions

File tree

docs/source/api/xrandom.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ Defined in ``xtensor/xrandom.hpp``
8080
:project: xtensor
8181

8282
.. _random-choice-function-reference:
83-
.. doxygenfunction:: xt::random::choice
83+
.. doxygenfunction:: xt::random::choice(const xexpression<T>&, std::size_t, bool, E&)
84+
:project: xtensor
85+
.. doxygenfunction:: xt::random::choice(const xexpression<T>&, std::size_t, const xexpression<W>&, bool, E&)
8486
:project: xtensor
8587

8688
.. _random-shuffle-function-reference:

include/xtensor/xrandom.hpp

Lines changed: 95 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,21 @@
1414
#ifndef XTENSOR_RANDOM_HPP
1515
#define XTENSOR_RANDOM_HPP
1616

17+
#include <algorithm>
1718
#include <functional>
1819
#include <random>
1920
#include <utility>
21+
#include <type_traits>
22+
23+
#include <xtl/xspan.hpp>
2024

2125
#include "xbuilder.hpp"
2226
#include "xgenerator.hpp"
27+
#include "xindex_view.hpp"
2328
#include "xtensor.hpp"
2429
#include "xtensor_config.hpp"
2530
#include "xview.hpp"
31+
#include "xmath.hpp"
2632

2733
namespace xt
2834
{
@@ -180,6 +186,12 @@ namespace xt
180186
template <class T, class E = random::default_engine_type>
181187
xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n, bool replace = true,
182188
E& engine = random::get_default_random_engine());
189+
190+
template <class T, class W, class E = random::default_engine_type>
191+
xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n,
192+
const xexpression<W>& weights,
193+
bool replace = true,
194+
E& engine = random::get_default_random_engine());
183195
}
184196

185197
namespace detail
@@ -755,14 +767,8 @@ namespace xt
755767
xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n, bool replace, E& engine)
756768
{
757769
const auto& de = e.derived_cast();
758-
if (de.dimension() != 1)
759-
{
760-
XTENSOR_THROW(std::runtime_error, "Sample expression must be 1 dimensional");
761-
}
762-
if (de.size() < n && !replace)
763-
{
764-
XTENSOR_THROW(std::runtime_error, "If replace is false, then the sample expression's size must be > n");
765-
}
770+
XTENSOR_ASSERT((de.dimension() == 1));
771+
XTENSOR_ASSERT((replace || n <= de.size()));
766772
using result_type = xtensor<typename T::value_type, 1>;
767773
using size_type = typename result_type::size_type;
768774
result_type result;
@@ -792,6 +798,87 @@ namespace xt
792798
}
793799
return result;
794800
}
801+
802+
/**
803+
* Weighted random sampling.
804+
*
805+
* Randomly sample n unique elements from xexpression ``e`` using the discrete distribution parametrized by
806+
* the weights ``w``.
807+
* When sampling with replacement, this means that the probability to sample element ``e[i]`` is defined as
808+
* ``w[i] / sum(w)``.
809+
* Without replacement, this only describes the probability of the first sample element.
810+
* In successive samples, the weight of items already sampled is assumed to be zero.
811+
*
812+
* For weighted random sampling with replacement, binary search with cumulative weights alogrithm is used.
813+
* For weighted random sampling without replacement, the algorithm used is the exponential sort from
814+
* [Efraimidis and Spirakis](https://doi.org/10.1016/j.ipl.2005.11.003) (2006) with the ``weight / randexp(1)``
815+
* [trick](https://web.archive.org/web/20201021162211/https://krlmlr.github.io/wrswoR/) from Kirill Müller.
816+
*
817+
* Note: this function makes a copy of your data, and only 1D data is accepted.
818+
*
819+
* @param e expression to sample from
820+
* @param n number of elements to sample
821+
* @param w expression for the weight distribution.
822+
* Weights must be positive and real-valued but need not sum to 1.
823+
* @param replace set true to sample with replacement
824+
* @param engine random number engine
825+
*
826+
* @return xtensor containing 1D container of sampled elements
827+
*/
828+
template <class T, class W, class E>
829+
xtensor<typename T::value_type, 1>
830+
choice(const xexpression<T>& e, std::size_t n, const xexpression<W>& weights, bool replace, E& engine)
831+
{
832+
const auto& de = e.derived_cast();
833+
const auto& dweights = weights.derived_cast();
834+
XTENSOR_ASSERT((de.dimension() == 1));
835+
XTENSOR_ASSERT((replace || n <= de.size()));
836+
XTENSOR_ASSERT((de.size() == dweights.size()));
837+
XTENSOR_ASSERT((de.dimension() == dweights.dimension()));
838+
XTENSOR_ASSERT(xt::all(dweights >= 0));
839+
static_assert(std::is_floating_point<typename W::value_type>::value,
840+
"Weight expression must be of floating point type");
841+
using result_type = xtensor<typename T::value_type, 1>;
842+
using size_type = typename result_type::size_type;
843+
using weight_type = typename W::value_type;
844+
result_type result;
845+
result.resize({n});
846+
847+
if (replace)
848+
{
849+
// Sample u uniformly in the range [0, sum(weights)[
850+
// The index idx of the sampled element is such that weight_cumul[idx - 1] <= u < weight_cumul[idx].
851+
// Where weight_cumul[-1] is implicitly 0, as the empty sum.
852+
const auto wc = eval(cumsum(dweights));
853+
std::uniform_real_distribution<weight_type> weight_dist{0, wc[wc.size() - 1]};
854+
for(auto& x : result)
855+
{
856+
const auto u = weight_dist(engine);
857+
const auto idx = static_cast<size_type>(std::upper_bound(wc.cbegin(), wc.cend(), u) - wc.cbegin());
858+
x = de[idx];
859+
}
860+
861+
}
862+
else
863+
{
864+
// Compute (modified) keys as weight/randexp(1).
865+
xtensor<weight_type, 1> keys;
866+
keys.resize({dweights.size()});
867+
std::exponential_distribution<weight_type> randexp{weight_type(1)};
868+
std::transform(dweights.cbegin(), dweights.cend(), keys.begin(),
869+
[&randexp, &engine](auto w){ return w / randexp(engine); });
870+
871+
// Find indexes for the n biggest key
872+
xtensor<size_type, 1> indices = arange<size_type>(0, dweights.size());
873+
std::partial_sort(indices.begin(), indices.begin() + n, indices.end(),
874+
[&keys](auto i, auto j) { return keys[i] > keys[j]; });
875+
876+
// Return samples with the n biggest keys
877+
result = index_view(de, xtl::span<size_type>{indices.data(), n});
878+
}
879+
return result;
880+
}
881+
795882
}
796883
}
797884

test/test_xrandom.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
* The full license is in the file LICENSE, distributed with this software. *
88
****************************************************************************/
99

10+
#include <type_traits>
11+
1012
#include "gtest/gtest.h"
1113
#include "test_common_macros.hpp"
1214
#if (defined(__GNUC__) && !defined(__clang__))
@@ -19,6 +21,7 @@
1921
#endif
2022
#include "xtensor/xarray.hpp"
2123
#include "xtensor/xview.hpp"
24+
#include "xtensor/xset_operation.hpp"
2225

2326
namespace xt
2427
{
@@ -158,13 +161,26 @@ namespace xt
158161
auto acr3 = xt::random::choice(a, 5, true);
159162
ASSERT_EQ(acr1, acr3);
160163
ASSERT_NE(acr1, acr2);
164+
}
161165

162-
xarray<double> b = {-1, 1};
163-
xt::random::seed(42);
164-
XT_ASSERT_THROW(xt::random::choice(b, 5, false), std::runtime_error);
165-
XT_ASSERT_NO_THROW(xt::random::choice(b, 5, true));
166-
xarray<double> multidim_input = { {1,2,3}, {3,4,5} };
167-
XT_ASSERT_THROW(xt::random::choice(multidim_input, 5, true), std::runtime_error);
166+
TEST(xrandom, weighted_choice)
167+
{
168+
xarray<int> a = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
169+
xarray<double> w = {1, 0, 2, 0, 1, 0, 1, 0, 2, 0, 1, 0};
170+
171+
for(bool replace : {true, false}) {
172+
xt::random::seed(42);
173+
auto ac1 = xt::random::choice(a, 6, w, replace);
174+
auto ac2 = xt::random::choice(a, 6, w, replace);
175+
xt::random::seed(42);
176+
auto ac3 = xt::random::choice(a, 6, w, replace);
177+
static_assert(std::is_same<decltype(a)::value_type, decltype(ac1)::value_type>::value,
178+
"Elements must be same type");
179+
ASSERT_EQ(ac1, ac3);
180+
ASSERT_NE(ac1, ac2);
181+
ASSERT_TRUE(all(isin(ac1, a)));
182+
ASSERT_TRUE(all(equal(ac1 % 2, 1)));
183+
}
168184
}
169185

170186
TEST(xrandom, shuffle)

0 commit comments

Comments
 (0)