Skip to content

Commit bb01b40

Browse files
committed
@rand: use Val(Inf) for sub-samplers generating arrays
1 parent 64b8c30 commit bb01b40

2 files changed

Lines changed: 16 additions & 5 deletions

File tree

src/macros.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,12 @@ function rand_macro(ex)
3030
namefull = sig.args[2] # x::X
3131
@assert namefull.head == :(::) # TODO: throw exception
3232

33-
sps = Any[] # sub-samplers
33+
# sub-samplers; second argument true forces Val(Inf) for the sampler
34+
sps = Pair{<:Any,Bool}[]
35+
3436
rng = gensym()
3537
body = samplerize!(sps, body, argname, rng)
3638
istrivial = isempty(sps)
37-
# rand -> Base.rand
3839

3940
exsig = Expr(:call,
4041
:(Random.rand),
@@ -69,7 +70,7 @@ function rand_macro(ex)
6970

7071
# insert inner samplers
7172
if !istrivial
72-
SP = [Expr(:call, :Sampler, :RNG, esc(x), :n) for x in sps]
73+
SP = [Expr(:call, :Sampler, :RNG, esc(x), many ? Val{Inf}() : :n) for (x, many) in sps]
7374
@assert :SP == pop!(sp.args[2].args[2].args[2].args[3].args)
7475
append!(sp.args[2].args[2].args[2].args[3].args, SP)
7576
end
@@ -95,8 +96,9 @@ function samplerize!(sps, ex, name, rng)
9596
end
9697
ex isa Expr || return ex
9798
if ex.head == :call && ex.args[1] == :rand
98-
# TODO: handle Repetition == Val(Inf) for arrays
99-
push!(sps, ex.args[2])
99+
# we assume that if rand has more than one arg, we want
100+
# a Val(Inf) sampler (e.g. rand(1:9, 2, 3)
101+
push!(sps, ex.args[2] => length(ex.args) > 2)
100102
i = length(sps)
101103
Expr(:call, :rand, rng, :($name.data[$i]), ex.args[3:end]...)
102104
else

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,15 @@ Base.eltype(::Type{DieT{T}}) where {T} = T
635635
@rand rand(d::Die) = d.n
636636
@test all(==(6), rand(d, 100))
637637

638+
# redefinition (Val(Inf) iff rand calls with 2+ arguments)
639+
@rand rand(d::Die) = (rand("asd") + rand("asd", 3)[1]; 1)
640+
s = Sampler(MersenneTwister, d, Val(1))
641+
@test s.data[1] isa Random.SamplerSimple{String}
642+
@test s.data[2] isa Random.SamplerSimple{Vector{Char}} # "proof" of Val(Inf)
643+
s = Sampler(MersenneTwister, d, Val(Inf))
644+
@test s.data[1] isa Random.SamplerSimple{Vector{Char}}
645+
@test s.data[2] isa Random.SamplerSimple{Vector{Char}}
646+
638647
# test esc-correctness
639648
VAR = 100
640649
@rand rand(d::Die) = rand(VAR+1:VAR+d.n) - VAR

0 commit comments

Comments
 (0)