Skip to content

Commit 11bf385

Browse files
committed
Fix some tests
- Fix doctests - Move _zeroed_backing() to a different file so it's defined before being used (otherwise causes failures on 1.13+) - Use `mergewith(f)` instead of the deprecated `merge(f)`
1 parent f9afcb4 commit 11bf385

5 files changed

Lines changed: 20 additions & 20 deletions

File tree

src/rules.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ unary input, unary output scalar function:
1818
```jldoctest frule
1919
julia> dself = NoTangent();
2020
21-
julia> x = rand()
22-
0.8236475079774124
21+
julia> x = 1.23456
22+
1.23456
2323
2424
julia> sinx, Δsinx = frule((dself, 1), sin, x)
25-
(0.7336293678134624, 0.6795498147167869)
25+
(0.9440031218347901, 0.3299365180851773)
2626
2727
julia> sinx == sin(x)
2828
true
@@ -51,7 +51,7 @@ that return a single output that is iterable, like a `Tuple`.
5151
So this is actually a [`Tangent`](@ref):
5252
```jldoctest frule
5353
julia> Δsincosx
54-
Tangent{Tuple{Float64, Float64}}(0.6795498147167869, -0.7336293678134624)
54+
Tangent{Tuple{Float64, Float64}}(0.3299365180851773, -0.9440031218347901)
5555
```
5656
5757
The optional [`RuleConfig`](@ref) option allows specifying frules only for AD systems that

src/tangent_arithmetic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ function Base.:+(a::P, d::StructuralTangent{P}) where {P}
142142
return construct(P, net_backing)
143143
end
144144
end
145-
Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
145+
Base.:+(a::Dict, d::Tangent{P}) where {P} = mergewith(+, a, backing(d))
146146
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a
147147

148148
Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)

src/tangent_types/abstract_zero.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,16 @@ zero_tangent(::Core.Compiler.AbstractInterpreter) = NoTangent()
200200
zero_tangent(::Core.Compiler.InstructionStream) = NoTangent()
201201
zero_tangent(::Core.CodeInfo) = NoTangent()
202202
zero_tangent(::Core.MethodInstance) = NoTangent()
203+
204+
205+
"""
206+
_zeroed_backing(P)
207+
208+
Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`.
209+
"""
210+
@generated function _zeroed_backing(::Type{P}) where {P}
211+
nil_base = ntuple(fieldcount(P)) do i
212+
(fieldname(P, i), ZeroTangent())
213+
end
214+
return (; nil_base...)
215+
end

src/tangent_types/structural_tangent.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -208,18 +208,6 @@ function backing(x::T)::NamedTuple where {T}
208208
end
209209
end
210210

211-
"""
212-
_zeroed_backing(P)
213-
214-
Returns a NamedTuple with same fields as `P`, and all values `ZeroTangent()`.
215-
"""
216-
@generated function _zeroed_backing(::Type{P}) where {P}
217-
nil_base = ntuple(fieldcount(P)) do i
218-
(fieldname(P, i), ZeroTangent())
219-
end
220-
return (; nil_base...)
221-
end
222-
223211
"""
224212
construct(::Type{T}, fields::[NamedTuple|Tuple])
225213
@@ -299,7 +287,7 @@ function elementwise_add(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn}
299287
end
300288
end
301289

302-
elementwise_add(a::Dict, b::Dict) = merge(+, a, b)
290+
elementwise_add(a::Dict, b::Dict) = mergewith(+, a, b)
303291

304292
struct PrimalAdditionFailedException{P} <: Exception
305293
primal::P

src/tangent_types/thunks.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,7 @@ To evaluate the wrapped closure, call [`unthunk`](@ref) which is a no-op when th
178178
argument is not a `Thunk`.
179179
180180
```jldoctest
181-
julia> t = @thunk(3)
182-
Thunk(var"#4#5"())
181+
julia> t = @thunk(3);
183182
184183
julia> unthunk(t)
185184
3

0 commit comments

Comments
 (0)