Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/mkdocs/en/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,6 @@ memory_service = MempalaceMemoryService(
),
wing="my_app_user",
room="conversations",
store_only_model_visible=True,
)
```

Expand Down
1 change: 0 additions & 1 deletion docs/mkdocs/zh/memory.md
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,6 @@ memory_service = MempalaceMemoryService(
memory_service_config=memory_service_config,
wing="my_app_user", # 可选:记忆命名空间;不传则默认由 save_key 推导
room="conversations", # 可选:记忆类别;默认 conversations
store_only_model_visible=True,
)
```

Expand Down
2 changes: 0 additions & 2 deletions examples/memory_service_with_mempalace/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,13 @@ memory_service = MempalaceMemoryService(
memory_service_config=memory_service_config,
wing="trpc-agent",
room="conversations",
store_only_model_visible=True,
)
```

这里的含义是:

- `wing="trpc-agent"`:把示例记忆固定写入 `trpc-agent` 这个 wing。
- `room="conversations"`:把普通对话记忆写入 `conversations` room。
- `store_only_model_visible=True`:只存模型可见的事件。
- `ttl_seconds=20`:超过 20 秒的记忆会被后台 cleanup 删除。
- `cleanup_interval_seconds=20`:每 20 秒执行一次清理。

Expand Down
1 change: 0 additions & 1 deletion examples/memory_service_with_mempalace/run_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def create_memory_service():
memory_service_config=memory_service_config,
wing="trpc-agent",
room="conversations",
store_only_model_visible=True,
)
return memory_service

Expand Down
7 changes: 3 additions & 4 deletions tests/agents/core/test_history_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,16 @@ def test_invocation_mode_filters_by_id(self, invocation_context):
assert len(events) == 1
assert events[0].content.parts[0].text == "current"

def test_invocation_mode_includes_summary_events(self, invocation_context):
def test_invocation_mode_filters_summary_events_by_id(self, invocation_context):
proc = HistoryProcessor(timeline_filter_mode=TimelineFilterMode.INVOCATION)
summary_event = _make_event("system", "Previous conversation summary", invocation_id="summary")
summary_event.set_summary_event(True)
current_event = _make_event("user", "current", invocation_id="inv-1")

events = proc.filter_events(invocation_context, [summary_event, current_event])

assert len(events) == 2
assert events[0].is_summary_event()
assert events[1].content.parts[0].text == "current"
assert len(events) == 1
assert events[0].content.parts[0].text == "current"


# ---------------------------------------------------------------------------
Expand Down
28 changes: 23 additions & 5 deletions tests/memory/test_mempalace_memory_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,25 +91,43 @@ def fake_store(session, events_to_store, wing, room):
assert "remember this" in events_to_store[0][1]
assert events_to_store[0][2] in svc._stored_drawer_ids

async def test_store_session_skips_invisible_events(self, monkeypatch):
async def test_store_session_ignores_model_visible_flag(self, monkeypatch):
calls = []

def fake_store(session, events_to_store, wing, room):
calls.append(events_to_store)
return {drawer_id for _, _, drawer_id in events_to_store}

visible_event = _make_event("visible")
invisible_event = _make_event("hidden")
invisible_event.set_model_visible(False)
flagged_event = _make_event("hidden")
flagged_event.set_model_visible(False)
svc = MempalaceMemoryService(memory_service_config=_make_config())
monkeypatch.setattr(svc, "_store_events", fake_store)

await svc.store_session(_make_session(events=[visible_event, invisible_event]))
await svc.store_session(_make_session(events=[visible_event, flagged_event]))
await svc.close()

assert len(calls) == 1
assert len(calls[0]) == 1
assert len(calls[0]) == 2
assert "visible" in calls[0][0][1]
assert "hidden" in calls[0][1][1]

async def test_store_only_model_visible_flag_is_compatibility_noop(self, monkeypatch):
calls = []

def fake_store(session, events_to_store, wing, room):
calls.append(events_to_store)
return {drawer_id for _, _, drawer_id in events_to_store}

svc = MempalaceMemoryService(memory_service_config=_make_config(), store_only_model_visible=False)
monkeypatch.setattr(svc, "_store_events", fake_store)

await svc.store_session(_make_session(events=[_make_event("active event")]))
await svc.close()

assert len(calls) == 1
assert len(calls[0]) == 1
assert "active event" in calls[0][0][1]

async def test_store_session_is_incremental(self, monkeypatch):
calls = []
Expand Down
57 changes: 41 additions & 16 deletions tests/sessions/test_base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,18 @@ async def test_append_event_empty_state_delta(self):
await svc.append_event(session, event)
assert len(session.events) == 1

async def test_append_event_stores_filtered_events_when_configured(self):
config = SessionServiceConfig(max_events=2, store_historical_events=True)
svc = ConcreteSessionService(session_config=config)
session = _make_session()

for i in range(6):
event = _make_event(author="user" if i == 2 else "agent", text=f"msg{i}")
await svc.append_event(session, event)

assert [event.get_text() for event in session.events] == ["msg2", "msg5"]
assert [event.get_text() for event in session.historical_events] == ["msg0", "msg1", "msg3", "msg4"]


class TestBaseSessionServiceTrimTempDeltaState:
"""Test _trim_temp_delta_state method."""
Expand Down Expand Up @@ -172,10 +184,23 @@ def test_filter_by_num_recent_events(self):
for i in range(10):
author = "user" if i == 7 else "agent"
session.events.append(_make_event(author=author, text=f"msg{i}"))
svc.filter_events(session)
filtered_session = svc.filter_events(session)
assert filtered_session is session
assert [event.get_text() for event in session.events] == ["msg7", "msg8", "msg9"]

def test_filter_by_num_recent_events_with_copy(self):
config = SessionServiceConfig(num_recent_events=3)
svc = ConcreteSessionService(session_config=config)
session = _make_session()
for i in range(10):
author = "user" if i == 7 else "agent"
session.events.append(_make_event(author=author, text=f"msg{i}"))

filtered_session = svc.filter_events(session, need_copy=True)

assert filtered_session is not session
assert len(session.events) == 10
visible_events = [event for event in session.events if event.is_model_visible()]
assert [event.get_text() for event in visible_events] == ["msg7", "msg8", "msg9"]
assert [event.get_text() for event in filtered_session.events] == ["msg7", "msg8", "msg9"]

def test_filter_by_event_ttl(self):
config = SessionServiceConfig(event_ttl_seconds=5.0)
Expand All @@ -190,18 +215,18 @@ def test_filter_by_event_ttl(self):
new_event.timestamp = time.time()
session.events.append(new_event)

svc.filter_events(session)
assert len(session.events) == 2
visible_events = [event for event in session.events if event.is_model_visible()]
assert len(visible_events) == 1
assert visible_events[0].get_text() == "new"
filtered_session = svc.filter_events(session)
assert filtered_session is session
assert len(session.events) == 1
assert session.events[0].get_text() == "new"

def test_filter_no_config(self):
svc = ConcreteSessionService()
session = _make_session()
for i in range(5):
session.events.append(_make_event(text=f"msg{i}"))
svc.filter_events(session)
filtered_session = svc.filter_events(session)
assert filtered_session is session
assert len(session.events) == 5

def test_filter_ttl_removes_all_old(self):
Expand All @@ -212,9 +237,9 @@ def test_filter_ttl_removes_all_old(self):
e = _make_event(text=f"old{i}")
e.timestamp = time.time() - 100
session.events.append(e)
svc.filter_events(session)
assert len(session.events) == 5
assert all(not event.is_model_visible() for event in session.events)
filtered_session = svc.filter_events(session)
assert filtered_session is session
assert session.events == []

def test_filter_by_num_recent_events_preserves_summary_anchor(self):
config = SessionServiceConfig(num_recent_events=3)
Expand All @@ -227,11 +252,11 @@ def test_filter_by_num_recent_events_preserves_summary_anchor(self):
for i in range(5):
session.events.append(_make_event(text=f"agent{i}"))

svc.filter_events(session)
filtered_session = svc.filter_events(session)

visible_events = [event for event in session.events if event.is_model_visible()]
assert len(visible_events) == 1
assert visible_events[0].is_summary_event()
assert filtered_session is session
assert len(session.events) == 1
assert session.events[0].is_summary_event()


class TestBaseSessionServiceSetSummarizerManager:
Expand Down
1 change: 1 addition & 0 deletions tests/sessions/test_in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ async def test_list_sessions_have_no_events(self):
result = await svc.list_sessions(app_name="app", user_id="user")
assert len(result.sessions) == 1
assert result.sessions[0].events == []
assert result.sessions[0].historical_events == []
await svc.close()

async def test_list_nonexistent_app(self):
Expand Down
44 changes: 42 additions & 2 deletions tests/sessions/test_redis_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
from trpc_agent_sdk.types import Content, EventActions, Part, State


def _make_config(ttl_seconds=0, cleanup_interval=0.0, enable_ttl=False):
config = SessionServiceConfig()
def _make_config(ttl_seconds=0,
cleanup_interval=0.0,
enable_ttl=False,
max_events=0,
store_historical_events=False):
config = SessionServiceConfig(max_events=max_events, store_historical_events=store_historical_events)
if enable_ttl:
config.ttl = SessionServiceConfig.create_ttl_config(
enable=True, ttl_seconds=ttl_seconds, cleanup_interval_seconds=cleanup_interval)
Expand Down Expand Up @@ -203,6 +207,7 @@ async def test_list_sessions_have_no_events(self):
result = await svc.list_sessions(app_name="app", user_id="user")
for s in result.sessions:
assert s.events == []
assert s.historical_events == []
await svc.close()


Expand Down Expand Up @@ -256,6 +261,28 @@ async def test_append_with_state_delta(self):
assert stored.state[f"{State.USER_PREFIX}user_key"] == "uv"
await svc.close()

async def test_append_does_not_persist_merged_or_temp_state_in_session_json(self):
svc = _create_service()
session = await svc.create_session(app_name="app", user_id="user", session_id="s1")
event = _make_event(state_delta={
"session_key": "sv",
f"{State.APP_PREFIX}app_key": "av",
f"{State.USER_PREFIX}user_key": "uv",
f"{State.TEMP_PREFIX}temp_key": "tv",
})

await svc.append_event(session, event)

stored_json = svc._redis_storage._store["session:app:user:s1"]
raw_session = Session.model_validate_json(stored_json)
assert raw_session.state == {"session_key": "sv"}
stored = await svc.get_session(app_name="app", user_id="user", session_id="s1")
assert stored.state["session_key"] == "sv"
assert stored.state[f"{State.APP_PREFIX}app_key"] == "av"
assert stored.state[f"{State.USER_PREFIX}user_key"] == "uv"
assert f"{State.TEMP_PREFIX}temp_key" not in raw_session.state
await svc.close()

async def test_append_to_nonexistent_session(self):
svc = _create_service()
session = _make_session_obj(id="nonexistent")
Expand All @@ -264,6 +291,19 @@ async def test_append_to_nonexistent_session(self):
assert result is event
await svc.close()

async def test_append_persists_filtered_active_and_historical_events(self):
config = _make_config(max_events=2, store_historical_events=True)
svc = _create_service(config=config)
session = await svc.create_session(app_name="app", user_id="user", session_id="s1")

for i in range(4):
await svc.append_event(session, _make_event(author="user" if i == 2 else "agent", text=f"msg{i}"))

stored = await svc.get_session(app_name="app", user_id="user", session_id="s1")
assert [event.get_text() for event in stored.events] == ["msg2", "msg3"]
assert [event.get_text() for event in stored.historical_events] == ["msg0", "msg1"]
await svc.close()


class TestRedisUpdateSession:
async def test_update_existing(self):
Expand Down
Loading
Loading