|
18 | 18 |
|
19 | 19 | import os |
20 | 20 |
|
| 21 | +import numpy as np |
21 | 22 | import pytest |
22 | 23 |
|
23 | 24 | import dpctl |
@@ -262,3 +263,40 @@ def test_create_kernel_bundle_from_invalid_src_ocl(): |
262 | 263 | }" |
263 | 264 | with pytest.raises(dpctl_prog.SyclKernelBundleCompilationError): |
264 | 265 | dpctl_prog.create_kernel_bundle_from_source(q, invalid_oclSrc) |
| 266 | + |
| 267 | + |
| 268 | +def test_create_kernel_bundle_with_spec_const(): |
| 269 | + try: |
| 270 | + q = dpctl.SyclQueue() |
| 271 | + except dpctl.SyclQueueCreationError: |
| 272 | + pytest.skip("Could not create default queue") |
| 273 | + |
| 274 | + spec_id = 0 |
| 275 | + sp = dpctl_prog.SpecializationConstant(spec_id, "i4", 42) |
| 276 | + |
| 277 | + spirv_file = get_spirv_abspath("specialization_constant_kernel.spv") |
| 278 | + with open(spirv_file, "br") as spv: |
| 279 | + spv_bytes = spv.read() |
| 280 | + |
| 281 | + kb = dpctl_prog.create_kernel_bundle_from_spirv( |
| 282 | + q, spv_bytes, specializations=[sp] |
| 283 | + ) |
| 284 | + kernel = kb.get_sycl_kernel("_ZTS20BasicSpecConstKernel") |
| 285 | + |
| 286 | + n = 128 |
| 287 | + x = np.ones(n, dtype="i4") |
| 288 | + y = np.zeros_like(x) |
| 289 | + |
| 290 | + x_usm = dpctl.memory.MemoryUSMDevice(x.nbytes, queue=q) |
| 291 | + y_usm = dpctl.memory.MemoryUSMDevice(y.nbytes, queue=q) |
| 292 | + |
| 293 | + e1 = q.memcpy_async(x_usm, x, x.nbytes) |
| 294 | + e2 = q.submit(kernel, [x_usm, y_usm], [n], dEvents=[e1]) |
| 295 | + e3 = q.memcpy_async(y, y_usm, y.nbytes, [e2]) |
| 296 | + |
| 297 | + ht_e = q._submit_keep_args_alive([x_usm], [e3]) |
| 298 | + |
| 299 | + e3.wait() |
| 300 | + ht_e.wait() |
| 301 | + |
| 302 | + assert np.all(y == 43) |
0 commit comments