Skip to content

Commit b7f8d82

Browse files
committed
Refactor specialiazation constants to use dataclass
also adds spec_id, itemsize, and default_value fields
1 parent bf443aa commit b7f8d82

3 files changed

Lines changed: 150 additions & 19 deletions

File tree

dpctl/program/utils/_utils.py

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
"""Implements various utilities for the dpctl.program module."""
1818

19+
from dataclasses import dataclass
1920
from enum import IntEnum
2021

2122
import numpy as np
@@ -31,6 +32,7 @@ class SpirvOpCode(IntEnum):
3132
OpSpecConstantTrue = 48
3233
OpSpecConstantFalse = 49
3334
OpSpecConstant = 50
35+
OpFunction = 54
3436
OpDecorate = 71
3537

3638

@@ -40,10 +42,52 @@ class SpirvDecoration(IntEnum):
4042
SpecId = 1
4143

4244

45+
@dataclass(frozen=True)
46+
class SpecializationConstantInfo:
47+
"""Data class representing specialization constant information."""
48+
49+
spec_id: int
50+
dtype: str
51+
name: str
52+
itemsize: int
53+
default_value: int | float | bool | None
54+
55+
4356
def parse_spirv_specializations(
4457
spv_bytes: bytes | bytearray | memoryview,
45-
) -> dict[int, dict[str, str]]:
46-
"""Parses SPIR-V binary to extract specialization constants."""
58+
) -> tuple[SpecializationConstantInfo]:
59+
"""
60+
Parses SPIR-V byte stream to extract information about specializations,
61+
including the specialization IDs, types, and names.
62+
63+
Note that the dtype information may be imprecise, as the compiler may
64+
choose to, for example, represent a bool as char, or may represent both
65+
signed and unsigned integers as unsigned integer bit buckets of the same
66+
length.
67+
68+
Args:
69+
spv_bytes (Union[bytes, bytearray, memoryview]):
70+
the SPIR-V byte stream.
71+
72+
Returns:
73+
tuple[SpecializationConstantInfo]:
74+
a tuple of parsed constants and their information represented by
75+
`SpecializationConstantInfo` objects, sorted by their
76+
specialization IDs. The length of the tuple is equal to the number
77+
of specialization constants found. Each
78+
`SpecializationConstantInfo` object contains the following
79+
attributes:
80+
81+
- `spec_id` (int): The specialization ID.
82+
- `dtype` (str): A NumPy style string representing the data type.
83+
- `itemsize` (int): The size of the specialization constant in
84+
bytes.
85+
- `name` (str): The variable name. If not preserved in the binary,
86+
a default name in the format `unnamed_spec_const_{spec_id}` is
87+
used.
88+
- `default_value` (int | float | bool | None): The default value of
89+
the specialization constant. If not specified, `None` is used.
90+
"""
4791
words = np.frombuffer(spv_bytes, dtype=np.uint32)
4892

