Skip to content

Commit 9b088cf

Browse files
committed
CUDA: Fix multi-GPU data movement
1 parent 319a71e commit 9b088cf

2 files changed

Lines changed: 39 additions & 12 deletions

File tree

ext/CUDAExt.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,24 @@ function Dagger.move(from::CuArrayDeviceProc, to::CuArrayDeviceProc, x::Dagger.C
8080
end
8181
end
8282

83+
function Dagger.move(from_proc::CPUProc, to_proc::CuArrayDeviceProc, x::CuArray)
84+
# TODO: No extra allocations here
85+
if CUDA.device(x) == collect(CUDA.devices())[to_proc.device+1]
86+
return x
87+
end
88+
DaggerGPU.with_device(to_proc) do
89+
_x = similar(x)
90+
copyto!(_x, x)
91+
return _x
92+
end
93+
end
94+
95+
function Dagger.move(from_proc::CuArrayDeviceProc, to_proc::CPUProc, x::CuArray{T,N}) where {T,N}
96+
_x = Array{T,N}(undef, size(x))
97+
copyto!(_x, x)
98+
return _x
99+
end
100+
83101
function Dagger.execute!(proc::CuArrayDeviceProc, f, args...; kwargs...)
84102
@nospecialize f args kwargs
85103
tls = Dagger.get_tls()

test/runtests.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,28 +79,37 @@ end
7979
CuArrayDeviceProc
8080
end
8181
@test DaggerGPU.processor(:CUDA) === cuproc
82-
b = generate_thunks()
83-
c = Dagger.with_options(;scope=Dagger.scope(cuda_gpu=1)) do
84-
@test fetch(Dagger.@spawn isongpu(b))
85-
Dagger.@spawn sum(b)
82+
ndevices = length(collect(CUDA.devices()))
83+
84+
@testset "Arrays (GPU $gpu)" for gpu in 1:min(ndevices, 2)
85+
b = generate_thunks()
86+
c = Dagger.with_options(;scope=Dagger.scope(cuda_gpu=gpu)) do
87+
@test fetch(Dagger.@spawn isongpu(b))
88+
Dagger.@spawn sum(b)
89+
end
90+
@test !fetch(Dagger.@spawn isongpu(b))
91+
@test fetch(Dagger.@spawn identity(c)) == 20
8692
end
87-
@test !fetch(Dagger.@spawn isongpu(b))
88-
@test fetch(Dagger.@spawn identity(c)) == 20
8993

90-
@testset "KernelAbstractions" begin
94+
@testset "KernelAbstractions (GPU $gpu)" for gpu in 1:min(ndevices, 2)
9195
A = rand(Float32, 8)
92-
DA, T = Dagger.with_options(;scope=Dagger.scope(cuda_gpu=1)) do
96+
DA, T = Dagger.with_options(;scope=Dagger.scope(cuda_gpu=gpu)) do
9397
fetch(Dagger.@spawn fill_thunk(A, 2.3f0))
9498
end
9599
@test all(DA .== 2.3f0)
96100
@test T <: CuArray
97101

98-
A = CUDA.rand(128)
99-
B = CUDA.zeros(128)
100-
Dagger.with_options(;scope=Dagger.scope(worker=1,cuda_gpu=1)) do
102+
local A, B
103+
CUDA.device!(gpu-1) do
104+
A = CUDA.rand(128)
105+
B = CUDA.zeros(128)
106+
end
107+
Dagger.with_options(;scope=Dagger.scope(worker=1,cuda_gpu=gpu)) do
101108
fetch(Dagger.@spawn Kernel(copy_kernel)(B, A; ndrange=length(A)))
102109
end
103-
@test all(B .== A)
110+
CUDA.device!(gpu-1) do
111+
@test all(B .== A)
112+
end
104113
end
105114
end
106115
end

0 commit comments

Comments
 (0)