Skip to content

Commit 67f3bca

Browse files
committed
@rand: fix missing esc
1 parent 198f49a commit 67f3bca

2 files changed

Lines changed: 37 additions & 25 deletions

File tree

src/macros.jl

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,53 @@ macro rand(exp)
33
end
44

55
function rand_macro(ex)
6-
ex isa Expr &&
7-
ex.head (:(=), :function) &&
8-
ex.args[1].head == :call &&
9-
ex.args[1].args[1] == :rand || throw(ArgumentError(
6+
ex isa Expr && ex.head (:(=), :function) ||
7+
throw(ArgumentError("@rand requires an expression defining `rand`"))
8+
sig = ex.args[1]
9+
body = ex.args[2]
10+
11+
sig.head == :call &&
12+
sig.args[1] == :rand || throw(ArgumentError(
1013
"@rand requires a function expression defining `rand`"))
1114

12-
name = ex.args[1].args[2].args[1] # x
13-
namefull = ex.args[1].args[2] # x::X
14-
@assert namefull.head == :(::)
15-
namefull.args[2] = esc(namefull.args[2])
15+
argname = sig.args[2].args[1] # x
16+
namefull = sig.args[2] # x::X
17+
@assert namefull.head == :(::) # TODO: throw exception
1618

1719
sps = Any[] # sub-samplers
18-
body = samplerize!(sps, ex.args[2], name)
20+
rng = gensym()
21+
body = samplerize!(sps, body, argname, rng)
1922
istrivial = isempty(sps)
2023
# rand -> Base.rand
21-
ex = Expr(ex.head, Expr(:call, :(Random.rand),
22-
:(rng::AbstractRNG),
23-
as_sampler(ex.args[1].args[2], istrivial)),
24-
body)
24+
25+
exsig = Expr(:call,
26+
:(Random.rand),
27+
:($(esc(rng))::AbstractRNG),
28+
esc(as_sampler(namefull, istrivial)))
29+
30+
ex = Expr(ex.head, exsig, esc(body))
2531

2632
sp = if istrivial
2733
# we explicitly define Sampler even in the trivial case to handle
2834
# redefinitions, where the old rand/sampler pair (for SamplerSimple)
2935
# is overwritten by a new one (for SamplerTrivial)
3036
quote
3137
Random.Sampler(::Type{RNG}, n::Repetition) where {RNG<:AbstractRNG} =
32-
SamplerTrivial($name)
38+
SamplerTrivial($(esc(argname)))
3339
end
3440
else
3541
quote
3642
Random.Sampler(::Type{RNG}, n::Repetition) where {RNG<:AbstractRNG} =
37-
SamplerSimple($name, tuple(SP))
43+
SamplerSimple($(esc(argname)), tuple(SP))
3844
end
3945
end
4046

4147
# insert x::X in the argument list, between RNG and n::Repetition
42-
insert!(sp.args[2].args[1].args[1].args, 3, namefull)
48+
insert!(sp.args[2].args[1].args[1].args, 3, esc(namefull))
4349

4450
# insert inner samplers
4551
if !istrivial
46-
SP = [Expr(:call, :Sampler, :RNG, x, :n) for x in sps]
52+
SP = [Expr(:call, :Sampler, :RNG, esc(x), :n) for x in sps]
4753
@assert :SP == pop!(sp.args[2].args[2].args[2].args[3].args)
4854
append!(sp.args[2].args[2].args[2].args[3].args, SP)
4955
end
@@ -55,14 +61,14 @@ function rand_macro(ex)
5561
end
5662

5763
function as_sampler(ex, istrivial)
58-
t = istrivial ? :SamplerTrivial : :SamplerSimple
64+
t = istrivial ? :(RandomExtensions.SamplerTrivial) : :(RandomExtensions.SamplerSimple)
5965
Expr(:(::),
6066
ex.args[1],
6167
Expr(:curly, t,
6268
Expr(:(<:), ex.args[2])))
6369
end
6470

65-
function samplerize!(sps, ex, name)
71+
function samplerize!(sps, ex, name, rng)
6672
if ex == name
6773
# not within a rand() call
6874
return Expr(:ref, name) # name -> name[]
@@ -72,8 +78,8 @@ function samplerize!(sps, ex, name)
7278
# TODO: handle Repetition == Val(Inf) for arrays
7379
push!(sps, ex.args[2])
7480
i = length(sps)
75-
Expr(:call, :rand, :rng, :($name.data[$i]), ex.args[3:end]...)
81+
Expr(:call, :rand, rng, :($name.data[$i]), ex.args[3:end]...)
7682
else
77-
Expr(ex.head, map(e -> samplerize!(sps, e, name), ex.args)...)
83+
Expr(ex.head, map(e -> samplerize!(sps, e, name, rng), ex.args)...)
7884
end
7985
end

test/runtests.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,9 +574,10 @@ end
574574
Base.eltype(::Type{Die}) = Int
575575

576576
@testset "@rand" begin
577-
d = Die(6)
578-
rng = MersenneTwister()
577+
# rng0 to be sure rng is not accessed in `esc`aped body of @rand
578+
rng0 = MersenneTwister()
579579

580+
d = Die(6)
580581
@rand function rand(d::Die)
581582
7
582583
end
@@ -585,8 +586,8 @@ Base.eltype(::Type{Die}) = Int
585586
# redefinition
586587
@rand rand(d::Die) = rand(1:d.n)
587588
@test rand(d) 1:6
588-
@test rand(rng, d) 1:6
589-
@test all((1:6), rand(rng, d, 10))
589+
@test rand(rng0, d) 1:6
590+
@test all((1:6), rand(rng0, d, 10))
590591
@test eltype(rand(d, 3)) == Int
591592

592593
# redefinition (multiple inner samplers)
@@ -601,4 +602,9 @@ Base.eltype(::Type{Die}) = Int
601602
# redefinition back to SamplerTrivial
602603
@rand rand(d::Die) = d.n
603604
@test all(==(6), rand(d, 100))
605+
606+
# test esc-correctness
607+
VAR = 100
608+
@rand rand(d::Die) = rand(VAR+1:VAR+d.n) - VAR
609+
@test all((1:6), rand(d, 100))
604610
end

0 commit comments

Comments
 (0)