Skip to content

Commit 8f11922

Browse files
authored
Initial support for in-place operations with unified memory (#21)
* Handle multiple GPUs * Add support for in-place ops with unified memory
1 parent c8941db commit 8f11922

2 files changed

Lines changed: 165 additions & 7 deletions

File tree

src/metal.jl

Lines changed: 114 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,74 @@ import .Metal: MtlArray, MtlDevice
33

44
struct MtlArrayDeviceProc <: Dagger.Processor
55
owner::Int
6-
device::MtlDevice
6+
device_id::UInt64
7+
end
8+
9+
# Assume that we can run anything.
10+
Dagger.iscompatible_func(proc::MtlArrayDeviceProc, opts, f) = true
11+
Dagger.iscompatible_arg(proc::MtlArrayDeviceProc, opts, x) = true
12+
13+
# CPUs shouldn't process our array type.
14+
Dagger.iscompatible_arg(proc::Dagger.ThreadProc, opts, x::MtlArray) = false
15+
16+
function Dagger.move(from_proc::OSProc, to_proc::MtlArrayDeviceProc, x::Chunk)
17+
from_pid = from_proc.pid
18+
to_pid = Dagger.get_parent(to_proc).pid
19+
@assert myid() == to_pid
20+
21+
return Dagger.move(from_proc, to_proc, remotecall_fetch(x->poolget(x.handle), from_pid, x))
22+
end
23+
24+
function Dagger.move(from_proc::MtlArrayDeviceProc, to_proc::OSProc, x::Chunk)
25+
from_pid = Dagger.get_parent(from_proc).pid
26+
to_pid = to_proc.pid
27+
@assert myid() == to_pid
28+
29+
return remotecall_fetch(from_pid, x) do x
30+
mtlarray = poolget(x.handle)
31+
return Dagger.move(from_proc, to_proc, mtlarray)
32+
end
33+
end
34+
35+
function Dagger.move(
36+
from_proc::OSProc,
37+
to_proc::MtlArrayDeviceProc,
38+
x::Array{T, N}
39+
) where {T, N}
40+
# If we have unified memory, we can try casting the `Array` to `MtlArray`.
41+
device = _get_metal_device(to_proc)
42+
43+
if (device !== nothing) && device.hasUnifiedMemory
44+
marray = _cast_array_to_mtlarray(x, device)
45+
marray !== nothing && return marray
46+
end
47+
48+
return adapt(MtlArray, x)
49+
end
50+
51+
function Dagger.move(from_proc::OSProc, to_proc::MtlArrayDeviceProc, x)
52+
adapt(MtlArray, x)
53+
end
54+
55+
function Dagger.move(
56+
from_proc::MtlArrayDeviceProc,
57+
to_proc::OSProc,
58+
x::Array{T, N}
59+
) where {T, N}
60+
# If we have unified memory, we can just cast the `MtlArray` to an `Array`.
61+
device = _get_metal_device(from_proc)
62+
63+
if (device !== nothing) && device.hasUnifiedMemory
64+
return unsafe_wrap(Array{T}, x, size(x))
65+
else
66+
return adapt(Array, x)
67+
end
68+
end
69+
70+
function Dagger.move(from_proc::MtlArrayDeviceProc, to_proc::OSProc, x)
71+
adapt(Array, x)
772
end
873

9-
@gpuproc(MtlArrayDeviceProc, MtlArray)
1074
Dagger.get_parent(proc::MtlArrayDeviceProc) = Dagger.OSProc(proc.owner)
1175

1276
function Dagger.execute!(proc::MtlArrayDeviceProc, func, args...)
@@ -30,15 +94,58 @@ function Dagger.execute!(proc::MtlArrayDeviceProc, func, args...)
3094
end
3195

3296
function Base.show(io::IO, proc::MtlArrayDeviceProc)
33-
print(io, "MtlArrayDeviceProc on worker $(proc.owner), device ($(proc.device.name))")
97+
print(io, "MtlArrayDeviceProc on worker $(proc.owner), device ($(something(_get_metal_device(proc)).name))")
3498
end
3599

36100
processor(::Val{:Metal}) = MtlArrayDeviceProc
37101
cancompute(::Val{:Metal}) = length(Metal.devices()) >= 1
38-
kernel_backend(::MtlArrayDeviceProc) = Metal.current_device()
102+
kernel_backend(proc::MtlArrayDeviceProc) = _get_metal_device(proc)
103+
104+
for dev in Metal.devices()
105+
Dagger.add_processor_callback!("metal_device_$(dev.registryID)") do
106+
MtlArrayDeviceProc(Distributed.myid(), dev.registryID)
107+
end
108+
end
109+
110+
################################################################################
111+
# Private functions
112+
################################################################################
113+
114+
# Try casting the array `x` to an `MtlArray`. If the casting is not possible,
115+
# return `nothing`.
116+
function _cast_array_to_mtlarray(x::Array{T,N}, device::MtlDevice) where {T,N}
117+
# Try creating the buffer without copying.
118+
dims = size(x)
119+
nbytes_array = prod(dims) * sizeof(T)
120+
pagesize = ccall(:getpagesize, Cint, ())
121+
num_pages = div(nbytes_array, pagesize, RoundUp)
122+
nbytes = num_pages * pagesize
123+
124+
pbuf = Metal.MTL.mtDeviceNewBufferWithBytesNoCopy(
125+
device,
126+
pointer(x),
127+
nbytes,
128+
Metal.Shared | Metal.MTL.DefaultTracking | Metal.MTL.DefaultCPUCache
129+
)
130+
131+
if pbuf != C_NULL
132+
buf = MtlBuffer(pbuf)
133+
marray = MtlArray{T,N}(buf, dims)
134+
return marray
135+
end
136+
137+
# If we reached here, the conversion was not possible.
138+
return nothing
139+
end
140+
141+
# Return the Metal device handler given the ID recorded in `proc`.
142+
function _get_metal_device(proc::MtlArrayDeviceProc)
143+
devices = Metal.devices()
144+
id = findfirst(dev -> dev.registryID == proc.device_id, devices)
39145

40-
if length(Metal.devices()) >= 1
41-
Dagger.add_processor_callback!("metal_device") do
42-
MtlArrayDeviceProc(Distributed.myid(), Metal.current_device())
146+
if devices === nothing
147+
return nothing
148+
else
149+
return devices[id]
43150
end
44151
end

test/runtests.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ end
2222
@show A
2323
A
2424
end
25+
26+
# Create a function to perform an in-place operation.
27+
function addarray!(x)
28+
x .= x .+ 1.0f0
29+
end
2530
end
2631

2732
function generate_thunks()
@@ -108,5 +113,51 @@ end
108113

109114
# It seems KernelAbstractions does not support Metal.jl.
110115
@test_skip "KernelAbstractions"
116+
117+
@testset "In-place operations" begin
118+
# Create a page-aligned array.
119+
dims = (2, 2)
120+
T = Float32
121+
pagesize = ccall(:getpagesize, Cint, ())
122+
addr = Ref(C_NULL)
123+
124+
ccall(
125+
:posix_memalign,
126+
Cint,
127+
(Ptr{Ptr{Cvoid}}, Csize_t, Csize_t), addr,
128+
pagesize,
129+
prod(dims) * sizeof(T)
130+
)
131+
132+
array = unsafe_wrap(
133+
Array{T, length(dims)},
134+
reinterpret(Ptr{T}, addr[]),
135+
dims,
136+
own = false
137+
)
138+
139+
# Initialize the array.
140+
array[1, 1] = 1
141+
array[1, 2] = 2
142+
array[2, 1] = 3
143+
array[2, 2] = 4
144+
145+
# Perform the computation only on a local `MtlArrayDeviceProc`
146+
t = Dagger.@spawn single=myid() proclist = [metalproc] addarray!(array)
147+
148+
# Fetch and check the results.
149+
ret = fetch(t)
150+
151+
@test ret[1, 1] == 2.0f0
152+
@test ret[1, 2] == 3.0f0
153+
@test ret[2, 1] == 4.0f0
154+
@test ret[2, 2] == 5.0f0
155+
156+
# Check if the operation happened in-place.
157+
@test array[1, 1] == 2.0f0
158+
@test array[1, 2] == 3.0f0
159+
@test array[2, 1] == 4.0f0
160+
@test array[2, 2] == 5.0f0
161+
end
111162
end
112163
end

0 commit comments

Comments
 (0)