Skip to content

Commit 77e98eb

Browse files
Introduce and use _aes_enc_full
1 parent 2d2ae22 commit 77e98eb

3 files changed

Lines changed: 56 additions & 52 deletions

File tree

src/aarch64/aesni.jl

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -211,22 +211,8 @@ get_ctr_uint64x2(o::AESNI1x)::Tuple{uint64x2} = (o.ctr,)
211211
get_key(o::Union{AESNI1x, AESNI4x})::NTuple{11,UInt128} = map(UInt128, get_key_uint64x2(o))
212212
get_ctr(o::Union{AESNI1x, AESNI4x})::Tuple{UInt128} = map(UInt128, get_ctr_uint64x2(o))
213213

214-
@inline function aesni(key::NTuple{11,uint64x2}, ctr::Tuple{uint64x2})::Tuple{uint64x2}
215-
key1, key2, key3, key4, key5, key6, key7, key8, key9, key10, key11 = key
216-
ctr1 = only(ctr)
217-
x = key1 ctr1
218-
x = _aes_enc(x, key2)
219-
x = _aes_enc(x, key3)
220-
x = _aes_enc(x, key4)
221-
x = _aes_enc(x, key5)
222-
x = _aes_enc(x, key6)
223-
x = _aes_enc(x, key7)
224-
x = _aes_enc(x, key8)
225-
x = _aes_enc(x, key9)
226-
x = _aes_enc(x, key10)
227-
x = _aes_enc_last(x, key11)
228-
(x,)
229-
end
214+
@inline aesni(key::NTuple{11,uint64x2}, ctr::Tuple{uint64x2})::Tuple{uint64x2} =
215+
(_aes_enc_full(only(ctr), key),)
230216

