|
| 1 | +macro rand(exp) |
| 2 | + rand_macro(exp) |
| 3 | +end |
| 4 | + |
| 5 | +function rand_macro(ex) |
| 6 | + ex isa Expr && |
| 7 | + ex.head ∈ (:(=), :function) && |
| 8 | + ex.args[1].head == :call && |
| 9 | + ex.args[1].args[1] == :rand || throw(ArgumentError( |
| 10 | + "@rand requires a function expression defining `rand`")) |
| 11 | + |
| 12 | + name = ex.args[1].args[2].args[1] # x |
| 13 | + namefull = ex.args[1].args[2] # x::X |
| 14 | + @assert namefull.head == :(::) |
| 15 | + namefull.args[2] = esc(namefull.args[2]) |
| 16 | + |
| 17 | + sps = Any[] # sub-samplers |
| 18 | + body = samplerize!(sps, ex.args[2], name) |
| 19 | + istrivial = isempty(sps) |
| 20 | + # rand -> Base.rand |
| 21 | + ex = Expr(ex.head, Expr(:call, :(Random.rand), |
| 22 | + :(rng::AbstractRNG), |
| 23 | + as_sampler(ex.args[1].args[2], istrivial)), |
| 24 | + body) |
| 25 | + istrivial && return ex |
| 26 | + sp = quote |
| 27 | + Random.Sampler(::Type{RNG}, n::Repetition) where {RNG<:AbstractRNG} = |
| 28 | + SamplerSimple($name, tuple(SP)) |
| 29 | + end |
| 30 | + # insert x::X in the argument list, between RNG and n::Repetition |
| 31 | + insert!(sp.args[2].args[1].args[1].args, 3, namefull) |
| 32 | + SP = [Expr(:call, :Sampler, :RNG, x, :n) for x in sps] |
| 33 | + @assert :SP == pop!(sp.args[2].args[2].args[2].args[3].args) |
| 34 | + append!(sp.args[2].args[2].args[2].args[3].args, SP) |
| 35 | + quote |
| 36 | + $ex |
| 37 | + $(sp.args[2]) # unwrap the quote/block around the definition |
| 38 | + end |
| 39 | +end |
| 40 | + |
| 41 | +function as_sampler(ex, istrivial) |
| 42 | + t = istrivial ? :SamplerTrivial : :SamplerSimple |
| 43 | + Expr(:(::), |
| 44 | + ex.args[1], |
| 45 | + Expr(:curly, t, |
| 46 | + Expr(:(<:), ex.args[2]))) |
| 47 | +end |
| 48 | + |
| 49 | +function samplerize!(sps, ex, name) |
| 50 | + ex isa Expr || return ex |
| 51 | + if ex.head == :call && ex.args[1] == :rand |
| 52 | + # TODO: handle Repetition == Val(Inf) for arrays |
| 53 | + push!(sps, ex.args[2]) |
| 54 | + i = length(sps) |
| 55 | + Expr(:call, :rand, :rng, :($name.data[$i]), ex.args[3:end]...) |
| 56 | + else |
| 57 | + Expr(ex.head, map(e -> samplerize!(sps, e, name), ex.args)...) |
| 58 | + end |
| 59 | +end |
0 commit comments