|
| 1 | +package graph |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + |
| 7 | + "github.com/zerfoo/ztensor/internal/pjrt" |
| 8 | + "github.com/zerfoo/ztensor/internal/stablehlo" |
| 9 | + "github.com/zerfoo/ztensor/tensor" |
| 10 | +) |
| 11 | + |
| 12 | +// CompilePJRT compiles the graph into a PJRTPlan for execution on a PJRT |
| 13 | +// backend. It traces the graph to obtain primitive operations, emits |
| 14 | +// StableHLO MLIR, compiles it via the PJRT client, and transfers frozen |
| 15 | +// weights to the device. |
| 16 | +// |
| 17 | +// For graphs with KV cache (StatefulInputNodes), both a prefill and a decode |
| 18 | +// executable are compiled. For graphs without KV cache, only the prefill |
| 19 | +// executable is produced. |
| 20 | +func (g *Graph[T]) CompilePJRT(ctx context.Context, client *pjrt.Client, inputs ...*tensor.TensorNumeric[T]) (*PJRTPlan[T], error) { |
| 21 | + if client == nil { |
| 22 | + return nil, fmt.Errorf("CompilePJRT: client is nil") |
| 23 | + } |
| 24 | + |
| 25 | + // Step 1: Trace the graph to get primitive ops and slot metadata. |
| 26 | + plan, err := g.CompileTraced(ctx, inputs...) |
| 27 | + if err != nil { |
| 28 | + return nil, fmt.Errorf("CompilePJRT: trace failed: %w", err) |
| 29 | + } |
| 30 | + |
| 31 | + tracedMetas := plan.Instructions() |
| 32 | + slotShapes := plan.SlotShapes() |
| 33 | + inputSlots := plan.InputSlots() |
| 34 | + outputSlot := plan.OutputSlot() |
| 35 | + frozenSlots := plan.FrozenSlots() |
| 36 | + |
| 37 | + // Build a set of frozen slot indices for quick lookup. |
| 38 | + frozenSet := make(map[int]bool, len(frozenSlots)) |
| 39 | + for _, fs := range frozenSlots { |
| 40 | + frozenSet[fs.SlotIdx] = true |
| 41 | + } |
| 42 | + |
| 43 | + // Determine the MLIR dtype from the Go generic type. |
| 44 | + dtype, err := mlirDtype[T]() |
| 45 | + if err != nil { |
| 46 | + return nil, fmt.Errorf("CompilePJRT: %w", err) |
| 47 | + } |
| 48 | + |
| 49 | + // Step 2: Convert traced InstructionMeta to stablehlo.ProgramOp. |
| 50 | + programOps := make([]stablehlo.ProgramOp, len(tracedMetas)) |
| 51 | + for i, meta := range tracedMetas { |
| 52 | + inputShapes := make([][]int, len(meta.InputIdx)) |
| 53 | + for j, idx := range meta.InputIdx { |
| 54 | + if idx < len(slotShapes) && slotShapes[idx] != nil { |
| 55 | + inputShapes[j] = slotShapes[idx] |
| 56 | + } |
| 57 | + } |
| 58 | + |
| 59 | + var outputShape []int |
| 60 | + if meta.OutputIdx < len(slotShapes) { |
| 61 | + outputShape = slotShapes[meta.OutputIdx] |
| 62 | + } |
| 63 | + |
| 64 | + programOps[i] = stablehlo.ProgramOp{ |
| 65 | + OpName: meta.OpName, |
| 66 | + InputSlots: meta.InputIdx, |
| 67 | + OutputSlot: meta.OutputIdx, |
| 68 | + InputShapes: inputShapes, |
| 69 | + OutputShape: outputShape, |
| 70 | + Dtype: dtype, |
| 71 | + Attrs: meta.ExtraArgs, |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + // Collect input shapes for the StableHLO function signature. |
| 76 | + // This includes both dynamic inputs and frozen (weight) slots. |
| 77 | + allInputSlots := make([]int, 0, len(inputSlots)+len(frozenSlots)) |
| 78 | + allInputShapes := make([][]int, 0, len(inputSlots)+len(frozenSlots)) |
| 79 | + |
| 80 | + for _, idx := range inputSlots { |
| 81 | + allInputSlots = append(allInputSlots, idx) |
| 82 | + if idx < len(slotShapes) { |
| 83 | + allInputShapes = append(allInputShapes, slotShapes[idx]) |
| 84 | + } else { |
| 85 | + allInputShapes = append(allInputShapes, nil) |
| 86 | + } |
| 87 | + } |
| 88 | + for _, fs := range frozenSlots { |
| 89 | + allInputSlots = append(allInputSlots, fs.SlotIdx) |
| 90 | + if fs.SlotIdx < len(slotShapes) { |
| 91 | + allInputShapes = append(allInputShapes, slotShapes[fs.SlotIdx]) |
| 92 | + } else { |
| 93 | + allInputShapes = append(allInputShapes, nil) |
| 94 | + } |
| 95 | + } |
| 96 | + |
| 97 | + // Step 3: Identify KV cache slots from graph's kvPairs. |
| 98 | + kvPairs := g.KVPairs() |
| 99 | + var kvSlots []stablehlo.KVCacheSlot |
| 100 | + |
| 101 | + if len(kvPairs) > 0 { |
| 102 | + // We need to map the KV pair nodes to their traced slot indices. |
| 103 | + // The graph ran a Forward during CompileTraced, so memo has the |
| 104 | + // output tensors. We need the slot indices from the plan. |
| 105 | + for _, kv := range kvPairs { |
| 106 | + inputShape := kv.Input.OutputShape() |
| 107 | + outputShape := kv.Output.OutputShape() |
| 108 | + |
| 109 | + // Find the slot index for the input node's output. The input |
| 110 | + // node is a StatefulInputNode — its tensor was registered as |
| 111 | + // a slot during tracing. We search plan's slot shapes. |
| 112 | + inputSlot := findNodeSlot(inputSlots, slotShapes, inputShape) |
| 113 | + outputSlot := findOutputSlot(programOps, outputShape) |
| 114 | + |
| 115 | + // Default seq_axis to 1 (standard KV cache layout: |
| 116 | + // [num_heads, seq_len, head_dim]). |
| 117 | + seqAxis := 1 |
| 118 | + if len(inputShape) >= 2 { |
| 119 | + seqAxis = len(inputShape) - 2 |
| 120 | + } |
| 121 | + |
| 122 | + kvSlots = append(kvSlots, stablehlo.KVCacheSlot{ |
| 123 | + InputSlot: inputSlot, |
| 124 | + OutputSlot: outputSlot, |
| 125 | + Shape: inputShape, |
| 126 | + SeqAxis: seqAxis, |
| 127 | + }) |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + // Step 4: Emit StableHLO MLIR and compile. |
| 132 | + var prefillMLIR string |
| 133 | + var decodeMLIR string |
| 134 | + |
| 135 | + if len(kvSlots) > 0 { |
| 136 | + prefillMLIR, err = stablehlo.EmitKVCacheProgram(programOps, allInputSlots, allInputShapes, kvSlots, dtype, false) |
| 137 | + if err != nil { |
| 138 | + return nil, fmt.Errorf("CompilePJRT: emit prefill program: %w", err) |
| 139 | + } |
| 140 | + decodeMLIR, err = stablehlo.EmitKVCacheProgram(programOps, allInputSlots, allInputShapes, kvSlots, dtype, true) |
| 141 | + if err != nil { |
| 142 | + return nil, fmt.Errorf("CompilePJRT: emit decode program: %w", err) |
| 143 | + } |
| 144 | + } else { |
| 145 | + prefillMLIR, err = stablehlo.EmitProgram(programOps, allInputSlots, allInputShapes, dtype) |
| 146 | + if err != nil { |
| 147 | + return nil, fmt.Errorf("CompilePJRT: emit program: %w", err) |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + // Step 5: Compile the MLIR programs via PJRT. |
| 152 | + prefillExec, err := client.Compile(prefillMLIR) |
| 153 | + if err != nil { |
| 154 | + return nil, fmt.Errorf("CompilePJRT: compile prefill: %w", err) |
| 155 | + } |
| 156 | + |
| 157 | + var decodeExec *pjrt.LoadedExecutable |
| 158 | + if decodeMLIR != "" { |
| 159 | + decodeExec, err = client.Compile(decodeMLIR) |
| 160 | + if err != nil { |
| 161 | + prefillExec.Close() |
| 162 | + return nil, fmt.Errorf("CompilePJRT: compile decode: %w", err) |
| 163 | + } |
| 164 | + } |
| 165 | + |
| 166 | + // Step 6: Transfer frozen weights to device. |
| 167 | + devices, err := client.AddressableDevices() |
| 168 | + if err != nil { |
| 169 | + prefillExec.Close() |
| 170 | + if decodeExec != nil { |
| 171 | + decodeExec.Close() |
| 172 | + } |
| 173 | + return nil, fmt.Errorf("CompilePJRT: get devices: %w", err) |
| 174 | + } |
| 175 | + if len(devices) == 0 { |
| 176 | + prefillExec.Close() |
| 177 | + if decodeExec != nil { |
| 178 | + decodeExec.Close() |
| 179 | + } |
| 180 | + return nil, fmt.Errorf("CompilePJRT: no addressable devices") |
| 181 | + } |
| 182 | + device := devices[0] |
| 183 | + |
| 184 | + weightBuffers := make([]*pjrt.Buffer, len(frozenSlots)) |
| 185 | + for i, fs := range frozenSlots { |
| 186 | + data := fs.Data.Data() |
| 187 | + shape := fs.Data.Shape() |
| 188 | + buf, err := pjrt.BufferFromHost(client, data, shape, device) |
| 189 | + if err != nil { |
| 190 | + // Clean up already-transferred buffers. |
| 191 | + for j := 0; j < i; j++ { |
| 192 | + weightBuffers[j].Close() |
| 193 | + } |
| 194 | + prefillExec.Close() |
| 195 | + if decodeExec != nil { |
| 196 | + decodeExec.Close() |
| 197 | + } |
| 198 | + return nil, fmt.Errorf("CompilePJRT: transfer weight slot %d: %w", fs.SlotIdx, err) |
| 199 | + } |
| 200 | + weightBuffers[i] = buf |
| 201 | + } |
| 202 | + |
| 203 | + // Build the slot shape map. |
| 204 | + slotShapeMap := make(map[int][]int, len(slotShapes)) |
| 205 | + for i, s := range slotShapes { |
| 206 | + if s != nil { |
| 207 | + slotShapeMap[i] = s |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + return &PJRTPlan[T]{ |
| 212 | + PrefillExec: prefillExec, |
| 213 | + DecodeExec: decodeExec, |
| 214 | + Client: client, |
| 215 | + WeightBuffers: weightBuffers, |
| 216 | + KVSlots: kvSlots, |
| 217 | + InputSlots: inputSlots, |
| 218 | + OutputSlot: outputSlot, |
| 219 | + SlotShapes: slotShapeMap, |
| 220 | + Dtype: dtype, |
| 221 | + FrozenSlots: frozenSlotIdxs(frozenSlots), |
| 222 | + }, nil |
| 223 | +} |
| 224 | + |
| 225 | +// mlirDtype resolves the MLIR dtype string for the generic type T. |
| 226 | +func mlirDtype[T tensor.Numeric]() (string, error) { |
| 227 | + var zero T |
| 228 | + goType := fmt.Sprintf("%T", zero) |
| 229 | + // Handle package-qualified type names. |
| 230 | + switch goType { |
| 231 | + case "float16.Float16": |
| 232 | + goType = "float16" |
| 233 | + case "float16.BFloat16": |
| 234 | + goType = "bfloat16" |
| 235 | + case "float8.Float8": |
| 236 | + goType = "float8" |
| 237 | + } |
| 238 | + dtype, ok := stablehlo.GoDTypeToMLIR(goType) |
| 239 | + if !ok { |
| 240 | + return "", fmt.Errorf("unsupported type %T for StableHLO", zero) |
| 241 | + } |
| 242 | + return dtype, nil |
| 243 | +} |
| 244 | + |
| 245 | +// findNodeSlot finds the slot index for a node by matching its shape against |
| 246 | +// known input slot shapes. This is a heuristic — in practice, KV cache input |
| 247 | +// nodes have unique shapes within the input set. |
| 248 | +func findNodeSlot(inputSlots []int, slotShapes [][]int, targetShape []int) int { |
| 249 | + for _, idx := range inputSlots { |
| 250 | + if idx < len(slotShapes) && shapesEqual(slotShapes[idx], targetShape) { |
| 251 | + return idx |
| 252 | + } |
| 253 | + } |
| 254 | + // Fallback: return -1 to indicate not found. |
| 255 | + return -1 |
| 256 | +} |
| 257 | + |
| 258 | +// findOutputSlot finds the last op whose output shape matches the target. |
| 259 | +func findOutputSlot(ops []stablehlo.ProgramOp, targetShape []int) int { |
| 260 | + for i := len(ops) - 1; i >= 0; i-- { |
| 261 | + if shapesEqual(ops[i].OutputShape, targetShape) { |
| 262 | + return ops[i].OutputSlot |
| 263 | + } |
| 264 | + } |
| 265 | + return -1 |
| 266 | +} |
| 267 | + |
| 268 | +// shapesEqual compares two shapes for equality. |
| 269 | +func shapesEqual(a, b []int) bool { |
| 270 | + if len(a) != len(b) { |
| 271 | + return false |
| 272 | + } |
| 273 | + for i := range a { |
| 274 | + if a[i] != b[i] { |
| 275 | + return false |
| 276 | + } |
| 277 | + } |
| 278 | + return true |
| 279 | +} |
| 280 | + |
| 281 | +// frozenSlotIdxs extracts slot indices from FrozenSlot. |
| 282 | +func frozenSlotIdxs[T tensor.Numeric](frozen []FrozenSlot[T]) []int { |
| 283 | + idxs := make([]int, len(frozen)) |
| 284 | + for i, fs := range frozen { |
| 285 | + idxs[i] = fs.SlotIdx |
| 286 | + } |
| 287 | + return idxs |
| 288 | +} |
0 commit comments