Skip to content

Commit bc83a00

Browse files
committed
add Categorical distribution
1 parent fa6b81d commit bc83a00

4 files changed

Lines changed: 67 additions & 1 deletion

File tree

src/RandomExtensions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module RandomExtensions
22

3-
export make, Uniform, Normal, Exponential, CloseOpen, OpenClose, OpenOpen, CloseClose, Rand, Bernoulli
3+
export make, Uniform, Normal, Exponential, CloseOpen, OpenClose, OpenOpen, CloseClose, Rand,
4+
Bernoulli, Categorical
45

56
# re-exports from Random, which don't overlap with new functionality and not from misc.jl
67
export rand!, AbstractRNG, MersenneTwister, RandomDevice

src/distributions.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,34 @@ end
200200

201201
Bernoulli(p::Real=0.5) = Bernoulli(Int, p)
202202
Bernoulli(::Type{T}, p::Real=0.5) where {T<:Number} = Bernoulli{T}(p)
203+
204+
## Categorical
205+
206+
struct Categorical{T<:Number} <: Distribution{T}
207+
cdf::Vector{Float64}
208+
209+
function Categorical{T}(weigths) where T
210+
if !isa(weigths, AbstractArray)
211+
# necessary for accumulate
212+
# TODO: will not be necessary anymore in Julia 1.5
213+
weigths = collect(weigths)
214+
end
215+
weigths = vec(weigths)
216+
217+
isempty(weigths) &&
218+
throw(ArgumentError("Categorical requires at least one category"))
219+
220+
s = Float64(sum(weigths))
221+
cdf = accumulate(weigths; init=0.0) do x, y
222+
x + Float64(y) / s
223+
end
224+
@assert isapprox(cdf[end], 1.0) # really?
225+
cdf[end] = 1.0 # to be sure the algo terminates
226+
new{T}(cdf)
227+
end
228+
end
229+
230+
Categorical(weigths) = Categorical{Int}(weigths)
231+
232+
Categorical(n::Number) =
233+
Categorical{typeof(n)}(Iterators.repeated(1.0 / Float64(n), Int(n)))

src/sampling.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,22 @@ rand(rng::AbstractRNG, sp::SamplerTag{Bernoulli{T}}) where {T} =
136136
ifelse(rand(rng, CloseOpen12()) < sp.data, one(T), zero(T))
137137

138138

139+
## Categorical
140+
141+
Sampler(RNG::Type{<:AbstractRNG}, c::Categorical, n::Repetition) =
142+
SamplerSimple(c, Sampler(RNG, CloseOpen(), n))
143+
144+
# unfortunately requires @inline to avoid allocating
145+
@inline rand(rng::AbstractRNG, sp::SamplerSimple{Categorical{T}}) where {T} =
146+
let c = rand(rng, sp.data)
147+
T(findfirst(x -> x >= c, sp[].cdf))
148+
end
149+
150+
# NOTE:
151+
# if length(cdf) is somewhere between 150 and 200, the following gets faster:
152+
# T(searchsortedfirst(sp[].cdf, rand(rng, sp.data)))
153+
154+
139155
## random elements from pairs
140156

141157
Sampler(RNG::Type{<:AbstractRNG}, t::Pair, n::Repetition) =

test/runtests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,24 @@ using Test
6262
r = rand(Bernoulli(T, 1))
6363
@test r == 1
6464
end
65+
66+
# Categorical
67+
n = rand(1:9)
68+
@test rand(Categorical(n)) 1:9
69+
@test all((1:9), rand(Categorical(n), 10))
70+
@test rand(Categorical(n)) isa Int
71+
c = Categorical(Float64(n))
72+
@test rand(c) isa Float64
73+
@test rand(c) 1:9
74+
c = Categorical([1, 7, 2])
75+
# cf. Bernoulli tests
76+
@test 620 < count(==(2), rand(c, 1000)) < 780
77+
@test rand(c) isa Int
78+
@test rand(Categorical{Float64}((1, 2, 3, 4))) isa Float64
79+
80+
@test_throws ArgumentError Categorical(())
81+
@test_throws ArgumentError Categorical([])
82+
@test_throws ArgumentError Categorical(x for x in 1:0)
6583
end
6684

6785
const rInt8 = typemin(Int8):typemax(Int8)

0 commit comments

Comments
 (0)