Skip to content

Commit 33f41e1

Browse files
miohansenclaude
andcommitted
feat: add generation and templates packages (v0.2.0)
Promotes the shared AI generation retry loop and prompt template loading logic from ctx-git/ctx-specs into the SDK so extensions can import them instead of maintaining byte-for-byte copies. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent f2e3d37 commit 33f41e1

7 files changed

Lines changed: 656 additions & 0 deletions

File tree

generation/config.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Package generation provides AI generation with retry and validation.
2+
package generation
3+
4+
import (
5+
"context"
6+
"io"
7+
)
8+
9+
// Format describes the expected output format.
10+
type Format string
11+
12+
const (
13+
FormatPlainText Format = "plaintext"
14+
FormatJSON Format = "json"
15+
)
16+
17+
// Config drives a single generation call.
18+
type Config struct {
19+
TypeName string
20+
Format Format
21+
Validator Validator
22+
AIFunc func(ctx context.Context, prompt string) (string, error)
23+
MaxAttempts int
24+
Debug io.Writer
25+
}
26+
27+
// Result holds the outcome of a generation call.
28+
type Result struct {
29+
Output string
30+
Attempts int
31+
}

generation/generate.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package generation
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"strings"
8+
)
9+
10+
func Generate(ctx context.Context, cfg *Config, prompt string) (*Result, error) {
11+
maxAttempts := cfg.MaxAttempts
12+
if maxAttempts <= 0 {
13+
maxAttempts = defaultMaxAttempts
14+
}
15+
16+
validator := cfg.Validator
17+
if validator == nil {
18+
validator = NoopValidator{}
19+
}
20+
21+
currentPrompt := prompt
22+
var lastOutput string
23+
var lastErrs []ValidationError
24+
25+
for attempt := 1; attempt <= maxAttempts; attempt++ {
26+
debugf(cfg.Debug, "=== attempt %d/%d ===\n", attempt, maxAttempts)
27+
debugf(cfg.Debug, "--- prompt ---\n%s\n---\n", currentPrompt)
28+
29+
raw, err := cfg.AIFunc(ctx, currentPrompt)
30+
if err != nil {
31+
return nil, fmt.Errorf("%s: AI call failed (attempt %d): %w",
32+
cfg.TypeName, attempt, err)
33+
}
34+
35+
cleaned := cleanOutput(raw)
36+
debugf(cfg.Debug, "--- response (cleaned) ---\n%s\n---\n", cleaned)
37+
38+
errs := validator.Validate(cleaned)
39+
if len(errs) == 0 {
40+
debugf(cfg.Debug, "validation passed on attempt %d\n", attempt)
41+
return &Result{Output: cleaned, Attempts: attempt}, nil
42+
}
43+
44+
lastOutput = cleaned
45+
lastErrs = errs
46+
47+
if attempt < maxAttempts {
48+
currentPrompt = buildRetryPrompt(currentPrompt, errs, attempt)
49+
}
50+
}
51+
52+
errMsgs := make([]string, len(lastErrs))
53+
for i, e := range lastErrs {
54+
errMsgs[i] = e.Message
55+
}
56+
return &Result{Output: lastOutput, Attempts: maxAttempts},
57+
fmt.Errorf("%s: validation failed after %d attempts: %s",
58+
cfg.TypeName, maxAttempts, strings.Join(errMsgs, "; "))
59+
}
60+
61+
func debugf(w io.Writer, format string, args ...interface{}) {
62+
if w != nil {
63+
fmt.Fprintf(w, format, args...)
64+
}
65+
}

