Skip to content

Commit d26b592

Browse files
kshyattJutho
andauthored
Add Enzyme rules (#243)
* Add Enzyme rules * Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Fix cache and simplify tests * Re-enable all tests * Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Update ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl Co-authored-by: Jutho <Jutho@users.noreply.github.com> * Remove irrelevant Mooncake test --------- Co-authored-by: Jutho <Jutho@users.noreply.github.com>
1 parent 7ad2286 commit d26b592

4 files changed

Lines changed: 307 additions & 2 deletions

File tree

Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
2323
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2424
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2525
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
26+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2627
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
2728

2829
[extensions]
2930
TensorOperationsBumperExt = "Bumper"
3031
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
3132
TensorOperationsMooncakeExt = "Mooncake"
33+
TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"]
3234
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]
3335

3436
[compat]
@@ -38,6 +40,8 @@ CUDA = "5"
3840
ChainRulesCore = "1"
3941
ChainRulesTestUtils = "1"
4042
DynamicPolynomials = "0.5, 0.6"
43+
Enzyme = "0.13.115"
44+
EnzymeTestUtils = "0.2"
4145
LRUCache = "1"
4246
LinearAlgebra = "1.6"
4347
Logging = "1.6"
@@ -59,13 +63,16 @@ julia = "1.10"
5963
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
6064
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
6165
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
66+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
6267
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
6368
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
69+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
70+
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
6471
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
6572
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
6673
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6774
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6875
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
6976

