Skip to content

Commit 382ea0a

Browse files
committed
feat(pjrt): add program execution, serialization, and full StableHLO emitter
T60.3.2: LoadedExecutable.Execute() wraps PJRT_LoadedExecutable_Execute T60.3.3: Serialize/DeserializeAndLoad for compiled executables T61.3.1: EmitProgram() assembles complete StableHLO MLIR modules from slot-based op sequences, routing to element-wise, structural, and reduction emitters
1 parent aa8c170 commit 382ea0a

4 files changed

Lines changed: 735 additions & 5 deletions

File tree

internal/pjrt/executable.go

Lines changed: 150 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package pjrt
33
import (
44
"fmt"
55
"unsafe"
6-
7-
"github.com/zerfoo/ztensor/internal/cuda"
86
)
97

108
// LoadedExecutable wraps a PJRT_LoadedExecutable handle returned by
@@ -79,7 +77,7 @@ func (c *Client) Compile(stablehloMLIR string) (*LoadedExecutable, error) {
7977
program: uintptr(unsafe.Pointer(&program)),
8078
}
8179

82-
errPtr := cuda.Ccall(c.lib.PJRT_Client_Compile, uintptr(unsafe.Pointer(&args)))
80+
errPtr := ccall(c.lib.PJRT_Client_Compile, uintptr(unsafe.Pointer(&args)))
8381
if err := c.lib.checkError(errPtr); err != nil {
8482
return nil, fmt.Errorf("PJRT_Client_Compile: %w", err)
8583
}
@@ -140,11 +138,158 @@ func (e *LoadedExecutable) Close() error {
140138
structSize: unsafe.Sizeof(destroyArgs{}),
141139
executable: e.handle,
142140
}
143-
errPtr := cuda.Ccall(e.lib.PJRT_LoadedExecutable_Destroy, uintptr(unsafe.Pointer(&args)))
141+
errPtr := ccall(e.lib.PJRT_LoadedExecutable_Destroy, uintptr(unsafe.Pointer(&args)))
144142
e.handle = 0
145143
return e.lib.checkError(errPtr)
146144
}
147145

