Skip to content

Commit 137f1f1

Browse files
committed
switch args in make([X], p::AbstractFloat, dims...)
1 parent c4096ef commit 137f1f1

3 files changed

Lines changed: 42 additions & 29 deletions

File tree

src/containers.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ end
6161

6262
@make_container(::Type{String}, [n::Integer])
6363
# sparse vectors & matrices
64-
@make_container(p::AbstractFloat, m::Integer, [n::Integer])
6564
# Tuple as a container
6665
@make_container(T::Type{<:Tuple})
6766
@make_container(::Type{Tuple}, n::Integer)
@@ -71,26 +70,27 @@ end
7170

7271
macro make_array_container(Cont)
7372
definitions =
74-
[ :(rand(rng::AbstractRNG, T::Type{<:$Cont}, dims::Dims) = rand(rng, make(T, dims))),
75-
:(rand( T::Type{<:$Cont}, dims::Dims) = rand(GLOBAL_RNG, make(T, dims))),
76-
:(rand(rng::AbstractRNG, T::Type{<:$Cont}, dims::Integer...) = rand(rng, make(T, Dims(dims)))),
77-
:(rand( T::Type{<:$Cont}, dims::Integer...) = rand(GLOBAL_RNG, make(T, Dims(dims)))),
78-
79-
:(rand(rng::AbstractRNG, X, T::Type{<:$Cont}, dims::Dims) = rand(rng, make(T, X, dims))),
80-
:(rand( X, T::Type{<:$Cont}, dims::Dims) = rand(GLOBAL_RNG, make(T, X, dims))),
81-
:(rand(rng::AbstractRNG, X, T::Type{<:$Cont}, dims::Integer...) = rand(rng, make(T, X, Dims(dims)))),
82-
:(rand( X, T::Type{<:$Cont}, dims::Integer...) = rand(GLOBAL_RNG, make(T, X, Dims(dims)))),
83-
84-
:(rand(rng::AbstractRNG, ::Type{X}, T::Type{<:$Cont}, dims::Dims) where {X} = rand(rng, make(T, X, dims))),
85-
:(rand( ::Type{X}, T::Type{<:$Cont}, dims::Dims) where {X} = rand(GLOBAL_RNG, make(T, X, dims))),
86-
:(rand(rng::AbstractRNG, ::Type{X}, T::Type{<:$Cont}, dims::Integer...) where {X} = rand(rng, make(T, X, Dims(dims)))),
87-
:(rand( ::Type{X}, T::Type{<:$Cont}, dims::Integer...) where {X} = rand(GLOBAL_RNG, make(T, X, Dims(dims)))),
73+
[ :(rand(rng::AbstractRNG, $Cont, dims::Dims) = rand(rng, make(t, dims))),
74+
:(rand( $Cont, dims::Dims) = rand(GLOBAL_RNG, make(t, dims))),
75+
:(rand(rng::AbstractRNG, $Cont, dims::Integer...) = rand(rng, make(t, Dims(dims)))),
76+
:(rand( $Cont, dims::Integer...) = rand(GLOBAL_RNG, make(t, Dims(dims)))),
77+
78+
:(rand(rng::AbstractRNG, X, $Cont, dims::Dims) = rand(rng, make(t, X, dims))),
79+
:(rand( X, $Cont, dims::Dims) = rand(GLOBAL_RNG, make(t, X, dims))),
80+
:(rand(rng::AbstractRNG, X, $Cont, dims::Integer...) = rand(rng, make(t, X, Dims(dims)))),
81+
:(rand( X, $Cont, dims::Integer...) = rand(GLOBAL_RNG, make(t, X, Dims(dims)))),
82+
83+
:(rand(rng::AbstractRNG, ::Type{X}, $Cont, dims::Dims) where {X} = rand(rng, make(t, X, dims))),
84+
:(rand( ::Type{X}, $Cont, dims::Dims) where {X} = rand(GLOBAL_RNG, make(t, X, dims))),
85+
:(rand(rng::AbstractRNG, ::Type{X}, $Cont, dims::Integer...) where {X} = rand(rng, make(t, X, Dims(dims)))),
86+
:(rand( ::Type{X}, $Cont, dims::Integer...) where {X} = rand(GLOBAL_RNG, make(t, X, Dims(dims)))),
8887
]
8988
esc(Expr(:block, definitions...))
9089
end
9190

92-
@make_array_container(Array)
93-
@make_array_container(BitArray)
91+
@make_array_container(t::Type{<:Array})
92+
@make_array_container(t::Type{<:BitArray})
93+
@make_array_container(t::AbstractFloat)
9494

9595

9696
## sets/dicts

src/sampling.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,18 @@ make(T::Type{SparseMatrixCSC}, p::AbstractFloat, d1::Integer, d2::Integer) = mak
437437
make(T::Type{SparseVector}, p::AbstractFloat, dims::Dims{1}) = make(T, default_sampling(T), p, dims)
438438
make(T::Type{SparseMatrixCSC}, p::AbstractFloat, dims::Dims{2}) = make(T, default_sampling(T), p, dims)
439439

440-
make(p::AbstractFloat, X, dims::Dims{1}) = make(SparseVector, X, p, dims)
441-
make(p::AbstractFloat, X, dims::Dims{2}) = make(SparseMatrixCSC, X, p, dims)
442-
443-
make(p::AbstractFloat, X, dims::Integer...) = make(p, X, Dims(dims))
444-
make(p::AbstractFloat, dims::Dims) = make(p, default_sampling(AbstractArray), dims)
445-
make(p::AbstractFloat, dims::Integer...) = make(p, default_sampling(AbstractArray), Dims(dims))
440+
make(X, p::AbstractFloat, dims::Dims{1}) = make(SparseVector, X, p, dims)
441+
make(::Type{X}, p::AbstractFloat, dims::Dims{1}) where {X} = make(SparseVector, X, p, dims)
442+
make(X, p::AbstractFloat, dims::Dims{2}) = make(SparseMatrixCSC, X, p, dims)
443+
make(::Type{X}, p::AbstractFloat, dims::Dims{2}) where {X} = make(SparseMatrixCSC, X, p, dims)
444+
445+
make(X, p::AbstractFloat, dims::Integer...) = make(X, p, Dims(dims))
446+
make(::Type{X}, p::AbstractFloat, dims::Integer...) where {X} = make(X, p, Dims(dims))
447+
make( p::AbstractFloat, dims::Dims) = make(default_sampling(AbstractArray), p, dims)
448+
make( p::AbstractFloat, dims::Integer...) = make(default_sampling(AbstractArray), p, Dims(dims))
449+
450+
# to make @make_array_container work:
451+
make(p::AbstractFloat, X, dims::Dims) = make(X, p, dims)
446452

447453
Sampler(RNG::Type{<:AbstractRNG}, c::Make3{A}, n::Repetition) where {A<:AbstractSparseArray} =
448454
SamplerTag{A}((sp = sampler(RNG, c.x, n),

test/runtests.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,22 @@ const spString = Sampler(MersenneTwister, String)
137137

138138
# sparse
139139
@test rand(rng..., Float64, .5, 10) isa SparseVector{Float64}
140+
@test rand(rng..., Float64, .5, (10,)) isa SparseVector{Float64}
141+
140142
@test rand(rng..., .5, 10) isa SparseVector{Float64}
143+
@test rand(rng..., .5, (10,)) isa SparseVector{Float64}
144+
141145
@test rand(rng..., Int, .5, 10) isa SparseVector{Int}
146+
@test rand(rng..., Int, .5, (10,)) isa SparseVector{Int}
147+
142148
@test rand(rng..., Float64, .5, 10, 3) isa SparseMatrixCSC{Float64}
149+
@test rand(rng..., Float64, .5, (10, 3)) isa SparseMatrixCSC{Float64}
150+
143151
@test rand(rng..., .5, 10, 3) isa SparseMatrixCSC{Float64}
152+
@test rand(rng..., .5, (10, 3)) isa SparseMatrixCSC{Float64}
153+
144154
@test rand(rng..., Int, .5, 10, 3) isa SparseMatrixCSC{Int}
155+
@test rand(rng..., Int, .5, (10, 3)) isa SparseMatrixCSC{Int}
145156

146157
# BitArray
147158
for S = ([], [Bool], [Bernoulli()])
@@ -447,16 +458,12 @@ end
447458
[Int8(2), Int16(3)] => 2),
448459
form = ([], [dim == 1 ? SparseVector : SparseMatrixCSC])
449460

450-
if form == []
451-
s = rand(make(0.3, k..., d...))
452-
else
453-
s = rand(make(form..., k..., 0.3, d...))
454-
end
461+
s = rand(make(form..., k..., 0.3, d...))
455462
@test s isa (dim == 1 ? SparseVector{Float64,Int} :
456463
SparseMatrixCSC{Float64,Int})
457464
@test length(s) == 6
458465
end
459-
@test rand(make(0.3, spString, 9)) isa SparseVector{String}
466+
@test rand(make(spString, 0.3, 9)) isa SparseVector{String}
460467
end
461468

462469
@testset "rand(make(default))" begin

0 commit comments

Comments
 (0)