@@ -40,15 +40,28 @@ function ChainRulesCore.rrule(
4040 return output, tensoralloc_pullback
4141end
4242
43+ # this function more or less boils down to `fill!(similar(x), y)` but does so in a single
44+ # call to allow higher-order derivatives
45+ function similar_and_fill (x, y)
46+ x′ = TensorOperations. tensoralloc (typeof (x), TensorOperations. tensorstructure (x))
47+ return fill! (x′, y)
48+ end
49+ function ChainRulesCore. rrule (:: typeof (similar_and_fill), x, y)
50+ similar_and_fill_pullback (Δx) = NoTangent (), ZeroTangent (), tensorscalar (unthunk (Δx))
51+ return similar_and_fill (x, y), similar_and_fill_pullback
52+ end
4353function ChainRulesCore. rrule (:: typeof (tensorscalar), C)
44- projectC = ProjectTo (C)
45- function tensorscalar_pullback (Δc)
46- _Δc = unthunk (Δc)
47- return NoTangent (), projectC (_Δc)
48- end
54+ tensorscalar_pullback (Δc) = NoTangent (), similar_and_fill (C, unthunk (Δc))
4955 return tensorscalar (C), tensorscalar_pullback
5056end
5157
58+ # To avoid computing rrules for α and β when these aren't needed, we want to have a
59+ # type-stable quick bail-out
60+ _needs_tangent (x) = _needs_tangent (typeof (x))
61+ _needs_tangent (:: Type{<:Number} ) = true
62+ _needs_tangent (:: Type{<:Integer} ) = false
63+ _needs_tangent (:: Type{<:Union{One, Zero}} ) = false
64+
5265# The current `rrule` design makes sure that the implementation for custom types does
5366# not need to support the backend or allocator arguments
5467# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
@@ -99,26 +112,34 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
99112 _dA = tensoradd! (_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj (α), Zero (), ba... )
100113 projectA (_dA)
101114 end
102- dα = @thunk let
103- _dα = tensorscalar (
104- tensorcontract (
105- A, ((), linearize (pA)), ! conjA,
106- ΔC, (trivtuple (numind (pA)), ()), false ,
107- ((), ()), One (), ba...
115+ dα = if _needs_tangent (α)
116+ @thunk let
117+ _dα = tensorscalar (
118+ tensorcontract (
119+ A, ((), linearize (pA)), ! conjA,
120+ ΔC, (trivtuple (numind (pA)), ()), false ,
121+ ((), ()), One (), ba...
122+ )
108123 )
109- )
110- projectα (_dα)
124+ projectα (_dα)
125+ end
126+ else
127+ ZeroTangent ()
111128 end
112- dβ = @thunk let
113- # TODO : consider using `inner`
114- _dβ = tensorscalar (
115- tensorcontract (
116- C, ((), trivtuple (numind (pA))), true ,
117- ΔC, (trivtuple (numind (pA)), ()), false ,
118- ((), ()), One (), ba...
129+ dβ = if _needs_tangent (β)
130+ @thunk let
131+ # TODO : consider using `inner`
132+ _dβ = tensorscalar (
133+ tensorcontract (
134+ C, ((), trivtuple (numind (pA))), true ,
135+ ΔC, (trivtuple (numind (pA)), ()), false ,
136+ ((), ()), One (), ba...
137+ )
119138 )
120- )
121- projectβ (_dβ)
139+ projectβ (_dβ)
140+ end
141+ else
142+ ZeroTangent ()
122143 end
123144 dba = map (_ -> NoTangent (), ba)
124145 return NoTangent (), dC, dA, NoTangent (), NoTangent (), dα, dβ, dba...
@@ -212,28 +233,36 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
212233 )
213234 projectB (_dB)
214235 end
215- dα = @thunk let
216- C_αβ = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
217- # TODO : consider using `inner`
218- _dα = tensorscalar (
219- tensorcontract (
220- C_αβ, ((), trivtuple (numind (pAB))), true ,
221- ΔC, (trivtuple (numind (pAB)), ()), false ,
222- ((), ()), One (), ba...
236+ dα = if _needs_tangent (α)
237+ @thunk let
238+ C_αβ = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
239+ # TODO : consider using `inner`
240+ _dα = tensorscalar (
241+ tensorcontract (
242+ C_αβ, ((), trivtuple (numind (pAB))), true ,
243+ ΔC, (trivtuple (numind (pAB)), ()), false ,
244+ ((), ()), One (), ba...
245+ )
223246 )
224- )
225- projectα (_dα)
247+ projectα (_dα)
248+ end
249+ else
250+ ZeroTangent ()
226251 end
227- dβ = @thunk let
228- # TODO : consider using `inner`
229- _dβ = tensorscalar (
230- tensorcontract (
231- C, ((), trivtuple (numind (pAB))), true ,
232- ΔC, (trivtuple (numind (pAB)), ()), false ,
233- ((), ()), One (), ba...
252+ dβ = if _needs_tangent (β)
253+ @thunk let
254+ # TODO : consider using `inner`
255+ _dβ = tensorscalar (
256+ tensorcontract (
257+ C, ((), trivtuple (numind (pAB))), true ,
258+ ΔC, (trivtuple (numind (pAB)), ()), false ,
259+ ((), ()), One (), ba...
260+ )
234261 )
235- )
236- projectβ (_dβ)
262+ projectβ (_dβ)
263+ end
264+ else
265+ ZeroTangent ()
237266 end
238267 dba = map (_ -> NoTangent (), ba)
239268 return NoTangent (), dC,
@@ -301,27 +330,35 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
301330 )
302331 projectA (_dA)
303332 end
304- dα = @thunk let
305- C_αβ = tensortrace (A, p, q, false , One (), ba... )
306- _dα = tensorscalar (
307- tensorcontract (
308- C_αβ, ((), trivtuple (numind (p))),
309- ! conjA,
310- ΔC, (trivtuple (numind (p)), ()), false ,
311- ((), ()), One (), ba...
333+ dα = if _needs_tangent (α)
334+ @thunk let
335+ C_αβ = tensortrace (A, p, q, false , One (), ba... )
336+ _dα = tensorscalar (
337+ tensorcontract (
338+ C_αβ, ((), trivtuple (numind (p))),
339+ ! conjA,
340+ ΔC, (trivtuple (numind (p)), ()), false ,
341+ ((), ()), One (), ba...
342+ )
312343 )
313- )
314- projectα (_dα)
344+ projectα (_dα)
345+ end
346+ else
347+ ZeroTangent ()
315348 end
316- dβ = @thunk let
317- _dβ = tensorscalar (
318- tensorcontract (
319- C, ((), trivtuple (numind (p))), true ,
320- ΔC, (trivtuple (numind (p)), ()), false ,
321- ((), ()), One (), ba...
349+ dβ = if _needs_tangent (β)
350+ @thunk let
351+ _dβ = tensorscalar (
352+ tensorcontract (
353+ C, ((), trivtuple (numind (p))), true ,
354+ ΔC, (trivtuple (numind (p)), ()), false ,
355+ ((), ()), One (), ba...
356+ )
322357 )
323- )
324- projectβ (_dβ)
358+ projectβ (_dβ)
359+ end
360+ else
361+ ZeroTangent ()
325362 end
326363 dba = map (_ -> NoTangent (), ba)
327364 return NoTangent (), dC, dA, NoTangent (), NoTangent (), NoTangent (), dα, dβ, dba...
0 commit comments