Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/coverage.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 77 additions & 28 deletions control-plane/internal/cli/agent_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package cli
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
Expand All @@ -13,6 +13,12 @@ import (
"github.com/stretchr/testify/require"
)

type roundTripFunc func(*http.Request) (*http.Response, error)

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}

func TestAgentHelpers(t *testing.T) {
t.Run("output agent json supports pretty and compact", func(t *testing.T) {
tests := []struct {
Expand Down Expand Up @@ -41,17 +47,30 @@ func TestAgentHelpers(t *testing.T) {
t.Run("agent http covers success headers and failures", func(t *testing.T) {
var gotAPIKey string
var gotContentType string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotAPIKey = r.Header.Get("X-API-Key")
gotContentType = r.Header.Get("Content-Type")
require.Equal(t, "/api/test", r.URL.Path)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
var gotMethod string
var gotPath string
var gotBody []byte

oldTransport := http.DefaultTransport
http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
gotMethod = req.Method
gotPath = req.URL.Path
gotAPIKey = req.Header.Get("X-API-Key")
gotContentType = req.Header.Get("Content-Type")
bodyBytes, err := io.ReadAll(req.Body)
require.NoError(t, err)
gotBody = bodyBytes
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewReader([]byte(`{"ok":true}`))),
Request: req,
}, nil
})
defer func() { http.DefaultTransport = oldTransport }()

oldServer, oldKey, oldTimeout := serverURL, apiKey, requestTimeout
serverURL, apiKey, requestTimeout = server.URL+"/", "api-secret", 1
serverURL, apiKey, requestTimeout = "http://agent.test", "api-secret", 1
defer func() {
serverURL, apiKey, requestTimeout = oldServer, oldKey, oldTimeout
}()
Expand All @@ -60,8 +79,11 @@ func TestAgentHelpers(t *testing.T) {
require.NoError(t, err)
require.Equal(t, http.StatusOK, status)
require.JSONEq(t, `{"ok":true}`, string(body))
require.Equal(t, http.MethodPost, gotMethod)
require.Equal(t, "/api/test", gotPath)
require.Equal(t, "api-secret", gotAPIKey)
require.Equal(t, "application/json", gotContentType)
require.JSONEq(t, `{"name":"demo"}`, string(gotBody))

_, _, err = agentHTTP(http.MethodPost, "/api/test", map[string]func(){"bad": func() {}})
require.ErrorContains(t, err, "encode request body")
Expand All @@ -71,6 +93,21 @@ func TestAgentHelpers(t *testing.T) {
require.ErrorContains(t, err, "build request")
})

t.Run("agent command help path returns structured output", func(t *testing.T) {
oldServer := serverURL
serverURL = "http://example.test"
defer func() { serverURL = oldServer }()

cmd := NewAgentCommand()
cmd.SetArgs([]string{})
output := captureOutput(t, func() {
require.NoError(t, cmd.Execute())
})
require.Contains(t, output, `"ok": true`)
require.Contains(t, output, `"command": "af agent"`)
require.Contains(t, output, `"server": "http://example.test"`)
})

t.Run("read batch input covers stdin and files", func(t *testing.T) {
withStdin(t, `{"operations":[]}`, func() {
data, err := readBatchInput("-")
Expand Down Expand Up @@ -224,22 +261,29 @@ func TestAgentCommandSubcommands(t *testing.T) {
}

var records []requestRecord
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
oldTransport := http.DefaultTransport
http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
var body bytes.Buffer
_, _ = body.ReadFrom(r.Body)
if req.Body != nil {
_, _ = body.ReadFrom(req.Body)
}
records = append(records, requestRecord{
Method: r.Method,
Path: r.URL.Path,
Query: r.URL.RawQuery,
Method: req.Method,
Path: req.URL.Path,
Query: req.URL.RawQuery,
Body: body.String(),
})
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"ok":true,"data":{"id":"demo"}}`))
}))
defer server.Close()
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewReader([]byte(`{"ok":true,"data":{"id":"demo"}}`))),
Request: req,
}, nil
})
defer func() { http.DefaultTransport = oldTransport }()

oldServer, oldFormat, oldTimeout := serverURL, outputFormat, requestTimeout
serverURL, outputFormat, requestTimeout = server.URL, "json", 1
serverURL, outputFormat, requestTimeout = "http://agent.test", "json", 1
defer func() {
serverURL, outputFormat, requestTimeout = oldServer, oldFormat, oldTimeout
}()
Expand Down Expand Up @@ -281,7 +325,7 @@ func TestAgentCommandSubcommands(t *testing.T) {
if tt.wantBodyPart != "" {
require.Contains(t, records[0].Body, tt.wantBodyPart)
}
require.Contains(t, output, `"server": "`+server.URL+`"`)
require.Contains(t, output, `"server": "http://agent.test"`)
})
}

Expand Down Expand Up @@ -339,14 +383,19 @@ func TestSpinnerAndPrintHelpers(t *testing.T) {
}

func TestProxyToServerArrayResponse(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`[{"id":"one"}]`))
}))
defer server.Close()
oldTransport := http.DefaultTransport
http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewReader([]byte(`[{"id":"one"}]`))),
Request: req,
}, nil
})
defer func() { http.DefaultTransport = oldTransport }()

oldServer, oldFormat, oldTimeout := serverURL, outputFormat, requestTimeout
serverURL, outputFormat, requestTimeout = server.URL, "json", 1
serverURL, outputFormat, requestTimeout = "http://agent.test", "json", 1
defer func() {
serverURL, outputFormat, requestTimeout = oldServer, oldFormat, oldTimeout
}()
Expand All @@ -355,7 +404,7 @@ func TestProxyToServerArrayResponse(t *testing.T) {
proxyToServer(http.MethodGet, "/array", nil)
})
require.Contains(t, output, `"ok": true`)
require.Contains(t, output, `"server": "`+server.URL+`"`)
require.Contains(t, output, `"server": "http://agent.test"`)
}

func batchFile(t *testing.T, content string) string {
Expand Down
25 changes: 15 additions & 10 deletions control-plane/internal/cli/command_helpers_additional_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package cli

import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -131,17 +131,22 @@ installed:
require.NoError(t, err)
require.JSONEq(t, `{"operations":[{"id":"1"}]}`, string(data))

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "application/json", r.Header.Get("Accept"))
require.Equal(t, "secret", r.Header.Get("X-API-Key"))
require.Equal(t, "/api/test", r.URL.Path)
w.WriteHeader(http.StatusCreated)
_, _ = w.Write([]byte(`{"ok":true}`))
}))
defer server.Close()
oldTransport := http.DefaultTransport
http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
require.Equal(t, "application/json", req.Header.Get("Accept"))
require.Equal(t, "secret", req.Header.Get("X-API-Key"))
require.Equal(t, "/api/test", req.URL.Path)
return &http.Response{
StatusCode: http.StatusCreated,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(bytes.NewReader([]byte(`{"ok":true}`))),
Request: req,
}, nil
})
defer func() { http.DefaultTransport = oldTransport }()

oldServer, oldAPIKey, oldTimeout := serverURL, apiKey, requestTimeout
serverURL, apiKey, requestTimeout = server.URL, "secret", 1
serverURL, apiKey, requestTimeout = "http://agent.test", "secret", 1
defer func() {
serverURL, apiKey, requestTimeout = oldServer, oldAPIKey, oldTimeout
}()
Expand Down
18 changes: 12 additions & 6 deletions control-plane/internal/cli/commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func TestStopCommand(t *testing.T) {
errorMsg := err.Error()
require.True(t,
strings.Contains(strings.ToLower(errorMsg), "not installed") ||
strings.Contains(strings.ToLower(errorMsg), "not running") ||
strings.Contains(strings.ToLower(errorMsg), "not found"),
strings.Contains(strings.ToLower(errorMsg), "not running") ||
strings.Contains(strings.ToLower(errorMsg), "not found"),
"Expected error about agent not found/installed/running, got: %s", errorMsg)
}
}
Expand Down Expand Up @@ -148,17 +148,23 @@ func TestLogsCommand(t *testing.T) {

// TestInitCommand tests the init command
func TestInitCommand(t *testing.T) {
wd := t.TempDir()
t.Setenv("HOME", t.TempDir())
oldWD, err := os.Getwd()
require.NoError(t, err)
require.NoError(t, os.Chdir(wd))
t.Cleanup(func() {
_ = os.Chdir(oldWD)
})
resetCLIStateForTest()

cmd := NewInitCommand()
cmd.SetOut(io.Discard)
cmd.SetErr(io.Discard)
cmd.SetArgs([]string{})
cmd.SetArgs([]string{"demo-agent", "--non-interactive", "--language", "python", "--author", "Jane Doe", "--email", "jane@example.com"})

err := cmd.Execute()
// May error if already initialized, but validates command structure
_ = err
err = cmd.Execute()
require.NoError(t, err)
}

// TestRootCommandFlags tests various root command flags
Expand Down
16 changes: 10 additions & 6 deletions control-plane/internal/core/services/coverage_targeted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ func TestRunDevHelperProcess(t *testing.T) {
}
}

func runDevPythonHelperMain() string {
return "import os\nimport sys\nimport threading\nimport time\nfrom http.server import BaseHTTPRequestHandler, HTTPServer\n\nport = int(os.environ['PORT'])\nmode = os.environ.get('AF_TEST_HELPER_MODE', 'serve-success')\n\nclass Handler(BaseHTTPRequestHandler):\n def log_message(self, fmt, *args):\n pass\n\n def do_GET(self):\n if self.path == '/health':\n self.send_response(200)\n self.end_headers()\n elif self.path == '/reasoners':\n self.send_response(200)\n self.end_headers()\n if mode == 'serve-capabilities-error':\n self.wfile.write(b'{\\\"reasoners\\\":[')\n else:\n self.wfile.write(b'{\\\"reasoners\\\":[{\\\"id\\\":\\\"helper-reasoner\\\"}]}')\n elif self.path == '/skills':\n self.send_response(200)\n self.end_headers()\n self.wfile.write(b'{\\\"skills\\\":[{\\\"id\\\":\\\"helper-skill\\\"}]}')\n else:\n self.send_response(404)\n self.end_headers()\n\nserver = HTTPServer(('127.0.0.1', port), Handler)\nthread = threading.Thread(target=server.serve_forever)\nthread.daemon = True\nthread.start()\ntime.sleep(10)\nserver.shutdown()\nserver.server_close()\nsys.exit(3 if mode == 'serve-exit-error' else 0)\n"
}

func TestDevServiceRunDev(t *testing.T) {
// These subtests spawn a real python3 subprocess that serves HTTP on a
// port in the 8001-8999 range. discoverAgentPort scans that range with a
Expand All @@ -82,8 +86,8 @@ func TestDevServiceRunDev(t *testing.T) {

port := findPortInAgentRange(t)
t.Setenv("AF_TEST_HELPER_MODE", "serve-success")
script := "#!/bin/sh\npython3 - <<'PY'\nimport os, sys, threading, time\nfrom http.server import BaseHTTPRequestHandler, HTTPServer\nport = int(os.environ['PORT'])\nmode = os.environ.get('AF_TEST_HELPER_MODE', 'serve-success')\nclass Handler(BaseHTTPRequestHandler):\n def log_message(self, fmt, *args):\n pass\n def do_GET(self):\n if self.path == '/health':\n self.send_response(200)\n self.end_headers()\n elif self.path == '/reasoners':\n self.send_response(200)\n self.end_headers()\n if mode == 'serve-capabilities-error':\n self.wfile.write(b'{\"reasoners\":[')\n else:\n self.wfile.write(b'{\"reasoners\":[{\"id\":\"helper-reasoner\"}]}')\n elif self.path == '/skills':\n self.send_response(200)\n self.end_headers()\n self.wfile.write(b'{\"skills\":[{\"id\":\"helper-skill\"}]}')\n else:\n self.send_response(404)\n self.end_headers()\nserver = HTTPServer(('127.0.0.1', port), Handler)\nthread = threading.Thread(target=server.serve_forever)\nthread.daemon = True\nthread.start()\ntime.sleep(1.2)\nserver.shutdown()\nserver.server_close()\nsys.exit(3 if mode == 'serve-exit-error' else 0)\nPY\n"
require.NoError(t, os.WriteFile(filepath.Join(venvBin, "python"), []byte(script), 0o755))
require.NoError(t, os.WriteFile(filepath.Join(dir, "main.py"), []byte(runDevPythonHelperMain()), 0o644))
require.NoError(t, os.WriteFile(filepath.Join(venvBin, "python"), []byte("#!/bin/sh\nexec python3 \"$@\"\n"), 0o755))

service := &DefaultDevService{}
err := service.runDev(dir, domain.DevOptions{Port: port})
Expand All @@ -97,8 +101,8 @@ func TestDevServiceRunDev(t *testing.T) {

port := findPortInAgentRange(t)
t.Setenv("AF_TEST_HELPER_MODE", "serve-exit-error")
script := "#!/bin/sh\npython3 - <<'PY'\nimport os, sys, threading, time\nfrom http.server import BaseHTTPRequestHandler, HTTPServer\nport = int(os.environ['PORT'])\nmode = os.environ.get('AF_TEST_HELPER_MODE', 'serve-success')\nclass Handler(BaseHTTPRequestHandler):\n def log_message(self, fmt, *args):\n pass\n def do_GET(self):\n if self.path == '/health':\n self.send_response(200)\n self.end_headers()\n elif self.path == '/reasoners':\n self.send_response(200)\n self.end_headers()\n if mode == 'serve-capabilities-error':\n self.wfile.write(b'{\"reasoners\":[')\n else:\n self.wfile.write(b'{\"reasoners\":[{\"id\":\"helper-reasoner\"}]}')\n elif self.path == '/skills':\n self.send_response(200)\n self.end_headers()\n self.wfile.write(b'{\"skills\":[{\"id\":\"helper-skill\"}]}')\n else:\n self.send_response(404)\n self.end_headers()\nserver = HTTPServer(('127.0.0.1', port), Handler)\nthread = threading.Thread(target=server.serve_forever)\nthread.daemon = True\nthread.start()\ntime.sleep(1.2)\nserver.shutdown()\nserver.server_close()\nsys.exit(3 if mode == 'serve-exit-error' else 0)\nPY\n"
require.NoError(t, os.WriteFile(filepath.Join(venvBin, "python"), []byte(script), 0o755))
require.NoError(t, os.WriteFile(filepath.Join(dir, "main.py"), []byte(runDevPythonHelperMain()), 0o644))
require.NoError(t, os.WriteFile(filepath.Join(venvBin, "python"), []byte("#!/bin/sh\nexec python3 \"$@\"\n"), 0o755))

service := &DefaultDevService{}
err := service.runDev(dir, domain.DevOptions{Port: port})
Expand All @@ -124,8 +128,8 @@ func TestDevServiceRunDev(t *testing.T) {

port := findPortInAgentRange(t)
t.Setenv("AF_TEST_HELPER_MODE", "serve-capabilities-error")
script := "#!/bin/sh\npython3 - <<'PY'\nimport os, sys, threading, time\nfrom http.server import BaseHTTPRequestHandler, HTTPServer\nport = int(os.environ['PORT'])\nmode = os.environ.get('AF_TEST_HELPER_MODE', 'serve-success')\nclass Handler(BaseHTTPRequestHandler):\n def log_message(self, fmt, *args):\n pass\n def do_GET(self):\n if self.path == '/health':\n self.send_response(200)\n self.end_headers()\n elif self.path == '/reasoners':\n self.send_response(200)\n self.end_headers()\n if mode == 'serve-capabilities-error':\n self.wfile.write(b'{\"reasoners\":[')\n else:\n self.wfile.write(b'{\"reasoners\":[{\"id\":\"helper-reasoner\"}]}')\n elif self.path == '/skills':\n self.send_response(200)\n self.end_headers()\n self.wfile.write(b'{\"skills\":[{\"id\":\"helper-skill\"}]}')\n else:\n self.send_response(404)\n self.end_headers()\nserver = HTTPServer(('127.0.0.1', port), Handler)\nthread = threading.Thread(target=server.serve_forever)\nthread.daemon = True\nthread.start()\ntime.sleep(1.2)\nserver.shutdown()\nserver.server_close()\nsys.exit(3 if mode == 'serve-exit-error' else 0)\nPY\n"
require.NoError(t, os.WriteFile(filepath.Join(venvBin, "python"), []byte(script), 0o755))
require.NoError(t, os.WriteFile(filepath.Join(dir, "main.py"), []byte(runDevPythonHelperMain()), 0o644))
require.NoError(t, os.WriteFile(filepath.Join(venvBin, "python"), []byte("#!/bin/sh\nexec python3 \"$@\"\n"), 0o755))

service := &DefaultDevService{}
err := service.runDev(dir, domain.DevOptions{Port: port})
Expand Down
2 changes: 0 additions & 2 deletions control-plane/internal/packages/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ func waitForHTTPServer(t *testing.T, addr string) {
}

func TestAgentNodeRunnerPortEnvAndRegistry(t *testing.T) {
t.Parallel()

home := t.TempDir()
runner := &AgentNodeRunner{AgentFieldHome: home}

Expand Down
30 changes: 24 additions & 6 deletions sdk/go/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1826,20 +1826,38 @@ func (a *Agent) AIWithTools(ctx context.Context, prompt string, config ai.ToolCa
},
}

allowedTargets := make(map[string]struct{}, len(tools))
for _, tool := range tools {
allowedTargets[normalizeToolInvocationTarget(agentToolNameToInvocationTarget(tool.Function.Name))] = struct{}{}
}

callFn := func(ctx context.Context, target string, input map[string]interface{}) (map[string]interface{}, error) {
if strings.Contains(target, ":skill:") {
parts := strings.SplitN(target, ":skill:", 2)
target = parts[0] + "." + parts[1]
} else if strings.Contains(target, ":") {
parts := strings.SplitN(target, ":", 2)
target = parts[0] + "." + parts[1]
target = normalizeToolInvocationTarget(target)
if _, ok := allowedTargets[target]; !ok {
return nil, fmt.Errorf("tool call target %q is not a discovered capability", target)
}
return a.Call(ctx, target, input)
}

return a.aiClient.ExecuteToolCallLoop(ctx, messages, tools, config, callFn)
}

func normalizeToolInvocationTarget(target string) string {
if strings.Contains(target, ":skill:") {
parts := strings.SplitN(target, ":skill:", 2)
return parts[0] + "." + parts[1]
}
if strings.Contains(target, ":") {
parts := strings.SplitN(target, ":", 2)
return parts[0] + "." + parts[1]
}
return target
}

func agentToolNameToInvocationTarget(name string) string {
return strings.ReplaceAll(name, "__", ":")
}

// AIStream makes a streaming AI/LLM call.
// Returns channels for streaming chunks and errors.
//
Expand Down
Loading
Loading