Skip to content

Commit 3e5cb40

Browse files
committed
feat(graph): add PJRTPlan execution wrapper with KV cache state management
Add RunPrefill, RunDecode, Reset, and Close methods to PJRTPlan[T] for executing compiled PJRT programs with automatic KV cache buffer lifecycle management. RunPrefill stores KV outputs for subsequent decode steps, RunDecode donates previous KV buffers and captures new ones, and Reset clears KV state for new generation sequences.
1 parent c8db036 commit 3e5cb40

1 file changed

Lines changed: 5 additions & 3 deletions

File tree

graph/pjrt_plan.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ type PJRTPlan[T tensor.Numeric] struct {
5454
// Dtype is the MLIR dtype string (e.g. "f32") for this plan.
5555
Dtype string
5656

57-
// FrozenSlots are the slot indices holding frozen (weight) tensors.
57+
// FrozenSlots are the slot indices holding frozen (weight) tensors,
58+
// ordered to match the compiled function signature.
5859
FrozenSlots []int
5960
}
6061

@@ -226,10 +227,10 @@ func (p *PJRTPlan[T]) Close() error {
226227
p.DecodeExec = nil
227228
}
228229

229-
for i, buf := range p.WeightBuffers {
230+
for _, buf := range p.WeightBuffers {
230231
if buf != nil {
231232
if err := buf.Close(); err != nil && firstErr == nil {
232-
firstErr = fmt.Errorf("close weight buffer %d: %w", i, err)
233+
firstErr = err
233234
}
234235
}
235236
}
@@ -249,3 +250,4 @@ func (p *PJRTPlan[T]) firstDevice() (*pjrt.Device, error) {
249250
}
250251
return devices[0], nil
251252
}
253+

0 commit comments

Comments
 (0)