Skip to content

Commit d2c6bb6

Browse files
committed
add program.utils namespace with SPIRV parser
1 parent d8b0afb commit d2c6bb6

3 files changed

Lines changed: 136 additions & 0 deletions

File tree

dpctl/program/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
"SpecializationConstant",
4646
]
4747

48+
# add submodules
49+
__all__ += [
50+
"utils",
51+
]
52+
4853

4954
def __getattr__(name):
5055
if name == "SyclProgram":

dpctl/program/utils/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""
18+
A collection of utility functions for dpctl.program module.
19+
"""
20+
21+
from ._utils import parse_spirv_specializations
22+
23+
__all__ = [
24+
"parse_spirv_specializations",
25+
]

dpctl/program/utils/_utils.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Implements various utilities for the dpctl.program module."""
18+
19+
from enum import IntEnum
20+
21+
import numpy as np
22+
23+
24+
class SpirvOpCode(IntEnum):
25+
OpName = 5
26+
OpTypeBool = 20
27+
OpTypeInt = 21
28+
OpTypeFloat = 22
29+
OpSpecConstantTrue = 48
30+
OpSpecConstantFalse = 49
31+
OpSpecConstant = 50
32+
OpDecorate = 71
33+
34+
35+
class SpirvDecoration(IntEnum):
36+
SpecId = 1
37+
38+
39+
def parse_spirv_specializations(
40+
spv_bytes: bytes | bytearray | memoryview,
41+
) -> dict[int, dict[str, str]]:
42+
words = np.frombuffer(spv_bytes, dtype=np.uint32)
43+
44+
# verify magic number
45+
if len(words) < 5 or words[0] != 0x07230203:
46+
raise ValueError("Invalid SPIR-V binary")
47+
48+
types = {}
49+
ids = {}
50+
names = {}
51+
constants = {}
52+
53+
i = 5 # skip 5 word header
54+
while i < len(words):
55+
word = words[i]
56+
opcode = word & 0xFFFF
57+
word_count = word >> 16
58+
59+
if word_count == 0:
60+
raise ValueError(f"Invalid SPIR-V instruction at word index {i}")
61+
62+
if opcode == SpirvOpCode.OpTypeBool:
63+
result_id = int(words[i + 1])
64+
types[result_id] = "?"
65+
elif opcode == SpirvOpCode.OpTypeInt:
66+
result_id = int(words[i + 1])
67+
width = int(words[i + 2])
68+
signed = int(words[i + 3])
69+
prefix = "i" if signed else "u"
70+
types[result_id] = f"{prefix}{width // 8}"
71+
elif opcode == SpirvOpCode.OpTypeFloat:
72+
result_id = int(words[i + 1])
73+
width = int(words[i + 2])
74+
types[result_id] = f"f{width // 8}"
75+
elif opcode in (
76+
SpirvOpCode.OpSpecConstant,
77+
SpirvOpCode.OpSpecConstantTrue,
78+
SpirvOpCode.OpSpecConstantFalse,
79+
):
80+
type_id = int(words[i + 1])
81+
result_id = int(words[i + 2])
82+
constants[result_id] = type_id
83+
elif opcode == SpirvOpCode.OpDecorate:
84+
target_id = int(words[i + 1])
85+
decoration = int(words[i + 2])
86+
if decoration == SpirvDecoration.SpecId:
87+
ids[target_id] = int(words[i + 3])
88+
elif opcode == SpirvOpCode.OpName:
89+
target_id = int(words[i + 1])
90+
name_bytes = words[i + 2 : i + word_count].tobytes()
91+
names[target_id] = name_bytes.split(b"\x00", 1)[0].decode("utf-8")
92+
93+
i += word_count
94+
95+
result = {}
96+
for target_id, spec_id in ids.items():
97+
type_id = constants.get(target_id)
98+
dtype_str = types.get(type_id, "unknown_type")
99+
name = names.get(target_id, f"unnamed_spec_const_{spec_id}")
100+
101+
result[spec_id] = {
102+
"name": name,
103+
"dtype": dtype_str,
104+
}
105+
106+
return result

0 commit comments

Comments
 (0)