231217
"""
232218
aesni(key::NTuple{11,UInt128}, ctr::Tuple{UInt128})::Tuple{UInt128}

src/aarch64/aesni_common.jl

Lines changed: 52 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,51 +10,51 @@ const uint64x2_lvec = NTuple{2, VecElement{UInt64}}
1010
struct uint64x2
1111
data::uint64x2_lvec
1212
end
13-
Base.convert(::Type{uint64x2}, x::UInt128) = unsafe_load(Ptr{uint64x2}(pointer_from_objref(Ref(x))))
14-
Base.convert(::Type{UInt128}, x::uint64x2) = unsafe_load(Ptr{UInt128}(pointer_from_objref(Ref(x))))
15-
UInt128(x::uint64x2) = convert(UInt128, x)
16-
uint64x2(x::UInt128) = convert(uint64x2, x)
17-
Base.convert(::Type{uint64x2}, x::Union{Signed, Unsigned}) = convert(uint64x2, UInt128(x))
18-
Base.convert(::Type{T}, x::uint64x2) where T <: Union{Signed, Unsigned} = convert(T, UInt128(x))
19-
20-
uint64x2(hi::UInt64, lo::UInt64) = if LITTLE_ENDIAN
13+
@inline Base.convert(::Type{uint64x2}, x::UInt128) = unsafe_load(Ptr{uint64x2}(pointer_from_objref(Ref(x))))
14+
@inline Base.convert(::Type{UInt128}, x::uint64x2) = unsafe_load(Ptr{UInt128}(pointer_from_objref(Ref(x))))
15+
@inline UInt128(x::uint64x2) = convert(UInt128, x)
16+
@inline uint64x2(x::UInt128) = convert(uint64x2, x)
17+
@inline Base.convert(::Type{uint64x2}, x::Union{Signed, Unsigned}) = convert(uint64x2, UInt128(x))
18+
@inline Base.convert(::Type{T}, x::uint64x2) where T <: Union{Signed, Unsigned} = convert(T, UInt128(x))
19+
20+
@inline uint64x2(hi::UInt64, lo::UInt64) = @static if LITTLE_ENDIAN
2121
uint64x2((VecElement(lo), VecElement(hi)))
2222
else
2323
uint64x2((VecElement(hi), VecElement(lo)))
2424
end
2525

26-
Base.zero(::Type{uint64x2}) = convert(uint64x2, 0)
27-
Base.one(::Type{uint64x2}) = uint64x2(zero(UInt64), one(UInt64))
28-
Base.xor(a::uint64x2, b::uint64x2) = llvmcall(
26+
@inline Base.zero(::Type{uint64x2}) = convert(uint64x2, 0)
27+
@inline Base.one(::Type{uint64x2}) = uint64x2(zero(UInt64), one(UInt64))
28+
@inline Base.xor(a::uint64x2, b::uint64x2) = llvmcall(
2929
"""%3 = xor <2 x i64> %1, %0
3030
ret <2 x i64> %3""",
3131
uint64x2_lvec, Tuple{uint64x2_lvec, uint64x2_lvec},
3232
a.data, b.data,
3333
) |> uint64x2
34-
(+)(a::uint64x2, b::uint64x2) = llvmcall(
34+
@inline (+)(a::uint64x2, b::uint64x2) = llvmcall(
3535
"""%3 = add <2 x i64> %1, %0
3636
ret <2 x i64> %3""",
3737
uint64x2_lvec, Tuple{uint64x2_lvec, uint64x2_lvec},
3838
a.data, b.data,
3939
) |> uint64x2
40-
(+)(a::uint64x2, b::Integer) = a + uint64x2(UInt128(b))
40+
@inline (+)(a::uint64x2, b::Integer) = a + uint64x2(UInt128(b))
4141

4242
const uint8x16_lvec = NTuple{16, VecElement{UInt8}}
4343
struct uint8x16
4444
data::uint8x16_lvec
4545
end
46-
Base.convert(::Type{uint64x2}, x::uint8x16) = unsafe_load(Ptr{uint64x2}(pointer_from_objref(Ref(x))))
47-
Base.convert(::Type{uint8x16}, x::uint64x2) = unsafe_load(Ptr{uint8x16}(pointer_from_objref(Ref(x))))
48-
uint8x16(x::uint64x2) = convert(uint8x16, x)
49-
uint64x2(x::uint8x16) = convert(uint64x2, x)
50-
Base.convert(::Type{uint8x16}, x::UInt128) = unsafe_load(Ptr{uint8x16}(pointer_from_objref(Ref(x))))
51-
Base.convert(::Type{UInt128}, x::uint8x16) = unsafe_load(Ptr{UInt128}(pointer_from_objref(Ref(x))))
52-
UInt128(x::uint8x16) = convert(UInt128, x)
53-
uint8x16(x::UInt128) = convert(uint8x16, x)
54-
Base.convert(::Type{uint8x16}, x::Union{Signed, Unsigned}) = convert(uint8x16, UInt128(x))
55-
Base.convert(::Type{T}, x::uint8x16) where T <: Union{Signed, Unsigned} = convert(T, UInt128(x))
56-
57-
function uint8x16(bytes::Vararg{UInt8, 16})
46+
@inline Base.convert(::Type{uint64x2}, x::uint8x16) = unsafe_load(Ptr{uint64x2}(pointer_from_objref(Ref(x))))
47+
@inline Base.convert(::Type{uint8x16}, x::uint64x2) = unsafe_load(Ptr{uint8x16}(pointer_from_objref(Ref(x))))
48+
@inline uint8x16(x::uint64x2) = convert(uint8x16, x)
49+
@inline uint64x2(x::uint8x16) = convert(uint64x2, x)
50+
@inline Base.convert(::Type{uint8x16}, x::UInt128) = unsafe_load(Ptr{uint8x16}(pointer_from_objref(Ref(x))))
51+
@inline Base.convert(::Type{UInt128}, x::uint8x16) = unsafe_load(Ptr{UInt128}(pointer_from_objref(Ref(x))))
52+
@inline UInt128(x::uint8x16) = convert(UInt128, x)
53+
@inline uint8x16(x::UInt128) = convert(uint8x16, x)
54+
@inline Base.convert(::Type{uint8x16}, x::Union{Signed, Unsigned}) = convert(uint8x16, UInt128(x))
55+
@inline Base.convert(::Type{T}, x::uint8x16) where T <: Union{Signed, Unsigned} = convert(T, UInt128(x))
56+
57+
@inline function uint8x16(bytes::Vararg{UInt8, 16})
5858
bytes_prepped = bytes
5959
if LITTLE_ENDIAN
6060
bytes_prepped = reverse(bytes_prepped)
@@ -63,23 +63,23 @@ function uint8x16(bytes::Vararg{UInt8, 16})
6363
return uint8x16(bytes_vec)
6464
end
6565

66-
Base.zero(::Type{uint8x16}) = convert(uint8x16, 0)
67-
Base.xor(a::uint8x16, b::uint8x16) = llvmcall(
66+
@inline Base.zero(::Type{uint8x16}) = convert(uint8x16, 0)
67+
@inline Base.xor(a::uint8x16, b::uint8x16) = llvmcall(
6868
"""%3 = xor <16 x i8> %1, %0
6969
ret <16 x i8> %3""",
7070
uint8x16_lvec, Tuple{uint8x16_lvec, uint8x16_lvec},
7171
a.data, b.data,
7272
) |> uint8x16
7373

7474
# Raw NEON instrinsics, provided by FEAT_AES
75-
_vaese(a::uint8x16, b::uint8x16) = ccall(
75+
@inline _vaese(a::uint8x16, b::uint8x16) = ccall(
7676
"llvm.aarch64.crypto.aese",
7777
llvmcall,
7878
uint8x16_lvec,
7979
(uint8x16_lvec, uint8x16_lvec),
8080
a.data, b.data,
8181
) |> uint8x16
82-
_vaesmc(a::uint8x16) = ccall(
82+
@inline _vaesmc(a::uint8x16) = ccall(
8383
"llvm.aarch64.crypto.aesmc",
8484
llvmcall,
8585
uint8x16_lvec,
@@ -104,7 +104,7 @@ uint8x16_t _mm_aeskeygenassist_helper(uint8x16_t a)
104104
```
105105
Then made architecture-agnostic as LLVM IR.
106106
"""
107-
_aes_key_gen_shuffle_helper(a::uint8x16) = llvmcall(
107+
@inline _aes_key_gen_shuffle_helper(a::uint8x16) = llvmcall(
108108
"""%2 = shufflevector <16 x i8> %0, <16 x i8> undef, <16 x i32> <i32 4, i32 1, i32 14, i32 11, i32 1, i32 14, i32 11, i32 4, i32 12, i32 9, i32 6, i32 3, i32 9, i32 6, i32 3, i32 12>
109109
ret <16 x i8> %2""",
110110
uint8x16_lvec, Tuple{uint8x16_lvec},
@@ -116,21 +116,39 @@ _aes_key_gen_shuffle_helper(a::uint8x16) = llvmcall(
116116
# Algorithm translations courtesy of the SIMD Everywhere and SSE2NEON projects:
117117
# https://github.com/simd-everywhere/simde/blob/v0.8.0-rc1/simde/x86/aes.h
118118
# https://github.com/DLTcollab/sse2neon/blob/v1.6.0/sse2neon.h
119-
function _aes_enc(a::uint64x2, round_key::uint64x2)
119+
@inline function _aes_enc(a::uint64x2, round_key::uint64x2)
120120
res = _vaesmc(_vaese(uint8x16(a), zero(uint8x16)))
121121
return uint64x2(res) round_key
122122
end
123-
function _aes_enc_last(a::uint64x2, round_key::uint64x2)
123+
@inline function _aes_enc_last(a::uint64x2, round_key::uint64x2)
124124
res = _vaese(uint8x16(a), zero(uint8x16))
125125
return uint64x2(res) round_key
126126
end
127-
128-
function _aes_key_gen_assist(a::uint64x2, ::Val{R}) where {R}
127+
@inline function _aes_key_gen_assist(a::uint64x2, ::Val{R}) where {R}
129128
res = _aes_key_gen_shuffle_helper(_vaese(uint8x16(a), zero(uint8x16)))
130129
r = R % UInt64
131130
return uint64x2(res) uint64x2(r, r)
132131
end
133132

133+
"""
134+
_aes_enc_full(a::uint64x2, round_keys::NTuple{N,uint64x2})::uint64x2 where {N}
135+
136+
Full AES encryption flow for N rounds.
137+
"""
138+
@inline function _aes_enc_full(a::uint64x2, round_keys::NTuple{N,uint64x2})::uint64x2 where {N}
139+
res = uint8x16(a)
140+
for (i, key) in enumerate(round_keys)
141+
if i N
142+
res = _vaese(res, uint8x16(key))
143+
if i N - 1
144+
res = _vaesmc(res)
145+
end
146+
else
147+
return uint64x2(res uint8x16(key))
148+
end
149+
end
150+
end
151+
134152
"Abstract RNG that generates one number at a time and is based on AESNI."
135153
abstract type AbstractAESNI1x <: R123Generator1x{UInt128} end
136154
"Abstract RNG that generates four numbers at a time and is based on AESNI."

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Printf: @printf
2323
(Philox4x(UInt32 , seed2) , philox , (Val(10),)) ,
2424
(Philox4x(UInt64 , seed2) , philox , (Val(10),)) ,
2525
]
26-
if R123_USE_AESNI
26+
@static if R123_USE_AESNI
2727
append!(alg_choices, AlgChoice[
2828
(AESNI1x(seed1) , aesni , () ) ,
2929
(AESNI4x(seed4) , aesni , () ) ,
@@ -172,7 +172,7 @@ redirect_stdout(stdout_)
172172
compare_dirs("expected", "actual")
173173
cd(pwd_)
174174

175-
if Random123.R123_USE_X86_AES_NI
175+
@static if Random123.R123_USE_X86_AES_NI
176176
include("./x86/aesni.jl")
177177
include("./x86/ars.jl")
178178
elseif Random123.R123_USE_AARCH64_FEAT_AES

0 commit comments

Comments
 (0)