Skip to content

Commit 994a94f

Browse files
pbrehmerlkdvos
andauthored
Add modified tsvd! reverse-rule with Lorentzian broadening (#181)
* Add `FullSVDReverseRule` struct for modified `tsvd!` rule from TensorKit which supports Lorentzian broadening * Per default, broadening is activated with `Defaults.svd_rrule_broadening=1e-13` * The rrule output verbosity can be controlled through `FullSVDReverseRule` as well --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 3c7ed46 commit 994a94f

8 files changed

Lines changed: 254 additions & 56 deletions

File tree

src/Defaults.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@ Module containing default algorithm parameter values and arguments.
3030
* `svd_rrule_min_krylovdim=$(Defaults.svd_rrule_min_krylovdim)` : Minimal Krylov dimension of the reverse-rule algorithm (if it is a Krylov algorithm).
3131
* `svd_rrule_verbosity=$(Defaults.svd_rrule_verbosity)` : SVD gradient output verbosity.
3232
* `svd_rrule_alg=:$(Defaults.svd_rrule_alg)` : Reverse-rule algorithm for the SVD gradient.
33-
- `:tsvd`: Uses TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD, see [TensorKit](https://github.com/Jutho/TensorKit.jl/blob/f9cddcf97f8d001888a26f4dce7408d5c6e2228f/ext/TensorKitChainRulesCoreExt/factorizations.jl#L3)
33+
- `:full`: Uses a modified version of TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD, see [`FullSVDReverseRule`](@ref).
3434
- `:gmres`: GMRES iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.GMRES) for details
3535
- `:bicgstab`: BiCGStab iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BiCGStab) for details
3636
- `:arnoldi`: Arnoldi Krylov algorithm, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Arnoldi) for details
37+
* `svd_rrule_broadening=$(Defaults.svd_rrule_broadening)` : Lorentzian broadening amplitude which smoothens the divergent term in the SVD adjoint in case of (pseudo) degenerate singular values
3738
3839
## Projectors
3940
@@ -96,7 +97,8 @@ const svd_fwd_alg = :sdd # ∈ {:sdd, :svd, :iterative}
9697
const svd_rrule_tol = ctmrg_tol
9798
const svd_rrule_min_krylovdim = 48
9899
const svd_rrule_verbosity = -1
99-
const svd_rrule_alg = :tsvd # ∈ {:tsvd, :gmres, :bicgstab, :arnoldi}
100+
const svd_rrule_alg = :full # ∈ {:full, :gmres, :bicgstab, :arnoldi}
101+
const svd_rrule_broadening = 1e-13
100102
const krylovdim_factor = 1.4
101103

102104
# Projectors

src/PEPSKit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ include("algorithms/select_algorithm.jl")
7171

7272
using .Defaults: set_scheduler!
7373
export set_scheduler!
74-
export SVDAdjoint, IterSVD
74+
export SVDAdjoint, FullSVDReverseRule, IterSVD
7575
export CTMRGEnv, SequentialCTMRG, SimultaneousCTMRG
7676
export FixedSpaceTruncation, HalfInfiniteProjector, FullInfiniteProjector
7777
export LocalOperator

src/algorithms/ctmrg/projectors.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,7 @@ function svd_algorithm(alg::ProjectorAlgorithm, (dir, r, c))
5555
nothing,
5656
)
5757
end
58-
return SVDAdjoint(;
59-
fwd_alg=fix_svd,
60-
rrule_alg=alg.svd_alg.rrule_alg,
61-
broadening=alg.svd_alg.broadening,
62-
)
58+
return SVDAdjoint(; fwd_alg=fix_svd, rrule_alg=alg.svd_alg.rrule_alg)
6359
else
6460
return alg.svd_alg
6561
end

src/algorithms/optimization/fixed_point_differentiation.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ function _fix_svd_algorithm(alg::SVDAdjoint, signs, info)
263263
return SVDAdjoint(;
264264
fwd_alg=FixedSVD(U_fixed, info.S, V_fixed, U_full_fixed, info.S_full, V_full_fixed),
265265
rrule_alg=alg.rrule_alg,
266-
broadening=alg.broadening,
267266
)
268267
end
269268
function _fix_svd_algorithm(alg::SVDAdjoint{F}, signs, info) where {F<:IterSVD}
@@ -272,7 +271,6 @@ function _fix_svd_algorithm(alg::SVDAdjoint{F}, signs, info) where {F<:IterSVD}
272271
return SVDAdjoint(;
273272
fwd_alg=FixedSVD(U_fixed, info.S, V_fixed, nothing, nothing, nothing),
274273
rrule_alg=alg.rrule_alg,
275-
broadening=alg.broadening,
276274
)
277275
end
278276

