Skip to content

Commit 64b8c30

Browse files
committed
@rand: accept anonymous functions
1 parent 3b557b7 commit 64b8c30

2 files changed

Lines changed: 49 additions & 2 deletions

File tree

src/macros.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ end
44

55
function rand_macro(ex)
66
whereparams = []
7-
ex isa Expr && ex.head (:(=), :function) ||
7+
ex isa Expr && ex.head (:(=), :function, :->) ||
88
throw(ArgumentError("@rand requires an expression defining `rand`"))
99
sig = ex.args[1]
1010
body = ex.args[2]
@@ -14,6 +14,14 @@ function rand_macro(ex)
1414
sig = sig.args[1]
1515
end
1616

17+
if ex.head == :function && sig.head == :tuple # anonymous function
18+
sig = Expr(:call, :rand, sig.args...)
19+
end
20+
if ex.head == :->
21+
# TODO: check that only one argument is passed
22+
sig = Expr(:call, :rand, sig)
23+
end
24+
1725
sig.head == :call &&
1826
sig.args[1] == :rand || throw(ArgumentError(
1927
"@rand requires a function expression defining `rand`"))
@@ -37,7 +45,7 @@ function rand_macro(ex)
3745
exsig = Expr(:where, exsig, map(esc, whereparams)...)
3846
end
3947

40-
ex = Expr(ex.head, exsig, esc(body))
48+
ex = Expr(:function, exsig, esc(body))
4149

4250
sp = if istrivial
4351
# we explicitly define Sampler even in the trivial case to handle

test/runtests.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,10 @@ Base.eltype(::Type{DieT{T}}) where {T} = T
588588
7
589589
end
590590
@test rand(d) == 7
591+
@rand function (d::Die) 7 end
592+
@test rand(d) == 7
593+
@rand (d::Die) -> 7
594+
@test rand(d) == 7
591595

592596
# redefinition
593597
@rand rand(d::Die) = rand(1:d.n)
@@ -596,11 +600,33 @@ Base.eltype(::Type{DieT{T}}) where {T} = T
596600
@test all((1:6), rand(rng0, d, 10))
597601
@test eltype(rand(d, 3)) == Int
598602

603+
@rand function (d::Die) rand(1:d.n) end
604+
@test rand(d) 1:6
605+
@test rand(rng0, d) 1:6
606+
@test all((1:6), rand(rng0, d, 10))
607+
@test eltype(rand(d, 3)) == Int
608+
609+
@rand (d::Die) -> rand(1:d.n)
610+
@test rand(d) 1:6
611+
@test rand(rng0, d) 1:6
612+
@test all((1:6), rand(rng0, d, 10))
613+
@test eltype(rand(d, 3)) == Int
614+
599615
# redefinition (multiple inner samplers)
600616
@rand rand(d::Die) = rand(10:10) + rand(1:d.n)
601617
@test rand(d) 11:16
602618
@test all((11:16), rand(d, 10))
603619

620+
@rand function (d::Die)
621+
rand(10:10) + rand(1:d.n)
622+
end
623+
@test rand(d) 11:16
624+
@test all((11:16), rand(d, 10))
625+
626+
@rand (d::Die) -> rand(10:10) + rand(1:d.n)
627+
@test rand(d) 11:16
628+
@test all((11:16), rand(d, 10))
629+
604630
# redefinition access to argument not within rand call
605631
@rand rand(d::Die) = rand(1:1) + d.n
606632
@test all(==(7), rand(d, 10))
@@ -620,6 +646,19 @@ Base.eltype(::Type{DieT{T}}) where {T} = T
620646
@rand rand(d::DieT{T}) where {T} = 1
621647
@test rand(d) == 1
622648

649+
@rand function rand(d::DieT{T}) where {T}
650+
2
651+
end
652+
@test rand(d) == 2
653+
654+
@rand function (d::DieT{T}) where {T}
655+
3
656+
end
657+
@test rand(d) == 3
658+
659+
@rand ((d::DieT{T}) where {T}) -> 4
660+
@test rand(d) == 4
661+
623662
@rand rand(d::DieT{Int}) = 0
624663
@test rand(d) == 0
625664

0 commit comments

Comments
 (0)