Skip to content

Commit bb53c65

Browse files
pbrehmerlkdvos
andauthored
Update SVD reverse-rule broadening (#194)
- Actually apply broadening in SVD rrule for (quasi) degenerate singular values instead of just using a cutoff - Add a test for differentiating SVDs with degenerate singular values --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 0cda404 commit bb53c65

2 files changed

Lines changed: 66 additions & 34 deletions

File tree

src/utility/svd.jl

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -438,13 +438,21 @@ function ChainRulesCore.rrule(
438438
return (U, S, V, info), tsvd!_itersvd_pullback
439439
end
440440

441-
# scalar inverses with a cutoff tolerance and Lorentzian broadening
442-
function _safe_inv(x, tol, ε=0)
443-
if abs(x) < tol
444-
return zero(x)
445-
else
446-
return iszero(ε) ? inv(x) : _lorentz_broaden(x, ε)
441+
# scalar inverses with a cutoff tolerance
442+
_safe_inv(x, tol) = abs(x) < tol ? zero(x) : inv(x)
443+
444+
# compute inverse singular value difference contribution to SVD gradient with broadening ε
445+
function _broadened_inv_S(S::AbstractVector{T}, tol, ε=0) where {T}
446+
F = similar(S, (axes(S, 1), axes(S, 1)))
447+
@inbounds for j in axes(F, 2), i in axes(F, 1)
448+
F[i, j] = if i == j
449+
zero(T)
450+
else
451+
Δsᵢⱼ = S[j] - S[i]
452+
ε > 0 ? _lorentz_broaden(Δsᵢⱼ, ε) : _safe_inv(Δsᵢⱼ, tol)
453+
end
447454
end
455+
return F
448456
end
449457

450458
# Lorentzian broadening for divergent term in SVD rrule, see
@@ -554,9 +562,8 @@ function svd_pullback!(
554562
@info "`svd` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
555563
end
556564

557-
UdΔAV =
558-
(aUΔU .+ aVΔV) .* _safe_inv.(Sp' .- Sp, tol, broadening) .+
559-
(aUΔU .- aVΔV) .* _safe_inv.(Sp' .+ Sp, tol)
565+
inv_S_minus = _broadened_inv_S(Sp, tol, broadening) # possibly divergent/broadened contribution
566+
UdΔAV = @. (aUΔU + aVΔV) * inv_S_minus + (aUΔU - aVΔV) * _safe_inv(Sp' .+ Sp, tol)
560567
if !(ΔS isa ZeroTangent)
561568
UdΔAV[diagind(UdΔAV)] .+= real.(ΔS)
562569
# in principle, ΔS is real, but maybe not if coming from an anyonic tensor
@@ -585,16 +592,14 @@ function svd_pullback!(
585592
VrΔV = fill!(similar(Vd, (r - p, p)), 0)
586593
end
587594

588-
X =
589-
(1//2) .* (
590-
(UrΔU .+ VrΔV) .* _safe_inv.(Sp' .- Sr, tol, broadening) .+
591-
(UrΔU .- VrΔV) .* _safe_inv.(Sp' .+ Sr, tol)
592-
)
593-
Y =
594-
(1//2) .* (
595-
(UrΔU .+ VrΔV) .* _safe_inv.(Sp' .- Sr, tol, broadening) .-
596-
(UrΔU .- VrΔV) .* _safe_inv.(Sp' .+ Sr, tol)
597-
)
595+
X = @. (1//2) * (
596+
(UrΔU + VrΔV) * _safe_inv(Sp' - Sr, tol) +
597+
(UrΔU - VrΔV) * _safe_inv(Sp' + Sr, tol)
598+
)
599+
Y = @. (1//2) * (
600+
(UrΔU + VrΔV) * _safe_inv(Sp' - Sr, tol) -
601+
(UrΔU - VrΔV) * _safe_inv(Sp' + Sr, tol)
602+
)
598603

599604
# ΔA += Ur * X * Vp' + Up * Y' * Vr'
600605
mul!(ΔA, Ur, X * Vp', 1, 1)

test/utility/svd_wrapper.jl

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ rtol = 1e-9
2121
Random.seed!(123456789)
2222
r = randn(dtype, ℂ^m, ℂ^n)
2323
R = randn(space(r))
24-
broadenings = [10.0^k for k in -16:-4]
2524

2625
full_alg = SVDAdjoint(; rrule_alg=(; alg=:full, broadening=0))
2726
iter_alg = SVDAdjoint(; fwd_alg=(; alg=:iterative))
@@ -42,13 +41,29 @@ end
4241
@test g_fullsvd[1] g_itersvd[1] rtol = rtol
4342
end
4443

45-
@testset "Truncated SVD with χ= and ε= broadening" for ε in broadenings
46-
broadened_alg = @set full_alg.rrule_alg.broadening = ε
47-
l_unbroadened, g_unbroadened = withgradient(A -> lossfun(A, full_alg, R, trunc), r)
48-
l_broadened, g_broadened = withgradient(A -> lossfun(A, broadened_alg, R, trunc), r)
44+
@testset "Truncated SVD broadening" begin
45+
u, s, v, = tsvd(r)
46+
s.data[1:2:m] .= s.data[2:2:m] # make every singular value two-fold degenerate
47+
r_degen = u * s * v
4948

50-
@test l_unbroadened l_broadened
51-
@test 1e1 * norm(g_broadened[1]) * ε > norm(g_unbroadened[1] - g_broadened[1]) > ε
49+
no_broadening_no_cutoff_alg = @set full_alg.rrule_alg.broadening = 1e-30
50+
small_broadening_alg = @set full_alg.rrule_alg.broadening = 1e-13
51+
52+
l_only_cutoff, g_only_cutoff = withgradient(
53+
A -> lossfun(A, full_alg, R, trunc), r_degen
54+
) # cutoff sets degenerate difference to zero
55+
l_no_broadening_no_cutoff, g_no_broadening_no_cutoff = withgradient( # degenerate singular value differences lead to divergent contributions
56+
A -> lossfun(A, no_broadening_no_cutoff_alg, R, trunc),
57+
r_degen,
58+
)
59+
l_small_broadening, g_small_broadening = withgradient( # Lorentzian broadening smoothens divergent contributions
60+
A -> lossfun(A, small_broadening_alg, R, trunc),
61+
r_degen,
62+
)
63+
64+
@test l_only_cutoff l_no_broadening_no_cutoff l_small_broadening
65+
@test norm(g_no_broadening_no_cutoff[1] - g_small_broadening[1]) > 1e-1 # divergences mess up the gradient
66+
@test g_only_cutoff[1] g_small_broadening[1] rtol = rtol # cutoff and Lorentzian broadening have similar effect
5267
end
5368

5469
symm_m, symm_n = 18, 24
@@ -80,17 +95,29 @@ symm_R = randn(dtype, space(symm_r))
8095
@test g_fullsvd_tr[1] g_itersvd_fb[1] rtol = rtol
8196
end
8297

83-
@testset "Truncated symmetric SVD with χ= and ε= broadening" for ε in broadenings
84-
broadened_alg = @set full_alg.rrule_alg.broadening = ε
85-
l_unbroadened, g_unbroadened = withgradient(
86-
A -> lossfun(A, full_alg, symm_R, symm_trspace), symm_r
98+
@testset "Truncated symmetric SVD broadening" begin
99+
u, s, v, = tsvd(symm_r)
100+
s.data[1:2:m] .= s.data[2:2:m] # make every singular value two-fold degenerate
101+
symm_r_degen = u * s * v
102+
103+
no_broadening_no_cutoff_alg = @set full_alg.rrule_alg.broadening = 1e-30
104+
small_broadening_alg = @set full_alg.rrule_alg.broadening = 1e-13
105+
106+
l_only_cutoff, g_only_cutoff = withgradient(
107+
A -> lossfun(A, full_alg, symm_R, symm_trspace), symm_r_degen
108+
) # cutoff sets degenerate difference to zero
109+
l_no_broadening_no_cutoff, g_no_broadening_no_cutoff = withgradient( # degenerate singular value differences lead to divergent contributions
110+
A -> lossfun(A, no_broadening_no_cutoff_alg, symm_R, symm_trspace),
111+
symm_r_degen,
87112
)
88-
l_broadened, g_broadened = withgradient(
89-
A -> lossfun(A, broadened_alg, symm_R, symm_trspace), symm_r
113+
l_small_broadening, g_small_broadening = withgradient( # Lorentzian broadening smoothens divergent contributions
114+
A -> lossfun(A, small_broadening_alg, symm_R, symm_trspace),
115+
symm_r_degen,
90116
)
91117

92-
@test l_unbroadened l_broadened
93-
@test 1e1 * norm(g_broadened[1]) * ε > norm(g_unbroadened[1] - g_broadened[1]) > ε
118+
@test l_only_cutoff l_no_broadening_no_cutoff l_small_broadening
119+
@test norm(g_no_broadening_no_cutoff[1] - g_small_broadening[1]) > 1e-2 # divergences mess up the gradient
120+
@test g_only_cutoff[1] g_small_broadening[1] rtol = rtol # cutoff and Lorentzian broadening have similar effect
94121
end
95122

96123
# TODO: Add when IterSVD is implemented for HalfInfiniteEnv

0 commit comments

Comments
 (0)