Skip to content

Commit ba3719a

Browse files
authored
Adapt for BraidingTensor (#374)
* implement adapt for BraidingTensor * add tests
1 parent 227b5f3 commit ba3719a

2 files changed

Lines changed: 7 additions & 0 deletions

File tree

ext/TensorKitAdaptExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@ function Adapt.adapt_structure(to, x::DiagonalTensorMap)
1515
data′ = adapt(to, x.data)
1616
return DiagonalTensorMap(data′, x.domain)
1717
end
18+
function Adapt.adapt_structure(::Type{TorA}, x::BraidingTensor) where {TorA <: Union{Number, DenseArray{<:Number}}}
19+
return BraidingTensor{scalartype(TorA)}(space(x), x.adjoint)
20+
end
1821

1922
end

test/tensors/planar.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Test, TestExtras
2+
using Adapt
23
using TensorKit
34
using TensorKit: PlanarTrivial, ℙ
45
using TensorKit: planaradd!, planartrace!, planarcontract!
@@ -19,6 +20,9 @@ using .TestSetup
1920
t2 = @constinferred BraidingTensor{ComplexF64}(W)
2021
@test scalartype(t2) == ComplexF64
2122
@test storagetype(t2) == Vector{ComplexF64}
23+
t3 = @testinferred adapt(storagetype(t2), t1)
24+
@test storagetype(t3) == storagetype(t2)
25+
@test t3 == t2
2226

2327
W2 = reverse(codomain(W)) domain(W)
2428
@test_throws SpaceMismatch BraidingTensor(W2)

0 commit comments

Comments
 (0)