|
14 | 14 | #ifndef XTENSOR_RANDOM_HPP |
15 | 15 | #define XTENSOR_RANDOM_HPP |
16 | 16 |
|
| 17 | +#include <algorithm> |
17 | 18 | #include <functional> |
18 | 19 | #include <random> |
19 | 20 | #include <utility> |
| 21 | +#include <type_traits> |
| 22 | + |
| 23 | +#include <xtl/xspan.hpp> |
20 | 24 |
|
21 | 25 | #include "xbuilder.hpp" |
22 | 26 | #include "xgenerator.hpp" |
| 27 | +#include "xindex_view.hpp" |
23 | 28 | #include "xtensor.hpp" |
24 | 29 | #include "xtensor_config.hpp" |
25 | 30 | #include "xview.hpp" |
| 31 | +#include "xmath.hpp" |
26 | 32 |
|
27 | 33 | namespace xt |
28 | 34 | { |
@@ -180,6 +186,12 @@ namespace xt |
180 | 186 | template <class T, class E = random::default_engine_type> |
181 | 187 | xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n, bool replace = true, |
182 | 188 | E& engine = random::get_default_random_engine()); |
| 189 | + |
| 190 | + template <class T, class W, class E = random::default_engine_type> |
| 191 | + xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n, |
| 192 | + const xexpression<W>& weights, |
| 193 | + bool replace = true, |
| 194 | + E& engine = random::get_default_random_engine()); |
183 | 195 | } |
184 | 196 |
|
185 | 197 | namespace detail |
@@ -755,14 +767,8 @@ namespace xt |
755 | 767 | xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n, bool replace, E& engine) |
756 | 768 | { |
757 | 769 | const auto& de = e.derived_cast(); |
758 | | - if (de.dimension() != 1) |
759 | | - { |
760 | | - XTENSOR_THROW(std::runtime_error, "Sample expression must be 1 dimensional"); |
761 | | - } |
762 | | - if (de.size() < n && !replace) |
763 | | - { |
764 | | - XTENSOR_THROW(std::runtime_error, "If replace is false, then the sample expression's size must be > n"); |
765 | | - } |
| 770 | + XTENSOR_ASSERT((de.dimension() == 1)); |
| 771 | + XTENSOR_ASSERT((replace || n <= de.size())); |
766 | 772 | using result_type = xtensor<typename T::value_type, 1>; |
767 | 773 | using size_type = typename result_type::size_type; |
768 | 774 | result_type result; |
@@ -792,6 +798,87 @@ namespace xt |
792 | 798 | } |
793 | 799 | return result; |
794 | 800 | } |
| 801 | + |
| 802 | + /** |
| 803 | + * Weighted random sampling. |
| 804 | + * |
| 805 | + * Randomly sample n unique elements from xexpression ``e`` using the discrete distribution parametrized by |
| 806 | + * the weights ``w``. |
| 807 | + * When sampling with replacement, this means that the probability to sample element ``e[i]`` is defined as |
| 808 | + * ``w[i] / sum(w)``. |
| 809 | + * Without replacement, this only describes the probability of the first sample element. |
| 810 | + * In successive samples, the weight of items already sampled is assumed to be zero. |
| 811 | + * |
| 812 | + * For weighted random sampling with replacement, binary search with cumulative weights alogrithm is used. |
| 813 | + * For weighted random sampling without replacement, the algorithm used is the exponential sort from |
| 814 | + * [Efraimidis and Spirakis](https://doi.org/10.1016/j.ipl.2005.11.003) (2006) with the ``weight / randexp(1)`` |
| 815 | + * [trick](https://web.archive.org/web/20201021162211/https://krlmlr.github.io/wrswoR/) from Kirill Müller. |
| 816 | + * |
| 817 | + * Note: this function makes a copy of your data, and only 1D data is accepted. |
| 818 | + * |
| 819 | + * @param e expression to sample from |
| 820 | + * @param n number of elements to sample |
| 821 | + * @param w expression for the weight distribution. |
| 822 | + * Weights must be positive and real-valued but need not sum to 1. |
| 823 | + * @param replace set true to sample with replacement |
| 824 | + * @param engine random number engine |
| 825 | + * |
| 826 | + * @return xtensor containing 1D container of sampled elements |
| 827 | + */ |
| 828 | + template <class T, class W, class E> |
| 829 | + xtensor<typename T::value_type, 1> |
| 830 | + choice(const xexpression<T>& e, std::size_t n, const xexpression<W>& weights, bool replace, E& engine) |
| 831 | + { |
| 832 | + const auto& de = e.derived_cast(); |
| 833 | + const auto& dweights = weights.derived_cast(); |
| 834 | + XTENSOR_ASSERT((de.dimension() == 1)); |
| 835 | + XTENSOR_ASSERT((replace || n <= de.size())); |
| 836 | + XTENSOR_ASSERT((de.size() == dweights.size())); |
| 837 | + XTENSOR_ASSERT((de.dimension() == dweights.dimension())); |
| 838 | + XTENSOR_ASSERT(xt::all(dweights >= 0)); |
| 839 | + static_assert(std::is_floating_point<typename W::value_type>::value, |
| 840 | + "Weight expression must be of floating point type"); |
| 841 | + using result_type = xtensor<typename T::value_type, 1>; |
| 842 | + using size_type = typename result_type::size_type; |
| 843 | + using weight_type = typename W::value_type; |
| 844 | + result_type result; |
| 845 | + result.resize({n}); |
| 846 | + |
| 847 | + if (replace) |
| 848 | + { |
| 849 | + // Sample u uniformly in the range [0, sum(weights)[ |
| 850 | + // The index idx of the sampled element is such that weight_cumul[idx - 1] <= u < weight_cumul[idx]. |
| 851 | + // Where weight_cumul[-1] is implicitly 0, as the empty sum. |
| 852 | + const auto wc = eval(cumsum(dweights)); |
| 853 | + std::uniform_real_distribution<weight_type> weight_dist{0, wc[wc.size() - 1]}; |
| 854 | + for(auto& x : result) |
| 855 | + { |
| 856 | + const auto u = weight_dist(engine); |
| 857 | + const auto idx = static_cast<size_type>(std::upper_bound(wc.cbegin(), wc.cend(), u) - wc.cbegin()); |
| 858 | + x = de[idx]; |
| 859 | + } |
| 860 | + |
| 861 | + } |
| 862 | + else |
| 863 | + { |
| 864 | + // Compute (modified) keys as weight/randexp(1). |
| 865 | + xtensor<weight_type, 1> keys; |
| 866 | + keys.resize({dweights.size()}); |
| 867 | + std::exponential_distribution<weight_type> randexp{weight_type(1)}; |
| 868 | + std::transform(dweights.cbegin(), dweights.cend(), keys.begin(), |
| 869 | + [&randexp, &engine](auto w){ return w / randexp(engine); }); |
| 870 | + |
| 871 | + // Find indexes for the n biggest key |
| 872 | + xtensor<size_type, 1> indices = arange<size_type>(0, dweights.size()); |
| 873 | + std::partial_sort(indices.begin(), indices.begin() + n, indices.end(), |
| 874 | + [&keys](auto i, auto j) { return keys[i] > keys[j]; }); |
| 875 | + |
| 876 | + // Return samples with the n biggest keys |
| 877 | + result = index_view(de, xtl::span<size_type>{indices.data(), n}); |
| 878 | + } |
| 879 | + return result; |
| 880 | + } |
| 881 | + |
795 | 882 | } |
796 | 883 | } |
797 | 884 |
|
|
0 commit comments