Skip to content

Commit c8941db

Browse files
authored
Initial support for Metal GPUs (#19)
Fix CPU tests
1 parent 2bfacdf commit c8941db

5 files changed

Lines changed: 74 additions & 5 deletions

File tree

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ julia = "1"
2424
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
2525
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2626
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
27+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2728
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2829

2930
[targets]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
**GPU integrations for Dagger.jl**
44

5-
DaggerGPU.jl makes use of the `Dagger.Processor` infrastructure to dispatch Dagger kernels to NVIDIA and AMD GPUs, via CUDA.jl and AMDGPU.jl respectively. Usage is simple: `add` or `dev` DaggerGPU.jl and CUDA.jl/AMDGPU.jl appropriately, load it with `using DaggerGPU`, and add `DaggerGPU.CuArrayDeviceProc`/`DaggerGPU.ROCArrayProc` to your scheduler or thunk options (see Dagger.jl documentation for details on how to do this).
5+
DaggerGPU.jl makes use of the `Dagger.Processor` infrastructure to dispatch Dagger kernels to NVIDIA, AMD, and Apple GPUs, via CUDA.jl, AMDGPU.jl, and Metal.jl respectively. Usage is simple: `add` or `dev` DaggerGPU.jl and CUDA.jl/AMDGPU.jl/Metal.jl appropriately, load it with `using DaggerGPU`, and add `DaggerGPU.CuArrayDeviceProc`/`DaggerGPU.ROCArrayProc`/`DaggerGPU.MtlArrayDeviceProc` to your scheduler or thunk options (see Dagger.jl documentation for details on how to do this).
66

77
DaggerGPU.jl is still experimental, but we welcome GPU-owning users to try it out and report back on any issues or sharp edges that they encounter. When filing an issue about DaggerGPU.jl, please provide:
88
- The complete error message and backtrace

src/DaggerGPU.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ function __init__()
5656
@require AMDGPU="21141c5a-9bdb-4563-92ae-f87d6854732e" begin
5757
include("roc.jl")
5858
end
59+
@require Metal="dde4c033-4e86-420c-a63e-0dd931031962" begin
60+
include("metal.jl")
61+
end
5962
end
6063

6164
end

src/metal.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using .Metal
2+
import .Metal: MtlArray, MtlDevice
3+
4+
struct MtlArrayDeviceProc <: Dagger.Processor
5+
owner::Int
6+
device::MtlDevice
7+
end
8+
9+
@gpuproc(MtlArrayDeviceProc, MtlArray)
10+
Dagger.get_parent(proc::MtlArrayDeviceProc) = Dagger.OSProc(proc.owner)
11+
12+
function Dagger.execute!(proc::MtlArrayDeviceProc, func, args...)
13+
tls = Dagger.get_tls()
14+
task = Threads.@spawn begin
15+
Dagger.set_tls!(tls)
16+
Metal.@sync func(args...)
17+
end
18+
19+
try
20+
fetch(task)
21+
catch err
22+
@static if VERSION >= v"1.1"
23+
stk = Base.catch_stack(task)
24+
err, frames = stk[1]
25+
rethrow(CapturedException(err, frames))
26+
else
27+
rethrow(task.result)
28+
end
29+
end
30+
end
31+
32+
function Base.show(io::IO, proc::MtlArrayDeviceProc)
33+
print(io, "MtlArrayDeviceProc on worker $(proc.owner), device ($(proc.device.name))")
34+
end
35+
36+
processor(::Val{:Metal}) = MtlArrayDeviceProc
37+
cancompute(::Val{:Metal}) = length(Metal.devices()) >= 1
38+
kernel_backend(::MtlArrayDeviceProc) = Metal.current_device()
39+
40+
if length(Metal.devices()) >= 1
41+
Dagger.add_processor_callback!("metal_device") do
42+
MtlArrayDeviceProc(Distributed.myid(), Metal.current_device())
43+
end
44+
end

test/runtests.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ using Test
33
addprocs(2, exeflags="--project")
44

55
@everywhere begin
6+
using CUDA, AMDGPU, Metal, KernelAbstractions
67
using Distributed, Dagger, DaggerGPU
7-
using CUDA, AMDGPU, KernelAbstractions
88
end
99
@everywhere begin
1010
function myfunc(X)
@@ -29,13 +29,15 @@ function generate_thunks()
2929
delayed((xs...)->[sum(xs)])(as...)
3030
end
3131

32-
@test DaggerGPU.cancompute(:CUDA) || DaggerGPU.cancompute(:ROC)
32+
@test DaggerGPU.cancompute(:CUDA) ||
33+
DaggerGPU.cancompute(:ROC) ||
34+
DaggerGPU.cancompute(:Metal)
3335

3436
@testset "CPU" begin
3537
@testset "KernelAbstractions" begin
3638
A = rand(Float32, 8)
37-
_A = collect(delayed(fill_thunk)(A, 2.3))
38-
@test all(_A .== 2.3)
39+
_A = collect(delayed(fill_thunk)(A, 2.3f0))
40+
@test all(_A .== 2.3f0)
3941
end
4042
end
4143

@@ -89,3 +91,22 @@ end
8991
=#
9092
end
9193
end
94+
95+
@testset "Metal" begin
96+
if !DaggerGPU.cancompute(:Metal)
97+
@warn "No Metal devices available, skipping tests"
98+
else
99+
metalproc = DaggerGPU.processor(:Metal)
100+
b = generate_thunks()
101+
opts = Dagger.Sch.ThunkOptions(;proclist = [metalproc])
102+
c_pre = delayed(myfunc; options = opts)(b)
103+
c = delayed(sum; options = opts)(b)
104+
105+
opts = Dagger.Sch.ThunkOptions(;proclist = [Dagger.ThreadProc])
106+
d = delayed(identity; options = opts)(c)
107+
@test collect(d) == 20
108+
109+
# It seems KernelAbstractions does not support Metal.jl.
110+
@test_skip "KernelAbstractions"
111+
end
112+
end

0 commit comments

Comments
 (0)