Skip to content

Commit acc4e44

Browse files
committed
feat(capabilityclient): forward CTX_AI_MODEL in ai.generate requests
When CTX_AI_MODEL is set in the container environment (injected by ctx from the extension's model: field in ctx.yml), Generate() now includes it as a "model" field in the ai.generate POST body. The host capability server uses this to create a one-shot provider with the overridden model instead of the cached default. The field is omitted when CTX_AI_MODEL is unset, preserving existing behaviour for callers that do not specify a per-extension model.
1 parent 33f41e1 commit acc4e44

2 files changed

Lines changed: 80 additions & 1 deletion

File tree

capabilityclient/generate.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"os"
78
"strings"
89
)
910

1011
type generateInput struct {
1112
Prompt string `json:"prompt"`
13+
Model string `json:"model,omitempty"` // optional per-call model override
1214
}
1315

1416
type generateOutput struct {
@@ -18,14 +20,22 @@ type generateOutput struct {
1820
// Generate calls the built-in ai.generate capability with the given prompt and
1921
// returns the trimmed text output.
2022
//
23+
// When CTX_AI_MODEL is set in the container environment (injected by ctx from
24+
// the extension's model: field in ctx.yml), it is forwarded in the request
25+
// body so the host capability server uses that model for this call.
26+
//
2127
// This is the typed equivalent of:
2228
//
2329
// ctx run ai.generate --input '{"prompt":"..."}'
2430
func (c *Client) Generate(ctx context.Context, prompt string) (string, error) {
2531
if prompt == "" {
2632
return "", fmt.Errorf("capability-client: Generate: prompt must not be empty")
2733
}
28-
raw, err := c.Run(ctx, "ai.generate", generateInput{Prompt: prompt})
34+
input := generateInput{Prompt: prompt}
35+
if model := os.Getenv("CTX_AI_MODEL"); model != "" {
36+
input.Model = model
37+
}
38+
raw, err := c.Run(ctx, "ai.generate", input)
2939
if err != nil {
3040
return "", err
3141
}

capabilityclient/generate_test.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package capabilityclient_test
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
7+
"io"
68
"net/http"
79
"net/http/httptest"
810
"testing"
@@ -64,3 +66,70 @@ func TestGenerate_RunError(t *testing.T) {
6466
t.Errorf("error type = %T; want *RunError", err)
6567
}
6668
}
69+
70+
func TestGenerate_ForwardsModelFromEnv(t *testing.T) {
71+
var capturedBody []byte
72+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
73+
capturedBody, _ = io.ReadAll(r.Body)
74+
w.Header().Set("Content-Type", "application/json")
75+
w.Write([]byte(`{"output":"ok"}`)) //nolint:errcheck
76+
}))
77+
defer srv.Close()
78+
t.Setenv("CTX_CAPABILITY_URL", srv.URL)
79+
t.Setenv("CTX_AI_MODEL", "claude-haiku-4-5")
80+
81+
_, err := capabilityclient.New().Generate(context.Background(), "hello")
82+
if err != nil {
83+
t.Fatalf("unexpected error: %v", err)
84+
}
85+
86+
// The body is {"input": <inputJSON>}; unwrap to check the model field.
87+
var envelope struct {
88+
Input json.RawMessage `json:"input"`
89+
}
90+
if err := json.Unmarshal(capturedBody, &envelope); err != nil {
91+
t.Fatalf("parse envelope: %v", err)
92+
}
93+
var input struct {
94+
Model string `json:"model"`
95+
}
96+
if err := json.Unmarshal(envelope.Input, &input); err != nil {
97+
t.Fatalf("parse input: %v", err)
98+
}
99+
if input.Model != "claude-haiku-4-5" {
100+
t.Errorf("model = %q; want claude-haiku-4-5", input.Model)
101+
}
102+
}
103+
104+
func TestGenerate_NoModelWhenEnvUnset(t *testing.T) {
105+
var capturedBody []byte
106+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
107+
capturedBody, _ = io.ReadAll(r.Body)
108+
w.Header().Set("Content-Type", "application/json")
109+
w.Write([]byte(`{"output":"ok"}`)) //nolint:errcheck
110+
}))
111+
defer srv.Close()
112+
t.Setenv("CTX_CAPABILITY_URL", srv.URL)
113+
t.Setenv("CTX_AI_MODEL", "") // explicitly unset
114+
115+
_, err := capabilityclient.New().Generate(context.Background(), "hello")
116+
if err != nil {
117+
t.Fatalf("unexpected error: %v", err)
118+
}
119+
120+
var envelope struct {
121+
Input json.RawMessage `json:"input"`
122+
}
123+
if err := json.Unmarshal(capturedBody, &envelope); err != nil {
124+
t.Fatalf("parse envelope: %v", err)
125+
}
126+
var input struct {
127+
Model *string `json:"model"`
128+
}
129+
if err := json.Unmarshal(envelope.Input, &input); err != nil {
130+
t.Fatalf("parse input: %v", err)
131+
}
132+
if input.Model != nil {
133+
t.Errorf("model field present with value %q; want omitted", *input.Model)
134+
}
135+
}

0 commit comments

Comments
 (0)