Skip to content

Commit 5f12518

Browse files
committed
@rand: handle where clauses in definitions
1 parent 67f3bca commit 5f12518

2 files changed

Lines changed: 36 additions & 0 deletions

File tree

src/macros.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,17 @@ macro rand(exp)
33
end
44

55
function rand_macro(ex)
6+
whereparams = []
67
ex isa Expr && ex.head (:(=), :function) ||
78
throw(ArgumentError("@rand requires an expression defining `rand`"))
89
sig = ex.args[1]
910
body = ex.args[2]
1011

12+
if sig.head == :where
13+
append!(whereparams, sig.args[2:end])
14+
sig = sig.args[1]
15+
end
16+
1117
sig.head == :call &&
1218
sig.args[1] == :rand || throw(ArgumentError(
1319
"@rand requires a function expression defining `rand`"))
@@ -27,6 +33,10 @@ function rand_macro(ex)
2733
:($(esc(rng))::AbstractRNG),
2834
esc(as_sampler(namefull, istrivial)))
2935

36+
if !isempty(whereparams)
37+
exsig = Expr(:where, exsig, map(esc, whereparams)...)
38+
end
39+
3040
ex = Expr(ex.head, exsig, esc(body))
3141

3242
sp = if istrivial
@@ -43,6 +53,8 @@ function rand_macro(ex)
4353
SamplerSimple($(esc(argname)), tuple(SP))
4454
end
4555
end
56+
@assert sp.args[2].args[1].head == :where
57+
append!(sp.args[2].args[1].args, map(esc, whereparams))
4658

4759
# insert x::X in the argument list, between RNG and n::Repetition
4860
insert!(sp.args[2].args[1].args[1].args, 3, esc(namefull))

test/runtests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,12 @@ end
573573

574574
Base.eltype(::Type{Die}) = Int
575575

576+
struct DieT{T}
577+
n::T
578+
end
579+
580+
Base.eltype(::Type{DieT{T}}) where {T} = T
581+
576582
@testset "@rand" begin
577583
# rng0 to be sure rng is not accessed in `esc`aped body of @rand
578584
rng0 = MersenneTwister()
@@ -607,4 +613,22 @@ Base.eltype(::Type{Die}) = Int
607613
VAR = 100
608614
@rand rand(d::Die) = rand(VAR+1:VAR+d.n) - VAR
609615
@test all((1:6), rand(d, 100))
616+
617+
# with type parameters
618+
d = DieT(6)
619+
620+
@rand rand(d::DieT{T}) where {T} = 1
621+
@test rand(d) == 1
622+
623+
@rand rand(d::DieT{Int}) = 0
624+
@test rand(d) == 0
625+
626+
d = DieT(0x0)
627+
@rand rand(d::DieT{T}) where {T<:UInt8} = rand(typemin(T):typemax(T))
628+
@test rand(d) isa UInt8
629+
630+
d = DieT(true)
631+
@rand rand(d::DieT{T}) where {T<:Bool} =
632+
T(typemin(T) + rand(typemin(T):typemax(T)))
633+
@test rand(d) isa Bool
610634
end

0 commit comments

Comments
 (0)