@@ -2,10 +2,25 @@ module PoissonRandom
22
33using Random
44using LogExpFunctions: log1pmx
5+ using SpecialFunctions: loggamma
56
67export pois_rand
78
8- count_rand (λ) = count_rand (Random. GLOBAL_RNG, λ)
9+ # GPU-compatible Poisson sampling
10+ randexp (T:: Type ) = - log (rand (T))
11+ randexp () = randexp (Float64)
12+
13+ function count_rand (λ)
14+ λ = Float64 (λ)
15+ n = 0
16+ c = randexp (Float64)
17+ while c < λ
18+ n += 1
19+ c += randexp (Float64)
20+ end
21+ return n
22+ end
23+
924function count_rand (rng:: AbstractRNG , λ)
1025 n = 0
1126 c = randexp (rng)
2439#
2540# For μ sufficiently large, (i.e. >= 10.0)
2641#
27- ad_rand (λ) = ad_rand (Random. GLOBAL_RNG, λ)
42+ function ad_rand (λ)
43+ λ = Float64 (λ)
44+ s = sqrt (λ)
45+ d = 6.0 * λ^ 2
46+ L = floor (Int, λ - 1.1484 )
47+
48+ G = λ + s * randn ()
49+
50+ if G >= 0
51+ K = floor (Int, G)
52+ if K >= L
53+ return K
54+ end
55+
56+ U = rand ()
57+ if d * U >= (λ - K)^ 3
58+ return K
59+ end
60+
61+ px, py, fx, fy = procf (λ, K, s)
62+ if fy * (1 - U) <= py * exp (px - fx)
63+ return K
64+ end
65+ end
66+
67+ while true
68+ E = randexp ()
69+ U = 2 * rand () - 1
70+ T_val = 1.8 + copysign (E, U)
71+ if T_val <= - 0.6744
72+ continue
73+ end
74+
75+ K = floor (Int, λ + s * T_val)
76+ px, py, fx, fy = procf (λ, K, s)
77+ c = 0.1069 / λ
78+
79+ @fastmath if c * abs (U) <= py * exp (px + E) - fy * exp (fx + E)
80+ return K
81+ end
82+ end
83+ end
84+
2885function ad_rand (rng:: AbstractRNG , λ)
2986 s = sqrt (λ)
3087 d = 6 * λ^ 2
@@ -75,6 +132,35 @@ function ad_rand(rng::AbstractRNG, λ)
75132end
76133
77134# Procedure F
135+ function procf (λ, K:: Int , s:: Float64 )
136+ INV_SQRT_2PI = 0.3989422804014327 # 1/sqrt(2π)
137+ ω = INV_SQRT_2PI / s
138+ b1 = 1 / (24 * λ)
139+ b2 = 0.3 * b1^ 2
140+ c3 = b1 * b2 / 7
141+ c2 = b2 - 15 * c3
142+ c1 = b1 - 6 * b2 + 45 * c3
143+ c0 = 1 - b1 + 3 * b2 - 15 * c3
144+
145+ if K < 10
146+ px = - λ
147+ log_py = K * log (λ) - loggamma (K + 1 ) # log(K!) via loggamma
148+ py = exp (log_py)
149+ else
150+ δ = 1 / (12 * K)
151+ δ -= 4.8 * δ^ 3
152+ V = (λ - K) / K
153+ px = K * log1pmx (V) - δ
154+ py = INV_SQRT_2PI / sqrt (K)
155+ end
156+
157+ X = (K - λ + 0.5 ) / s
158+ X2 = X^ 2
159+ fx = - X2 / 2
160+ fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
161+ return px, py, fx, fy
162+ end
163+
78164function procf (λ, K:: Int , s:: Float64 )
79165 # can be pre-computed, but does not seem to affect performance
80166 INV_SQRT_2PI = inv (sqrt (2pi ))
@@ -114,16 +200,16 @@ Generates Poisson(λ) distributed random numbers using a fast polyalgorithm.
114200## Examples
115201
116202```julia
117- # Simple Poisson random
203+ # Simple Poisson random which works on GPU
118204pois_rand(λ)
119205
120- # Using another RNG
206+ # Using RNG
121207using RandomNumbers
122208rng = Xorshifts.Xoroshiro128Plus()
123209pois_rand(rng, λ)
124210```
125211"""
126- pois_rand (λ) = pois_rand (Random . GLOBAL_RNG, λ)
212+ pois_rand (λ) = λ < 6 ? count_rand (λ) : ad_rand ( λ)
127213pois_rand (rng:: AbstractRNG , λ) = λ < 6 ? count_rand (rng, λ) : ad_rand (rng, λ)
128214
129215end # module
0 commit comments