src/utility/svd.jl

Lines changed: 218 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,31 @@ using TensorKit:
99
_create_svdtensors,
1010
_compute_truncdim,
1111
_compute_truncerr
12-
const TensorKitCRCExt = Base.get_extension(TensorKit, :TensorKitChainRulesCoreExt)
1312
const KrylovKitCRCExt = Base.get_extension(KrylovKit, :KrylovKitChainRulesCoreExt)
1413

14+
"""
15+
struct FullSVDReverseRule
16+
FullSVDReverseRule(; kwargs...)
17+
18+
SVD reverse-rule algorithm which uses a modified version of TensorKit's `tsvd!` reverse-rule
19+
allowing for Lorentzian broadening and output verbosity control.
20+
21+
## Keyword arguments
22+
23+
* `broadening::Float64=$(Defaults.svd_rrule_broadening)`: Lorentzian broadening amplitude for smoothing divergent term in SVD derivative in case of (pseudo) degenerate singular values.
24+
* `verbosity::Int=0`: Suppresses all output if `≤0`, print gauge dependency warnings if `1`, and always print gauge dependency if `≥2`.
25+
"""
26+
@kwdef struct FullSVDReverseRule
27+
broadening::Float64 = Defaults.svd_rrule_broadening
28+
verbosity::Int = 0
29+
end
30+
1531
"""
1632
struct SVDAdjoint
1733
SVDAdjoint(; kwargs...)
1834
1935
Wrapper for a SVD algorithm `fwd_alg` with a defined reverse rule `rrule_alg`.
2036
If `isnothing(rrule_alg)`, Zygote differentiates the forward call automatically.
21-
In case of degenerate singular values, one might need a `broadening` scheme which
22-
removes the divergences from the adjoint.
2337
2438
## Keyword arguments
2539
@@ -28,16 +42,14 @@ removes the divergences from the adjoint.
2842
- `:svd`: TensorKit's wrapper for LAPACK's `_gesvd`
2943
- `:iterative`: Iterative SVD only computing the specifed number of singular values and vectors, see ['IterSVD'](@ref)
3044
* `rrule_alg::Union{Algorithm,NamedTuple}=(; alg::Symbol=$(Defaults.svd_rrule_alg))`: Reverse-rule algorithm for differentiating the SVD. Can be supplied by an `Algorithm` instance directly or as a `NamedTuple` where `alg` is one of the following:
31-
- `:tsvd`: Uses TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD, see [TensorKit](https://github.com/Jutho/TensorKit.jl/blob/f9cddcf97f8d001888a26f4dce7408d5c6e2228f/ext/TensorKitChainRulesCoreExt/factorizations.jl#L3)
45+
- `:full`: Uses a modified version of TensorKit's reverse-rule for `tsvd` which doesn't solve any linear problem and instead requires access to the full SVD, see [`FullSVDReverseRule`](@ref).
3246
- `:gmres`: GMRES iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.GMRES) for details
3347
- `:bicgstab`: BiCGStab iterative linear solver, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.BiCGStab) for details
3448
- `:arnoldi`: Arnoldi Krylov algorithm, see the [KrylovKit docs](https://jutho.github.io/KrylovKit.jl/stable/man/algorithms/#KrylovKit.Arnoldi) for details
35-
* `broadening=nothing`: Broadening of singular value differences to stabilize the SVD gradient. Currently not implemented.
3649
"""
37-
struct SVDAdjoint{F,R,B}
50+
struct SVDAdjoint{F,R}
3851
fwd_alg::F
3952
rrule_alg::R
40-
broadening::B
4153
end # Keep truncation algorithm separate to be able to specify CTMRG dependent information
4254

