@@ -8,8 +8,8 @@ If you want to learn about `frule`s, you should still read and understand this e
88
99We define a struct ` Foo `
1010``` julia
11- struct Foo
12- A:: Matrix
11+ struct Foo{T}
12+ A:: Matrix{T}
1313 c:: Float64
1414end
1515```
@@ -25,17 +25,27 @@ Note that field `c` is ignored in the calculation.
2525
2626The ` rrule ` method for our primal computation should extend the ` ChainRulesCore.rrule ` function.
2727``` julia
28- function ChainRulesCore. rrule (:: typeof (foo_mul), foo:: Foo , b:: AbstractArray )
28+ function ChainRulesCore. rrule (:: typeof (foo_mul), foo:: Foo{T} , b:: AbstractArray ) where T
2929 y = foo_mul (foo, b)
3030 function foo_mul_pullback (ȳ)
3131 f̄ = NoTangent ()
32- f̄oo = Tangent {Foo} (; A= ȳ * b' , c= ZeroTangent ())
32+ f̄oo = Tangent {Foo{T} } (; A= ȳ * b' , c= ZeroTangent ())
3333 b̄ = @thunk (foo. A' * ȳ)
3434 return f̄, f̄oo, b̄
3535 end
3636 return y, foo_mul_pullback
3737end
3838```
39+
40+ We can check this rule against a finite-differences approach using [ ` ChainRulesTestUtils ` ] ( https://github.com/JuliaDiff/ChainRulesTestUtils.jl ) :
41+ ``` julia
42+ julia> using ChainRulesTestUtils
43+ julia> test_rrule (foo_mul, Foo (rand (3 , 3 ), 3.0 ), rand (3 , 3 ))
44+ Test Summary: | Pass Total
45+ test_rrule: foo_mul on Foo{Float64},Matrix{Float64} | 10 10
46+ Test. DefaultTestSet (" test_rrule: foo_mul on Foo{Float64},Matrix{Float64}" , Any[], 10 , false , false )
47+ ```
48+
3949Now let's examine the rule in more detail:
4050``` julia
4151function ChainRulesCore. rrule (:: typeof (foo_mul), foo:: Foo , b:: AbstractArray )
@@ -84,5 +94,5 @@ The idea is that in case the tangent is not used anywhere, the computation never
8494Use [ ` InplaceableThunk ` ] ( @ref ) if you are interested in [ accumulating gradients inplace] (@ref grad_acc).
8595Note that in practice one would also ` @thunk ` the ` f̄oo.A ` tangent, but it was omitted in this example for clarity.
8696
87- As a final note, Since ` b ` is an ` AbstractArray ` , its tangent ` b̄ ` should be projected to the right subspace.
97+ As a final note, since ` b ` is an ` AbstractArray ` , its tangent ` b̄ ` should be projected to the right subspace.
8898See the [ ` ProjectTo ` the primal subspace] (@ref projectto) section for more information and an example that motivates the projection operation.
0 commit comments