Skip to content

Commit 5c66243

Browse files
authored
Use LogExpFunctions
no reason to carry around our own log1pmx implimentation.
1 parent b8c1567 commit 5c66243

1 file changed

Lines changed: 14 additions & 52 deletions

File tree

src/PoissonRandom.jl

Lines changed: 14 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module PoissonRandom
22

33
using Random
4+
using LogExpFunctions: log1pmx
45

56
export pois_rand
67

@@ -26,12 +27,12 @@ end
2627
ad_rand(λ) = ad_rand(Random.GLOBAL_RNG, λ)
2728
function ad_rand(rng::AbstractRNG, λ)
2829
s = sqrt(λ)
29-
d = 6.0 * λ^2
30+
d = 6 * λ^2
3031
L = floor(Int, λ - 1.1484)
3132
# Step N
3233
G = λ + s * randn(rng)
3334

34-
if G >= 0.0
35+
if G >= 0
3536
K = floor(Int, G)
3637
# Step I
3738
if K >= L
@@ -56,7 +57,7 @@ function ad_rand(rng::AbstractRNG, λ)
5657
while true
5758
# Step E
5859
E = randexp(rng)
59-
U = 2.0 * rand(rng) - 1.0
60+
U = 2 * rand(rng) - 1
6061
T = 1.8 + copysign(E, U)
6162
if T <= -0.6744
6263
continue
@@ -73,70 +74,31 @@ function ad_rand(rng::AbstractRNG, λ)
7374
end
7475
end
7576

76-
# log(1+x)-x
77-
# accurate ~2ulps for -0.227 < x < 0.315
78-
function log1pmx_kernel(x::Float64)
79-
r = x / (x + 2.0)
80-
t = r * r
81-
w = @evalpoly(t,
82-
6.66666666666666667e-1, # 2/3
83-
4.00000000000000000e-1, # 2/5
84-
2.85714285714285714e-1, # 2/7
85-
2.22222222222222222e-1, # 2/9
86-
1.81818181818181818e-1, # 2/11
87-
1.53846153846153846e-1, # 2/13
88-
1.33333333333333333e-1, # 2/15
89-
1.17647058823529412e-1) # 2/17
90-
hxsq = 0.5 * x * x
91-
r * (hxsq + w * t) - hxsq
92-
end
93-
94-
# use naive calculation or range reduction outside kernel range.
95-
# accurate ~2ulps for all x
96-
function log1pmx(x::Float64)
97-
if !(-0.7 < x < 0.9)
98-
return log1p(x) - x
99-
elseif x > 0.315
100-
u = (x - 0.5) / 1.5
101-
return log1pmx_kernel(u) - 9.45348918918356180e-2 - 0.5 * u
102-
elseif x > -0.227
103-
return log1pmx_kernel(x)
104-
elseif x > -0.4
105-
u = (x + 0.25) / 0.75
106-
return log1pmx_kernel(u) - 3.76820724517809274e-2 + 0.25 * u
107-
elseif x > -0.6
108-
u = (x + 0.5) * 2.0
109-
return log1pmx_kernel(u) - 1.93147180559945309e-1 + 0.5 * u
110-
else
111-
u = (x + 0.625) / 0.375
112-
return log1pmx_kernel(u) - 3.55829253011726237e-1 + 0.625 * u
113-
end
114-
end
115-
11677
# Procedure F
11778
function procf(λ, K::Int, s::Float64)
11879
# can be pre-computed, but does not seem to affect performance
119-
ω = 0.3989422804014327 / s
120-
b1 = 0.041666666666666664 / λ
80+
INV_SQRT_2PI = inv(sqrt(2pi))
81+
ω = INV_SQRT_2PI / s
82+
b1 = inv(24) / λ
12183
b2 = 0.3 * b1 * b1
122-
c3 = 0.14285714285714285 * b1 * b2
123-
c2 = b2 - 15.0 * c3
124-
c1 = b1 - 6.0 * b2 + 45.0 * c3
125-
c0 = 1.0 - b1 + 3.0 * b2 - 15.0 * c3
84+
c3 = inv(7) * b1 * b2
85+
c2 = b2 - 15 * c3
86+
c1 = b1 - 6 * b2 + 45 * c3
87+
c0 = 1 - b1 + 3 * b2 - 15 * c3
12688

12789
if K < 10
12890
px = -float(λ)
12991
py = λ^K / factorial(K)
13092
else
131-
δ = 0.08333333333333333 / K
93+
δ = inv(12) / K
13294
δ -= 4.8 * δ^3
13395
V =- K) / K
13496
px = K * log1pmx(V) - δ # avoids need for table
135-
py = 0.3989422804014327 / sqrt(K)
97+
py = INV_SQRT_2PI / sqrt(K)
13698
end
13799
X = (K - λ + 0.5) / s
138100
X2 = X^2
139-
fx = -0.5 * X2 # missing negation in pseudo-algorithm, but appears in fortran code.
101+
fx = X2 / -2 # missing negation in pseudo-algorithm, but appears in fortran code.
140102
fy = ω * (((c3 * X2 + c2) * X2 + c1) * X2 + c0)
141103
return px, py, fx, fy
142104
end

0 commit comments

Comments
 (0)