Skip to content

Commit 3efc70b

Browse files
Merge pull request #54 from sivasathyaseeelan/pois_rand-gpu
gpu support for pois_rand
2 parents 505b336 + 14e3940 commit 3efc70b

1 file changed

Lines changed: 16 additions & 2 deletions

File tree

src/PoissonRandom.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@ module PoissonRandom
33
using Random
44
using LogExpFunctions: log1pmx
55

6-
export pois_rand
6+
export pois_rand, PassthroughRNG
7+
8+
# GPU-compatible Poisson sampling via PassthroughRNG
9+
struct PassthroughRNG <: AbstractRNG end
10+
11+
rand(rng::PassthroughRNG) = Random.rand()
12+
randexp(rng::PassthroughRNG) = Random.randexp()
13+
randn(rng::PassthroughRNG) = Random.randn()
14+
15+
rand(rng::AbstractRNG) = Random.rand(rng)
16+
randexp(rng::AbstractRNG) = Random.randexp(rng)
17+
randn(rng::AbstractRNG) = Random.randn(rng)
718

819
count_rand(λ) = count_rand(Random.GLOBAL_RNG, λ)
920
function count_rand(rng::AbstractRNG, λ)
@@ -88,7 +99,7 @@ function procf(λ, K::Int, s::Float64)
8899

89100
if K < 10
90101
px = -float(λ)
91-
py = λ^K / factorial(K)
102+
py = λ^K / prod(2:K)
92103
else
93104
δ = inv(12) / K
94105
δ -= 4.8 * δ^3
@@ -121,6 +132,9 @@ pois_rand(λ)
121132
using RandomNumbers
122133
rng = Xorshifts.Xoroshiro128Plus()
123134
pois_rand(rng, λ)
135+
136+
# Simple Poisson random on GPU
137+
pois_rand(PoissonRandom.PassthroughRNG(), λ)
124138
```
125139
"""
126140
pois_rand(λ) = pois_rand(Random.GLOBAL_RNG, λ)

0 commit comments

Comments
 (0)