146+
// ExecOption configures Execute behavior.
147+
type ExecOption func(*execConfig)
148+
149+
type execConfig struct {
150+
// device ordinal to execute on (0 = first addressable device).
151+
deviceOrdinal int
152+
// donateInputs indicates that the runtime may take ownership of
153+
// input buffers, avoiding a copy. The caller must not use the
154+
// donated buffers after Execute returns.
155+
donateInputs []bool
156+
}
157+
158+
// WithDeviceOrdinal selects which device to execute on.
159+
func WithDeviceOrdinal(ordinal int) ExecOption {
160+
return func(c *execConfig) {
161+
c.deviceOrdinal = ordinal
162+
}
163+
}
164+
165+
// WithInputDonation marks specific inputs for buffer donation.
166+
// donated[i] == true means input i may be consumed by the runtime.
167+
func WithInputDonation(donated []bool) ExecOption {
168+
return func(c *execConfig) {
169+
c.donateInputs = donated
170+
}
171+
}
172+
173+
// Execute runs the compiled program with the given input buffers and
174+
// returns the output buffers. The caller owns the returned buffers and
175+
// must close them when done.
176+
//
177+
//go:nocheckptr
178+
func (e *LoadedExecutable) Execute(inputs []*Buffer, opts ...ExecOption) ([]*Buffer, error) {
179+
if e.handle == 0 {
180+
return nil, fmt.Errorf("pjrt: cannot execute closed executable")
181+
}
182+
if e.lib.PJRT_LoadedExecutable_Execute == 0 {
183+
return nil, fmt.Errorf("pjrt: plugin does not support PJRT_LoadedExecutable_Execute")
184+
}
185+
186+
var cfg execConfig
187+
for _, o := range opts {
188+
o(&cfg)
189+
}
190+
191+
// Build the flat array of input buffer handles.
192+
numInputs := len(inputs)
193+
inputHandles := make([]uintptr, numInputs)
194+
for i, buf := range inputs {
195+
if buf == nil || buf.Handle() == 0 {
196+
return nil, fmt.Errorf("pjrt: input buffer %d is nil or closed", i)
197+
}
198+
inputHandles[i] = buf.Handle()
199+
}
200+
201+
var inputHandlesPtr uintptr
202+
if numInputs > 0 {
203+
inputHandlesPtr = uintptr(unsafe.Pointer(&inputHandles[0]))
204+
}
205+
206+
// Allocate output buffer handle slots. PJRT writes one PJRT_Buffer*
207+
// per output per device. We execute on a single device.
208+
numOutputs := e.numOutputs
209+
outputHandles := make([]uintptr, numOutputs)
210+
var outputHandlesPtr uintptr
211+
if numOutputs > 0 {
212+
outputHandlesPtr = uintptr(unsafe.Pointer(&outputHandles[0]))
213+
}
214+
215+
// PJRT expects a pointer-to-pointer for the output list (one list
216+
// per device). We execute on one device, so we have a single list.
217+
outputListPtr := outputHandlesPtr
218+
outputListsPtr := uintptr(unsafe.Pointer(&outputListPtr))
219+
220+
// PJRT_LoadedExecutable_Execute_Args:
221+
// struct_size uintptr
222+
// executable uintptr (PJRT_LoadedExecutable*)
223+
// options uintptr (PJRT_ExecuteOptions*, may be 0)
224+
// argument_lists uintptr (PJRT_Buffer* const* const*, one list per device)
225+
// num_devices uintptr (size_t)
226+
// num_args uintptr (size_t)
227+
// output_lists uintptr (PJRT_Buffer** const*, out: one list per device)
228+
// device_complete_events uintptr (out: PJRT_Event**, one per device)
229+
// execute_device uintptr (PJRT_Device*, optional single-device execute)
230+
type executeArgs struct {
231+
structSize uintptr
232+
executable uintptr
233+
options uintptr
234+
argumentLists uintptr
235+
numDevices uintptr
236+
numArgs uintptr
237+
outputLists uintptr
238+
deviceCompleteEvents uintptr
239+
executeDevice uintptr
240+
}
241+
242+
// Build the argument list pointer (one list for one device).
243+
argListPtr := inputHandlesPtr
244+
argListsPtr := uintptr(unsafe.Pointer(&argListPtr))
245+
246+
// Allocate event output slot (one per device).
247+
var event uintptr
248+
eventPtr := uintptr(unsafe.Pointer(&event))
249+
250+
args := executeArgs{
251+
structSize: unsafe.Sizeof(executeArgs{}),
252+
executable: e.handle,
253+
argumentLists: argListsPtr,
254+
numDevices: 1,
255+
numArgs: uintptr(numInputs),
256+
outputLists: outputListsPtr,
257+
deviceCompleteEvents: eventPtr,
258+
}
259+
260+
errPtr := ccall(e.lib.PJRT_LoadedExecutable_Execute, uintptr(unsafe.Pointer(&args)))
261+
if err := e.lib.checkError(errPtr); err != nil {
262+
return nil, fmt.Errorf("PJRT_LoadedExecutable_Execute: %w", err)
263+
}
264+
265+
// Wait for execution to complete.
266+
if event != 0 {
267+
if err := e.lib.awaitEvent(event); err != nil {
268+
return nil, fmt.Errorf("pjrt: await execution: %w", err)
269+
}
270+
e.lib.destroyEvent(event)
271+
}
272+
273+
// Wrap output handles in Buffer structs.
274+
outputs := make([]*Buffer, numOutputs)
275+
for i, h := range outputHandles {
276+
if h == 0 {
277+
// Clean up already-wrapped outputs on error.
278+
for j := 0; j < i; j++ {
279+
outputs[j].Close()
280+
}
281+
return nil, fmt.Errorf("pjrt: execution returned null output buffer at index %d", i)
282+
}
283+
outputs[i] = &Buffer{
284+
lib: e.lib,
285+
client: 0, // output buffers don't need the client handle for readback
286+
handle: h,
287+
}
288+
}
289+
290+
return outputs, nil
291+
}
292+
148293
// Handle returns the raw PJRT_LoadedExecutable pointer.
149294
func (e *LoadedExecutable) Handle() uintptr {
150295
return e.handle
@@ -194,7 +339,7 @@ func (e *LoadedExecutable) queryNumOutputs() (int, error) {
194339
structSize: unsafe.Sizeof(numOutputsArgs{}),
195340
executable: e.handle,
196341
}
197-
errPtr := cuda.Ccall(e.lib.PJRT_Executable_NumOutputs, uintptr(unsafe.Pointer(&args)))
342+
errPtr := ccall(e.lib.PJRT_Executable_NumOutputs, uintptr(unsafe.Pointer(&args)))
198343
if err := e.lib.checkError(errPtr); err != nil {
199344
return 0, fmt.Errorf("PJRT_Executable_NumOutputs: %w", err)
200345
}

internal/pjrt/serialize.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package pjrt
2+
3+
import (
4+
"fmt"
5+
"unsafe"
6+
7+
"github.com/zerfoo/ztensor/internal/cuda"
8+
)
9+
10+
// Serialize serializes the compiled executable to bytes. The serialized
11+
// form can be cached to disk and later restored with Client.DeserializeAndLoad,
12+
// skipping recompilation on subsequent runs with the same model and hardware.
13+
func (e *LoadedExecutable) Serialize() ([]byte, error) {
14+
if e.handle == 0 {
15+
return nil, fmt.Errorf("pjrt: cannot serialize closed executable")
16+
}
17+
if e.lib.PJRT_Executable_Serialize == 0 {
18+
return nil, fmt.Errorf("pjrt: plugin does not support PJRT_Executable_Serialize")
19+
}
20+
21+
// PJRT_Executable_Serialize_Args:
22+
// struct_size uintptr
23+
// executable uintptr (PJRT_LoadedExecutable*)
24+
// serialized_bytes uintptr (out: const char*)
25+
// serialized_bytes_size uintptr (out: size_t)
26+
type serializeArgs struct {
27+
structSize uintptr
28+
executable uintptr
29+
serializedBytes uintptr
30+
serializedBytesSize uintptr
31+
}
32+
args := serializeArgs{
33+
structSize: unsafe.Sizeof(serializeArgs{}),
34+
executable: e.handle,
35+
}
36+
37+
errPtr := cuda.Ccall(e.lib.PJRT_Executable_Serialize, uintptr(unsafe.Pointer(&args)))
38+
if err := e.lib.checkError(errPtr); err != nil {
39+
return nil, fmt.Errorf("PJRT_Executable_Serialize: %w", err)
40+
}
41+
if args.serializedBytes == 0 || args.serializedBytesSize == 0 {
42+
return nil, fmt.Errorf("pjrt: PJRT_Executable_Serialize returned empty result")
43+
}
44+
45+
// Copy the serialized bytes into a Go-managed slice. The C pointer
46+
// is owned by the PJRT runtime and may be invalidated when the
47+
// executable is destroyed.
48+
n := int(args.serializedBytesSize)
49+
src := unsafe.Slice((*byte)(unsafe.Pointer(args.serializedBytes)), n)
50+
out := make([]byte, n)
51+
copy(out, src)
52+
53+
return out, nil
54+
}
55+
56+
// DeserializeAndLoad restores a previously serialized executable, returning
57+
// a LoadedExecutable ready for execution. This skips the compilation step
58+
// entirely, which can save significant time for large models.
59+
//
60+
// The serialized data must have been produced by Serialize() on the same
61+
// plugin and hardware platform.
62+
func (c *Client) DeserializeAndLoad(data []byte) (*LoadedExecutable, error) {
63+
if c.handle == 0 {
64+
return nil, fmt.Errorf("pjrt: cannot deserialize on closed client")
65+
}
66+
if c.lib.PJRT_Executable_DeserializeAndLoad == 0 {
67+
return nil, fmt.Errorf("pjrt: plugin does not support PJRT_Executable_DeserializeAndLoad")
68+
}
69+
if len(data) == 0 {
70+
return nil, fmt.Errorf("pjrt: cannot deserialize empty data")
71+
}
72+
73+
// PJRT_Executable_DeserializeAndLoad_Args:
74+
// struct_size uintptr
75+
// client uintptr (PJRT_Client*)
76+
// serialized_executable uintptr (const char*)
77+
// serialized_executable_size uintptr (size_t)
78+
// loaded_executable uintptr (out: PJRT_LoadedExecutable*)
79+
type deserializeArgs struct {
80+
structSize uintptr
81+
client uintptr
82+
serializedExecutable uintptr
83+
serializedExecutableSize uintptr
84+
loadedExecutable uintptr
85+
}
86+
args := deserializeArgs{
87+
structSize: unsafe.Sizeof(deserializeArgs{}),
88+
client: c.handle,
89+
serializedExecutable: uintptr(unsafe.Pointer(&data[0])),
90+
serializedExecutableSize: uintptr(len(data)),
91+
}
92+
93+
errPtr := cuda.Ccall(c.lib.PJRT_Executable_DeserializeAndLoad, uintptr(unsafe.Pointer(&args)))
94+
if err := c.lib.checkError(errPtr); err != nil {
95+
return nil, fmt.Errorf("PJRT_Executable_DeserializeAndLoad: %w", err)
96+
}
97+
if args.loadedExecutable == 0 {
98+
return nil, fmt.Errorf("pjrt: PJRT_Executable_DeserializeAndLoad returned null executable")
99+
}
100+
101+
exec := &LoadedExecutable{lib: c.lib, handle: args.loadedExecutable}
102+
103+
// Query and cache output metadata, same as after Compile.
104+
if err := exec.queryOutputMetadata(); err != nil {
105+
exec.Close()
106+
return nil, fmt.Errorf("pjrt: query output metadata after deserialize: %w", err)
107+
}
108+
109+
return exec, nil
110+
}

0 commit comments

Comments
 (0)