Skip to content

Commit c60b0e2

Browse files
Introduce AES support for AArch64
1 parent 764d71c commit c60b0e2

7 files changed

Lines changed: 653 additions & 3 deletions

File tree

src/Random123.jl

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,31 @@ else
4949
false
5050
end
5151

52+
"True when AArch64 FEAT_AES intrinsics have been detected."
53+
const R123_USE_AARCH64_FEAT_AES::Bool = if Sys.ARCH :aarch64
54+
try
55+
cmd = Base.julia_cmd()
56+
push!(
57+
cmd.exec,
58+
"-e",
59+
"const uint8x16 = NTuple{16, VecElement{UInt8}};" *
60+
"@assert ccall(\"llvm.aarch64.crypto.aesmc\", " *
61+
"llvmcall, uint8x16, (uint8x16,), " *
62+
"uint8x16((0x4a, 0x68, 0xbd, 0xe1, 0xfe, 0x16, 0x3d, " *
63+
"0xec, 0xde, 0x06, 0x72, 0x86, 0xe3, 0x8c, 0x14, 0xd9))) ≡ " *
64+
"uint8x16((0x70, 0xa7, 0x7b, 0xd2, 0x0c, 0x79, 0xbd, " *
65+
"0xf1, 0x59, 0xc2, 0xad, 0x1a, 0x9f, 0x05, 0x37, 0x0f))",
66+
)
67+
success(cmd)
68+
catch e
69+
false
70+
end
71+
else
72+
false
73+
end
74+
5275
"True when AES-acceleration instructions have been detected."
53-
const R123_USE_AESNI::Bool = R123_USE_X86_AES_NI
76+
const R123_USE_AESNI::Bool = R123_USE_X86_AES_NI || R123_USE_AARCH64_FEAT_AES
5477

5578
@static if R123_USE_AESNI
5679
export AESNI1x, AESNI4x, aesni
@@ -63,6 +86,10 @@ end
6386
include("./x86/aesni_common.jl")
6487
include("./x86/aesni.jl")
6588
include("./x86/ars.jl")
89+
elseif R123_USE_AARCH64_FEAT_AES
90+
include("./aarch64/aesni_common.jl")
91+
include("./aarch64/aesni.jl")
92+
include("./aarch64/ars.jl")
6693
end
6794

6895
end

