Skip to content

Commit 8b70a79

Browse files
authored
Move bugfixes Zygote tests to AD tests (#339)
1 parent 2631631 commit 8b70a79

2 files changed

Lines changed: 42 additions & 43 deletions

File tree

test/autodiff/ad.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,3 +628,45 @@ for V in spacelist
628628
end
629629
end
630630
end
631+
632+
# https://github.com/quantumkithub/TensorKit.jl/issues/201
633+
@testset "Issue #201" begin
634+
function f(A::AbstractTensorMap)
635+
U, S, V, = svd_compact(A)
636+
return tr(S)
637+
end
638+
function f(A::AbstractMatrix)
639+
S = LinearAlgebra.svdvals(A)
640+
return sum(S)
641+
end
642+
A₀ = randn(Z2Space(4, 4) Z2Space(4, 4))
643+
grad1, = Zygote.gradient(f, A₀)
644+
grad2, = Zygote.gradient(f, convert(Array, A₀))
645+
@test convert(Array, grad1) grad2
646+
647+
function g(A::AbstractTensorMap)
648+
U, S, V, = svd_compact(A)
649+
return tr(U * V)
650+
end
651+
function g(A::AbstractMatrix)
652+
U, S, V, = LinearAlgebra.svd(A)
653+
return tr(U * V')
654+
end
655+
B₀ = randn(ComplexSpace(4) ComplexSpace(4))
656+
grad3, = Zygote.gradient(g, B₀)
657+
grad4, = Zygote.gradient(g, convert(Array, B₀))
658+
@test convert(Array, grad3) grad4
659+
end
660+
661+
# https://github.com/quantumkithub/TensorKit.jl/issues/209
662+
@testset "Issue #209" begin
663+
function f(T, D)
664+
@tensor T[1, 4, 1, 3] * D[3, 4]
665+
end
666+
V = Z2Space(2, 2)
667+
D = DiagonalTensorMap(randn(4), V)
668+
T = randn(V V V V)
669+
g1, = Zygote.gradient(f, T, D)
670+
g2, = Zygote.gradient(f, T, TensorMap(D))
671+
@test g1 g2
672+
end

test/other/bugfixes.jl

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Test, TestExtras
22
using TensorKit
33
using TensorOperations
44
using LinearAlgebra: LinearAlgebra
5-
using Zygote
65

76
@testset "BugfixConvert" begin
87
v = randn(
@@ -43,45 +42,3 @@ end
4342
@test storagetype(t5) == Vector{Float64}
4443
tensorfree!(t2)
4544
end
46-
47-
# https://github.com/quantumkithub/TensorKit.jl/issues/201
48-
@testset "Issue #201" begin
49-
function f(A::AbstractTensorMap)
50-
U, S, V, = svd_compact(A)
51-
return tr(S)
52-
end
53-
function f(A::AbstractMatrix)
54-
S = LinearAlgebra.svdvals(A)
55-
return sum(S)
56-
end
57-
A₀ = randn(Z2Space(4, 4) Z2Space(4, 4))
58-
grad1, = Zygote.gradient(f, A₀)
59-
grad2, = Zygote.gradient(f, convert(Array, A₀))
60-
@test convert(Array, grad1) grad2
61-
62-
function g(A::AbstractTensorMap)
63-
U, S, V, = svd_compact(A)
64-
return tr(U * V)
65-
end
66-
function g(A::AbstractMatrix)
67-
U, S, V, = LinearAlgebra.svd(A)
68-
return tr(U * V')
69-
end
70-
B₀ = randn(ComplexSpace(4) ComplexSpace(4))
71-
grad3, = Zygote.gradient(g, B₀)
72-
grad4, = Zygote.gradient(g, convert(Array, B₀))
73-
@test convert(Array, grad3) grad4
74-
end
75-
76-
# https://github.com/quantumkithub/TensorKit.jl/issues/209
77-
@testset "Issue #209" begin
78-
function f(T, D)
79-
@tensor T[1, 4, 1, 3] * D[3, 4]
80-
end
81-
V = Z2Space(2, 2)
82-
D = DiagonalTensorMap(randn(4), V)
83-
T = randn(V V V V)
84-
g1, = Zygote.gradient(f, T, D)
85-
g2, = Zygote.gradient(f, T, TensorMap(D))
86-
@test g1 g2
87-
end

0 commit comments

Comments
 (0)