@@ -6,27 +6,15 @@ using SpecialFunctions: loggamma
66
77export pois_rand
88
9- # GPU-compatible Poisson sampling
10- randexp (T:: Type ) = - log (rand (T))
11- randexp () = randexp (Float64)
12-
13- function count_rand (λ)
14- λ = Float64 (λ)
15- n = 0
16- c = randexp (Float64)
17- while c < λ
18- n += 1
19- c += randexp (Float64)
20- end
21- return n
22- end
9+ # GPU-compatible Poisson sampling PassthroughRNG
10+ struct PassthroughRNG <: AbstractRNG end
2311
2412function count_rand (rng:: AbstractRNG , λ)
2513 n = 0
26- c = randexp (rng)
14+ c = rng isa PassthroughRNG ? randexp () : randexp (rng)
2715 while c < λ
2816 n += 1
29- c += randexp (rng)
17+ c += rng isa PassthroughRNG ? randexp () : randexp (rng)
3018 end
3119 return n
3220end
3927#
4028# For μ sufficiently large, (i.e. >= 10.0)
4129#
42- function ad_rand (λ)
30+ function ad_rand (rng :: AbstractRNG , λ)
4331 λ = Float64 (λ)
4432 s = sqrt (λ)
4533 d = 6.0 * λ^ 2
4634 L = floor (Int, λ - 1.1484 )
4735
48- G = λ + s * randn ()
36+ G = λ + s * (rng isa PassthroughRNG ? randn () : randn (rng) )
4937
5038 if G >= 0
5139 K = floor (Int, G)
5240 if K >= L
5341 return K
5442 end
5543
56- U = rand ()
44+ U = rng isa PassthroughRNG ? rand () : rand (rng )
5745 if d * U >= (λ - K)^ 3
5846 return K
5947 end
@@ -65,8 +53,8 @@ function ad_rand(λ)
6553 end
6654
6755 while true
68- E = randexp ()
69- U = 2 * rand () - 1
56+ E = rng isa PassthroughRNG ? randexp () : randexp (rng )
57+ U = 2 * (rng isa PassthroughRNG ? rand () : rand (rng) ) - 1
7058 T_val = 1.8 + copysign (E, U)
7159 if T_val <= - 0.6744
7260 continue
@@ -82,55 +70,6 @@ function ad_rand(λ)
8270 end
8371end
8472
85- function ad_rand (rng:: AbstractRNG , λ)
86- s = sqrt (λ)
87- d = 6 * λ^ 2
88- L = floor (Int, λ - 1.1484 )
89- # Step N
90- G = λ + s * randn (rng)
91-
92- if G >= 0
93- K = floor (Int, G)
94- # Step I
95- if K >= L
96- return K
97- end
98-
99- # Step S
100- U = rand (rng)
101- if d * U >= (λ - K)^ 3
102- return K
103- end
104-
105- # Step P
106- px, py, fx, fy = procf (λ, K, s)
107-
108- # Step Q
109- if fy * (1 - U) <= py * exp (px - fx)
110- return K
111- end
112- end
113-
114- while true
115- # Step E
116- E = randexp (rng)
117- U = 2 * rand (rng) - 1
118- T = 1.8 + copysign (E, U)
119- if T <= - 0.6744
120- continue
121- end
122-
123- K = floor (Int, λ + s * T)
124- px, py, fx, fy = procf (λ, K, s)
125- c = 0.1069 / λ
126-
127- # Step H
128- @fastmath if c * abs (U) <= py * exp (px + E) - fy * exp (fx + E)
129- return K
130- end
131- end
132- end
133-
13473# Procedure F
13574function procf (λ, K:: Int , s:: Float64 )
13675 INV_SQRT_2PI = 0.3989422804014327 # 1/sqrt(2π)
@@ -161,34 +100,6 @@ function procf(λ, K::Int, s::Float64)
161100 return px, py, fx, fy
162101end
163102
164- function procf (λ, K:: Int , s:: Float64 )
165- # can be pre-computed, but does not seem to affect performance
166- INV_SQRT_2PI = inv (sqrt (2pi ))
167- ω = INV_SQRT_2PI / s
168- b1 = inv (24 ) / λ
169- b2 = 0.3 * b1 * b1
170- c3 = inv (7 ) * b1 * b2
171- c2 = b2 - 15 * c3
172- c1 = b1 - 6 * b2 + 45 * c3
173- c0 = 1 - b1 + 3 * b2 - 15 * c3
174-
175- if K < 10
176- px = - float (λ)
177- py = λ^ K / factorial (K)
178- else
179- δ = inv (12 ) / K
180- δ -= 4.8 * δ^ 3
181- V = (λ - K) / K
182- px = K * log1pmx (V) - δ # avoids need for table
183- py = INV_SQRT_2PI / sqrt (K)
184- end
185- X = (K - λ + 0.5 ) / s
186- X2 = X^ 2
187- fx = X2 / - 2 # missing negation in pseudo-algorithm, but appears in fortran code.
188- fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
189- return px, py, fx, fy
190- end
191-
192103"""
193104```julia
194105pois_rand(λ)
@@ -209,7 +120,7 @@ rng = Xorshifts.Xoroshiro128Plus()
209120pois_rand(rng, λ)
210121```
211122"""
212- pois_rand (λ) = λ < 6 ? count_rand (λ) : ad_rand (λ)
123+ pois_rand (λ) = λ < 6 ? count_rand (PassthroughRNG (), λ) : ad_rand (PassthroughRNG (), λ)
213124pois_rand (rng:: AbstractRNG , λ) = λ < 6 ? count_rand (rng, λ) : ad_rand (rng, λ)
214125
215126end # module
0 commit comments