Skip to content

Commit 25372cb

Browse files
lkdvoskshyatt
andauthored
GPU-friendly truncation implementations (#349)
* try to make truncation GPU-friendly * Temporarily fix StridedViews version * Revert "Temporarily fix StridedViews version" This reverts commit 77f0ffa. * Small update for diagonal pullbacks * Fix last error * Reenable truncated CUDA tests * make truncation run on GPU * bypass scalar indexing by specializing * convenience overloads * gpu-friendly copies * retain storagetype in extended_S * avoid GPU issues with truncated adjoint tensormaps * various utility improvements * complete rewrite of implementation * GPU doesn't like `trues` * remove CUDA specializations and temporarily add missing MatrixAlgebraKit thingies * better dimension testing * fix unbound type parameter * add missing import * be careful about double method definitions * disable diagonal test * bump MatrixAlgebraKit dependency * Revert "disable diagonal test" This reverts commit f26cffe. * remove unnecessary specializations * specialize CPU implementations * add explanation TruncationByOrder * add explanation and specialization TruncationByError * fix stupidity * fix views * enfore positive and finite p-norms --------- Co-authored-by: Katharine Hyatt <kslimes@gmail.com>
1 parent 3b826a7 commit 25372cb

10 files changed

Lines changed: 211 additions & 145 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ FiniteDifferences = "0.12"
4444
GPUArrays = "11.3.1"
4545
LRUCache = "1.0.2"
4646
LinearAlgebra = "1"
47-
MatrixAlgebraKit = "0.6.2"
47+
MatrixAlgebraKit = "0.6.3"
4848
Mooncake = "0.4.183"
4949
OhMyThreads = "0.8.0"
5050
Printf = "1"

ext/TensorKitCUDAExt/TensorKitCUDAExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ using TensorKit: MatrixAlgebraKit
1717
using Random
1818

1919
include("cutensormap.jl")
20+
include("truncation.jl")
2021

2122
end

ext/TensorKitCUDAExt/truncation.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
const CuSectorVector{T, I} = TensorKit.SectorVector{T, I, <:CuVector{T}}
2+
3+
function MatrixAlgebraKit.findtruncated(
4+
values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByOrder
5+
)
6+
I = sectortype(values)
7+
8+
dims = similar(values, Base.promote_op(dim, I))
9+
for (c, v) in pairs(dims)
10+
fill!(v, dim(c))
11+
end
12+
13+
perm = sortperm(parent(values); strategy.by, strategy.rev)
14+
cumulative_dim = cumsum(Base.permute!(parent(dims), perm))
15+
16+
result = similar(values, Bool)
17+
parent(result)[perm] .= cumulative_dim .<= strategy.howmany
18+
return result
19+
end
20+
21+
function MatrixAlgebraKit.findtruncated(
22+
values::CuSectorVector, strategy::MatrixAlgebraKit.TruncationByError
23+
)
24+
(isfinite(strategy.p) && strategy.p > 0) ||
25+
throw(ArgumentError(lazy"p-norm with p = $(strategy.p) is currently not supported."))
26+
ϵᵖmax = max(strategy.atol^strategy.p, strategy.rtol^strategy.p * norm(values, strategy.p))
27+
ϵᵖ = similar(values, typeof(ϵᵖmax))
28+
29+
# dimensions are all 1 so no need to account for weight
30+
if FusionStyle(sectortype(values)) isa UniqueFusion
31+
parent(ϵᵖ) .= abs.(parent(values)) .^ strategy.p
32+
else
33+
for (c, v) in pairs(values)
34+
v′ = ϵᵖ[c]
35+
v′ .= abs.(v) .^ strategy.p .* dim(c)
36+
end
37+
end
38+
39+
perm = sortperm(parent(values); by = abs, rev = false)
40+
cumulative_err = cumsum(Base.permute!(parent(ϵᵖ), perm))
41+
42+
result = similar(values, Bool)
43+
parent(result)[perm] .= cumulative_err .> ϵᵖmax
44+
return result
45+
end
46+
47+
# Needed until MatrixAlgebraKit patch hits...
48+
function MatrixAlgebraKit._ind_intersect(A::CuVector{Bool}, B::CuVector{Int})
49+
result = fill!(similar(A), false)
50+
result[B] .= @view A[B]
51+
return result
52+
end

src/factorizations/adjoint.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ _adjoint(alg::MAK.LAPACK_HouseholderLQ) = MAK.LAPACK_HouseholderQR(; alg.kwargs.
77
_adjoint(alg::MAK.LAPACK_HouseholderQL) = MAK.LAPACK_HouseholderRQ(; alg.kwargs...)
88
_adjoint(alg::MAK.LAPACK_HouseholderRQ) = MAK.LAPACK_HouseholderQL(; alg.kwargs...)
99
_adjoint(alg::MAK.PolarViaSVD) = MAK.PolarViaSVD(_adjoint(alg.svd_alg))
10+
_adjoint(alg::TruncatedAlgorithm) = TruncatedAlgorithm(_adjoint(alg.alg), alg.trunc)
1011
_adjoint(alg::AbstractAlgorithm) = alg
1112

1213
_adjoint(alg::MAK.CUSOLVER_HouseholderQR) = MAK.LQViaTransposedQR(alg)
@@ -81,7 +82,7 @@ for (left_f, right_f) in zip(
8182
end
8283

8384
# 3-arg functions
84-
for f in (:svd_full, :svd_compact)
85+
for f in (:svd_full, :svd_compact, :svd_trunc)
8586
f! = Symbol(f, :!)
8687
@eval function MAK.copy_input(::typeof($f), t::AdjointTensorMap)
8788
return adjoint(MAK.copy_input($f, adjoint(t)))
@@ -93,9 +94,16 @@ for f in (:svd_full, :svd_compact)
9394
return reverse(adjoint.(MAK.initialize_output($f!, adjoint(t), _adjoint(alg))))
9495
end
9596

96-
@eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
97-
F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
98-
return reverse(adjoint.(F′))
97+
if f === :svd_trunc
98+
function MAK.svd_trunc!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
99+
U, S, Vᴴ, ϵ = svd_trunc!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
100+
return Vᴴ', S, U', ϵ
101+
end
102+
else
103+
@eval function MAK.$f!(t::AdjointTensorMap, F, alg::AbstractAlgorithm)
104+
F′ = $f!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
105+
return reverse(adjoint.(F′))
106+
end
99107
end
100108

101109
# disambiguate by prohibition
@@ -111,6 +119,15 @@ function MAK.svd_compact!(t::AdjointTensorMap, F, alg::DiagonalAlgorithm)
111119
F′ = svd_compact!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
112120
return reverse(adjoint.(F′))
113121
end
122+
function MAK.initialize_output(
123+
::typeof(svd_trunc!), t::AdjointTensorMap, alg::TruncatedAlgorithm
124+
)
125+
return reverse(adjoint.(MAK.initialize_output(svd_trunc!, adjoint(t), _adjoint(alg))))
126+
end
127+
function MAK.svd_trunc!(t::AdjointTensorMap, F, alg::TruncatedAlgorithm)
128+
U, S, Vᴴ, ϵ = svd_trunc!(adjoint(t), reverse(adjoint.(F)), _adjoint(alg))
129+
return Vᴴ', S, U', ϵ
130+
end
114131

115132
function LinearAlgebra.isposdef(t::AdjointTensorMap)
116133
return isposdef(adjoint(t))

src/factorizations/diagonal.jl

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,6 @@ for f in (
1313
@eval MAK.copy_input(::typeof($f), d::DiagonalTensorMap) = copy(d)
1414
end
1515

16-
for f! in (:eig_full!, :eig_trunc!)
17-
@eval function MAK.initialize_output(
18-
::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm
19-
)
20-
return d, similar(d)
21-
end
22-
end
23-
24-
for f! in (:eigh_full!, :eigh_trunc!)
25-
@eval function MAK.initialize_output(
26-
::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm
27-
)
28-
if scalartype(d) <: Real
29-
return d, similar(d, space(d))
30-
else
31-
return similar(d, real(scalartype(d))), similar(d, space(d))
32-
end
33-
end
34-
end
35-
3616
for f! in (:qr_full!, :qr_compact!)
3717
@eval function MAK.initialize_output(
3818
::typeof($f!), d::AbstractTensorMap, ::DiagonalAlgorithm
@@ -93,7 +73,7 @@ end
9373
# For diagonal inputs we don't have to promote the scalartype since we know they are symmetric
9474
function MAK.initialize_output(::typeof(eig_vals!), t::AbstractTensorMap, alg::DiagonalAlgorithm)
9575
V_D = fuse(domain(t))
96-
Tc = scalartype(t)
76+
Tc = complex(scalartype(t))
9777
A = similarstoragetype(t, Tc)
9878
return SectorVector{Tc, sectortype(t), A}(undef, V_D)
9979
end

0 commit comments

Comments
 (0)