Skip to content

Commit 4958dea

Browse files
authored
Issues #22, #23, #24 - QueryStats Improvements (#26)
* Add query ID and warnings to `QueryResult`, refactor `QueryStats` duration handling - Added `QueryID` and `Warnings` fields to `QueryResult`. - Refactored `QueryStats` to use `DurationMs` (in milliseconds) instead of `time.Duration`. - Updated tests to verify new fields and JSON serialization. - Introduced `queryProgressUpdater` to capture query ID during execution. - Added fallback logic to handle unsupported progress callbacks. * Update `QueryStats` tests to verify `QueryID` addition and `DurationMs` refactor * Refactor tests: replace `time.Duration` with `DurationMs` in `QueryStats` * Update `QueryResult` to use `DurationMs` instead of `Milliseconds` for execution time formatting * Refactor tests: replace `Duration` with `DurationMs` in `QueryStats` usage * Remove Warnings field from QueryResult and related tests - Removed the `Warnings` field from `QueryResult` as it is not exposed in trino-go-client v0.333.0. - Updated tests to reflect this change. - Added a comment noting the Trino REST API behavior and linked to the relevant GitHub issue. * Add tests for `queryProgressUpdater` including concurrent access handling - Introduced unit tests for `queryProgressUpdater` to verify correct handling of query ID updates. - Added a concurrency test to ensure thread safety during concurrent reads and writes. * Add codecov configuration to enforce coverage targets and enable PR comments - Introduced `codecov.yml` with project and patch coverage targets (82% and 80% respectively). - Configured Codecov to provide PR comments for coverage changes.
1 parent fd00323 commit 4958dea

7 files changed

Lines changed: 191 additions & 27 deletions

File tree

codecov.yml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
coverage:
2+
status:
3+
project:
4+
default:
5+
target: 82%
6+
threshold: 2%
7+
patch:
8+
default:
9+
target: 80%
10+
threshold: 5%
11+
12+
comment:
13+
layout: "reach,diff,flags,files"
14+
behavior: default
15+
require_changes: true

pkg/client/client.go

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ import (
66
"encoding/json"
77
"fmt"
88
"strings"
9+
"sync"
910
"time"
1011

11-
_ "github.com/trinodb/trino-go-client/trino" // Trino database driver
12+
"github.com/trinodb/trino-go-client/trino"
1213
)
1314

1415
// Client is a wrapper around the Trino database connection.
@@ -68,6 +69,8 @@ type QueryResult struct {
6869
Columns []ColumnInfo `json:"columns"`
6970
Rows []map[string]any `json:"rows"`
7071
Stats QueryStats `json:"stats"`
72+
// Note: Trino REST API returns warnings, but trino-go-client v0.333.0
73+
// does not expose them. See: https://github.com/trinodb/trino-go-client/issues
7174
}
7275

7376
// ColumnInfo describes a column in the result set.
@@ -79,10 +82,11 @@ type ColumnInfo struct {
7982

8083
// QueryStats contains execution statistics.
8184
type QueryStats struct {
82-
RowCount int `json:"row_count"`
83-
Duration time.Duration `json:"duration_ms"`
84-
Truncated bool `json:"truncated"`
85-
LimitApplied int `json:"limit_applied,omitempty"`
85+
RowCount int `json:"row_count"`
86+
DurationMs int64 `json:"duration_ms"`
87+
Truncated bool `json:"truncated"`
88+
LimitApplied int `json:"limit_applied,omitempty"`
89+
QueryID string `json:"query_id,omitempty"`
8690
}
8791

8892
// QueryOptions configures query execution.
@@ -107,6 +111,28 @@ func DefaultQueryOptions() QueryOptions {
107111
}
108112
}
109113

