@@ -4,17 +4,22 @@ using Random
44using LogExpFunctions: log1pmx
55using SpecialFunctions: loggamma
66
7- export pois_rand
7+ export pois_rand, PassthroughRNG
88
99# GPU-compatible Poisson sampling PassthroughRNG
1010struct PassthroughRNG <: AbstractRNG end
1111
12+ rand (rng:: PassthroughRNG ) = Random. rand ()
13+ randexp (rng:: PassthroughRNG ) = Random. randexp ()
14+ randn (rng:: PassthroughRNG ) = Random. randn ()
15+
16+ count_rand (λ) = count_rand (Random. GLOBAL_RNG, λ)
1217function count_rand (rng:: AbstractRNG , λ)
1318 n = 0
14- c = rng isa PassthroughRNG ? randexp () : randexp (rng)
19+ c = randexp (rng)
1520 while c < λ
1621 n += 1
17- c += rng isa PassthroughRNG ? randexp () : randexp (rng)
22+ c += randexp (rng)
1823 end
1924 return n
2025end
2732#
2833# For μ sufficiently large, (i.e. >= 10.0)
2934#
35+ ad_rand (λ) = ad_rand (Random. GLOBAL_RNG, λ)
3036function ad_rand (rng:: AbstractRNG , λ)
31- λ = Float64 (λ)
3237 s = sqrt (λ)
33- d = 6.0 * λ^ 2
38+ d = 6 * λ^ 2
3439 L = floor (Int, λ - 1.1484 )
35-
36- G = λ + s * (rng isa PassthroughRNG ? randn () : randn ( rng) )
40+ # Step N
41+ G = λ + s * randn (rng)
3742
3843 if G >= 0
3944 K = floor (Int, G)
45+ # Step I
4046 if K >= L
4147 return K
4248 end
4349
44- U = rng isa PassthroughRNG ? rand () : rand (rng)
50+ # Step S
51+ U = rand (rng)
4552 if d * U >= (λ - K)^ 3
4653 return K
4754 end
4855
56+ # Step P
4957 px, py, fx, fy = procf (λ, K, s)
58+
59+ # Step Q
5060 if fy * (1 - U) <= py * exp (px - fx)
5161 return K
5262 end
5363 end
5464
5565 while true
56- E = rng isa PassthroughRNG ? randexp () : randexp (rng)
57- U = 2 * (rng isa PassthroughRNG ? rand () : rand (rng)) - 1
58- T_val = 1.8 + copysign (E, U)
59- if T_val <= - 0.6744
66+ # Step E
67+ E = randexp (rng)
68+ U = 2 * rand (rng) - 1
69+ T = 1.8 + copysign (E, U)
70+ if T <= - 0.6744
6071 continue
6172 end
6273
63- K = floor (Int, λ + s * T_val )
74+ K = floor (Int, λ + s * T )
6475 px, py, fx, fy = procf (λ, K, s)
6576 c = 0.1069 / λ
6677
78+ # Step H
6779 @fastmath if c * abs (U) <= py * exp (px + E) - fy * exp (fx + E)
6880 return K
6981 end
@@ -89,13 +101,12 @@ function procf(λ, K::Int, s::Float64)
89101 δ = 1 / (12 * K)
90102 δ -= 4.8 * δ^ 3
91103 V = (λ - K) / K
92- px = K * log1pmx (V) - δ
104+ px = K * log1pmx (V) - δ # avoids need for table
93105 py = INV_SQRT_2PI / sqrt (K)
94106 end
95-
96107 X = (K - λ + 0.5 ) / s
97108 X2 = X^ 2
98- fx = - X2 / 2
109+ fx = - X2 / 2 # missing negation in pseudo-algorithm, but appears in fortran code.
99110 fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
100111 return px, py, fx, fy
101112end
@@ -120,7 +131,7 @@ rng = Xorshifts.Xoroshiro128Plus()
120131pois_rand(rng, λ)
121132```
122133"""
123- pois_rand (λ) = λ < 6 ? count_rand ( PassthroughRNG (), λ) : ad_rand ( PassthroughRNG () , λ)
134+ pois_rand (λ) = pois_rand (Random . GLOBAL_RNG , λ)
124135pois_rand (rng:: AbstractRNG , λ) = λ < 6 ? count_rand (rng, λ) : ad_rand (rng, λ)
125136
126137end # module
0 commit comments