Skip to content

Commit f1fd025

Browse files
authored
Avoid computing derivatives with respect to non-differentiable α, β (#236)
* insert opt-out for non-derivable alpha and beta * revert change to allocate device 0-dim arrays * restore higher-order differentiability
1 parent 4f03381 commit f1fd025

1 file changed

Lines changed: 96 additions & 59 deletions

File tree

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 96 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,28 @@ function ChainRulesCore.rrule(
4040
return output, tensoralloc_pullback
4141
end
4242

43+
# this function more or less boils down to `fill!(similar(x), y)` but does so in a single
44+
# call to allow higher-order derivatives
45+
function similar_and_fill(x, y)
46+
x′ = TensorOperations.tensoralloc(typeof(x), TensorOperations.tensorstructure(x))
47+
return fill!(x′, y)
48+
end
49+
function ChainRulesCore.rrule(::typeof(similar_and_fill), x, y)
50+
similar_and_fill_pullback(Δx) = NoTangent(), ZeroTangent(), tensorscalar(unthunk(Δx))
51+
return similar_and_fill(x, y), similar_and_fill_pullback
52+
end
4353
function ChainRulesCore.rrule(::typeof(tensorscalar), C)
44-
projectC = ProjectTo(C)
45-
function tensorscalar_pullback(Δc)
46-
_Δc = unthunk(Δc)
47-
return NoTangent(), projectC(_Δc)
48-
end
54+
tensorscalar_pullback(Δc) = NoTangent(), similar_and_fill(C, unthunk(Δc))
4955
return tensorscalar(C), tensorscalar_pullback
5056
end
5157

58+
# To avoid computing rrules for α and β when these aren't needed, we want to have a
59+
# type-stable quick bail-out
60+
_needs_tangent(x) = _needs_tangent(typeof(x))
61+
_needs_tangent(::Type{<:Number}) = true
62+
_needs_tangent(::Type{<:Integer}) = false
63+
_needs_tangent(::Type{<:Union{One, Zero}}) = false
64+
5265
# The current `rrule` design makes sure that the implementation for custom types does
5366
# not need to support the backend or allocator arguments
5467
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
@@ -99,26 +112,34 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
99112
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
100113
projectA(_dA)
101114
end
102-
= @thunk let
103-
_dα = tensorscalar(
104-
tensorcontract(
105-
A, ((), linearize(pA)), !conjA,
106-
ΔC, (trivtuple(numind(pA)), ()), false,
107-
((), ()), One(), ba...
115+
= if _needs_tangent(α)
116+
@thunk let
117+
_dα = tensorscalar(
118+
tensorcontract(
119+
A, ((), linearize(pA)), !conjA,
120+
ΔC, (trivtuple(numind(pA)), ()), false,
121+
((), ()), One(), ba...
122+
)
108123
)
109-
)
110-
projectα(_dα)
124+
projectα(_dα)
125+
end
126+
else
127+
ZeroTangent()
111128
end
112-
= @thunk let
113-
# TODO: consider using `inner`
114-
_dβ = tensorscalar(
115-
tensorcontract(
116-
C, ((), trivtuple(numind(pA))), true,
117-
ΔC, (trivtuple(numind(pA)), ()), false,
118-
((), ()), One(), ba...
129+
= if _needs_tangent(β)
130+
@thunk let
131+
# TODO: consider using `inner`
132+
_dβ = tensorscalar(
133+
tensorcontract(
134+
C, ((), trivtuple(numind(pA))), true,
135+
ΔC, (trivtuple(numind(pA)), ()), false,
136+
((), ()), One(), ba...
137+
)
119138
)
120-
)
121-
projectβ(_dβ)
139+
projectβ(_dβ)
140+
end
141+
else
142+
ZeroTangent()
122143
end
123144
dba = map(_ -> NoTangent(), ba)
124145
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
@@ -212,28 +233,36 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
212233
)
213234
projectB(_dB)
214235
end
215-
= @thunk let
216-
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
217-
# TODO: consider using `inner`
218-
_dα = tensorscalar(
219-
tensorcontract(
220-
C_αβ, ((), trivtuple(numind(pAB))), true,
221-
ΔC, (trivtuple(numind(pAB)), ()), false,
222-
((), ()), One(), ba...
236+
= if _needs_tangent(α)
237+
@thunk let
238+
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
239+
# TODO: consider using `inner`
240+
_dα = tensorscalar(
241+
tensorcontract(
242+
C_αβ, ((), trivtuple(numind(pAB))), true,
243+
ΔC, (trivtuple(numind(pAB)), ()), false,
244+
((), ()), One(), ba...
245+
)
223246
)
224-
)
225-
projectα(_dα)
247+
projectα(_dα)
248+
end
249+
else
250+
ZeroTangent()
226251
end
227-
= @thunk let
228-
# TODO: consider using `inner`
229-
_dβ = tensorscalar(
230-
tensorcontract(
231-
C, ((), trivtuple(numind(pAB))), true,
232-
ΔC, (trivtuple(numind(pAB)), ()), false,
233-
((), ()), One(), ba...
252+
= if _needs_tangent(β)
253+
@thunk let
254+
# TODO: consider using `inner`
255+
_dβ = tensorscalar(
256+
tensorcontract(
257+
C, ((), trivtuple(numind(pAB))), true,
258+
ΔC, (trivtuple(numind(pAB)), ()), false,
259+
((), ()), One(), ba...
260+
)
234261
)
235-
)
236-
projectβ(_dβ)
262+
projectβ(_dβ)
263+
end
264+
else
265+
ZeroTangent()
237266
end
238267
dba = map(_ -> NoTangent(), ba)
239268
return NoTangent(), dC,
@@ -301,27 +330,35 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
301330
)
302331
projectA(_dA)
303332
end
304-
= @thunk let
305-
C_αβ = tensortrace(A, p, q, false, One(), ba...)
306-
_dα = tensorscalar(
307-
tensorcontract(
308-
C_αβ, ((), trivtuple(numind(p))),
309-
!conjA,
310-
ΔC, (trivtuple(numind(p)), ()), false,
311-
((), ()), One(), ba...
333+
= if _needs_tangent(α)
334+
@thunk let
335+
C_αβ = tensortrace(A, p, q, false, One(), ba...)
336+
_dα = tensorscalar(
337+
tensorcontract(
338+
C_αβ, ((), trivtuple(numind(p))),
339+
!conjA,
340+
ΔC, (trivtuple(numind(p)), ()), false,
341+
((), ()), One(), ba...
342+
)
312343
)
313-
)
314-
projectα(_dα)
344+
projectα(_dα)
345+
end
346+
else
347+
ZeroTangent()
315348
end
316-
= @thunk let
317-
_dβ = tensorscalar(
318-
tensorcontract(
319-
C, ((), trivtuple(numind(p))), true,
320-
ΔC, (trivtuple(numind(p)), ()), false,
321-
((), ()), One(), ba...
349+
= if _needs_tangent(β)
350+
@thunk let
351+
_dβ = tensorscalar(
352+
tensorcontract(
353+
C, ((), trivtuple(numind(p))), true,
354+
ΔC, (trivtuple(numind(p)), ()), false,
355+
((), ()), One(), ba...
356+
)
322357
)
323-
)
324-
projectβ(_dβ)
358+
projectβ(_dβ)
359+
end
360+
else
361+
ZeroTangent()
325362
end
326363
dba = map(_ -> NoTangent(), ba)
327364
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...

0 commit comments

Comments
 (0)