@@ -3,11 +3,17 @@ macro rand(exp)
33end
44
55function rand_macro (ex)
6+ whereparams = []
67 ex isa Expr && ex. head ∈ (:(= ), :function ) ||
78 throw (ArgumentError (" @rand requires an expression defining `rand`" ))
89 sig = ex. args[1 ]
910 body = ex. args[2 ]
1011
12+ if sig. head == :where
13+ append! (whereparams, sig. args[2 : end ])
14+ sig = sig. args[1 ]
15+ end
16+
1117 sig. head == :call &&
1218 sig. args[1 ] == :rand || throw (ArgumentError (
1319 " @rand requires a function expression defining `rand`" ))
@@ -27,6 +33,10 @@ function rand_macro(ex)
2733 :($ (esc (rng)):: AbstractRNG ),
2834 esc (as_sampler (namefull, istrivial)))
2935
36+ if ! isempty (whereparams)
37+ exsig = Expr (:where , exsig, map (esc, whereparams)... )
38+ end
39+
3040 ex = Expr (ex. head, exsig, esc (body))
3141
3242 sp = if istrivial
@@ -43,6 +53,8 @@ function rand_macro(ex)
4353 SamplerSimple ($ (esc (argname)), tuple (SP))
4454 end
4555 end
56+ @assert sp. args[2 ]. args[1 ]. head == :where
57+ append! (sp. args[2 ]. args[1 ]. args, map (esc, whereparams))
4658
4759 # insert x::X in the argument list, between RNG and n::Repetition
4860 insert! (sp. args[2 ]. args[1 ]. args[1 ]. args, 3 , esc (namefull))
0 commit comments