Skip to content

Commit 56b3cfe

Browse files
authored
add support for Metal.jl (#34)
1 parent 095ed78 commit 56b3cfe

3 files changed

Lines changed: 48 additions & 4 deletions

File tree

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "StridedViews"
22
uuid = "4db3bf67-4bd7-4b4e-b153-31dc3fb37143"
3-
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
43
version = "0.4.6"
4+
authors = ["Lukas Devos <ldevos98@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -10,8 +10,9 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1010
[weakdeps]
1111
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
13-
PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
1413
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
14+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
15+
PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
1516

1617
[extensions]
1718
StridedViewsAMDGPUExt = "AMDGPU"
@@ -26,6 +27,7 @@ CUDA = "4,5"
2627
JET = "0.9, 0.10, 0.11"
2728
JLArrays = "0.3.1"
2829
LinearAlgebra = "1.6"
30+
Metal = "1.9.3"
2931
PackageExtensionCompat = "1"
3032
PtrArrays = "1.2.0"
3133
Random = "1.6"
@@ -38,9 +40,10 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3840
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3941
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4042
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
43+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
4144
PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
4245
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4346
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4447

4548
[targets]
46-
test = ["Test", "Random", "Aqua", "JET", "PtrArrays", "CUDA", "AMDGPU", "JLArrays"]
49+
test = ["Test", "Random", "Aqua", "JET", "PtrArrays", "CUDA", "AMDGPU", "JLArrays", "Metal"]

ext/StridedViewsMetalExt.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module StridedViewsMetalExt
2+
3+
using StridedViews
4+
using Metal
5+
using Metal: Adapt, MtlPtr
6+
7+
const MtlStridedView{T, N, A <: MtlArray{T}} = StridedView{T, N, A}
8+
9+
function Adapt.adapt_structure(to, A::MtlStridedView)
10+
return StridedView(
11+
Adapt.adapt_structure(to, parent(A)),
12+
A.size, A.strides, A.offset, A.op
13+
)
14+
end
15+
16+
function Base.pointer(x::MtlStridedView{T}) where {T}
17+
return Base.unsafe_convert(MtlPtr{T}, pointer(x.parent, x.offset + 1))
18+
end
19+
function Base.unsafe_convert(::Type{MtlPtr{T}}, a::MtlStridedView{T}) where {T}
20+
return convert(MtlPtr{T}, pointer(a))
21+
end
22+
23+
function Base.print_array(io::IO, X::MtlStridedView)
24+
return Base.print_array(io, Adapt.adapt_structure(Array, X))
25+
end
26+
27+
end # module

test/runtests.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ if !is_buildkite
315315
end
316316
end
317317

318-
using CUDA, AMDGPU
318+
using CUDA, AMDGPU, Metal
319319

320320
if CUDA.functional()
321321
@testset "CuArrays with StridedView" begin
@@ -344,3 +344,17 @@ if AMDGPU.functional()
344344
end
345345
end
346346
end
347+
348+
if Metal.functional()
349+
@testset "MtlArrays with StridedView" begin
350+
@testset for T in (Float32, ComplexF32)
351+
A = MtlArray(randn(T, 10, 10, 10, 10))
352+
@test isstrided(A)
353+
B = StridedView(A)
354+
@test B isa StridedView
355+
Metal.@allowscalar begin
356+
@test B == A
357+
end
358+
end
359+
end
360+
end

0 commit comments

Comments
 (0)