4355
const SVD_FWD_SYMBOLS = IdDict{Symbol,Any}(
@@ -48,10 +60,10 @@ const SVD_FWD_SYMBOLS = IdDict{Symbol,Any}(
4860
IterSVD(; alg=GKL(; tol, krylovdim), kwargs...),
4961
)
5062
const SVD_RRULE_SYMBOLS = IdDict{Symbol,Type{<:Any}}(
51-
:tsvd => Nothing, :gmres => GMRES, :bicgstab => BiCGStab, :arnoldi => Arnoldi
63+
:full => FullSVDReverseRule, :gmres => GMRES, :bicgstab => BiCGStab, :arnoldi => Arnoldi
5264
)
5365

54-
function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;), broadening=nothing)
66+
function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;))
5567
# parse forward SVD algorithm
5668
fwd_algorithm = if fwd_alg isa NamedTuple
5769
fwd_kwargs = (; alg=Defaults.svd_fwd_alg, fwd_alg...) # overwrite with specified kwargs
@@ -70,6 +82,7 @@ function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;), broadening=nothing)
7082
alg=Defaults.svd_rrule_alg,
7183
tol=Defaults.svd_rrule_tol,
7284
krylovdim=Defaults.svd_rrule_min_krylovdim,
85+
broadening=Defaults.svd_rrule_broadening,
7386
verbosity=Defaults.svd_rrule_verbosity,
7487
rrule_alg...,
7588
) # overwrite with specified kwargs
@@ -79,23 +92,23 @@ function SVDAdjoint(; fwd_alg=(;), rrule_alg=(;), broadening=nothing)
7992
rrule_type = SVD_RRULE_SYMBOLS[rrule_kwargs.alg]
8093

8194
# IterSVD is incompatible with tsvd rrule -> default to Arnoldi
82-
if rrule_type <: Nothing && fwd_algorithm isa IterSVD
95+
if rrule_type <: FullSVDReverseRule && fwd_algorithm isa IterSVD
8396
rrule_type = Arnoldi
8497
end
8598

86-
if rrule_type <: Nothing
87-
nothing
99+
if rrule_type <: FullSVDReverseRule
100+
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg=nothing, tol=0.0, krylovdim=0)) # remove `alg`, `tol` and `krylovdim` keyword arguments
88101
else
89-
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg=nothing)) # remove `alg` keyword argument
102+
rrule_kwargs = Base.structdiff(rrule_kwargs, (; alg=nothing, broadening=0.0)) # remove `alg` and `broadening` keyword arguments
90103
rrule_type <: BiCGStab &&
91104
(rrule_kwargs = Base.structdiff(rrule_kwargs, (; krylovdim=nothing))) # BiCGStab doens't take `krylovdim`
92-
rrule_type(; rrule_kwargs...)
93105
end
106+
rrule_type(; rrule_kwargs...)
94107
else
95108
rrule_alg
96109
end
97110

98-
return SVDAdjoint(fwd_algorithm, rrule_algorithm, broadening)
111+
return SVDAdjoint(fwd_algorithm, rrule_algorithm)
99112
end
100113

