Skip to content

Commit 7f010a7

Browse files
added gpu version for pois_rand
1 parent 505b336 commit 7f010a7

2 files changed

Lines changed: 93 additions & 5 deletions

File tree

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ version = "0.4.5"
55
[deps]
66
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
77
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
8+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
89

910
[compat]
1011
Aqua = "0.8"
1112
Distributions = "0.25"
1213
LogExpFunctions = "0.3"
1314
Random = "1.10"
15+
SpecialFunctions = "2"
1416
Statistics = "1"
1517
Test = "1"
1618
julia = "1.10"

src/PoissonRandom.jl

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,25 @@ module PoissonRandom
22

33
using Random
44
using LogExpFunctions: log1pmx
5+
using SpecialFunctions: loggamma
56

67
export pois_rand
78

8-
count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ)
9+
# GPU-compatible Poisson sampling
10+
randexp(T::Type) = -log(rand(T))
11+
randexp() = randexp(Float64)
12+
13+
function count_rand(λ)
14+
λ = Float64(λ)
15+
n = 0
16+
c = randexp(Float64)
17+
while c < λ
18+
n += 1
19+
c += randexp(Float64)
20+
end
21+
return n
22+
end
23+
924
function count_rand(rng::AbstractRNG, λ)
1025
n = 0
1126
c = randexp(rng)
@@ -24,7 +39,49 @@ end
2439
#
2540
# For μ sufficiently large, (i.e. >= 10.0)
2641
#
27-
ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ)
42+
function ad_rand(λ)
43+
λ = Float64(λ)
44+
s = sqrt(λ)
45+
d = 6.0 * λ^2
46+
L = floor(Int, λ - 1.1484)
47+
48+
G = λ + s * randn()
49+
50+
if G >= 0
51+
K = floor(Int, G)
52+
if K >= L
53+
return K
54+
end
55+
56+
U = rand()
57+
if d * U >=- K)^3
58+
return K
59+
end
60+
61+
px, py, fx, fy = procf(λ, K, s)
62+
if fy * (1 - U) <= py * exp(px - fx)
63+
return K
64+
end
65+
end
66+
67+
while true
68+
E = randexp()
69+
U = 2 * rand() - 1
70+
T_val = 1.8 + copysign(E, U)
71+
if T_val <= -0.6744
72+
continue
73+
end
74+
75+
K = floor(Int, λ + s * T_val)
76+
px, py, fx, fy = procf(λ, K, s)
77+
c = 0.1069 / λ
78+
79+
@fastmath if c * abs(U) <= py * exp(px + E) - fy * exp(fx + E)
80+
return K
81+
end
82+
end
83+
end
84+
2885
function ad_rand(rng::AbstractRNG, λ)
2986
s = sqrt(λ)
3087
d = 6 * λ^2
@@ -75,6 +132,35 @@ function ad_rand(rng::AbstractRNG, λ)
75132
end
76133

77134
# Procedure F
135+
function procf(λ, K::Int, s::Float64)
136+
INV_SQRT_2PI = 0.3989422804014327 # 1/sqrt(2π)
137+
ω = INV_SQRT_2PI / s
138+
b1 = 1 / (24 * λ)
139+
b2 = 0.3 * b1^2
140+
c3 = b1 * b2 / 7
141+
c2 = b2 - 15 * c3
142+
c1 = b1 - 6 * b2 + 45 * c3
143+
c0 = 1 - b1 + 3 * b2 - 15 * c3
144+
145+
if K < 10
146+
px = -λ
147+
log_py = K * log(λ) - loggamma(K + 1) # log(K!) via loggamma
148+
py = exp(log_py)
149+
else
150+
δ = 1 / (12 * K)
151+
δ -= 4.8 * δ^3
152+
V =- K) / K
153+
px = K * log1pmx(V) - δ
154+
py = INV_SQRT_2PI / sqrt(K)
155+
end
156+
157+
X = (K - λ + 0.5) / s
158+
X2 = X^2
159+
fx = -X2 / 2
160+
fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
161+
return px, py, fx, fy
162+
end
163+
78164
function procf(λ, K::Int, s::Float64)
79165
# can be pre-computed, but does not seem to affect performance
80166
INV_SQRT_2PI = inv(sqrt(2pi))
@@ -114,16 +200,16 @@ Generates Poisson(λ) distributed random numbers using a fast polyalgorithm.
114200
## Examples
115201
116202
```julia
117-
# Simple Poisson random
203+
# Simple Poisson random which works on GPU
118204
pois_rand(λ)
119205
120-
# Using another RNG
206+
# Using RNG
121207
using RandomNumbers
122208
rng = Xorshifts.Xoroshiro128Plus()
123209
pois_rand(rng, λ)
124210
```
125211
"""
126-
pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ)
212+
pois_rand(λ) = λ < 6 ? count_rand(λ) : ad_rand(λ)
127213
pois_rand(rng::AbstractRNG, λ) = λ < 6 ? count_rand(rng, λ) : ad_rand(rng, λ)
128214

129215
end # module

0 commit comments

Comments
 (0)