4993
# verify magic number
@@ -54,6 +98,7 @@ def parse_spirv_specializations(
5498
ids = {}
5599
names = {}
56100
constants = {}
101+
defaults = {}
57102

58103
i = 5 # skip 5 word header
59104
while i < len(words):
@@ -64,27 +109,45 @@ def parse_spirv_specializations(
64109
if word_count == 0:
65110
raise ValueError(f"Invalid SPIR-V instruction at word index {i}")
66111

67-
if opcode == SpirvOpCode.OpTypeBool:
112+
if opcode == SpirvOpCode.OpFunction:
113+
# everything following is not relevant to specialization constant
114+
# parsing, so we can stop parsing at this point
115+
break
116+
elif opcode == SpirvOpCode.OpTypeBool:
68117
result_id = int(words[i + 1])
69-
types[result_id] = "?"
118+
types[result_id] = {"dtype": "?", "itemsize": 1}
70119
elif opcode == SpirvOpCode.OpTypeInt:
71120
result_id = int(words[i + 1])
72121
width = int(words[i + 2])
73122
signed = int(words[i + 3])
74123
prefix = "i" if signed else "u"
75-
types[result_id] = f"{prefix}{width // 8}"
124+
types[result_id] = {
125+
"dtype": f"{prefix}{width // 8}",
126+
"itemsize": width // 8,
127+
}
76128
elif opcode == SpirvOpCode.OpTypeFloat:
77129
result_id = int(words[i + 1])
78130
width = int(words[i + 2])
79-
types[result_id] = f"f{width // 8}"
80-
elif opcode in (
81-
SpirvOpCode.OpSpecConstant,
82-
SpirvOpCode.OpSpecConstantTrue,
83-
SpirvOpCode.OpSpecConstantFalse,
84-
):
131+
types[result_id] = {
132+
"dtype": f"f{width // 8}",
133+
"itemsize": width // 8,
134+
}
135+
elif opcode == SpirvOpCode.OpSpecConstant:
136+
type_id = int(words[i + 1])
137+
result_id = int(words[i + 2])
138+
constants[result_id] = type_id
139+
literal_words = words[i + 3 : i + word_count]
140+
defaults[result_id] = literal_words.tobytes()
141+
elif opcode == SpirvOpCode.OpSpecConstantTrue:
85142
type_id = int(words[i + 1])
86143
result_id = int(words[i + 2])
87144
constants[result_id] = type_id
145+
defaults[result_id] = True
146+
elif opcode == SpirvOpCode.OpSpecConstantFalse:
147+
type_id = int(words[i + 1])
148+
result_id = int(words[i + 2])
149+
constants[result_id] = type_id
150+
defaults[result_id] = False
88151
elif opcode == SpirvOpCode.OpDecorate:
89152
target_id = int(words[i + 1])
90153
decoration = int(words[i + 2])
@@ -97,15 +160,37 @@ def parse_spirv_specializations(
97160

98161
i += word_count
99162

100-
result = {}
163+
# a spec ID may appear multiple times in the same binary with different
164+
# target IDs. We only need to keep one, so skip duplicates
165+
unique_ids = set()
166+
result = []
101167
for target_id, spec_id in ids.items():
168+
if spec_id in unique_ids:
169+
continue
170+
unique_ids.add(spec_id)
102171
type_id = constants.get(target_id)
103-
dtype_str = types.get(type_id, "unknown_type")
172+
type_info = types.get(type_id, {"dtype": "unknown_type", "itemsize": 0})
104173
name = names.get(target_id, f"unnamed_spec_const_{spec_id}")
105174

106-
result[spec_id] = {
107-
"name": name,
108-
"dtype": dtype_str,
109-
}
110-
111-
return result
175+
dtype_str = type_info["dtype"]
176+
raw_default = defaults.get(target_id)
177+
default_value = None
178+
if isinstance(raw_default, bytes):
179+
try:
180+
default_value = np.frombuffer(raw_default, dtype=dtype_str)[
181+
0
182+
].item()
183+
except Exception:
184+
default_value = None
185+
186+
result.append(
187+
SpecializationConstantInfo(
188+
spec_id=spec_id,
189+
dtype=dtype_str,
190+
name=name,
191+
itemsize=type_info["itemsize"],
192+
default_value=default_value,
193+
)
194+
)
195+
196+
return tuple(sorted(result, key=lambda x: x.spec_id))
72 Bytes
Binary file not shown.

dpctl/tests/test_sycl_program.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,49 @@ def test_create_kernel_bundle_with_composite_spec_const():
341341

342342
# 1.0 * 10 + 2.5 = 12.5
343343
assert np.all(y == 12.5)
344+
345+
346+
def test_spirv_specializations_parser():
347+
spirv_file = get_spirv_abspath("specialization_constant_kernel.spv")
348+
with open(spirv_file, "rb") as spv:
349+
spv_bytes = spv.read()
350+
spec_consts = dpctl_prog.utils.parse_spirv_specializations(spv_bytes)
351+
assert len(spec_consts) == 1
352+
assert spec_consts[0].dtype == "u4"
353+
354+
spirv_file = get_spirv_abspath("specialization_constant_composite.spv")
355+
with open(spirv_file, "rb") as spv:
356+
spv_bytes = spv.read()
357+
358+
spec_consts = dpctl_prog.utils.parse_spirv_specializations(spv_bytes)
359+
assert len(spec_consts) == 3
360+
spec_const0, spec_const1, spec_const2 = spec_consts
361+
assert spec_const0.dtype == "u4"
362+
assert spec_const0.itemsize == 4
363+
assert spec_const0.name == "unnamed_spec_const_0"
364+
assert spec_const0.default_value == 1
365+
366+
assert spec_const1.dtype == "f4"
367+
assert spec_const1.itemsize == 4
368+
assert spec_const1.name == "unnamed_spec_const_1"
369+
assert spec_const1.default_value == 0
370+
371+
# compiler translates bool to char
372+
assert spec_const2.dtype == "u1"
373+
assert spec_const2.itemsize == 1
374+
assert spec_const2.name == "unnamed_spec_const_2"
375+
assert spec_const2.default_value == 0
376+
377+
378+
def test_spirv_specializations_parser_no_spec_consts():
379+
spirv_file = get_spirv_abspath("multi_kernel.spv")
380+
with open(spirv_file, "rb") as spv:
381+
spv_bytes = spv.read()
382+
spec_consts = dpctl_prog.utils.parse_spirv_specializations(spv_bytes)
383+
assert not spec_consts
384+
385+
386+
def test_spirv_specializations_parser_invalid_spirv():
387+
invalid_spv = b"\x00\x01\x02\x03\x04\x05"
388+
with pytest.raises(ValueError):
389+
dpctl_prog.utils.parse_spirv_specializations(invalid_spv)

0 commit comments

Comments
 (0)