114+
// queryProgressUpdater captures the query ID from Trino progress updates.
115+
type queryProgressUpdater struct {
116+
mu sync.Mutex
117+
queryID string
118+
}
119+
120+
// Update implements the trino.ProgressUpdater interface.
121+
func (p *queryProgressUpdater) Update(info trino.QueryProgressInfo) {
122+
p.mu.Lock()
123+
defer p.mu.Unlock()
124+
if info.QueryId != "" {
125+
p.queryID = info.QueryId
126+
}
127+
}
128+
129+
// QueryID returns the captured query ID.
130+
func (p *queryProgressUpdater) QueryID() string {
131+
p.mu.Lock()
132+
defer p.mu.Unlock()
133+
return p.queryID
134+
}
135+
110136
// Query executes a SQL query and returns the results.
111137
func (c *Client) Query(ctx context.Context, sqlQuery string, opts QueryOptions) (*QueryResult, error) {
112138
start := time.Now()
@@ -125,10 +151,28 @@ func (c *Client) Query(ctx context.Context, sqlQuery string, opts QueryOptions)
125151
limit = 1000
126152
}
127153

128-
// Execute query
129-
rows, err := c.db.QueryContext(ctx, sqlQuery)
154+
// Set up progress updater to capture query ID
155+
progressUpdater := &queryProgressUpdater{}
156+
157+
// Execute query with progress callback to capture query ID.
158+
// The progress callback is a Trino-specific feature. If the driver doesn't
159+
// support it (e.g., when using sqlmock for testing), fall back to a simple query.
160+
rows, err := c.db.QueryContext(ctx, sqlQuery,
161+
sql.Named("X-Trino-Progress-Callback", trino.ProgressUpdater(progressUpdater)),
162+
sql.Named("X-Trino-Progress-Callback-Period", 100*time.Millisecond),
163+
)
130164
if err != nil {
131-
return nil, fmt.Errorf("query failed: %w", err)
165+
// Check if the error is due to unsupported argument type (e.g., when using sqlmock).
166+
// In that case, retry without the progress callback.
167+
if strings.Contains(err.Error(), "unsupported type") {
168+
rows, err = c.db.QueryContext(ctx, sqlQuery)
169+
if err != nil {
170+
return nil, fmt.Errorf("query failed: %w", err)
171+
}
172+
progressUpdater = nil // Clear so we don't try to get QueryID
173+
} else {
174+
return nil, fmt.Errorf("query failed: %w", err)
175+
}
132176
}
133177
defer func() { _ = rows.Close() }()
134178

@@ -187,12 +231,16 @@ func (c *Client) Query(ctx context.Context, sqlQuery string, opts QueryOptions)
187231
return nil, fmt.Errorf("row iteration error: %w", err)
188232
}
189233

190-
result.Stats = QueryStats{
234+
stats := QueryStats{
191235
RowCount: rowCount,
192-
Duration: time.Since(start),
236+
DurationMs: time.Since(start).Milliseconds(),
193237
Truncated: truncated,
194238
LimitApplied: limit,
195239
}
240+
if progressUpdater != nil {
241+
stats.QueryID = progressUpdater.QueryID()
242+
}
243+
result.Stats = stats
196244

197245
return result, nil
198246
}

pkg/client/client_test.go

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ package client
33
import (
44
"bytes"
55
"encoding/json"
6+
"sync"
67
"testing"
78
"time"
9+
10+
"github.com/trinodb/trino-go-client/trino"
811
)
912

1013
func TestNew_InvalidConfig(t *testing.T) {
@@ -202,9 +205,10 @@ func TestQueryResult_JSON(t *testing.T) {
202205
},
203206
Stats: QueryStats{
204207
RowCount: 2,
205-
Duration: 100 * time.Millisecond,
208+
DurationMs: 100,
206209
Truncated: false,
207210
LimitApplied: 1000,
211+
QueryID: "20240115_123456_00001_abcde",
208212
},
209213
}
210214

