Skip to content

Commit 1e3d426

Browse files
authored
Merge pull request #660 from JuliaDiff/ox/subtract
Implement Tangent subtraction
2 parents d2b4b94 + c7e00c7 commit 1e3d426

2 files changed

Lines changed: 19 additions & 0 deletions

File tree

src/tangent_arithmetic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d))
146146
Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a
147147

148148
Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent)
149+
Base.:-(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P} = a + (-b)
149150

150151
# We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful
151152
# In general one doesn't have to represent multiplications of 2 tangents

test/tangent_types/structural_tangent.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,24 @@ end
358358
@test -1.0 * t == -t
359359
end
360360

361+
@testset "subtraction" begin
362+
a = Tangent{Foo}(; x=2.0, y=-2.0)
363+
b = Tangent{Foo}(; x=1.0, y=2.0)
364+
@test (a - b) == Tangent{Foo}(; x=1.0, y=-4.0)
365+
366+
a = Tangent{Foo}(; x=2.0, y=-2.0)
367+
b = Tangent{Foo}(; x=1.0)
368+
@test (a - b) == Tangent{Foo}(; x=1.0, y=-2.0)
369+
370+
a = Tangent{Tuple{Float64,Float64}}(2.0, 3.0)
371+
b = Tangent{Tuple{Float64,Float64}}(1.0, 1.0)
372+
@test (a - b) == Tangent{Tuple{Float64,Float64}}(1.0, 2.0)
373+
374+
a = MutableTangent{MFoo}(; x=1.5, y=1.5)
375+
b = MutableTangent{MFoo}(; x=0.5, y=0.5)
376+
@test (a - b) == MutableTangent{MFoo}(; x=1.0, y=1.0)
377+
end
378+
361379
@testset "scaling" begin
362380
@test (
363381
2 * Tangent{Foo}(; y=1.5, x=2.5) ==

0 commit comments

Comments
 (0)