Skip to content

Commit dfd77a4

Browse files
committed
feat(graph): add CompilePJRT for PJRT backend compilation
Add CompilePJRT() method on Graph[T] that traces the graph to obtain primitive operations, emits StableHLO MLIR text, and compiles it via a PJRT client. For graphs with KV cache, both prefill and decode executables are compiled. Frozen weight tensors are transferred to the device. Also adds PJRTPlan[T] struct with execution, reset, and close methods, and KVPairs() accessor on Graph[T] for external KV cache slot resolution.
1 parent c8decc5 commit dfd77a4

3 files changed

Lines changed: 560 additions & 0 deletions

File tree

graph/compile_pjrt.go

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
package graph
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/zerfoo/ztensor/internal/pjrt"
8+
"github.com/zerfoo/ztensor/internal/stablehlo"
9+
"github.com/zerfoo/ztensor/tensor"
10+
)
11+
12+
// CompilePJRT compiles the graph into a PJRTPlan for execution on a PJRT
13+
// backend. It traces the graph to obtain primitive operations, emits
14+
// StableHLO MLIR, compiles it via the PJRT client, and transfers frozen
15+
// weights to the device.
16+
//
17+
// For graphs with KV cache (StatefulInputNodes), both a prefill and a decode
18+
// executable are compiled. For graphs without KV cache, only the prefill
19+
// executable is produced.
20+
func (g *Graph[T]) CompilePJRT(ctx context.Context, client *pjrt.Client, inputs ...*tensor.TensorNumeric[T]) (*PJRTPlan[T], error) {
21+
if client == nil {
22+
return nil, fmt.Errorf("CompilePJRT: client is nil")
23+
}
24+
25+
// Step 1: Trace the graph to get primitive ops and slot metadata.
26+
plan, err := g.CompileTraced(ctx, inputs...)
27+
if err != nil {
28+
return nil, fmt.Errorf("CompilePJRT: trace failed: %w", err)
29+
}
30+
31+
tracedMetas := plan.Instructions()
32+
slotShapes := plan.SlotShapes()
33+
inputSlots := plan.InputSlots()
34+
outputSlot := plan.OutputSlot()
35+
frozenSlots := plan.FrozenSlots()
36+
37+
// Build a set of frozen slot indices for quick lookup.
38+
frozenSet := make(map[int]bool, len(frozenSlots))
39+
for _, fs := range frozenSlots {
40+
frozenSet[fs.SlotIdx] = true
41+
}
42+
43+
// Determine the MLIR dtype from the Go generic type.
44+
dtype, err := mlirDtype[T]()
45+
if err != nil {
46+
return nil, fmt.Errorf("CompilePJRT: %w", err)
47+
}
48+
49+
// Step 2: Convert traced InstructionMeta to stablehlo.ProgramOp.
50+
programOps := make([]stablehlo.ProgramOp, len(tracedMetas))
51+
for i, meta := range tracedMetas {
52+
inputShapes := make([][]int, len(meta.InputIdx))
53+
for j, idx := range meta.InputIdx {
54+
if idx < len(slotShapes) && slotShapes[idx] != nil {
55+
inputShapes[j] = slotShapes[idx]
56+
}
57+
}
58+
59+
var outputShape []int
60+
if meta.OutputIdx < len(slotShapes) {
61+
outputShape = slotShapes[meta.OutputIdx]
62+
}
63+
64+
programOps[i] = stablehlo.ProgramOp{
65+
OpName: meta.OpName,
66+
InputSlots: meta.InputIdx,
67+
OutputSlot: meta.OutputIdx,
68+
InputShapes: inputShapes,
69+
OutputShape: outputShape,
70+
Dtype: dtype,
71+
Attrs: meta.ExtraArgs,
72+
}
73+
}
74+
75+
// Collect input shapes for the StableHLO function signature.
76+
// This includes both dynamic inputs and frozen (weight) slots.
77+
allInputSlots := make([]int, 0, len(inputSlots)+len(frozenSlots))
78+
allInputShapes := make([][]int, 0, len(inputSlots)+len(frozenSlots))
79+
80+
for _, idx := range inputSlots {
81+
allInputSlots = append(allInputSlots, idx)
82+
if idx < len(slotShapes) {
83+
allInputShapes = append(allInputShapes, slotShapes[idx])
84+
} else {
85+
allInputShapes = append(allInputShapes, nil)
86+
}
87+
}
88+
for _, fs := range frozenSlots {
89+
allInputSlots = append(allInputSlots, fs.SlotIdx)
90+
if fs.SlotIdx < len(slotShapes) {
91+
allInputShapes = append(allInputShapes, slotShapes[fs.SlotIdx])
92+
} else {
93+
allInputShapes = append(allInputShapes, nil)
94+
}
95+
}
96+
97+
// Step 3: Identify KV cache slots from graph's kvPairs.
98+
kvPairs := g.KVPairs()
99+
var kvSlots []stablehlo.KVCacheSlot
100+
101+
if len(kvPairs) > 0 {
102+
// We need to map the KV pair nodes to their traced slot indices.
103+
// The graph ran a Forward during CompileTraced, so memo has the
104+
// output tensors. We need the slot indices from the plan.
105+
for _, kv := range kvPairs {
106+
inputShape := kv.Input.OutputShape()
107+
outputShape := kv.Output.OutputShape()
108+
109+
// Find the slot index for the input node's output. The input
110+
// node is a StatefulInputNode — its tensor was registered as
111+
// a slot during tracing. We search plan's slot shapes.
112+
inputSlot := findNodeSlot(inputSlots, slotShapes, inputShape)
113+
outputSlot := findOutputSlot(programOps, outputShape)
114+
115+
// Default seq_axis to 1 (standard KV cache layout:
116+
// [num_heads, seq_len, head_dim]).
117+
seqAxis := 1
118+
if len(inputShape) >= 2 {
119+
seqAxis = len(inputShape) - 2
120+
}
121+
122+
kvSlots = append(kvSlots, stablehlo.KVCacheSlot{
123+
InputSlot: inputSlot,
124+
OutputSlot: outputSlot,
125+
Shape: inputShape,
126+
SeqAxis: seqAxis,
127+
})
128+
}
129+
}
130+
131+
// Step 4: Emit StableHLO MLIR and compile.
132+
var prefillMLIR string
133+
var decodeMLIR string
134+
135+
if len(kvSlots) > 0 {
136+
prefillMLIR, err = stablehlo.EmitKVCacheProgram(programOps, allInputSlots, allInputShapes, kvSlots, dtype, false)
137+
if err != nil {
138+
return nil, fmt.Errorf("CompilePJRT: emit prefill program: %w", err)
139+
}
140+
decodeMLIR, err = stablehlo.EmitKVCacheProgram(programOps, allInputSlots, allInputShapes, kvSlots, dtype, true)
141+
if err != nil {
142+
return nil, fmt.Errorf("CompilePJRT: emit decode program: %w", err)
143+
}
144+
} else {
145+
prefillMLIR, err = stablehlo.EmitProgram(programOps, allInputSlots, allInputShapes, dtype)
146+
if err != nil {
147+
return nil, fmt.Errorf("CompilePJRT: emit program: %w", err)
148+
}
149+
}
150+
151+
// Step 5: Compile the MLIR programs via PJRT.
152+
prefillExec, err := client.Compile(prefillMLIR)
153+
if err != nil {
154+
return nil, fmt.Errorf("CompilePJRT: compile prefill: %w", err)
155+
}
156+
157+
var decodeExec *pjrt.LoadedExecutable
158+
if decodeMLIR != "" {
159+
decodeExec, err = client.Compile(decodeMLIR)
160+
if err != nil {
161+
prefillExec.Close()
162+
return nil, fmt.Errorf("CompilePJRT: compile decode: %w", err)
163+
}
164+
}
165+
166+
// Step 6: Transfer frozen weights to device.
167+
devices, err := client.AddressableDevices()
168+
if err != nil {
169+
prefillExec.Close()
170+
if decodeExec != nil {
171+
decodeExec.Close()
172+
}
173+
return nil, fmt.Errorf("CompilePJRT: get devices: %w", err)
174+
}
175+
if len(devices) == 0 {
176+
prefillExec.Close()
177+
if decodeExec != nil {
178+
decodeExec.Close()
179+
}
180+
return nil, fmt.Errorf("CompilePJRT: no addressable devices")
181+
}
182+
device := devices[0]
183+
184+
weightBuffers := make([]*pjrt.Buffer, len(frozenSlots))
185+
for i, fs := range frozenSlots {
186+
data := fs.Data.Data()
187+
shape := fs.Data.Shape()
188+
buf, err := pjrt.BufferFromHost(client, data, shape, device)
189+
if err != nil {
190+
// Clean up already-transferred buffers.
191+
for j := 0; j < i; j++ {
192+
weightBuffers[j].Close()
193+
}
194+
prefillExec.Close()
195+
if decodeExec != nil {
196+
decodeExec.Close()
197+
}
198+
return nil, fmt.Errorf("CompilePJRT: transfer weight slot %d: %w", fs.SlotIdx, err)
199+
}
200+
weightBuffers[i] = buf
201+
}
202+
203+
// Build the slot shape map.
204+
slotShapeMap := make(map[int][]int, len(slotShapes))
205+
for i, s := range slotShapes {
206+
if s != nil {
207+
slotShapeMap[i] = s
208+
}
209+
}
210+
211+
return &PJRTPlan[T]{
212+
PrefillExec: prefillExec,
213+
DecodeExec: decodeExec,
214+
Client: client,
215+
WeightBuffers: weightBuffers,
216+
KVSlots: kvSlots,
217+
InputSlots: inputSlots,
218+
OutputSlot: outputSlot,
219+
SlotShapes: slotShapeMap,
220+
Dtype: dtype,
221+
FrozenSlots: frozenSlotIdxs(frozenSlots),
222+
}, nil
223+
}
224+
225+
// mlirDtype resolves the MLIR dtype string for the generic type T.
226+
func mlirDtype[T tensor.Numeric]() (string, error) {
227+
var zero T
228+
goType := fmt.Sprintf("%T", zero)
229+
// Handle package-qualified type names.
230+
switch goType {
231+
case "float16.Float16":
232+
goType = "float16"
233+
case "float16.BFloat16":
234+
goType = "bfloat16"
235+
case "float8.Float8":
236+
goType = "float8"
237+
}
238+
dtype, ok := stablehlo.GoDTypeToMLIR(goType)
239+
if !ok {
240+
return "", fmt.Errorf("unsupported type %T for StableHLO", zero)
241+
}
242+
return dtype, nil
243+
}
244+
245+
// findNodeSlot finds the slot index for a node by matching its shape against
246+
// known input slot shapes. This is a heuristic — in practice, KV cache input
247+
// nodes have unique shapes within the input set.
248+
func findNodeSlot(inputSlots []int, slotShapes [][]int, targetShape []int) int {
249+
for _, idx := range inputSlots {
250+
if idx < len(slotShapes) && shapesEqual(slotShapes[idx], targetShape) {
251+
return idx
252+
}
253+
}
254+
// Fallback: return -1 to indicate not found.
255+
return -1
256+
}
257+
258+
// findOutputSlot finds the last op whose output shape matches the target.
259+
func findOutputSlot(ops []stablehlo.ProgramOp, targetShape []int) int {
260+
for i := len(ops) - 1; i >= 0; i-- {
261+
if shapesEqual(ops[i].OutputShape, targetShape) {
262+
return ops[i].OutputSlot
263+
}
264+
}
265+
return -1
266+
}
267+
268+
// shapesEqual compares two shapes for equality.
269+
func shapesEqual(a, b []int) bool {
270+
if len(a) != len(b) {
271+
return false
272+
}
273+
for i := range a {
274+
if a[i] != b[i] {
275+
return false
276+
}
277+
}
278+
return true
279+
}
280+
281+
// frozenSlotIdxs extracts slot indices from FrozenSlot.
282+
func frozenSlotIdxs[T tensor.Numeric](frozen []FrozenSlot[T]) []int {
283+
idxs := make([]int, len(frozen))
284+
for i, fs := range frozen {
285+
idxs[i] = fs.SlotIdx
286+
}
287+
return idxs
288+
}

