@@ -806,15 +806,27 @@ namespace xt
806806 }
807807
808808 /* *
809- * Randomly select n unique elements from xexpression e using the weights distribution w.
809+ * Weighted random sampling.
810+ *
811+ * Randomly sample n unique elements from xexpression ``e`` using the discrete distribution parametrized by
812+ * the weights ``w``.
813+ * When sampling with replacement, this means that the probability to sample element ``e[i]`` is defined as
814+ * ``w[i] / sum(w)``.
815+ * Without replacement, this only describes the probability of the first sample element.
816+ * In successive samples, the weight of items already sampled is assumed to be zero.
817+ *
818+ * For weighted random sampling with replacement, binary search with cumulative weights alogrithm is used.
819+ * For weighted random sampling without replacement, the algorithm used is the exponential sort from
820+ * [Efraimidis and Spirakis](https://doi.org/10.1016/j.ipl.2005.11.003) (2006) with the ``weight / randexp(1)``
821+ * [trick](https://web.archive.org/web/20201021162211/https://krlmlr.github.io/wrswoR/) from Kirill Müller.
810822 *
811823 * Note: this function makes a copy of your data, and only 1D data is accepted.
812824 *
813825 * @param e expression to sample from
814826 * @param n number of elements to sample
815- * @param e expression for the weight distribution.
827+ * @param w expression for the weight distribution.
816828 * Weights must be positive and real-valued but need not sum to 1.
817- * @param replace whether to sample with or without replacement
829+ * @param replace set true to sample with replacement
818830 * @param engine random number engine
819831 *
820832 * @return xtensor containing 1D container of sampled elements
@@ -837,6 +849,7 @@ namespace xt
837849 {
838850 XTENSOR_THROW (std::runtime_error, " Sample and weight expression must have the same size" );
839851 }
852+ XTENSOR_ASSERT (xt::all (dweights >= 0 ));
840853 static_assert (std::is_floating_point<typename W::value_type>::value,
841854 " Weight expression must be of floating point type" );
842855 using result_type = xtensor<typename T::value_type, 1 >;
@@ -848,37 +861,26 @@ namespace xt
848861 if (replace)
849862 {
850863 // Sample u uniformly in the range [0, sum(weights)[
851- // The index idx of the sampled element in e is the largest idx such that weight_cumul[idx] < u (given by std::upper_bound - 1).
852- const xtensor<weight_type, 1 > weight_cumul = cumsum (dweights); // 0 included as first elem
864+ // The index idx of the sampled element is such that weight_cumul[idx - 1] <= u < weight_cumul[idx].
865+ // Where weight_cumul[-1] is implicitly 0, as the empty sum.
866+ const auto weight_cumul = eval (cumsum (dweights));
853867 const auto weight_cumul_begin = weight_cumul.storage ().begin ();
854868 std::uniform_real_distribution<weight_type> weight_dist{0 , weight_cumul[weight_cumul.size () - 1 ]};
855869 for (auto & x : result)
856870 {
857871 const auto u = weight_dist (engine);
858- const auto idx_iter = std::upper_bound (weight_cumul_begin, weight_cumul.storage ().end (), u) - 1 ;
872+ const auto idx_iter = std::upper_bound (weight_cumul_begin, weight_cumul.storage ().end (), u);
859873 const auto idx = static_cast <size_type>(idx_iter - weight_cumul_begin);
860874 x = de.storage ()[idx];
861875 }
862876
863877 }
864878 else
865879 {
866- // Algorithm from
867- // Efraimidis PS, Spirakis PG (2006). "Weighted random sampling with a reservoir."
868- // Information Processing Letters, 97 (5), 181-185. ISSN 0020-0190.
869- // doi:10.1016/j.ipl.2005.11.003.
870- // https://www.sciencedirect.com/science/article/pii/S002001900500298X
871- //
872- // The keys computed are replaced with weight/randexp(1) instead rand()^(1/weight) as done in wrlmlR:
873- // https://web.archive.org/web/20201021162211/https://krlmlr.github.io/wrswoR/
874- // https://web.archive.org/web/20201021162520/https://github.com/krlmlr/wrswoR/blob/master/src/sample_int_crank.cpp
875- // As well as in JuliaStats:
876- // https://web.archive.org/web/20201021162949/https://github.com/JuliaStats/StatsBase.jl/blob/master/src/sampling.jl
877-
878880 // Compute (modified) keys as weight/randexp(1).
879881 xtensor<weight_type, 1 > keys;
880882 keys.resize ({dweights.size ()});
881- std::exponential_distribution<weight_type> randexp{1 . };
883+ std::exponential_distribution<weight_type> randexp{weight_type ( 1 ) };
882884 std::transform (dweights.storage ().begin (), dweights.storage ().end (), keys.begin (),
883885 [&randexp, &engine](auto w){ return w / randexp (engine); });
884886
0 commit comments