@@ -4,19 +4,44 @@ using TensorOperations
44using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
55using VectorInterface
66using TupleTools
7- using Enzyme, ChainRulesCore
7+ using Enzyme
88using Enzyme. EnzymeCore
99using Enzyme. EnzymeCore: EnzymeRules
1010
1111@inline EnzymeRules. inactive (:: typeof (TensorOperations. tensorfree!), :: Any ) = true
12- Enzyme. @import_rrule (typeof (TensorOperations. tensoralloc), Any, Any, Any, Any)
13-
1412@inline EnzymeRules. inactive_type (v:: Type{<:AbstractBackend} ) = true
1513@inline EnzymeRules. inactive_type (v:: Type{DefaultAllocator} ) = true
1614@inline EnzymeRules. inactive_type (v:: Type{<:CUDAAllocator} ) = true
1715@inline EnzymeRules. inactive_type (v:: Type{ManualAllocator} ) = true
1816@inline EnzymeRules. inactive_type (v:: Type{<:Index2Tuple} ) = true
1917
18+ function EnzymeRules. augmented_primal (
19+ config:: EnzymeRules.RevConfigWidth{1} ,
20+ func:: Const{typeof(TensorOperations.tensoralloc)} ,
21+ :: Type{RT} ,
22+ ttype:: Const ,
23+ structure:: Const ,
24+ istemp:: Const{Bool} ,
25+ allocator:: Const
26+ ) where {RT}
27+ primal = EnzymeRules. needs_primal (config) ? TensorOperations. tensoralloc (ttype. val, structure. val, Val (false ), allocator. val) : nothing
28+ shadow = EnzymeRules. needs_shadow (config) ? TensorOperations. tensoralloc (ttype. val, structure. val, Val (false ), allocator. val) : nothing
29+ return EnzymeRules. AugmentedReturn (primal, shadow, nothing )
30+ end
31+
32+ function EnzymeRules. reverse (
33+ config:: EnzymeRules.RevConfigWidth{1} ,
34+ func:: Const{typeof(TensorOperations.tensoralloc)} ,
35+ :: Type{RT} ,
36+ cache,
37+ ttype:: Const ,
38+ structure:: Const ,
39+ istemp:: Const{Bool} ,
40+ allocator:: Const ,
41+ ) where {RT}
42+ return nothing , nothing , nothing , nothing
43+ end
44+
2045function EnzymeRules. augmented_primal (
2146 config:: EnzymeRules.RevConfigWidth{1} ,
2247 func:: Const{typeof(TensorOperations.tensorcontract!)} ,
@@ -36,7 +61,7 @@ function EnzymeRules.augmented_primal(
3661 # form caches if needed
3762 cache_A = EnzymeRules. overwritten (config)[3 ] ? copy (A_dA. val) : nothing
3863 cache_B = EnzymeRules. overwritten (config)[6 ] ? copy (B_dB. val) : nothing
39- cache_C = ! iszero (β_dβ. val ) ? copy (C_dC. val) : C_dC. val
64+ cache_C = ! isa (β_dβ, Const ) ? copy (C_dC. val) : C_dC. val
4065 ba = map (ba_ -> getfield (ba_, :val ), ba_dba)
4166 TensorOperations. tensorcontract! (C_dC. val, A_dA. val, pA_dpA. val, conjA_dconjA. val, B_dB. val, pB_dpB. val, conjB_dconjB. val, pAB_dpAB. val, α_dα. val, β_dβ. val, ba... )
4267 primal = EnzymeRules. needs_primal (config) ? C_dC. val : nothing
@@ -66,15 +91,39 @@ function EnzymeRules.reverse(
6691 Bval = something (cache_B, B_dB. val)
6792 Cval = cache_C
6893 # good way to check that we don't use it accidentally when we should not be needing it?
69- dC = C_dC. dval
70- dA = A_dA. dval
71- dB = B_dB. dval
7294 ba = map (ba_ -> getfield (ba_, :val ), ba_dba)
7395 α = α_dα. val
7496 β = β_dβ. val
7597 pA, pB, pAB, conjA, conjB = getfield .((pA_dpA, pB_dpB, pAB_dpAB, conjA_dconjA, conjB_dconjB), :val )
76- dC, dA, dB, dα, dβ = TensorOperations. tensorcontract_pullback! (dC, dA, dB, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, β, ba... )
77- return nothing , nothing , nothing , nothing , nothing , nothing , nothing , nothing , dα, dβ, map (ba_ -> nothing , ba)...
98+
99+ if ! isa (A_dA, Const) && ! isa (C_dC, Const)
100+ ΔC = C_dC. dval
101+ ΔA = A_dA. dval
102+ TensorOperations. tensorcontract_pullback_dA! (ΔA, ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba... )
103+ end
104+ if ! isa (B_dB, Const) && ! isa (C_dC, Const)
105+ ΔC = C_dC. dval
106+ ΔB = B_dB. dval
107+ TensorOperations. tensorcontract_pullback_dB! (ΔB, ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba... )
108+ end
109+ Δα = if ! isa (α_dα, Const) && ! isa (C_dC, Const)
110+ ΔC = C_dC. dval
111+ TensorOperations. tensorcontract_pullback_dα (ΔC, Cval, Aval, pA, conjA, Bval, pB, conjB, pAB, α, ba... )
112+ elseif ! isa (α_dα, Const)
113+ zero (α_dα. val)
114+ else
115+ nothing
116+ end
117+ Δβ = if ! isa (β_dβ, Const) && ! isa (C_dC, Const)
118+ ΔC = C_dC. dval
119+ TensorOperations. pullback_dβ (ΔC, Cval, β)
120+ elseif ! isa (β_dβ, Const)
121+ zero (β_dβ. val)
122+ else
123+ nothing
124+ end
125+ ! isa (C_dC, Const) && TensorOperations. pullback_dC! (C_dC. dval, β)
126+ return nothing , nothing , nothing , nothing , nothing , nothing , nothing , nothing , Δα, Δβ, map (ba_ -> nothing , ba)...
78127end
79128
80129function EnzymeRules. augmented_primal (
@@ -123,10 +172,30 @@ function EnzymeRules.reverse(
123172 α = α_dα. val
124173 β = β_dβ. val
125174 ba = map (ba_ -> getfield (ba_, :val ), ba_dba)
126- dC = C_dC. dval
127- dA = A_dA. dval
128- dC, dA, dα, dβ = TensorOperations. tensoradd_pullback! (dC, dA, Cval, Aval, pA, conjA, α, β, ba... )
129- return nothing , nothing , nothing , nothing , dα, dβ, map (ba_ -> nothing , ba)...
175+
176+ if ! isa (A_dA, Const) && ! isa (C_dC, Const)
177+ ΔC = C_dC. dval
178+ ΔA = A_dA. dval
179+ TensorOperations. tensoradd_pullback_dA! (ΔA, ΔC, Cval, Aval, pA, conjA, α, ba... )
180+ end
181+ Δα = if ! isa (α_dα, Const) && ! isa (C_dC, Const)
182+ ΔC = C_dC. dval
183+ TensorOperations. tensoradd_pullback_dα (ΔC, Cval, Aval, pA, conjA, α, ba... )
184+ elseif ! isa (α_dα, Const)
185+ zero (α_dα. val)
186+ else
187+ nothing
188+ end
189+ Δβ = if ! isa (β_dβ, Const) && ! isa (C_dC, Const)
190+ ΔC = C_dC. dval
191+ TensorOperations. pullback_dβ (ΔC, Cval, β)
192+ elseif ! isa (β_dβ, Const)
193+ zero (β_dβ. val)
194+ else
195+ nothing
196+ end
197+ ! isa (C_dC, Const) && TensorOperations. pullback_dC! (C_dC. dval, β)
198+ return nothing , nothing , nothing , nothing , Δα, Δβ, map (ba_ -> nothing , ba)...
130199end
131200
132201function EnzymeRules. augmented_primal (
@@ -144,7 +213,7 @@ function EnzymeRules.augmented_primal(
144213 ) where {RT, Tα <: Number , Tβ <: Number , TA <: Number , TC <: Number }
145214 # form caches if needed
146215 cache_A = EnzymeRules. overwritten (config)[3 ] ? copy (A_dA. val) : nothing
147- cache_C = ! iszero (β_dβ. val ) ? copy (C_dC. val) : C_dC . val
216+ cache_C = ! isa (β_dβ, Const ) ? copy (C_dC. val) : nothing
148217 ba = map (ba_ -> getfield (ba_, :val ), ba_dba)
149218 α = α_dα. val
150219 β = β_dβ. val
@@ -171,17 +240,37 @@ function EnzymeRules.reverse(
171240 ) where {RT, Tα <: Number , Tβ <: Number , TA <: Number , TC <: Number }
172241 cache_A, cache_C = cache
173242 Aval = something (cache_A, A_dA. val)
174- Cval = cache_C
243+ Cval = something ( cache_C, C_dC . val)
175244 p = p_dp. val
176245 q = q_dq. val
177246 conjA = conjA_dconjA. val
178247 α = α_dα. val
179248 β = β_dβ. val
180249 ba = map (ba_ -> getfield (ba_, :val ), ba_dba)
181- dC = C_dC. dval
182- dA = A_dA. dval
183- dC, dA, dα, dβ = TensorOperations. tensortrace_pullback! (dC, dA, Cval, Aval, p, q, conjA, α, β, ba... )
184- return nothing , nothing , nothing , nothing , nothing , dα, dβ, map (ba_ -> nothing , ba)...
250+
251+ if ! isa (A_dA, Const) && ! isa (C_dC, Const)
252+ ΔC = C_dC. dval
253+ ΔA = A_dA. dval
254+ TensorOperations. tensortrace_pullback_dA! (ΔA, ΔC, Cval, Aval, p, q, conjA, α, ba... )
255+ end
256+ Δα = if ! isa (α_dα, Const) && ! isa (C_dC, Const)
257+ ΔC = C_dC. dval
258+ TensorOperations. tensortrace_pullback_dα (ΔC, Cval, Aval, p, q, conjA, α, ba... )
259+ elseif ! isa (α_dα, Const)
260+ zero (α_dα. val)
261+ else
262+ nothing
263+ end
264+ Δβ = if ! isa (β_dβ, Const) && ! isa (C_dC, Const)
265+ ΔC = C_dC. dval
266+ TensorOperations. pullback_dβ (ΔC, Cval, β)
267+ elseif ! isa (β_dβ, Const)
268+ zero (β_dβ. val)
269+ else
270+ nothing
271+ end
272+ ! isa (C_dC, Const) && TensorOperations. pullback_dC! (C_dC. dval, β)
273+ return nothing , nothing , nothing , nothing , nothing , Δα, Δβ, map (ba_ -> nothing , ba)...
185274end
186275
187276end
0 commit comments