src/aarch64/aesni.jl

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import Base: copy, copyto!, ==, llvmcall
2+
import Random: rand, seed!
3+
import RandomNumbers: gen_seed, union_uint, seed_type, unsafe_copyto!, unsafe_compare
4+
5+
6+
"The key for AESNI."
7+
mutable struct AESNIKey
8+
key1::uint64x2
9+
key2::uint64x2
10+
key3::uint64x2
11+
key4::uint64x2
12+
key5::uint64x2
13+
key6::uint64x2
14+
key7::uint64x2
15+
key8::uint64x2
16+
key9::uint64x2
17+
key10::uint64x2
18+
key11::uint64x2
19+
AESNIKey() = new(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
20+
end
21+
22+
copyto!(dest::AESNIKey, src::AESNIKey) = unsafe_copyto!(dest, src, UInt128, 11)
23+
24+
copy(src::AESNIKey) = copyto!(AESNIKey(), src)
25+
26+
==(key1::AESNIKey, key2::AESNIKey) = unsafe_compare(key1, key2, UInt128, 11)
27+
28+
"""
29+
Assistant function for AES128. Originally compiled for x86 from the C++ source code:
30+
```cpp
31+
R123_STATIC_INLINE __m128i AES_128_ASSIST (__m128i temp1, __m128i temp2) {
32+
uint64x2 temp3;
33+
temp2 = _mm_shuffle_epi32 (temp2 ,0xff);
34+
temp3 = _mm_slli_si128 (temp1, 0x4);
35+
temp1 = _mm_xor_si128 (temp1, temp3);
36+
temp3 = _mm_slli_si128 (temp3, 0x4);
37+
temp1 = _mm_xor_si128 (temp1, temp3);
38+
temp3 = _mm_slli_si128 (temp3, 0x4);
39+
temp1 = _mm_xor_si128 (temp1, temp3);
40+
temp1 = _mm_xor_si128 (temp1, temp2);
41+
return temp1;
42+
}
43+
```
44+
Then made architecture-agnostic as LLVM IR.
45+
"""
46+
_aes_128_assist(a::uint64x2, b::uint64x2) = llvmcall(
47+
"""%3 = bitcast <2 x i64> %1 to <4 x i32>
48+
%4 = shufflevector <4 x i32> %3, <4 x i32> undef, <4 x i32> <i32 3, i32 3, i32 3, i32 3>
49+
%5 = bitcast <4 x i32> %4 to <2 x i64>
50+
%6 = bitcast <2 x i64> %0 to <16 x i8>
51+
%7 = shufflevector <16 x i8> <i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 0, i8 0, i8 0, i8 0>, <16 x i8> %6, <16 x i32> <i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27>
52+
%8 = bitcast <16 x i8> %7 to <2 x i64>
53+
%9 = xor <2 x i64> %8, %0
54+
%10 = shufflevector <16 x i8> <i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 0, i8 0, i8 0, i8 0>, <16 x i8> %7, <16 x i32> <i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27>
55+
%11 = bitcast <16 x i8> %10 to <2 x i64>
56+
%12 = xor <2 x i64> %9, %11
57+
%13 = shufflevector <16 x i8> <i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 undef, i8 0, i8 0, i8 0, i8 0>, <16 x i8> %10, <16 x i32> <i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27>
58+
%14 = bitcast <16 x i8> %13 to <2 x i64>
59+
%15 = xor <2 x i64> %12, %5
60+
%16 = xor <2 x i64> %15, %14
61+
ret <2 x i64> %16""",
62+
uint64x2_lvec, Tuple{uint64x2_lvec, uint64x2_lvec},
63+
a.data, b.data
64+
) |> uint64x2
65+
66+
function _aesni_expand!(k::AESNIKey, rkey::uint64x2)
67+
k.key1 = rkey
68+
tmp = _aes_key_gen_assist(rkey, Val(0x1))
69+
rkey = _aes_128_assist(rkey, tmp)
70+
k.key2 = rkey
71+
72+
tmp = _aes_key_gen_assist(rkey, Val(0x2))
73+
rkey = _aes_128_assist(rkey, tmp)
74+
k.key3 = rkey
75+
76+
tmp = _aes_key_gen_assist(rkey, Val(0x4))
77+
rkey = _aes_128_assist(rkey, tmp)
78+
k.key4 = rkey
79+
80+
tmp = _aes_key_gen_assist(rkey, Val(0x8))
81+
rkey = _aes_128_assist(rkey, tmp)
82+
k.key5 = rkey
83+
84+
tmp = _aes_key_gen_assist(rkey, Val(0x10))
85+
rkey = _aes_128_assist(rkey, tmp)
86+
k.key6 = rkey
87+
88+
tmp = _aes_key_gen_assist(rkey, Val(0x20))
89+
rkey = _aes_128_assist(rkey, tmp)
90+
k.key7 = rkey
91+
92+
tmp = _aes_key_gen_assist(rkey, Val(0x40))
93+
rkey = _aes_128_assist(rkey, tmp)
94+
k.key8 = rkey
95+
96+
tmp = _aes_key_gen_assist(rkey, Val(0x80))
97+
rkey = _aes_128_assist(rkey, tmp)
98+
k.key9 = rkey
99+
100+
tmp = _aes_key_gen_assist(rkey, Val(0x1b))
101+
rkey = _aes_128_assist(rkey, tmp)
102+
k.key10 = rkey
103+
104+
tmp = _aes_key_gen_assist(rkey, Val(0x36))
105+
rkey = _aes_128_assist(rkey, tmp)
106+
k.key11 = rkey
107+
108+
k
109+
end
110+
111+
AESNIKey(key::UInt128) = _aesni_expand!(AESNIKey(), uint64x2(key))
112+
113+
"""
114+
```julia
115+
AESNI1x <: AbstractAESNI1x
116+
AESNI1x([seed])
117+
```
118+
119+
AESNI1x is one kind of AESNI Counter-Based RNGs. It generates one `UInt128` number at a time.
120+
121+
`seed` is an `Integer` which will be automatically converted to `UInt128`.
122+
123+
Only available when [`R123_USE_AESNI`](@ref).
124+
"""
125+
mutable struct AESNI1x <: AbstractAESNI1x
126+
x::uint64x2
127+
ctr::uint64x2
128+
key::AESNIKey
129+
end
130+
131+
function AESNI1x(seed::Integer=gen_seed(UInt128))
132+
r = AESNI1x(0, 0, AESNIKey())
133+
seed!(r, seed)
134+
r
135+
end
136+
137+
function seed!(r::AESNI1x, seed::Integer=gen_seed(UInt128))
138+
r.x = zero(uint64x2)
139+
r.ctr = zero(uint64x2)
140+
_aesni_expand!(r.key, uint64x2(seed % UInt128))
141+
random123_r(r)
142+
r
143+
end
144+
145+
seed_type(::Type{AESNI1x}) = UInt128
146+
147+
function copyto!(dest::AESNI1x, src::AESNI1x)
148+
dest.x = src.x
149+
dest.ctr = src.ctr
150+
copyto!(dest.key, src.key)
151+
dest
152+
end
153+
154+
copy(src::AESNI1x) = copyto!(AESNI1x(), src)
155+
156+
==(r1::AESNI1x, r2::AESNI1x) = r1.x == r2.x && r1.key == r2.key && r1.ctr == r2.ctr
157+
158+
"""
159+
```julia
160+
AESNI4x <: AbstractAESNI4x
161+
AESNI4x([seed])
162+
```
163+
164+
AESNI4x is one kind of AESNI Counter-Based RNGs. It generates four `UInt32` numbers at a time.
165+
166+
`seed` is a `Tuple` of four `Integer`s which will all be automatically converted to `UInt32`.
167+
168+
Only available when [`R123_USE_AESNI`](@ref).
169+
"""
170+
mutable struct AESNI4x <: AbstractAESNI4x
171+
x::uint64x2
172+
ctr1::uint64x2
173+
key::AESNIKey
174+
p::Int
175+
end
176+
177+
function AESNI4x(seed::NTuple{4, Integer}=gen_seed(UInt32, 4))
178+
r = AESNI4x(zero(uint64x2), zero(uint64x2), AESNIKey(), 0)
179+
seed!(r, seed)
180+
r
181+
end
182+
183+
function seed!(r::AESNI4x, seed::NTuple{4, Integer}=gen_seed(UInt32, 4))
184+
key = union_uint(Tuple(x % UInt32 for x in seed))
185+
r.ctr1 = 0
186+
_aesni_expand!(r.key, uint64x2(key))
187+
r.p = 0
188+
random123_r(r)
189+
r
190+
end
191+
192+
seed_type(::Type{AESNI4x}) = NTuple{4, UInt32}
193+
194+
function copyto!(dest::AESNI4x, src::AESNI4x)
195+
unsafe_copyto!(dest, src, UInt128, 2)
196+
copyto!(dest.key, src.key)
197+
dest.p = src.p
198+
dest
199+
end
200+
201+
copy(src::AESNI4x) = copyto!(AESNI4x(), src)
202+
==(r1::AESNI4x, r2::AESNI4x) = unsafe_compare(r1, r2, UInt128, 2) &&
203+
r1.key == r2.key && r1.p == r2.p
204+
205+
function get_key_uint64x2(o::Union{AESNI1x, AESNI4x})::NTuple{11, uint64x2}
206+
k = o.key
207+
(k.key1,k.key2,k.key3,k.key4,k.key5,k.key6,k.key7,k.key8,k.key9,k.key10,k.key11)
208+
end
209+
get_ctr_uint64x2(o::AESNI4x)::Tuple{uint64x2} = (o.ctr1,)
210+
get_ctr_uint64x2(o::AESNI1x)::Tuple{uint64x2} = (o.ctr,)
211+
get_key(o::Union{AESNI1x, AESNI4x})::NTuple{11,UInt128} = map(UInt128, get_key_uint64x2(o))
212+
get_ctr(o::Union{AESNI1x, AESNI4x})::Tuple{UInt128} = map(UInt128, get_ctr_uint64x2(o))
213+
214+
@inline function aesni(key::NTuple{11,uint64x2}, ctr::Tuple{uint64x2})::Tuple{uint64x2}
215+
key1, key2, key3, key4, key5, key6, key7, key8, key9, key10, key11 = key
216+
ctr1 = only(ctr)
217+
x = key1 ctr1
218+
x = _aes_enc(x, key2)
219+
x = _aes_enc(x, key3)
220+
x = _aes_enc(x, key4)
221+
x = _aes_enc(x, key5)
222+
x = _aes_enc(x, key6)
223+
x = _aes_enc(x, key7)
224+
x = _aes_enc(x, key8)
225+
x = _aes_enc(x, key9)
226+
x = _aes_enc(x, key10)
227+
x = _aes_enc_last(x, key11)
228+
(x,)
229+
end
230+
231+
"""
232+
aesni(key::NTuple{11,UInt128}, ctr::Tuple{UInt128})::Tuple{UInt128}
233+
234+
Functional variant of [`AESNI1x`](@ref) and [`AESNI4x`](@ref).
235+
This function if free of mutability and side effects.
236+
"""
237+
@inline function aesni(key::NTuple{11,UInt128}, ctr::Tuple{UInt128})::Tuple{UInt128}
238+
k = map(uint64x2, key)
239+
c = map(uint64x2, ctr)
240+
map(UInt128,aesni(k,c))
241+
end
242+
243+
244+
@inline function random123_r(r::AESNI1x)
245+
r.x = only(aesni(get_key_uint64x2(r), get_ctr_uint64x2(r)))
246+
(UInt128(r.x),)
247+
end
248+
249+
@inline function random123_r(r::AESNI4x)
250+
r.x = only(aesni(get_key_uint64x2(r), get_ctr_uint64x2(r)))
251+
split_uint(UInt128(r.x), UInt32)
252+
end

0 commit comments

Comments
 (0)