Skip to content

Commit 644f980

Browse files
committed
add support for NamedTuple
1 parent c1322ef commit 644f980

3 files changed

Lines changed: 115 additions & 31 deletions

File tree

src/RandomExtensions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ export rand!, AbstractRNG, MersenneTwister, RandomDevice
88
import Random: Sampler, rand, rand!, gentype
99

1010
using Random
11-
using Random: GLOBAL_RNG, SamplerTrivial, SamplerSimple, SamplerTag, Repetition
11+
using Random: GLOBAL_RNG, SamplerTrivial, SamplerSimple, SamplerTag, SamplerType, Repetition
1212

1313
using SparseArrays: sprand, sprandn, AbstractSparseArray, SparseVector, SparseMatrixCSC
1414

src/sampling.jl

Lines changed: 77 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -197,32 +197,43 @@ Sampler(RNG::Type{<:AbstractRNG}, ::Type{Complex{T}}, n::Repetition) where {T<:R
197197

198198
#### "simple scalar" (non-make) version
199199

200-
Sampler(RNG::Type{<:AbstractRNG}, ::Type{T}, n::Repetition) where {T<:Tuple} =
200+
Sampler(RNG::Type{<:AbstractRNG}, ::Type{T}, n::Repetition) where {T<:Union{Tuple,NamedTuple}} =
201201
Sampler(RNG, make(T), n)
202202

203203
#### make
204204

205205
# implement make(Tuple, S1, S2...), e.g. for rand(make(Tuple, Int, 1:3)),
206206
# and make(NTuple{N}, S)
207207

208-
@generated function _make(::Type{T}, args...) where T <: Tuple
209-
if isempty(args)
210-
TT = T === Tuple ? Tuple{} :
211-
T === NTuple ? Tuple{} :
212-
T isa UnionAll && Type{T} <: Type{NTuple{N}} where N ? T{default_gentype(Tuple)} :
213-
T
214-
return :(Make0{$TT}())
215-
end
216-
isNT = length(args) == 1 && T !== Tuple && (
217-
T <: NTuple || !isa(T, UnionAll)) # !isa(Tuple, UnionAll) !!
208+
_find_type(::Type{T}) where {T<:Tuple} =
209+
T === Tuple ?
210+
Tuple{} :
211+
T === NTuple ?
212+
Tuple{} :
213+
T isa UnionAll && Type{T} <: Type{NTuple{N}} where N ?
214+
T{default_gentype(Tuple)} :
215+
T
218216

217+
function _find_type(::Type{T}, args...) where T <: Tuple
219218
types = [t <: Type ? t.parameters[1] : gentype(t) for t in args]
220-
TT = T === Tuple ? Tuple{types...} :
221-
isNT ? (T isa UnionAll ? Tuple{fill(types[1], fieldcount(T))...} : T ) :
219+
TT = T === Tuple ?
220+
Tuple{types...} :
221+
_isNTuple(T, args...) ?
222+
(T isa UnionAll ? Tuple{fill(types[1], fieldcount(T))...} : T ) :
222223
T
224+
TT
225+
end
226+
227+
_isNTuple(::Type{T}, args...) where {T<:Tuple} =
228+
length(args) == 1 && T !== Tuple && (
229+
T <: NTuple || !isa(T, UnionAll)) # !isa(Tuple, UnionAll) !!
230+
231+
@generated function _make(::Type{T}, args...) where T <: Tuple
232+
isempty(args) && return :(Make0{$(_find_type(T))}())
233+
TT = _find_type(T, args...)
223234
samples = [t <: Type ? :(UniformType{$(t.parameters[1])}()) :
224235
:(args[$i]) for (i, t) in enumerate(args)]
225-
if isNT
236+
if _isNTuple(T, args...)
226237
:(Make1{$TT}($(samples[1])))
227238
else
228239
quote
@@ -253,22 +264,26 @@ make(::Type{NTuple{N,T} where N}, ::Type{X}, n::Integer) where {T,X} = make(NTup
253264

254265
# disambiguate
255266

256-
make(::Type{T}, X) where {T<:Tuple} = _make(T, X)
257-
make(::Type{T}, ::Type{X}) where {T<:Tuple,X} = _make(T, X)
258-
259-
make(::Type{T}, X, Y) where {T<:Tuple} = _make(T, X, Y)
260-
make(::Type{T}, ::Type{X}, Y) where {T<:Tuple,X} = _make(T, X, Y)
261-
make(::Type{T}, X, ::Type{Y}) where {T<:Tuple,Y} = _make(T, X, Y)
262-
make(::Type{T}, ::Type{X}, ::Type{Y}) where {T<:Tuple,X,Y} = _make(T, X, Y)
263-
264-
make(::Type{T}, X, Y, Z) where {T<:Tuple} = _make(T, X, Y, Z)
265-
make(::Type{T}, ::Type{X}, Y, Z) where {T<:Tuple,X} = _make(T, X, Y, Z)
266-
make(::Type{T}, X, ::Type{Y}, Z) where {T<:Tuple,Y} = _make(T, X, Y, Z)
267-
make(::Type{T}, ::Type{X}, ::Type{Y}, Z) where {T<:Tuple,X,Y} = _make(T, X, Y, Z)
268-
make(::Type{T}, X, Y, ::Type{Z}) where {T<:Tuple,Z} = _make(T, X, Y, Z)
269-
make(::Type{T}, ::Type{X}, Y, ::Type{Z}) where {T<:Tuple,X,Z} = _make(T, X, Y, Z)
270-
make(::Type{T}, X, ::Type{Y}, ::Type{Z}) where {T<:Tuple,Y,Z} = _make(T, X, Y, Z)
271-
make(::Type{T}, ::Type{X}, ::Type{Y}, ::Type{Z}) where {T<:Tuple,X,Y,Z} = _make(T, X, Y, Z)
267+
for Tupl = (Tuple, NamedTuple)
268+
@eval begin
269+
make(::Type{T}, X) where {T<:$Tupl} = _make(T, X)
270+
make(::Type{T}, ::Type{X}) where {T<:$Tupl,X} = _make(T, X)
271+
272+
make(::Type{T}, X, Y) where {T<:$Tupl} = _make(T, X, Y)
273+
make(::Type{T}, ::Type{X}, Y) where {T<:$Tupl,X} = _make(T, X, Y)
274+
make(::Type{T}, X, ::Type{Y}) where {T<:$Tupl,Y} = _make(T, X, Y)
275+
make(::Type{T}, ::Type{X}, ::Type{Y}) where {T<:$Tupl,X,Y} = _make(T, X, Y)
276+
277+
make(::Type{T}, X, Y, Z) where {T<:$Tupl} = _make(T, X, Y, Z)
278+
make(::Type{T}, ::Type{X}, Y, Z) where {T<:$Tupl,X} = _make(T, X, Y, Z)
279+
make(::Type{T}, X, ::Type{Y}, Z) where {T<:$Tupl,Y} = _make(T, X, Y, Z)
280+
make(::Type{T}, ::Type{X}, ::Type{Y}, Z) where {T<:$Tupl,X,Y} = _make(T, X, Y, Z)
281+
make(::Type{T}, X, Y, ::Type{Z}) where {T<:$Tupl,Z} = _make(T, X, Y, Z)
282+
make(::Type{T}, ::Type{X}, Y, ::Type{Z}) where {T<:$Tupl,X,Z} = _make(T, X, Y, Z)
283+
make(::Type{T}, X, ::Type{Y}, ::Type{Z}) where {T<:$Tupl,Y,Z} = _make(T, X, Y, Z)
284+
make(::Type{T}, ::Type{X}, ::Type{Y}, ::Type{Z}) where {T<:$Tupl,X,Y,Z} = _make(T, X, Y, Z)
285+
end
286+
end
272287

273288
#### Sampler for general tuples
274289

@@ -318,6 +333,38 @@ Sampler(RNG::Type{<:AbstractRNG}, c::Make1{T,X}, n::Repetition) where {T<:Tuple,
318333
:(tuple($(rands...)))
319334
end
320335

336+
### named tuples
337+
338+
make(T::Type{<:NamedTuple}, args...) = _make(T, args...)
339+
340+
_make(::Type{NamedTuple{}}) = Make0{NamedTuple{}}()
341+
342+
@generated function _make(::Type{NamedTuple{K}}, X...) where {K}
343+
if length(X) <= 1
344+
NT = NamedTuple{K,_find_type(NTuple{length(K)}, X...)}
345+
:(Make1{$NT}(make(NTuple{length(K)}, X...)))
346+
else
347+
NT = NamedTuple{K,_find_type(Tuple, X...)}
348+
:(Make1{$NT}(make(Tuple, X...)))
349+
end
350+
end
351+
352+
function _make(::Type{NamedTuple{K,V}}, X...) where {K,V}
353+
Make1{NamedTuple{K,V}}(make(V, X...))
354+
end
355+
356+
# necessary to avoid circular defintions
357+
Sampler(RNG::Type{<:AbstractRNG}, m::Make0{NamedTuple}, n::Repetition) =
358+
SamplerType{NamedTuple}()
359+
360+
Sampler(RNG::Type{<:AbstractRNG}, m::Make1{T}, n::Repetition) where T <: NamedTuple =
361+
SamplerTag{Cont{T}}(Sampler(RNG, m.x , n))
362+
363+
rand(rng::AbstractRNG, sp::SamplerType{NamedTuple{}}) = NamedTuple()
364+
365+
rand(rng::AbstractRNG, sp::SamplerTag{Cont{T}}) where T <: NamedTuple =
366+
T(rand(rng, sp.data))
367+
321368

322369
## collections
323370

test/runtests.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,43 @@ end
381381
@test r isa NTuple{3,UInt8}
382382
end
383383

384+
@testset "NamedTuple" begin
385+
for t = (rand(make(NamedTuple)), rand(NamedTuple))
386+
@test t == NamedTuple()
387+
end
388+
389+
for t = (rand(make(NamedTuple{(:a,)})), rand(NamedTuple{(:a,)}))
390+
@test t isa NamedTuple{(:a,), Tuple{Float64}}
391+
end
392+
for t = (rand(make(NamedTuple{(:a,),Tuple{Int}})),
393+
rand(NamedTuple{(:a,),Tuple{Int}}))
394+
@test t isa NamedTuple{(:a,), Tuple{Int}}
395+
end
396+
397+
t = rand(make(NamedTuple{(:a,)}, 1:3))
398+
@test t isa NamedTuple{(:a,), Tuple{Int}}
399+
@test t.a 1:3
400+
t = rand(make(NamedTuple{(:a,),Tuple{Float64}}, 1:3))
401+
@test t isa NamedTuple{(:a,), Tuple{Float64}}
402+
@test t.a 1:3
403+
404+
405+
for t = (rand(make(NamedTuple{(:a, :b)})),
406+
rand(NamedTuple{(:a, :b)}))
407+
@test t isa NamedTuple{(:a, :b), Tuple{Float64,Float64}}
408+
end
409+
for t = (rand(make(NamedTuple{(:a, :b),Tuple{Int,UInt8}})),
410+
rand(NamedTuple{(:a, :b),Tuple{Int,UInt8}}))
411+
@test t isa NamedTuple{(:a, :b), Tuple{Int,UInt8}}
412+
end
413+
t = rand(make(NamedTuple{(:a, :b)}, 1:3))
414+
@test t isa NamedTuple{(:a, :b), Tuple{Int,Int}}
415+
@test t.a 1:3 && t.b 1:3
416+
t = rand(make(NamedTuple{(:a, :b),Tuple{Float64,UInt8}}, 1:3))
417+
@test t isa NamedTuple{(:a, :b), Tuple{Float64,UInt8}}
418+
@test t.a 1:3 && t.b 1:3
419+
end
420+
384421
@testset "rand(make(String, ...))" begin
385422
b = UInt8['0':'9';'A':'Z';'a':'z']
386423

0 commit comments

Comments
 (0)