7077
[targets]
71-
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"]
78+
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "ChainRulesCore", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"]
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
module TensorOperationsEnzymeExt
2+
3+
using TensorOperations
4+
using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
5+
using VectorInterface
6+
using TupleTools
7+
using Enzyme, ChainRulesCore
8+
using Enzyme.EnzymeCore
9+
using Enzyme.EnzymeCore: EnzymeRules
10+
11+
@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true
12+
Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any)
13+
14+
@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true
15+
@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true
16+
@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true
17+
@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true
18+
@inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true
19+
20+
function EnzymeRules.augmented_primal(
21+
config::EnzymeRules.RevConfigWidth{1},
22+
func::Const{typeof(TensorOperations.tensorcontract!)},
23+
::Type{RT},
24+
C_dC::Annotation{<:AbstractArray{TC}},
25+
A_dA::Annotation{<:AbstractArray{TA}},
26+
pA_dpA::Const{<:Index2Tuple},
27+
conjA_dconjA::Const{Bool},
28+
B_dB::Annotation{<:AbstractArray{TB}},
29+
pB_dpB::Const{<:Index2Tuple},
30+
conjB_dconjB::Const{Bool},
31+
pAB_dpAB::Const{<:Index2Tuple},
32+
α_dα::Annotation{Tα},
33+
β_dβ::Annotation{Tβ},
34+
ba_dba::Const...,
35+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
36+
# form caches if needed
37+
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
38+
cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing
39+
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
40+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
41+
TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...)
42+
primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing
43+
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
44+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C))
45+
end
46+
47+
function EnzymeRules.reverse(
48+
config::EnzymeRules.RevConfigWidth{1},
49+
func::Const{typeof(TensorOperations.tensorcontract!)},
50+
::Type{RT},
51+
cache,
52+
C_dC::Annotation{<:AbstractArray{TC}},
53+
A_dA::Annotation{<:AbstractArray{TA}},
54+
pA_dpA::Const{<:Index2Tuple},
55+
conjA_dconjA::Const{Bool},
56+
B_dB::Annotation{<:AbstractArray{TB}},
57+
pB_dpB::Const{<:Index2Tuple},
58+
conjB_dconjB::Const{Bool},
59+
pAB_dpAB::Const{<:Index2Tuple},
60+
α_dα::Annotation{Tα},
61+
β_dβ::Annotation{Tβ},
62+
ba_dba::Const...,
63+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
64+
cache_A, cache_B, cache_C = cache
65+
Aval = something(cache_A, A_dA.val)
66+
Bval = something(cache_B, B_dB.val)
67+
Cval = cache_C
68+
# good way to check that we don't use it accidentally when we should not be needing it?
69+
dC = C_dC.dval
70+
dA = A_dA.dval
71+
dB = B_dB.dval
72+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
73+
α = α_dα.val
74+
β = β_dβ.val
75+
pA, pB, pAB, conjA, conjB = getfield.((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val)
76+
dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, β, ba...)
77+
return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
78+
end
79+
80+
function EnzymeRules.augmented_primal(
81+
config::EnzymeRules.RevConfigWidth{1},
82+
::Annotation{typeof(tensoradd!)},
83+
::Type{RT},
84+
C_dC::Annotation{<:AbstractArray{TC}},
85+
A_dA::Annotation{<:AbstractArray{TA}},
86+
pA_dpA::Const{<:Index2Tuple},
87+
conjA_dconjA::Const{Bool},
88+
α_dα::Annotation{Tα},
89+
β_dβ::Annotation{Tβ},
90+
ba_dba::Const...,
91+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
92+
# form caches if needed
93+
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
94+
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
95+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
96+
α = α_dα.val
97+
β = β_dβ.val
98+
conjA = conjA_dconjA.val
99+
TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...)
100+
primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing
101+
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
102+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C))
103+
end
104+
105+
function EnzymeRules.reverse(
106+
config::EnzymeRules.RevConfigWidth{1},
107+
::Annotation{typeof(tensoradd!)},
108+
::Type{RT},
109+
cache,
110+
C_dC::Annotation{<:AbstractArray{TC}},
111+
A_dA::Annotation{<:AbstractArray{TA}},
112+
pA_dpA::Const{<:Index2Tuple},
113+
conjA_dconjA::Const{Bool},
114+
α_dα::Annotation{Tα},
115+
β_dβ::Annotation{Tβ},
116+
ba_dba::Const...,
117+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
118+
cache_A, cache_C = cache
119+
Aval = something(cache_A, A_dA.val)
120+
Cval = cache_C
121+
pA = pA_dpA.val
122+
conjA = conjA_dconjA.val
123+
α = α_dα.val
124+
β = β_dβ.val
125+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
126+
dC = C_dC.dval
127+
dA = A_dA.dval
128+
dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, pA, conjA, α, β, ba...)
129+
return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
130+
end
131+
132+
function EnzymeRules.augmented_primal(
133+
config::EnzymeRules.RevConfigWidth{1},
134+
::Annotation{typeof(tensortrace!)},
135+
::Type{RT},
136+
C_dC::Annotation{<:AbstractArray{TC}},
137+
A_dA::Annotation{<:AbstractArray{TA}},
138+
p_dp::Const{<:Index2Tuple},
139+
q_dq::Const{<:Index2Tuple},
140+
conjA_dconjA::Const{Bool},
141+
α_dα::Annotation{Tα},
142+
β_dβ::Annotation{Tβ},
143+
ba_dba::Const...,
144+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
145+
# form caches if needed
146+
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
147+
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
148+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
149+
α = α_dα.val
150+
β = β_dβ.val
151+
conjA = conjA_dconjA.val
152+
TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...)
153+
primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing
154+
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
155+
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C))
156+
end
157+
158+
function EnzymeRules.reverse(
159+
config::EnzymeRules.RevConfigWidth{1},
160+
::Annotation{typeof(tensortrace!)},
161+
::Type{RT},
162+
cache,
163+
C_dC::Annotation{<:AbstractArray{TC}},
164+
A_dA::Annotation{<:AbstractArray{TA}},
165+
p_dp::Const{<:Index2Tuple},
166+
q_dq::Const{<:Index2Tuple},
167+
conjA_dconjA::Const{Bool},
168+
α_dα::Annotation{Tα},
169+
β_dβ::Annotation{Tβ},
170+
ba_dba::Const...,
171+
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
172+
cache_A, cache_C = cache
173+
Aval = something(cache_A, A_dA.val)
174+
Cval = cache_C
175+
p = p_dp.val
176+
q = q_dq.val
177+
conjA = conjA_dconjA.val
178+
α = α_dα.val
179+
β = β_dβ.val
180+
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
181+
dC = C_dC.dval
182+
dA = A_dA.dval
183+
dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, p, q, conjA, α, β, ba...)
184+
return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
185+
end
186+
187+
end

