|
14 | 14 | #ifndef XTENSOR_RANDOM_HPP |
15 | 15 | #define XTENSOR_RANDOM_HPP |
16 | 16 |
|
| 17 | +#include <algorithm> |
17 | 18 | #include <functional> |
18 | 19 | #include <random> |
19 | 20 | #include <utility> |
| 21 | +#include <type_traits> |
| 22 | + |
| 23 | +#include <xtl/xspan.hpp> |
20 | 24 |
|
21 | 25 | #include "xbuilder.hpp" |
22 | 26 | #include "xgenerator.hpp" |
| 27 | +#include "xindex_view.hpp" |
23 | 28 | #include "xtensor.hpp" |
24 | 29 | #include "xtensor_config.hpp" |
25 | 30 | #include "xview.hpp" |
@@ -180,6 +185,12 @@ namespace xt |
180 | 185 | template <class T, class E = random::default_engine_type> |
181 | 186 | xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n, bool replace = true, |
182 | 187 | E& engine = random::get_default_random_engine()); |
| 188 | + |
| 189 | + template <class T, class W, class E = random::default_engine_type> |
| 190 | + xtensor<typename T::value_type, 1> choice(const xexpression<T>& e, std::size_t n, |
| 191 | + const xexpression<W>& weights, |
| 192 | + bool replace = true, |
| 193 | + E& engine = random::get_default_random_engine()); |
183 | 194 | } |
184 | 195 |
|
185 | 196 | namespace detail |
@@ -792,6 +803,83 @@ namespace xt |
792 | 803 | } |
793 | 804 | return result; |
794 | 805 | } |
| 806 | + |
| 807 | + /** |
| 808 | + * Randomly select n unique elements from xexpression e using the weights distribution w. |
| 809 | + * |
| 810 | + * Note: this function makes a copy of your data, and only 1D data is accepted. |
| 811 | + * |
| 812 | + * @param e expression to sample from |
| 813 | + * @param n number of elements to sample |
| 814 | + * @param e expression for the weight distribution. |
| 815 | + * Weights must be positive and real-valued but need not sum to 1. |
| 816 | + * @param replace whether to sample with or without replacement |
| 817 | + * @param engine random number engine |
| 818 | + * |
| 819 | + * @return xtensor containing 1D container of sampled elements |
| 820 | + */ |
| 821 | + template <class T, class W, class E> |
| 822 | + xtensor<typename T::value_type, 1> |
| 823 | + choice(const xexpression<T>& e, std::size_t n, const xexpression<W>& weights, bool replace, E& engine) |
| 824 | + { |
| 825 | + const auto& de = e.derived_cast(); |
| 826 | + const auto& dweights = weights.derived_cast(); |
| 827 | + if (de.dimension() != 1) |
| 828 | + { |
| 829 | + XTENSOR_THROW(std::runtime_error, "Sample and weight expression must be 1 dimensional"); |
| 830 | + } |
| 831 | + if (de.size() < n && !replace) |
| 832 | + { |
| 833 | + XTENSOR_THROW(std::runtime_error, "If replace is false, then the sample expression's size must be > n"); |
| 834 | + } |
| 835 | + if (de.size() != dweights.size() || de.dimension() != dweights.dimension()) |
| 836 | + { |
| 837 | + XTENSOR_THROW(std::runtime_error, "Sample and weight expression must have the same size"); |
| 838 | + } |
| 839 | + static_assert(std::is_floating_point<typename W::value_type>::value, |
| 840 | + "Weight expression must be of floating point type"); |
| 841 | + using result_type = xtensor<typename T::value_type, 1>; |
| 842 | + using size_type = typename result_type::size_type; |
| 843 | + using weight_type = typename W::value_type; |
| 844 | + result_type result; |
| 845 | + result.resize({n}); |
| 846 | + |
| 847 | + if (replace) |
| 848 | + { |
| 849 | + XTENSOR_THROW(std::runtime_error, "Not implemented"); |
| 850 | + } |
| 851 | + else |
| 852 | + { |
| 853 | + // Algorithm from |
| 854 | + // Efraimidis PS, Spirakis PG (2006). "Weighted random sampling with a reservoir." |
| 855 | + // Information Processing Letters, 97 (5), 181-185. ISSN 0020-0190. |
| 856 | + // doi:10.1016/j.ipl.2005.11.003. |
| 857 | + // https://www.sciencedirect.com/science/article/pii/S002001900500298X |
| 858 | + // |
| 859 | + // The keys computed are replaced with weight/randexp(1) instead rand()^(1/weight) as done in wrlmlR: |
| 860 | + // https://web.archive.org/web/20201021162211/https://krlmlr.github.io/wrswoR/ |
| 861 | + // https://web.archive.org/web/20201021162520/https://github.com/krlmlr/wrswoR/blob/master/src/sample_int_crank.cpp |
| 862 | + // As well as in JuliaStats: |
| 863 | + // https://web.archive.org/web/20201021162949/https://github.com/JuliaStats/StatsBase.jl/blob/master/src/sampling.jl |
| 864 | + |
| 865 | + // Compute (modified) keys as weight/randexp(1). |
| 866 | + xtensor<weight_type, 1> keys; |
| 867 | + keys.resize({dweights.size()}); |
| 868 | + std::exponential_distribution<weight_type> randexp{1.}; |
| 869 | + std::transform(dweights.storage().begin(), dweights.storage().end(), keys.begin(), |
| 870 | + [&randexp, &engine](auto w){ return w / randexp(engine); }); |
| 871 | + |
| 872 | + // Find indexes for the n biggest key |
| 873 | + xtensor<size_type, 1> indices = arange<size_type>(0, dweights.size()); |
| 874 | + std::partial_sort(indices.storage().begin(), indices.storage().begin() + n, indices.storage().end(), |
| 875 | + [&keys](auto i, auto j) { return keys[i] > keys[j]; }); |
| 876 | + |
| 877 | + // Return samples with the n biggest keys |
| 878 | + result = index_view(de, xtl::span<size_type>{indices.data(), n}); |
| 879 | + } |
| 880 | + return result; |
| 881 | + } |
| 882 | + |
795 | 883 | } |
796 | 884 | } |
797 | 885 |
|
|
0 commit comments