Skip to content

Commit 340f7b6

Browse files
committed
add rand(make(Dict, ...))
1 parent 4c7fbc0 commit 340f7b6

3 files changed

Lines changed: 49 additions & 47 deletions

File tree

src/containers.jl

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -96,49 +96,19 @@ end
9696
@make_array_container(BitArray)
9797

9898

99-
## dicts
99+
## sets/dicts
100100

101-
# again same inference bug
102-
# TODO: extend to AbstractDict ? (needs to work-around the inderence bug)
103-
default_sampling(::Type{<:Dict{K,V}}) where {K,V} = Pair{K,V}
104-
default_sampling(D::Type{<:Dict}) = throw(ArgumentError("under-specified scalar type for $D"))
105-
106-
rand!(A::AbstractDict{K,V}, dist::Union{Type{<:Pair},Distribution{<:Pair}}=make(Pair, K, V)) where {K,V} =
107-
rand!(GLOBAL_RNG, A, dist)
108-
109-
rand!(rng::AbstractRNG, A::AbstractDict{K,V},
110-
dist::Union{Type{<:Pair},Distribution{<:Pair}}=make(Pair, K, V)) where {K,V} =
111-
rand!(rng, A, Sampler(rng, dist))
112-
113-
function _rand!(rng::AbstractRNG, A::Union{AbstractDict,AbstractSet}, n::Integer, sp::Sampler)
101+
function _rand!(rng::AbstractRNG, A::SetDict, n::Integer, sp::Sampler)
114102
empty!(A)
115103
while length(A) < n
116104
push!(A, rand(rng, sp))
117105
end
118106
A
119107
end
120108

121-
rand!(rng::AbstractRNG, A::AbstractDict{K,V}, sp::Sampler) where {K,V} = _rand!(rng, A, length(A), sp)
122-
123-
rand(rng::AbstractRNG, dist::Distribution{P}, ::Type{T}, n::Integer) where {P<:Pair,T<:AbstractDict} =
124-
_rand!(rng, deduce_type(T, fieldtype(P, 1), fieldtype(P, 2))(), n, Sampler(rng, dist))
125-
126-
rand(rng::AbstractRNG, ::Type{P}, ::Type{T}, n::Integer) where {P<:Pair,T<:AbstractDict} = rand(rng, Uniform(P), T, n)
127-
128-
rand(rng::AbstractRNG, ::Type{T}, n::Integer) where {T<:AbstractDict} = rand(rng, default_sampling(T), T, n)
129-
130-
rand(u::Distribution{<:Pair}, ::Type{T}, n::Integer) where {T<:AbstractDict} = rand(GLOBAL_RNG, u, T, n)
131-
132-
rand(::Type{P}, ::Type{T}, n::Integer) where {P<:Pair,T<:AbstractDict} = rand(GLOBAL_RNG, Uniform(P), T, n)
133-
134-
rand(::Type{T}, n::Integer) where {T<:AbstractDict} = rand(GLOBAL_RNG, default_sampling(T), T, n)
135-
136-
137-
## sets
138-
139-
rand!( A::AbstractSet{T}, X) where {T} = rand!(GLOBAL_RNG, A, X)
140-
rand!(rng::AbstractRNG, A::AbstractSet, X) = _rand!(rng, A, length(A), sampler(rng, X))
141-
rand!( A::AbstractSet{T}, ::Type{X}=default_sampling(A)) where {T,X} = rand!(GLOBAL_RNG, A, X)
142-
rand!(rng::AbstractRNG, A::AbstractSet{T}, ::Type{X}=default_sampling(A)) where {T,X} = rand!(rng, A, Sampler(rng, X))
109+
rand!( A::SetDict, X) = rand!(GLOBAL_RNG, A, X)
110+
rand!(rng::AbstractRNG, A::SetDict, X) = _rand!(rng, A, length(A), sampler(rng, X))
111+
rand!( A::SetDict, ::Type{X}=default_sampling(A)) where {X} = rand!(GLOBAL_RNG, A, X)
112+
rand!(rng::AbstractRNG, A::SetDict, ::Type{X}=default_sampling(A)) where {X} = rand!(rng, A, Sampler(rng, X))
143113

144-
@make_container(T::Type{<:AbstractSet}, n::Integer)
114+
@make_container(T::Type{<:SetDict}, n::Integer)

src/sampling.jl

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -247,25 +247,28 @@ end
247247

248248
## collections
249249

