Skip to content

Commit 56b1720

Browse files
added PassthroughRNG implementation
1 parent 7f010a7 commit 56b1720

1 file changed

Lines changed: 10 additions & 99 deletions

File tree

src/PoissonRandom.jl

Lines changed: 10 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,15 @@ using SpecialFunctions: loggamma
66

77
export pois_rand
88

9-
# GPU-compatible Poisson sampling
10-
randexp(T::Type) = -log(rand(T))
11-
randexp() = randexp(Float64)
12-
13-
function count_rand(λ)
14-
λ = Float64(λ)
15-
n = 0
16-
c = randexp(Float64)
17-
while c < λ
18-
n += 1
19-
c += randexp(Float64)
20-
end
21-
return n
22-
end
9+
# GPU-compatible Poisson sampling PassthroughRNG
10+
struct PassthroughRNG <: AbstractRNG end
2311

2412
function count_rand(rng::AbstractRNG, λ)
2513
n = 0
26-
c = randexp(rng)
14+
c = rng isa PassthroughRNG ? randexp() : randexp(rng)
2715
while c < λ
2816
n += 1
29-
c += randexp(rng)
17+
c += rng isa PassthroughRNG ? randexp() : randexp(rng)
3018
end
3119
return n
3220
end
@@ -39,21 +27,21 @@ end
3927
#
4028
# For μ sufficiently large, (i.e. >= 10.0)
4129
#
42-
function ad_rand(λ)
30+
function ad_rand(rng::AbstractRNG, λ)
4331
λ = Float64(λ)
4432
s = sqrt(λ)
4533
d = 6.0 * λ^2
4634
L = floor(Int, λ - 1.1484)
4735

48-
G = λ + s * randn()
36+
G = λ + s * (rng isa PassthroughRNG ? randn() : randn(rng))
4937

5038
if G >= 0
5139
K = floor(Int, G)
5240
if K >= L
5341
return K
5442
end
5543

56-
U = rand()
44+
U = rng isa PassthroughRNG ? rand() : rand(rng)
5745
if d * U >=- K)^3
5846
return K
5947
end
@@ -65,8 +53,8 @@ function ad_rand(λ)
6553
end
6654

6755
while true
68-
E = randexp()
69-
U = 2 * rand() - 1
56+
E = rng isa PassthroughRNG ? randexp() : randexp(rng)
57+
U = 2 * (rng isa PassthroughRNG ? rand() : rand(rng)) - 1
7058
T_val = 1.8 + copysign(E, U)
7159
if T_val <= -0.6744
7260
continue
@@ -82,55 +70,6 @@ function ad_rand(λ)
8270
end
8371
end
8472

85-
function ad_rand(rng::AbstractRNG, λ)
86-
s = sqrt(λ)
87-
d = 6 * λ^2
88-
L = floor(Int, λ - 1.1484)
89-
# Step N
90-
G = λ + s * randn(rng)
91-
92-
if G >= 0
93-
K = floor(Int, G)
94-
# Step I
95-
if K >= L
96-
return K
97-
end
98-
99-
# Step S
100-
U = rand(rng)
101-
if d * U >=- K)^3
102-
return K
103-
end
104-
105-
# Step P
106-
px, py, fx, fy = procf(λ, K, s)
107-
108-
# Step Q
109-
if fy * (1 - U) <= py * exp(px - fx)
110-
return K
111-
end
112-
end
113-
114-
while true
115-
# Step E
116-
E = randexp(rng)
117-
U = 2 * rand(rng) - 1
118-
T = 1.8 + copysign(E, U)
119-
if T <= -0.6744
120-
continue
121-
end
122-
123-
K = floor(Int, λ + s * T)
124-
px, py, fx, fy = procf(λ, K, s)
125-
c = 0.1069 / λ
126-
127-
# Step H
128-
@fastmath if c * abs(U) <= py * exp(px + E) - fy * exp(fx + E)
129-
return K
130-
end
131-
end
132-
end
133-
13473
# Procedure F
13574
function procf(λ, K::Int, s::Float64)
13675
INV_SQRT_2PI = 0.3989422804014327 # 1/sqrt(2π)
@@ -161,34 +100,6 @@ function procf(λ, K::Int, s::Float64)
161100
return px, py, fx, fy
162101
end
163102

164-
function procf(λ, K::Int, s::Float64)
165-
# can be pre-computed, but does not seem to affect performance
166-
INV_SQRT_2PI = inv(sqrt(2pi))
167-
ω = INV_SQRT_2PI / s
168-
b1 = inv(24) / λ
169-
b2 = 0.3 * b1 * b1
170-
c3 = inv(7) * b1 * b2
171-
c2 = b2 - 15 * c3
172-
c1 = b1 - 6 * b2 + 45 * c3
173-
c0 = 1 - b1 + 3 * b2 - 15 * c3
174-
175-
if K < 10
176-
px = -float(λ)
177-
py = λ^K / factorial(K)
178-
else
179-
δ = inv(12) / K
180-
δ -= 4.8 * δ^3
181-
V =- K) / K
182-
px = K * log1pmx(V) - δ # avoids need for table
183-
py = INV_SQRT_2PI / sqrt(K)
184-
end
185-
X = (K - λ + 0.5) / s
186-
X2 = X^2
187-
fx = X2 / -2 # missing negation in pseudo-algorithm, but appears in fortran code.
188-
fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
189-
return px, py, fx, fy
190-
end
191-
192103
"""
193104
```julia
194105
pois_rand(λ)
@@ -209,7 +120,7 @@ rng = Xorshifts.Xoroshiro128Plus()
209120
pois_rand(rng, λ)
210121
```
211122
"""
212-
pois_rand(λ) = λ < 6 ? count_rand(λ) : ad_rand(λ)
123+
pois_rand(λ) = λ < 6 ? count_rand(PassthroughRNG(), λ) : ad_rand(PassthroughRNG(), λ)
213124
pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ)
214125

215126
end # module

0 commit comments

Comments
 (0)