Skip to content

Commit 93c06e6

Browse files
authored
Add JLArrays extension and tests (#31)
* Add JLArrays extension and tests * Add a test
1 parent 44b764a commit 93c06e6

3 files changed

Lines changed: 48 additions & 1 deletion

File tree

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,20 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1111
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1313
PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
14+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1415

1516
[extensions]
1617
StridedViewsAMDGPUExt = "AMDGPU"
1718
StridedViewsCUDAExt = "CUDA"
19+
StridedViewsJLArraysExt = "JLArrays"
1820
StridedViewsPtrArraysExt = "PtrArrays"
1921

2022
[compat]
2123
AMDGPU = "2"
2224
Aqua = "0.8"
2325
CUDA = "4,5"
2426
JET = "0.9, 0.10, 0.11"
27+
JLArrays = "0.3.1"
2528
LinearAlgebra = "1.6"
2629
PackageExtensionCompat = "1"
2730
PtrArrays = "1.2.0"
@@ -34,9 +37,10 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3437
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3538
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
3639
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
40+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
3741
PtrArrays = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d"
3842
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3943
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4044

4145
[targets]
42-
test = ["Test", "Random", "Aqua", "JET", "PtrArrays", "CUDA", "AMDGPU"]
46+
test = ["Test", "Random", "Aqua", "JET", "PtrArrays", "CUDA", "AMDGPU", "JLArrays"]

ext/StridedViewsJLArraysExt.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module StridedViewsJLArraysExt
2+
3+
using StridedViews
4+
using JLArrays
5+
using JLArrays: Adapt
6+
7+
const JLArrayStridedView{T, N, A <: JLArray{T}} = StridedView{T, N, A}
8+
9+
function Adapt.adapt_structure(to, A::JLArrayStridedView)
10+
return StridedView(
11+
Adapt.adapt_structure(to, parent(A)),
12+
A.size, A.strides, A.offset, A.op
13+
)
14+
end
15+
16+
function Base.pointer(x::JLArrayStridedView{T}) where {T}
17+
return Base.unsafe_convert(Ptr{T}, pointer(x.parent, x.offset + 1))
18+
end
19+
function Base.unsafe_convert(::Type{Ptr{T}}, a::JLArrayStridedView{T}) where {T}
20+
return convert(Ptr{T}, pointer(a))
21+
end
22+
23+
function Base.print_array(io::IO, X::JLArrayStridedView)
24+
return Base.print_array(io, Adapt.adapt_structure(Array, X))
25+
end
26+
27+
end # module

test/runtests.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Test
22
using LinearAlgebra
33
using Random
44
using StridedViews
5+
using JLArrays
56

67
Random.seed!(1234)
78

@@ -290,6 +291,21 @@ if !is_buildkite
290291
end
291292
end
292293

294+
@testset "JLArrays with StridedView" begin
295+
@testset for T in (Float64, ComplexF64)
296+
Araw = randn(T, 10, 10, 10, 10)
297+
A = JLArray(Araw)
298+
@test isstrided(A)
299+
B = StridedView(A)
300+
@test B isa StridedView
301+
JLArrays.@allowscalar begin
302+
@test B == A
303+
end
304+
Bvec = JLArrays.Adapt.adapt(Vector{T}, B)
305+
@test Bvec == StridedView(Araw)
306+
end
307+
end
308+
293309
using Aqua
294310
Aqua.test_all(StridedViews)
295311

0 commit comments

Comments
 (0)