Skip to content

Commit b5f09a6

Browse files
committed
@rand: allow redefinitions (SamplerSimple -> SamplerTrivial)
1 parent 75c36f1 commit b5f09a6

2 files changed

Lines changed: 27 additions & 7 deletions

File tree

src/macros.jl

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,32 @@ function rand_macro(ex)
2222
:(rng::AbstractRNG),
2323
as_sampler(ex.args[1].args[2], istrivial)),
2424
body)
25-
istrivial && return ex
26-
sp = quote
27-
Random.Sampler(::Type{RNG}, n::Repetition) where {RNG<:AbstractRNG} =
28-
SamplerSimple($name, tuple(SP))
25+
26+
sp = if istrivial
27+
# we explicitly define Sampler even in the trivial case to handle
28+
# redefinitions, where the old rand/sampler pair (for SamplerSimple)
29+
# is overwritten by a new one (for SamplerTrivial)
30+
quote
31+
Random.Sampler(::Type{RNG}, n::Repetition) where {RNG<:AbstractRNG} =
32+
SamplerTrivial($name)
33+
end
34+
else
35+
quote
36+
Random.Sampler(::Type{RNG}, n::Repetition) where {RNG<:AbstractRNG} =
37+
SamplerSimple($name, tuple(SP))
38+
end
2939
end
40+
3041
# insert x::X in the argument list, between RNG and n::Repetition
3142
insert!(sp.args[2].args[1].args[1].args, 3, namefull)
32-
SP = [Expr(:call, :Sampler, :RNG, x, :n) for x in sps]
33-
@assert :SP == pop!(sp.args[2].args[2].args[2].args[3].args)
34-
append!(sp.args[2].args[2].args[2].args[3].args, SP)
43+
44+
# insert inner samplers
45+
if !istrivial
46+
SP = [Expr(:call, :Sampler, :RNG, x, :n) for x in sps]
47+
@assert :SP == pop!(sp.args[2].args[2].args[2].args[3].args)
48+
append!(sp.args[2].args[2].args[2].args[3].args, SP)
49+
end
50+
3551
quote
3652
$ex
3753
$(sp.args[2]) # unwrap the quote/block around the definition

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,4 +594,8 @@ end
594594
@rand rand(d::Die) = rand(10:10) + rand(1:d.n)
595595
@test rand(d) 11:16
596596
@test all((11:16), rand(d, 10))
597+
598+
# redefinition back to SamplerTrivial
599+
@rand rand(d::Die) = 0
600+
@test all(==(0), rand(d, 100))
597601
end

0 commit comments

Comments
 (0)