1616
1717"""Implements various utilities for the dpctl.program module."""
1818
19+ from dataclasses import dataclass
1920from enum import IntEnum
2021
2122import numpy as np
@@ -31,6 +32,7 @@ class SpirvOpCode(IntEnum):
3132 OpSpecConstantTrue = 48
3233 OpSpecConstantFalse = 49
3334 OpSpecConstant = 50
35+ OpFunction = 54
3436 OpDecorate = 71
3537
3638
@@ -40,10 +42,52 @@ class SpirvDecoration(IntEnum):
4042 SpecId = 1
4143
4244
45+ @dataclass (frozen = True )
46+ class SpecializationConstantInfo :
47+ """Data class representing specialization constant information."""
48+
49+ spec_id : int
50+ dtype : str
51+ name : str
52+ itemsize : int
53+ default_value : int | float | bool | None
54+
55+
4356def parse_spirv_specializations (
4457 spv_bytes : bytes | bytearray | memoryview ,
45- ) -> dict [int , dict [str , str ]]:
46- """Parses SPIR-V binary to extract specialization constants."""
58+ ) -> tuple [SpecializationConstantInfo ]:
59+ """
60+ Parses SPIR-V byte stream to extract information about specializations,
61+ including the specialization IDs, types, and names.
62+
63+ Note that the dtype information may be imprecise, as the compiler may
64+ choose to, for example, represent a bool as char, or may represent both
65+ signed and unsigned integers as unsigned integer bit buckets of the same
66+ length.
67+
68+ Args:
69+ spv_bytes (Union[bytes, bytearray, memoryview]):
70+ the SPIR-V byte stream.
71+
72+ Returns:
73+ tuple[SpecializationConstantInfo]:
74+ a tuple of parsed constants and their information represented by
75+ `SpecializationConstantInfo` objects, sorted by their
76+ specialization IDs. The length of the tuple is equal to the number
77+ of specialization constants found. Each
78+ `SpecializationConstantInfo` object contains the following
79+ attributes:
80+
81+ - `spec_id` (int): The specialization ID.
82+ - `dtype` (str): A NumPy style string representing the data type.
83+ - `itemsize` (int): The size of the specialization constant in
84+ bytes.
85+ - `name` (str): The variable name. If not preserved in the binary,
86+ a default name in the format `unnamed_spec_const_{spec_id}` is
87+ used.
88+ - `default_value` (int | float | bool | None): The default value of
89+ the specialization constant. If not specified, `None` is used.
90+ """
4791 words = np .frombuffer (spv_bytes , dtype = np .uint32 )
4892
4993 # verify magic number
@@ -54,6 +98,7 @@ def parse_spirv_specializations(
5498 ids = {}
5599 names = {}
56100 constants = {}
101+ defaults = {}
57102
58103 i = 5 # skip 5 word header
59104 while i < len (words ):
@@ -64,27 +109,45 @@ def parse_spirv_specializations(
64109 if word_count == 0 :
65110 raise ValueError (f"Invalid SPIR-V instruction at word index { i } " )
66111
67- if opcode == SpirvOpCode .OpTypeBool :
112+ if opcode == SpirvOpCode .OpFunction :
113+ # everything following is not relevant to specialization constant
114+ # parsing, so we can stop parsing at this point
115+ break
116+ elif opcode == SpirvOpCode .OpTypeBool :
68117 result_id = int (words [i + 1 ])
69- types [result_id ] = "?"
118+ types [result_id ] = { "dtype" : "?" , "itemsize" : 1 }
70119 elif opcode == SpirvOpCode .OpTypeInt :
71120 result_id = int (words [i + 1 ])
72121 width = int (words [i + 2 ])
73122 signed = int (words [i + 3 ])
74123 prefix = "i" if signed else "u"
75- types [result_id ] = f"{ prefix } { width // 8 } "
124+ types [result_id ] = {
125+ "dtype" : f"{ prefix } { width // 8 } " ,
126+ "itemsize" : width // 8 ,
127+ }
76128 elif opcode == SpirvOpCode .OpTypeFloat :
77129 result_id = int (words [i + 1 ])
78130 width = int (words [i + 2 ])
79- types [result_id ] = f"f{ width // 8 } "
80- elif opcode in (
81- SpirvOpCode .OpSpecConstant ,
82- SpirvOpCode .OpSpecConstantTrue ,
83- SpirvOpCode .OpSpecConstantFalse ,
84- ):
131+ types [result_id ] = {
132+ "dtype" : f"f{ width // 8 } " ,
133+ "itemsize" : width // 8 ,
134+ }
135+ elif opcode == SpirvOpCode .OpSpecConstant :
136+ type_id = int (words [i + 1 ])
137+ result_id = int (words [i + 2 ])
138+ constants [result_id ] = type_id
139+ literal_words = words [i + 3 : i + word_count ]
140+ defaults [result_id ] = literal_words .tobytes ()
141+ elif opcode == SpirvOpCode .OpSpecConstantTrue :
85142 type_id = int (words [i + 1 ])
86143 result_id = int (words [i + 2 ])
87144 constants [result_id ] = type_id
145+ defaults [result_id ] = True
146+ elif opcode == SpirvOpCode .OpSpecConstantFalse :
147+ type_id = int (words [i + 1 ])
148+ result_id = int (words [i + 2 ])
149+ constants [result_id ] = type_id
150+ defaults [result_id ] = False
88151 elif opcode == SpirvOpCode .OpDecorate :
89152 target_id = int (words [i + 1 ])
90153 decoration = int (words [i + 2 ])
@@ -97,15 +160,37 @@ def parse_spirv_specializations(
97160
98161 i += word_count
99162
100- result = {}
163+ # a spec ID may appear multiple times in the same binary with different
164+ # target IDs. We only need to keep one, so skip duplicates
165+ unique_ids = set ()
166+ result = []
101167 for target_id , spec_id in ids .items ():
168+ if spec_id in unique_ids :
169+ continue
170+ unique_ids .add (spec_id )
102171 type_id = constants .get (target_id )
103- dtype_str = types .get (type_id , " unknown_type" )
172+ type_info = types .get (type_id , { "dtype" : " unknown_type", "itemsize" : 0 } )
104173 name = names .get (target_id , f"unnamed_spec_const_{ spec_id } " )
105174
106- result [spec_id ] = {
107- "name" : name ,
108- "dtype" : dtype_str ,
109- }
110-
111- return result
175+ dtype_str = type_info ["dtype" ]
176+ raw_default = defaults .get (target_id )
177+ default_value = None
178+ if isinstance (raw_default , bytes ):
179+ try :
180+ default_value = np .frombuffer (raw_default , dtype = dtype_str )[
181+ 0
182+ ].item ()
183+ except Exception :
184+ default_value = None
185+
186+ result .append (
187+ SpecializationConstantInfo (
188+ spec_id = spec_id ,
189+ dtype = dtype_str ,
190+ name = name ,
191+ itemsize = type_info ["itemsize" ],
192+ default_value = default_value ,
193+ )
194+ )
195+
196+ return tuple (sorted (result , key = lambda x : x .spec_id ))
0 commit comments