Skip to content

Commit 0a68078

Browse files
committed
add test for SpecializationConstant use in kernel
1 parent 977f326 commit 0a68078

2 files changed

Lines changed: 38 additions & 0 deletions

File tree

2.23 KB
Binary file not shown.

dpctl/tests/test_sycl_program.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import os
2020

21+
import numpy as np
2122
import pytest
2223

2324
import dpctl
@@ -262,3 +263,40 @@ def test_create_kernel_bundle_from_invalid_src_ocl():
262263
}"
263264
with pytest.raises(dpctl_prog.SyclKernelBundleCompilationError):
264265
dpctl_prog.create_kernel_bundle_from_source(q, invalid_oclSrc)
266+
267+
268+
def test_create_kernel_bundle_with_spec_const():
269+
try:
270+
q = dpctl.SyclQueue()
271+
except dpctl.SyclQueueCreationError:
272+
pytest.skip("Could not create default queue")
273+
274+
spec_id = 0
275+
sp = dpctl_prog.SpecializationConstant(spec_id, "i4", 42)
276+
277+
spirv_file = get_spirv_abspath("specialization_constant_kernel.spv")
278+
with open(spirv_file, "br") as spv:
279+
spv_bytes = spv.read()
280+
281+
kb = dpctl_prog.create_kernel_bundle_from_spirv(
282+
q, spv_bytes, specializations=[sp]
283+
)
284+
kernel = kb.get_sycl_kernel("_ZTS20BasicSpecConstKernel")
285+
286+
n = 128
287+
x = np.ones(n, dtype="i4")
288+
y = np.zeros_like(x)
289+
290+
x_usm = dpctl.memory.MemoryUSMDevice(x.nbytes, queue=q)
291+
y_usm = dpctl.memory.MemoryUSMDevice(y.nbytes, queue=q)
292+
293+
e1 = q.memcpy_async(x_usm, x, x.nbytes)
294+
e2 = q.submit(kernel, [x_usm, y_usm], [n], dEvents=[e1])
295+
e3 = q.memcpy_async(y, y_usm, y.nbytes, [e2])
296+
297+
ht_e = q._submit_keep_args_alive([x_usm], [e3])
298+
299+
e3.wait()
300+
ht_e.wait()
301+
302+
assert np.all(y == 43)

0 commit comments

Comments
 (0)