Skip to content

Commit cd6bcf2

Browse files
authored
Updates for more Enzyme activities (#255)
Properly test and handle `Const`-annotated floats and cases in which alpha has a different element type to A.
1 parent cce51ce commit cd6bcf2

7 files changed

Lines changed: 207 additions & 47 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
3030
TensorOperationsBumperExt = "Bumper"
3131
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
3232
TensorOperationsMooncakeExt = "Mooncake"
33-
TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"]
33+
TensorOperationsEnzymeExt = "Enzyme"
3434
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]
3535

3636
[compat]

ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl

Lines changed: 108 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,44 @@ using TensorOperations
44
using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
55
using VectorInterface
66
using TupleTools
7-
using Enzyme, ChainRulesCore
7+
using Enzyme
88
using Enzyme.EnzymeCore
99
using Enzyme.EnzymeCore: EnzymeRules
1010

1111
@inline EnzymeRules.inactive(::typeof(TensorOperations.tensorfree!), ::Any) = true
12-
Enzyme.@import_rrule(typeof(TensorOperations.tensoralloc), Any, Any, Any, Any)
13-
1412
@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true
1513
@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true
1614
@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true
1715
@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true
1816
@inline EnzymeRules.inactive_type(v::Type{<:Index2Tuple}) = true
1917