generation/generation_test.go

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
package generation
2+
3+
import (
4+
"context"
5+
"errors"
6+
"strings"
7+
"testing"
8+
)
9+
10+
// --- helpers ----------------------------------------------------------------
11+
12+
// fixedAIFunc returns a sequence of responses in order. Once the sequence is
13+
// exhausted every subsequent call returns the last response.
14+
func fixedAIFunc(responses ...string) func(context.Context, string) (string, error) {
15+
i := 0
16+
return func(_ context.Context, _ string) (string, error) {
17+
resp := responses[i]
18+
if i < len(responses)-1 {
19+
i++
20+
}
21+
return resp, nil
22+
}
23+
}
24+
25+
// errorAIFunc always returns the given error.
26+
func errorAIFunc(err error) func(context.Context, string) (string, error) {
27+
return func(_ context.Context, _ string) (string, error) {
28+
return "", err
29+
}
30+
}
31+
32+
// fixedValidator fails for a set of outputs and passes for everything else.
33+
type fixedValidator struct {
34+
failOutputs map[string][]ValidationError
35+
}
36+
37+
func (v *fixedValidator) Validate(output string) []ValidationError {
38+
if errs, ok := v.failOutputs[output]; ok {
39+
return errs
40+
}
41+
return nil
42+
}
43+
44+
// --- Generate ----------------------------------------------------------------
45+
46+
func TestGenerate_successFirstAttempt(t *testing.T) {
47+
cfg := &Config{
48+
TypeName: "TestType",
49+
AIFunc: fixedAIFunc("valid output"),
50+
Validator: &fixedValidator{
51+
failOutputs: map[string][]ValidationError{},
52+
},
53+
}
54+
55+
result, err := Generate(context.Background(), cfg, "prompt")
56+
if err != nil {
57+
t.Fatalf("unexpected error: %v", err)
58+
}
59+
if result.Attempts != 1 {
60+
t.Errorf("Attempts = %d, want 1", result.Attempts)
61+
}
62+
if result.Output != "valid output" {
63+
t.Errorf("Output = %q, want %q", result.Output, "valid output")
64+
}
65+
}
66+
67+
func TestGenerate_retryOnValidationError(t *testing.T) {
68+
// First call returns "bad", second call returns "good".
69+
cfg := &Config{
70+
TypeName: "TestType",
71+
AIFunc: fixedAIFunc("bad", "good"),
72+
Validator: &fixedValidator{
73+
failOutputs: map[string][]ValidationError{
74+
"bad": {{Message: "output is bad"}},
75+
},
76+
},
77+
}
78+
79+
result, err := Generate(context.Background(), cfg, "prompt")
80+
if err != nil {
81+
t.Fatalf("unexpected error: %v", err)
82+
}
83+
if result.Attempts != 2 {
84+
t.Errorf("Attempts = %d, want 2", result.Attempts)
85+
}
86+
if result.Output != "good" {
87+
t.Errorf("Output = %q, want %q", result.Output, "good")
88+
}
89+
}
90+
91+
func TestGenerate_exhaustsMaxAttempts(t *testing.T) {
92+
cfg := &Config{
93+
TypeName: "TestType",
94+
MaxAttempts: 2,
95+
AIFunc: fixedAIFunc("always bad"),
96+
Validator: &fixedValidator{
97+
failOutputs: map[string][]ValidationError{
98+
"always bad": {{Message: "still bad"}},
99+
},
100+
},
101+
}
102+
103+
result, err := Generate(context.Background(), cfg, "prompt")
104+
if err == nil {
105+
t.Fatal("expected error but got nil")
106+
}
107+
if result == nil {
108+
t.Fatal("expected non-nil Result even on failure")
109+
}
110+
if result.Attempts != 2 {
111+
t.Errorf("Attempts = %d, want 2", result.Attempts)
112+
}
113+
if result.Output != "always bad" {
114+
t.Errorf("Output = %q, want %q", result.Output, "always bad")
115+
}
116+
if !strings.Contains(err.Error(), "still bad") {
117+
t.Errorf("error message %q does not contain validation error text", err.Error())
118+
}
119+
}
120+
121+
func TestGenerate_AIFuncError(t *testing.T) {
122+
aiErr := errors.New("AI unavailable")
123+
cfg := &Config{
124+
TypeName: "TestType",
125+
AIFunc: errorAIFunc(aiErr),
126+
}
127+
128+
_, err := Generate(context.Background(), cfg, "prompt")
129+
if err == nil {
130+
t.Fatal("expected error but got nil")
131+
}
132+
if !errors.Is(err, aiErr) {
133+
t.Errorf("error %v does not wrap %v", err, aiErr)
134+
}
135+
}
136+
137+
func TestGenerate_defaultMaxAttempts(t *testing.T) {
138+
// MaxAttempts=0 should fall back to defaultMaxAttempts (3).
139+
// We count actual calls via a counter AIFunc that always returns invalid output.
140+
calls := 0
141+
cfg := &Config{
142+
TypeName: "TestType",
143+
MaxAttempts: 0,
144+
AIFunc: func(_ context.Context, _ string) (string, error) {
145+
calls++
146+
return "invalid", nil
147+
},
148+
Validator: &fixedValidator{
149+
failOutputs: map[string][]ValidationError{
150+
"invalid": {{Message: "invalid"}},
151+
},
152+
},
153+
}
154+
155+
_, err := Generate(context.Background(), cfg, "prompt")
156+
if err == nil {
157+
t.Fatal("expected error but got nil")
158+
}
159+
if calls != defaultMaxAttempts {
160+
t.Errorf("AIFunc called %d times, want %d (defaultMaxAttempts)", calls, defaultMaxAttempts)
161+
}
162+
}
163+
164+
func TestGenerate_nilValidator(t *testing.T) {
165+
cfg := &Config{
166+
TypeName: "TestType",
167+
Validator: nil,
168+
AIFunc: fixedAIFunc("output"),
169+
}
170+
171+
// Should not panic — nil Validator must be treated as NoopValidator.
172+
result, err := Generate(context.Background(), cfg, "prompt")
173+
if err != nil {
174+
t.Fatalf("unexpected error: %v", err)
175+
}
176+
if result.Attempts != 1 {
177+
t.Errorf("Attempts = %d, want 1", result.Attempts)
178+
}
179+
}
180+
181+
// --- cleanOutput -------------------------------------------------------------
182+
183+
func TestCleanOutput_stripsFences(t *testing.T) {
184+
input := "```json\n{\"key\":\"value\"}\n```"
185+
got := cleanOutput(input)
186+
want := `{"key":"value"}`
187+
if got != want {
188+
t.Errorf("cleanOutput() = %q, want %q", got, want)
189+
}
190+
}
191+
192+
func TestCleanOutput_noFences(t *testing.T) {
193+
input := "plain text output"
194+
got := cleanOutput(input)
195+
if got != input {
196+
t.Errorf("cleanOutput() = %q, want %q", got, input)
197+
}
198+
}
199+
200+
func TestCleanOutput_trimsWhitespace(t *testing.T) {
201+
input := " \n trimmed \n "
202+
got := cleanOutput(input)
203+
want := "trimmed"
204+
if got != want {
205+
t.Errorf("cleanOutput() = %q, want %q", got, want)
206+
}
207+
}
208+
209+
// --- buildRetryPrompt --------------------------------------------------------
210+
211+
func TestBuildRetryPrompt_format(t *testing.T) {
212+
errs := []ValidationError{{Message: "field missing"}}
213+
result := buildRetryPrompt("original prompt", errs, 1)
214+
215+
checks := []struct {
216+
desc string
217+
contain string
218+
}{
219+
{"contains original prompt", "original prompt"},
220+
{"contains numbered error", "1. field missing"},
221+
{"contains attempt number", "attempt 1"},
222+
{"contains correction header", "INSTRUCTIONS FOR CORRECTION"},
223+
{"contains regenerate instruction", "Please regenerate the output"},
224+
}
225+
226+
for _, c := range checks {
227+
if !strings.Contains(result, c.contain) {
228+
t.Errorf("%s: output does not contain %q\nFull output:\n%s", c.desc, c.contain, result)
229+
}
230+
}
231+
}
232+
233+
func TestBuildRetryPrompt_multipleErrors(t *testing.T) {
234+
errs := []ValidationError{
235+
{Message: "first error"},
236+
{Message: "second error"},
237+
{Message: "third error"},
238+
}
239+
result := buildRetryPrompt("prompt", errs, 2)
240+
241+
for i, e := range errs {
242+
numbered := strings.Join([]string{
243+
// e.g. "1. first error"
244+
string(rune('1'+i)) + ". " + e.Message,
245+
}, "")
246+
if !strings.Contains(result, numbered) {
247+
t.Errorf("output does not contain %q\nFull output:\n%s", numbered, result)
248+
}
249+
}
250+
}
251+
252+
// --- JSONValidator -----------------------------------------------------------
253+
254+
func TestJSONValidator_valid(t *testing.T) {
255+
v := JSONValidator{RequiredFields: []string{"name", "age"}}
256+
errs := v.Validate(`{"name":"Alice","age":30}`)
257+
if len(errs) != 0 {
258+
t.Errorf("expected no errors, got %v", errs)
259+
}
260+
}
261+
262+
func TestJSONValidator_missingField(t *testing.T) {
263+
v := JSONValidator{RequiredFields: []string{"name", "age"}}
264+
errs := v.Validate(`{"name":"Alice"}`)
265+
if len(errs) != 1 {
266+
t.Fatalf("expected 1 error, got %d: %v", len(errs), errs)
267+
}
268+
if !strings.Contains(errs[0].Message, "age") {
269+
t.Errorf("error message %q does not mention missing field %q", errs[0].Message, "age")
270+
}
271+
}
272+
273+
func TestJSONValidator_invalidJSON(t *testing.T) {
274+
v := JSONValidator{}
275+
errs := v.Validate("not json at all")
276+
if len(errs) != 1 {
277+
t.Fatalf("expected 1 error, got %d: %v", len(errs), errs)
278+
}
279+
if !strings.Contains(errs[0].Message, "invalid JSON") {
280+
t.Errorf("error message %q does not indicate JSON parse failure", errs[0].Message)
281+
}
282+
}
283+
284+
// --- NoopValidator -----------------------------------------------------------
285+
286+
func TestNoopValidator(t *testing.T) {
287+
v := NoopValidator{}
288+
inputs := []string{
289+
"",
290+
"some output",
291+
"not json {{{",
292+
" ",
293+
}
294+
for _, input := range inputs {
295+
errs := v.Validate(input)
296+
if len(errs) != 0 {
297+
t.Errorf("NoopValidator.Validate(%q) = %v, want nil", input, errs)
298+
}
299+
}
300+
}

0 commit comments

Comments
 (0)