Skip to content

Commit f997d26

Browse files
committed
Simplify iterators usage in weighted choice
1 parent d4bedf5 commit f997d26

1 file changed

Lines changed: 6 additions & 8 deletions

File tree

include/xtensor/xrandom.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -849,15 +849,13 @@ namespace xt
849849
// Sample u uniformly in the range [0, sum(weights)[
850850
// The index idx of the sampled element is such that weight_cumul[idx - 1] <= u < weight_cumul[idx].
851851
// Where weight_cumul[-1] is implicitly 0, as the empty sum.
852-
const auto weight_cumul = eval(cumsum(dweights));
853-
const auto weight_cumul_begin = weight_cumul.storage().begin();
854-
std::uniform_real_distribution<weight_type> weight_dist{0, weight_cumul[weight_cumul.size() - 1]};
852+
const auto wc = eval(cumsum(dweights));
853+
std::uniform_real_distribution<weight_type> weight_dist{0, wc[wc.size() - 1]};
855854
for(auto& x : result)
856855
{
857856
const auto u = weight_dist(engine);
858-
const auto idx_iter = std::upper_bound(weight_cumul_begin, weight_cumul.storage().end(), u);
859-
const auto idx = static_cast<size_type>(idx_iter - weight_cumul_begin);
860-
x = de.storage()[idx];
857+
const auto idx = static_cast<size_type>(std::upper_bound(wc.cbegin(), wc.cend(), u) - wc.cbegin());
858+
x = de[idx];
861859
}
862860

863861
}
@@ -867,12 +865,12 @@ namespace xt
867865
xtensor<weight_type, 1> keys;
868866
keys.resize({dweights.size()});
869867
std::exponential_distribution<weight_type> randexp{weight_type(1)};
870-
std::transform(dweights.storage().begin(), dweights.storage().end(), keys.begin(),
868+
std::transform(dweights.cbegin(), dweights.cend(), keys.begin(),
871869
[&randexp, &engine](auto w){ return w / randexp(engine); });
872870

873871
// Find indexes for the n biggest key
874872
xtensor<size_type, 1> indices = arange<size_type>(0, dweights.size());
875-
std::partial_sort(indices.storage().begin(), indices.storage().begin() + n, indices.storage().end(),
873+
std::partial_sort(indices.begin(), indices.begin() + n, indices.end(),
876874
[&keys](auto i, auto j) { return keys[i] > keys[j]; });
877875

878876
// Return samples with the n biggest keys

0 commit comments

Comments
 (0)