diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 49cac265d35b8..1db74da36f093 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1517,10 +1517,11 @@ func ChatQueuedMessage(message database.ChatQueuedMessage) codersdk.ChatQueuedMe } return codersdk.ChatQueuedMessage{ - ID: message.ID, - ChatID: message.ChatID, - Content: parts, - CreatedAt: message.CreatedAt, + ID: message.ID, + ChatID: message.ChatID, + ModelConfigID: nullUUIDPtr(message.ModelConfigID), + Content: parts, + CreatedAt: message.CreatedAt, } } diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 9706259cd5234..a0b2c8f933080 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1422,7 +1422,8 @@ CREATE TABLE chat_queued_messages ( id bigint NOT NULL, chat_id uuid NOT NULL, content jsonb NOT NULL, - created_at timestamp with time zone DEFAULT now() NOT NULL + created_at timestamp with time zone DEFAULT now() NOT NULL, + model_config_id uuid ); CREATE SEQUENCE chat_queued_messages_id_seq diff --git a/coderd/database/migrations/000477_chat_queued_message_model_config.down.sql b/coderd/database/migrations/000477_chat_queued_message_model_config.down.sql new file mode 100644 index 0000000000000..aa655e7a9c1fa --- /dev/null +++ b/coderd/database/migrations/000477_chat_queued_message_model_config.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE chat_queued_messages +DROP COLUMN model_config_id; diff --git a/coderd/database/migrations/000477_chat_queued_message_model_config.up.sql b/coderd/database/migrations/000477_chat_queued_message_model_config.up.sql new file mode 100644 index 0000000000000..fb4fc16410164 --- /dev/null +++ b/coderd/database/migrations/000477_chat_queued_message_model_config.up.sql @@ -0,0 +1,8 @@ +ALTER TABLE chat_queued_messages +ADD COLUMN model_config_id uuid; + +UPDATE chat_queued_messages AS cqm +SET model_config_id = chats.last_model_config_id +FROM chats +WHERE chats.id = cqm.chat_id + AND cqm.model_config_id IS NULL; diff --git a/coderd/database/models.go b/coderd/database/models.go index d5ffc05e82cf2..e406ee9e1c2d5 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4509,10 +4509,11 @@ type ChatProvider struct { } type ChatQueuedMessage struct { - ID int64 `db:"id" json:"id"` - ChatID uuid.UUID `db:"chat_id" json:"chat_id"` - Content json.RawMessage `db:"content" json:"content"` - CreatedAt time.Time `db:"created_at" json:"created_at"` + ID int64 `db:"id" json:"id"` + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + Content json.RawMessage `db:"content" json:"content"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"` } type ChatUsageLimitConfig struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 355bdff168bf4..588720f15483f 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -6612,7 +6612,7 @@ func (q *sqlQuerier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]Get } const getChatQueuedMessages = `-- name: GetChatQueuedMessages :many -SELECT id, chat_id, content, created_at FROM chat_queued_messages +SELECT id, chat_id, content, created_at, model_config_id FROM chat_queued_messages WHERE chat_id = $1 ORDER BY id ASC ` @@ -6631,6 +6631,7 @@ func (q *sqlQuerier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID &i.ChatID, &i.Content, &i.CreatedAt, + &i.ModelConfigID, ); err != nil { return nil, err } @@ -7501,24 +7502,30 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa } const insertChatQueuedMessage = `-- name: InsertChatQueuedMessage :one -INSERT INTO chat_queued_messages (chat_id, content) -VALUES ($1, $2) -RETURNING id, chat_id, content, created_at +INSERT INTO chat_queued_messages (chat_id, content, model_config_id) +VALUES ( + $1, + $2, + $3::uuid +) +RETURNING id, chat_id, content, created_at, model_config_id ` type InsertChatQueuedMessageParams struct { - ChatID uuid.UUID `db:"chat_id" json:"chat_id"` - Content json.RawMessage `db:"content" json:"content"` + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + Content json.RawMessage `db:"content" json:"content"` + ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"` } func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error) { - row := q.db.QueryRowContext(ctx, insertChatQueuedMessage, arg.ChatID, arg.Content) + row := q.db.QueryRowContext(ctx, insertChatQueuedMessage, arg.ChatID, arg.Content, arg.ModelConfigID) var i ChatQueuedMessage err := row.Scan( &i.ID, &i.ChatID, &i.Content, &i.CreatedAt, + &i.ModelConfigID, ) return i, err } @@ -7743,7 +7750,7 @@ WHERE id = ( ORDER BY cqm.id ASC LIMIT 1 ) -RETURNING id, chat_id, content, created_at +RETURNING id, chat_id, content, created_at, model_config_id ` func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (ChatQueuedMessage, error) { @@ -7754,6 +7761,7 @@ func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) &i.ChatID, &i.Content, &i.CreatedAt, + &i.ModelConfigID, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 3365e943bbb2d..48852f3ef05c2 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -882,8 +882,12 @@ RETURNING *; -- name: InsertChatQueuedMessage :one -INSERT INTO chat_queued_messages (chat_id, content) -VALUES (@chat_id, @content) +INSERT INTO chat_queued_messages (chat_id, content, model_config_id) +VALUES ( + @chat_id, + @content, + sqlc.narg('model_config_id')::uuid +) RETURNING *; -- name: GetChatQueuedMessages :many diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 78b39a156d1f1..ea4a15affa49e 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -2531,13 +2531,18 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { return } + modelConfigID := uuid.Nil + if req.ModelConfigID != nil { + modelConfigID = *req.ModelConfigID + } + sendResult, sendErr := api.chatDaemon.SendMessage( ctx, chatd.SendMessageOptions{ ChatID: chatID, CreatedBy: apiKey.UserID, Content: contentBlocks, - ModelConfigID: req.ModelConfigID, + ModelConfigID: modelConfigID, BusyBehavior: busyBehavior, PlanMode: sendPlanMode, MCPServerIDs: req.MCPServerIDs, @@ -2560,6 +2565,12 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { }) return } + if xerrors.Is(sendErr, chatd.ErrInvalidModelConfigID) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid model config ID.", + }) + return + } httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Failed to create chat message.", Detail: sendErr.Error(), diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index c3f21d1df4fbb..730246b220ca1 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -5870,6 +5870,329 @@ func TestPostChatMessages(t *testing.T) { }) } +func waitForChatWatchStatusChangeEvent( + ctx context.Context, + t *testing.T, + conn *websocket.Conn, + chatID uuid.UUID, +) codersdk.ChatWatchEvent { + t.Helper() + + for { + var payload codersdk.ChatWatchEvent + err := wsjson.Read(ctx, conn, &payload) + require.NoError(t, err) + if payload.Kind == codersdk.ChatWatchEventKindStatusChange && payload.Chat.ID == chatID { + return payload + } + } +} + +func TestSendMessageWithModelOverrideUpdatesLastModelConfigID(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-override-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "mid-chat model switch direct send", + }) + require.NoError(t, err) + + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "switch to model b", + }}, + ModelConfigID: ptr.Ref(modelConfigB.ID), + }) + require.NoError(t, err) + require.False(t, resp.Queued) + require.NotNil(t, resp.Message) + require.NotNil(t, resp.Message.ModelConfigID) + require.Equal(t, modelConfigB.ID, *resp.Message.ModelConfigID) + + storedChat, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfigB.ID, storedChat.LastModelConfigID) + + messages, err := db.GetChatMessagesByChatID(dbauthz.AsSystemRestricted(ctx), database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + require.Len(t, messages, 1) + require.True(t, messages[0].ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, messages[0].ModelConfigID.UUID) +} + +func TestSendMessageQueuesEffectiveModelConfigID(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-queued-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "mid-chat model switch queued send", + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "queue this with model b", + }}, + ModelConfigID: ptr.Ref(modelConfigB.ID), + BusyBehavior: codersdk.ChatBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, resp.Queued) + require.NotNil(t, resp.QueuedMessage) + require.NotNil(t, resp.QueuedMessage.ModelConfigID) + require.Equal(t, modelConfigB.ID, *resp.QueuedMessage.ModelConfigID) + + queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Len(t, queuedMessages, 1) + require.True(t, queuedMessages[0].ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, queuedMessages[0].ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfigA.ID, storedChat.LastModelConfigID) +} + +func TestQueuedMessageWithoutOverrideCapturesEnqueueTimeModel(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-later-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "capture queued enqueue-time model", + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "queue with stored model", + }}, + BusyBehavior: codersdk.ChatBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, resp.Queued) + require.NotNil(t, resp.QueuedMessage) + require.NotNil(t, resp.QueuedMessage.ModelConfigID) + require.Equal(t, modelConfigA.ID, *resp.QueuedMessage.ModelConfigID) + + _, err = db.UpdateChatLastModelConfigByID(dbauthz.AsSystemRestricted(ctx), database.UpdateChatLastModelConfigByIDParams{ + ID: chat.ID, + LastModelConfigID: modelConfigB.ID, + }) + require.NoError(t, err) + + queuedMessages, err := db.GetChatQueuedMessages(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Len(t, queuedMessages, 1) + require.True(t, queuedMessages[0].ModelConfigID.Valid) + require.Equal(t, modelConfigA.ID, queuedMessages[0].ModelConfigID.UUID) +} + +func TestSubsequentSendWithoutOverrideUsesPersistedModel(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-persisted-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigB.ID, + Title: "subsequent send uses persisted model", + }) + require.NoError(t, err) + + resp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "reuse the persisted model", + }}, + }) + require.NoError(t, err) + require.False(t, resp.Queued) + require.NotNil(t, resp.Message) + require.NotNil(t, resp.Message.ModelConfigID) + require.Equal(t, modelConfigB.ID, *resp.Message.ModelConfigID) + + messages, err := db.GetChatMessagesByChatID(dbauthz.AsSystemRestricted(ctx), database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + require.Len(t, messages, 1) + require.True(t, messages[0].ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, messages[0].ModelConfigID.UUID) +} + +func TestWatchChatsStatusChangeCarriesUpdatedLastModelConfigID(t *testing.T) { + t.Parallel() + + t.Run("DirectSend", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-watch-direct-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "watch direct model switch", + }) + require.NoError(t, err) + + conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil) + require.NoError(t, err) + defer conn.Close(websocket.StatusNormalClosure, "done") + + _, err = client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "watch the direct send override", + }}, + ModelConfigID: ptr.Ref(modelConfigB.ID), + }) + require.NoError(t, err) + + event := waitForChatWatchStatusChangeEvent(ctx, t, conn, chat.ID) + require.Equal(t, modelConfigB.ID, event.Chat.LastModelConfigID) + }) + + t.Run("QueuedPromotion", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfigA := createChatModelConfig(t, client) + modelConfigB := createAdditionalChatModelConfig(t, client, "openai", "gpt-4o-mini-watch-promote-"+uuid.NewString()) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OrganizationID: user.OrganizationID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.UserID, + LastModelConfigID: modelConfigA.ID, + Title: "watch queued promotion model switch", + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + WorkerID: uuid.NullUUID{UUID: uuid.New(), Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + queuedResp, err := client.CreateChatMessage(ctx, chat.ID, codersdk.CreateChatMessageRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "queue the promoted model override", + }}, + ModelConfigID: ptr.Ref(modelConfigB.ID), + BusyBehavior: codersdk.ChatBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queuedResp.Queued) + require.NotNil(t, queuedResp.QueuedMessage) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusWaiting, + WorkerID: uuid.NullUUID{}, + StartedAt: sql.NullTime{}, + HeartbeatAt: sql.NullTime{}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + conn, err := client.Dial(ctx, "/api/experimental/chats/watch", nil) + require.NoError(t, err) + defer conn.Close(websocket.StatusNormalClosure, "done") + + promoteRes, err := client.Request( + ctx, + http.MethodPost, + fmt.Sprintf("/api/experimental/chats/%s/queue/%d/promote", chat.ID, queuedResp.QueuedMessage.ID), + nil, + ) + require.NoError(t, err) + defer promoteRes.Body.Close() + require.Equal(t, http.StatusOK, promoteRes.StatusCode) + + event := waitForChatWatchStatusChangeEvent(ctx, t, conn, chat.ID) + require.Equal(t, modelConfigB.ID, event.Chat.LastModelConfigID) + }) +} + func TestChatMessageWithFileReferences(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index bf6ddcd4ee424..fe9c206279160 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -871,6 +871,8 @@ func (c *streamStateCollector) Collect(ch chan<- prometheus.Metric) { const MaxQueueSize = 20 var ( + // ErrInvalidModelConfigID indicates the requested model config does not exist. + ErrInvalidModelConfigID = xerrors.New("invalid model config ID") // ErrMessageQueueFull indicates the per-chat queue limit was reached. ErrMessageQueueFull = xerrors.New("chat message queue is full") // ErrEditedMessageNotFound indicates the edited message does not exist @@ -950,7 +952,7 @@ type SendMessageOptions struct { ChatID uuid.UUID CreatedBy uuid.UUID Content []codersdk.ChatMessagePart - ModelConfigID *uuid.UUID + ModelConfigID uuid.UUID BusyBehavior SendMessageBusyBehavior PlanMode *database.NullChatPlanMode MCPServerIDs *[]uuid.UUID @@ -983,7 +985,6 @@ type PromoteQueuedOptions struct { ChatID uuid.UUID CreatedBy uuid.UUID QueuedMessageID int64 - ModelConfigID *uuid.UUID } // PromoteQueuedResult contains post-promotion message metadata. @@ -1217,9 +1218,14 @@ func (p *Server) SendMessage( } } - modelConfigID := lockedChat.LastModelConfigID - if opts.ModelConfigID != nil { - modelConfigID = *opts.ModelConfigID + modelConfigID, err := resolveSendMessageModelConfigID( + ctx, + tx, + lockedChat, + opts.ModelConfigID, + ) + if err != nil { + return err } // Update MCP server IDs on the chat when explicitly provided. @@ -1264,6 +1270,10 @@ func (p *Server) SendMessage( queued, err := tx.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ ChatID: opts.ChatID, Content: content.RawMessage, + ModelConfigID: uuid.NullUUID{ + UUID: modelConfigID, + Valid: modelConfigID != uuid.Nil, + }, }) if err != nil { return xerrors.Errorf("insert queued message: %w", err) @@ -1368,6 +1378,90 @@ func (p *Server) checkUsageLimit(ctx context.Context, store database.Store, owne return nil } +func chatdModelConfigLookupContext(ctx context.Context) context.Context { + //nolint:gocritic // Chat message admission needs daemon-scoped + // deployment-config reads for model config validation. + return dbauthz.AsChatd(ctx) +} + +func resolveSendMessageModelConfigID( + ctx context.Context, + store database.Store, + chat database.Chat, + requested uuid.UUID, +) (uuid.UUID, error) { + if requested == uuid.Nil { + return resolveFallbackModelConfigID(ctx, store, chat.LastModelConfigID) + } + + chatdCtx := chatdModelConfigLookupContext(ctx) + if _, err := store.GetChatModelConfigByID(chatdCtx, requested); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, xerrors.Errorf( + "%w: %s", + ErrInvalidModelConfigID, + requested, + ) + } + return uuid.Nil, xerrors.Errorf( + "get requested model config %s: %w", + requested, + err, + ) + } + return requested, nil +} + +func resolveQueuedMessageModelConfigID( + ctx context.Context, + store database.Store, + chat database.Chat, + queuedModelConfigID uuid.NullUUID, +) (uuid.UUID, error) { + chatdCtx := chatdModelConfigLookupContext(ctx) + if queuedModelConfigID.Valid && queuedModelConfigID.UUID != uuid.Nil { + if _, err := store.GetChatModelConfigByID(chatdCtx, queuedModelConfigID.UUID); err == nil { + return queuedModelConfigID.UUID, nil + } else if !errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, xerrors.Errorf( + "get queued model config %s: %w", + queuedModelConfigID.UUID, + err, + ) + } + } + + return resolveFallbackModelConfigID(ctx, store, chat.LastModelConfigID) +} + +func resolveFallbackModelConfigID( + ctx context.Context, + store database.Store, + modelConfigID uuid.UUID, +) (uuid.UUID, error) { + chatdCtx := chatdModelConfigLookupContext(ctx) + if modelConfigID != uuid.Nil { + if _, err := store.GetChatModelConfigByID(chatdCtx, modelConfigID); err == nil { + return modelConfigID, nil + } else if !errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, xerrors.Errorf( + "get chat model config %s: %w", + modelConfigID, + err, + ) + } + } + + defaultConfig, err := store.GetDefaultChatModelConfig(chatdCtx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return uuid.Nil, xerrors.New("no default chat model config is available") + } + return uuid.Nil, xerrors.Errorf("get default chat model config: %w", err) + } + return defaultConfig.ID, nil +} + // EditMessage marks the old user message as deleted, soft-deletes all // following messages, inserts a new message with the updated content, // clears queued messages, and moves the chat into pending status. @@ -1768,23 +1862,20 @@ func (p *Server) PromoteQueued( return ErrChatArchived } - modelConfigID := lockedChat.LastModelConfigID - if opts.ModelConfigID != nil { - modelConfigID = *opts.ModelConfigID - } - queuedMessages, err := tx.GetChatQueuedMessages(ctx, opts.ChatID) if err != nil { return xerrors.Errorf("get queued messages: %w", err) } var ( - targetContent json.RawMessage - found bool + targetContent json.RawMessage + targetModelConfigID uuid.NullUUID + found bool ) for _, qm := range queuedMessages { if qm.ID == opts.QueuedMessageID { targetContent = qm.Content + targetModelConfigID = qm.ModelConfigID found = true break } @@ -1793,6 +1884,16 @@ func (p *Server) PromoteQueued( return xerrors.New("queued message not found") } + effectiveModelConfigID, err := resolveQueuedMessageModelConfigID( + ctx, + tx, + lockedChat, + targetModelConfigID, + ) + if err != nil { + return err + } + err = tx.DeleteChatQueuedMessage(ctx, database.DeleteChatQueuedMessageParams{ ID: opts.QueuedMessageID, ChatID: opts.ChatID, @@ -1805,7 +1906,7 @@ func (p *Server) PromoteQueued( ctx, tx, lockedChat, - modelConfigID, + effectiveModelConfigID, pqtype.NullRawMessage{ RawMessage: targetContent, Valid: len(targetContent) > 0, @@ -3313,6 +3414,8 @@ func BuildSingleChatMessageInsertParams( return params } +// insertUserMessageAndSetPending inserts a user message, transitions the +// chat to pending when needed, and returns the refreshed chat row. func insertUserMessageAndSetPending( ctx context.Context, store database.Store, @@ -3338,7 +3441,16 @@ func insertUserMessageAndSetPending( message := messages[0] if lockedChat.Status == database.ChatStatusPending { - return message, lockedChat, nil + if modelConfigID == uuid.Nil || lockedChat.LastModelConfigID == modelConfigID { + return message, lockedChat, nil + } + // The InsertChatMessages CTE updates chats.last_model_config_id when + // the message's model config differs. Reload to surface that change. + updatedChat, err := store.GetChatByID(ctx, lockedChat.ID) + if err != nil { + return database.ChatMessage{}, database.Chat{}, xerrors.Errorf("get chat after model config update: %w", err) + } + return message, updatedChat, nil } updatedChat, err := store.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ @@ -4752,13 +4864,31 @@ func (p *Server) tryAutoPromoteQueuedMessage( ) (*database.ChatMessage, []database.ChatQueuedMessage, bool, error) { logger := p.logger.With(slog.F("chat_id", chat.ID)) - nextQueued, err := tx.PopNextQueuedMessage(ctx, chat.ID) - if errors.Is(err, sql.ErrNoRows) { + queuedMessages, err := tx.GetChatQueuedMessages(ctx, chat.ID) + if err != nil { + return nil, nil, false, xerrors.Errorf("get queued messages: %w", err) + } + if len(queuedMessages) == 0 { return nil, nil, false, nil } + nextQueued := queuedMessages[0] + effectiveModelConfigID, err := resolveQueuedMessageModelConfigID( + ctx, + tx, + chat, + nextQueued.ModelConfigID, + ) + if err != nil { + return nil, nil, false, err + } + + poppedQueued, err := tx.PopNextQueuedMessage(ctx, chat.ID) if err != nil { return nil, nil, false, xerrors.Errorf("pop next queued message: %w", err) } + if poppedQueued.ID != nextQueued.ID { + return nil, nil, false, xerrors.New("popped queued message out of order") + } msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. ChatID: chat.ID, @@ -4770,7 +4900,7 @@ func (p *Server) tryAutoPromoteQueuedMessage( Valid: len(nextQueued.Content) > 0, }, database.ChatMessageVisibilityBoth, - chat.LastModelConfigID, + effectiveModelConfigID, chatprompt.CurrentContentVersion, ).withCreatedBy(chat.OwnerID)) msgs, err := insertChatMessageWithStore(ctx, tx, msgParams) diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 7f545974c36ad..beb1216315dac 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -39,6 +39,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/workspacestats" @@ -2045,6 +2046,38 @@ func TestSendMessageQueuesWhenWaitingWithQueuedBacklog(t *testing.T) { require.Len(t, messages, 1) } +func TestSendMessageRejectsInvalidQueuedModelConfigID(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfig := seedChatDependencies(ctx, t, db) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusPending, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfig.ID, + Title: "reject invalid queued model config", + }) + require.NoError(t, err) + + invalidModelConfigID := uuid.New() + _, err = replica.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")}, + ModelConfigID: invalidModelConfigID, + }) + require.ErrorIs(t, err, chatd.ErrInvalidModelConfigID) + + queued, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Empty(t, queued) +} + func TestSendMessageInterruptBehaviorQueuesAndInterruptsWhenBusy(t *testing.T) { t.Parallel() @@ -2501,6 +2534,463 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing require.Equal(t, database.ChatMessageRoleUser, messages[3].Role) } +func TestPromoteQueuedMessageUsesQueuedModelConfigID(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfigA := seedChatDependencies(ctx, t, db) + modelConfigB := insertChatModelConfigWithCallConfig( + ctx, + t, + db, + user.ID, + "openai", + "gpt-4o-mini-promote-"+uuid.NewString(), + codersdk.ChatModelCallConfig{}, + ) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfigA.ID, + Title: "promote queued uses stored model", + }) + require.NoError(t, err) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("queued with model b")}) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + ModelConfigID: uuid.NullUUID{ + UUID: modelConfigB.ID, + Valid: true, + }, + }) + require.NoError(t, err) + + result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.True(t, result.PromotedMessage.ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, result.PromotedMessage.ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfigB.ID, storedChat.LastModelConfigID) + require.Equal(t, database.ChatStatusPending, storedChat.Status) +} + +func TestPromoteQueuedMessageReloadsChatWhenModelConfigChangesDuringPending(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfigA := seedChatDependencies(ctx, t, db) + modelConfigB := insertChatModelConfigWithCallConfig( + ctx, + t, + db, + user.ID, + "openai", + "gpt-4o-mini-promote-pending-"+uuid.NewString(), + codersdk.ChatModelCallConfig{}, + ) + + watchEvents := make(chan struct { + payload codersdk.ChatWatchEvent + err error + }, 1) + cancelWatch, err := ps.SubscribeWithErr( + coderdpubsub.ChatWatchEventChannel(user.ID), + coderdpubsub.HandleChatWatchEvent(func(_ context.Context, payload codersdk.ChatWatchEvent, err error) { + select { + case watchEvents <- struct { + payload codersdk.ChatWatchEvent + err error + }{payload: payload, err: err}: + default: + } + }), + ) + require.NoError(t, err) + defer cancelWatch() + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusPending, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfigA.ID, + Title: "promote queued reloads pending chat", + }) + require.NoError(t, err) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("queued with new model")}) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + ModelConfigID: uuid.NullUUID{ + UUID: modelConfigB.ID, + Valid: true, + }, + }) + require.NoError(t, err) + + result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.True(t, result.PromotedMessage.ModelConfigID.Valid) + require.Equal(t, modelConfigB.ID, result.PromotedMessage.ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusPending, storedChat.Status) + require.Equal(t, modelConfigB.ID, storedChat.LastModelConfigID) + + select { + case event := <-watchEvents: + require.NoError(t, event.err) + require.Equal(t, codersdk.ChatWatchEventKindStatusChange, event.payload.Kind) + require.Equal(t, chat.ID, event.payload.Chat.ID) + require.Equal(t, codersdk.ChatStatusPending, event.payload.Chat.Status) + require.Equal(t, modelConfigB.ID, event.payload.Chat.LastModelConfigID) + case <-ctx.Done(): + t.Fatal("timed out waiting for status change watch event") + } +} + +func TestAutoPromoteQueuedMessagesPreservesPerTurnModelOrder(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + firstRunStarted := make(chan struct{}) + allowFirstRunFinish := make(chan struct{}) + var requestCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + + switch requestCount.Add(1) { + case 1: + chunks := make(chan chattest.OpenAIChunk, 1) + go func() { + defer close(chunks) + chunks <- chattest.OpenAITextChunks("first run partial")[0] + select { + case <-firstRunStarted: + default: + close(firstRunStarted) + } + <-allowFirstRunFinish + }() + return chattest.OpenAIResponse{StreamingChunks: chunks} + case 2: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("second run done")...) + case 3: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("third run done")...) + default: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("extra run done")...) + } + }) + + server := newActiveTestServer(t, db, ps) + user, org, modelConfigA := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + modelConfigB := insertChatModelConfigWithCallConfig( + ctx, + t, + db, + user.ID, + "openai-compat", + "gpt-4o-mini-queue-b-"+uuid.NewString(), + codersdk.ChatModelCallConfig{}, + ) + modelConfigC := insertChatModelConfigWithCallConfig( + ctx, + t, + db, + user.ID, + "openai-compat", + "gpt-4o-mini-queue-c-"+uuid.NewString(), + codersdk.ChatModelCallConfig{}, + ) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "auto-promote per-turn model order", + ModelConfigID: modelConfigA.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + testutil.TryReceive(ctx, t, firstRunStarted) + + queuedB, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued b")}, + ModelConfigID: modelConfigB.ID, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queuedB.Queued) + + queuedC, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued c")}, + ModelConfigID: modelConfigC.ID, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + require.True(t, queuedC.Queued) + + close(allowFirstRunFinish) + + require.Eventually(t, func() bool { + return requestCount.Load() >= 3 + }, testutil.WaitLong, testutil.IntervalFast) + chatd.WaitUntilIdleForTest(server) + + queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Empty(t, queuedMessages) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusWaiting, storedChat.Status) + require.Equal(t, modelConfigC.ID, storedChat.LastModelConfigID) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var userTexts []string + var userModelConfigIDs []uuid.UUID + for _, message := range messages { + if message.Role != database.ChatMessageRoleUser { + continue + } + sdkMessage := db2sdk.ChatMessage(message) + require.Len(t, sdkMessage.Content, 1) + userTexts = append(userTexts, sdkMessage.Content[0].Text) + require.True(t, message.ModelConfigID.Valid) + userModelConfigIDs = append(userModelConfigIDs, message.ModelConfigID.UUID) + } + require.Equal(t, []string{"hello", "queued b", "queued c"}, userTexts) + require.Equal(t, []uuid.UUID{modelConfigA.ID, modelConfigB.ID, modelConfigC.ID}, userModelConfigIDs) +} + +func TestAutoPromoteQueuedMessageFallsBackForLegacyQueuedRows(t *testing.T) { + t.Parallel() + + testAutoPromoteQueuedMessageFallback(t, uuid.NullUUID{}) +} + +func TestAutoPromoteQueuedMessageFallsBackForInvalidQueuedModelConfigID(t *testing.T) { + t.Parallel() + + testAutoPromoteQueuedMessageFallback(t, uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }) +} + +func testAutoPromoteQueuedMessageFallback(t *testing.T, queuedModelConfigID uuid.NullUUID) { + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + firstRunStarted := make(chan struct{}) + allowFirstRunFinish := make(chan struct{}) + var requestCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + + switch requestCount.Add(1) { + case 1: + chunks := make(chan chattest.OpenAIChunk, 1) + go func() { + defer close(chunks) + chunks <- chattest.OpenAITextChunks("first run partial")[0] + select { + case <-firstRunStarted: + default: + close(firstRunStarted) + } + <-allowFirstRunFinish + }() + return chattest.OpenAIResponse{StreamingChunks: chunks} + default: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("fallback run done")...) + } + }) + + server := newActiveTestServer(t, db, ps) + user, org, modelConfig := seedChatDependenciesWithProvider(ctx, t, db, "openai-compat", openAIURL) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "auto-promote queued fallback", + ModelConfigID: modelConfig.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + + testutil.TryReceive(ctx, t, firstRunStarted) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("legacy queued row")}) + require.NoError(t, err) + _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + ModelConfigID: queuedModelConfigID, + }) + require.NoError(t, err) + + close(allowFirstRunFinish) + + require.Eventually(t, func() bool { + return requestCount.Load() >= 2 + }, testutil.WaitLong, testutil.IntervalFast) + chatd.WaitUntilIdleForTest(server) + + queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID) + require.NoError(t, err) + require.Empty(t, queuedMessages) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusWaiting, storedChat.Status) + require.Equal(t, modelConfig.ID, storedChat.LastModelConfigID) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var found bool + for _, message := range messages { + if message.Role != database.ChatMessageRoleUser { + continue + } + sdkMessage := db2sdk.ChatMessage(message) + require.Len(t, sdkMessage.Content, 1) + if sdkMessage.Content[0].Text != "legacy queued row" { + continue + } + require.True(t, message.ModelConfigID.Valid) + require.Equal(t, modelConfig.ID, message.ModelConfigID.UUID) + found = true + } + require.True(t, found) +} + +func TestPromoteQueuedMessageFallsBackForLegacyQueuedRows(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfigA := seedChatDependencies(ctx, t, db) + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfigA.ID, + Title: "promote queued legacy fallback", + }) + require.NoError(t, err) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("legacy queued row")}) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + }) + require.NoError(t, err) + + result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.True(t, result.PromotedMessage.ModelConfigID.Valid) + require.Equal(t, modelConfigA.ID, result.PromotedMessage.ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfigA.ID, storedChat.LastModelConfigID) +} + +func TestPromoteQueuedMessageFallsBackForInvalidQueuedModelConfigID(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, modelConfig := seedChatDependencies(ctx, t, db) + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: org.ID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: user.ID, + LastModelConfigID: modelConfig.ID, + Title: "promote queued invalid fallback", + }) + require.NoError(t, err) + + queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("invalid queued model")}) + require.NoError(t, err) + queuedMessage, err := db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ + ChatID: chat.ID, + Content: queuedContent, + ModelConfigID: uuid.NullUUID{ + UUID: uuid.New(), + Valid: true, + }, + }) + require.NoError(t, err) + + result, err := replica.PromoteQueued(ctx, chatd.PromoteQueuedOptions{ + ChatID: chat.ID, + QueuedMessageID: queuedMessage.ID, + CreatedBy: user.ID, + }) + require.NoError(t, err) + require.True(t, result.PromotedMessage.ModelConfigID.Valid) + require.Equal(t, modelConfig.ID, result.PromotedMessage.ModelConfigID.UUID) + + storedChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, modelConfig.ID, storedChat.LastModelConfigID) +} + func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) { t.Parallel() diff --git a/codersdk/chats.go b/codersdk/chats.go index 7f66cb4203823..08e9acc54d8b6 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -1247,10 +1247,11 @@ const ( // ChatQueuedMessage represents a queued message waiting to be processed. type ChatQueuedMessage struct { - ID int64 `json:"id"` - ChatID uuid.UUID `json:"chat_id" format:"uuid"` - Content []ChatMessagePart `json:"content"` - CreatedAt time.Time `json:"created_at" format:"date-time"` + ID int64 `json:"id"` + ChatID uuid.UUID `json:"chat_id" format:"uuid"` + ModelConfigID *uuid.UUID `json:"model_config_id,omitempty" format:"uuid"` + Content []ChatMessagePart `json:"content"` + CreatedAt time.Time `json:"created_at" format:"date-time"` } // ChatStreamMessagePart is a streamed message part update. diff --git a/site/src/api/queries/chats.test.ts b/site/src/api/queries/chats.test.ts index 2501b600f132f..54eb5765695c1 100644 --- a/site/src/api/queries/chats.test.ts +++ b/site/src/api/queries/chats.test.ts @@ -25,6 +25,8 @@ import { infiniteChats, interruptChat, invalidateChatListQueries, + mergeWatchedChatIntoCaches, + mergeWatchedChatSummary, paginatedChatCostUsers, pinChat, promoteChatQueuedMessage, @@ -1892,6 +1894,382 @@ describe("updateChildInParentCache", () => { }); }); +describe("mergeWatchedChatSummary", () => { + it("merges fresh status updates without clobbering a newer title snapshot", () => { + const cachedChat = makeChat("chat-1", { + status: "pending", + title: "Fresh title", + last_model_config_id: "model-old", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "running", + title: "Stale title", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + }), + ).toMatchObject({ + status: "running", + title: "Fresh title", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + }); + + it("merges last_model_config_id when watched updated_at equals cached updated_at", () => { + const cachedChat = makeChat("chat-1", { + last_model_config_id: "11111111-1111-4111-8111-111111111111", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + last_model_config_id: "22222222-2222-4222-8222-222222222222", + updated_at: "2025-01-01T00:00:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + }).last_model_config_id, + ).toBe("22222222-2222-4222-8222-222222222222"); + }); + + it("compares updated_at values as instants instead of strings", () => { + const cachedChat = makeChat("chat-1", { + status: "pending", + last_model_config_id: "model-old", + updated_at: "2025-01-01T00:00:00.12Z", + }); + const watchedChat = makeChat("chat-1", { + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:00:00.1203Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + }), + ).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:00:00.1203Z", + }); + }); + + it("merges fresh title updates without clobbering a newer status snapshot", () => { + const cachedChat = makeChat("chat-1", { + status: "running", + title: "Fresh title", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + title: "Updated title", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "title_change", + }), + ).toMatchObject({ + status: "running", + title: "Updated title", + }); + }); + + it("merges title updates even when chat updated_at is older", () => { + const cachedChat = makeChat("chat-1", { + status: "running", + title: "Fresh title", + updated_at: "2025-01-01T00:10:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + title: "Newer generated title", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "title_change", + }), + ).toMatchObject({ + status: "running", + title: "Newer generated title", + updated_at: "2025-01-01T00:10:00.000Z", + }); + }); + + it("merges fresh diff status updates without clobbering status or title", () => { + const cachedDiffStatus = { + chat_id: "chat-1", + url: "https://example.com/pr/1", + pull_request_state: "open", + pull_request_title: "Old title", + pull_request_draft: false, + changes_requested: false, + additions: 1, + deletions: 2, + changed_files: 3, + refreshed_at: "2025-01-01T00:00:00.000Z", + stale_at: "2025-01-01T01:00:00.000Z", + }; + const watchedDiffStatus = { + chat_id: "chat-1", + url: "https://example.com/pr/2", + pull_request_state: "merged", + pull_request_title: "New title", + pull_request_draft: false, + changes_requested: true, + additions: 4, + deletions: 5, + changed_files: 6, + refreshed_at: "2025-01-01T00:05:00.000Z", + stale_at: "2025-01-01T01:05:00.000Z", + }; + const cachedChat = makeChat("chat-1", { + status: "running", + title: "Fresh title", + diff_status: cachedDiffStatus, + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + title: "Stale title", + diff_status: watchedDiffStatus, + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "diff_status_change", + }), + ).toMatchObject({ + status: "running", + title: "Fresh title", + diff_status: watchedDiffStatus, + }); + }); + + it("merges diff status updates even when chat updated_at is older", () => { + const cachedDiffStatus = { + chat_id: "chat-1", + url: "https://example.com/pr/1", + pull_request_state: "open", + pull_request_title: "Old title", + pull_request_draft: false, + changes_requested: false, + additions: 1, + deletions: 2, + changed_files: 3, + refreshed_at: "2025-01-01T00:00:00.000Z", + stale_at: "2025-01-01T01:00:00.000Z", + }; + const watchedDiffStatus = { + chat_id: "chat-1", + url: "https://example.com/pr/2", + pull_request_state: "open", + pull_request_title: "New title", + pull_request_draft: true, + changes_requested: true, + additions: 4, + deletions: 5, + changed_files: 6, + refreshed_at: "2025-01-01T00:10:00.000Z", + stale_at: "2025-01-01T01:10:00.000Z", + }; + const cachedChat = makeChat("chat-1", { + status: "running", + title: "Fresh title", + diff_status: cachedDiffStatus, + updated_at: "2025-01-01T00:10:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + title: "Stale title", + diff_status: watchedDiffStatus, + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "diff_status_change", + }), + ).toMatchObject({ + status: "running", + title: "Fresh title", + diff_status: watchedDiffStatus, + updated_at: "2025-01-01T00:10:00.000Z", + }); + }); + + it("marks other chats unread on fresh status updates", () => { + const cachedChat = makeChat("chat-1", { + has_unread: false, + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + activeChatId: "chat-2", + }).has_unread, + ).toBe(true); + }); + + it("preserves has_unread for the active chat", () => { + const cachedChat = makeChat("chat-1", { + has_unread: false, + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat("chat-1", { + status: "completed", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + expect( + mergeWatchedChatSummary(cachedChat, watchedChat, { + eventKind: "status_change", + activeChatId: "chat-1", + }).has_unread, + ).toBe(false); + }); +}); + +describe("mergeWatchedChatIntoCaches", () => { + it("merges last_model_config_id into the root list cache and per-chat cache", () => { + const queryClient = createTestQueryClient(); + const chatId = "chat-1"; + const cachedChat = makeChat(chatId, { + status: "pending", + last_model_config_id: "model-old", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const watchedChat = makeChat(chatId, { + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + seedInfiniteChats(queryClient, [cachedChat]); + queryClient.setQueryData(chatKey(chatId), cachedChat); + + mergeWatchedChatIntoCaches(queryClient, watchedChat, { + eventKind: "status_change", + }); + + expect(readInfiniteChats(queryClient)?.[0]).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + expect( + queryClient.getQueryData(chatKey(chatId)), + ).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + }); + + it("merges last_model_config_id into the parent-embedded child snapshot and child cache", () => { + const queryClient = createTestQueryClient(); + const childId = "child-1"; + const cachedChild = makeChat(childId, { + parent_chat_id: "parent-1", + root_chat_id: "parent-1", + status: "pending", + last_model_config_id: "model-old", + updated_at: "2025-01-01T00:00:00.000Z", + }); + const parent = makeChat("parent-1", { children: [cachedChild] }); + const watchedChild = makeChat(childId, { + parent_chat_id: "parent-1", + root_chat_id: "parent-1", + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + + seedInfiniteChats(queryClient, [parent]); + queryClient.setQueryData(chatKey(childId), cachedChild); + + mergeWatchedChatIntoCaches(queryClient, watchedChild, { + eventKind: "status_change", + }); + + expect(readInfiniteChats(queryClient)?.[0].children?.[0]).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + expect( + queryClient.getQueryData(chatKey(childId)), + ).toMatchObject({ + status: "running", + last_model_config_id: "model-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + }); + + it("does not let an older watch payload clobber newer cached metadata", () => { + const queryClient = createTestQueryClient(); + const chatId = "chat-1"; + const cachedChat = makeChat(chatId, { + status: "completed", + title: "Fresh title", + last_model_config_id: "model-new", + workspace_id: "workspace-new", + build_id: "build-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + const staleWatchChat = makeChat(chatId, { + status: "running", + title: "Stale title", + last_model_config_id: "model-old", + workspace_id: "workspace-old", + build_id: "build-old", + updated_at: "2025-01-01T00:00:00.000Z", + }); + + seedInfiniteChats(queryClient, [cachedChat]); + queryClient.setQueryData(chatKey(chatId), cachedChat); + + mergeWatchedChatIntoCaches(queryClient, staleWatchChat, { + eventKind: "status_change", + }); + + expect(readInfiniteChats(queryClient)?.[0]).toMatchObject({ + status: "completed", + title: "Fresh title", + last_model_config_id: "model-new", + workspace_id: "workspace-new", + build_id: "build-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + expect( + queryClient.getQueryData(chatKey(chatId)), + ).toMatchObject({ + status: "completed", + title: "Fresh title", + last_model_config_id: "model-new", + workspace_id: "workspace-new", + build_id: "build-new", + updated_at: "2025-01-01T00:05:00.000Z", + }); + }); +}); + describe("removeChildFromParentInCache", () => { it("removes the child from its parent's children array", () => { const queryClient = createTestQueryClient(); diff --git a/site/src/api/queries/chats.ts b/site/src/api/queries/chats.ts index 5a59772971120..5ab48bfbce1bc 100644 --- a/site/src/api/queries/chats.ts +++ b/site/src/api/queries/chats.ts @@ -187,6 +187,180 @@ export const removeChildFromParentInCache = ( return found; }; +const parseUpdatedAtInstant = (updatedAt: string) => { + const match = updatedAt.match(/^(.*?)(?:\.(\d+))?(Z|[+-]\d\d:\d\d)$/); + if (!match) { + const epochMs = Date.parse(updatedAt); + return Number.isNaN(epochMs) ? undefined : { epochMs, fractionalNanos: 0 }; + } + + const [, timestampWithoutFraction, fractionalSeconds = "", timezone] = match; + const epochMs = Date.parse(`${timestampWithoutFraction}${timezone}`); + if (Number.isNaN(epochMs)) { + return undefined; + } + return { + epochMs, + fractionalNanos: Number(fractionalSeconds.slice(0, 9).padEnd(9, "0")), + }; +}; + +const compareUpdatedAtInstants = (a: string, b: string): number => { + const parsedA = parseUpdatedAtInstant(a); + const parsedB = parseUpdatedAtInstant(b); + if (!parsedA || !parsedB) { + return a.localeCompare(b); + } + if (parsedA.epochMs !== parsedB.epochMs) { + return parsedA.epochMs - parsedB.epochMs; + } + return parsedA.fractionalNanos - parsedB.fractionalNanos; +}; + +type MergeWatchedChatOptions = { + readonly eventKind: TypesGen.ChatWatchEventKind; + readonly activeChatId?: string; +}; + +// Shallow-compare two ChatDiffStatus objects by their meaningful +// fields, ignoring refreshed_at/stale_at which change on every poll. +const diffStatusEqual = ( + a: TypesGen.ChatDiffStatus | undefined, + b: TypesGen.ChatDiffStatus | undefined, +): boolean => { + if (a === b) { + return true; + } + if (!a || !b) { + return false; + } + return ( + a.url === b.url && + a.pull_request_state === b.pull_request_state && + a.pull_request_title === b.pull_request_title && + a.pull_request_draft === b.pull_request_draft && + a.changes_requested === b.changes_requested && + a.additions === b.additions && + a.deletions === b.deletions && + a.changed_files === b.changed_files && + a.pr_number === b.pr_number && + a.approved === b.approved && + a.commits === b.commits + ); +}; + +/** + * Merges event-scoped chat fields into a cached summary, using updated_at + * as a stale guard while still adopting the latest DB-backed model config. + */ +export const mergeWatchedChatSummary = ( + cachedChat: TypesGen.Chat, + watchedChat: TypesGen.Chat, + { eventKind, activeChatId }: MergeWatchedChatOptions, +): TypesGen.Chat => { + const isTitleEvent = eventKind === "title_change"; + const isStatusEvent = eventKind === "status_change"; + const isDiffStatusEvent = eventKind === "diff_status_change"; + const updatedAtComparison = compareUpdatedAtInstants( + cachedChat.updated_at, + watchedChat.updated_at, + ); + const isFreshEnough = updatedAtComparison <= 0; + const nextStatus = + isFreshEnough && isStatusEvent ? watchedChat.status : cachedChat.status; + // maybeGenerateChatTitle can publish a previously loaded chat snapshot, so + // apply title_change payloads even when the chat summary timestamp is older. + const nextTitle = isTitleEvent ? watchedChat.title : cachedChat.title; + // Diff status freshness is tracked outside chats.updated_at, so apply + // diff_status_change payloads even when the chat summary timestamp is older. + const nextDiffStatus = isDiffStatusEvent + ? watchedChat.diff_status + : cachedChat.diff_status; + const nextWorkspaceId = isFreshEnough + ? (watchedChat.workspace_id ?? cachedChat.workspace_id) + : cachedChat.workspace_id; + const nextBuildId = isFreshEnough + ? (watchedChat.build_id ?? cachedChat.build_id) + : cachedChat.build_id; + // All event types carry the current model config from the DB. + const nextLastModelConfigId = isFreshEnough + ? watchedChat.last_model_config_id + : cachedChat.last_model_config_id; + const nextHasUnread = + isFreshEnough && isStatusEvent && watchedChat.id !== activeChatId + ? true + : cachedChat.has_unread; + const nextUpdatedAt = + updatedAtComparison > 0 ? cachedChat.updated_at : watchedChat.updated_at; + + // Keep updated_at in the no-op guard. This gives up the old streaming + // rerender shortcut so later stale events cannot pass isFreshEnough + // against a timestamp that should already have been superseded. + if ( + nextStatus === cachedChat.status && + nextTitle === cachedChat.title && + diffStatusEqual(nextDiffStatus, cachedChat.diff_status) && + nextWorkspaceId === cachedChat.workspace_id && + nextBuildId === cachedChat.build_id && + nextLastModelConfigId === cachedChat.last_model_config_id && + nextHasUnread === cachedChat.has_unread && + nextUpdatedAt === cachedChat.updated_at + ) { + return cachedChat; + } + + return { + ...cachedChat, + status: nextStatus, + title: nextTitle, + diff_status: nextDiffStatus, + workspace_id: nextWorkspaceId, + build_id: nextBuildId, + last_model_config_id: nextLastModelConfigId, + has_unread: nextHasUnread, + updated_at: nextUpdatedAt, + }; +}; + +/** + * Applies the same event-scoped merge and stale guard across the list, + * parent-child, and per-chat caches, covering all three cache layers. + */ +export const mergeWatchedChatIntoCaches = ( + queryClient: QueryClient, + watchedChat: TypesGen.Chat, + options: MergeWatchedChatOptions, +) => { + const mergeCachedChat = (cachedChat: TypesGen.Chat) => + mergeWatchedChatSummary(cachedChat, watchedChat, options); + + updateInfiniteChatsCache(queryClient, (chats) => { + let didUpdate = false; + const nextChats = chats.map((chat) => { + if (chat.id !== watchedChat.id) { + return chat; + } + const mergedChat = mergeCachedChat(chat); + if (mergedChat !== chat) { + didUpdate = true; + } + return mergedChat; + }); + return didUpdate ? nextChats : chats; + }); + + updateChildInParentCache(queryClient, mergeCachedChat, watchedChat.id); + queryClient.setQueryData( + chatKey(watchedChat.id), + (cachedChat) => { + if (!cachedChat) { + return cachedChat; + } + return mergeCachedChat(cachedChat); + }, + ); +}; + const getNextOptimisticPinOrder = (queryClient: QueryClient): number => { let maxPinOrder = 0; const queries = queryClient.getQueriesData< diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index ddedc3e99d3bd..2166824882443 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -2160,6 +2160,7 @@ export const ChatProviderConfigSources: ChatProviderConfigSource[] = [ export interface ChatQueuedMessage { readonly id: number; readonly chat_id: string; + readonly model_config_id?: string; readonly content: readonly ChatMessagePart[]; readonly created_at: string; } diff --git a/site/src/pages/AgentsPage/AgentsPage.tsx b/site/src/pages/AgentsPage/AgentsPage.tsx index 621944ef91b23..ba074f37c0eee 100644 --- a/site/src/pages/AgentsPage/AgentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.tsx @@ -20,6 +20,7 @@ import { chatsByWorkspaceKeyPrefix, infiniteChats, invalidateChatListQueries, + mergeWatchedChatIntoCaches, pinChat, prependToInfiniteChatsCache, readInfiniteChatsCache, @@ -29,7 +30,6 @@ import { unarchiveChat, unpinChat, updateChatTitle, - updateChildInParentCache, updateInfiniteChatsCache, } from "#/api/queries/chats"; import { workspaceById } from "#/api/queries/workspaces"; @@ -55,29 +55,6 @@ import { chatDetailErrorsEqual, } from "./utils/usageLimitMessage"; -// Shallow-compare two ChatDiffStatus objects by their meaningful -// fields, ignoring refreshed_at/stale_at which change on every poll. -function diffStatusEqual( - a: TypesGen.ChatDiffStatus | undefined, - b: TypesGen.ChatDiffStatus | undefined, -): boolean { - if (a === b) return true; - if (!a || !b) return false; - return ( - a.url === b.url && - a.pull_request_state === b.pull_request_state && - a.pull_request_title === b.pull_request_title && - a.pull_request_draft === b.pull_request_draft && - a.changes_requested === b.changes_requested && - a.additions === b.additions && - a.deletions === b.deletions && - a.changed_files === b.changed_files && - a.pr_number === b.pr_number && - a.approved === b.approved && - a.commits === b.commits - ); -} - export type { AgentsOutletContext } from "./AgentsPageView"; const AgentsPage: FC = () => { @@ -557,14 +534,8 @@ const AgentsPage: FC = () => { exact: true, }); } - // Scope field updates by event kind so that - // status_change events (which may carry a stale title - // snapshot from before async title generation - // finished) don't clobber a title_change that already - // landed. - const isTitleEvent = chatEvent.kind === "title_change"; - const isStatusEvent = chatEvent.kind === "status_change"; - const isDiffStatusEvent = chatEvent.kind === "diff_status_change"; + // Merge watch payloads by event kind so stale field + // snapshots do not clobber fresher cached metadata. // Cancel in-flight list and per-chat refetches so // they cannot overwrite the cache update below with @@ -606,117 +577,11 @@ const AgentsPage: FC = () => { prependToInfiniteChatsCache(queryClient, updatedChat); } } else { - // Build a field updater shared between root and - // child cache update paths. - const applyFields = (c: TypesGen.Chat): TypesGen.Chat => { - const nextStatus = isStatusEvent ? updatedChat.status : c.status; - const nextTitle = isTitleEvent ? updatedChat.title : c.title; - const nextDiffStatus = isDiffStatusEvent - ? updatedChat.diff_status - : c.diff_status; - const nextWorkspaceId = - updatedChat.workspace_id ?? c.workspace_id; - const nextBuildId = updatedChat.build_id ?? c.build_id; - const nextUpdatedAt = - c.updated_at > updatedChat.updated_at - ? c.updated_at - : updatedChat.updated_at; - // The server's pubsub path does not compute - // has_unread (it always sends false). For - // status_change events on non-active chats, - // optimistically mark as unread since the - // assistant produced new output. - const nextHasUnread = - isStatusEvent && updatedChat.id !== activeChatIDRef.current - ? true - : c.has_unread; - if ( - nextStatus === c.status && - nextTitle === c.title && - diffStatusEqual(nextDiffStatus, c.diff_status) && - nextWorkspaceId === c.workspace_id && - nextBuildId === c.build_id && - nextHasUnread === c.has_unread - ) { - return c; - } - return { - ...c, - status: nextStatus, - title: nextTitle, - diff_status: nextDiffStatus, - workspace_id: nextWorkspaceId, - build_id: nextBuildId, - updated_at: nextUpdatedAt, - has_unread: nextHasUnread, - }; - }; - - // Try root-level update first. - updateInfiniteChatsCache(queryClient, (chats) => { - let didUpdate = false; - const nextChats = chats.map((c) => { - if (c.id !== updatedChat.id) return c; - const result = applyFields(c); - if (result !== c) didUpdate = true; - return result; - }); - return didUpdate ? nextChats : chats; + mergeWatchedChatIntoCaches(queryClient, updatedChat, { + eventKind: chatEvent.kind, + activeChatId: activeChatIDRef.current, }); - - // Also update inside parent's children array - // in case the event targets a child chat. - updateChildInParentCache(queryClient, applyFields, updatedChat.id); } - queryClient.setQueryData( - chatKey(updatedChat.id), - (previousChat) => { - if (!previousChat) { - return previousChat; - } - // Only create a new object if a field actually - // changed. Returning the same reference prevents - // react-query from notifying subscribers, avoiding - // unnecessary re-renders of AgentChatPage during - // streaming when repeated status_change events - // carry the same "running" status. - const nextStatus = isStatusEvent - ? updatedChat.status - : previousChat.status; - const nextTitle = isTitleEvent - ? updatedChat.title - : previousChat.title; - const nextDiffStatus = isDiffStatusEvent - ? updatedChat.diff_status - : previousChat.diff_status; - const nextWorkspaceId = - updatedChat.workspace_id ?? previousChat.workspace_id; - const nextBuildId = updatedChat.build_id ?? previousChat.build_id; - const nextUpdatedAt = - previousChat.updated_at > updatedChat.updated_at - ? previousChat.updated_at - : updatedChat.updated_at; - - if ( - nextStatus === previousChat.status && - nextTitle === previousChat.title && - diffStatusEqual(nextDiffStatus, previousChat.diff_status) && - nextWorkspaceId === previousChat.workspace_id && - nextBuildId === previousChat.build_id - ) { - return previousChat; - } - return { - ...previousChat, - status: nextStatus, - title: nextTitle, - diff_status: nextDiffStatus, - workspace_id: nextWorkspaceId, - build_id: nextBuildId, - updated_at: nextUpdatedAt, - }; - }, - ); }); return ws; },