Skip to content

Commit 3b826a7

Browse files
authored
convert(TensorMap, t) retains storagetype (#357)
1 parent b5a3ab5 commit 3b826a7

3 files changed

Lines changed: 8 additions & 4 deletions

File tree

src/tensors/adjoint.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ Base.@propagate_inbounds function subblock(t::AdjointTensorMap, (f₁, f₂)::Tu
5050
return permutedims(conj(data), (domainind(tp)..., codomainind(tp)...))
5151
end
5252

53+
to_cpu(t::AdjointTensorMap) = adjoint(to_cpu(adjoint(t)))
54+
5355
# Show
5456
#------
5557
function Base.showarg(io::IO, t::AdjointTensorMap, toplevel::Bool)

src/tensors/tensor.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,8 @@ end
532532
#---------------------------
533533
Base.convert(::Type{TensorMap}, t::TensorMap) = t
534534
function Base.convert(::Type{TensorMap}, t::AbstractTensorMap)
535-
return copy!(TensorMap{scalartype(t)}(undef, space(t)), t)
535+
A = storagetype(t)
536+
return copy!(TensorMapWithStorage{scalartype(A), A}(undef, space(t)), t)
536537
end
537538

538539
function Base.convert(::Type{TensorMapWithStorage{T, A}}, t::TensorMap) where {T, A}

test/cuda/tensors.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -258,14 +258,15 @@ for V in spacelist
258258
end
259259
end
260260
end
261-
@timedtestset "Tensor conversion" begin # TODO adjoint conversion methods don't work yet
261+
@timedtestset "Tensor conversion" begin
262262
W = V1 V2
263263
t = @constinferred CUDA.randn(W W)
264-
#@test typeof(convert(TensorMap, t')) == typeof(t) # TODO Adjoint not supported yet
264+
@test typeof(convert(typeof(t), t')) == typeof(t)
265+
@test typeof(TensorKit.to_cpu(t')) == typeof(TensorKit.to_cpu(t)')
265266
tc = complex(t)
266267
@test convert(typeof(tc), t) == tc
267268
@test typeof(convert(typeof(tc), t)) == typeof(tc)
268-
# @test typeof(convert(typeof(tc), t')) == typeof(tc) # TODO Adjoint not supported yet
269+
@test typeof(convert(typeof(tc), t')) == typeof(tc)
269270
@test Base.promote_typeof(t, tc) == typeof(tc)
270271
@test Base.promote_typeof(tc, t) == typeof(tc + t)
271272
end

0 commit comments

Comments
 (0)