Skip to content

Commit 7d7d710

Browse files
authored
🔧 Modify zygote ext to handle thunks (#35)
1 parent 9d3ac03 commit 7d7d710

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

ext/ReferenceFrameRotationsZygoteExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ using ReferenceFrameRotations
1010
using ForwardDiff
1111

1212
using Zygote.ChainRulesCore: ChainRulesCore
13-
import Zygote.ChainRulesCore: NoTangent
13+
import Zygote.ChainRulesCore: NoTangent, unthunk
1414

1515
function ChainRulesCore.rrule(::Type{<:DCM}, data::NTuple{9, T}) where {T}
1616
y = DCM(data)
1717

1818
function DCM_pullback(Δ)
19-
return (NoTangent(), Tuple(Δ))
19+
Δ_unthunked = unthunk(Δ)
20+
return (NoTangent(), Tuple(Δ_unthunked))
2021
end
2122

2223
return y, DCM_pullback
@@ -26,8 +27,9 @@ function ChainRulesCore.rrule(::typeof(orthonormalize), dcm::DCM)
2627
y = orthonormalize(dcm)
2728

2829
function orthonormalize_pullback(Δ)
30+
Δ_unthunked = unthunk(Δ)
2931
jac = ForwardDiff.jacobian(orthonormalize, dcm)
30-
return (NoTangent(), reshape(vcat(Δ...)' * jac, 3, 3))
32+
return (NoTangent(), reshape(vcat(Δ_unthunked...)' * jac, 3, 3))
3133
end
3234

3335
return y, orthonormalize_pullback

0 commit comments

Comments
 (0)