diff --git a/coderd/coderd.go b/coderd/coderd.go index e46df91cc202b..4feb7d8dbab8c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1184,8 +1184,8 @@ func New(options *Options) *API { r.Put("/system-prompt", api.putChatSystemPrompt) r.Get("/plan-mode-instructions", api.getChatPlanModeInstructions) r.Put("/plan-mode-instructions", api.putChatPlanModeInstructions) - r.Get("/explore-model-override", api.getChatExploreModelOverride) - r.Put("/explore-model-override", api.putChatExploreModelOverride) + r.Get("/agent-model-override/{context}", api.getChatAgentModelOverride) + r.Put("/agent-model-override/{context}", api.putChatAgentModelOverride) r.Get("/desktop-enabled", api.getChatDesktopEnabled) r.Put("/desktop-enabled", api.putChatDesktopEnabled) r.Get("/debug-logging", api.getChatDebugLogging) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 1fb81b96a55b4..d06328d24e0fa 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2737,6 +2737,13 @@ func (q *querier) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]dat return files, nil } +func (q *querier) GetChatGeneralModelOverride(ctx context.Context) (string, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { + return "", err + } + return q.db.GetChatGeneralModelOverride(ctx) +} + func (q *querier) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) { // The include-default-system-prompt flag is a deployment-wide setting read // during chat creation by every authenticated user, so no RBAC policy @@ -7390,6 +7397,13 @@ func (q *querier) UpsertChatExploreModelOverride(ctx context.Context, value stri return q.db.UpsertChatExploreModelOverride(ctx, value) } +func (q *querier) UpsertChatGeneralModelOverride(ctx context.Context, value string) error { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.UpsertChatGeneralModelOverride(ctx, value) +} + func (q *querier) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 1d3b9530c912c..91596cb8d35d1 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -890,6 +890,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatDesktopEnabled(gomock.Any()).Return(false, nil).AnyTimes() check.Args().Asserts() })) + s.Run("GetChatGeneralModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().GetChatGeneralModelOverride(gomock.Any()).Return("", nil).AnyTimes() + check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead) + })) s.Run("GetChatExploreModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetChatExploreModelOverride(gomock.Any()).Return("", nil).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead) @@ -1205,6 +1209,10 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpsertChatDesktopEnabled(gomock.Any(), false).Return(nil).AnyTimes() check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) + s.Run("UpsertChatGeneralModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + dbm.EXPECT().UpsertChatGeneralModelOverride(gomock.Any(), "").Return(nil).AnyTimes() + check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) + })) s.Run("UpsertChatExploreModelOverride", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().UpsertChatExploreModelOverride(gomock.Any(), "").Return(nil).AnyTimes() check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index cda2c66bbc595..e070a3c7129af 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1256,6 +1256,14 @@ func (m queryMetricsStore) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUI return r0, r1 } +func (m queryMetricsStore) GetChatGeneralModelOverride(ctx context.Context) (string, error) { + start := time.Now() + r0, r1 := m.s.GetChatGeneralModelOverride(ctx) + m.queryLatencies.WithLabelValues("GetChatGeneralModelOverride").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatGeneralModelOverride").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) { start := time.Now() r0, r1 := m.s.GetChatIncludeDefaultSystemPrompt(ctx) @@ -5288,6 +5296,14 @@ func (m queryMetricsStore) UpsertChatExploreModelOverride(ctx context.Context, v return r0 } +func (m queryMetricsStore) UpsertChatGeneralModelOverride(ctx context.Context, value string) error { + start := time.Now() + r0 := m.s.UpsertChatGeneralModelOverride(ctx, value) + m.queryLatencies.WithLabelValues("UpsertChatGeneralModelOverride").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertChatGeneralModelOverride").Inc() + return r0 +} + func (m queryMetricsStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error { start := time.Now() r0 := m.s.UpsertChatIncludeDefaultSystemPrompt(ctx, includeDefaultSystemPrompt) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index e7bedcf10757c..22d8cb8e8bc60 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2311,6 +2311,21 @@ func (mr *MockStoreMockRecorder) GetChatFilesByIDs(ctx, ids any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatFilesByIDs", reflect.TypeOf((*MockStore)(nil).GetChatFilesByIDs), ctx, ids) } +// GetChatGeneralModelOverride mocks base method. +func (m *MockStore) GetChatGeneralModelOverride(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatGeneralModelOverride", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatGeneralModelOverride indicates an expected call of GetChatGeneralModelOverride. +func (mr *MockStoreMockRecorder) GetChatGeneralModelOverride(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatGeneralModelOverride", reflect.TypeOf((*MockStore)(nil).GetChatGeneralModelOverride), ctx) +} + // GetChatIncludeDefaultSystemPrompt mocks base method. func (m *MockStore) GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) { m.ctrl.T.Helper() @@ -9933,6 +9948,20 @@ func (mr *MockStoreMockRecorder) UpsertChatExploreModelOverride(ctx, value any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatExploreModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatExploreModelOverride), ctx, value) } +// UpsertChatGeneralModelOverride mocks base method. +func (m *MockStore) UpsertChatGeneralModelOverride(ctx context.Context, value string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertChatGeneralModelOverride", ctx, value) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertChatGeneralModelOverride indicates an expected call of UpsertChatGeneralModelOverride. +func (mr *MockStoreMockRecorder) UpsertChatGeneralModelOverride(ctx, value any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertChatGeneralModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertChatGeneralModelOverride), ctx, value) +} + // UpsertChatIncludeDefaultSystemPrompt mocks base method. func (m *MockStore) UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 34858374c8a48..ca14af1b7daaa 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -308,6 +308,7 @@ type sqlcQuerier interface { // loading file content. GetChatFileMetadataByChatID(ctx context.Context, chatID uuid.UUID) ([]GetChatFileMetadataByChatIDRow, error) GetChatFilesByIDs(ctx context.Context, ids []uuid.UUID) ([]ChatFile, error) + GetChatGeneralModelOverride(ctx context.Context) (string, error) // GetChatIncludeDefaultSystemPrompt preserves the legacy default // for deployments created before the explicit include-default toggle. // When the toggle is unset, a non-empty custom prompt implies false; @@ -1174,6 +1175,7 @@ type sqlcQuerier interface { UpsertChatDiffStatus(ctx context.Context, arg UpsertChatDiffStatusParams) (ChatDiffStatus, error) UpsertChatDiffStatusReference(ctx context.Context, arg UpsertChatDiffStatusReferenceParams) (ChatDiffStatus, error) UpsertChatExploreModelOverride(ctx context.Context, value string) error + UpsertChatGeneralModelOverride(ctx context.Context, value string) error UpsertChatIncludeDefaultSystemPrompt(ctx context.Context, includeDefaultSystemPrompt bool) error UpsertChatPlanModeInstructions(ctx context.Context, value string) error UpsertChatRetentionDays(ctx context.Context, retentionDays int32) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 1be67fee0dde5..cc0036ee5d0f6 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -20406,6 +20406,18 @@ func (q *sqlQuerier) GetChatExploreModelOverride(ctx context.Context) (string, e return model_config_id, err } +const getChatGeneralModelOverride = `-- name: GetChatGeneralModelOverride :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_general_model_override'), '') :: text AS model_config_id +` + +func (q *sqlQuerier) GetChatGeneralModelOverride(ctx context.Context) (string, error) { + row := q.db.QueryRowContext(ctx, getChatGeneralModelOverride) + var model_config_id string + err := row.Scan(&model_config_id) + return model_config_id, err +} + const getChatIncludeDefaultSystemPrompt = `-- name: GetChatIncludeDefaultSystemPrompt :one SELECT COALESCE( @@ -20773,6 +20785,16 @@ func (q *sqlQuerier) UpsertChatExploreModelOverride(ctx context.Context, value s return err } +const upsertChatGeneralModelOverride = `-- name: UpsertChatGeneralModelOverride :exec +INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override' +` + +func (q *sqlQuerier) UpsertChatGeneralModelOverride(ctx context.Context, value string) error { + _, err := q.db.ExecContext(ctx, upsertChatGeneralModelOverride, value) + return err +} + const upsertChatIncludeDefaultSystemPrompt = `-- name: UpsertChatIncludeDefaultSystemPrompt :exec INSERT INTO site_configs (key, value) VALUES ( diff --git a/coderd/database/queries/siteconfig.sql b/coderd/database/queries/siteconfig.sql index ffb97a95fa177..1bf22448e182c 100644 --- a/coderd/database/queries/siteconfig.sql +++ b/coderd/database/queries/siteconfig.sql @@ -175,6 +175,14 @@ SELECT INSERT INTO site_configs (key, value) VALUES ('agents_chat_explore_model_override', $1) ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_explore_model_override'; +-- name: GetChatGeneralModelOverride :one +SELECT + COALESCE((SELECT value FROM site_configs WHERE key = 'agents_chat_general_model_override'), '') :: text AS model_config_id; + +-- name: UpsertChatGeneralModelOverride :exec +INSERT INTO site_configs (key, value) VALUES ('agents_chat_general_model_override', $1) +ON CONFLICT (key) DO UPDATE SET value = $1 WHERE site_configs.key = 'agents_chat_general_model_override'; + -- name: GetChatDesktopEnabled :one SELECT COALESCE((SELECT value = 'true' FROM site_configs WHERE key = 'agents_desktop_enabled'), false) :: boolean AS enable_desktop; diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 455d84862b4bf..78b39a156d1f1 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -449,7 +449,7 @@ func validateChatPlanMode(mode codersdk.ChatPlanMode) bool { } } -func parseChatExploreModelOverride(raw string) (*uuid.UUID, error) { +func parseChatModelOverride(raw string) (*uuid.UUID, error) { trimmed := strings.TrimSpace(raw) if trimmed == "" { //nolint:nilnil // Empty site-config value means the override is unset. @@ -457,19 +457,29 @@ func parseChatExploreModelOverride(raw string) (*uuid.UUID, error) { } modelConfigID, err := uuid.Parse(trimmed) if err != nil { - return nil, xerrors.Errorf("parse explore model override: %w", err) + return nil, xerrors.Errorf("parse chat model override: %w", err) } return &modelConfigID, nil } -func formatChatExploreModelOverride(id *uuid.UUID) string { +func formatChatModelOverride(id *uuid.UUID) string { if id == nil { return "" } return id.String() } -func validateChatExploreModelOverrideID( +func lookupEnabledChatModelConfigByID( + ctx context.Context, + db database.Store, + id uuid.UUID, +) (database.ChatModelConfig, error) { + //nolint:gocritic // Validation lookup uses AsChatd to check model + // availability independently of the caller's read permissions. + return db.GetEnabledChatModelConfigByID(dbauthz.AsChatd(ctx), id) +} + +func validateChatModelOverrideID( ctx context.Context, db database.Store, id *uuid.UUID, @@ -482,9 +492,7 @@ func validateChatExploreModelOverrideID( Message: "Invalid model_config_id.", } } - //nolint:gocritic // Validation lookup uses system context to check model - // availability independently of the caller's read permissions. - _, err := db.GetEnabledChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), *id) + _, err := lookupEnabledChatModelConfigByID(ctx, db, *id) if err == nil { return 0, nil } @@ -499,18 +507,23 @@ func validateChatExploreModelOverrideID( } } -func (api *API) getChatExploreModelOverrideConfig( +func (api *API) getChatModelOverrideConfig( ctx context.Context, + settingName string, + getter func(context.Context) (string, error), ) (*uuid.UUID, bool, error) { - raw, err := api.Database.GetChatExploreModelOverride(ctx) + raw, err := getter(ctx) if err != nil { - return nil, false, xerrors.Errorf("get explore model override: %w", err) + return nil, false, xerrors.Errorf("get %s model override: %w", settingName, err) } - id, err := parseChatExploreModelOverride(raw) + id, err := parseChatModelOverride(raw) if err != nil { // Degrade malformed values to unset so the admin settings page // remains accessible and the bad value can be cleared. - api.Logger.Warn(ctx, "malformed explore model override in site config, treating as unset", + api.Logger.Warn( + ctx, + "malformed model override in site config, treating as unset", + slog.F("setting", settingName), slog.F("raw_value", raw), slog.Error(err), ) @@ -519,6 +532,64 @@ func (api *API) getChatExploreModelOverrideConfig( return id, false, nil } +func parseChatAgentModelOverrideContext(raw string) (codersdk.ChatAgentModelOverrideContext, error) { + overrideContext := codersdk.ChatAgentModelOverrideContext(raw) + if overrideContext.Valid() { + return overrideContext, nil + } + return "", xerrors.Errorf("unknown chat agent model override context %q", raw) +} + +type chatAgentModelOverrideSiteConfig struct { + getter func(context.Context) (string, error) + upsert func(context.Context, string) error +} + +func (api *API) chatAgentModelOverrideSiteConfig( + overrideContext codersdk.ChatAgentModelOverrideContext, +) (chatAgentModelOverrideSiteConfig, error) { + switch overrideContext { + case codersdk.ChatAgentModelOverrideContextGeneral: + return chatAgentModelOverrideSiteConfig{ + getter: api.Database.GetChatGeneralModelOverride, + upsert: api.Database.UpsertChatGeneralModelOverride, + }, nil + case codersdk.ChatAgentModelOverrideContextExplore: + return chatAgentModelOverrideSiteConfig{ + getter: api.Database.GetChatExploreModelOverride, + upsert: api.Database.UpsertChatExploreModelOverride, + }, nil + default: + return chatAgentModelOverrideSiteConfig{}, xerrors.Errorf( + "unknown chat agent model override context %q", + overrideContext, + ) + } +} + +func (api *API) getChatAgentModelOverrideConfig( + ctx context.Context, + overrideContext codersdk.ChatAgentModelOverrideContext, +) (*uuid.UUID, bool, error) { + siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext) + if err != nil { + return nil, false, err + } + return api.getChatModelOverrideConfig(ctx, string(overrideContext), siteConfig.getter) +} + +func (api *API) upsertChatAgentModelOverrideConfig( + ctx context.Context, + overrideContext codersdk.ChatAgentModelOverrideContext, + modelConfigID *uuid.UUID, +) error { + siteConfig, err := api.chatAgentModelOverrideSiteConfig(overrideContext) + if err != nil { + return err + } + return siteConfig.upsert(ctx, formatChatModelOverride(modelConfigID)) +} + // EXPERIMENTAL: this endpoint is experimental and is subject to change. func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -3827,53 +3898,103 @@ func (api *API) putChatPlanModeInstructions(rw http.ResponseWriter, r *http.Requ rw.WriteHeader(http.StatusNoContent) } +func readChatAgentModelOverrideContext( + rw http.ResponseWriter, + r *http.Request, +) (codersdk.ChatAgentModelOverrideContext, bool) { + ctx := r.Context() + rawContext := chi.URLParam(r, "context") + overrideContext, err := parseChatAgentModelOverrideContext(rawContext) + if err == nil { + return overrideContext, true + } + validContextValues := make( + []string, + 0, + len(codersdk.AllChatAgentModelOverrideContexts()), + ) + for _, overrideContext := range codersdk.AllChatAgentModelOverrideContexts() { + validContextValues = append(validContextValues, string(overrideContext)) + } + validContexts := strings.Join(validContextValues, ", ") + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid chat agent model override context.", + Detail: fmt.Sprintf( + "Expected one of %s. Got %q.", + validContexts, + rawContext, + ), + }) + return "", false +} + // EXPERIMENTAL: this endpoint is experimental and is subject to change. // //nolint:revive // get-return: revive assumes get* must be a getter, but this is an HTTP handler. -func (api *API) getChatExploreModelOverride(rw http.ResponseWriter, r *http.Request) { +func (api *API) getChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) { httpapi.ResourceNotFound(rw) return } + overrideContext, ok := readChatAgentModelOverrideContext(rw, r) + if !ok { + return + } - modelConfigID, hasMalformedOverride, err := api.getChatExploreModelOverrideConfig(ctx) + modelConfigID, isMalformed, err := api.getChatAgentModelOverrideConfig(ctx, overrideContext) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching Explore model override.", + Message: fmt.Sprintf("Internal error fetching %s model override.", overrideContext), Detail: err.Error(), }) return } - httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatExploreModelOverrideResponse{ - ModelConfigID: modelConfigID, - HasMalformedOverride: hasMalformedOverride, - }) + resp := codersdk.ChatAgentModelOverrideResponse{ + Context: overrideContext, + ModelConfigID: formatChatModelOverride(modelConfigID), + IsMalformed: isMalformed, + } + + httpapi.Write(ctx, rw, http.StatusOK, resp) } // EXPERIMENTAL: this endpoint is experimental and is subject to change. -func (api *API) putChatExploreModelOverride(rw http.ResponseWriter, r *http.Request) { +func (api *API) putChatAgentModelOverride(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { httpapi.Forbidden(rw) return } + overrideContext, ok := readChatAgentModelOverrideContext(rw, r) + if !ok { + return + } - var req codersdk.UpdateChatExploreModelOverrideRequest + var req codersdk.UpdateChatAgentModelOverrideRequest if !httpapi.Read(ctx, rw, r, &req) { return } - status, resp := validateChatExploreModelOverrideID(ctx, api.Database, req.ModelConfigID) + modelConfigID, err := parseChatModelOverride(req.ModelConfigID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid model_config_id.", + Detail: fmt.Sprintf("Value %q is not a valid UUID.", req.ModelConfigID), + }) + return + } + + status, resp := validateChatModelOverrideID(ctx, api.Database, modelConfigID) if resp != nil { httpapi.Write(ctx, rw, status, *resp) return } - if err := api.Database.UpsertChatExploreModelOverride(ctx, formatChatExploreModelOverride(req.ModelConfigID)); err != nil { + if err := api.upsertChatAgentModelOverrideConfig(ctx, overrideContext, modelConfigID); err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error updating Explore model override.", + Message: fmt.Sprintf("Internal error updating %s model override.", overrideContext), Detail: err.Error(), }) return diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 07907ebed7fc2..c3f21d1df4fbb 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -1134,43 +1134,63 @@ func TestListChats(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) - client, db := newChatClientWithDatabase(t) + client, _ := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) - modelConfig := createChatModelConfig(t, client) + _ = createChatModelConfig(t, client) - // Insert the chat that will later be pinned directly - // into the database with a completed status so we - // avoid the background chatd processor entirely. - pinnedDBChat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ - OrganizationID: firstUser.OrganizationID, - OwnerID: firstUser.UserID, - LastModelConfigID: modelConfig.ID, - Title: "pinned-chat", - Status: database.ChatStatusCompleted, - ClientType: database.ChatClientTypeUi, + // Create the chat that will later be pinned. It gets the + // earliest updated_at because it is inserted first. + pinnedChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + OrganizationID: firstUser.OrganizationID, + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "pinned-chat", + }}, }) require.NoError(t, err) - // Fill page 1 with newer chats so the pinned chat - // would normally be pushed off the first page - // (default limit 50). Insert directly into the - // database to avoid spawning 51 background chat - // processors, which causes timeouts under -race. + // Fill page 1 with newer chats so the pinned chat would + // normally be pushed off the first page (default limit 50). const fillerCount = 51 + fillerChats := make([]codersdk.Chat, 0, fillerCount) for i := range fillerCount { - _, insertErr := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ - OrganizationID: firstUser.OrganizationID, - OwnerID: firstUser.UserID, - LastModelConfigID: modelConfig.ID, - Title: fmt.Sprintf("filler-%d", i), - Status: database.ChatStatusCompleted, - ClientType: database.ChatClientTypeUi, + c, createErr := client.CreateChat(ctx, codersdk.CreateChatRequest{ + OrganizationID: firstUser.OrganizationID, + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: fmt.Sprintf("filler-%d", i), + }}, }) - require.NoError(t, insertErr) + require.NoError(t, createErr) + fillerChats = append(fillerChats, c) + } + + // Wait for all chats to reach a terminal status so + // updated_at is stable before paginating. A single + // polling loop checks every chat per tick to avoid + // O(N) separate Eventually loops. + allCreated := append([]codersdk.Chat{pinnedChat}, fillerChats...) + pending := make(map[uuid.UUID]struct{}, len(allCreated)) + for _, c := range allCreated { + pending[c.ID] = struct{}{} } + testutil.Eventually(ctx, t, func(_ context.Context) bool { + all, listErr := client.ListChats(ctx, &codersdk.ListChatsOptions{ + Pagination: codersdk.Pagination{Limit: fillerCount + 10}, + }) + if listErr != nil { + return false + } + for _, ch := range all { + if _, ok := pending[ch.ID]; ok && ch.Status != codersdk.ChatStatusPending && ch.Status != codersdk.ChatStatusRunning { + delete(pending, ch.ID) + } + } + return len(pending) == 0 + }, testutil.IntervalFast) // Pin the earliest chat. - err = client.UpdateChat(ctx, pinnedDBChat.ID, codersdk.UpdateChatRequest{ + err = client.UpdateChat(ctx, pinnedChat.ID, codersdk.UpdateChatRequest{ PinOrder: ptr.Ref(int32(1)), }) require.NoError(t, err) @@ -1186,11 +1206,11 @@ func TestListChats(t *testing.T) { for _, c := range page1 { page1IDs[c.ID] = struct{}{} } - _, found := page1IDs[pinnedDBChat.ID] + _, found := page1IDs[pinnedChat.ID] require.True(t, found, "pinned chat should appear on page 1") // The pinned chat should be the first item in the list. - require.Equal(t, pinnedDBChat.ID, page1[0].ID, "pinned chat should be first") + require.Equal(t, pinnedChat.ID, page1[0].ID, "pinned chat should be first") }) // Test cursor pagination with a mix of pinned and unpinned chats. @@ -9396,6 +9416,44 @@ func createChatModelConfig(t *testing.T, client *codersdk.ExperimentalClient) co return modelConfig } +func createAdditionalChatModelConfig( + t *testing.T, + client *codersdk.ExperimentalClient, + provider string, + model string, +) codersdk.ChatModelConfig { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitLong) + contextLimit := int64(4096) + isDefault := false + modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: provider, + Model: model, + ContextLimit: &contextLimit, + IsDefault: &isDefault, + }) + require.NoError(t, err) + return modelConfig +} + +func createDisabledChatModelConfig( + t *testing.T, + client *codersdk.ExperimentalClient, + provider string, + model string, +) codersdk.ChatModelConfig { + t.Helper() + + modelConfig := createAdditionalChatModelConfig(t, client, provider, model) + ctx := testutil.Context(t, testutil.WaitLong) + updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ + Enabled: ptr.Ref(false), + }) + require.NoError(t, err) + return updated +} + //nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. func TestChatSystemPrompt(t *testing.T) { t.Parallel() @@ -9897,129 +9955,249 @@ func TestChatPlanModeInstructions(t *testing.T) { }) } -//nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. -func TestChatExploreModelOverride(t *testing.T) { +//nolint:tparallel,paralleltest // Setting subtests share per-setting coderdtest instances. +func TestChatModelOverrides(t *testing.T) { t.Parallel() - adminClient, db := newChatClientWithDatabase(t) - firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) - defaultModel := createChatModelConfig(t, adminClient) - memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) - memberClient := codersdk.NewExperimentalClient(memberClientRaw) + type overrideResponse struct { + context codersdk.ChatAgentModelOverrideContext + modelConfigID string + isMalformed bool + } - createAdditionalModel := func(t *testing.T, model string, enabled bool) codersdk.ChatModelConfig { - t.Helper() + type settingTest struct { + name string + context codersdk.ChatAgentModelOverrideContext + dbGet func(context.Context, database.Store) (string, error) + dbUpsert func(context.Context, database.Store, string) error + } - ctx := testutil.Context(t, testutil.WaitLong) - contextLimit := int64(4096) - isDefault := false - modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: defaultModel.Provider, - Model: model, - ContextLimit: &contextLimit, - IsDefault: &isDefault, - }) - require.NoError(t, err) - if enabled { - return modelConfig + settingPath := func(overrideContext codersdk.ChatAgentModelOverrideContext) string { + return "/api/experimental/chats/config/agent-model-override/" + string(overrideContext) + } + + getOverride := func( + ctx context.Context, + client *codersdk.ExperimentalClient, + overrideContext codersdk.ChatAgentModelOverrideContext, + ) (overrideResponse, error) { + resp, err := client.GetChatAgentModelOverride(ctx, overrideContext) + if err != nil { + return overrideResponse{}, err } - updated, err := adminClient.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ - Enabled: ptr.Ref(false), - }) - require.NoError(t, err) - return updated + return overrideResponse{ + context: resp.Context, + modelConfigID: resp.ModelConfigID, + isMalformed: resp.IsMalformed, + }, nil } - t.Run("DefaultGETReturnsEmpty", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitLong) + putOverride := func( + ctx context.Context, + client *codersdk.ExperimentalClient, + overrideContext codersdk.ChatAgentModelOverrideContext, + modelConfigID string, + ) error { + return client.UpdateChatAgentModelOverride( + ctx, + overrideContext, + codersdk.UpdateChatAgentModelOverrideRequest{ModelConfigID: modelConfigID}, + ) + } - resp, err := adminClient.GetChatExploreModelOverride(ctx) - require.NoError(t, err) - require.Nil(t, resp.ModelConfigID) - require.False(t, resp.HasMalformedOverride) - }) + settings := []settingTest{ + { + name: "General", + context: codersdk.ChatAgentModelOverrideContextGeneral, + dbGet: func(ctx context.Context, db database.Store) (string, error) { + return db.GetChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx)) + }, + dbUpsert: func(ctx context.Context, db database.Store, value string) error { + return db.UpsertChatGeneralModelOverride(dbauthz.AsSystemRestricted(ctx), value) + }, + }, + { + name: "Explore", + context: codersdk.ChatAgentModelOverrideContextExplore, + dbGet: func(ctx context.Context, db database.Store) (string, error) { + return db.GetChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx)) + }, + dbUpsert: func(ctx context.Context, db database.Store, value string) error { + return db.UpsertChatExploreModelOverride(dbauthz.AsSystemRestricted(ctx), value) + }, + }, + } - t.Run("AdminCanSetAndClear", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitLong) - overrideModel := createAdditionalModel(t, "gpt-4.1-mini", true) + for _, setting := range settings { + t.Run(setting.name, func(t *testing.T) { + adminClient, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + defaultModel := createChatModelConfig(t, adminClient) + openAIModel := createAdditionalChatModelConfig( + t, + adminClient, + defaultModel.Provider, + "gpt-4.1-mini-"+string(setting.context), + ) + disabledModel := createDisabledChatModelConfig( + t, + adminClient, + defaultModel.Provider, + "gpt-4.1-disabled-"+string(setting.context), + ) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) - err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{ - ModelConfigID: &overrideModel.ID, - }) - require.NoError(t, err) + t.Run("DefaultGETReturnsEmpty", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) - resp, err := adminClient.GetChatExploreModelOverride(ctx) - require.NoError(t, err) - require.NotNil(t, resp.ModelConfigID) - require.Equal(t, overrideModel.ID, *resp.ModelConfigID) - require.False(t, resp.HasMalformedOverride) + resp, err := getOverride(ctx, adminClient, setting.context) + require.NoError(t, err) + require.Equal(t, setting.context, resp.context) + require.Empty(t, resp.modelConfigID) + require.False(t, resp.isMalformed) - err = adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{}) - require.NoError(t, err) + raw, err := setting.dbGet(ctx, db) + require.NoError(t, err) + require.Empty(t, raw, "expected empty stored override for %s", settingPath(setting.context)) + }) - resp, err = adminClient.GetChatExploreModelOverride(ctx) - require.NoError(t, err) - require.Nil(t, resp.ModelConfigID) - require.False(t, resp.HasMalformedOverride) - }) + t.Run("AdminCanSetAndClear", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) - t.Run("MalformedStoredOverrideIsReportedAndCanBeCleared", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitLong) + err := putOverride(ctx, adminClient, setting.context, openAIModel.ID.String()) + require.NoError(t, err) - require.NoError(t, db.UpsertChatExploreModelOverride( - dbauthz.AsSystemRestricted(ctx), - "not-a-uuid", - )) + raw, err := setting.dbGet(ctx, db) + require.NoError(t, err) + require.Equal(t, openAIModel.ID.String(), raw, "expected stored override for %s", settingPath(setting.context)) - resp, err := adminClient.GetChatExploreModelOverride(ctx) - require.NoError(t, err) - require.Nil(t, resp.ModelConfigID) - require.True(t, resp.HasMalformedOverride) + resp, err := getOverride(ctx, adminClient, setting.context) + require.NoError(t, err) + require.Equal(t, setting.context, resp.context) + require.Equal(t, openAIModel.ID.String(), resp.modelConfigID) + require.False(t, resp.isMalformed) - err = adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{}) - require.NoError(t, err) + err = putOverride(ctx, adminClient, setting.context, "") + require.NoError(t, err) - resp, err = adminClient.GetChatExploreModelOverride(ctx) - require.NoError(t, err) - require.Nil(t, resp.ModelConfigID) - require.False(t, resp.HasMalformedOverride) - }) + raw, err = setting.dbGet(ctx, db) + require.NoError(t, err) + require.Empty(t, raw, "expected cleared override for %s", settingPath(setting.context)) - t.Run("DisabledModelReturns400", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitLong) - disabledModel := createAdditionalModel(t, "gpt-4.1-disabled", false) + resp, err = getOverride(ctx, adminClient, setting.context) + require.NoError(t, err) + require.Equal(t, setting.context, resp.context) + require.Empty(t, resp.modelConfigID) + require.False(t, resp.isMalformed) + }) + + t.Run("MalformedStoredOverrideIsReportedAndCanBeCleared", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + require.NoError(t, setting.dbUpsert(ctx, db, "not-a-uuid")) + + resp, err := getOverride(ctx, adminClient, setting.context) + require.NoError(t, err) + require.Equal(t, setting.context, resp.context) + require.Empty(t, resp.modelConfigID) + require.True(t, resp.isMalformed) + + err = putOverride(ctx, adminClient, setting.context, "") + require.NoError(t, err) + + raw, err := setting.dbGet(ctx, db) + require.NoError(t, err) + require.Empty(t, raw, "expected malformed override to be cleared for %s", settingPath(setting.context)) + + resp, err = getOverride(ctx, adminClient, setting.context) + require.NoError(t, err) + require.Equal(t, setting.context, resp.context) + require.Empty(t, resp.modelConfigID) + require.False(t, resp.isMalformed) + }) + + t.Run("InvalidUUIDReturns400", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + err := putOverride(ctx, adminClient, setting.context, "not-a-uuid") + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid model_config_id.", sdkErr.Message) + require.Equal(t, "Value \"not-a-uuid\" is not a valid UUID.", sdkErr.Detail) + }) + + t.Run("DisabledModelReturns400", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + err := putOverride(ctx, adminClient, setting.context, disabledModel.ID.String()) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid model_config_id.", sdkErr.Message) + }) + + t.Run("UnknownModelReturns400", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + unknownModelID := uuid.New() + + err := putOverride(ctx, adminClient, setting.context, unknownModelID.String()) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid model_config_id.", sdkErr.Message) + }) - err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{ - ModelConfigID: &disabledModel.ID, + t.Run("NonAdminGETReturns404", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + _, err := getOverride(ctx, memberClient, setting.context) + requireSDKError(t, err, http.StatusNotFound) + }) + + t.Run("NonAdminPUTReturns403", func(t *testing.T) { + ctx := testutil.Context(t, testutil.WaitLong) + + err := putOverride(ctx, memberClient, setting.context, defaultModel.ID.String()) + requireSDKError(t, err, http.StatusForbidden) + }) }) - sdkErr := requireSDKError(t, err, http.StatusBadRequest) - require.Equal(t, "Invalid model_config_id.", sdkErr.Message) - }) + } - t.Run("UnknownModelReturns400", func(t *testing.T) { + t.Run("UnknownContextReturns400", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) - unknownModelID := uuid.New() - err := adminClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{ - ModelConfigID: &unknownModelID, - }) + adminClient := newChatClient(t) + coderdtest.CreateFirstUser(t, adminClient.Client) + unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context") + + _, err := getOverride(ctx, adminClient, unknownContext) sdkErr := requireSDKError(t, err, http.StatusBadRequest) - require.Equal(t, "Invalid model_config_id.", sdkErr.Message) + require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message) + require.Equal( + t, + `Expected one of general, explore. Got "not-a-context".`, + sdkErr.Detail, + ) + + err = putOverride(ctx, adminClient, unknownContext, "") + sdkErr = requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid chat agent model override context.", sdkErr.Message) + require.Equal( + t, + `Expected one of general, explore. Got "not-a-context".`, + sdkErr.Detail, + ) }) - t.Run("NonAdminGETReturns404", func(t *testing.T) { + t.Run("NonAdminUnknownContextUsesAuthResponse", func(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) - _, err := memberClient.GetChatExploreModelOverride(ctx) - requireSDKError(t, err, http.StatusNotFound) - }) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + unknownContext := codersdk.ChatAgentModelOverrideContext("not-a-context") - t.Run("NonAdminPUTReturns403", func(t *testing.T) { - ctx := testutil.Context(t, testutil.WaitLong) + _, err := getOverride(ctx, memberClient, unknownContext) + requireSDKError(t, err, http.StatusNotFound) - err := memberClient.UpdateChatExploreModelOverride(ctx, codersdk.UpdateChatExploreModelOverrideRequest{ - ModelConfigID: &defaultModel.ID, - }) + err = putOverride(ctx, memberClient, unknownContext, "") requireSDKError(t, err, http.StatusForbidden) }) } diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 8ef082c738ae0..52617e8847915 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "slices" "sort" @@ -27,6 +28,18 @@ import ( var ErrSubagentNotDescendant = xerrors.New("target chat is not a descendant of current chat") +var errInvalidModelOverrideMetadata = xerrors.New("invalid model override metadata") + +type modelOverrideConfigResolver func( + context.Context, + uuid.UUID, +) (database.ChatModelConfig, string, error) + +type modelOverrideProviderKeysResolver func( + context.Context, + uuid.UUID, +) (chatprovider.ProviderAPIKeys, error) + const ( subagentAwaitPollInterval = 200 * time.Millisecond subagentAwaitFallbackPoll = 5 * time.Second @@ -90,66 +103,199 @@ func (p *Server) isDesktopEnabled(ctx context.Context) bool { return enabled } -func (p *Server) resolveExploreSubagentModelConfigID( +func subagentModelOverrideLogLabel( + overrideContext codersdk.ChatAgentModelOverrideContext, +) string { + switch overrideContext { + case codersdk.ChatAgentModelOverrideContextGeneral: + return "general delegated child" + case codersdk.ChatAgentModelOverrideContextExplore: + return "explore" + default: + return string(overrideContext) + } +} + +func readSubagentModelOverride( ctx context.Context, - ownerID uuid.UUID, - fallback uuid.UUID, -) (uuid.UUID, error) { - //nolint:gocritic // Chatd needs its scoped deployment-config read access here. - chatdCtx := dbauthz.AsChatd(ctx) - raw, err := p.db.GetChatExploreModelOverride(chatdCtx) + db database.Store, + overrideContext codersdk.ChatAgentModelOverrideContext, +) (string, error) { + switch overrideContext { + case codersdk.ChatAgentModelOverrideContextGeneral: + return db.GetChatGeneralModelOverride(ctx) + case codersdk.ChatAgentModelOverrideContextExplore: + return db.GetChatExploreModelOverride(ctx) + default: + return "", xerrors.Errorf( + "unknown subagent model override context %q", + overrideContext, + ) + } +} + +func validateModelConfigAndResolveProvider( + modelConfig database.ChatModelConfig, +) (database.ChatModelConfig, string, error) { + if !modelConfig.Enabled { + return database.ChatModelConfig{}, "", sql.ErrNoRows + } + providerName, _, err := chatprovider.ResolveModelWithProviderHint( + modelConfig.Model, + modelConfig.Provider, + ) if err != nil { - return uuid.Nil, xerrors.Errorf("get Explore model override: %w", err) + return database.ChatModelConfig{}, "", xerrors.Errorf( + "%w: %v", + errInvalidModelOverrideMetadata, + err, + ) } + return modelConfig, providerName, nil +} + +func enabledProviderContainsName( + providers []database.ChatProvider, + providerName string, +) bool { + normalizedProviderName := chatprovider.NormalizeProvider(providerName) + for _, provider := range providers { + if chatprovider.NormalizeProvider(provider.Provider) == normalizedProviderName { + return true + } + } + return false +} + +func (p *Server) resolveConfiguredModelOverride( + ctx context.Context, + overrideContext string, + raw string, + ownerID uuid.UUID, + resolveModelConfig modelOverrideConfigResolver, + resolveProviderKeys modelOverrideProviderKeysResolver, +) (database.ChatModelConfig, bool, error) { trimmed := strings.TrimSpace(raw) if trimmed == "" { - return fallback, nil + return database.ChatModelConfig{}, false, nil } configuredModelConfigID, err := uuid.Parse(trimmed) if err != nil { - p.logger.Warn(ctx, - "invalid Explore model override, falling back to current turn model", + p.logger.Info(ctx, + "invalid model override, ignoring", + slog.F("override_context", overrideContext), slog.F("raw_model_config_id", trimmed), slog.Error(err), ) - return fallback, nil + return database.ChatModelConfig{}, false, nil } - modelConfig, err := p.db.GetEnabledChatModelConfigByID( - chatdCtx, + modelConfig, providerName, err := resolveModelConfig( + ctx, configuredModelConfigID, ) if err != nil { - if xerrors.Is(err, sql.ErrNoRows) { + switch { + case xerrors.Is(err, sql.ErrNoRows): + p.logger.Info(ctx, + "model override is unavailable, ignoring", + slog.F("override_context", overrideContext), + slog.F("model_config_id", configuredModelConfigID), + ) + case errors.Is(err, errInvalidModelOverrideMetadata): + p.logger.Info(ctx, + "model override metadata is invalid, ignoring", + slog.F("override_context", overrideContext), + slog.F("model_config_id", configuredModelConfigID), + slog.Error(err), + ) + default: p.logger.Warn(ctx, - "explore model override is unavailable, falling back to current turn model", + "failed to resolve model override, ignoring", + slog.F("override_context", overrideContext), slog.F("model_config_id", configuredModelConfigID), + slog.Error(err), ) - return fallback, nil } - return uuid.Nil, xerrors.Errorf("get enabled chat model config by id: %w", err) + return database.ChatModelConfig{}, false, nil } - providerName, _, err := chatprovider.ResolveModelWithProviderHint( - modelConfig.Model, - modelConfig.Provider, - ) + providerKeys, err := resolveProviderKeys(ctx, ownerID) if err != nil { - return uuid.Nil, xerrors.Errorf("resolve Explore model provider: %w", err) - } - providerKeys, err := p.resolveUserProviderAPIKeys(ctx, ownerID) - if err != nil { - return uuid.Nil, xerrors.Errorf("resolve provider API keys: %w", err) + return database.ChatModelConfig{}, false, xerrors.Errorf( + "resolve provider API keys: %w", + err, + ) } - if providerKeys.APIKey(providerName) == "" { - p.logger.Warn(ctx, - "explore model override credentials are unavailable, falling back to current turn model", + if providerKeys.APIKey(providerName) == "" && + !(chatprovider.ProviderAllowsAmbientCredentials(providerName) && + providerKeys.HasProvider(providerName)) { + p.logger.Info(ctx, + "model override credentials are unavailable, ignoring", + slog.F("override_context", overrideContext), slog.F("model_config_id", configuredModelConfigID), slog.F("provider", providerName), ) - return fallback, nil + return database.ChatModelConfig{}, false, nil + } + return modelConfig, true, nil +} + +func (p *Server) resolveSubagentModelConfigID( + ctx context.Context, + ownerID uuid.UUID, + overrideContext codersdk.ChatAgentModelOverrideContext, +) (uuid.UUID, error) { + //nolint:gocritic // Chatd needs its scoped deployment-config read access here. + chatdCtx := dbauthz.AsChatd(ctx) + raw, err := readSubagentModelOverride(chatdCtx, p.db, overrideContext) + if err != nil { + return uuid.Nil, xerrors.Errorf( + "get %s model override: %w", + subagentModelOverrideLogLabel(overrideContext), + err, + ) + } + modelConfig, ok, err := p.resolveConfiguredModelOverride( + ctx, + string(overrideContext), + raw, + ownerID, + p.resolveModelConfigAndNormalizedProvider, + p.resolveUserProviderAPIKeys, + ) + if err != nil { + return uuid.Nil, err + } + if !ok { + return uuid.Nil, nil } return modelConfig.ID, nil } +func (p *Server) resolveModelConfigAndNormalizedProvider( + ctx context.Context, + modelConfigID uuid.UUID, +) (database.ChatModelConfig, string, error) { + if modelConfigID == uuid.Nil { + return database.ChatModelConfig{}, "", sql.ErrNoRows + } + modelConfig, err := p.configCache.ModelConfigByID(ctx, modelConfigID) + if err != nil { + return database.ChatModelConfig{}, "", err + } + modelConfig, providerName, err := validateModelConfigAndResolveProvider(modelConfig) + if err != nil { + return database.ChatModelConfig{}, "", err + } + enabledProviders, err := p.configCache.EnabledProviders(ctx) + if err != nil { + return database.ChatModelConfig{}, "", err + } + if !enabledProviderContainsName(enabledProviders, providerName) { + return database.ChatModelConfig{}, "", sql.ErrNoRows + } + return modelConfig, providerName, nil +} + func (p *Server) subagentTools( ctx context.Context, currentChat func() database.Chat, @@ -444,7 +590,6 @@ func (p *Server) loadSubagentSpawnParentChat( if err := validateSubagentSpawnParent(parent); err != nil { return database.Chat{}, err } - reloadedParent, err := p.db.GetChatByID(ctx, parent.ID) if err != nil { p.logger.Warn(ctx, "failed to load parent chat for spawn_agent", diff --git a/coderd/x/chatd/subagent_catalog.go b/coderd/x/chatd/subagent_catalog.go index 5e920a9faf0ca..2b08f45045fec 100644 --- a/coderd/x/chatd/subagent_catalog.go +++ b/coderd/x/chatd/subagent_catalog.go @@ -9,6 +9,7 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/codersdk" ) const ( @@ -43,22 +44,37 @@ func allSubagentDefinitions() []subagentDefinition { { id: subagentTypeGeneral, description: "delegated work that may inspect or modify workspace files", - buildOptions: func(_ context.Context, _ *Server, _ database.Chat, _ database.Chat, _ uuid.UUID, _ string) (childSubagentChatOptions, error) { - return childSubagentChatOptions{}, nil + buildOptions: func(ctx context.Context, p *Server, parent database.Chat, _ database.Chat, _ uuid.UUID, _ string) (childSubagentChatOptions, error) { + modelConfigID, err := p.resolveSubagentModelConfigID( + ctx, + parent.OwnerID, + codersdk.ChatAgentModelOverrideContextGeneral, + ) + if err != nil { + return childSubagentChatOptions{}, err + } + options := childSubagentChatOptions{} + if modelConfigID != uuid.Nil { + options.modelConfigIDOverride = &modelConfigID + } + return options, nil }, }, { id: subagentTypeExplore, description: "read-only discovery, code tracing, and system understanding", buildOptions: func(ctx context.Context, p *Server, _ database.Chat, turnParent database.Chat, currentModelConfigID uuid.UUID, _ string) (childSubagentChatOptions, error) { - modelConfigID, err := p.resolveExploreSubagentModelConfigID( + modelConfigID, err := p.resolveSubagentModelConfigID( ctx, turnParent.OwnerID, - currentModelConfigID, + codersdk.ChatAgentModelOverrideContextExplore, ) if err != nil { return childSubagentChatOptions{}, err } + if modelConfigID == uuid.Nil { + modelConfigID = currentModelConfigID + } inheritedMCPServerIDs, err := p.resolveExploreToolSnapshot( ctx, turnParent, diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index c9872d7438c8e..a1a24702b226b 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "sync" "testing" "time" @@ -12,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -71,7 +73,14 @@ func newInternalTestServer( ps pubsub.Pubsub, keys chatprovider.ProviderAPIKeys, ) *Server { - return newInternalTestServerWithClock(t, db, ps, keys, nil) + return newInternalTestServerWithLoggerAndClock( + t, + db, + ps, + keys, + slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + nil, + ) } func newInternalTestServerWithClock( @@ -80,10 +89,37 @@ func newInternalTestServerWithClock( ps pubsub.Pubsub, keys chatprovider.ProviderAPIKeys, clk quartz.Clock, +) *Server { + return newInternalTestServerWithLoggerAndClock( + t, + db, + ps, + keys, + slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + clk, + ) +} + +func newInternalTestServerWithLogger( + t *testing.T, + db database.Store, + ps pubsub.Pubsub, + keys chatprovider.ProviderAPIKeys, + logger slog.Logger, +) *Server { + return newInternalTestServerWithLoggerAndClock(t, db, ps, keys, logger, nil) +} + +func newInternalTestServerWithLoggerAndClock( + t *testing.T, + db database.Store, + ps pubsub.Pubsub, + keys chatprovider.ProviderAPIKeys, + logger slog.Logger, + clk quartz.Clock, ) *Server { t.Helper() - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) server := New(Config{ Logger: logger, Database: db, @@ -101,6 +137,35 @@ func newInternalTestServerWithClock( return server } +type subagentTestLogSink struct { + mu sync.Mutex + entries []slog.SinkEntry +} + +func (s *subagentTestLogSink) LogEntry(_ context.Context, entry slog.SinkEntry) { + s.mu.Lock() + defer s.mu.Unlock() + s.entries = append(s.entries, entry) +} + +func (*subagentTestLogSink) Sync() {} + +func (s *subagentTestLogSink) entriesAtLevelWithMessage( + level slog.Level, + message string, +) []slog.SinkEntry { + s.mu.Lock() + defer s.mu.Unlock() + + entries := make([]slog.SinkEntry, 0, len(s.entries)) + for _, entry := range s.entries { + if entry.Level == level && entry.Message == message { + entries = append(entries, entry) + } + } + return entries +} + // seedInternalChatDeps inserts an OpenAI provider and model config // into the database and returns the created user, organization, // and model. This deliberately does NOT create an Anthropic @@ -218,6 +283,54 @@ func insertInternalChatModelConfig( userID uuid.UUID, model string, enabled bool, +) database.ChatModelConfig { + return insertInternalChatModelConfigForProvider( + ctx, + t, + db, + userID, + "openai", + model, + enabled, + ) +} + +func insertInternalChatProvider( + ctx context.Context, + t *testing.T, + db database.Store, + userID uuid.UUID, + provider string, + apiKey string, + centralAPIKeyEnabled bool, + allowUserAPIKey bool, + allowCentralAPIKeyFallback bool, +) database.ChatProvider { + t.Helper() + + providerConfig, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: provider, + DisplayName: provider, + APIKey: apiKey, + CreatedBy: uuid.NullUUID{UUID: userID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: centralAPIKeyEnabled, + AllowUserApiKey: allowUserAPIKey, + AllowCentralApiKeyFallback: allowCentralAPIKeyFallback, + }) + require.NoError(t, err) + + return providerConfig +} + +func insertInternalChatModelConfigForProvider( + ctx context.Context, + t *testing.T, + db database.Store, + userID uuid.UUID, + provider string, + model string, + enabled bool, ) database.ChatModelConfig { t.Helper() return insertInternalChatModelConfigWithOptions( @@ -225,6 +338,7 @@ func insertInternalChatModelConfig( t, db, userID, + provider, model, enabled, json.RawMessage(`{}`), @@ -236,6 +350,7 @@ func insertInternalChatModelConfigWithOptions( t *testing.T, db database.Store, userID uuid.UUID, + provider string, model string, enabled bool, options json.RawMessage, @@ -243,7 +358,7 @@ func insertInternalChatModelConfigWithOptions( t.Helper() modelConfig, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ - Provider: "openai", + Provider: provider, Model: model, DisplayName: model, CreatedBy: uuid.NullUUID{UUID: userID, Valid: true}, @@ -574,6 +689,213 @@ func TestSpawnAgent_GeneralInheritsParentModelWhenOmitted(t *testing.T) { require.Equal(t, parentChat.LastModelConfigID, childChat.LastModelConfigID) } +func TestSpawnAgent_GeneralUsesConfiguredModelOverride(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + + ctx := chatdTestContext(t) + user, org, model := seedInternalChatDeps(ctx, t, db) + overrideModel := insertInternalChatModelConfig( + ctx, t, db, user.ID, "general-override-"+uuid.NewString(), true, + ) + require.NoError(t, db.UpsertChatGeneralModelOverride(ctx, overrideModel.ID.String())) + parentChat := createInternalParentChat( + ctx, t, server, db, org.ID, user.ID, model.ID, "parent-general-override", + ) + + resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{ + Type: subagentTypeGeneral, + Prompt: "delegate general work", + }) + childID := requireSpawnAgentChildChatID(t, resp) + + childChat, err := db.GetChatByID(ctx, childID) + require.NoError(t, err) + require.Equal(t, overrideModel.ID, childChat.LastModelConfigID) + require.False(t, childChat.PlanMode.Valid) +} + +func TestSpawnAgent_GeneralOverrideLogsAndFallsBackWhenCredentialsUnavailable(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + logSink := &subagentTestLogSink{} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink) + server := newInternalTestServerWithLogger(t, db, ps, chatprovider.ProviderAPIKeys{}, logger) + + ctx := chatdTestContext(t) + user, org, model := seedInternalChatDeps(ctx, t, db) + insertInternalChatProvider( + ctx, + t, + db, + user.ID, + "openai-compat", + "", + false, + true, + false, + ) + overrideModel := insertInternalChatModelConfigForProvider( + ctx, + t, + db, + user.ID, + "openai-compat", + "gpt-4o-mini", + true, + ) + require.NoError(t, db.UpsertChatGeneralModelOverride(ctx, overrideModel.ID.String())) + parent, err := server.CreateChat(ctx, CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "parent-general-credentials-fallback", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("delegate work"), + }, + }) + require.NoError(t, err) + parentChat, err := db.GetChatByID(ctx, parent.ID) + require.NoError(t, err) + + resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{ + Type: subagentTypeGeneral, + Prompt: "inspect provider credentials", + }) + childID := requireSpawnAgentChildChatID(t, resp) + + childChat, err := db.GetChatByID(ctx, childID) + require.NoError(t, err) + require.Equal(t, model.ID, childChat.LastModelConfigID) + require.False(t, childChat.PlanMode.Valid) + require.Len(t, logSink.entriesAtLevelWithMessage( + slog.LevelInfo, + "model override credentials are unavailable, ignoring", + ), 1) +} + +func TestSpawnAgent_GeneralOverrideLogsAndFallsBackWhenProviderDisabled(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + logSink := &subagentTestLogSink{} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink) + server := newInternalTestServerWithLogger( + t, + db, + ps, + chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{ + "openai-compat": "fallback-key", + }, + }, + logger, + ) + + ctx := chatdTestContext(t) + user, org, model := seedInternalChatDeps(ctx, t, db) + _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: "openai-compat", + DisplayName: "openai-compat", + APIKey: "", + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: false, + CentralApiKeyEnabled: false, + AllowUserApiKey: true, + AllowCentralApiKeyFallback: false, + }) + require.NoError(t, err) + overrideModel := insertInternalChatModelConfigForProvider( + ctx, + t, + db, + user.ID, + "openai-compat", + "gpt-4o-mini", + true, + ) + require.NoError(t, db.UpsertChatGeneralModelOverride(ctx, overrideModel.ID.String())) + parent, err := server.CreateChat(ctx, CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "parent-general-disabled-provider-fallback", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("delegate work"), + }, + }) + require.NoError(t, err) + parentChat, err := db.GetChatByID(ctx, parent.ID) + require.NoError(t, err) + + resp := runSpawnAgentTool(ctx, t, server, parentChat, spawnAgentArgs{ + Type: subagentTypeGeneral, + Prompt: "inspect disabled providers", + }) + childID := requireSpawnAgentChildChatID(t, resp) + + childChat, err := db.GetChatByID(ctx, childID) + require.NoError(t, err) + require.Equal(t, model.ID, childChat.LastModelConfigID) + require.False(t, childChat.PlanMode.Valid) + require.Len(t, logSink.entriesAtLevelWithMessage( + slog.LevelInfo, + "model override is unavailable, ignoring", + ), 1) +} + +func TestResolveConfiguredModelOverride_AcceptsAmbientCredentialsProvider( + t *testing.T, +) { + t.Parallel() + + logSink := &subagentTestLogSink{} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink) + server := &Server{logger: logger} + ctx := chatdTestContext(t) + ownerID := uuid.New() + modelConfig := database.ChatModelConfig{ + ID: uuid.New(), + Provider: "bedrock", + Model: "anthropic.claude-haiku-4-5-20251001-v1:0", + DisplayName: "Ambient Bedrock Override", + Enabled: true, + } + + resolvedModelConfig, ok, err := server.resolveConfiguredModelOverride( + ctx, + "plan", + modelConfig.ID.String(), + ownerID, + func( + _ context.Context, + configuredModelConfigID uuid.UUID, + ) (database.ChatModelConfig, string, error) { + require.Equal(t, modelConfig.ID, configuredModelConfigID) + return modelConfig, "bedrock", nil + }, + func( + _ context.Context, + resolvedOwnerID uuid.UUID, + ) (chatprovider.ProviderAPIKeys, error) { + require.Equal(t, ownerID, resolvedOwnerID) + return chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{"bedrock": ""}, + }, nil + }, + ) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, modelConfig, resolvedModelConfig) + require.Empty(t, logSink.entriesAtLevelWithMessage( + slog.LevelInfo, + "model override credentials are unavailable, ignoring", + )) +} + func TestCreateChildSubagentChat_OverrideWorksWhenParentHasNoModel(t *testing.T) { t.Parallel() @@ -1328,7 +1650,6 @@ func TestSubagentLifecycleToolsIncludePersistedSubagentTypeAcrossVariants(t *tes } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() @@ -1450,7 +1771,6 @@ func TestSubagentLifecycleToolErrorsIncludePersistedSubagentType(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() diff --git a/codersdk/chats.go b/codersdk/chats.go index 5814464bba48e..7f66cb4203823 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -562,19 +562,46 @@ type UpdateChatPlanModeInstructionsRequest struct { PlanModeInstructions string `json:"plan_mode_instructions"` } -// ChatExploreModelOverrideResponse is the response body for the Explore -// subagent model override configuration endpoint. -type ChatExploreModelOverrideResponse struct { - ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` - // HasMalformedOverride reports whether the saved override is malformed and - // is currently being treated as unset. - HasMalformedOverride bool `json:"has_malformed_override"` +// ChatAgentModelOverrideContext identifies which chat or subagent context +// a deployment override applies to. +type ChatAgentModelOverrideContext string + +const ( + ChatAgentModelOverrideContextGeneral ChatAgentModelOverrideContext = "general" + ChatAgentModelOverrideContextExplore ChatAgentModelOverrideContext = "explore" +) + +// Valid reports whether the override context is one of the supported values. +func (c ChatAgentModelOverrideContext) Valid() bool { + switch c { + case ChatAgentModelOverrideContextGeneral, + ChatAgentModelOverrideContextExplore: + return true + default: + return false + } } -// UpdateChatExploreModelOverrideRequest is the request body for updating the -// Explore subagent model override configuration endpoint. -type UpdateChatExploreModelOverrideRequest struct { - ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` +// AllChatAgentModelOverrideContexts returns all supported override contexts. +func AllChatAgentModelOverrideContexts() []ChatAgentModelOverrideContext { + return []ChatAgentModelOverrideContext{ + ChatAgentModelOverrideContextGeneral, + ChatAgentModelOverrideContextExplore, + } +} + +// ChatAgentModelOverrideResponse is the response body for the chat agent +// model override configuration endpoint. +type ChatAgentModelOverrideResponse struct { + Context ChatAgentModelOverrideContext `json:"context"` + ModelConfigID string `json:"model_config_id"` + IsMalformed bool `json:"is_malformed"` +} + +// UpdateChatAgentModelOverrideRequest is the request body for updating the +// chat agent model override configuration endpoint. +type UpdateChatAgentModelOverrideRequest struct { + ModelConfigID string `json:"model_config_id"` } // UserChatCustomPrompt is the request and response body for the @@ -2024,25 +2051,33 @@ func (c *ExperimentalClient) UpdateChatPlanModeInstructions(ctx context.Context, return nil } -// GetChatExploreModelOverride returns the deployment-wide Explore subagent -// model override. -func (c *ExperimentalClient) GetChatExploreModelOverride(ctx context.Context) (ChatExploreModelOverrideResponse, error) { - res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/explore-model-override", nil) +// GetChatAgentModelOverride returns the deployment-wide chat agent model +// override for the requested context. +func (c *ExperimentalClient) GetChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext) (ChatAgentModelOverrideResponse, error) { + path := fmt.Sprintf( + "/api/experimental/chats/config/agent-model-override/%s", + url.PathEscape(string(override)), + ) + res, err := c.Request(ctx, http.MethodGet, path, nil) if err != nil { - return ChatExploreModelOverrideResponse{}, err + return ChatAgentModelOverrideResponse{}, err } defer res.Body.Close() if res.StatusCode != http.StatusOK { - return ChatExploreModelOverrideResponse{}, ReadBodyAsError(res) + return ChatAgentModelOverrideResponse{}, ReadBodyAsError(res) } - var resp ChatExploreModelOverrideResponse + var resp ChatAgentModelOverrideResponse return resp, json.NewDecoder(res.Body).Decode(&resp) } -// UpdateChatExploreModelOverride updates the deployment-wide Explore subagent -// model override. -func (c *ExperimentalClient) UpdateChatExploreModelOverride(ctx context.Context, req UpdateChatExploreModelOverrideRequest) error { - res, err := c.Request(ctx, http.MethodPut, "/api/experimental/chats/config/explore-model-override", req) +// UpdateChatAgentModelOverride updates the deployment-wide chat agent model +// override for the requested context. +func (c *ExperimentalClient) UpdateChatAgentModelOverride(ctx context.Context, override ChatAgentModelOverrideContext, req UpdateChatAgentModelOverrideRequest) error { + path := fmt.Sprintf( + "/api/experimental/chats/config/agent-model-override/%s", + url.PathEscape(string(override)), + ) + res, err := c.Request(ctx, http.MethodPut, path, req) if err != nil { return err } diff --git a/site/src/api/api.ts b/site/src/api/api.ts index f47d4972a35f4..dd35d9a10f53b 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -3257,20 +3257,22 @@ class ExperimentalApiMethods { ); }; - getChatExploreModelOverride = - async (): Promise => { - const response = - await this.axios.get( - "/api/experimental/chats/config/explore-model-override", - ); - return response.data; - }; + getChatAgentModelOverride = async ( + context: TypesGen.ChatAgentModelOverrideContext, + ): Promise => { + const response = + await this.axios.get( + `/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`, + ); + return response.data; + }; - updateChatExploreModelOverride = async ( - req: TypesGen.UpdateChatExploreModelOverrideRequest, + updateChatAgentModelOverride = async ( + context: TypesGen.ChatAgentModelOverrideContext, + req: TypesGen.UpdateChatAgentModelOverrideRequest, ): Promise => { await this.axios.put( - "/api/experimental/chats/config/explore-model-override", + `/api/experimental/chats/config/agent-model-override/${encodeURIComponent(context)}`, req, ); }; diff --git a/site/src/api/queries/chats.ts b/site/src/api/queries/chats.ts index fdf33f66c6169..5a59772971120 100644 --- a/site/src/api/queries/chats.ts +++ b/site/src/api/queries/chats.ts @@ -1127,23 +1127,6 @@ export const updateChatPlanModeInstructions = (queryClient: QueryClient) => ({ }, }); -const chatExploreModelOverrideKey = ["chat-explore-model-override"] as const; - -export const chatExploreModelOverride = () => ({ - queryKey: chatExploreModelOverrideKey, - queryFn: () => API.experimental.getChatExploreModelOverride(), -}); - -export const updateChatExploreModelOverride = (queryClient: QueryClient) => ({ - mutationFn: (req: TypesGen.UpdateChatExploreModelOverrideRequest) => - API.experimental.updateChatExploreModelOverride(req), - onSuccess: async () => { - await queryClient.invalidateQueries({ - queryKey: chatExploreModelOverrideKey, - }); - }, -}); - const chatDesktopEnabledKey = ["chat-desktop-enabled"] as const; export const chatDesktopEnabled = () => ({ diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index c0abec1ba3c19..4876c797103b6 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1272,6 +1272,25 @@ export interface Chat { readonly children: readonly Chat[]; } +// From codersdk/chats.go +export type ChatAgentModelOverrideContext = "explore" | "general"; + +export const ChatAgentModelOverrideContexts: ChatAgentModelOverrideContext[] = [ + "explore", + "general", +]; + +// From codersdk/chats.go +/** + * ChatAgentModelOverrideResponse is the response body for the chat agent + * model override configuration endpoint. + */ +export interface ChatAgentModelOverrideResponse { + readonly context: ChatAgentModelOverrideContext; + readonly model_config_id: string; + readonly is_malformed: boolean; +} + // From codersdk/chats.go export type ChatBusyBehavior = "interrupt" | "queue"; @@ -1590,20 +1609,6 @@ export interface ChatDiffStatus { readonly stale_at?: string; } -// From codersdk/chats.go -/** - * ChatExploreModelOverrideResponse is the response body for the Explore - * subagent model override configuration endpoint. - */ -export interface ChatExploreModelOverrideResponse { - readonly model_config_id?: string; - /** - * HasMalformedOverride reports whether the saved override is malformed and - * is currently being treated as unset. - */ - readonly has_malformed_override: boolean; -} - // From codersdk/chats.go /** * ChatFileMetadata contains lightweight metadata about a file @@ -7657,6 +7662,15 @@ export interface UpdateAppearanceConfig { readonly announcement_banners: readonly BannerConfig[]; } +// From codersdk/chats.go +/** + * UpdateChatAgentModelOverrideRequest is the request body for updating the + * chat agent model override configuration endpoint. + */ +export interface UpdateChatAgentModelOverrideRequest { + readonly model_config_id: string; +} + // From codersdk/chats.go /** * UpdateChatDebugLoggingAllowUsersRequest is the admin request to @@ -7674,15 +7688,6 @@ export interface UpdateChatDesktopEnabledRequest { readonly enable_desktop: boolean; } -// From codersdk/chats.go -/** - * UpdateChatExploreModelOverrideRequest is the request body for updating the - * Explore subagent model override configuration endpoint. - */ -export interface UpdateChatExploreModelOverrideRequest { - readonly model_config_id?: string; -} - // From codersdk/chats.go /** * UpdateChatModelConfigRequest updates a chat model config. diff --git a/site/src/pages/AgentsPage/AgentSettingsAgentsPage.tsx b/site/src/pages/AgentsPage/AgentSettingsAgentsPage.tsx index 8aceb61d656b9..b406a173ea27b 100644 --- a/site/src/pages/AgentsPage/AgentSettingsAgentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsAgentsPage.tsx @@ -1,34 +1,83 @@ import type { FC } from "react"; -import { useMutation, useQuery, useQueryClient } from "react-query"; import { - chatExploreModelOverride, - chatModelConfigs, - updateChatExploreModelOverride, -} from "#/api/queries/chats"; + type QueryClient, + useMutation, + useQuery, + useQueryClient, +} from "react-query"; +import { API } from "#/api/api"; +import { chatModelConfigs } from "#/api/queries/chats"; +import type * as TypesGen from "#/api/typesGenerated"; import { useAuthenticated } from "#/hooks/useAuthenticated"; import { RequirePermission } from "#/modules/permissions/RequirePermission"; import { AgentSettingsAgentsPageView } from "./AgentSettingsAgentsPageView"; +const generalOverrideContext: TypesGen.ChatAgentModelOverrideContext = + "general"; +const exploreOverrideContext: TypesGen.ChatAgentModelOverrideContext = + "explore"; + +const chatAgentModelOverrideKey = ( + context: TypesGen.ChatAgentModelOverrideContext, +) => ["chat-agent-model-override", context] as const; + +const chatAgentModelOverrideQuery = ( + context: TypesGen.ChatAgentModelOverrideContext, +) => ({ + queryKey: chatAgentModelOverrideKey(context), + queryFn: () => API.experimental.getChatAgentModelOverride(context), +}); + +const updateChatAgentModelOverrideMutation = ( + queryClient: QueryClient, + context: TypesGen.ChatAgentModelOverrideContext, +) => ({ + mutationFn: (req: TypesGen.UpdateChatAgentModelOverrideRequest) => + API.experimental.updateChatAgentModelOverride(context, req), + onSuccess: async () => { + await queryClient.invalidateQueries({ + queryKey: chatAgentModelOverrideKey(context), + exact: true, + }); + }, +}); + const AgentSettingsAgentsPage: FC = () => { const { permissions } = useAuthenticated(); const queryClient = useQueryClient(); + const canEditDeploymentConfig = permissions.editDeploymentConfig; + const generalModelOverrideQuery = useQuery({ + ...chatAgentModelOverrideQuery(generalOverrideContext), + enabled: canEditDeploymentConfig, + }); const exploreModelOverrideQuery = useQuery({ - ...chatExploreModelOverride(), - enabled: permissions.editDeploymentConfig, + ...chatAgentModelOverrideQuery(exploreOverrideContext), + enabled: canEditDeploymentConfig, }); const modelConfigsQuery = useQuery(chatModelConfigs()); + const saveGeneralModelOverrideMutation = useMutation( + updateChatAgentModelOverrideMutation(queryClient, generalOverrideContext), + ); const saveExploreModelOverrideMutation = useMutation( - updateChatExploreModelOverride(queryClient), + updateChatAgentModelOverrideMutation(queryClient, exploreOverrideContext), ); return ( - + , +): TypesGen.ChatModelConfig => ({ + id: "model-default", + provider: "openai", + model: "gpt-4.1-mini", + display_name: "GPT 4.1 Mini", + enabled: true, + is_default: false, + context_limit: 1_000_000, + compression_threshold: 70, + created_at: "2026-03-12T12:00:00.000Z", + updated_at: "2026-03-12T12:00:00.000Z", + ...overrides, +}); + +const buildOverrideData = ( + context: TypesGen.ChatAgentModelOverrideContext, + overrides: Partial = {}, +): TypesGen.ChatAgentModelOverrideResponse => ({ + context, + model_config_id: "", + is_malformed: false, + ...overrides, +}); + +const generalModelConfig = buildModelConfig({ + id: "model-general-gpt-4.1-mini", + display_name: "GPT 4.1 Mini", +}); + +const claudeSonnetModelConfig = buildModelConfig({ + id: "model-claude-sonnet-4", + provider: "anthropic", + model: "claude-sonnet-4", + display_name: "Claude Sonnet 4", + context_limit: 200_000, +}); + +const exploreFallbackModelConfig = buildModelConfig({ + id: "model-explore-blank-display", + provider: "anthropic", + model: "claude-sonnet-4-20250514", + display_name: "", + context_limit: 200_000, +}); + +const generalDisabledModelConfig = buildModelConfig({ + id: "model-general-disabled", + model: "gpt-4.1-legacy", + display_name: "GPT 4.1 Legacy", + enabled: false, +}); + +const exploreDisabledModelConfig = buildModelConfig({ + id: "model-explore-disabled", + provider: "anthropic", + model: "claude-haiku-legacy", + display_name: "Claude Haiku Legacy", + enabled: false, + context_limit: 200_000, +}); + +const allModelConfigs: TypesGen.ChatModelConfig[] = [ + generalModelConfig, + claudeSonnetModelConfig, + exploreFallbackModelConfig, + generalDisabledModelConfig, + exploreDisabledModelConfig, +]; + +const makeArgs = ( + overrides: Partial = {}, +): AgentSettingsAgentsPageViewProps => ({ + generalModelOverrideData: buildOverrideData("general"), + exploreModelOverrideData: buildOverrideData("explore"), + modelConfigsData: allModelConfigs, modelConfigsError: undefined, isLoadingModelConfigs: false, + onSaveGeneralModelOverride: fn(), + isSavingGeneralModelOverride: false, + isSaveGeneralModelOverrideError: false, onSaveExploreModelOverride: fn(), isSavingExploreModelOverride: false, isSaveExploreModelOverrideError: false, + ...overrides, +}); + +const getSection = async ( + canvasElement: HTMLElement, + headingName: string, +): Promise => { + const canvas = within(canvasElement); + const heading = await canvas.findByRole("heading", { name: headingName }); + const section = heading.closest("section"); + if (!(section instanceof HTMLElement)) { + throw new Error( + `Expected ${headingName} heading to live inside a section.`, + ); + } + return section; +}; + +const selectModelInSection = async ( + section: HTMLElement, + canvasElement: HTMLElement, + currentSelectionName: string | RegExp, + optionName: string, +) => { + const trigger = within(section).getByRole("combobox", { + name: currentSelectionName, + }); + await userEvent.click(trigger); + const body = within(canvasElement.ownerDocument.body); + await userEvent.click(await body.findByRole("option", { name: optionName })); }; const meta = { title: "pages/AgentsPage/AgentSettingsAgentsPageView", component: AgentSettingsAgentsPageView, - args: baseArgs, + args: makeArgs(), } satisfies Meta; export default meta; type Story = StoryObj; -export const ExploreModelOverrideSetting: Story = { - args: { - exploreModelOverrideData: { - model_config_id: "model-explore-1", - has_malformed_override: false, - }, - modelConfigsData: [ - { - id: "model-explore-1", - provider: "openai", - model: "gpt-4.1-mini", - display_name: "GPT 4.1 Mini", - enabled: true, - is_default: false, - context_limit: 1_000_000, - compression_threshold: 70, - created_at: "2026-03-12T12:00:00.000Z", - updated_at: "2026-03-12T12:00:00.000Z", - }, - { - id: "model-explore-2", - provider: "anthropic", - model: "claude-sonnet-4", - display_name: "Claude Sonnet 4", - enabled: true, - is_default: false, - context_limit: 200_000, - compression_threshold: 70, - created_at: "2026-03-12T12:00:00.000Z", - updated_at: "2026-03-12T12:00:00.000Z", - }, - ] as TypesGen.ChatModelConfig[], - }, - play: async ({ canvasElement, args }) => { +export const AllOverridesUnset: Story = { + args: makeArgs(), + play: async ({ canvasElement }) => { const canvas = within(canvasElement); await canvas.findByText("Agents"); - await canvas.findByText("Explore subagent model"); - const trigger = canvas.getByRole("combobox", { - name: /gpt 4.1 mini/i, - }); - await userEvent.click(trigger); - const body = within(canvasElement.ownerDocument.body); - await userEvent.click( - await body.findByRole("option", { name: "Claude Sonnet 4" }), - ); - const form = trigger.closest("form"); - if (!(form instanceof HTMLFormElement)) { - throw new Error("Expected Explore model selector to live inside a form."); + + const headings = await canvas.findAllByRole("heading", { level: 3 }); + expect(headings.map((heading) => heading.textContent?.trim())).toEqual([ + "General model", + "Explore subagent model", + ]); + + for (const headingName of ["General model", "Explore subagent model"]) { + const section = await getSection(canvasElement, headingName); + expect( + within(section).getByRole("combobox", { name: "Use chat default" }), + ).toBeInTheDocument(); + expect( + within(section).getByRole("button", { name: "Save" }), + ).toBeDisabled(); } - const saveButton = within(form).getByRole("button", { name: "Save" }); + }, +}; + +export const EachOverrideSetToEnabledModel: Story = { + args: makeArgs({ + generalModelOverrideData: buildOverrideData("general", { + model_config_id: generalModelConfig.id, + }), + exploreModelOverrideData: buildOverrideData("explore", { + model_config_id: exploreFallbackModelConfig.id, + }), + }), + play: async ({ canvasElement, args }) => { + const generalSection = await getSection(canvasElement, "General model"); + const exploreSection = await getSection( + canvasElement, + "Explore subagent model", + ); + + expect( + within(exploreSection).getByRole("combobox", { + name: /claude-sonnet-4-20250514/i, + }), + ).toHaveTextContent("claude-sonnet-4-20250514"); + + await selectModelInSection( + generalSection, + canvasElement, + /gpt 4\.1 mini/i, + "Claude Sonnet 4", + ); + const generalSaveButton = within(generalSection).getByRole("button", { + name: "Save", + }); await waitFor(() => { - expect(saveButton).toBeEnabled(); + expect(generalSaveButton).toBeEnabled(); }); - await userEvent.click(saveButton); + await userEvent.click(generalSaveButton); await waitFor(() => { - expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith( - { model_config_id: "model-explore-2" }, + expect(args.onSaveGeneralModelOverride).toHaveBeenCalledWith( + { model_config_id: claudeSonnetModelConfig.id }, expect.anything(), ); }); - }, -}; - -export const ExploreModelOverrideAllowsExplicitClear: Story = { - args: { - exploreModelOverrideData: { - model_config_id: "model-explore-clear", - has_malformed_override: false, - }, - modelConfigsData: [ - { - id: "model-explore-clear", - provider: "openai", - model: "gpt-4.1-mini", - display_name: "GPT 4.1 Mini", - enabled: true, - is_default: false, - context_limit: 1_000_000, - compression_threshold: 70, - created_at: "2026-03-12T12:00:00.000Z", - updated_at: "2026-03-12T12:00:00.000Z", - }, - ] as TypesGen.ChatModelConfig[], - onSaveExploreModelOverride: fn(), - }, - play: async ({ canvasElement, args }) => { - const canvas = within(canvasElement); - const clearButton = await canvas.findByRole("button", { name: "Clear" }); - const form = clearButton.closest("form"); - if (!(form instanceof HTMLFormElement)) { - throw new Error( - "Expected Explore model clear button to live inside a form.", - ); - } - const saveButton = within(form).getByRole("button", { name: "Save" }); - await userEvent.click(clearButton); - expect(args.onSaveExploreModelOverride).not.toHaveBeenCalled(); + const exploreClearButton = within(exploreSection).getByRole("button", { + name: "Clear", + }); + await userEvent.click(exploreClearButton); + const exploreSaveButton = within(exploreSection).getByRole("button", { + name: "Save", + }); await waitFor(() => { - expect(saveButton).toBeEnabled(); + expect(exploreSaveButton).toBeEnabled(); }); - await userEvent.click(saveButton); + await userEvent.click(exploreSaveButton); await waitFor(() => { expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith( - {}, + { model_config_id: "" }, expect.anything(), ); }); }, }; -export const ExploreModelOverrideClearsMalformedSavedValue: Story = { - args: { - exploreModelOverrideData: { - has_malformed_override: true, - }, - modelConfigsData: [], - onSaveExploreModelOverride: fn(), - }, +export const MalformedOverridesRemainClearableAndSaveable: Story = { + args: makeArgs({ + generalModelOverrideData: buildOverrideData("general", { + is_malformed: true, + }), + exploreModelOverrideData: buildOverrideData("explore", { + is_malformed: true, + }), + }), play: async ({ canvasElement, args }) => { - const canvas = within(canvasElement); - await canvas.findByText( - "The saved override is malformed and is being treated as unset. Click Save to clear it.", + const generalSection = await getSection(canvasElement, "General model"); + const exploreSection = await getSection( + canvasElement, + "Explore subagent model", ); - const clearButton = await canvas.findByRole("button", { name: "Clear" }); - const form = clearButton.closest("form"); - if (!(form instanceof HTMLFormElement)) { - throw new Error( - "Expected Explore model clear button to live inside a form.", - ); + + for (const section of [generalSection, exploreSection]) { + await within(section).findByText(OVERRIDE_MALFORMED_WARNING); } - const saveButton = within(form).getByRole("button", { name: "Save" }); + const generalSaveButton = within(generalSection).getByRole("button", { + name: "Save", + }); + await waitFor(() => { + expect(generalSaveButton).toBeEnabled(); + }); + await userEvent.click(generalSaveButton); + await waitFor(() => { + expect(args.onSaveGeneralModelOverride).toHaveBeenCalledWith( + { model_config_id: "" }, + expect.anything(), + ); + }); + + const exploreSaveButton = within(exploreSection).getByRole("button", { + name: "Save", + }); await waitFor(() => { - expect(saveButton).toBeEnabled(); + expect(exploreSaveButton).toBeEnabled(); }); - await userEvent.click(clearButton); - expect(args.onSaveExploreModelOverride).not.toHaveBeenCalled(); - await userEvent.click(saveButton); + await userEvent.click(exploreSaveButton); await waitFor(() => { expect(args.onSaveExploreModelOverride).toHaveBeenCalledWith( - {}, + { model_config_id: "" }, expect.anything(), ); }); }, }; -export const ExploreModelOverrideFallsBackToModelName: Story = { - args: { - exploreModelOverrideData: { - model_config_id: "model-explore-empty-name", - has_malformed_override: false, - }, - modelConfigsData: [ - { - id: "model-explore-empty-name", - provider: "anthropic", - model: "claude-sonnet-4-20250514", - display_name: "", - enabled: true, - is_default: false, - context_limit: 200_000, - compression_threshold: 70, - created_at: "2026-03-12T12:00:00.000Z", - updated_at: "2026-03-12T12:00:00.000Z", - }, - ] as TypesGen.ChatModelConfig[], - }, +export const UnavailableSavedModels: Story = { + args: makeArgs({ + generalModelOverrideData: buildOverrideData("general", { + model_config_id: generalDisabledModelConfig.id, + }), + exploreModelOverrideData: buildOverrideData("explore", { + model_config_id: exploreDisabledModelConfig.id, + }), + }), play: async ({ canvasElement }) => { - const canvas = within(canvasElement); - const trigger = await canvas.findByRole("combobox", { - name: /claude-sonnet-4-20250514/i, - }); - expect(trigger).toHaveTextContent("claude-sonnet-4-20250514"); - await userEvent.click(trigger); - const body = within(canvasElement.ownerDocument.body); - expect( - await body.findByRole("option", { - name: "claude-sonnet-4-20250514", - }), - ).toBeInTheDocument(); + const generalSection = await getSection(canvasElement, "General model"); + const exploreSection = await getSection( + canvasElement, + "Explore subagent model", + ); + + for (const section of [generalSection, exploreSection]) { + await within(section).findByText(UNAVAILABLE_SAVED_MODEL_WARNING); + expect( + within(section).getByRole("combobox", { name: "Unavailable model" }), + ).toBeInTheDocument(); + } }, }; diff --git a/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.tsx b/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.tsx index 8e4d25e406aa0..94006dd8e9e42 100644 --- a/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.tsx +++ b/site/src/pages/AgentsPage/AgentSettingsAgentsPageView.tsx @@ -1,24 +1,26 @@ import type { FC } from "react"; import type * as TypesGen from "#/api/typesGenerated"; -import { ExploreModelOverrideSettings } from "./components/ExploreModelOverrideSettings"; import { SectionHeader } from "./components/SectionHeader"; +import { + type MutationCallbacks, + SubagentModelOverrideSettings, +} from "./components/SubagentModelOverrideSettings"; -interface MutationCallbacks { - onSuccess?: () => void; - onError?: () => void; -} +type SaveModelOverride = ( + req: TypesGen.UpdateChatAgentModelOverrideRequest, + options?: MutationCallbacks, +) => void; export interface AgentSettingsAgentsPageViewProps { - exploreModelOverrideData: - | TypesGen.ChatExploreModelOverrideResponse - | undefined; + generalModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse; + exploreModelOverrideData?: TypesGen.ChatAgentModelOverrideResponse; modelConfigsData: TypesGen.ChatModelConfig[] | undefined; modelConfigsError: unknown; isLoadingModelConfigs: boolean; - onSaveExploreModelOverride: ( - req: TypesGen.UpdateChatExploreModelOverrideRequest, - options?: MutationCallbacks, - ) => void; + onSaveGeneralModelOverride?: SaveModelOverride; + isSavingGeneralModelOverride?: boolean; + isSaveGeneralModelOverrideError?: boolean; + onSaveExploreModelOverride: SaveModelOverride; isSavingExploreModelOverride: boolean; isSaveExploreModelOverrideError: boolean; } @@ -26,37 +28,84 @@ export interface AgentSettingsAgentsPageViewProps { export const AgentSettingsAgentsPageView: FC< AgentSettingsAgentsPageViewProps > = ({ + generalModelOverrideData, exploreModelOverrideData, modelConfigsData, modelConfigsError, isLoadingModelConfigs, + onSaveGeneralModelOverride, + isSavingGeneralModelOverride = false, + isSaveGeneralModelOverrideError = false, onSaveExploreModelOverride, isSavingExploreModelOverride, isSaveExploreModelOverrideError, }) => { + const enabledModelConfigs = (modelConfigsData ?? []).filter( + (modelConfig) => modelConfig.enabled, + ); + const showGeneralModelSection = + onSaveGeneralModelOverride !== undefined || + generalModelOverrideData !== undefined || + isSavingGeneralModelOverride || + isSaveGeneralModelOverrideError; + return (
-
+ {showGeneralModelSection && onSaveGeneralModelOverride && ( +
+ + +
+ )} +
- + Deployment-wide model override for read-only Explore subagents + launched through the spawn_agent tool with a + type=explore argument. + + } + modelOverrideData={exploreModelOverrideData} + enabledModelConfigs={enabledModelConfigs} modelConfigsError={modelConfigsError} - isLoadingModelConfigs={isLoadingModelConfigs} - onSaveExploreModelOverride={onSaveExploreModelOverride} - isSavingExploreModelOverride={isSavingExploreModelOverride} - isSaveExploreModelOverrideError={isSaveExploreModelOverrideError} + isLoading={isLoadingModelConfigs} + onSaveModelOverride={onSaveExploreModelOverride} + isSaving={isSavingExploreModelOverride} + isSaveError={isSaveExploreModelOverrideError} + saveErrorMessage="Failed to save Explore model override." showHeader={false} /> -
+
); }; diff --git a/site/src/pages/AgentsPage/AgentsPageView.stories.tsx b/site/src/pages/AgentsPage/AgentsPageView.stories.tsx index 180f9fd64b3ce..66fcb0b36e928 100644 --- a/site/src/pages/AgentsPage/AgentsPageView.stories.tsx +++ b/site/src/pages/AgentsPage/AgentsPageView.stories.tsx @@ -156,7 +156,11 @@ const fixedNow = dayjs("2026-03-12T12:00:00"); const AgentsRouteElement = () => ( { const canvas = within(canvasElement); - const link = await waitFor(() => - canvas.getByRole("link", { name: "Settings" }), + const settingsLink = canvas.queryByRole("link", { name: "Settings" }); + if (settingsLink) { + await userEvent.click(settingsLink); + return; + } + + const mobileMoreOptionsButton = canvas + .getAllByRole("button", { name: "More options" }) + .find((button) => button.getAttribute("aria-haspopup") === "menu"); + if (!mobileMoreOptionsButton) { + throw new Error("Expected a mobile More options menu button."); + } + await userEvent.click(mobileMoreOptionsButton); + const body = within(canvasElement.ownerDocument.body); + await userEvent.click( + await body.findByRole("menuitem", { name: "Settings" }), ); - await userEvent.click(link); }; export const OpensAnalyticsForAdmins: Story = { diff --git a/site/src/pages/AgentsPage/components/ExploreModelOverrideSettings.tsx b/site/src/pages/AgentsPage/components/ExploreModelOverrideSettings.tsx deleted file mode 100644 index ed0c275347e45..0000000000000 --- a/site/src/pages/AgentsPage/components/ExploreModelOverrideSettings.tsx +++ /dev/null @@ -1,176 +0,0 @@ -import { useFormik } from "formik"; -import type { FC } from "react"; -import type * as TypesGen from "#/api/typesGenerated"; -import { Alert, AlertDescription } from "#/components/Alert/Alert"; -import { Button } from "#/components/Button/Button"; -import type { ModelSelectorOption } from "./ChatElements/ModelSelector"; -import { ModelSelector } from "./ChatElements/ModelSelector"; - -interface MutationCallbacks { - onSuccess?: () => void; - onError?: () => void; -} - -interface ExploreModelOverrideSettingsProps { - exploreModelOverrideData: - | TypesGen.ChatExploreModelOverrideResponse - | undefined; - modelConfigs: readonly TypesGen.ChatModelConfig[]; - modelConfigsError: unknown; - isLoadingModelConfigs: boolean; - onSaveExploreModelOverride: ( - req: TypesGen.UpdateChatExploreModelOverrideRequest, - options?: MutationCallbacks, - ) => void; - isSavingExploreModelOverride: boolean; - isSaveExploreModelOverrideError: boolean; - showHeader?: boolean; -} - -const toModelSelectorOption = ( - modelConfig: TypesGen.ChatModelConfig, -): ModelSelectorOption => ({ - id: modelConfig.id, - provider: modelConfig.provider, - model: modelConfig.model, - displayName: modelConfig.display_name.trim() || modelConfig.model, - contextLimit: modelConfig.context_limit, -}); - -export const ExploreModelOverrideSettings: FC< - ExploreModelOverrideSettingsProps -> = ({ - exploreModelOverrideData, - modelConfigs, - modelConfigsError, - isLoadingModelConfigs, - onSaveExploreModelOverride, - isSavingExploreModelOverride, - isSaveExploreModelOverrideError, - showHeader = true, -}) => { - const hasLoadedExploreModelOverride = exploreModelOverrideData !== undefined; - const enabledModelOptions = modelConfigs - .filter((modelConfig) => modelConfig.enabled) - .map(toModelSelectorOption); - - const form = useFormik({ - enableReinitialize: true, - initialValues: { - model_config_id: exploreModelOverrideData?.model_config_id ?? "", - }, - onSubmit: (values, { resetForm }) => { - onSaveExploreModelOverride( - { - model_config_id: values.model_config_id || undefined, - }, - { - onSuccess: () => { - resetForm({ values }); - }, - }, - ); - }, - }); - - const isUnavailableSavedModel = - form.values.model_config_id !== "" && - !enabledModelOptions.some( - (option) => option.id === form.values.model_config_id, - ); - const hasMalformedOverride = - exploreModelOverrideData?.has_malformed_override ?? false; - const isExploreModelOverrideDisabled = - isSavingExploreModelOverride || - isLoadingModelConfigs || - !hasLoadedExploreModelOverride; - const canSaveExploreModelOverride = - hasLoadedExploreModelOverride && (form.dirty || hasMalformedOverride); - - return ( -
- {showHeader && ( - <> -
-

- Explore subagent model -

-
-

- Optional deployment-wide model override for read-only Explore - subagents spawned with spawn_agent using - type=explore. -

- - )} -
- - form.setFieldValue("model_config_id", value) - } - disabled={isExploreModelOverrideDisabled} - placeholder={ - isUnavailableSavedModel ? "Unavailable model" : "Use chat default" - } - emptyMessage={ - isLoadingModelConfigs - ? "Loading models..." - : "No enabled models found." - } - className="h-10 w-full justify-between rounded-md border border-border border-solid bg-transparent px-3 text-sm shadow-sm" - contentClassName="min-w-[18rem]" - /> -
- {isUnavailableSavedModel && ( - - - The saved model is no longer enabled and will be ignored until you - choose a new override. - - - )} - {hasMalformedOverride && ( - - - The saved override is malformed and is being treated as unset. Click - Save to clear it. - - - )} - {Boolean(modelConfigsError) && ( -

- Failed to load model configs. -

- )} -
- - -
- {isSaveExploreModelOverrideError && ( -

- Failed to save Explore model override. -

- )} -
- ); -}; diff --git a/site/src/pages/AgentsPage/components/SubagentModelOverrideSettings.tsx b/site/src/pages/AgentsPage/components/SubagentModelOverrideSettings.tsx new file mode 100644 index 0000000000000..9e46313ea1e35 --- /dev/null +++ b/site/src/pages/AgentsPage/components/SubagentModelOverrideSettings.tsx @@ -0,0 +1,174 @@ +import { useFormik } from "formik"; +import type { FC, ReactNode } from "react"; +import type * as TypesGen from "#/api/typesGenerated"; +import { Alert, AlertDescription } from "#/components/Alert/Alert"; +import { Button } from "#/components/Button/Button"; +import type { ModelSelectorOption } from "./ChatElements/ModelSelector"; +import { ModelSelector } from "./ChatElements/ModelSelector"; + +export interface MutationCallbacks { + onSuccess?: () => void; + onError?: () => void; +} + +interface ModelOverrideData { + readonly model_config_id: string; + readonly is_malformed: boolean; +} + +interface UpdateModelOverrideRequest { + readonly model_config_id: string; +} + +interface SubagentModelOverrideSettingsProps { + title: string; + description: ReactNode; + modelOverrideData: ModelOverrideData | undefined; + enabledModelConfigs: readonly TypesGen.ChatModelConfig[]; + modelConfigsError: unknown; + isLoading: boolean; + onSaveModelOverride: ( + req: UpdateModelOverrideRequest, + options?: MutationCallbacks, + ) => void; + isSaving: boolean; + isSaveError: boolean; + saveErrorMessage: string; + showHeader?: boolean; + disabled?: boolean; +} + +const toModelSelectorOption = ( + modelConfig: TypesGen.ChatModelConfig, +): ModelSelectorOption => ({ + id: modelConfig.id, + provider: modelConfig.provider, + model: modelConfig.model, + displayName: modelConfig.display_name.trim() || modelConfig.model, + contextLimit: modelConfig.context_limit, +}); + +export const SubagentModelOverrideSettings: FC< + SubagentModelOverrideSettingsProps +> = ({ + title, + description, + modelOverrideData, + enabledModelConfigs, + modelConfigsError, + isLoading, + onSaveModelOverride, + isSaving, + isSaveError, + saveErrorMessage, + showHeader = true, + disabled = false, +}) => { + const hasLoadedModelOverride = modelOverrideData !== undefined; + const enabledModelOptions = enabledModelConfigs.map(toModelSelectorOption); + + const form = useFormik({ + enableReinitialize: true, + initialValues: { + model_config_id: modelOverrideData?.model_config_id ?? "", + }, + onSubmit: (values, { resetForm }) => { + onSaveModelOverride( + { + model_config_id: values.model_config_id, + }, + { + onSuccess: () => { + resetForm({ values }); + }, + }, + ); + }, + }); + + const isUnavailableSavedModel = + form.values.model_config_id !== "" && + !enabledModelOptions.some( + (option) => option.id === form.values.model_config_id, + ); + const isMalformedOverride = modelOverrideData?.is_malformed ?? false; + const isModelOverrideDisabled = + disabled || isSaving || isLoading || !hasLoadedModelOverride; + const canSaveModelOverride = + hasLoadedModelOverride && (form.dirty || isMalformedOverride); + + return ( +
+ {showHeader && ( + <> +

+ {title} +

+

+ {description} +

+ + )} + form.setFieldValue("model_config_id", value)} + disabled={isModelOverrideDisabled} + placeholder={ + isUnavailableSavedModel ? "Unavailable model" : "Use chat default" + } + emptyMessage={ + isLoading ? "Loading models..." : "No enabled models found." + } + className="h-10 w-full justify-between rounded-md border border-border border-solid bg-transparent px-3 text-sm shadow-sm" + contentClassName="min-w-[18rem]" + /> + {isUnavailableSavedModel && ( + + + The saved model is no longer enabled and will be ignored until you + choose a new override. + + + )} + {isMalformedOverride && ( + + + The saved override is malformed and is being treated as unset. Click + Save to clear it. + + + )} + {Boolean(modelConfigsError) && ( +

+ Failed to load model configs. +

+ )} +
+ + +
+ {isSaveError && ( +

+ {saveErrorMessage} +

+ )} + + ); +};