101114
"""
@@ -245,7 +258,7 @@ end
245258
function TensorKit._compute_svddata!(
246259
f, alg::IterSVD, trunc::Union{NoTruncation,TruncationSpace}
247260
)
248-
InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:tsvd!)
261+
InnerProductStyle(f) === EuclideanInnerProduct() || throw_invalid_innerproduct(:full!)
249262
I = sectortype(f)
250263
dims = SectorDict{I,Int}()
251264

@@ -285,10 +298,10 @@ end
285298
function ChainRulesCore.rrule(
286299
::typeof(PEPSKit.tsvd!),
287300
t::AbstractTensorMap,
288-
alg::SVDAdjoint{F,R,B};
301+
alg::SVDAdjoint{F,R};
289302
trunc::TruncationScheme=TensorKit.NoTruncation(),
290303
p::Real=2,
291-
) where {F,R<:Nothing,B}
304+
) where {F,R<:FullSVDReverseRule}
292305
@assert !(alg.fwd_alg isa IterSVD) "IterSVD is not compatible with tsvd reverse-rule"
293306
Ũ, S̃, Ṽ⁺, info = tsvd(t, alg; trunc, p)
294307
U, S, V⁺ = info.U_full, info.S_full, info.V_full # untruncated SVD decomposition
@@ -306,8 +319,17 @@ function ChainRulesCore.rrule(
306319
ΔUc, ΔSc, ΔV⁺c = block(ΔU, c), block(ΔS, c), block(ΔV⁺, c)
307320
Sdc = view(Sc, diagind(Sc))
308321
ΔSdc = (ΔSc isa AbstractZero) ? ΔSc : view(ΔSc, diagind(ΔSc))
309-
TensorKitCRCExt.svd_pullback!(
310-
b, Uc, Sdc, V⁺c, ΔUc, ΔSdc, ΔV⁺c; tol=pullback_tol
322+
svd_pullback!(
323+
b,
324+
Uc,
325+
Sdc,
326+
V⁺c,
327+
ΔUc,
328+
ΔSdc,
329+
ΔV⁺c;
330+
tol=pullback_tol,
331+
broadening=alg.rrule_alg.broadening,
332+
verbosity=alg.rrule_alg.verbosity,
311333
)
312334
end
313335
return NoTangent(), Δt, NoTangent()
@@ -323,10 +345,10 @@ end
323345
function ChainRulesCore.rrule(
324346
::typeof(PEPSKit.tsvd!),
325347
f,
326-
alg::SVDAdjoint{F,R,B};
348+
alg::SVDAdjoint{F,R};
327349
trunc::TruncationScheme=notrunc(),
328350
p::Real=2,
329-
) where {F,R<:Union{GMRES,BiCGStab,Arnoldi},B}
351+
) where {F,R<:Union{GMRES,BiCGStab,Arnoldi}}
330352
U, S, V, info = tsvd(f, alg; trunc, p)
331353

332354
# update rrule_alg tolerance to be compatible with smallest singular value
@@ -389,3 +411,177 @@ function ChainRulesCore.rrule(
389411

390412
return (U, S, V, info), tsvd!_itersvd_pullback
391413
end
414+
415+
# scalar inverses with a cutoff tolerance and Lorentzian broadening
416+
function _safe_inv(x, tol, ε=0)
417+
if abs(x) < tol
418+
return zero(x)
419+
else
420+
return iszero(ε) ? inv(x) : _lorentz_broaden(x, ε)
421+
end
422+
end
423+
424+
# Lorentzian broadening for divergent term in SVD rrule, see
425+
# https://journals.aps.org/prresearch/abstract/10.1103/PhysRevResearch.7.013237
426+
function _lorentz_broaden(x, ε=eps(real(scalartype(x)))^(3 / 4))
427+
return x / (x^2 + ε)
428+
end
429+
430+
function _default_pullback_gaugetol(x)
431+
n = norm(x, Inf)
432+
return eps(eltype(n))^(3 / 4) * max(n, one(n))
433+
end
434+
435+
# SVD_pullback: pullback implementation for general (possibly truncated) SVD
436+
#
437+
# This is a modified version of TensorKit's pullback
438+
# https://github.com/Jutho/TensorKit.jl/blob/fa1551472ac74d7f2a61bdb2135cf418c8c53378/ext/TensorKitChainRulesCoreExt/factorizations.jl#L190)
439+
# with support for Lorentzian broadening and improved verbosity control
440+
#
441+
# Arguments are U, S and Vd of full (non-truncated, but still thin) SVD, as well as
442+
# cotangent ΔU, ΔS, ΔVd variables of truncated SVD
443+
#
444+
# Checks whether the cotangent variables are such that they would couple to gauge-dependent
445+
# degrees of freedom (phases of singular vectors), and prints a warning if this is the case
446+
#
447+
# An implementation that only uses U, S, and Vd from truncated SVD is also possible, but
448+
# requires solving a Sylvester equation, which does not seem to be supported on GPUs.
449+
#
450+
# Other implementation considerations for GPU compatibility:
451+
# no scalar indexing, lots of broadcasting and views
452+
#
453+
function svd_pullback!(
454+
ΔA::AbstractMatrix,
455+
U::AbstractMatrix,
456+
S::AbstractVector,
457+
Vd::AbstractMatrix,
458+
ΔU,
459+
ΔS,
460+
ΔVd;
461+
tol::Real=_default_pullback_gaugetol(S),
462+
broadening::Real=0,
463+
verbosity=1,
464+
)
465+
466+
# Basic size checks and determination
467+
m, n = size(U, 1), size(Vd, 2)
468+
size(U, 2) == size(Vd, 1) == length(S) == min(m, n) || throw(DimensionMismatch())
469+
p = -1
470+
if !(ΔU isa AbstractZero)
471+
m == size(ΔU, 1) || throw(DimensionMismatch())
472+
p = size(ΔU, 2)
473+
end
474+
if !(ΔVd isa AbstractZero)
475+
n == size(ΔVd, 2) || throw(DimensionMismatch())
476+
if p == -1
477+
p = size(ΔVd, 1)
478+
else
479+
p == size(ΔVd, 1) || throw(DimensionMismatch())
480+
end
481+
end
482+
if !(ΔS isa AbstractZero)
483+
if p == -1
484+
p = length(ΔS)
485+
else
486+
p == length(ΔS) || throw(DimensionMismatch())
487+
end
488+
end
489+
Up = view(U, :, 1:p)
490+
Vp = view(Vd, 1:p, :)'
491+
Sp = view(S, 1:p)
492+
493+
# rank
494+
r = searchsortedlast(S, tol; rev=true)
495+
496+
# compute antihermitian part of projection of ΔU and ΔV onto U and V
497+
# also already subtract this projection from ΔU and ΔV
498+
if !(ΔU isa AbstractZero)
499+
UΔU = Up' * ΔU
500+
aUΔU = rmul!(UΔU - UΔU', 1 / 2)
501+
if m > p
502+
ΔU -= Up * UΔU
503+
end
504+
else
505+
aUΔU = fill!(similar(U, (p, p)), 0)
506+
end
507+
if !(ΔVd isa AbstractZero)
508+
VΔV = Vp' * ΔVd'
509+
aVΔV = rmul!(VΔV - VΔV', 1 / 2)
510+
if n > p
511+
ΔVd -= VΔV' * Vp'
512+
end
513+
else
514+
aVΔV = fill!(similar(Vd, (p, p)), 0)
515+
end
516+
517+
# check whether cotangents arise from gauge-invariance objective function
518+
mask = abs.(Sp' .- Sp) .< tol
519+
Δgauge = norm(view(aUΔU, mask) + view(aVΔV, mask), Inf)
520+
if p > r
521+
rprange = (r + 1):p
522+
Δgauge = max(Δgauge, norm(view(aUΔU, rprange, rprange), Inf))
523+
Δgauge = max(Δgauge, norm(view(aVΔV, rprange, rprange), Inf))
524+
end
525+
if verbosity == 1 && Δgauge > tol # warn if verbosity is 1
526+
@warn "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
527+
elseif verbosity 2 # always info for debugging purposes
528+
@info "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
529+
end
530+
531+
UdΔAV =
532+
(aUΔU .+ aVΔV) .* _safe_inv.(Sp' .- Sp, tol, broadening) .+
533+
(aUΔU .- aVΔV) .* _safe_inv.(Sp' .+ Sp, tol)
534+
if !(ΔS isa ZeroTangent)
535+
UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
536+
# in principle, ΔS is real, but maybe not if coming from an anyonic tensor
537+
end
538+
mul!(ΔA, Up, UdΔAV * Vp')
539+
540+
if r > p # contribution from truncation
541+
Ur = view(U, :, (p + 1):r)
542+
Vr = view(Vd, (p + 1):r, :)'
543+
Sr = view(S, (p + 1):r)
544+
545+
if !(ΔU isa AbstractZero)
546+
UrΔU = Ur' * ΔU
547+
if m > r
548+
ΔU -= Ur * UrΔU # subtract this part from ΔU
549+
end
550+
else
551+
UrΔU = fill!(similar(U, (r - p, p)), 0)
552+
end
553+
if !(ΔVd isa AbstractZero)
554+
VrΔV = Vr' * ΔVd'
555+
if n > r
556+
ΔVd -= VrΔV' * Vr' # subtract this part from ΔV
557+
end
558+
else
559+
VrΔV = fill!(similar(Vd, (r - p, p)), 0)
560+
end
561+
562+
X =
563+
(1//2) .* (
564+
(UrΔU .+ VrΔV) .* _safe_inv.(Sp' .- Sr, tol, broadening) .+
565+
(UrΔU .- VrΔV) .* _safe_inv.(Sp' .+ Sr, tol)
566+
)
567+
Y =
568+
(1//2) .* (
569+
(UrΔU .+ VrΔV) .* _safe_inv.(Sp' .- Sr, tol, broadening) .-
570+
(UrΔU .- VrΔV) .* _safe_inv.(Sp' .+ Sr, tol)
571+
)
572+
573+
# ΔA += Ur * X * Vp' + Up * Y' * Vr'
574+
mul!(ΔA, Ur, X * Vp', 1, 1)
575+
mul!(ΔA, Up * Y', Vr', 1, 1)
576+
end
577+
578+
if m > max(r, p) && !(ΔU isa AbstractZero) # remaining ΔU is already orthogonal to U[:,1:max(p,r)]
579+
# ΔA += (ΔU .* _safe_inv.(Sp', tol)) * Vp'
580+
mul!(ΔA, ΔU .* _safe_inv.(Sp', tol), Vp', 1, 1)
581+
end
582+
if n > max(r, p) && !(ΔVd isa AbstractZero) # remaining ΔV is already orthogonal to V[:,1:max(p,r)]
583+
# ΔA += U * (_safe_inv.(Sp, tol) .* ΔVd)
584+
mul!(ΔA, Up, _safe_inv.(Sp, tol) .* ΔVd, 1, 1)
585+
end
586+
return ΔA
587+
end

0 commit comments

Comments
 (0)