diff --git a/astrbot/builtin_stars/astrbot/long_term_memory.py b/astrbot/builtin_stars/astrbot/long_term_memory.py index e08cdc5157..288668bf25 100644 --- a/astrbot/builtin_stars/astrbot/long_term_memory.py +++ b/astrbot/builtin_stars/astrbot/long_term_memory.py @@ -23,6 +23,12 @@ def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None: self.session_chats = defaultdict(list) """记录群成员的群聊记录""" + def _group_key(self, event: AstrMessageEvent) -> str: + """获取群级别的 key,不受 unique_session 影响""" + if event.get_message_type() == MessageType.GROUP_MESSAGE and event.get_group_id(): + return f"{event.get_platform_id()}:GroupMessage:{event.get_group_id()}" + return event.unified_msg_origin + def cfg(self, event: AstrMessageEvent): cfg = self.context.get_config(umo=event.unified_msg_origin) try: @@ -58,9 +64,10 @@ def cfg(self, event: AstrMessageEvent): async def remove_session(self, event: AstrMessageEvent) -> int: cnt = 0 - if event.unified_msg_origin in self.session_chats: - cnt = len(self.session_chats[event.unified_msg_origin]) - del self.session_chats[event.unified_msg_origin] + group_key = self._group_key(event) + if group_key in self.session_chats: + cnt = len(self.session_chats[group_key]) + del self.session_chats[group_key] return cnt async def get_image_caption( @@ -143,17 +150,19 @@ async def handle_message(self, event: AstrMessageEvent) -> None: parts.append(f" [At: {comp.name}]") final_message = "".join(parts) - logger.debug(f"ltm | {event.unified_msg_origin} | {final_message}") - self.session_chats[event.unified_msg_origin].append(final_message) - if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: - self.session_chats[event.unified_msg_origin].pop(0) + group_key = self._group_key(event) + logger.debug(f"ltm | {group_key} | {final_message}") + self.session_chats[group_key].append(final_message) + if len(self.session_chats[group_key]) > cfg["max_cnt"]: + self.session_chats[group_key].pop(0) async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None: """当触发 LLM 请求前,调用此方法修改 req""" - if event.unified_msg_origin not in self.session_chats: + group_key = self._group_key(event) + if group_key not in self.session_chats: return - chats_str = "\n---\n".join(self.session_chats[event.unified_msg_origin]) + chats_str = "\n---\n".join(self.session_chats[group_key]) cfg = self.cfg(event) if cfg["enable_active_reply"]: @@ -174,15 +183,16 @@ async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> Non async def after_req_llm( self, event: AstrMessageEvent, llm_resp: LLMResponse ) -> None: - if event.unified_msg_origin not in self.session_chats: + group_key = self._group_key(event) + if group_key not in self.session_chats: return if llm_resp.completion_text: final_message = f"[You/{datetime.datetime.now().strftime('%H:%M:%S')}]: {llm_resp.completion_text}" logger.debug( - f"Recorded AI response: {event.unified_msg_origin} | {final_message}" + f"Recorded AI response: {group_key} | {final_message}" ) - self.session_chats[event.unified_msg_origin].append(final_message) + self.session_chats[group_key].append(final_message) cfg = self.cfg(event) - if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]: - self.session_chats[event.unified_msg_origin].pop(0) + if len(self.session_chats[group_key]) > cfg["max_cnt"]: + self.session_chats[group_key].pop(0) diff --git a/astrbot/builtin_stars/astrbot/main.py b/astrbot/builtin_stars/astrbot/main.py index 3d800edd26..eb7816a234 100644 --- a/astrbot/builtin_stars/astrbot/main.py +++ b/astrbot/builtin_stars/astrbot/main.py @@ -33,6 +33,19 @@ def __init__(self, context: star.Context) -> None: self.ltm = None try: self.ltm = LongTermMemory(self.context.astrbot_config_mgr, self.context) + + async def _clear_ltm_session(umo: str) -> None: + self.ltm.session_chats.pop(umo, None) + # Also clear group-level key for unique_session scenarios + parts = umo.split(":") + if len(parts) >= 3 and parts[1] == "GroupMessage": + group_id = parts[2].split("%")[-1] + group_key = f"{parts[0]}:GroupMessage:{group_id}" + self.ltm.session_chats.pop(group_key, None) + + self.context.conversation_manager.register_on_session_deleted( + _clear_ltm_session + ) except BaseException as e: logger.error(f"聊天增强 err: {e}") diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 2c282867f9..cd7c593a20 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -142,6 +142,7 @@ async def delete_conversation( if curr_cid == conversation_id: self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") + await self._trigger_session_deleted(unified_msg_origin) async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None: """删除会话的所有对话 diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 9ec24d254d..75c76096f3 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1315,6 +1315,14 @@ async def post_delete_platform(self): try: save_config(self.config, self.config, is_core=True) await self.core_lifecycle.platform_manager.terminate_platform(platform_id) + convs = await self.core_lifecycle.db.get_conversations( + platform_id=platform_id + ) + for conv in convs: + await self.core_lifecycle.conversation_manager.delete_conversation( + unified_msg_origin=conv.user_id, + conversation_id=conv.conversation_id, + ) except Exception as e: return Response().error(str(e)).__dict__ return Response().ok(None, "删除平台配置成功~").__dict__ diff --git a/tests/test_ltm_cleanup_on_delete.py b/tests/test_ltm_cleanup_on_delete.py new file mode 100644 index 0000000000..0edcc19096 --- /dev/null +++ b/tests/test_ltm_cleanup_on_delete.py @@ -0,0 +1,175 @@ +"""测试 Web UI 删除对话后 LTM session_chats 是否被正确清理 (Issue #8386 Bug 1)""" + +import pytest +import pytest_asyncio +from collections import defaultdict +from unittest.mock import AsyncMock, MagicMock + +from astrbot.core.conversation_mgr import ConversationManager +from astrbot.core.platform.message_type import MessageType +from astrbot.builtin_stars.astrbot.long_term_memory import LongTermMemory + + +@pytest_asyncio.fixture +async def conversation_manager(): + db = AsyncMock() + db.delete_conversation = AsyncMock() + db.get_conversation_by_id = AsyncMock(return_value=None) + mgr = ConversationManager(db) + return mgr + + +@pytest.fixture +def ltm(): + acm = MagicMock() + context = MagicMock() + context.get_config = MagicMock(return_value={ + "provider_ltm_settings": { + "group_message_max_cnt": 300, + "image_caption": False, + "image_caption_provider_id": "", + "active_reply": {"enable": False, "method": "possibility_reply", "possibility_reply": 0.1}, + }, + "provider_settings": {"image_caption_prompt": ""}, + }) + return LongTermMemory(acm, context) + + +@pytest.mark.asyncio +async def test_delete_conversation_triggers_session_deleted_callback(conversation_manager): + """验证 delete_conversation 会触发 _on_session_deleted_callbacks""" + callback = AsyncMock() + conversation_manager.register_on_session_deleted(callback) + + umo = "feishu:group:test_group_123" + conversation_manager.session_conversations[umo] = "conv-id-1" + + await conversation_manager.delete_conversation( + unified_msg_origin=umo, + conversation_id="conv-id-1", + ) + + callback.assert_called_once_with(umo) + + +@pytest.mark.asyncio +async def test_delete_conversation_clears_ltm_session_chats(conversation_manager, ltm): + """模拟完整流程:LTM 注册回调后,Web UI 删除对话应清理 session_chats""" + umo = "feishu:group:test_group_456" + + # 模拟群聊中已有 LTM 记录 + ltm.session_chats[umo] = [ + "[Alice/10:00:00]: 你好", + "[Bob/10:01:00]: 你好啊", + "[You/10:01:30]: 大家好!", + ] + + # 注册回调(和 main.py 中的逻辑一致) + async def _clear_ltm_session(origin: str) -> None: + ltm.session_chats.pop(origin, None) + + conversation_manager.register_on_session_deleted(_clear_ltm_session) + + # 模拟当前会话指向该对话 + conversation_manager.session_conversations[umo] = "conv-id-2" + + # 执行删除(Web UI 路径) + await conversation_manager.delete_conversation( + unified_msg_origin=umo, + conversation_id="conv-id-2", + ) + + # 验证 LTM 内存已清理 + assert umo not in ltm.session_chats + + +@pytest.mark.asyncio +async def test_ltm_on_req_llm_skips_after_session_cleared(conversation_manager, ltm): + """删除对话后,on_req_llm 不应再注入已删除的历史到 system_prompt""" + umo = "lark:GroupMessage:test_group_789" + + # 模拟已有 LTM 记录(存储在 group-level key 下) + ltm.session_chats[umo] = [ + "[User1/09:00:00]: 之前的秘密对话", + ] + + async def _clear_ltm_session(origin: str) -> None: + ltm.session_chats.pop(origin, None) + parts = origin.split(":") + if len(parts) >= 3 and parts[1] == "GroupMessage": + group_id = parts[2].split("%")[-1] + group_key = f"{parts[0]}:GroupMessage:{group_id}" + ltm.session_chats.pop(group_key, None) + + conversation_manager.register_on_session_deleted(_clear_ltm_session) + conversation_manager.session_conversations[umo] = "conv-id-3" + + # 删除对话 + await conversation_manager.delete_conversation( + unified_msg_origin=umo, + conversation_id="conv-id-3", + ) + + # 模拟后续 LLM 请求 + event = MagicMock() + event.unified_msg_origin = umo + event.get_message_type.return_value = MessageType.GROUP_MESSAGE + event.get_group_id.return_value = "test_group_789" + event.get_platform_id.return_value = "lark" + + req = MagicMock() + req.system_prompt = "You are a helpful assistant." + req.prompt = "你好" + req.contexts = [] + + await ltm.on_req_llm(event, req) + + # system_prompt 不应包含已删除的历史 + assert "秘密对话" not in req.system_prompt + + +@pytest.mark.asyncio +async def test_delete_other_conversation_does_not_affect_unrelated_session(conversation_manager, ltm): + """删除某个 session 的对话不应影响其他 session 的 LTM 记录""" + umo_a = "feishu:group:group_a" + umo_b = "feishu:group:group_b" + + ltm.session_chats[umo_a] = ["[A/10:00:00]: hello"] + ltm.session_chats[umo_b] = ["[B/10:00:00]: world"] + + async def _clear_ltm_session(origin: str) -> None: + ltm.session_chats.pop(origin, None) + + conversation_manager.register_on_session_deleted(_clear_ltm_session) + conversation_manager.session_conversations[umo_a] = "conv-a" + + # 只删除 group_a + await conversation_manager.delete_conversation( + unified_msg_origin=umo_a, + conversation_id="conv-a", + ) + + assert umo_a not in ltm.session_chats + assert ltm.session_chats[umo_b] == ["[B/10:00:00]: world"] + + +@pytest.mark.asyncio +async def test_group_key_ignores_unique_session(ltm): + """Bug 3: unique_session 开启时,不同用户的 group_key 应相同(都指向群级别)""" + # 用户 A 的 event(unique_session 改写了 unified_msg_origin) + event_a = MagicMock() + event_a.unified_msg_origin = "lark:GroupMessage:userA%group123" + event_a.get_message_type.return_value = MessageType.GROUP_MESSAGE + event_a.get_group_id.return_value = "group123" + event_a.get_platform_id.return_value = "lark" + + # 用户 B 的 event + event_b = MagicMock() + event_b.unified_msg_origin = "lark:GroupMessage:userB%group123" + event_b.get_message_type.return_value = MessageType.GROUP_MESSAGE + event_b.get_group_id.return_value = "group123" + event_b.get_platform_id.return_value = "lark" + + # 两者的 group_key 应该相同 + assert ltm._group_key(event_a) == ltm._group_key(event_b) + assert ltm._group_key(event_a) == "lark:GroupMessage:group123"