250-
### sets
251-
252-
default_sampling(::Type{<:AbstractSet}) = Float64
253-
default_sampling(::Type{<:AbstractSet{T}}) where {T} = T
250+
### sets/dicts
254251

255-
make(T::Type{<:AbstractSet}, n::Integer) = make(T, default_sampling(T), Int(n))
252+
const SetDict = Union{AbstractSet,AbstractDict}
256253

257-
make(T::Type{<:AbstractSet}, X, n::Integer) = Make2{find_type(T, X, n)}(X , Int(n))
258-
make(T::Type{<:AbstractSet}, ::Type{X}, n::Integer) where {X} = Make2{find_type(T, X, n)}(X , Int(n))
254+
make(T::Type{<:SetDict}, X, n::Integer) = Make2{find_type(T, X, n)}(X , Int(n))
255+
make(T::Type{<:SetDict}, ::Type{X}, n::Integer) where {X} = Make2{find_type(T, X, n)}(X , Int(n))
256+
make(T::Type{<:SetDict}, n::Integer) = make(T, default_sampling(T), Int(n))
259257

260-
Sampler(RNG::Type{<:AbstractRNG}, c::Make2{T}, n::Repetition) where {T<:AbstractSet} =
258+
Sampler(RNG::Type{<:AbstractRNG}, c::Make2{T}, n::Repetition) where {T<:SetDict} =
261259
SamplerTag{Cont{T}}((sampler(RNG, c.x, n), c.y))
262260

263-
function rand(rng::AbstractRNG, sp::SamplerTag{Cont{S}}) where {S<:AbstractSet}
264-
# assuming S() creates an empty set
261+
function rand(rng::AbstractRNG, sp::SamplerTag{Cont{S}}) where {S<:SetDict}
262+
# assuming S() creates an empty set/dict
265263
s = sizehint!(S(), sp.data[2])
266264
_rand!(rng, s, sp.data[2], sp.data[1])
267265
end
268266

267+
### sets
268+
269+
default_sampling(::Type{<:AbstractSet}) = Float64
270+
default_sampling(::Type{<:AbstractSet{T}}) where {T} = T
271+
269272
#### Set
270273

271274
find_type(::Type{Set}, X, _) = Set{val_gentype(X)}
@@ -278,6 +281,23 @@ default_sampling(::Type{BitSet}) = Int8 # almost arbitrary, may change
278281
find_type(::Type{BitSet}, _, _) = BitSet
279282

280283

284+
### dicts
285+
286+
# again same inference bug
287+
# TODO: extend to AbstractDict ? (needs to work-around the inderence bug)
288+
default_sampling(::Type{Dict{K,V}}) where {K,V} = Pair{K,V}
289+
default_sampling(D::Type{<:Dict}) = throw(ArgumentError("under-specified scalar type for $D"))
290+
291+
find_type(D::Type{<:AbstractDict{K,V}}, _, ::Integer) where {K,V} = D
292+
find_type(D::Type{<:AbstractDict{K,V}}, ::Type, ::Integer) where {K,V} = D
293+
294+
#### Dict
295+
296+
find_type(::Type{Dict{K}}, X, ::Integer) where {K} = Dict{K,fieldtype(val_gentype(X), 2)}
297+
find_type(::Type{Dict{K,V} where K}, X, ::Integer) where {V} = Dict{fieldtype(val_gentype(X), 1),V}
298+
find_type(::Type{Dict}, X, ::Integer) = Dict{fieldtype(val_gentype(X), 1),fieldtype(val_gentype(X), 2)}
299+
300+
281301
### AbstractArray
282302

283303
default_sampling(::Type{<:AbstractArray{T}}) where {T} = T

test/runtests.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,18 @@ end
360360
end
361361
end
362362

363+
@testset "rand(make(Dict, ...))" begin
364+
for (D, S) = (Dict{Int16,Int16} => [],
365+
Dict{Int16} => [Pair{Int8,Int16}],
366+
Dict{K,Int16} where K => [Pair{Int16,Int8}],
367+
Dict => [Pair{Int16,Int16}])
368+
369+
d = rand(make(D, S..., 3))
370+
@test d isa Dict{Int16,Int16}
371+
@test length(d) == 3
372+
end
373+
end
374+
363375
@testset "rand(make(Array/BitArray, ...))" begin
364376
for (T, Arr) = (Bool => BitArray, Float64 => Array{Float64}),
365377
k = ([], [T], [Bernoulli(T, 0.3)]),

0 commit comments

Comments
 (0)