Skip to content

Commit 75c36f1

Browse files
committed
add @rand macro for easy rand/Sampler methods definitions
1 parent bc83a00 commit 75c36f1

3 files changed

Lines changed: 92 additions & 1 deletion

File tree

src/RandomExtensions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module RandomExtensions
22

33
export make, Uniform, Normal, Exponential, CloseOpen, OpenClose, OpenOpen, CloseClose, Rand,
4-
Bernoulli, Categorical
4+
Bernoulli, Categorical, @rand
55

66
# re-exports from Random, which don't overlap with new functionality and not from misc.jl
77
export rand!, AbstractRNG, MersenneTwister, RandomDevice
@@ -39,6 +39,7 @@ include("distributions.jl")
3939
include("sampling.jl")
4040
include("containers.jl")
4141
include("iteration.jl")
42+
include("macros.jl")
4243

4344

4445
## updated rand docstring (TODO: replace Base's one)

src/macros.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
macro rand(exp)
2+
rand_macro(exp)
3+
end
4+
5+
function rand_macro(ex)
6+
ex isa Expr &&
7+
ex.head (:(=), :function) &&
8+
ex.args[1].head == :call &&
9+
ex.args[1].args[1] == :rand || throw(ArgumentError(
10+
"@rand requires a function expression defining `rand`"))
11+
12+
name = ex.args[1].args[2].args[1] # x
13+
namefull = ex.args[1].args[2] # x::X
14+
@assert namefull.head == :(::)
15+
namefull.args[2] = esc(namefull.args[2])
16+
17+
sps = Any[] # sub-samplers
18+
body = samplerize!(sps, ex.args[2], name)
19+
istrivial = isempty(sps)
20+
# rand -> Base.rand
21+
ex = Expr(ex.head, Expr(:call, :(Random.rand),
22+
:(rng::AbstractRNG),
23+
as_sampler(ex.args[1].args[2], istrivial)),
24+
body)
25+
istrivial && return ex
26+
sp = quote
27+
Random.Sampler(::Type{RNG}, n::Repetition) where {RNG<:AbstractRNG} =
28+
SamplerSimple($name, tuple(SP))
29+
end
30+
# insert x::X in the argument list, between RNG and n::Repetition
31+
insert!(sp.args[2].args[1].args[1].args, 3, namefull)
32+
SP = [Expr(:call, :Sampler, :RNG, x, :n) for x in sps]
33+
@assert :SP == pop!(sp.args[2].args[2].args[2].args[3].args)
34+
append!(sp.args[2].args[2].args[2].args[3].args, SP)
35+
quote
36+
$ex
37+
$(sp.args[2]) # unwrap the quote/block around the definition
38+
end
39+
end
40+
41+
function as_sampler(ex, istrivial)
42+
t = istrivial ? :SamplerTrivial : :SamplerSimple
43+
Expr(:(::),
44+
ex.args[1],
45+
Expr(:curly, t,
46+
Expr(:(<:), ex.args[2])))
47+
end
48+
49+
function samplerize!(sps, ex, name)
50+
ex isa Expr || return ex
51+
if ex.head == :call && ex.args[1] == :rand
52+
# TODO: handle Repetition == Val(Inf) for arrays
53+
push!(sps, ex.args[2])
54+
i = length(sps)
55+
Expr(:call, :rand, :rng, :($name.data[$i]), ex.args[3:end]...)
56+
else
57+
Expr(ex.head, map(e -> samplerize!(sps, e, name), ex.args)...)
58+
end
59+
end

test/runtests.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,3 +564,34 @@ end
564564
@test rand(make(1:3)) 1:3
565565
@test rand(make(Float64)) isa Float64
566566
end
567+
568+
## @rand
569+
570+
struct Die
571+
n::Int
572+
end
573+
574+
Base.eltype(::Type{Die}) = Int
575+
576+
@rand function rand(d::Die)
577+
7
578+
end
579+
580+
@testset "@rand" begin
581+
d = Die(6)
582+
rng = MersenneTwister()
583+
584+
@test rand(d) == 7
585+
586+
# redefinition
587+
@rand rand(d::Die) = rand(1:d.n)
588+
@test rand(d) 1:6
589+
@test rand(rng, d) 1:6
590+
@test all((1:6), rand(rng, d, 10))
591+
@test eltype(rand(d, 3)) == Int
592+
593+
# redefinition (multiple inner samplers)
594+
@rand rand(d::Die) = rand(10:10) + rand(1:d.n)
595+
@test rand(d) 11:16
596+
@test all((11:16), rand(d, 10))
597+
end

0 commit comments

Comments
 (0)