Skip to content

Commit 889303b

Browse files
committed
add support for ImmutableDict
1 parent 6c30651 commit 889303b

2 files changed

Lines changed: 27 additions & 16 deletions

File tree

src/sampling.jl

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -328,12 +328,13 @@ make(T::Type{<:SetDict}, ::Type{X}, n::Integer) where {X} = Make2{find_type(T, X
328328
make(T::Type{<:SetDict}, n::Integer) = make(T, default_sampling(T), Int(n))
329329

330330
Sampler(RNG::Type{<:AbstractRNG}, c::Make2{T}, n::Repetition) where {T<:SetDict} =
331-
SamplerTag{Cont{T}}((sampler(RNG, c.x, n), c.y))
331+
SamplerTag{Cont{T}}((sp = sampler(RNG, c.x, n),
332+
len = c.y))
332333

333334
function rand(rng::AbstractRNG, sp::SamplerTag{Cont{S}}) where {S<:SetDict}
334335
# assuming S() creates an empty set/dict
335336
s = sizehint!(S(), sp.data[2])
336-
_rand!(rng, s, sp.data[2], sp.data[1])
337+
_rand!(rng, s, sp.data.len, sp.data.sp)
337338
end
338339

339340
### sets
@@ -355,19 +356,28 @@ find_type(::Type{BitSet}, _, _) = BitSet
355356

356357
### dicts
357358

358-
# again same inference bug
359-
# TODO: extend to AbstractDict ? (needs to work-around the inderence bug)
360-
default_sampling(::Type{Dict{K,V}}) where {K,V} = Uniform(Pair{K,V})
361-
default_sampling(D::Type{<:Dict}) = throw(ArgumentError("under-specified scalar type for $D"))
362-
363359
find_type(D::Type{<:AbstractDict{K,V}}, _, ::Integer) where {K,V} = D
364360
find_type(D::Type{<:AbstractDict{K,V}}, ::Type, ::Integer) where {K,V} = D
365361

366-
#### Dict
362+
#### Dict/ImmutableDict
363+
364+
for D in (Dict, Base.ImmutableDict)
365+
@eval begin
366+
# again same inference bug
367+
# TODO: extend to AbstractDict ? (needs to work-around the inderence bug)
368+
default_sampling(::Type{$D{K,V}}) where {K,V} = Uniform(Pair{K,V})
369+
default_sampling(D::Type{<:$D}) = throw(ArgumentError("under-specified scalar type for $D"))
370+
371+
find_type(::Type{$D{K}}, X, ::Integer) where {K} = $D{K,fieldtype(val_gentype(X), 2)}
372+
find_type(::Type{$D{K,V} where K}, X, ::Integer) where {V} = $D{fieldtype(val_gentype(X), 1),V}
373+
find_type(::Type{$D}, X, ::Integer) = $D{fieldtype(val_gentype(X), 1),fieldtype(val_gentype(X), 2)}
374+
end
375+
end
367376

368-
find_type(::Type{Dict{K}}, X, ::Integer) where {K} = Dict{K,fieldtype(val_gentype(X), 2)}
369-
find_type(::Type{Dict{K,V} where K}, X, ::Integer) where {V} = Dict{fieldtype(val_gentype(X), 1),V}
370-
find_type(::Type{Dict}, X, ::Integer) = Dict{fieldtype(val_gentype(X), 1),fieldtype(val_gentype(X), 2)}
377+
rand(rng::AbstractRNG, sp::SamplerTag{Cont{S}}) where {S<:Base.ImmutableDict} =
378+
foldl((d, _) -> Base.ImmutableDict(d, rand(rng, sp.data.sp)),
379+
1:sp.data.len,
380+
init=S())
371381

372382

373383
### AbstractArray

test/runtests.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,14 @@ end
405405
end
406406

407407
@testset "rand(make(Dict, ...))" begin
408-
for (D, S) = (Dict{Int16,Int16} => [],
409-
Dict{Int16} => [Pair{Int8,Int16}],
410-
Dict{K,Int16} where K => [Pair{Int16,Int8}],
411-
Dict => [Pair{Int16,Int16}])
408+
for BD = (Dict, Base.ImmutableDict),
409+
(D, S) = (BD{Int16,Int16} => [],
410+
BD{Int16} => [Pair{Int8,Int16}],
411+
BD{K,Int16} where K => [Pair{Int16,Int8}],
412+
BD => [Pair{Int16,Int16}])
412413

413414
d = rand(make(D, S..., 3))
414-
@test d isa Dict{Int16,Int16}
415+
@test d isa BD{Int16,Int16}
415416
@test length(d) == 3
416417
end
417418
end

0 commit comments

Comments
 (0)