@@ -10,18 +10,59 @@ struct CuArrayDeviceProc <: Dagger.Processor
1010 device:: Int
1111end
1212@gpuproc (CuArrayDeviceProc, CuArray)
13- #= FIXME : DtoD copies and CUDA IPC
14- function Dagger.move(from::CuArrayDeviceProc, to::CuArrayDeviceProc, x)
15- if from === to
16- return x
13+ Dagger. get_parent (proc:: CuArrayDeviceProc ) = Dagger. OSProc (proc. owner)
14+
15+ # function can_access(this, peer)
16+ # status = Ref{Cint}()
17+ # CUDA.cuDeviceCanAccessPeer(status, this, peer)
18+ # return status[] == 1
19+ # end
20+
21+ function Dagger. move (from:: CuArrayDeviceProc , to:: CuArrayDeviceProc , x:: Dagger.Chunk{T} ) where T<: CuArray
22+ if from == to
23+ # Same process and GPU, no change
24+ poolget (x. handle)
25+ elseif from. owner == to. owner
26+ # Same process but different GPUs, use DtoD copy
27+ from_arr = poolget (x. handle)
28+ to_arr = CUDA. device! (to. device) do
29+ CuArray {T,N} (undef, size)
30+ end
31+ copyto! (to_arr, from_arr)
32+ to_arr
33+ elseif Dagger. system_uuid (from. owner) == Dagger. system_uuid (to. owner)
34+ # Same node, we can use IPC
35+ ipc_handle, eT, shape = remotecall_fetch (from. owner, x. handle) do h
36+ arr = poolget (h)
37+ ipc_handle_ref = Ref {CUDA.CUipcMemHandle} ()
38+ GC. @preserve arr begin
39+ CUDA. cuIpcGetMemHandle (ipc_handle_ref, pointer (arr))
40+ end
41+ (ipc_handle_ref[], eltype (arr), size (arr))
42+ end
43+ r_ptr = Ref {CUDA.CUdeviceptr} ()
44+ CUDA. device! (from. device) do # FIXME : Assumes that device IDs are identical across processes
45+ CUDA. cuIpcOpenMemHandle (r_ptr, ipc_handle, CUDA. CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS)
46+ end
47+ ptr = Base. unsafe_convert (CUDA. CuPtr{eT}, r_ptr[])
48+ arr = unsafe_wrap (CuArray, ptr, shape; own= false )
49+ finalizer (arr) do arr
50+ CUDA. cuIpcCloseMemHandle (pointer (arr))
51+ end
52+ # FIXME : Deal with to_proc being a different GPU
1753 else
18- error("Not implemented")
54+ # Different node, use DtoH, serialization, HtoD
55+ # TODO UCX
56+ CuArray (remotecall_fetch (from. owner, x. handle) do h
57+ Array (poolget (h))
58+ end )
1959 end
2060end
21- =#
61+
2262function Dagger. execute! (proc:: CuArrayDeviceProc , func, args... )
63+ tls = Dagger. get_tls ()
2364 fetch (Threads. @spawn begin
24- task_local_storage ( :processor , proc )
65+ Dagger . set_tls! (tls )
2566 CUDA. device! (proc. device)
2667 CUDA. @sync func (args... )
2768 end )
@@ -35,7 +76,7 @@ kernel_backend(::CuArrayDeviceProc) = CUDADevice()
3576
3677if CUDA. has_cuda ()
3778 for dev in devices ()
38- Dagger. add_callback! (proc -> begin
79+ Dagger. add_callback! (() -> begin
3980 return CuArrayDeviceProc (Distributed. myid (), #= CuContext(dev),=# dev. handle)
4081 end )
4182 end
0 commit comments