Skip to content

Commit 01cecdc

Browse files
authored
Issue #30 - Enhance tool descriptions with AI agent decision context (#32)
* Issue #30 - Enhance tool descriptions with AI agent decision context Add decision-making context to MCP tool descriptions so AI agents know when to use each tool, not just what it does. Remove draft releases from goreleaser config. Fix prealloc lint issue in multiserver config. * Issue #30 - Add tool description override API Add WithDescription() for per-registration overrides and WithDescriptions() for toolkit-level bulk overrides, with YAML/JSON config file support for standalone deployments.
1 parent f27fa20 commit 01cecdc

15 files changed

Lines changed: 417 additions & 24 deletions

.goreleaser.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ release:
170170
name: mcp-trino
171171
name_template: "{{.ProjectName}}-v{{.Version}}"
172172
prerelease: auto
173-
draft: true
174173
mode: append
175174
footer: |
176175
## Installation

internal/server/server.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ type Options struct {
3232
// ExtensionsConfig configures middleware, interceptors, and transformers.
3333
ExtensionsConfig extensions.Config
3434

35+
// Descriptions provides custom tool descriptions that override defaults.
36+
// Keys are tool names (e.g., tools.ToolQuery), values are description strings.
37+
Descriptions map[tools.ToolName]string
38+
3539
// SemanticProvider is an optional semantic metadata provider.
3640
// If nil and SEMANTIC_FILE env var is set, a static provider will be created.
3741
SemanticProvider semantic.Provider
@@ -85,6 +89,11 @@ func New(opts Options) (*mcp.Server, *multiserver.Manager, error) {
8589
// Build toolkit options from extensions configuration
8690
toolkitOpts := extensions.BuildToolkitOptions(opts.ExtensionsConfig)
8791

92+
// Apply description overrides if provided
93+
if len(opts.Descriptions) > 0 {
94+
toolkitOpts = append(toolkitOpts, tools.WithDescriptions(opts.Descriptions))
95+
}
96+
8897
// If unconfigured, add middleware that returns helpful error for all tools
8998
if configErr != nil {
9099
toolkitOpts = append(toolkitOpts, tools.WithMiddleware(

internal/server/server_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,42 @@ func TestNew_MultipleConnections(t *testing.T) {
258258
}
259259
}
260260

261+
func TestDefaultOptions_DescriptionsNil(t *testing.T) {
262+
opts := DefaultOptions()
263+
264+
if opts.Descriptions != nil {
265+
t.Error("Descriptions should be nil by default")
266+
}
267+
}
268+
269+
func TestNew_WithDescriptions(t *testing.T) {
270+
opts := Options{
271+
MultiServerConfig: &multiserver.Config{
272+
Default: "default",
273+
Primary: client.Config{
274+
Host: "localhost",
275+
Port: 8080,
276+
User: "admin",
277+
},
278+
},
279+
ToolkitConfig: tools.DefaultConfig(),
280+
Descriptions: map[tools.ToolName]string{
281+
tools.ToolQuery: "Custom query description",
282+
tools.ToolExplain: "Custom explain description",
283+
},
284+
}
285+
286+
server, mgr, err := New(opts)
287+
if err != nil {
288+
t.Fatalf("unexpected error: %v", err)
289+
}
290+
defer func() { _ = mgr.Close() }()
291+
292+
if server == nil {
293+
t.Error("server should not be nil")
294+
}
295+
}
296+
261297
func TestDefaultOptions_SemanticFields(t *testing.T) {
262298
opts := DefaultOptions()
263299

pkg/extensions/config_file.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,11 @@ type TrinoConfig struct {
7171

7272
// ToolkitConfig maps to tools.Config for file-based loading.
7373
type ToolkitConfig struct {
74-
DefaultLimit int `json:"default_limit" yaml:"default_limit"`
75-
MaxLimit int `json:"max_limit" yaml:"max_limit"`
76-
DefaultTimeout Duration `json:"default_timeout" yaml:"default_timeout"`
77-
MaxTimeout Duration `json:"max_timeout" yaml:"max_timeout"`
74+
DefaultLimit int `json:"default_limit" yaml:"default_limit"`
75+
MaxLimit int `json:"max_limit" yaml:"max_limit"`
76+
DefaultTimeout Duration `json:"default_timeout" yaml:"default_timeout"`
77+
MaxTimeout Duration `json:"max_timeout" yaml:"max_timeout"`
78+
Descriptions map[string]string `json:"descriptions,omitempty" yaml:"descriptions,omitempty"`
7879
}
7980

8081
// ExtFileConfig maps to Config for file-based loading.
@@ -258,6 +259,19 @@ func (c ServerConfig) ToolsConfig() tools.Config {
258259
return cfg
259260
}
260261

262+
// DescriptionsMap converts the Toolkit.Descriptions string map to a tools.ToolName map.
263+
// Returns nil if no descriptions are configured.
264+
func (c ServerConfig) DescriptionsMap() map[tools.ToolName]string {
265+
if len(c.Toolkit.Descriptions) == 0 {
266+
return nil
267+
}
268+
m := make(map[tools.ToolName]string, len(c.Toolkit.Descriptions))
269+
for k, v := range c.Toolkit.Descriptions {
270+
m[tools.ToolName(k)] = v
271+
}
272+
return m
273+
}
274+
261275
// ExtConfig converts the Extensions section to a Config.
262276
func (c ServerConfig) ExtConfig() Config {
263277
cfg := DefaultConfig()

pkg/extensions/config_file_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,67 @@ func TestEnvironmentVariableExpansion(t *testing.T) {
370370
}
371371
}
372372

373+
func TestFromBytes_YAML_WithDescriptions(t *testing.T) {
374+
yamlConfig := `
375+
trino:
376+
host: trino.example.com
377+
user: admin
378+
379+
toolkit:
380+
default_limit: 1000
381+
descriptions:
382+
trino_query: "Query the retail analytics data warehouse."
383+
trino_explain: "Check query performance before running."
384+
`
385+
386+
cfg, err := FromBytes([]byte(yamlConfig), ".yaml")
387+
if err != nil {
388+
t.Fatalf("FromBytes failed: %v", err)
389+
}
390+
391+
if len(cfg.Toolkit.Descriptions) != 2 {
392+
t.Fatalf("expected 2 descriptions, got %d", len(cfg.Toolkit.Descriptions))
393+
}
394+
if cfg.Toolkit.Descriptions["trino_query"] != "Query the retail analytics data warehouse." {
395+
t.Errorf("unexpected query description: %q", cfg.Toolkit.Descriptions["trino_query"])
396+
}
397+
if cfg.Toolkit.Descriptions["trino_explain"] != "Check query performance before running." {
398+
t.Errorf("unexpected explain description: %q", cfg.Toolkit.Descriptions["trino_explain"])
399+
}
400+
}
401+
402+
func TestDescriptionsMap(t *testing.T) {
403+
cfg := ServerConfig{
404+
Toolkit: ToolkitConfig{
405+
Descriptions: map[string]string{
406+
"trino_query": "Custom query",
407+
"trino_explain": "Custom explain",
408+
},
409+
},
410+
}
411+
412+
m := cfg.DescriptionsMap()
413+
if len(m) != 2 {
414+
t.Fatalf("expected 2 entries, got %d", len(m))
415+
}
416+
417+
if m["trino_query"] != "Custom query" {
418+
t.Errorf("expected 'Custom query', got %q", m["trino_query"])
419+
}
420+
if m["trino_explain"] != "Custom explain" {
421+
t.Errorf("expected 'Custom explain', got %q", m["trino_explain"])
422+
}
423+
}
424+
425+
func TestDescriptionsMap_Empty(t *testing.T) {
426+
cfg := ServerConfig{}
427+
428+
m := cfg.DescriptionsMap()
429+
if m != nil {
430+
t.Errorf("expected nil for empty descriptions, got %v", m)
431+
}
432+
}
433+
373434
func TestDefaultServerConfig(t *testing.T) {
374435
cfg := DefaultServerConfig()
375436

pkg/multiserver/config.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ func (c Config) ClientConfig(name string) (client.Config, error) {
153153
// ConnectionNames returns the names of all available connections.
154154
// Always includes the primary connection name as the first entry.
155155
func (c Config) ConnectionNames() []string {
156-
names := []string{c.Default}
156+
names := make([]string, 1, 1+len(c.Connections))
157+
names[0] = c.Default
157158
for name := range c.Connections {
158159
names = append(names, name)
159160
}

pkg/tools/connections.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ func (t *Toolkit) registerListConnectionsTool(server *mcp.Server, cfg *toolConfi
4040

4141
// Register with MCP
4242
mcp.AddTool(server, &mcp.Tool{
43-
Name: "trino_list_connections",
44-
Description: "List all configured Trino server connections. " +
45-
"Use this to discover available connections before querying specific servers. " +
46-
"Pass the connection name to other tools via the 'connection' parameter.",
43+
Name: string(ToolListConnections),
44+
Description: t.getDescription(ToolListConnections, cfg),
4745
}, func(ctx context.Context, req *mcp.CallToolRequest, input ListConnectionsInput) (*mcp.CallToolResult, any, error) {
4846
return wrappedHandler(ctx, req, input)
4947
})

pkg/tools/descriptions.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package tools
2+
3+
// defaultDescriptions holds the default description for each built-in tool.
4+
// These are used when no override is provided via WithDescription or WithDescriptions.
5+
var defaultDescriptions = map[ToolName]string{
6+
ToolQuery: "Execute a SELECT query against Trino and return results. Use the table path format " +
7+
"catalog.schema.table. Results are limited to prevent excessive data transfer — use " +
8+
"the limit parameter to control result size. For large tables, always include WHERE " +
9+
"clauses to filter results. Consider using trino_explain first for expensive queries.",
10+
11+
ToolExplain: "Get the execution plan for a SQL query to understand performance characteristics " +
12+
"before running expensive queries. Use this when querying large tables (millions of " +
13+
"rows) to verify the query plan uses appropriate filters. Also useful for debugging " +
14+
"slow queries or understanding join strategies.",
15+
16+
ToolListCatalogs: "List all available catalogs in the Trino cluster. Catalogs are the top-level " +
17+
"containers for schemas and tables.",
18+
19+
ToolListSchemas: "List all schemas in a catalog. Schemas are containers for tables within a catalog.",
20+
21+
ToolListTables: "List all tables in a schema. Optionally filter by a LIKE pattern.",
22+
23+
ToolDescribeTable: "Get detailed table information with columns, types, and optional sample data. " +
24+
"Set include_sample=true to see actual data values, which helps understand column " +
25+
"meaning and data formats. This is the richest single-call way to understand a " +
26+
"table's structure. Requires catalog, schema, and table name — use trino_list_tables " +
27+
"to discover available tables if needed.",
28+
29+
ToolListConnections: "List all configured Trino server connections. " +
30+
"Use this to discover available connections before querying specific servers. " +
31+
"Pass the connection name to other tools via the 'connection' parameter.",
32+
}
33+
34+
// DefaultDescription returns the default description for a tool.
35+
// Returns an empty string for unknown tool names.
36+
func DefaultDescription(name ToolName) string {
37+
return defaultDescriptions[name]
38+
}
39+
40+
// getDescription resolves the description for a tool using the priority chain:
41+
// 1. Per-registration override (cfg.description) — highest priority
42+
// 2. Toolkit-level override (t.descriptions) — medium priority
43+
// 3. Default description — lowest priority.
44+
func (t *Toolkit) getDescription(name ToolName, cfg *toolConfig) string {
45+
// Per-registration override (highest priority)
46+
if cfg != nil && cfg.description != nil {
47+
return *cfg.description
48+
}
49+
50+
// Toolkit-level override (medium priority)
51+
if desc, ok := t.descriptions[name]; ok {
52+
return desc
53+
}
54+
55+
// Default description (lowest priority)
56+
return defaultDescriptions[name]
57+
}

pkg/tools/descriptions_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package tools
2+
3+
import "testing"
4+
5+
func TestDefaultDescription(t *testing.T) {
6+
tests := []struct {
7+
name ToolName
8+
wantNon bool // expect non-empty
9+
}{
10+
{ToolQuery, true},
11+
{ToolExplain, true},
12+
{ToolListCatalogs, true},
13+
{ToolListSchemas, true},
14+
{ToolListTables, true},
15+
{ToolDescribeTable, true},
16+
{ToolListConnections, true},
17+
{ToolName("unknown_tool"), false},
18+
}
19+
20+
for _, tt := range tests {
21+
t.Run(string(tt.name), func(t *testing.T) {
22+
desc := DefaultDescription(tt.name)
23+
if tt.wantNon && desc == "" {
24+
t.Errorf("expected non-empty description for %s", tt.name)
25+
}
26+
if !tt.wantNon && desc != "" {
27+
t.Errorf("expected empty description for %s, got %q", tt.name, desc)
28+
}
29+
})
30+
}
31+
}
32+
33+
func TestDefaultDescriptions_AllToolsCovered(t *testing.T) {
34+
for _, name := range AllTools() {
35+
if DefaultDescription(name) == "" {
36+
t.Errorf("tool %s has no default description", name)
37+
}
38+
}
39+
}
40+
41+
func TestGetDescription_Priority(t *testing.T) {
42+
tests := []struct {
43+
name string
44+
toolkitDescs map[ToolName]string
45+
cfgDesc *string
46+
toolName ToolName
47+
wantDesc string
48+
}{
49+
{
50+
name: "default only",
51+
toolkitDescs: nil,
52+
cfgDesc: nil,
53+
toolName: ToolQuery,
54+
wantDesc: defaultDescriptions[ToolQuery],
55+
},
56+
{
57+
name: "toolkit override",
58+
toolkitDescs: map[ToolName]string{ToolQuery: "Toolkit override"},
59+
cfgDesc: nil,
60+
toolName: ToolQuery,
61+
wantDesc: "Toolkit override",
62+
},
63+
{
64+
name: "per-registration override",
65+
toolkitDescs: nil,
66+
cfgDesc: strPtr("Per-reg override"),
67+
toolName: ToolQuery,
68+
wantDesc: "Per-reg override",
69+
},
70+
{
71+
name: "per-registration beats toolkit",
72+
toolkitDescs: map[ToolName]string{ToolQuery: "Toolkit override"},
73+
cfgDesc: strPtr("Per-reg override"),
74+
toolName: ToolQuery,
75+
wantDesc: "Per-reg override",
76+
},
77+
{
78+
name: "toolkit override for one tool, default for another",
79+
toolkitDescs: map[ToolName]string{ToolQuery: "Custom query desc"},
80+
cfgDesc: nil,
81+
toolName: ToolExplain,
82+
wantDesc: defaultDescriptions[ToolExplain],
83+
},
84+
}
85+
86+
for _, tt := range tests {
87+
t.Run(tt.name, func(t *testing.T) {
88+
tk := &Toolkit{
89+
descriptions: make(map[ToolName]string),
90+
}
91+
if tt.toolkitDescs != nil {
92+
tk.descriptions = tt.toolkitDescs
93+
}
94+
95+
var cfg *toolConfig
96+
if tt.cfgDesc != nil {
97+
cfg = &toolConfig{description: tt.cfgDesc}
98+
}
99+
100+
got := tk.getDescription(tt.toolName, cfg)
101+
if got != tt.wantDesc {
102+
t.Errorf("getDescription() = %q, want %q", got, tt.wantDesc)
103+
}
104+
})
105+
}
106+
}
107+
108+
func TestGetDescription_NilConfig(t *testing.T) {
109+
tk := &Toolkit{
110+
descriptions: make(map[ToolName]string),
111+
}
112+
113+
got := tk.getDescription(ToolQuery, nil)
114+
if got != defaultDescriptions[ToolQuery] {
115+
t.Errorf("expected default description, got %q", got)
116+
}
117+
}
118+
119+
func strPtr(s string) *string {
120+
return &s
121+
}

pkg/tools/explain.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ func (t *Toolkit) registerExplainTool(server *mcp.Server, cfg *toolConfig) {
4040

4141
// Register with MCP using typed handler that calls wrapped handler
4242
mcp.AddTool(server, &mcp.Tool{
43-
Name: "trino_explain",
44-
Description: "Get the execution plan for a SQL query. " +
45-
"Use this to understand how Trino will execute a query and identify potential performance issues.",
43+
Name: string(ToolExplain),
44+
Description: t.getDescription(ToolExplain, cfg),
4645
}, func(ctx context.Context, req *mcp.CallToolRequest, input ExplainInput) (*mcp.CallToolResult, any, error) {
4746
return wrappedHandler(ctx, req, input)
4847
})

0 commit comments

Comments
 (0)