Skip to content

Commit 8dbb320

Browse files
committed
Refactor specialiazation constants to use dataclass
also adds spec_id, itemsize, and default_value fields
1 parent d2c6bb6 commit 8dbb320

3 files changed

Lines changed: 151 additions & 18 deletions

File tree

dpctl/program/utils/_utils.py

Lines changed: 104 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
"""Implements various utilities for the dpctl.program module."""
1818

19+
from dataclasses import dataclass
1920
from enum import IntEnum
2021

2122
import numpy as np
@@ -29,16 +30,60 @@ class SpirvOpCode(IntEnum):
2930
OpSpecConstantTrue = 48
3031
OpSpecConstantFalse = 49
3132
OpSpecConstant = 50
33+
OpFunction = 54
3234
OpDecorate = 71
3335

3436

3537
class SpirvDecoration(IntEnum):
3638
SpecId = 1
3739

3840

41+
@dataclass(frozen=True)
42+
class SpecializationConstantInfo:
43+
"""Data class representing specialization constant information."""
44+
45+
spec_id: int
46+
dtype: str
47+
name: str
48+
itemsize: int
49+
default_value: int | float | bool | None
50+
51+
3952
def parse_spirv_specializations(
4053
spv_bytes: bytes | bytearray | memoryview,
41-
) -> dict[int, dict[str, str]]:
54+
) -> tuple[SpecializationConstantInfo]:
55+
"""
56+
Parses SPIR-V byte stream to extract information about specializations,
57+
including the specialization IDs, types, names, and default values.
58+
59+
Note that the dtype information may be imprecise, as the compiler may
60+
choose to, for example, represent a bool as char, or may represent both
61+
signed and unsigned integers as unsigned integer bit buckets of the same
62+
length.
63+
64+
Args:
65+
spv_bytes (bytes | bytearray | memoryview):
66+
the SPIR-V byte stream.
67+
68+
Returns:
69+
tuple[SpecializationConstantInfo]:
70+
a tuple of parsed constants and their information represented by
71+
`SpecializationConstantInfo` objects, sorted by their
72+
specialization IDs. The length of the tuple is equal to the number
73+
of specialization constants found. Each
74+
`SpecializationConstantInfo` object contains the following
75+
attributes:
76+
77+
- `spec_id` (int): The specialization ID.
78+
- `dtype` (str): A NumPy style string representing the data type.
79+
- `itemsize` (int): The size of the specialization constant in
80+
bytes.
81+
- `name` (str): The variable name. If not preserved in the binary,
82+
a default name in the format `unnamed_spec_const_{spec_id}` is
83+
used.
84+
- `default_value` (int | float | bool | None): The default value of
85+
the specialization constant. If not specified, `None` is used.
86+
"""
4287
words = np.frombuffer(spv_bytes, dtype=np.uint32)
4388

