Skip to content

Commit 7dc8fe6

Browse files
authored
A few more small fixes for upstream + CUDA (#373)
* Small fixes for upstream + CUDA * Suggestion * Use default memory * Have Base.ones and zeros accept CuArray * Try to finagle ones and zeros again * Fix CUDA test
1 parent 71d6c00 commit 7dc8fe6

3 files changed

Lines changed: 21 additions & 0 deletions

File tree

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ function Base.promote_rule(
138138
return CuTensorMap{T, S, N₁, N₂}
139139
end
140140

141+
TensorKit.promote_storage_rule(::Type{CuArray{T, N}}, ::Type{<:CuArray{T, N}}) where {T, N} =
142+
CuArray{T, N, CUDA.default_memory}
143+
144+
141145
# CuTensorMap exponentation:
142146
function TensorKit.exp!(t::CuTensorMap)
143147
domain(t) == codomain(t) ||

test/cuda/tensors.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ for V in spacelist
5555
@test domain(t) == one(W)
5656
@test typeof(t) == TensorMap{Float64, spacetype(t), 5, 0, CuVector{Float64, CUDA.DeviceMemory}}
5757
end
58+
for f in (Base.ones, Base.zeros)
59+
t = @constinferred f(CuVector{Float64, CUDA.DeviceMemory}, W)
60+
@test scalartype(t) == Float64
61+
@test codomain(t) == W
62+
@test space(t) == (W one(W))
63+
@test domain(t) == one(W)
64+
@test typeof(t) == TensorMap{Float64, spacetype(t), 5, 0, CuVector{Float64, CUDA.DeviceMemory}}
65+
end
5866
for f in (rand, randn)
5967
t = @constinferred f(CuVector{Float64, CUDA.DeviceMemory}, W)
6068
@test scalartype(t) == Float64

test/tensors/tensors.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ for V in spacelist
4444
@test space(t) == (W one(W))
4545
@test domain(t) == one(W)
4646
@test typeof(t) == TensorMap{T, spacetype(t), 5, 0, Vector{T}}
47+
# Array type input
48+
t = @constinferred zeros(Vector{T}, W)
49+
@test @constinferred(hash(t)) == hash(deepcopy(t))
50+
@test scalartype(t) == T
51+
@test norm(t) == 0
52+
@test codomain(t) == W
53+
@test space(t) == (W one(W))
54+
@test domain(t) == one(W)
55+
@test typeof(t) == TensorMap{T, spacetype(t), 5, 0, Vector{T}}
4756
# blocks
4857
bs = @constinferred blocks(t)
4958
if !isempty(blocksectors(t)) # multifusion space ending on module gives empty data

0 commit comments

Comments
 (0)