Skip to content

Commit 68cff49

Browse files
lkdvoskshyatt
andauthored
Setup Mooncake extension (#352)
* add Mooncake extension * start adding some rules * reorganize mooncake extension * reorganize tensorcontract_pullback * add tests * Add Mooncake compat --------- Co-authored-by: Katharine Hyatt <kslimes@gmail.com>
1 parent 8ccff5a commit 68cff49

8 files changed

Lines changed: 325 additions & 1 deletion

File tree

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2323
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2424
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2525
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
26+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2627

2728
[extensions]
2829
TensorKitAdaptExt = "Adapt"
2930
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
3031
TensorKitChainRulesCoreExt = "ChainRulesCore"
3132
TensorKitFiniteDifferencesExt = "FiniteDifferences"
33+
TensorKitMooncakeExt = "Mooncake"
3234

3335
[compat]
3436
Adapt = "4"
@@ -43,6 +45,7 @@ GPUArrays = "11.3.1"
4345
LRUCache = "1.0.2"
4446
LinearAlgebra = "1"
4547
MatrixAlgebraKit = "0.6.2"
48+
Mooncake = "0.4.183"
4649
OhMyThreads = "0.8.0"
4750
Printf = "1"
4851
Random = "1"
@@ -70,6 +73,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
7073
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
7174
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
7275
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
76+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
7377
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
7478
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
7579
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -78,4 +82,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7882
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
7983

8084
[targets]
81-
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
85+
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
module TensorKitMooncakeExt
2+
3+
using Mooncake
4+
using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal
5+
using TensorKit
6+
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
7+
import TensorOperations as TO
8+
using VectorInterface: One, Zero
9+
using TupleTools
10+
11+
12+
include("utility.jl")
13+
include("tangent.jl")
14+
include("linalg.jl")
15+
include("tensoroperations.jl")
16+
17+
end

ext/TensorKitMooncakeExt/linalg.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real}
2+
3+
function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real})
4+
t, Δt = arrayify(tΔt)
5+
p = primal(pdp)
6+
p == 2 || error("currently only implemented for p = 2")
7+
n = norm(t, p)
8+
function norm_pullback(Δn)
9+
x = (Δn' + Δn) / 2 / hypot(n, eps(one(n)))
10+
add!(Δt, t, x)
11+
return NoRData(), NoRData(), NoRData()
12+
end
13+
return CoDual(n, Mooncake.NoFData()), norm_pullback
14+
end
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
function Mooncake.arrayify(A_dA::CoDual{<:TensorMap})
2+
A = Mooncake.primal(A_dA)
3+
dA_fw = Mooncake.tangent(A_dA)
4+
data = dA_fw.data.data
5+
dA = typeof(A)(data, A.space)
6+
return A, dA
7+
end
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
Mooncake.@is_primitive(
2+
DefaultCtx,
3+
ReverseMode,
4+
Tuple{
5+
typeof(TO.tensorcontract!),
6+
AbstractTensorMap,
7+
AbstractTensorMap, Index2Tuple, Bool,
8+
AbstractTensorMap, Index2Tuple, Bool,
9+
Index2Tuple,
10+
Number, Number,
11+
Vararg{Any},
12+
}
13+
)
14+
15+
function Mooncake.rrule!!(
16+
::CoDual{typeof(TO.tensorcontract!)},
17+
C_ΔC::CoDual{<:AbstractTensorMap},
18+
A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool},
19+
B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, conjB_ΔconjB::CoDual{Bool},
20+
pAB_ΔpAB::CoDual{<:Index2Tuple},
21+
α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number},
22+
ba_Δba::CoDual...,
23+
)
24+
# prepare arguments
25+
(C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB))
26+
pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB))
27+
conjA, conjB = primal.((conjA_ΔconjA, conjB_ΔconjB))
28+
α, β = primal.((α_Δα, β_Δβ))
29+
ba = primal.(ba_Δba)
30+
31+
# primal call
32+
C_cache = copy(C)
33+
TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...)
34+
35+
function tensorcontract_pullback(::NoRData)
36+
copy!(C, C_cache)
37+
38+
ΔCr = tensorcontract_pullback_ΔC!(ΔC, β)
39+
ΔAr = tensorcontract_pullback_ΔA!(
40+
ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
41+
)
42+
ΔBr = tensorcontract_pullback_ΔB!(
43+
ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
44+
)
45+
Δαr = tensorcontract_pullback_Δα(
46+
ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
47+
)
48+
Δβr = tensorcontract_pullback_Δβ(ΔC, C, β)
49+
50+
return NoRData(), ΔCr,
51+
ΔAr, NoRData(), NoRData(),
52+
ΔBr, NoRData(), NoRData(),
53+
NoRData(),
54+
Δαr, Δβr,
55+
map(ba_ -> NoRData(), ba)...
56+
end
57+
58+
return C_ΔC, tensorcontract_pullback
59+
end
60+
61+
tensorcontract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData())
62+
63+
function tensorcontract_pullback_ΔA!(
64+
ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
65+
)
66+
ipAB = invperm(linearize(pAB))
67+
pΔC = _repartition(ipAB, TO.numout(pA))
68+
ipA = _repartition(invperm(linearize(pA)), A)
69+
conjΔC = conjA
70+
conjB′ = conjA ? conjB : !conjB
71+
72+
tB = twist(
73+
B,
74+
TupleTools.vcat(
75+
filter(x -> !isdual(space(B, x)), pB[1]),
76+
filter(x -> isdual(space(B, x)), pB[2])
77+
); copy = false
78+
)
79+
80+
TO.tensorcontract!(
81+
ΔA,
82+
ΔC, pΔC, conjΔC,
83+
tB, reverse(pB), conjB′,
84+
ipA,
85+
conjA ? α : conj(α), Zero(),
86+
ba...
87+
)
88+
89+
return NoRData()
90+
end
91+
92+
function tensorcontract_pullback_ΔB!(
93+
ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
94+
)
95+
ipAB = invperm(linearize(pAB))
96+
pΔC = _repartition(ipAB, TO.numout(pA))
97+
ipB = _repartition(invperm(linearize(pB)), B)
98+
conjΔC = conjB
99+
conjA′ = conjB ? conjA : !conjA
100+
101+
tA = twist(
102+
A,
103+
TupleTools.vcat(
104+
filter(x -> isdual(space(A, x)), pA[1]),
105+
filter(x -> !isdual(space(A, x)), pA[2])
106+
); copy = false
107+
)
108+
109+
TO.tensorcontract!(
110+
ΔB,
111+
tA, reverse(pA), conjA′,
112+
ΔC, pΔC, conjΔC,
113+
ipB,
114+
conjB ? α : conj(α), Zero(), ba...
115+
)
116+
117+
return NoRData()
118+
end
119+
120+
function tensorcontract_pullback_Δα(
121+
ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
122+
)
123+
Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α)))
124+
Tdα === NoRData && return NoRData()
125+
126+
AB = TO.tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
127+
Δα = inner(AB, ΔC)
128+
return Mooncake._rdata(Δα)
129+
end
130+
131+
function tensorcontract_pullback_Δβ(ΔC, C, β)
132+
Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β)))
133+
Tdβ === NoRData && return NoRData()
134+
135+
Δβ = inner(C, ΔC)
136+
return Mooncake._rdata(Δβ)
137+
end
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
_needs_tangent(x) = _needs_tangent(typeof(x))
2+
_needs_tangent(::Type{<:Number}) = true
3+
_needs_tangent(::Type{<:Integer}) = false
4+
_needs_tangent(::Type{<:Union{One, Zero}}) = false
5+
6+
# IndexTuple utility
7+
# ------------------
8+
trivtuple(N) = ntuple(identity, N)
9+
10+
Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int)
11+
length(p) >= N₁ ||
12+
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
13+
return TupleTools.getindices(p, trivtuple(N₁)),
14+
TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁)
15+
end
16+
Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int)
17+
return _repartition(linearize(p), N₁)
18+
end
19+
function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
20+
return _repartition(p, N₁)
21+
end
22+
function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap)
23+
return _repartition(p, TensorKit.numout(t))
24+
end
25+
26+
# Ignore derivatives
27+
# ------------------
28+
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any}

