Skip to content

Commit 2a452e7

Browse files
refactor
1 parent 56b1720 commit 2a452e7

1 file changed

Lines changed: 28 additions & 17 deletions

File tree

src/PoissonRandom.jl

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,22 @@ using Random
44
using LogExpFunctions: log1pmx
55
using SpecialFunctions: loggamma
66

7-
export pois_rand
7+
export pois_rand, PassthroughRNG
88

99
# GPU-compatible Poisson sampling PassthroughRNG
1010
struct PassthroughRNG <: AbstractRNG end
1111

12+
rand(rng::PassthroughRNG) = Random.rand()
13+
randexp(rng::PassthroughRNG) = Random.randexp()
14+
randn(rng::PassthroughRNG) = Random.randn()
15+
16+
count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ)
1217
function count_rand(rng::AbstractRNG, λ)
1318
n = 0
14-
c = rng isa PassthroughRNG ? randexp() : randexp(rng)
19+
c = randexp(rng)
1520
while c < λ
1621
n += 1
17-
c += rng isa PassthroughRNG ? randexp() : randexp(rng)
22+
c += randexp(rng)
1823
end
1924
return n
2025
end
@@ -27,43 +32,50 @@ end
2732
#
2833
# For μ sufficiently large, (i.e. >= 10.0)
2934
#
35+
ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ)
3036
function ad_rand(rng::AbstractRNG, λ)
31-
λ = Float64(λ)
3237
s = sqrt(λ)
33-
d = 6.0 * λ^2
38+
d = 6 * λ^2
3439
L = floor(Int, λ - 1.1484)
35-
36-
G = λ + s * (rng isa PassthroughRNG ? randn() : randn(rng))
40+
# Step N
41+
G = λ + s * randn(rng)
3742

3843
if G >= 0
3944
K = floor(Int, G)
45+
# Step I
4046
if K >= L
4147
return K
4248
end
4349

44-
U = rng isa PassthroughRNG ? rand() : rand(rng)
50+
# Step S
51+
U = rand(rng)
4552
if d * U >=- K)^3
4653
return K
4754
end
4855

56+
# Step P
4957
px, py, fx, fy = procf(λ, K, s)
58+
59+
# Step Q
5060
if fy * (1 - U) <= py * exp(px - fx)
5161
return K
5262
end
5363
end
5464

5565
while true
56-
E = rng isa PassthroughRNG ? randexp() : randexp(rng)
57-
U = 2 * (rng isa PassthroughRNG ? rand() : rand(rng)) - 1
58-
T_val = 1.8 + copysign(E, U)
59-
if T_val <= -0.6744
66+
# Step E
67+
E = randexp(rng)
68+
U = 2 * rand(rng) - 1
69+
T = 1.8 + copysign(E, U)
70+
if T <= -0.6744
6071
continue
6172
end
6273

63-
K = floor(Int, λ + s * T_val)
74+
K = floor(Int, λ + s * T)
6475
px, py, fx, fy = procf(λ, K, s)
6576
c = 0.1069 / λ
6677

78+
# Step H
6779
@fastmath if c * abs(U) <= py * exp(px + E) - fy * exp(fx + E)
6880
return K
6981
end
@@ -89,13 +101,12 @@ function procf(λ, K::Int, s::Float64)
89101
δ = 1 / (12 * K)
90102
δ -= 4.8 * δ^3
91103
V =- K) / K
92-
px = K * log1pmx(V) - δ
104+
px = K * log1pmx(V) - δ # avoids need for table
93105
py = INV_SQRT_2PI / sqrt(K)
94106
end
95-
96107
X = (K - λ + 0.5) / s
97108
X2 = X^2
98-
fx = -X2 / 2
109+
fx = -X2 / 2 # missing negation in pseudo-algorithm, but appears in fortran code.
99110
fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
100111
return px, py, fx, fy
101112
end
@@ -120,7 +131,7 @@ rng = Xorshifts.Xoroshiro128Plus()
120131
pois_rand(rng, λ)
121132
```
122133
"""
123-
pois_rand(λ) = λ < 6 ? count_rand(PassthroughRNG(), λ) : ad_rand(PassthroughRNG(), λ)
134+
pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ)
124135
pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ)
125136

126137
end # module

0 commit comments

Comments
 (0)