Skip to content

Commit 00b5824

Browse files
committed
random: add Distribution type, and implement some concrete ones
* Normal & Exponential distributions * Pair * Complex
1 parent cb7123b commit 00b5824

5 files changed

Lines changed: 231 additions & 1 deletion

File tree

REQUIRE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
julia 0.6
1+
julia 0.7.0-DEV.3261

src/RandomExtensions.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
11
module RandomExtensions
22

3+
export Combine, Uniform, Normal, Exponential, CloseOpen
4+
5+
import Random: Sampler, rand
6+
7+
using Random
8+
using Random: SamplerTrivial, SamplerSimple, SamplerTag, Repetition
9+
10+
include("distributions.jl")
11+
include("sampling.jl")
12+
313
end # module

src/distributions.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# definition of some distribution types
2+
3+
4+
## Distribution & Combine
5+
6+
abstract type Distribution{T} end
7+
8+
Base.eltype(::Type{Distribution{T}}) where {T} = T
9+
10+
abstract type Combine{T} <: Distribution{T} end
11+
12+
struct Combine0{T} <: Combine{T} end
13+
14+
Combine(::Type{T}) where {T} = Combine0{T}()
15+
16+
struct Combine1{T,X} <: Combine{T}
17+
x::X
18+
end
19+
20+
Combine(::Type{T}, x::X) where {T,X} = Combine1{T,X}(x)
21+
Combine(::Type{T}, ::Type{X}) where {T,X} = Combine1{T,Type{X}}(X)
22+
23+
struct Combine2{T,X,Y} <: Combine{T}
24+
x::X
25+
y::Y
26+
end
27+
28+
Combine(::Type{T}, x::X, y::Y) where {T,X,Y} = Combine2{deduce_type(T,X,Y),X,Y}(x, y)
29+
Combine(::Type{T}, ::Type{X}, y::Y) where {T,X,Y} = Combine2{deduce_type(T,X,Y),Type{X},Y}(X, y)
30+
Combine(::Type{T}, x::X, ::Type{Y}) where {T,X,Y} = Combine2{deduce_type(T,X,Y),X,Type{Y}}(x, Y)
31+
Combine(::Type{T}, ::Type{X}, ::Type{Y}) where {T,X,Y} = Combine2{deduce_type(T,X,Y),Type{X},Type{Y}}(X, Y)
32+
33+
deduce_type(::Type{T}, ::Type{X}, ::Type{Y}) where {T,X,Y} = _deduce_type(T, Val(isconcretetype(T)), eltype(X), eltype(Y))
34+
deduce_type(::Type{T}, ::Type{X}) where {T,X} = _deduce_type(T, Val(isconcretetype(T)), eltype(X))
35+
36+
_deduce_type(::Type{T}, ::Val{true}, ::Type{X}, ::Type{Y}) where {T,X,Y} = T
37+
_deduce_type(::Type{T}, ::Val{false}, ::Type{X}, ::Type{Y}) where {T,X,Y} = deduce_type(T{X}, Y)
38+
39+
_deduce_type(::Type{T}, ::Val{true}, ::Type{X}) where {T,X} = T
40+
_deduce_type(::Type{T}, ::Val{false}, ::Type{X}) where {T,X} = T{X}
41+
42+
43+
## Uniform
44+
45+
abstract type Uniform{T} <: Distribution{T} end
46+
47+
48+
struct UniformType{T} <: Uniform{T} end
49+
50+
Uniform(::Type{T}) where {T} = UniformType{T}()
51+
52+
Base.getindex(::UniformType{T}) where {T} = T
53+
54+
struct UniformWrap{T,E} <: Uniform{E}
55+
val::T
56+
end
57+
58+
Uniform(x::T) where {T} = UniformWrap{T,eltype(T)}(x)
59+
60+
Base.getindex(x::UniformWrap) = x.val
61+
62+
63+
## Normal & Exponential
64+
65+
abstract type Normal{T} <: Distribution{T} end
66+
67+
struct Normal01{T} <: Normal{T} end
68+
69+
struct Normalμσ{T} <: Normal{T}
70+
μ::T
71+
σ::T
72+
end
73+
74+
Normal(::Type{T}=Float64) where {T} = Normal01{T}()
75+
Normal::T, σ::T) where {T} = Normalμσ(μ, σ)
76+
77+
abstract type Exponential{T} <: Distribution{T} end
78+
79+
struct Exponential1{T} <: Exponential{T} end
80+
81+
struct Exponentialθ{T} <: Exponential{T}
82+
θ::T
83+
end
84+
85+
Exponential(::Type{T}=Float64) where {T<:AbstractFloat} = Exponential1{T}()
86+
Exponential::T) where {T<:AbstractFloat} = Exponentialθ(θ)
87+
88+
89+
## floats
90+
91+
abstract type FloatInterval{T<:AbstractFloat} <: Uniform{T} end
92+
abstract type CloseOpen{T<:AbstractFloat} <: FloatInterval{T} end
93+
94+
struct CloseOpen01{T<:AbstractFloat} <: CloseOpen{T} end # interval [0,1)
95+
struct CloseOpen12{T<:AbstractFloat} <: CloseOpen{T} end # interval [1,2)
96+
97+
struct CloseOpenAB{T<:AbstractFloat} <: CloseOpen{T} # interval [a,b)
98+
a::T
99+
b::T
100+
end
101+
102+
const FloatInterval_64 = FloatInterval{Float64}
103+
const CloseOpen01_64 = CloseOpen01{Float64}
104+
const CloseOpen12_64 = CloseOpen12{Float64}
105+
106+
CloseOpen01(::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen01{T}()
107+
CloseOpen12(::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen12{T}()
108+
109+
CloseOpen(::Type{T}=Float64) where {T<:AbstractFloat} = CloseOpen01{T}()
110+
CloseOpen(a::T, b::T) where {T<:AbstractFloat} = CloseOpenAB{T}(a, b)
111+
112+
113+
Base.eltype(::Type{<:FloatInterval{T}}) where {T<:AbstractFloat} = T
114+
115+
116+
## a dummy container type to take advangage of SamplerTag constructor
117+
118+
struct Cont{T} end
119+
120+
Base.eltype(::Type{Cont{T}}) where {T} = T

src/sampling.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# definition of samplers and random generation
2+
3+
4+
## Uniform
5+
6+
Sampler(rng::AbstractRNG, d::Union{UniformWrap,UniformType}, n::Repetition) =
7+
Sampler(rng, d[], n)
8+
9+
10+
## floats
11+
12+
### override def from Random
13+
14+
Sampler(rng::AbstractRNG, ::Type{T}, n::Repetition) where {T<:AbstractFloat} =
15+
Sampler(rng, CloseOpen01(T), n)
16+
17+
### fall-back on Random definitions
18+
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen01{T}}) where {T} =
19+
rand(r, SamplerTrivial(Random.CloseOpen01{T}()))
20+
21+
rand(r::AbstractRNG, ::SamplerTrivial{CloseOpen12{T}}) where {T} =
22+
rand(r, SamplerTrivial(Random.CloseOpen12{T}()))
23+
24+
### CloseOpenAB
25+
26+
Sampler(rng::AbstractRNG, d::CloseOpenAB{T}, n::Repetition) where {T} =
27+
SamplerTag{CloseOpenAB{T}}((a=d.a, d=d.b - d.a, sp=Sampler(rng, CloseOpen01{T}(), n)))
28+
29+
rand(rng::AbstractRNG, sp::SamplerTag{CloseOpenAB{T}}) where {T} =
30+
sp.data.a + sp.data.d * rand(rng, sp.data.sp)
31+
32+
33+
## sampler for pairs and complex numbers
34+
35+
function Sampler(rng::AbstractRNG, u::Combine2{T}, n::Repetition) where T <: Union{Pair,Complex}
36+
sp1 = Sampler(rng, u.x, n)
37+
sp2 = u.x == u.y ? sp1 : Sampler(rng, u.y, n)
38+
SamplerTag{Cont{T}}((sp1, sp2))
39+
end
40+
41+
rand(rng::AbstractRNG, sp::SamplerTag{Cont{T}}) where {T<:Union{Pair,Complex}} =
42+
T(rand(rng, sp.data[1]), rand(rng, sp.data[2]))
43+
44+
45+
### additional methods for complex numbers
46+
47+
Sampler(rng::AbstractRNG, u::Combine1{Complex}, n::Repetition) =
48+
Sampler(rng, Combine(Complex, u.x, u.x), n)
49+
50+
Sampler(rng::AbstractRNG, ::Type{Complex{T}}, n::Repetition) where {T<:Real} =
51+
Sampler(rng, Combine(Complex, T, T), n)
52+
53+
54+
## Normal & Exponential
55+
56+
rand(rng::AbstractRNG, ::SamplerTrivial{Normal01{T}}) where {T<:Union{AbstractFloat,Complex{<:AbstractFloat}}} =
57+
randn(rng, T)
58+
59+
Sampler(rng::AbstractRNG, d::Normalμσ{T}, n::Repetition) where {T} =
60+
SamplerSimple(d, Sampler(rng, Normal(T), n))
61+
62+
rand(rng::AbstractRNG, sp::SamplerSimple{Normalμσ{T},<:Sampler}) where {T} =
63+
sp[].μ + sp[].σ * rand(rng, sp.data)
64+
65+
rand(rng::AbstractRNG, ::SamplerTrivial{Exponential1{T}}) where {T<:AbstractFloat} =
66+
randexp(rng, T)
67+
68+
Sampler(rng::AbstractRNG, d::Exponentialθ{T}, n::Repetition) where {T} =
69+
SamplerSimple(d, Sampler(rng, Exponential(T), n))
70+
71+
rand(rng::AbstractRNG, sp::SamplerSimple{Exponentialθ{T},<:Sampler}) where {T} =
72+
sp[].θ * rand(rng, sp.data)