test/autodiff/mooncake.jl

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
using Test, TestExtras
2+
using TensorKit
3+
using TensorOperations
4+
using Mooncake
5+
using Random
6+
7+
mode = Mooncake.ReverseMode
8+
rng = Random.default_rng()
9+
is_primitive = false
10+
11+
function randindextuple(N::Int, k::Int = rand(0:N))
12+
@assert 0 k N
13+
_p = randperm(N)
14+
return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...))
15+
end
16+
17+
const _repartition = @static if isdefined(Base, :get_extension)
18+
Base.get_extension(TensorKit, :TensorKitMooncakeExt)._repartition
19+
else
20+
TensorKit.TensorKitMooncakeExt._repartition
21+
end
22+
23+
spacelist = (
24+
(ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
25+
(
26+
Vect[Z2Irrep](0 => 1, 1 => 1),
27+
Vect[Z2Irrep](0 => 1, 1 => 2)',
28+
Vect[Z2Irrep](0 => 2, 1 => 2)',
29+
Vect[Z2Irrep](0 => 2, 1 => 3),
30+
Vect[Z2Irrep](0 => 2, 1 => 2),
31+
),
32+
(
33+
Vect[FermionParity](0 => 1, 1 => 1),
34+
Vect[FermionParity](0 => 1, 1 => 2)',
35+
Vect[FermionParity](0 => 2, 1 => 1)',
36+
Vect[FermionParity](0 => 2, 1 => 3),
37+
Vect[FermionParity](0 => 2, 1 => 2),
38+
),
39+
(
40+
Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1),
41+
Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1),
42+
Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)',
43+
Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2),
44+
Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)',
45+
),
46+
(
47+
Vect[SU2Irrep](0 => 2, 1 // 2 => 1),
48+
Vect[SU2Irrep](0 => 1, 1 => 1),
49+
Vect[SU2Irrep](1 // 2 => 1, 1 => 1)',
50+
Vect[SU2Irrep](1 // 2 => 2),
51+
Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)',
52+
),
53+
(
54+
Vect[FibonacciAnyon](:I => 2, => 1),
55+
Vect[FibonacciAnyon](:I => 1, => 2)',
56+
Vect[FibonacciAnyon](:I => 2, => 2)',
57+
Vect[FibonacciAnyon](:I => 2, => 3),
58+
Vect[FibonacciAnyon](:I => 2, => 2),
59+
),
60+
)
61+
62+
for V in spacelist
63+
I = sectortype(eltype(V))
64+
Istr = TensorKit.type_repr(I)
65+
66+
symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding
67+
println("---------------------------------------")
68+
println("Mooncake with symmetry: $Istr")
69+
println("---------------------------------------")
70+
eltypes = (Float64,) # no complex support yet
71+
symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes
72+
atol = precision(T)
73+
rtol = precision(T)
74+
75+
@timedtestset "tensorcontract!" begin
76+
for _ in 1:5
77+
d = 0
78+
local V1, V2, V3
79+
# retry a couple times to make sure there are at least some nonzero elements
80+
for _ in 1:10
81+
k1 = rand(0:3)
82+
k2 = rand(0:2)
83+
k3 = rand(0:2)
84+
V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1]))
85+
V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1]))
86+
V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1]))
87+
d = min(dim(V1 V2), dim(V1' V2), dim(V2 V3), dim(V2' V3))
88+
d > 0 && break
89+
end
90+
ipA = randindextuple(length(V1) + length(V2))
91+
pA = _repartition(invperm(linearize(ipA)), length(V1))
92+
ipB = randindextuple(length(V2) + length(V3))
93+
pB = _repartition(invperm(linearize(ipB)), length(V2))
94+
pAB = randindextuple(length(V1) + length(V3))
95+
96+
α = randn(T)
97+
β = randn(T)
98+
V2_conj = prod(conj, V2; init = one(V[1]))
99+
100+
for conjA in (false, true), conjB in (false, true)
101+
A = randn(T, permute(V1 (conjA ? V2_conj : V2), ipA))
102+
B = randn(T, permute((conjB ? V2_conj : V2) V3, ipB))
103+
C = randn!(
104+
TensorOperations.tensoralloc_contract(
105+
T, A, pA, conjA, B, pB, conjB, pAB, Val(false)
106+
)
107+
)
108+
Mooncake.TestUtils.test_rule(
109+
rng, tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β;
110+
atol, rtol, mode, is_primitive
111+
)
112+
113+
end
114+
end
115+
end
116+
end
117+
end

0 commit comments

Comments
 (0)