@@ -3,10 +3,74 @@ import .Metal: MtlArray, MtlDevice
33
44struct MtlArrayDeviceProc <: Dagger.Processor
55 owner:: Int
6- device:: MtlDevice
6+ device_id:: UInt64
7+ end
8+
9+ # Assume that we can run anything.
10+ Dagger. iscompatible_func (proc:: MtlArrayDeviceProc , opts, f) = true
11+ Dagger. iscompatible_arg (proc:: MtlArrayDeviceProc , opts, x) = true
12+
13+ # CPUs shouldn't process our array type.
14+ Dagger. iscompatible_arg (proc:: Dagger.ThreadProc , opts, x:: MtlArray ) = false
15+
16+ function Dagger. move (from_proc:: OSProc , to_proc:: MtlArrayDeviceProc , x:: Chunk )
17+ from_pid = from_proc. pid
18+ to_pid = Dagger. get_parent (to_proc). pid
19+ @assert myid () == to_pid
20+
21+ return Dagger. move (from_proc, to_proc, remotecall_fetch (x-> poolget (x. handle), from_pid, x))
22+ end
23+
24+ function Dagger. move (from_proc:: MtlArrayDeviceProc , to_proc:: OSProc , x:: Chunk )
25+ from_pid = Dagger. get_parent (from_proc). pid
26+ to_pid = to_proc. pid
27+ @assert myid () == to_pid
28+
29+ return remotecall_fetch (from_pid, x) do x
30+ mtlarray = poolget (x. handle)
31+ return Dagger. move (from_proc, to_proc, mtlarray)
32+ end
33+ end
34+
35+ function Dagger. move (
36+ from_proc:: OSProc ,
37+ to_proc:: MtlArrayDeviceProc ,
38+ x:: Array{T, N}
39+ ) where {T, N}
40+ # If we have unified memory, we can try casting the `Array` to `MtlArray`.
41+ device = _get_metal_device (to_proc)
42+
43+ if (device != = nothing ) && device. hasUnifiedMemory
44+ marray = _cast_array_to_mtlarray (x, device)
45+ marray != = nothing && return marray
46+ end
47+
48+ return adapt (MtlArray, x)
49+ end
50+
51+ function Dagger. move (from_proc:: OSProc , to_proc:: MtlArrayDeviceProc , x)
52+ adapt (MtlArray, x)
53+ end
54+
55+ function Dagger. move (
56+ from_proc:: MtlArrayDeviceProc ,
57+ to_proc:: OSProc ,
58+ x:: Array{T, N}
59+ ) where {T, N}
60+ # If we have unified memory, we can just cast the `MtlArray` to an `Array`.
61+ device = _get_metal_device (from_proc)
62+
63+ if (device != = nothing ) && device. hasUnifiedMemory
64+ return unsafe_wrap (Array{T}, x, size (x))
65+ else
66+ return adapt (Array, x)
67+ end
68+ end
69+
70+ function Dagger. move (from_proc:: MtlArrayDeviceProc , to_proc:: OSProc , x)
71+ adapt (Array, x)
772end
873
9- @gpuproc (MtlArrayDeviceProc, MtlArray)
1074Dagger. get_parent (proc:: MtlArrayDeviceProc ) = Dagger. OSProc (proc. owner)
1175
1276function Dagger. execute! (proc:: MtlArrayDeviceProc , func, args... )
@@ -30,15 +94,58 @@ function Dagger.execute!(proc::MtlArrayDeviceProc, func, args...)
3094end
3195
3296function Base. show (io:: IO , proc:: MtlArrayDeviceProc )
33- print (io, " MtlArrayDeviceProc on worker $(proc. owner) , device ($(proc. device . name) )" )
97+ print (io, " MtlArrayDeviceProc on worker $(proc. owner) , device ($(something ( _get_metal_device ( proc)) . name) )" )
3498end
3599
36100processor (:: Val{:Metal} ) = MtlArrayDeviceProc
37101cancompute (:: Val{:Metal} ) = length (Metal. devices ()) >= 1
38- kernel_backend (:: MtlArrayDeviceProc ) = Metal. current_device ()
102+ kernel_backend (proc:: MtlArrayDeviceProc ) = _get_metal_device (proc)
103+
104+ for dev in Metal. devices ()
105+ Dagger. add_processor_callback! (" metal_device_$(dev. registryID) " ) do
106+ MtlArrayDeviceProc (Distributed. myid (), dev. registryID)
107+ end
108+ end
109+
110+ # ###############################################################################
111+ # Private functions
112+ # ###############################################################################
113+
114+ # Try casting the array `x` to an `MtlArray`. If the casting is not possible,
115+ # return `nothing`.
116+ function _cast_array_to_mtlarray (x:: Array{T,N} , device:: MtlDevice ) where {T,N}
117+ # Try creating the buffer without copying.
118+ dims = size (x)
119+ nbytes_array = prod (dims) * sizeof (T)
120+ pagesize = ccall (:getpagesize , Cint, ())
121+ num_pages = div (nbytes_array, pagesize, RoundUp)
122+ nbytes = num_pages * pagesize
123+
124+ pbuf = Metal. MTL. mtDeviceNewBufferWithBytesNoCopy (
125+ device,
126+ pointer (x),
127+ nbytes,
128+ Metal. Shared | Metal. MTL. DefaultTracking | Metal. MTL. DefaultCPUCache
129+ )
130+
131+ if pbuf != C_NULL
132+ buf = MtlBuffer (pbuf)
133+ marray = MtlArray {T,N} (buf, dims)
134+ return marray
135+ end
136+
137+ # If we reached here, the conversion was not possible.
138+ return nothing
139+ end
140+
141+ # Return the Metal device handler given the ID recorded in `proc`.
142+ function _get_metal_device (proc:: MtlArrayDeviceProc )
143+ devices = Metal. devices ()
144+ id = findfirst (dev -> dev. registryID == proc. device_id, devices)
39145
40- if length (Metal. devices ()) >= 1
41- Dagger. add_processor_callback! (" metal_device" ) do
42- MtlArrayDeviceProc (Distributed. myid (), Metal. current_device ())
146+ if devices === nothing
147+ return nothing
148+ else
149+ return devices[id]
43150 end
44151end
0 commit comments