test/enzyme.jl

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
using TensorOperations, VectorInterface
2+
using Enzyme, ChainRulesCore, EnzymeTestUtils
3+
4+
@testset "tensorcontract!" begin
5+
pAB = ((3, 2, 4, 1), ())
6+
pA = ((2, 4, 5), (1, 3))
7+
pB = ((2, 1), (3,))
8+
@testset "($T₁, $T₂)" for (T₁, T₂) in (
9+
(Float64, Float64),
10+
(Float32, Float64),
11+
(ComplexF64, ComplexF64),
12+
(Float64, ComplexF64),
13+
(ComplexF64, Float64),
14+
)
15+
T = promote_type(T₁, T₂)
16+
atol = max(precision(T₁), precision(T₂))
17+
rtol = max(precision(T₁), precision(T₂))
18+
19+
A = rand(T₁, (2, 3, 4, 2, 5))
20+
B = rand(T₂, (4, 2, 3))
21+
C = rand(T, (5, 2, 3, 3))
22+
zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)))
23+
αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),)
24+
# test zeros only once to avoid wasteful tests
25+
@testset for (α, β) in αβs
26+
= α === Zero() ? Const : Active
27+
= β === Zero() ? Const : Active
28+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol)
29+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol)
30+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ); atol, rtol)
31+
32+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (B, Duplicated), (pB, Const), (false, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol)
33+
test_reverse(tensorcontract!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (B, Duplicated), (pB, Const), (true, Const), (pAB, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol)
34+
35+
end
36+
end
37+
end
38+
39+
@testset "tensoradd!" begin
40+
pA = ((2, 1, 4, 3, 5), ())
41+
@testset "($T₁, $T₂)" for (T₁, T₂) in (
42+
(Float64, Float64),
43+
(Float32, Float64),
44+
(ComplexF64, ComplexF64),
45+
(Float64, ComplexF64),
46+
)
47+
T = promote_type(T₁, T₂)
48+
atol = max(precision(T₁), precision(T₂))
49+
rtol = max(precision(T₁), precision(T₂))
50+
51+
A = rand(T₁, (2, 3, 4, 2, 1))
52+
C = rand(T₂, size.(Ref(A), pA[1]))
53+
zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)))
54+
αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),)
55+
# test zeros only once to avoid wasteful tests
56+
@testset for (α, β) in αβs
57+
= α === Zero() ? Const : Active
58+
= β === Zero() ? Const : Active
59+
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol)
60+
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol)
61+
62+
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (false, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol)
63+
test_reverse(tensoradd!, Duplicated, (C, Duplicated), (A, Duplicated), (pA, Const), (true, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol)
64+
end
65+
end
66+
end
67+
68+
@testset "tensortrace!" begin
69+
p = ((3, 5, 2), ())
70+
q = ((1,), (4,))
71+
@testset "($T₁, $T₂)" for (T₁, T₂) in
72+
(
73+
(Float64, Float64),
74+
(Float32, Float64),
75+
(ComplexF64, ComplexF64),
76+
(Float64, ComplexF64),
77+
)
78+
T = promote_type(T₁, T₂)
79+
atol = max(precision(T₁), precision(T₂))
80+
rtol = max(precision(T₁), precision(T₂))
81+
82+
A = rand(T₁, (2, 3, 4, 2, 5))
83+
C = rand(T₂, size.(Ref(A), p[1]))
84+
zero_αβs = ((Zero(), Zero()), (randn(T), Zero()), (Zero(), randn(T)))
85+
αβs = (T == T₁ == T₂ == Float64) ? vcat(zero_αβs..., (randn(T), randn(T))) : ((randn(T), randn(T)),)
86+
# test zeros only once to avoid wasteful tests
87+
@testset for (α, β) in αβs
88+
= α === Zero() ? Const : Active
89+
= β === Zero() ? Const : Active
90+
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ); atol, rtol)
91+
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ); atol, rtol)
92+
93+
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (true, Const), (α, Tα), (β, Tβ), (StridedBLAS(), Const); atol, rtol)
94+
test_reverse(tensortrace!, Duplicated, (C, Duplicated), (A, Duplicated), (p, Const), (q, Const), (false, Const), (α, Tα), (β, Tβ), (StridedNative(), Const); atol, rtol)
95+
end
96+
end
97+
end
98+
99+
@testset "tensorscalar ($T)" for T in (Float32, Float64, ComplexF64)
100+
atol = precision(T)
101+
rtol = precision(T)
102+
103+
C = Array{T, 0}(undef, ())
104+
fill!(C, rand(T))
105+
test_reverse(tensorscalar, Active, (C, Duplicated); atol, rtol)
106+
end

test/runtests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ precision(::Type{<:Union{Float64, Complex{Float64}}}) = 1.0e-8
1515
# specific ones
1616
is_buildkite = get(ENV, "BUILDKITE", "false") == "true"
1717
if !is_buildkite
18-
1918
@testset "tensoropt" verbose = true begin
2019
include("tensoropt.jl")
2120
end
@@ -37,6 +36,12 @@ if !is_buildkite
3736
@testset "mooncake" verbose = false begin
3837
include("mooncake.jl")
3938
end
39+
# mystery segfault on 1.10 for now
40+
@static if VERSION >= v"1.11.0"
41+
@testset "enzyme" verbose = false begin
42+
include("enzyme.jl")
43+
end
44+
end
4045
end
4146

4247
if is_buildkite

0 commit comments

Comments
 (0)