@@ -3,47 +3,53 @@ macro rand(exp)
33end
44
55function rand_macro (ex)
6- ex isa Expr &&
7- ex. head ∈ (:(= ), :function ) &&
8- ex. args[1 ]. head == :call &&
9- ex. args[1 ]. args[1 ] == :rand || throw (ArgumentError (
6+ ex isa Expr && ex. head ∈ (:(= ), :function ) ||
7+ throw (ArgumentError (" @rand requires an expression defining `rand`" ))
8+ sig = ex. args[1 ]
9+ body = ex. args[2 ]
10+
11+ sig. head == :call &&
12+ sig. args[1 ] == :rand || throw (ArgumentError (
1013 " @rand requires a function expression defining `rand`" ))
1114
12- name = ex. args[1 ]. args[2 ]. args[1 ] # x
13- namefull = ex. args[1 ]. args[2 ] # x::X
14- @assert namefull. head == :(:: )
15- namefull. args[2 ] = esc (namefull. args[2 ])
15+ argname = sig. args[2 ]. args[1 ] # x
16+ namefull = sig. args[2 ] # x::X
17+ @assert namefull. head == :(:: ) # TODO : throw exception
1618
1719 sps = Any[] # sub-samplers
18- body = samplerize! (sps, ex. args[2 ], name)
20+ rng = gensym ()
21+ body = samplerize! (sps, body, argname, rng)
1922 istrivial = isempty (sps)
2023 # rand -> Base.rand
21- ex = Expr (ex. head, Expr (:call , :(Random. rand),
22- :(rng:: AbstractRNG ),
23- as_sampler (ex. args[1 ]. args[2 ], istrivial)),
24- body)
24+
25+ exsig = Expr (:call ,
26+ :(Random. rand),
27+ :($ (esc (rng)):: AbstractRNG ),
28+ esc (as_sampler (namefull, istrivial)))
29+
30+ ex = Expr (ex. head, exsig, esc (body))
2531
2632 sp = if istrivial
2733 # we explicitly define Sampler even in the trivial case to handle
2834 # redefinitions, where the old rand/sampler pair (for SamplerSimple)
2935 # is overwritten by a new one (for SamplerTrivial)
3036 quote
3137 Random. Sampler (:: Type{RNG} , n:: Repetition ) where {RNG<: AbstractRNG } =
32- SamplerTrivial ($ name )
38+ SamplerTrivial ($ ( esc (argname)) )
3339 end
3440 else
3541 quote
3642 Random. Sampler (:: Type{RNG} , n:: Repetition ) where {RNG<: AbstractRNG } =
37- SamplerSimple ($ name , tuple (SP))
43+ SamplerSimple ($ ( esc (argname)) , tuple (SP))
3844 end
3945 end
4046
4147 # insert x::X in the argument list, between RNG and n::Repetition
42- insert! (sp. args[2 ]. args[1 ]. args[1 ]. args, 3 , namefull)
48+ insert! (sp. args[2 ]. args[1 ]. args[1 ]. args, 3 , esc ( namefull) )
4349
4450 # insert inner samplers
4551 if ! istrivial
46- SP = [Expr (:call , :Sampler , :RNG , x , :n ) for x in sps]
52+ SP = [Expr (:call , :Sampler , :RNG , esc (x) , :n ) for x in sps]
4753 @assert :SP == pop! (sp. args[2 ]. args[2 ]. args[2 ]. args[3 ]. args)
4854 append! (sp. args[2 ]. args[2 ]. args[2 ]. args[3 ]. args, SP)
4955 end
@@ -55,14 +61,14 @@ function rand_macro(ex)
5561end
5662
5763function as_sampler (ex, istrivial)
58- t = istrivial ? :SamplerTrivial : :SamplerSimple
64+ t = istrivial ? :(RandomExtensions . SamplerTrivial) : :(RandomExtensions . SamplerSimple)
5965 Expr (:(:: ),
6066 ex. args[1 ],
6167 Expr (:curly , t,
6268 Expr (:(< :), ex. args[2 ])))
6369end
6470
65- function samplerize! (sps, ex, name)
71+ function samplerize! (sps, ex, name, rng )
6672 if ex == name
6773 # not within a rand() call
6874 return Expr (:ref , name) # name -> name[]
@@ -72,8 +78,8 @@ function samplerize!(sps, ex, name)
7278 # TODO : handle Repetition == Val(Inf) for arrays
7379 push! (sps, ex. args[2 ])
7480 i = length (sps)
75- Expr (:call , :rand , : rng , :($ name. data[$ i]), ex. args[3 : end ]. .. )
81+ Expr (:call , :rand , rng, :($ name. data[$ i]), ex. args[3 : end ]. .. )
7682 else
77- Expr (ex. head, map (e -> samplerize! (sps, e, name), ex. args)... )
83+ Expr (ex. head, map (e -> samplerize! (sps, e, name, rng ), ex. args)... )
7884 end
7985end
0 commit comments