Skip to content

Commit c1322ef

Browse files
committed
enable Sampler(rng, Tuple{...})
1 parent 7f81bef commit c1322ef

2 files changed

Lines changed: 43 additions & 32 deletions

File tree

src/sampling.jl

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -195,35 +195,12 @@ Sampler(RNG::Type{<:AbstractRNG}, ::Type{Complex{T}}, n::Repetition) where {T<:R
195195

196196
### sampler for tuples
197197

198-
@generated function Sampler(RNG::Type{<:AbstractRNG}, ::Type{T}, n::Repetition) where {T<:Tuple}
199-
d = Dict{DataType,Int}()
200-
sps = []
201-
for t in T.parameters
202-
i = get(d, t, nothing)
203-
if i === nothing
204-
push!(sps, :(Sampler(RNG, $t, n)))
205-
d[t] = length(sps)
206-
else
207-
push!(sps, Val(i))
208-
end
209-
end
210-
:(SamplerTag{Cont{T}}(tuple($(sps...))))
211-
end
198+
#### "simple scalar" (non-make) version
212199

213-
@generated function rand(rng::AbstractRNG, sp::SamplerTag{Cont{T},S}) where {T<:Tuple,S<:Tuple}
214-
@assert fieldcount(T) == fieldcount(S)
215-
rands = []
216-
for i = 1:fieldcount(T)
217-
j = fieldtype(S, i) <: Val ?
218-
fieldtype(S, i).parameters[1] :
219-
i
220-
push!(rands, :(convert($(fieldtype(T, i)),
221-
rand(rng, sp.data[$j]))))
222-
end
223-
:(tuple($(rands...)))
224-
end
200+
Sampler(RNG::Type{<:AbstractRNG}, ::Type{T}, n::Repetition) where {T<:Tuple} =
201+
Sampler(RNG, make(T), n)
225202

226-
#### with make
203+
#### make
227204

228205
# implement make(Tuple, S1, S2...), e.g. for rand(make(Tuple, Int, 1:3)),
229206
# and make(NTuple{N}, S)
@@ -293,18 +270,43 @@ make(::Type{T}, ::Type{X}, Y, ::Type{Z}) where {T<:Tuple,X,Z} = _make(
293270
make(::Type{T}, X, ::Type{Y}, ::Type{Z}) where {T<:Tuple,Y,Z} = _make(T, X, Y, Z)
294271
make(::Type{T}, ::Type{X}, ::Type{Y}, ::Type{Z}) where {T<:Tuple,X,Y,Z} = _make(T, X, Y, Z)
295272

296-
##### Sampler for general tuples (rand is already implemented above, like for rand(Tuple{...})
273+
#### Sampler for general tuples
297274

298275
@generated function Sampler(RNG::Type{<:AbstractRNG}, c::Make1{T,X}, n::Repetition) where {T<:Tuple,X<:Tuple}
299276
@assert fieldcount(T) == fieldcount(X)
300277
sps = [:(sampler(RNG, c.x[$i], n)) for i in 1:length(T.parameters)]
301278
:(SamplerTag{Cont{T}}(tuple($(sps...))))
302279
end
303280

304-
Sampler(RNG::Type{<:AbstractRNG}, ::Make0{T}, n::Repetition) where {T<:Tuple} =
305-
Sampler(RNG, T, n)
281+
@generated function Sampler(RNG::Type{<:AbstractRNG}, ::Make0{T}, n::Repetition) where {T<:Tuple}
282+
d = Dict{DataType,Int}()
283+
sps = []
284+
for t in T.parameters
285+
i = get(d, t, nothing)
286+
if i === nothing
287+
push!(sps, :(Sampler(RNG, $t, n)))
288+
d[t] = length(sps)
289+
else
290+
push!(sps, Val(i))
291+
end
292+
end
293+
:(SamplerTag{Cont{T}}(tuple($(sps...))))
294+
end
295+
296+
@generated function rand(rng::AbstractRNG, sp::SamplerTag{Cont{T},S}) where {T<:Tuple,S<:Tuple}
297+
@assert fieldcount(T) == fieldcount(S)
298+
rands = []
299+
for i = 1:fieldcount(T)
300+
j = fieldtype(S, i) <: Val ?
301+
fieldtype(S, i).parameters[1] :
302+
i
303+
push!(rands, :(convert($(fieldtype(T, i)),
304+
rand(rng, sp.data[$j]))))
305+
end
306+
:(tuple($(rands...)))
307+
end
306308

307-
##### for "NTuple-like"
309+
#### for "NTuple-like"
308310

309311
# should catch Tuple{Integer,Integer} which is not NTuple, or even Tuple{Int,UInt}, when only one sampler was passed
310312

test/runtests.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using RandomExtensions, Random, SparseArrays
2-
using Random: Sampler
2+
using Random: Sampler, gentype
33
using Test
44

55
@testset "Distributions" begin
@@ -309,6 +309,15 @@ end
309309
@test rand(T) isa Tuple{tlist...}
310310
end
311311
@test rand(Tuple{}) === ()
312+
sp = Sampler(MersenneTwister, Tuple)
313+
@test gentype(sp) == Tuple{}
314+
@test rand(sp) == ()
315+
sp = Sampler(MersenneTwister, NTuple{3})
316+
@test gentype(sp) == NTuple{3,Float64}
317+
@test rand(sp) isa NTuple{3,Float64}
318+
sp = Sampler(MersenneTwister, Tuple{Int8,UInt8})
319+
@test gentype(sp) == Tuple{Int8,UInt8}
320+
@test rand(sp) isa Tuple{Int8,UInt8}
312321
end
313322

314323
@testset "rand(make(Tuple, ...))" begin

0 commit comments

Comments
 (0)