Skip to content

Commit 732a068

Browse files
authored
Merge pull request #26 from QuantumKitHub/ksh/gpu
Add AMDGPU extension and some GPU tests
2 parents e41566d + 9730884 commit 732a068

5 files changed

Lines changed: 92 additions & 3 deletions

File tree

.buildkite/pipeline.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,35 @@ steps:
3131
cuda: "*"
3232
if: build.message !~ /\[skip tests\]/
3333
timeout_in_minutes: 30
34+
35+
- label: "Julia v1 -- AMDGPU"
36+
plugins:
37+
- JuliaCI/julia#v1:
38+
version: "1"
39+
- JuliaCI/julia-test#v1: ~
40+
- JuliaCI/julia-coverage#v1:
41+
dirs:
42+
- src
43+
- ext
44+
agents:
45+
queue: "juliagpu"
46+
rocm: "*"
47+
rocmgpu: "*"
48+
if: build.message !~ /\[skip tests\]/
49+
timeout_in_minutes: 30
50+
51+
- label: "Julia LTS -- AMDGPU"
52+
plugins:
53+
- JuliaCI/julia#v1:
54+
version: "1.10" # "lts" isn't valid
55+
- JuliaCI/julia-test#v1: ~
56+
- JuliaCI/julia-coverage#v1:
57+
dirs:
58+
- src
59+
- ext
60+
agents:
61+
queue: "juliagpu"
62+
rocm: "*"
63+
rocmgpu: "*"
64+
if: build.message !~ /\[skip tests\]/
65+
timeout_in_minutes: 30

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
99

1010
[weakdeps]
11+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1112
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1213
PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
1314

1415
[extensions]
16+
StridedViewsAMDGPUExt = "AMDGPU"
1517
StridedViewsCUDAExt = "CUDA"
1618
StridedViewsPtrArraysExt = "PtrArrays"
1719

1820
[compat]
21+
AMDGPU = "2"
1922
Aqua = "0.8"
2023
CUDA = "4,5"
2124
JET = "0.9, 0.10, 0.11"
@@ -27,6 +30,7 @@ Test = "1.6"
2730
julia = "1.10"
2831

2932
[extras]
33+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3034
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3135
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3236
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
@@ -35,4 +39,4 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3539
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3640

3741
[targets]
38-
test = ["Test", "Random", "Aqua", "JET", "PtrArrays", "CUDA"]
42+
test = ["Test", "Random", "Aqua", "JET", "PtrArrays", "CUDA", "AMDGPU"]

ext/StridedViewsAMDGPUExt.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module StridedViewsAMDGPUExt
2+
3+
using StridedViews
4+
using AMDGPU
5+
using AMDGPU: Adapt, ROCPtr
6+
7+
const ROCStridedView{T, N, A <: ROCArray{T}} = StridedView{T, N, A}
8+
9+
function Adapt.adapt_structure(to, A::ROCStridedView)
10+
return StridedView(
11+
Adapt.adapt_structure(to, parent(A)),
12+
A.size, A.strides, A.offset, A.op
13+
)
14+
end
15+
16+
function Base.pointer(x::ROCStridedView{T}) where {T}
17+
return Base.unsafe_convert(Ptr{T}, pointer(x.parent, x.offset + 1))
18+
end
19+
function Base.unsafe_convert(::Type{Ptr{T}}, a::ROCStridedView{T}) where {T}
20+
return convert(Ptr{T}, pointer(a))
21+
end
22+
23+
function Base.print_array(io::IO, X::ROCStridedView)
24+
return Base.print_array(io, Adapt.adapt_structure(Array, X))
25+
end
26+
27+
end # module

ext/StridedViewsCUDAExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ using CUDA: Adapt, CuPtr
66

77
const CuStridedView{T, N, A <: CuArray{T}} = StridedView{T, N, A}
88

9-
function Adapt.adapt_structure(::Type{T}, A::StridedView) where {T}
9+
function Adapt.adapt_structure(to, A::CuStridedView)
1010
return StridedView(
11-
Adapt.adapt_structure(T, parent(A)),
11+
Adapt.adapt_structure(to, parent(A)),
1212
A.size, A.strides, A.offset, A.op
1313
)
1414
end

test/runtests.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,29 @@ if !is_buildkite
298298
JET.test_package(StridedViews; target_modules = (StridedViews,))
299299
end
300300
end
301+
302+
using CUDA, AMDGPU
303+
304+
if CUDA.functional()
305+
@testset "CuArrays with StridedView" begin
306+
@testset for T in (Float64, ComplexF64)
307+
A = CUDA.randn!(T, 10, 10, 10, 10)
308+
@test isstrided(A)
309+
B = StridedView(A)
310+
@test B isa StridedView
311+
@test B == A
312+
end
313+
end
314+
end
315+
316+
if AMDGPU.functional()
317+
@testset "ROCArrays with StridedView" begin
318+
@testset for T in (Float64, ComplexF64)
319+
A = AMDGPU.randn!(T, 10, 10, 10, 10)
320+
@test isstrided(A)
321+
B = StridedView(A)
322+
@test B isa StridedView
323+
@test B == A
324+
end
325+
end
326+
end

0 commit comments

Comments
 (0)