18+
function EnzymeRules.augmented_primal(
19+
config::EnzymeRules.RevConfigWidth{1},
20+
func::Const{typeof(TensorOperations.tensoralloc)},
21+
::Type{RT},
22+
ttype::Const,
23+
structure::Const,
24+
istemp::Const{Bool},
25+
allocator::Const
26+
) where {RT}
27+
primal = EnzymeRules.needs_primal(config) ? TensorOperations.tensoralloc(ttype.val, structure.val, Val(false), allocator.val) : nothing
28+
shadow = EnzymeRules.needs_shadow(config) ? TensorOperations.tensoralloc(ttype.val, structure.val, Val(false), allocator.val) : nothing
29+
return EnzymeRules.AugmentedReturn(primal, shadow, nothing)
30+
end
31+
32+
function EnzymeRules.reverse(
33+
config::EnzymeRules.RevConfigWidth{1},
34+
func::Const{typeof(TensorOperations.tensoralloc)},
35+
::Type{RT},
36+
cache,
37+
ttype::Const,
38+
structure::Const,
39+
istemp::Const{Bool},
40+
allocator::Const,
41+
) where {RT}
42+
return nothing, nothing, nothing, nothing
43+
end
44+
2045
function EnzymeRules.augmented_primal(
2146
config::EnzymeRules.RevConfigWidth{1},
2247
func::Const{typeof(TensorOperations.tensorcontract!)},
@@ -36,7 +61,7 @@ function EnzymeRules.augmented_primal(
3661
# form caches if needed
3762
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
3863
cache_B = EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing
39-
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
64+
cache_C = !isa(β_dβ, Const) ? copy(C_dC.val) : C_dC.val
4065
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
4166
TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...)
4267
primal = EnzymeRules.needs_primal(config) ? C_dC.val : nothing
@@ -66,15 +91,39 @@ function EnzymeRules.reverse(
6691
Bval = something(cache_B, B_dB.val)
6792
Cval = cache_C
6893
# good way to check that we don't use it accidentally when we should not be needing it?
69-
dC = C_dC.dval
70-
dA = A_dA.dval
71-
dB = B_dB.dval
7294
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
7395
α = α_dα.val
7496
β = β_dβ.val
7597
pA, pB, pAB, conjA, conjB = getfield.((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val)
76-
dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, β, ba...)
77-
return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
98+
99+
if !isa(A_dA, Const) && !isa(C_dC, Const)
100+
ΔC = C_dC.dval
101+
ΔA = A_dA.dval
102+
TensorOperations.tensorcontract_pullback_dA!(ΔA, ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba...)
103+
end
104+
if !isa(B_dB, Const) && !isa(C_dC, Const)
105+
ΔC = C_dC.dval
106+
ΔB = B_dB.dval
107+
TensorOperations.tensorcontract_pullback_dB!(ΔB, ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba...)
108+
end
109+
Δα = if !isa(α_dα, Const) && !isa(C_dC, Const)
110+
ΔC = C_dC.dval
111+
TensorOperations.tensorcontract_pullback_dα(ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba...)
112+
elseif !isa(α_dα, Const)
113+
zero(α_dα.val)
114+
else
115+
nothing
116+
end
117+
Δβ = if !isa(β_dβ, Const) && !isa(C_dC, Const)
118+
ΔC = C_dC.dval
119+
TensorOperations.pullback_dβ(ΔC, Cval, β)
120+
elseif !isa(β_dβ, Const)
121+
zero(β_dβ.val)
122+
else
123+
nothing
124+
end
125+
!isa(C_dC, Const) && TensorOperations.pullback_dC!(C_dC.dval, β)
126+
return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)...
78127
end
79128

80129
function EnzymeRules.augmented_primal(
@@ -123,10 +172,30 @@ function EnzymeRules.reverse(
123172
α = α_dα.val
124173
β = β_dβ.val
125174
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
126-
dC = C_dC.dval
127-
dA = A_dA.dval
128-
dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, pA, conjA, α, β, ba...)
129-
return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
175+
176+
if !isa(A_dA, Const) && !isa(C_dC, Const)
177+
ΔC = C_dC.dval
178+
ΔA = A_dA.dval
179+
TensorOperations.tensoradd_pullback_dA!(ΔA, ΔC, Cval, Aval, pA, conjA, α, ba...)
180+
end
181+
Δα = if !isa(α_dα, Const) && !isa(C_dC, Const)
182+
ΔC = C_dC.dval
183+
TensorOperations.tensoradd_pullback_dα(ΔC, Cval, Aval, pA, conjA, α, ba...)
184+
elseif !isa(α_dα, Const)
185+
zero(α_dα.val)
186+
else
187+
nothing
188+
end
189+
Δβ = if !isa(β_dβ, Const) && !isa(C_dC, Const)
190+
ΔC = C_dC.dval
191+
TensorOperations.pullback_dβ(ΔC, Cval, β)
192+
elseif !isa(β_dβ, Const)
193+
zero(β_dβ.val)
194+
else
195+
nothing
196+
end
197+
!isa(C_dC, Const) && TensorOperations.pullback_dC!(C_dC.dval, β)
198+
return nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)...
130199
end
131200

132201
function EnzymeRules.augmented_primal(
@@ -144,7 +213,7 @@ function EnzymeRules.augmented_primal(
144213
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
145214
# form caches if needed
146215
cache_A = EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
147-
cache_C = !iszero(β_dβ.val) ? copy(C_dC.val) : C_dC.val
216+
cache_C = !isa(β_dβ, Const) ? copy(C_dC.val) : nothing
148217
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
149218
α = α_dα.val
150219
β = β_dβ.val
@@ -171,17 +240,37 @@ function EnzymeRules.reverse(
171240
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
172241
cache_A, cache_C = cache
173242
Aval = something(cache_A, A_dA.val)
174-
Cval = cache_C
243+
Cval = something(cache_C, C_dC.val)
175244
p = p_dp.val
176245
q = q_dq.val
177246
conjA = conjA_dconjA.val
178247
α = α_dα.val
179248
β = β_dβ.val
180249
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
181-
dC = C_dC.dval
182-
dA = A_dA.dval
183-
dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, p, q, conjA, α, β, ba...)
184-
return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
250+
251+
if !isa(A_dA, Const) && !isa(C_dC, Const)
252+
ΔC = C_dC.dval
253+
ΔA = A_dA.dval
254+
TensorOperations.tensortrace_pullback_dA!(ΔA, ΔC, Cval, Aval, p, q, conjA, α, ba...)
255+
end
256+
Δα = if !isa(α_dα, Const) && !isa(C_dC, Const)
257+
ΔC = C_dC.dval
258+
TensorOperations.tensortrace_pullback_dα(ΔC, Cval, Aval, p, q, conjA, α, ba...)
259+
elseif !isa(α_dα, Const)
260+
zero(α_dα.val)
261+
else
262+
nothing
263+
end
264+
Δβ = if !isa(β_dβ, Const) && !isa(C_dC, Const)
265+
ΔC = C_dC.dval
266+
TensorOperations.pullback_dβ(ΔC, Cval, β)
267+
elseif !isa(β_dβ, Const)
268+
zero(β_dβ.val)
269+
else
270+
nothing
271+
end
272+
!isa(C_dC, Const) && TensorOperations.pullback_dC!(C_dC.dval, β)
273+
return nothing, nothing, nothing, nothing, nothing, Δα, Δβ, map(ba_ -> nothing, ba)...
185274
end
186275

187276
end

src/pullbacks/add.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,12 @@ Compute the pullback for [`tensoradd!]`(ref) with respect to scaling coefficient
4646
"""
4747
function tensoradd_pullback_dα(ΔC, C, A, pA::Index2Tuple, conjA::Bool, α, ba...)
4848
_needs_tangent(α) || return nothing
49-
return tensorscalar(
49+
Δα = tensorscalar(
5050
tensorcontract(
5151
A, repartition(pA, 0), !conjA,
5252
ΔC, trivialpermutation(numind(pA), 0), false,
5353
((), ()), One(), ba...
5454
)
5555
)
56+
return project_scalar(α, Δα)
5657
end

src/pullbacks/common.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ _needs_tangent(::Type{<:Integer}) = false
1414
_needs_tangent(::Type{<:Union{One, Zero}}) = false
1515
_needs_tangent(::Type{Complex{T}}) where {T} = _needs_tangent(T)
1616

17+
"""
18+
project_scalar(x::Number, dx::Number)
19+
20+
Project a computed tangent `dx` onto the correct tangent type for `x`.
21+
For example, we might compute a complex `dx` but only require the real part.
22+
"""
23+
project_scalar(x::Number, dx::Number) = oftype(x, dx)
24+
project_scalar(x::Real, dx::Complex) = project_scalar(x, real(dx))
25+
1726
# (partial) pullbacks that are shared
1827
@doc """
1928
pullback_dC(ΔC, β)
@@ -31,4 +40,4 @@ pullback_dC(ΔC, β) = scale(ΔC, conj(β))
3140
For functions of the form `f!(C, β, ...) = βC + ...`, compute the pullback with respect to `β`.
3241
""" pullback_dβ
3342

34-
pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? inner(C, ΔC) : nothing
43+
pullback_dβ(ΔC, C, β) = _needs_tangent(β) ? project_scalar(β, inner(C, ΔC)) : nothing

src/pullbacks/contract.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,5 @@ function tensorcontract_pullback_dα(
122122
)
123123
_needs_tangent(α) || return nothing
124124
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
125-
return inner(C_αβ, ΔC)
125+
return project_scalar(α, inner(C_αβ, ΔC))
126126
end

src/pullbacks/trace.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,13 @@ function tensortrace_pullback_dα(
8888
)
8989
_needs_tangent(α) || return nothing
9090
C_αβ = tensortrace(A, p, q, false, One(), ba...)
91-
return tensorscalar(
91+
Δα = tensorscalar(
9292
tensorcontract(
9393
C_αβ, trivialpermutation(0, numind(p)),
9494
!conjA,
9595
ΔC, trivialpermutation(numind(p), 0), false,
9696
((), ()), One(), ba...
9797
)
9898
)
99+
return project_scalar(α, Δα)
99100
end

0 commit comments

Comments
 (0)