Skip to content

Commit 8dca526

Browse files
committed
add program.utils namespace with SPIRV parser
1 parent d8b0afb commit 8dca526

3 files changed

Lines changed: 142 additions & 0 deletions

File tree

dpctl/program/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
2222
"""
2323

24+
from . import utils
2425
from ._program import (
2526
SpecializationConstant,
2627
SyclKernel,
@@ -45,6 +46,11 @@
4546
"SpecializationConstant",
4647
]
4748

49+
# add submodules
50+
__all__ += [
51+
"utils",
52+
]
53+
4854

4955
def __getattr__(name):
5056
if name == "SyclProgram":

dpctl/program/utils/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""
18+
A collection of utility functions for dpctl.program module.
19+
"""
20+
21+
from ._utils import parse_spirv_specializations
22+
23+
__all__ = [
24+
"parse_spirv_specializations",
25+
]

dpctl/program/utils/_utils.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Implements various utilities for the dpctl.program module."""
18+
19+
from enum import IntEnum
20+
21+
import numpy as np
22+
23+
24+
class SpirvOpCode(IntEnum):
25+
"""SPIR-V operation codes used in the dpctl.program module."""
26+
27+
OpName = 5
28+
OpTypeBool = 20
29+
OpTypeInt = 21
30+
OpTypeFloat = 22
31+
OpSpecConstantTrue = 48
32+
OpSpecConstantFalse = 49
33+
OpSpecConstant = 50
34+
OpDecorate = 71
35+
36+
37+
class SpirvDecoration(IntEnum):
38+
"""SPIR-V decoration codes used in the dpctl.program module."""
39+
40+
SpecId = 1
41+
42+
43+
def parse_spirv_specializations(
44+
spv_bytes: bytes | bytearray | memoryview,
45+
) -> dict[int, dict[str, str]]:
46+
"""Parses SPIR-V binary to extract specialization constants."""
47+
words = np.frombuffer(spv_bytes, dtype=np.uint32)
48+
49+
# verify magic number
50+
if len(words) < 5 or words[0] != 0x07230203:
51+
raise ValueError("Invalid SPIR-V binary")
52+
53+
types = {}
54+
ids = {}
55+
names = {}
56+
constants = {}
57+
58+
i = 5 # skip 5 word header
59+
while i < len(words):
60+
word = words[i]
61+
opcode = word & 0xFFFF
62+
word_count = word >> 16
63+
64+
if word_count == 0:
65+
raise ValueError(f"Invalid SPIR-V instruction at word index {i}")
66+
67+
if opcode == SpirvOpCode.OpTypeBool:
68+
result_id = int(words[i + 1])
69+
types[result_id] = "?"
70+
elif opcode == SpirvOpCode.OpTypeInt:
71+
result_id = int(words[i + 1])
72+
width = int(words[i + 2])
73+
signed = int(words[i + 3])
74+
prefix = "i" if signed else "u"
75+
types[result_id] = f"{prefix}{width // 8}"
76+
elif opcode == SpirvOpCode.OpTypeFloat:
77+
result_id = int(words[i + 1])
78+
width = int(words[i + 2])
79+
types[result_id] = f"f{width // 8}"
80+
elif opcode in (
81+
SpirvOpCode.OpSpecConstant,
82+
SpirvOpCode.OpSpecConstantTrue,
83+
SpirvOpCode.OpSpecConstantFalse,
84+
):
85+
type_id = int(words[i + 1])
86+
result_id = int(words[i + 2])
87+
constants[result_id] = type_id
88+
elif opcode == SpirvOpCode.OpDecorate:
89+
target_id = int(words[i + 1])
90+
decoration = int(words[i + 2])
91+
if decoration == SpirvDecoration.SpecId:
92+
ids[target_id] = int(words[i + 3])
93+
elif opcode == SpirvOpCode.OpName:
94+
target_id = int(words[i + 1])
95+
name_bytes = words[i + 2 : i + word_count].tobytes()
96+
names[target_id] = name_bytes.split(b"\x00", 1)[0].decode("utf-8")
97+
98+
i += word_count
99+
100+
result = {}
101+
for target_id, spec_id in ids.items():
102+
type_id = constants.get(target_id)
103+
dtype_str = types.get(type_id, "unknown_type")
104+
name = names.get(target_id, f"unnamed_spec_const_{spec_id}")
105+
106+
result[spec_id] = {
107+
"name": name,
108+
"dtype": dtype_str,
109+
}
110+
111+
return result

0 commit comments

Comments
 (0)