Skip to content

Commit 5d005f8

Browse files
committed
add support for StaticArrays
1 parent 889303b commit 5d005f8

1 file changed

Lines changed: 28 additions & 7 deletions

File tree

src/sampling.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -385,11 +385,11 @@ rand(rng::AbstractRNG, sp::SamplerTag{Cont{S}}) where {S<:Base.ImmutableDict} =
385385
default_sampling(::Type{<:AbstractArray{T}}) where {T} = Uniform(T)
386386
default_sampling(::Type{<:AbstractArray}) = Uniform(Float64)
387387

388-
make(A::Type{<:AbstractArray}, X, dims::Integer...) = make(A, X, Dims(dims))
389-
make(A::Type{<:AbstractArray}, ::Type{X}, dims::Integer...) where {X} = make(A, X, Dims(dims))
388+
make(A::Type{<:AbstractArray}, X, d1::Integer, dims::Integer...) = make(A, X, Dims((d1, dims...)))
389+
make(A::Type{<:AbstractArray}, ::Type{X}, d1::Integer, dims::Integer...) where {X} = make(A, X, Dims((d1, dims...)))
390390

391-
make(A::Type{<:AbstractArray}, dims::Dims) = make(A, default_sampling(A), dims)
392-
make(A::Type{<:AbstractArray}, dims::Integer...) = make(A, default_sampling(A), Dims(dims))
391+
make(A::Type{<:AbstractArray}, dims::Dims) = make(A, default_sampling(A), dims)
392+
make(A::Type{<:AbstractArray}, d1::Integer, dims::Integer...) = make(A, default_sampling(A), Dims((d1, dims...)))
393393

394394

395395
Sampler(RNG::Type{<:AbstractRNG}, c::Make2{A}, n::Repetition) where {A<:AbstractArray} =
@@ -411,9 +411,9 @@ find_type(A::Type{Array}, X, ::Dims{N}) where {N} = Array{val_ge
411411
# special shortcut
412412

413413
make(X, dims::Dims) = make(Array, X, dims)
414-
make(X, d1::Integer, dims::Integer...) = make(Array, X, Dims(tuple(d1, dims...)))
414+
make(X, d1::Integer, dims::Integer...) = make(Array, X, Dims((d1, dims...)))
415415
make(::Type{X}, dims::Dims) where {X} = make(Array, X, dims)
416-
make(::Type{X}, d1::Integer, dims::Integer...) where {X} = make(Array, X, Dims(tuple(d1, dims...)))
416+
make(::Type{X}, d1::Integer, dims::Integer...) where {X} = make(Array, X, Dims((d1, dims...)))
417417
make( dims::Integer...) = make(Array, default_sampling(Array), Dims(dims))
418418

419419
# omitted: make(dims::Dims)
@@ -427,7 +427,7 @@ find_type(::Type{BitArray{N}}, _, ::Dims{N}) where {N} = BitArray{N}
427427
find_type(::Type{BitArray}, _, ::Dims{N}) where {N} = BitArray{N}
428428

429429

430-
### sparse vectors & matrices
430+
#### sparse vectors & matrices
431431

432432
find_type(::Type{SparseVector}, X, p::AbstractFloat, dims::Dims{1}) = SparseVector{ val_gentype(X), Int}
433433
find_type(::Type{SparseMatrixCSC}, X, p::AbstractFloat, dims::Dims{2}) = SparseMatrixCSC{val_gentype(X), Int}
@@ -470,6 +470,27 @@ rand(rng::AbstractRNG, sp::SamplerTag{A}) where {A<:SparseMatrixCSC} =
470470
sprand(rng, sp.data.dims[1], sp.data.dims[2], sp.data.p, (r, n)->rand(r, sp.data.sp, n), gentype(sp.data.sp))
471471

472472

473+
#### StaticArrays
474+
475+
function random_staticarrays()
476+
@eval using StaticArrays: tuple_length, tuple_prod, SArray, MArray
477+
for Arr = (:SArray, :MArray)
478+
@eval begin
479+
find_type(::Type{<:$Arr{S}} , X) where {S<:Tuple} = $Arr{S,val_gentype(X),tuple_length(S),tuple_prod(S)}
480+
find_type(::Type{<:$Arr{S,T}}, _) where {S<:Tuple,T} = $Arr{S,T,tuple_length(S),tuple_prod(S)}
481+
482+
Sampler(RNG::Type{<:AbstractRNG}, c::Make1{A}, n::Repetition) where {A<:$Arr} =
483+
SamplerTag{Cont{A}}(Sampler(RNG, c.x, n))
484+
485+
rand(rng::AbstractRNG, sp::SamplerTag{Cont{$Arr{S,T,N,L}}}) where {S,T,N,L} =
486+
$Arr{S,T,N,L}(rand(rng, make(NTuple{L}, sp.data)))
487+
488+
@make_container(T::Type{<:$Arr})
489+
end
490+
end
491+
end
492+
493+
473494
### String as a scalar
474495

475496
let b = UInt8['0':'9';'A':'Z';'a':'z'],

0 commit comments

Comments
 (0)