@@ -3,8 +3,6 @@ package pjrt
33import (
44 "fmt"
55 "unsafe"
6-
7- "github.com/zerfoo/ztensor/internal/cuda"
86)
97
108// LoadedExecutable wraps a PJRT_LoadedExecutable handle returned by
@@ -79,7 +77,7 @@ func (c *Client) Compile(stablehloMLIR string) (*LoadedExecutable, error) {
7977 program : uintptr (unsafe .Pointer (& program )),
8078 }
8179
82- errPtr := cuda . Ccall (c .lib .PJRT_Client_Compile , uintptr (unsafe .Pointer (& args )))
80+ errPtr := ccall (c .lib .PJRT_Client_Compile , uintptr (unsafe .Pointer (& args )))
8381 if err := c .lib .checkError (errPtr ); err != nil {
8482 return nil , fmt .Errorf ("PJRT_Client_Compile: %w" , err )
8583 }
@@ -140,11 +138,158 @@ func (e *LoadedExecutable) Close() error {
140138 structSize : unsafe .Sizeof (destroyArgs {}),
141139 executable : e .handle ,
142140 }
143- errPtr := cuda . Ccall (e .lib .PJRT_LoadedExecutable_Destroy , uintptr (unsafe .Pointer (& args )))
141+ errPtr := ccall (e .lib .PJRT_LoadedExecutable_Destroy , uintptr (unsafe .Pointer (& args )))
144142 e .handle = 0
145143 return e .lib .checkError (errPtr )
146144}
147145
146+ // ExecOption configures Execute behavior.
147+ type ExecOption func (* execConfig )
148+
149+ type execConfig struct {
150+ // device ordinal to execute on (0 = first addressable device).
151+ deviceOrdinal int
152+ // donateInputs indicates that the runtime may take ownership of
153+ // input buffers, avoiding a copy. The caller must not use the
154+ // donated buffers after Execute returns.
155+ donateInputs []bool
156+ }
157+
158+ // WithDeviceOrdinal selects which device to execute on.
159+ func WithDeviceOrdinal (ordinal int ) ExecOption {
160+ return func (c * execConfig ) {
161+ c .deviceOrdinal = ordinal
162+ }
163+ }
164+
165+ // WithInputDonation marks specific inputs for buffer donation.
166+ // donated[i] == true means input i may be consumed by the runtime.
167+ func WithInputDonation (donated []bool ) ExecOption {
168+ return func (c * execConfig ) {
169+ c .donateInputs = donated
170+ }
171+ }
172+
173+ // Execute runs the compiled program with the given input buffers and
174+ // returns the output buffers. The caller owns the returned buffers and
175+ // must close them when done.
176+ //
177+ //go:nocheckptr
178+ func (e * LoadedExecutable ) Execute (inputs []* Buffer , opts ... ExecOption ) ([]* Buffer , error ) {
179+ if e .handle == 0 {
180+ return nil , fmt .Errorf ("pjrt: cannot execute closed executable" )
181+ }
182+ if e .lib .PJRT_LoadedExecutable_Execute == 0 {
183+ return nil , fmt .Errorf ("pjrt: plugin does not support PJRT_LoadedExecutable_Execute" )
184+ }
185+
186+ var cfg execConfig
187+ for _ , o := range opts {
188+ o (& cfg )
189+ }
190+
191+ // Build the flat array of input buffer handles.
192+ numInputs := len (inputs )
193+ inputHandles := make ([]uintptr , numInputs )
194+ for i , buf := range inputs {
195+ if buf == nil || buf .Handle () == 0 {
196+ return nil , fmt .Errorf ("pjrt: input buffer %d is nil or closed" , i )
197+ }
198+ inputHandles [i ] = buf .Handle ()
199+ }
200+
201+ var inputHandlesPtr uintptr
202+ if numInputs > 0 {
203+ inputHandlesPtr = uintptr (unsafe .Pointer (& inputHandles [0 ]))
204+ }
205+
206+ // Allocate output buffer handle slots. PJRT writes one PJRT_Buffer*
207+ // per output per device. We execute on a single device.
208+ numOutputs := e .numOutputs
209+ outputHandles := make ([]uintptr , numOutputs )
210+ var outputHandlesPtr uintptr
211+ if numOutputs > 0 {
212+ outputHandlesPtr = uintptr (unsafe .Pointer (& outputHandles [0 ]))
213+ }
214+
215+ // PJRT expects a pointer-to-pointer for the output list (one list
216+ // per device). We execute on one device, so we have a single list.
217+ outputListPtr := outputHandlesPtr
218+ outputListsPtr := uintptr (unsafe .Pointer (& outputListPtr ))
219+
220+ // PJRT_LoadedExecutable_Execute_Args:
221+ // struct_size uintptr
222+ // executable uintptr (PJRT_LoadedExecutable*)
223+ // options uintptr (PJRT_ExecuteOptions*, may be 0)
224+ // argument_lists uintptr (PJRT_Buffer* const* const*, one list per device)
225+ // num_devices uintptr (size_t)
226+ // num_args uintptr (size_t)
227+ // output_lists uintptr (PJRT_Buffer** const*, out: one list per device)
228+ // device_complete_events uintptr (out: PJRT_Event**, one per device)
229+ // execute_device uintptr (PJRT_Device*, optional single-device execute)
230+ type executeArgs struct {
231+ structSize uintptr
232+ executable uintptr
233+ options uintptr
234+ argumentLists uintptr
235+ numDevices uintptr
236+ numArgs uintptr
237+ outputLists uintptr
238+ deviceCompleteEvents uintptr
239+ executeDevice uintptr
240+ }
241+
242+ // Build the argument list pointer (one list for one device).
243+ argListPtr := inputHandlesPtr
244+ argListsPtr := uintptr (unsafe .Pointer (& argListPtr ))
245+
246+ // Allocate event output slot (one per device).
247+ var event uintptr
248+ eventPtr := uintptr (unsafe .Pointer (& event ))
249+
250+ args := executeArgs {
251+ structSize : unsafe .Sizeof (executeArgs {}),
252+ executable : e .handle ,
253+ argumentLists : argListsPtr ,
254+ numDevices : 1 ,
255+ numArgs : uintptr (numInputs ),
256+ outputLists : outputListsPtr ,
257+ deviceCompleteEvents : eventPtr ,
258+ }
259+
260+ errPtr := ccall (e .lib .PJRT_LoadedExecutable_Execute , uintptr (unsafe .Pointer (& args )))
261+ if err := e .lib .checkError (errPtr ); err != nil {
262+ return nil , fmt .Errorf ("PJRT_LoadedExecutable_Execute: %w" , err )
263+ }
264+
265+ // Wait for execution to complete.
266+ if event != 0 {
267+ if err := e .lib .awaitEvent (event ); err != nil {
268+ return nil , fmt .Errorf ("pjrt: await execution: %w" , err )
269+ }
270+ e .lib .destroyEvent (event )
271+ }
272+
273+ // Wrap output handles in Buffer structs.
274+ outputs := make ([]* Buffer , numOutputs )
275+ for i , h := range outputHandles {
276+ if h == 0 {
277+ // Clean up already-wrapped outputs on error.
278+ for j := 0 ; j < i ; j ++ {
279+ outputs [j ].Close ()
280+ }
281+ return nil , fmt .Errorf ("pjrt: execution returned null output buffer at index %d" , i )
282+ }
283+ outputs [i ] = & Buffer {
284+ lib : e .lib ,
285+ client : 0 , // output buffers don't need the client handle for readback
286+ handle : h ,
287+ }
288+ }
289+
290+ return outputs , nil
291+ }
292+
148293// Handle returns the raw PJRT_LoadedExecutable pointer.
149294func (e * LoadedExecutable ) Handle () uintptr {
150295 return e .handle
@@ -194,7 +339,7 @@ func (e *LoadedExecutable) queryNumOutputs() (int, error) {
194339 structSize : unsafe .Sizeof (numOutputsArgs {}),
195340 executable : e .handle ,
196341 }
197- errPtr := cuda . Ccall (e .lib .PJRT_Executable_NumOutputs , uintptr (unsafe .Pointer (& args )))
342+ errPtr := ccall (e .lib .PJRT_Executable_NumOutputs , uintptr (unsafe .Pointer (& args )))
198343 if err := e .lib .checkError (errPtr ); err != nil {
199344 return 0 , fmt .Errorf ("PJRT_Executable_NumOutputs: %w" , err )
200345 }
0 commit comments