test/runtests.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,30 @@
11
using RandomExtensions
22
using Test
3+
4+
@testset "Distributions" begin
5+
# Normal/Exponential
6+
@test rand(Normal()) isa Float64
7+
@test rand(Normal(0.0, 1.0)) isa Float64
8+
@test rand(Exponential()) isa Float64
9+
@test rand(Exponential(1.0)) isa Float64
10+
@test rand(Normal(Float32)) isa Float32
11+
@test rand(Exponential(Float32)) isa Float32
12+
@test rand(Normal(ComplexF64)) isa ComplexF64
13+
14+
# pairs/complexes
15+
@test rand(Combine(Pair, 1:3, Float64)) isa Pair{Int,Float64}
16+
z = rand(Combine(Complex, 1:3, 6:9))
17+
@test z.re 1:3
18+
@test z.im 6:9
19+
@test z isa Complex{Int}
20+
z = rand(Combine(ComplexF64, 1:3, 6:9))
21+
@test z.re 1:3
22+
@test z.im 6:9
23+
@test z isa ComplexF64
24+
25+
# Uniform
26+
@test rand(Uniform(Float64)) isa Float64
27+
@test rand(Uniform(1:10)) isa Int
28+
@test rand(Uniform(1:10)) 1:10
29+
@test rand(Uniform(Int)) isa Int
30+
end

0 commit comments

Comments
 (0)