@@ -227,6 +231,9 @@ func TestQueryResult_JSON(t *testing.T) {
227231
if decoded.Stats.RowCount != 2 {
228232
t.Errorf("expected RowCount 2, got %d", decoded.Stats.RowCount)
229233
}
234+
if decoded.Stats.QueryID != "20240115_123456_00001_abcde" {
235+
t.Errorf("expected QueryID '20240115_123456_00001_abcde', got %q", decoded.Stats.QueryID)
236+
}
230237
}
231238

232239
func TestColumnInfo_JSON(t *testing.T) {
@@ -330,16 +337,23 @@ func TestColumnDef_JSON(t *testing.T) {
330337
func TestQueryStats_JSON(t *testing.T) {
331338
stats := QueryStats{
332339
RowCount: 100,
333-
Duration: 500 * time.Millisecond,
340+
DurationMs: 500,
334341
Truncated: true,
335342
LimitApplied: 100,
343+
QueryID: "test_query_id",
336344
}
337345

338346
data, err := json.Marshal(stats)
339347
if err != nil {
340348
t.Fatalf("failed to marshal QueryStats: %v", err)
341349
}
342350

351+
// Verify the JSON contains the correct field name and value
352+
expectedJSON := `{"row_count":100,"duration_ms":500,"truncated":true,"limit_applied":100,"query_id":"test_query_id"}`
353+
if string(data) != expectedJSON {
354+
t.Errorf("expected JSON %s, got %s", expectedJSON, string(data))
355+
}
356+
343357
var decoded QueryStats
344358
if err := json.Unmarshal(data, &decoded); err != nil {
345359
t.Fatalf("failed to unmarshal QueryStats: %v", err)
@@ -348,9 +362,15 @@ func TestQueryStats_JSON(t *testing.T) {
348362
if decoded.RowCount != 100 {
349363
t.Errorf("expected RowCount 100, got %d", decoded.RowCount)
350364
}
365+
if decoded.DurationMs != 500 {
366+
t.Errorf("expected DurationMs 500, got %d", decoded.DurationMs)
367+
}
351368
if !decoded.Truncated {
352369
t.Error("expected Truncated to be true")
353370
}
371+
if decoded.QueryID != "test_query_id" {
372+
t.Errorf("expected QueryID 'test_query_id', got %q", decoded.QueryID)
373+
}
354374
}
355375

356376
func TestQueryOptions_Defaults(t *testing.T) {
@@ -435,3 +455,81 @@ func TestQuoteIdentifier(t *testing.T) {
435455
})
436456
}
437457
}
458+
459+
func TestQueryProgressUpdater_Update(t *testing.T) {
460+
tests := []struct {
461+
name string
462+
updates []trino.QueryProgressInfo
463+
expectedQueryID string
464+
}{
465+
{
466+
name: "single update with query ID",
467+
updates: []trino.QueryProgressInfo{{QueryId: "20240115_123456_00001_abcde"}},
468+
expectedQueryID: "20240115_123456_00001_abcde",
469+
},
470+
{
471+
name: "empty query ID ignored",
472+
updates: []trino.QueryProgressInfo{{QueryId: ""}},
473+
expectedQueryID: "",
474+
},
475+
{
476+
name: "multiple updates keeps last non-empty",
477+
updates: []trino.QueryProgressInfo{
478+
{QueryId: "first_query_id"},
479+
{QueryId: "second_query_id"},
480+
},
481+
expectedQueryID: "second_query_id",
482+
},
483+
{
484+
name: "empty update after valid ID keeps valid",
485+
updates: []trino.QueryProgressInfo{
486+
{QueryId: "valid_query_id"},
487+
{QueryId: ""},
488+
},
489+
expectedQueryID: "valid_query_id",
490+
},
491+
}
492+
493+
for _, tt := range tests {
494+
t.Run(tt.name, func(t *testing.T) {
495+
updater := &queryProgressUpdater{}
496+
497+
for _, update := range tt.updates {
498+
updater.Update(update)
499+
}
500+
501+
if got := updater.QueryID(); got != tt.expectedQueryID {
502+
t.Errorf("QueryID() = %q, want %q", got, tt.expectedQueryID)
503+
}
504+
})
505+
}
506+
}
507+
508+
func TestQueryProgressUpdater_ConcurrentAccess(t *testing.T) {
509+
updater := &queryProgressUpdater{}
510+
var wg sync.WaitGroup
511+
512+
// Spawn multiple goroutines to test thread safety
513+
for i := 0; i < 100; i++ {
514+
wg.Add(2)
515+
516+
// Writer goroutine
517+
go func() {
518+
defer wg.Done()
519+
updater.Update(trino.QueryProgressInfo{QueryId: "query_id"})
520+
}()
521+
522+
// Reader goroutine
523+
go func() {
524+
defer wg.Done()
525+
_ = updater.QueryID()
526+
}()
527+
}
528+
529+
wg.Wait()
530+
531+
// Should have a valid query ID after all updates
532+
if got := updater.QueryID(); got != "query_id" {
533+
t.Errorf("QueryID() = %q, want %q", got, "query_id")
534+
}
535+
}

pkg/client/integration_test.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,23 +260,27 @@ func TestColumnDef_Fields(t *testing.T) {
260260
func TestQueryStats_Fields(t *testing.T) {
261261
stats := QueryStats{
262262
RowCount: 1000,
263-
Duration: 500 * time.Millisecond,
263+
DurationMs: 500,
264264
Truncated: true,
265265
LimitApplied: 1000,
266+
QueryID: "test_query_123",
266267
}
267268

268269
if stats.RowCount != 1000 {
269270
t.Errorf("expected RowCount 1000, got %d", stats.RowCount)
270271
}
271-
if stats.Duration != 500*time.Millisecond {
272-
t.Errorf("expected Duration 500ms, got %v", stats.Duration)
272+
if stats.DurationMs != 500 {
273+
t.Errorf("expected DurationMs 500, got %d", stats.DurationMs)
273274
}
274275
if !stats.Truncated {
275276
t.Error("expected Truncated to be true")
276277
}
277278
if stats.LimitApplied != 1000 {
278279
t.Errorf("expected LimitApplied 1000, got %d", stats.LimitApplied)
279280
}
281+
if stats.QueryID != "test_query_123" {
282+
t.Errorf("expected QueryID 'test_query_123', got %q", stats.QueryID)
283+
}
280284
}
281285

282286
// TestColumnInfo_Fields tests the ColumnInfo struct.

pkg/tools/mock_client_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package tools
22

33
import (
44
"context"
5-
"time"
65

76
"github.com/txn2/mcp-trino/pkg/client"
87
)
@@ -59,7 +58,7 @@ func NewMockTrinoClient() *MockTrinoClient {
5958
},
6059
Stats: client.QueryStats{
6160
RowCount: 2,
62-
Duration: 100 * time.Millisecond,
61+
DurationMs: 100,
6362
Truncated: false,
6463
LimitApplied: 1000,
6564
},

pkg/tools/query.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func formatCSV(result *client.QueryResult) string {
165165
if result.Stats.Truncated {
166166
output += fmt.Sprintf(" (truncated at limit %d)", result.Stats.LimitApplied)
167167
}
168-
output += fmt.Sprintf(", executed in %dms", result.Stats.Duration.Milliseconds())
168+
output += fmt.Sprintf(", executed in %dms", result.Stats.DurationMs)
169169

170170
return output
171171
}
@@ -209,7 +209,7 @@ func formatMarkdown(result *client.QueryResult) string {
209209
if result.Stats.Truncated {
210210
output += fmt.Sprintf(" (truncated at limit %d)", result.Stats.LimitApplied)
211211
}
212-
output += fmt.Sprintf(", executed in %dms*", result.Stats.Duration.Milliseconds())
212+
output += fmt.Sprintf(", executed in %dms*", result.Stats.DurationMs)
213213

214214
return output
215215
}

0 commit comments

Comments
 (0)