4489
# verify magic number
@@ -49,6 +94,7 @@ def parse_spirv_specializations(
4994
ids = {}
5095
names = {}
5196
constants = {}
97+
defaults = {}
5298

5399
i = 5 # skip 5 word header
54100
while i < len(words):
@@ -59,27 +105,45 @@ def parse_spirv_specializations(
59105
if word_count == 0:
60106
raise ValueError(f"Invalid SPIR-V instruction at word index {i}")
61107

62-
if opcode == SpirvOpCode.OpTypeBool:
108+
if opcode == SpirvOpCode.OpFunction:
109+
# everything following is not relevant to specialization constant
110+
# parsing, so we can stop parsing at this point
111+
break
112+
elif opcode == SpirvOpCode.OpTypeBool:
63113
result_id = int(words[i + 1])
64-
types[result_id] = "?"
114+
types[result_id] = {"dtype": "?", "itemsize": 1}
65115
elif opcode == SpirvOpCode.OpTypeInt:
66116
result_id = int(words[i + 1])
67117
width = int(words[i + 2])
68118
signed = int(words[i + 3])
69119
prefix = "i" if signed else "u"
70-
types[result_id] = f"{prefix}{width // 8}"
120+
types[result_id] = {
121+
"dtype": f"{prefix}{width // 8}",
122+
"itemsize": width // 8,
123+
}
71124
elif opcode == SpirvOpCode.OpTypeFloat:
72125
result_id = int(words[i + 1])
73126
width = int(words[i + 2])
74-
types[result_id] = f"f{width // 8}"
75-
elif opcode in (
76-
SpirvOpCode.OpSpecConstant,
77-
SpirvOpCode.OpSpecConstantTrue,
78-
SpirvOpCode.OpSpecConstantFalse,
79-
):
127+
types[result_id] = {
128+
"dtype": f"f{width // 8}",
129+
"itemsize": width // 8,
130+
}
131+
elif opcode == SpirvOpCode.OpSpecConstant:
132+
type_id = int(words[i + 1])
133+
result_id = int(words[i + 2])
134+
constants[result_id] = type_id
135+
literal_words = words[i + 3 : i + word_count]
136+
defaults[result_id] = literal_words.tobytes()
137+
elif opcode == SpirvOpCode.OpSpecConstantTrue:
80138
type_id = int(words[i + 1])
81139
result_id = int(words[i + 2])
82140
constants[result_id] = type_id
141+
defaults[result_id] = True
142+
elif opcode == SpirvOpCode.OpSpecConstantFalse:
143+
type_id = int(words[i + 1])
144+
result_id = int(words[i + 2])
145+
constants[result_id] = type_id
146+
defaults[result_id] = False
83147
elif opcode == SpirvOpCode.OpDecorate:
84148
target_id = int(words[i + 1])
85149
decoration = int(words[i + 2])
@@ -92,15 +156,37 @@ def parse_spirv_specializations(
92156

93157
i += word_count
94158

95-
result = {}
159+
# a spec ID may appear multiple times in the same binary with different
160+
# target IDs. We only need to keep one, so skip duplicates
161+
unique_ids = set()
162+
result = []
96163
for target_id, spec_id in ids.items():
164+
if spec_id in unique_ids:
165+
continue
166+
unique_ids.add(spec_id)
97167
type_id = constants.get(target_id)
98-
dtype_str = types.get(type_id, "unknown_type")
168+
type_info = types.get(type_id, {"dtype": "unknown_type", "itemsize": 0})
99169
name = names.get(target_id, f"unnamed_spec_const_{spec_id}")
100170

101-
result[spec_id] = {
102-
"name": name,
103-
"dtype": dtype_str,
104-
}
105-
106-
return result
171+
dtype_str = type_info["dtype"]
172+
raw_default = defaults.get(target_id)
173+
default_value = None
174+
if isinstance(raw_default, bytes):
175+
try:
176+
default_value = np.frombuffer(raw_default, dtype=dtype_str)[
177+
0
178+
].item()
179+
except Exception:
180+
default_value = None
181+
182+
result.append(
183+
SpecializationConstantInfo(
184+
spec_id=spec_id,
185+
dtype=dtype_str,
186+
name=name,
187+
itemsize=type_info["itemsize"],
188+
default_value=default_value,
189+
)
190+
)
191+
192+
return tuple(sorted(result, key=lambda x: x.spec_id))
72 Bytes
Binary file not shown.

dpctl/tests/test_sycl_program.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import dpctl
2525
import dpctl.program as dpctl_prog
26+
from dpctl.program.utils import parse_spirv_specializations
2627

2728

2829
def get_spirv_abspath(fn):
@@ -341,3 +342,49 @@ def test_create_kernel_bundle_with_composite_spec_const():
341342

342343
# 1.0 * 10 + 2.5 = 12.5
343344
assert np.all(y == 12.5)
345+
346+
347+
def test_spirv_specializations_parser():
348+
spirv_file = get_spirv_abspath("specialization_constant_kernel.spv")
349+
with open(spirv_file, "rb") as spv:
350+
spv_bytes = spv.read()
351+
spec_consts = parse_spirv_specializations(spv_bytes)
352+
assert len(spec_consts) == 1
353+
assert spec_consts[0].dtype == "u4"
354+
355+
spirv_file = get_spirv_abspath("specialization_constant_composite.spv")
356+
with open(spirv_file, "rb") as spv:
357+
spv_bytes = spv.read()
358+
359+
spec_consts = parse_spirv_specializations(spv_bytes)
360+
assert len(spec_consts) == 3
361+
spec_const0, spec_const1, spec_const2 = spec_consts
362+
assert spec_const0.dtype == "u4"
363+
assert spec_const0.itemsize == 4
364+
assert spec_const0.name == "unnamed_spec_const_0"
365+
assert spec_const0.default_value == 1
366+
367+
assert spec_const1.dtype == "f4"
368+
assert spec_const1.itemsize == 4
369+
assert spec_const1.name == "unnamed_spec_const_1"
370+
assert spec_const1.default_value == 0
371+
372+
# compiler translates bool to char
373+
assert spec_const2.dtype == "u1"
374+
assert spec_const2.itemsize == 1
375+
assert spec_const2.name == "unnamed_spec_const_2"
376+
assert spec_const2.default_value == 0
377+
378+
379+
def test_spirv_specializations_parser_no_spec_consts():
380+
spirv_file = get_spirv_abspath("multi_kernel.spv")
381+
with open(spirv_file, "rb") as spv:
382+
spv_bytes = spv.read()
383+
spec_consts = parse_spirv_specializations(spv_bytes)
384+
assert not spec_consts
385+
386+
387+
def test_spirv_specializations_parser_invalid_spirv():
388+
invalid_spv = b"\x00\x01\x02\x03\x04\x05"
389+
with pytest.raises(ValueError):
390+
parse_spirv_specializations(invalid_spv)

0 commit comments

Comments
 (0)