|
| 1 | +module ChainRulesCoreSparseArraysExt |
| 2 | + |
| 3 | +using ChainRulesCore |
| 4 | +using ChainRulesCore: project_type, _projection_mismatch |
| 5 | +using SparseArrays: SparseVector, SparseMatrixCSC, nzrange, rowvals |
| 6 | + |
| 7 | +ChainRulesCore.is_inplaceable_destination(::SparseVector) = true |
| 8 | +ChainRulesCore.is_inplaceable_destination(::SparseMatrixCSC) = true |
| 9 | + |
| 10 | +# Word from on high is that we should regard all un-stored values of sparse arrays as |
| 11 | +# structural zeros. Thus ProjectTo needs to store nzind, and get only those. |
| 12 | +# This implementation very naiive, can probably be made more efficient. |
| 13 | + |
| 14 | +function ChainRulesCore.ProjectTo(x::SparseVector{T}) where {T<:Number} |
| 15 | + return ProjectTo{SparseVector}(; |
| 16 | + element=ProjectTo(zero(T)), nzind=x.nzind, axes=axes(x) |
| 17 | + ) |
| 18 | +end |
| 19 | +function (project::ProjectTo{SparseVector})(dx::AbstractArray) |
| 20 | + dy = if axes(dx) == project.axes |
| 21 | + dx |
| 22 | + else |
| 23 | + if size(dx, 1) != length(project.axes[1]) |
| 24 | + throw(_projection_mismatch(project.axes, size(dx))) |
| 25 | + end |
| 26 | + reshape(dx, project.axes) |
| 27 | + end |
| 28 | + nzval = map(i -> project.element(dy[i]), project.nzind) |
| 29 | + return SparseVector(length(dx), project.nzind, nzval) |
| 30 | +end |
| 31 | +function (project::ProjectTo{SparseVector})(dx::SparseVector) |
| 32 | + if size(dx) != map(length, project.axes) |
| 33 | + throw(_projection_mismatch(project.axes, size(dx))) |
| 34 | + end |
| 35 | + # When sparsity pattern is unchanged, all the time is in checking this, |
| 36 | + # perhaps some simple hash/checksum might be good enough? |
| 37 | + samepattern = project.nzind == dx.nzind |
| 38 | + # samepattern = length(project.nzind) == length(dx.nzind) |
| 39 | + if eltype(dx) <: project_type(project.element) && samepattern |
| 40 | + return dx |
| 41 | + elseif samepattern |
| 42 | + nzval = map(project.element, dx.nzval) |
| 43 | + SparseVector(length(dx), dx.nzind, nzval) |
| 44 | + else |
| 45 | + nzind = project.nzind |
| 46 | + # Or should we intersect? Can this exploit sorting? |
| 47 | + # nzind = intersect(project.nzind, dx.nzind) |
| 48 | + nzval = map(i -> project.element(dx[i]), nzind) |
| 49 | + return SparseVector(length(dx), nzind, nzval) |
| 50 | + end |
| 51 | +end |
| 52 | + |
| 53 | +function ChainRulesCore.ProjectTo(x::SparseMatrixCSC{T}) where {T<:Number} |
| 54 | + return ProjectTo{SparseMatrixCSC}(; |
| 55 | + element=ProjectTo(zero(T)), |
| 56 | + axes=axes(x), |
| 57 | + rowval=rowvals(x), |
| 58 | + nzranges=nzrange.(Ref(x), axes(x, 2)), |
| 59 | + colptr=x.colptr, |
| 60 | + ) |
| 61 | +end |
| 62 | +# You need not really store nzranges, you can get them from colptr -- TODO |
| 63 | +# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1) |
| 64 | +function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray) |
| 65 | + dy = if axes(dx) == project.axes |
| 66 | + dx |
| 67 | + else |
| 68 | + if size(dx) != (length(project.axes[1]), length(project.axes[2])) |
| 69 | + throw(_projection_mismatch(project.axes, size(dx))) |
| 70 | + end |
| 71 | + reshape(dx, project.axes) |
| 72 | + end |
| 73 | + nzval = Vector{project_type(project.element)}(undef, length(project.rowval)) |
| 74 | + k = 0 |
| 75 | + for col in project.axes[2] |
| 76 | + for i in project.nzranges[col] |
| 77 | + row = project.rowval[i] |
| 78 | + val = dy[row, col] |
| 79 | + nzval[k += 1] = project.element(val) |
| 80 | + end |
| 81 | + end |
| 82 | + m, n = map(length, project.axes) |
| 83 | + return SparseMatrixCSC(m, n, project.colptr, project.rowval, nzval) |
| 84 | +end |
| 85 | + |
| 86 | +function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC) |
| 87 | + if size(dx) != map(length, project.axes) |
| 88 | + throw(_projection_mismatch(project.axes, size(dx))) |
| 89 | + end |
| 90 | + samepattern = dx.colptr == project.colptr && dx.rowval == project.rowval |
| 91 | + # samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end] |
| 92 | + if eltype(dx) <: project_type(project.element) && samepattern |
| 93 | + return dx |
| 94 | + elseif samepattern |
| 95 | + nzval = map(project.element, dx.nzval) |
| 96 | + m, n = size(dx) |
| 97 | + return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval) |
| 98 | + else |
| 99 | + invoke(project, Tuple{AbstractArray}, dx) |
| 100 | + end |
| 101 | +end |
| 102 | + |
| 103 | +end # module |
0 commit comments