Skip to content

Commit 88d348f

Browse files
committed
add Rand objects
1 parent 0da32ac commit 88d348f

3 files changed

Lines changed: 52 additions & 1 deletion

File tree

src/RandomExtensions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module RandomExtensions
22

3-
export Combine, Uniform, Normal, Exponential, CloseOpen
3+
export Combine, Uniform, Normal, Exponential, CloseOpen, Rand
44

55
import Random: Sampler, rand, rand!
66

@@ -13,6 +13,7 @@ using SparseArrays: sprand, sprandn
1313
include("distributions.jl")
1414
include("sampling.jl")
1515
include("containers.jl")
16+
include("iteration.jl")
1617

1718

1819
## updated rand docstring (TODO: replace Base's one)

src/iteration.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# iterating over on-the-fly generated random values
2+
3+
## Rand
4+
5+
struct Rand{R<:AbstractRNG,S<:Sampler}
6+
rng::R
7+
sp::S
8+
end
9+
10+
# X can be an explicit Distribution, or an implicit one like 1:10
11+
Rand(rng::AbstractRNG, X) = Rand(rng, Sampler(rng, X))
12+
Rand(rng::AbstractRNG, ::Type{X}=Float64) where {X} = Rand(rng, Sampler(rng, X))
13+
14+
Rand(X) = Rand(GLOBAL_RNG, X)
15+
Rand(::Type{X}=Float64) where {X} = Rand(GLOBAL_RNG, X)
16+
17+
(R::Rand)(args...) = rand(R.rng, R.sp, args...)
18+
19+
Base.start(R::Rand) = R
20+
21+
function Base.next(::Union{Rand,Distribution}, R::Rand)
22+
e = R()
23+
e, R
24+
end
25+
26+
Base.done(::Union{Rand,Distribution}, ::Rand) = false
27+
28+
Base.IteratorSize(::Type{<:Rand}) = Base.IsInfinite()
29+
30+
Base.IteratorEltype(::Type{<:Rand}) = Base.HasEltype()
31+
Base.eltype(::Type{<:Rand{R, <:Sampler{T}}}) where {R,T} = T
32+
33+
# convenience iteration over distributions
34+
35+
Base.start(d::Distribution) = Rand(GLOBAL_RNG, d)
36+
Base.IteratorSize(::Type{<:Distribution}) = Base.IsInfinite()

test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,17 @@ end
8888
@test length(s) == 8
8989
@test Set(s) <= Set("asd")
9090
end
91+
92+
@testset "Rand" for rng in ([], [MersenneTwister(0)], [RandomDevice()])
93+
for XT = zip(([Int], [1:3], []), (Int, Int, Float64))
94+
X, T = XT
95+
r = Rand(rng..., X...)
96+
@test collect(Iterators.take(r, 10)) isa Vector{T}
97+
@test r() isa T
98+
@test r(2, 3) isa Matrix{T}
99+
@test r(.3, 2, 3) isa SparseMatrixCSC{T}
100+
end
101+
for d = (Uniform(1:10), Uniform(Int))
102+
@test collect(Iterators.take(d, 10)) isa Vector{Int}
103+
end
104+
end

0 commit comments

Comments
 (0)