Skip to content

Commit c4096ef

Browse files
committed
add make(Sparse..., [X], p, dims...)
1 parent 70eadb3 commit c4096ef

2 files changed

Lines changed: 35 additions & 10 deletions

File tree

src/sampling.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -419,21 +419,41 @@ find_type(::Type{BitArray}, _, ::Dims{N}) where {N} = BitArray{N}
419419

420420
### sparse vectors & matrices
421421

422-
make(p::AbstractFloat, X, dims::Dims{1}) = Make3{SparseVector{ val_gentype(X), Int}}(X, dims, p)
423-
make(p::AbstractFloat, X, dims::Dims{2}) = Make3{SparseMatrixCSC{val_gentype(X), Int}}(X, dims, p)
422+
find_type(::Type{SparseVector}, X, p::AbstractFloat, dims::Dims{1}) = SparseVector{ val_gentype(X), Int}
423+
find_type(::Type{SparseMatrixCSC}, X, p::AbstractFloat, dims::Dims{2}) = SparseMatrixCSC{val_gentype(X), Int}
424424

425-
make(p::AbstractFloat, X, dims::Integer...) = make(p, X, Dims(dims))
426-
make(p::AbstractFloat, dims::Dims) = make(p, Float64, dims)
427-
make(p::AbstractFloat, dims::Integer...) = make(p, Float64, Dims(dims))
425+
find_type(::Type{SparseVector{X}}, _, p::AbstractFloat, dims::Dims{1}) where {X} = SparseVector{ X, Int}
426+
find_type(::Type{SparseMatrixCSC{X}}, _, p::AbstractFloat, dims::Dims{2}) where {X} = SparseMatrixCSC{X, Int}
427+
428+
# need to be explicit and split these defs in 2 (or 4) to avoid ambiguities
429+
make(T::Type{SparseVector}, X, p::AbstractFloat, d1::Integer) = make(T, X, p, Dims((d1,)))
430+
make(T::Type{SparseVector}, ::Type{X}, p::AbstractFloat, d1::Integer) where {X} = make(T, X, p, Dims((d1,)))
431+
make(T::Type{SparseMatrixCSC}, X, p::AbstractFloat, d1::Integer, d2::Integer) = make(T, X, p, Dims((d1, d2)))
432+
make(T::Type{SparseMatrixCSC}, ::Type{X}, p::AbstractFloat, d1::Integer, d2::Integer) where {X} = make(T, X, p, Dims((d1, d2)))
433+
434+
make(T::Type{SparseVector}, p::AbstractFloat, d1::Integer) = make(T, default_sampling(T), p, Dims((d1,)))
435+
make(T::Type{SparseMatrixCSC}, p::AbstractFloat, d1::Integer, d2::Integer) = make(T, default_sampling(T), p, Dims((d1, d2)))
436+
437+
make(T::Type{SparseVector}, p::AbstractFloat, dims::Dims{1}) = make(T, default_sampling(T), p, dims)
438+
make(T::Type{SparseMatrixCSC}, p::AbstractFloat, dims::Dims{2}) = make(T, default_sampling(T), p, dims)
439+
440+
make(p::AbstractFloat, X, dims::Dims{1}) = make(SparseVector, X, p, dims)
441+
make(p::AbstractFloat, X, dims::Dims{2}) = make(SparseMatrixCSC, X, p, dims)
442+
443+
make(p::AbstractFloat, X, dims::Integer...) = make(p, X, Dims(dims))
444+
make(p::AbstractFloat, dims::Dims) = make(p, default_sampling(AbstractArray), dims)
445+
make(p::AbstractFloat, dims::Integer...) = make(p, default_sampling(AbstractArray), Dims(dims))
428446

429447
Sampler(RNG::Type{<:AbstractRNG}, c::Make3{A}, n::Repetition) where {A<:AbstractSparseArray} =
430-
SamplerTag{A}((sampler(RNG, c.x, n), c.y, c.z))
448+
SamplerTag{A}((sp = sampler(RNG, c.x, n),
449+
p = c.y,
450+
dims = c.z))
431451

432452
rand(rng::AbstractRNG, sp::SamplerTag{A}) where {A<:SparseVector} =
433-
sprand(rng, sp.data[2][1], sp.data[3], (r, n)->rand(r, sp.data[1], n))
453+
sprand(rng, sp.data.dims[1], sp.data.p, (r, n)->rand(r, sp.data.sp, n))
434454

435455
rand(rng::AbstractRNG, sp::SamplerTag{A}) where {A<:SparseMatrixCSC} =
436-
sprand(rng, sp.data[2][1], sp.data[2][2], sp.data[3], (r, n)->rand(r, sp.data[1], n), gentype(sp.data[1]))
456+
sprand(rng, sp.data.dims[1], sp.data.dims[2], sp.data.p, (r, n)->rand(r, sp.data.sp, n), gentype(sp.data.sp))
437457

438458

439459
### String as a scalar

test/runtests.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,14 @@ end
444444
[(2,3)] => 2,
445445
[6] => 1,
446446
[2, 3] => 2,
447-
[Int8(2), Int16(3)] => 2)
447+
[Int8(2), Int16(3)] => 2),
448+
form = ([], [dim == 1 ? SparseVector : SparseMatrixCSC])
448449

449-
s = rand(make(0.3, k..., d...))
450+
if form == []
451+
s = rand(make(0.3, k..., d...))
452+
else
453+
s = rand(make(form..., k..., 0.3, d...))
454+
end
450455
@test s isa (dim == 1 ? SparseVector{Float64,Int} :
451456
SparseMatrixCSC{Float64,Int})
452457
@test length(s) == 6

0 commit comments

Comments
 (0)