graph/graph.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,27 @@ func (g *Graph[T]) ResetStatefulNodes() {
7777
g.cachedRefCount = nil
7878
}
7979

80+
// KVPairInfo describes a KV cache pair for external consumers (e.g., PJRT
81+
// compilation). It exposes the input and output nodes so callers can
82+
// resolve their slot indices after tracing.
83+
type KVPairInfo[T tensor.Numeric] struct {
84+
Input Node[T] // the stateful input node (also satisfies StatefulInputNode)
85+
Output Node[T] // the output node whose result feeds back into Input
86+
}
87+
88+
// KVPairs returns the registered KV cache pairs. Each pair links a
89+
// stateful input node to the output node that produces its next state.
90+
func (g *Graph[T]) KVPairs() []KVPairInfo[T] {
91+
pairs := make([]KVPairInfo[T], len(g.kvPairs))
92+
for i, kv := range g.kvPairs {
93+
pairs[i] = KVPairInfo[T]{
94+
Input: kv.input.(Node[T]),
95+
Output: kv.output,
96+
}
97+
}
98+
return pairs
99+
}
100+
80101
// AddKVPair registers a stateful input node that should receive the output
81102
// of another node after each forward pass. Used for ONNX KV cache feedback.
82103
func (g *Graph[T]) AddKVPair(input StatefulInputNode[T], output Node[T]) {

0 commit comments

Comments
 (0)