Skip to content

Commit 2dd41c6

Browse files
authored
Fix type issues (#141)
1 parent 04204f7 commit 2dd41c6

6 files changed

Lines changed: 29 additions & 27 deletions

File tree

src/calculus/tilt.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Given function `f`, an array `a` and a constant `b` (optional), return function
1010
g(x) = f(x) + \\langle a, x \\rangle + b.
1111
```
1212
"""
13-
struct Tilt{T, S <: AbstractArray, R <: Real}
13+
struct Tilt{T, S, R}
1414
f::T
1515
a::S
1616
b::R
@@ -24,18 +24,18 @@ is_smooth(::Type{<:Tilt{T}}) where T = is_smooth(T)
2424
is_generalized_quadratic(::Type{<:Tilt{T}}) where T = is_generalized_quadratic(T)
2525
is_strongly_convex(::Type{<:Tilt{T}}) where T = is_strongly_convex(T)
2626

27-
Tilt(f::T, a::S) where {R <: Real, T, S <: AbstractArray{R}} = Tilt{T, S, R}(f, a, R(0))
27+
Tilt(f::T, a::S) where {T, S} = Tilt{T, S, real(eltype(S))}(f, a, real(eltype(S))(0))
2828

29-
function (g::Tilt)(x::AbstractArray{T}) where T <: RealOrComplex
30-
return g.f(x) + dot(g.a, x) + g.b
29+
function (g::Tilt)(x)
30+
return g.f(x) + real(dot(g.a, x)) + g.b
3131
end
3232

33-
function prox!(y::AbstractArray{T}, g::Tilt, x::AbstractArray{T}, gamma=R(1)) where {R <: Real, T <: RealOrComplex{R}}
33+
function prox!(y, g::Tilt, x, gamma)
3434
v = prox!(y, g.f, x .- gamma .* g.a, gamma)
35-
return v + dot(g.a, y) + g.b
35+
return v + real(dot(g.a, y)) + g.b
3636
end
3737

38-
function prox_naive(g::Tilt, x::AbstractArray{T}, gamma=R(1)) where {R <: Real, T <: RealOrComplex{R}}
38+
function prox_naive(g::Tilt, x, gamma)
3939
y, v = prox_naive(g.f, x .- gamma .* g.a, gamma)
40-
return y, v + dot(g.a, y) + g.b
40+
return y, v + real(dot(g.a, y)) + g.b
4141
end

src/functions/indGraphSkinny.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
using LinearAlgebra
44

5-
struct IndGraphSkinny{T <: RealOrComplex} <: IndGraph
5+
struct IndGraphSkinny{T} <: IndGraph
66
m::Int
77
n::Int
88
A::Matrix{T}

src/functions/indPSD.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333

3434
IndPSD(; scaling=false) = IndPSD(scaling)
3535

36-
function (::IndPSD)(X::HermOrSym)
36+
function (::IndPSD)(X::Union{Symmetric, Hermitian})
3737
R = real(eltype(X))
3838
F = eigen(X)
3939
for i in eachindex(F.values)
@@ -48,7 +48,7 @@ end
4848
is_convex(f::Type{<:IndPSD}) = true
4949
is_cone(f::Type{<:IndPSD}) = true
5050

51-
function prox!(Y::HermOrSym, ::IndPSD, X::HermOrSym, gamma)
51+
function prox!(Y::Union{Symmetric, Hermitian}, ::IndPSD, X::Union{Symmetric, Hermitian}, gamma)
5252
R = real(eltype(X))
5353
n = size(X, 1)
5454
F = eigen(X)
@@ -65,7 +65,7 @@ function prox!(Y::HermOrSym, ::IndPSD, X::HermOrSym, gamma)
6565
return R(0)
6666
end
6767

68-
function prox_naive(::IndPSD, X::HermOrSym, gamma)
68+
function prox_naive(::IndPSD, X::Union{Symmetric, Hermitian}, gamma)
6969
R = real(eltype(X))
7070
F = eigen(X)
7171
return F.vectors * Diagonal(max.(R(0), F.values)) * F.vectors', R(0)

src/functions/leastSquaresDirect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function LeastSquaresDirect(A::M, b, lambda) where M <: SparseMatrixCSC
4949
LeastSquaresDirect{ndims(b), R, C, M, typeof(b), SuiteSparse.CHOLMOD.Factor{C}, lambda >= 0}(A, b, R(lambda))
5050
end
5151

52-
function LeastSquaresDirect(A::TransposeOrAdjoint, b, lambda)
52+
function LeastSquaresDirect(A::Union{Transpose, Adjoint}, b, lambda)
5353
LeastSquaresDirect(copy(A), b, lambda)
5454
end
5555

src/functions/leastSquaresIterative.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ function LeastSquaresIterative(A::M, b, lambda) where M
2626
m, n = size(A)
2727
x_shape = infer_shape_of_x(A, b)
2828
shape, S, res2 = if m >= n
29-
:Tall, AcA(A), []
29+
:Tall, AcA(A, x_shape), []
3030
else
31-
:Fat, AAc(A), zero(b)
31+
:Fat, AAc(A, size(b)), zero(b)
3232
end
3333
RC = eltype(A)
3434
R = real(RC)

src/utilities/linops.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ size(Op::LinOp, i::Integer) = i <= 2 ? size(Op)[i] : 1
1717

1818
# AAc (Gram matrix)
1919

20-
mutable struct AAc{M, T} <: LinOp
20+
struct AAc{M, T} <: LinOp
2121
A::M
22-
buf::Maybe{AbstractArray{T}}
23-
function AAc{M, T}(A::M) where {M, T}
24-
new(A, nothing)
25-
end
22+
buf::T
2623
end
2724

28-
AAc(A::M) where {M} = AAc{M, eltype(A)}(A)
25+
function AAc(A::M, input_shape::Tuple) where M
26+
buffer_shape = (size(A, 2), input_shape[2:end]...)
27+
buffer = zeros(eltype(A), buffer_shape)
28+
AAc(A, buffer)
29+
end
2930

3031
function mul!(y, Op::AAc, x)
3132
if Op.buf === nothing
@@ -43,15 +44,16 @@ eltype(Op::AAc) = eltype(Op.A)
4344

4445
# AcA (Covariance matrix)
4546

46-
mutable struct AcA{M, T} <: LinOp
47+
struct AcA{M, T} <: LinOp
4748
A::M
48-
buf::Maybe{AbstractArray{T}}
49-
function AcA{M, T}(A::M) where {M, T}
50-
new(A, nothing)
51-
end
49+
buf::T
5250
end
5351

54-
AcA(A::M) where {M} = AcA{M, eltype(A)}(A)
52+
function AcA(A::M, input_shape::Tuple) where M
53+
buffer_shape = (size(A, 1), input_shape[2:end]...)
54+
buffer = zeros(eltype(A), buffer_shape)
55+
AcA(A, buffer)
56+
end
5557

5658
function mul!(y, Op::AcA, x)
5759
if Op.buf === nothing

0 commit comments

Comments
 (0)