Skip to content

Commit bd64864

Browse files
committed
Implement Tangent subtraction
1 parent d2b4b94 commit bd64864

2 files changed

Lines changed: 15 additions & 0 deletions

File tree

src/tangent_arithmetic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
146146
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a
147147

148148
Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)
149+
Base.:-(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P} = a + (-b)
149150

150151
# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
151152
# In general one doesn't have to represent multiplications of 2 tangents

test/tangent_types/structural_tangent.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,20 @@ end
358358
@test -1.0 * t == -t
359359
end
360360

361+
@test "subtraction" begin
362+
a = Tangent{Foo}(; x=2.0, y=-2.0)
363+
b = Tangent{Foo}(; x=1.0, y=2.0)
364+
@test (a - b) == Tangent{Foo}(; x=1.0, y=-4.0)
365+
366+
a = Tangent{Foo}(; x=2.0, y=-2.0)
367+
b = Tangent{Foo}(; x=1.0)
368+
@test (a - b) == Tangent{Foo}(; x=1.0, y=-2.0)
369+
370+
a = Tangent{Tuple{Float64, Float64}}(2.0, 3.0)
371+
b = Tangent{Tuple{Float64, Float64}}(1.0, 1.0)
372+
@test (a - b) == Tangent{Tuple{Float64, Float64}}(1.0, 2.0)
373+
end
374+
361375
@testset "scaling" begin
362376
@test (
363377
2 * Tangent{Foo}(; y=1.5, x=2.5) ==

0 commit comments

Comments
 (0)