diff --git a/astrbot/core/agent/context/compressor.py b/astrbot/core/agent/context/compressor.py index 424e3381e2..759604dd93 100644 --- a/astrbot/core/agent/context/compressor.py +++ b/astrbot/core/agent/context/compressor.py @@ -1,6 +1,11 @@ from typing import TYPE_CHECKING, Protocol, runtime_checkable +from ...provider.modalities import ( + log_context_sanitize_stats, + sanitize_contexts_by_modalities, +) from ..message import Message +from .token_counter import EstimateTokenCounter, TokenCounter if TYPE_CHECKING: from astrbot import logger @@ -96,23 +101,6 @@ async def __call__(self, messages: list[Message]) -> list[Message]: return truncated_messages -def _message_to_dict(msg: Message) -> dict: - """Convert a Message to a plain dict suitable for round splitting.""" - d = {"role": msg.role} - if msg.content is not None: - d["content"] = msg.content - if getattr(msg, "tool_calls", None): - d["tool_calls"] = msg.tool_calls - if getattr(msg, "tool_call_id", None): - d["tool_call_id"] = msg.tool_call_id - return d - - -def _dict_to_message(d: dict) -> Message: - """Convert a plain dict back to a Message.""" - return Message(**d) - - def _extract_system_messages(messages: list[Message]) -> list[Message]: """Return the leading system messages from a message list.""" result = [] @@ -126,28 +114,36 @@ def _extract_system_messages(messages: list[Message]) -> list[Message]: class LLMSummaryCompressor: """LLM-based summary compressor. - Uses LLM to summarize the old conversation history, keeping the latest messages. + Uses LLM to summarize old conversation history while keeping a recent token + budget as exact context. """ + TASK_CONTINUATION_INSTRUCTION = ( + "If a task appears to be in progress, end the summary with the latest " + "known result and the concrete next step to continue the task." + ) + def __init__( self, provider: "Provider", - keep_recent: int = 4, + keep_recent_ratio: float = 0.15, instruction_text: str | None = None, compression_threshold: float = 0.82, + token_counter: TokenCounter | None = None, ) -> None: """Initialize the LLM summary compressor. Args: provider: The LLM provider instance. - keep_recent: The number of latest messages to keep (default: 4). + keep_recent_ratio: Ratio of current context tokens to keep as recent + exact context. Clamped to 0-0.3. instruction_text: Custom instruction for summary generation. compression_threshold: The compression trigger threshold (default: 0.82). """ self.provider = provider - self.keep_recent = keep_recent + self.keep_recent_ratio = min(max(float(keep_recent_ratio), 0.0), 0.3) self.compression_threshold = compression_threshold - self.existing_summary: str = "" + self.token_counter = token_counter or EstimateTokenCounter() self.instruction_text = instruction_text or ( "Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n" @@ -177,6 +173,36 @@ def should_compress( usage_rate = current_tokens / max_tokens return usage_rate > self.compression_threshold + def _split_recent_rounds_by_token_ratio( + self, + rounds: list[list[Message]], + total_tokens: int, + ) -> tuple[list[list[Message]], list[list[Message]]]: + """Split rounds into summarised history and exact recent context. + + The token budget is computed from the current context token count and + `keep_recent_ratio`, then floored by `int(...)`. Mapping that budget to + rounds is round-granular: a positive ratio always preserves the latest + whole round, even if that round itself exceeds the budget. Earlier + rounds are added only while the accumulated recent rounds stay within + the budget. No round is split. + """ + if not rounds or self.keep_recent_ratio <= 0 or total_tokens <= 0: + return rounds, [] + + budget = max(1, int(total_tokens * self.keep_recent_ratio)) + used = 0 + recent_start = len(rounds) + + for idx in range(len(rounds) - 1, -1, -1): + round_tokens = self.token_counter.count_tokens(rounds[idx]) + if used > 0 and used + round_tokens > budget: + break + used += round_tokens + recent_start = idx + + return rounds[:recent_start], rounds[recent_start:] + async def __call__(self, messages: list[Message]) -> list[Message]: """Use LLM to generate a summary of the conversation history. @@ -184,40 +210,72 @@ async def __call__(self, messages: list[Message]) -> list[Message]: On LLM failure, returns the original messages unchanged (caller should fall back to truncation). """ - from .round_utils import rounds_to_text, split_into_rounds - - # Convert messages to dict list for round splitting - msg_dicts = [_message_to_dict(m) for m in messages] - rounds = split_into_rounds(msg_dicts) - - if len(rounds) <= self.keep_recent: - return messages + from .round_utils import split_into_rounds + + rounds = split_into_rounds(messages) + message_rounds = [ + [seg for seg in rnd if isinstance(seg, Message)] for rnd in rounds + ] + total_tokens = self.token_counter.count_tokens(messages) + old_rounds, recent_rounds = self._split_recent_rounds_by_token_ratio( + message_rounds, + total_tokens, + ) - old_rounds = rounds[: -self.keep_recent] - recent_rounds = rounds[-self.keep_recent :] + # The latest user message is the active request. Keep its whole round + # exact even when the ratio is 0 or the ratio budget would otherwise + # summarize every round. + if messages and messages[-1].role == "user" and old_rounds: + latest_old_round = old_rounds[-1] + if latest_old_round and latest_old_round[-1] is messages[-1]: + old_rounds = old_rounds[:-1] + recent_rounds = [latest_old_round, *recent_rounds] if not old_rounds: - return messages - - # Build LLM payload - old_text = rounds_to_text(old_rounds) - existing_note = "" - if self.existing_summary: - existing_note = ( - "\nExisting memory summary (merge with old rounds above):\n" - f"{self.existing_summary}\n" + if recent_rounds and messages and messages[-1].role == "user": + return messages + old_rounds = message_rounds + recent_rounds = [] + + summary_contexts = [msg for rnd in old_rounds for msg in rnd] + if not any(msg.role != "system" for msg in summary_contexts): + if recent_rounds and messages and messages[-1].role == "user": + return messages + old_rounds = message_rounds + recent_rounds = [] + summary_contexts = [msg for rnd in old_rounds for msg in rnd] + if not any(msg.role != "system" for msg in summary_contexts): + return messages + + if summary_contexts[-1].role != "assistant": + summary_contexts.append( + Message( + role="assistant", + content="Acknowledged.", + ) + ) + summary_contexts.append( + Message( + role="user", + content=( + "Generate a summary of our previous conversation history.\n" + f"\n{self.instruction_text}\n\n" + f"{self.TASK_CONTINUATION_INSTRUCTION}\n" + "Respond ONLY with the summary content, without any additional text or formatting." + ), ) - prompt = ( - f"{self.instruction_text}\n\n" - "--- BEGIN CONVERSATION ROUNDS TO SUMMARIZE ---\n" - f"{old_text}\n" - "--- END CONVERSATION ROUNDS ---" - f"{existing_note}" ) + sanitized_summary_contexts, sanitize_stats = sanitize_contexts_by_modalities( + summary_contexts, + self.provider.provider_config.get("modalities", None), + ) + log_context_sanitize_stats(sanitize_stats) # Generate summary try: - response = await self.provider.text_chat(prompt=prompt) + response = await self.provider.text_chat( + contexts=sanitized_summary_contexts, + ) summary_content = (response.completion_text or "").strip() except Exception as e: logger.error(f"Failed to generate summary: {e}") @@ -246,6 +304,7 @@ async def __call__(self, messages: list[Message]) -> list[Message]: # Flatten recent rounds back to message list for rnd in recent_rounds: for seg in rnd: - result.append(_dict_to_message(seg)) + if isinstance(seg, Message): + result.append(seg) return result diff --git a/astrbot/core/agent/context/config.py b/astrbot/core/agent/context/config.py index b8fd8eb968..aa216d9a25 100644 --- a/astrbot/core/agent/context/config.py +++ b/astrbot/core/agent/context/config.py @@ -25,8 +25,8 @@ class ContextConfig: """ llm_compress_instruction: str | None = None """Instruction prompt for LLM-based compression.""" - llm_compress_keep_recent: int = 0 - """Number of recent messages to keep during LLM-based compression.""" + llm_compress_keep_recent_ratio: float = 0.15 + """Percent of current context tokens to keep as exact recent context during LLM-based compression.""" llm_compress_provider: "Provider | None" = None """LLM provider used for compression tasks. If None, truncation strategy is used.""" custom_token_counter: TokenCounter | None = None diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 216a3e7e15..1a11ebff96 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -33,8 +33,9 @@ def __init__( elif config.llm_compress_provider: self.compressor = LLMSummaryCompressor( provider=config.llm_compress_provider, - keep_recent=config.llm_compress_keep_recent, + keep_recent_ratio=config.llm_compress_keep_recent_ratio, instruction_text=config.llm_compress_instruction, + token_counter=self.token_counter, ) else: self.compressor = TruncateByTurnsCompressor( diff --git a/astrbot/core/agent/context/round_utils.py b/astrbot/core/agent/context/round_utils.py index 20c2f5711f..c93057ef44 100644 --- a/astrbot/core/agent/context/round_utils.py +++ b/astrbot/core/agent/context/round_utils.py @@ -1,21 +1,32 @@ """Round-based utilities shared by LTM compaction and LLMSummaryCompressor.""" import json +from collections.abc import Sequence from typing import Any +from ..message import ContentPart, Message, ToolCall + +RoundSegment = dict[str, Any] | Message + + +def _segment_role(seg: RoundSegment) -> str: + if isinstance(seg, Message): + return seg.role + return str(seg.get("role", "?")) + def split_into_rounds( - contexts: list[dict[str, Any]], -) -> list[list[dict[str, Any]]]: + contexts: Sequence[RoundSegment], +) -> list[list[RoundSegment]]: """Split a flat contexts list into logical rounds. A round begins at a ``user`` segment and includes all subsequent ``assistant`` / ``tool`` segments until the next ``user`` segment. """ - rounds: list[list[dict[str, Any]]] = [] - current: list[dict[str, Any]] = [] + rounds: list[list[RoundSegment]] = [] + current: list[RoundSegment] = [] for seg in contexts: - if seg.get("role") == "user" and current: + if _segment_role(seg) == "user" and current: rounds.append(current) current = [] current.append(seg) @@ -24,15 +35,38 @@ def split_into_rounds( return rounds -def rounds_to_text(rounds: list[list[dict[str, Any]]]) -> str: +def _content_to_text(content: Any) -> str: + if isinstance(content, list): + normalized = [ + part.model_dump_for_context() if isinstance(part, ContentPart) else part + for part in content + ] + return json.dumps(normalized, ensure_ascii=False) + if isinstance(content, ContentPart): + return json.dumps(content.model_dump_for_context(), ensure_ascii=False) + return str(content or "") + + +def _segment_content(seg: RoundSegment) -> Any: + if isinstance(seg, Message): + if seg.content is not None: + return seg.content + if seg.tool_calls: + return [ + tc.model_dump() if isinstance(tc, ToolCall) else tc + for tc in seg.tool_calls + ] + return "" + return seg.get("content") or seg.get("tool_calls") or "" + + +def rounds_to_text(rounds: list[list[RoundSegment]]) -> str: """Render rounds into a plain-text string for LLM summarisation.""" lines: list[str] = [] for i, rnd in enumerate(rounds, 1): lines.append(f"--- Round {i} ---") for seg in rnd: - role = seg.get("role", "?") - content = seg.get("content") or seg.get("tool_calls") or "" - if isinstance(content, list): - content = json.dumps(content, ensure_ascii=False) + role = _segment_role(seg) + content = _content_to_text(_segment_content(seg)) lines.append(f"[{role}] {content}") return "\n".join(lines) diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 4078b6f72b..3f74f0ec9b 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -216,7 +216,7 @@ async def reset( enforce_max_turns: int = -1, # llm compressor llm_compress_instruction: str | None = None, - llm_compress_keep_recent: int = 0, + llm_compress_keep_recent_ratio: float = 0.15, llm_compress_provider: Provider | None = None, # truncate by turns compressor truncate_turns: int = 1, @@ -233,7 +233,7 @@ async def reset( self.streaming = streaming self.enforce_max_turns = enforce_max_turns self.llm_compress_instruction = llm_compress_instruction - self.llm_compress_keep_recent = llm_compress_keep_recent + self.llm_compress_keep_recent_ratio = llm_compress_keep_recent_ratio self.llm_compress_provider = llm_compress_provider self.truncate_turns = truncate_turns self.custom_token_counter = custom_token_counter @@ -241,19 +241,21 @@ async def reset( self.tool_result_overflow_dir = tool_result_overflow_dir self.read_tool = read_tool self._tool_result_token_counter = EstimateTokenCounter() - self.context_config = ContextConfig( + self.request_context_manager_config = ContextConfig( # <=0 disables token-based guarding. max_context_tokens=provider.provider_config.get("max_context_tokens", 0), # Enforce max turns before token-based guarding. enforce_max_turns=self.enforce_max_turns, truncate_turns=self.truncate_turns, llm_compress_instruction=self.llm_compress_instruction, - llm_compress_keep_recent=self.llm_compress_keep_recent, + llm_compress_keep_recent_ratio=self.llm_compress_keep_recent_ratio, llm_compress_provider=self.llm_compress_provider, custom_token_counter=self.custom_token_counter, custom_compressor=self.custom_compressor, ) - self.context_manager = ContextManager(self.context_config) + self.request_context_manager = ContextManager( + self.request_context_manager_config + ) self.provider = provider self.fallback_providers: list[Provider] = [] @@ -579,7 +581,9 @@ def _sanitize_contexts_for_provider( contexts: list[Message] | list[dict[str, T.Any]], ) -> list[Message] | list[dict[str, T.Any]]: modalities = self.provider.provider_config.get("modalities", None) - if not modalities: # Unconfigured (None or empty list) defaults to support all modalities + if ( + not modalities + ): # Unconfigured (None or empty list) defaults to support all modalities return contexts sanitized_contexts, stats = sanitize_contexts_by_modalities( contexts, @@ -600,11 +604,14 @@ def _func_tool_for_provider(self) -> ToolSet | None: return None return self.req.func_tool - def _simple_print_message_role(self, tag: str = ""): - roles = [] - for message in self.run_context.messages: - roles.append(message.role) - logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}") + def _simple_print_message_role(self, tag: str, messages: list): + roles = [m.role for m in messages] + n = len(roles) + if n > 10: + summary = ",".join(roles[:4]) + ",...," + ",".join(roles[-4:]) + else: + summary = ",".join(roles) + logger.debug(f"{tag} messages -> [{n}] {summary}") def follow_up( self, @@ -698,16 +705,13 @@ async def step(self): self._transition_state(AgentState.RUNNING) llm_resp_result = None - # Process request-time context on a copy so the runner's canonical - # messages are never mutated. The processed result is only used for this - # provider call. Persistent compaction is owned by the conversation / - # memory layer. + # Process request-time context before sending it to the provider. token_usage = self.req.conversation.token_usage if self.req.conversation else 0 - self._simple_print_message_role("[BefCompact]") - self.run_context.messages = await self.context_manager.process( + self._simple_print_message_role("[BefCompact]", self.run_context.messages) + self.run_context.messages = await self.request_context_manager.process( self.run_context.messages, trusted_token_usage=token_usage ) - self._simple_print_message_role("[AftCompact]") + self._simple_print_message_role("[AftCompact]", self.run_context.messages) async for llm_response in self._iter_llm_responses_with_fallback(): if llm_response.is_chunk: diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index 233b6bb1e1..1c4fd400a0 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -150,8 +150,8 @@ class MainAgentBuildConfig: """The strategy to handle context length limit reached.""" llm_compress_instruction: str = "" """The instruction for compression in llm_compress strategy.""" - llm_compress_keep_recent: int = 10 - """The number of most recent turns to keep during llm_compress strategy.""" + llm_compress_keep_recent_ratio: float = 0.15 + """Percent of current context tokens to keep as exact recent context during llm_compress strategy.""" llm_compress_provider_id: str = "" """The provider ID for the LLM used in context compression.""" max_context_length: int = 50 @@ -1531,9 +1531,10 @@ async def build_main_agent( agent_hooks=MAIN_AGENT_HOOKS, streaming=config.streaming_response, llm_compress_instruction=config.llm_compress_instruction, - llm_compress_keep_recent=config.llm_compress_keep_recent, + llm_compress_keep_recent_ratio=config.llm_compress_keep_recent_ratio, llm_compress_provider=_get_compress_provider(config, plugin_context, event), truncate_turns=config.dequeue_context_length, + enforce_max_turns=config.max_context_length, tool_schema_mode=config.tool_schema_mode, fallback_providers=fallback_providers, tool_result_overflow_dir=( diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 642a7bed50..22a53bb446 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -130,7 +130,7 @@ "4. If there was an initial user goal, state it first and describe the current progress/status.\n" "5. Write the summary in the user's language.\n" ), - "llm_compress_keep_recent": 10, + "llm_compress_keep_recent_ratio": 0.15, "llm_compress_provider_id": "", "max_context_length": 50, "dequeue_context_length": 10, @@ -3581,10 +3581,11 @@ "provider_settings.agent_runner_type": "local", }, }, - "provider_settings.llm_compress_keep_recent": { - "description": "压缩时保留最近对话轮数", - "type": "int", - "hint": "始终保留的最近 N 轮对话。", + "provider_settings.llm_compress_keep_recent_ratio": { + "description": "压缩时保留最近上下文比例", + "type": "float", + "slider": {"min": 0, "max": 0.3, "step": 0.01}, + "hint": "按当前上下文 token 数保留最近内容,范围 0-0.3。0.15 表示保留 15%;比例大于 0 时至少保留最后一轮。", "condition": { "provider_settings.context_limit_reached_strategy": "llm_compress", "provider_settings.agent_runner_type": "local", diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index c29a0d09f1..0b636b5b2b 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -97,8 +97,8 @@ async def initialize(self, ctx: PipelineContext) -> None: self.llm_compress_instruction: str = settings.get( "llm_compress_instruction", "" ) - self.llm_compress_keep_recent: int = settings.get( - "llm_compress_keep_recent", 5 + self.llm_compress_keep_recent_ratio: float = settings.get( + "llm_compress_keep_recent_ratio", 0.15 ) self.llm_compress_provider_id: str = settings.get( "llm_compress_provider_id", "" @@ -138,7 +138,7 @@ async def initialize(self, ctx: PipelineContext) -> None: file_extract_msh_api_key=self.file_extract_msh_api_key, context_limit_reached_strategy=self.context_limit_reached_strategy, llm_compress_instruction=self.llm_compress_instruction, - llm_compress_keep_recent=self.llm_compress_keep_recent, + llm_compress_keep_recent_ratio=self.llm_compress_keep_recent_ratio, llm_compress_provider_id=self.llm_compress_provider_id, max_context_length=self.max_context_length, dequeue_context_length=self.dequeue_context_length, diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py index 5c14ffbd0e..71e9dadc9d 100644 --- a/astrbot/core/provider/sources/gemini_embedding_source.py +++ b/astrbot/core/provider/sources/gemini_embedding_source.py @@ -1,4 +1,3 @@ - from google import genai from google.genai import types from google.genai.errors import APIError diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index be670a9b97..618b95bac4 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -265,9 +265,9 @@ "description": "Context Compression Instruction", "hint": "If empty, the default prompt will be used." }, - "llm_compress_keep_recent": { - "description": "Keep Recent Turns When Compressing", - "hint": "Always keep the most recent N turns of conversation when compressing context." + "llm_compress_keep_recent_ratio": { + "description": "Recent Context Token Ratio to Keep", + "hint": "Keep recent exact context by current context token ratio, from 0-0.3. 0.15 means keeping 15%; values above 0 keep at least the latest round." }, "llm_compress_provider_id": { "description": "Model Provider ID for Context Compression", diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index cf450bfa5e..c42a3313a5 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -265,9 +265,9 @@ "description": "Инструкция для сжатия контекста", "hint": "Если пусто, используется промпт по умолчанию." }, - "llm_compress_keep_recent": { - "description": "Сохранять последние раунды при сжатии", - "hint": "Всегда оставлять последние N раундов диалога без изменений при сжатии." + "llm_compress_keep_recent_ratio": { + "description": "Доля последних токенов контекста при сжатии", + "hint": "Сохраняет последние сообщения по доле текущих токенов контекста, от 0 до 0.3. 0.15 означает 15%; значение выше 0 сохраняет как минимум последний раунд." }, "llm_compress_provider_id": { "description": "Модель для сжатия контекста", diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 3ff951e343..200b3b9fe1 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -267,9 +267,9 @@ "description": "上下文压缩提示词", "hint": "如果为空则使用默认提示词。" }, - "llm_compress_keep_recent": { - "description": "压缩时保留最近对话轮数", - "hint": "始终保留的最近 N 轮对话。" + "llm_compress_keep_recent_ratio": { + "description": "压缩时保留最近上下文比例", + "hint": "按当前上下文 token 数保留最近内容,范围 0-0.3。0.15 表示保留 15%;比例大于 0 时至少保留最后一轮。" }, "llm_compress_provider_id": { "description": "用于上下文压缩的模型提供商 ID", diff --git a/tests/agent/test_context_manager.py b/tests/agent/test_context_manager.py index a14b49691a..8e9e601b3f 100644 --- a/tests/agent/test_context_manager.py +++ b/tests/agent/test_context_manager.py @@ -12,7 +12,7 @@ from astrbot.core.agent.context.config import ContextConfig from astrbot.core.agent.context.manager import ContextManager -from astrbot.core.agent.message import Message, TextPart +from astrbot.core.agent.message import AudioURLPart, ImageURLPart, Message, TextPart from astrbot.core.provider.entities import LLMResponse @@ -25,14 +25,14 @@ def __init__(self): "model": "gpt-4", "modalities": ["text", "image", "tool_use"], } + self.last_text_chat_kwargs = None async def text_chat(self, **kwargs): """模拟 LLM 调用,返回摘要""" - messages = kwargs.get("messages", []) - # 简单的摘要逻辑:返回消息数量统计 + self.last_text_chat_kwargs = kwargs return LLMResponse( role="assistant", - completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。", + completion_text="Summary of conversation: Hello and discussed various topics.", ) def get_model(self): @@ -76,7 +76,7 @@ def test_init_with_llm_compressor(self): mock_provider = MockProvider() config = ContextConfig( llm_compress_provider=mock_provider, # type: ignore - llm_compress_keep_recent=5, + llm_compress_keep_recent_ratio=0.15, llm_compress_instruction="Summarize the conversation", ) manager = ContextManager(config) @@ -102,7 +102,7 @@ async def test_llm_compressor_keeps_history_when_summary_is_empty(self): provider.text_chat = AsyncMock( return_value=LLMResponse(role="assistant", completion_text=" ") ) - compressor = LLMSummaryCompressor(provider=provider, keep_recent=2) # type: ignore[arg-type] + compressor = LLMSummaryCompressor(provider=provider, keep_recent_ratio=0.15) # type: ignore[arg-type] messages = self.create_messages(6) with patch("astrbot.core.agent.context.compressor.logger") as mock_logger: @@ -113,6 +113,269 @@ async def test_llm_compressor_keeps_history_when_summary_is_empty(self): "LLM context compression returned an empty summary." ) + @pytest.mark.asyncio + async def test_llm_compressor_handles_textpart_content(self): + from astrbot.core.agent.context.compressor import LLMSummaryCompressor + + provider = MockProvider() + compressor = LLMSummaryCompressor(provider=provider, keep_recent_ratio=0.01) # type: ignore[arg-type] + messages = [ + Message(role="user", content=[TextPart(text="Hello")]), + Message(role="assistant", content=[TextPart(text="Hi there")]), + Message(role="user", content=[TextPart(text="Summarize our work")]), + Message(role="assistant", content=[TextPart(text="Sure")]), + ] + + result = await compressor(messages) + + assert provider.last_text_chat_kwargs is not None + assert "prompt" not in provider.last_text_chat_kwargs + assert "system_prompt" not in provider.last_text_chat_kwargs + summary_contexts = provider.last_text_chat_kwargs["contexts"] + assert summary_contexts[0] == { + "role": "user", + "content": [{"type": "text", "text": "Hello"}], + } + assert summary_contexts[1] == { + "role": "assistant", + "content": [{"type": "text", "text": "Hi there"}], + } + assert summary_contexts[-1]["role"] == "user" + assert compressor.instruction_text in summary_contexts[-1]["content"] + assert ( + compressor.TASK_CONTINUATION_INSTRUCTION in summary_contexts[-1]["content"] + ) + + assert len(result) == 4 + assert result[0].role == "user" + assert isinstance(result[0].content, str) + assert result[0].content.strip() + assert "Hello" in result[0].content + assert result[-1].content == [TextPart(text="Sure")] + + @pytest.mark.asyncio + async def test_llm_compressor_preserves_system_and_pads_before_instruction(self): + from astrbot.core.agent.context.compressor import LLMSummaryCompressor + + provider = MockProvider() + instruction = "Summarize the old context." + compressor = LLMSummaryCompressor( + provider=provider, + keep_recent_ratio=0.01, + instruction_text=instruction, + ) # type: ignore[arg-type] + messages = [ + Message(role="system", content="System prompt"), + Message(role="user", content="Old question"), + Message(role="user", content="Current question"), + ] + + result = await compressor(messages) + + summary_contexts = provider.last_text_chat_kwargs["contexts"] + assert summary_contexts[0] == {"role": "system", "content": "System prompt"} + assert summary_contexts[1] == {"role": "user", "content": "Old question"} + assert summary_contexts[2]["role"] == "assistant" + assert summary_contexts[2]["content"] + assert summary_contexts[3]["role"] == "user" + assert instruction in summary_contexts[3]["content"] + assert ( + compressor.TASK_CONTINUATION_INSTRUCTION in summary_contexts[3]["content"] + ) + + assert result[0] is messages[0] + assert result[-1] is messages[-1] + + @pytest.mark.asyncio + async def test_llm_compressor_summarizes_single_long_round(self): + from astrbot.core.agent.context.compressor import LLMSummaryCompressor + + provider = MockProvider() + compressor = LLMSummaryCompressor( + provider=provider, + keep_recent_ratio=0.15, + instruction_text="Summarize the whole trajectory.", + ) # type: ignore[arg-type] + messages = [ + Message(role="user", content="Run the tool."), + Message( + role="assistant", + content="Calling tool", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + ), + Message(role="tool", content="x" * 1000, tool_call_id="call_1"), + ] + + result = await compressor(messages) + + summary_contexts = provider.last_text_chat_kwargs["contexts"] + assert summary_contexts[0] == {"role": "user", "content": "Run the tool."} + assert summary_contexts[1]["role"] == "assistant" + assert summary_contexts[1]["tool_calls"] + assert summary_contexts[2]["role"] == "tool" + assert summary_contexts[2]["tool_call_id"] == "call_1" + assert summary_contexts[3]["role"] == "assistant" + assert summary_contexts[4]["role"] == "user" + assert "Summarize the whole trajectory." in summary_contexts[4]["content"] + assert ( + compressor.TASK_CONTINUATION_INSTRUCTION in summary_contexts[4]["content"] + ) + assert all(original not in result for original in messages) + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_llm_compressor_preserves_active_user_request(self): + from astrbot.core.agent.context.compressor import LLMSummaryCompressor + + provider = MockProvider() + compressor = LLMSummaryCompressor( + provider=provider, + keep_recent_ratio=0, + instruction_text="Summarize old context.", + ) # type: ignore[arg-type] + messages = [ + Message(role="user", content="Old question"), + Message(role="assistant", content="Old answer"), + Message(role="user", content="Current question"), + ] + + result = await compressor(messages) + + summary_contexts = provider.last_text_chat_kwargs["contexts"] + assert summary_contexts[0] == {"role": "user", "content": "Old question"} + assert summary_contexts[1] == {"role": "assistant", "content": "Old answer"} + assert not any( + msg.get("content") == "Current question" for msg in summary_contexts + ) + assert result[-1] is messages[2] + + @pytest.mark.asyncio + async def test_llm_compressor_does_not_summarize_only_active_user_request(self): + from astrbot.core.agent.context.compressor import LLMSummaryCompressor + + provider = MockProvider() + compressor = LLMSummaryCompressor( + provider=provider, + keep_recent_ratio=0.15, + instruction_text="Summarize old context.", + ) # type: ignore[arg-type] + messages = [Message(role="user", content="Current question")] + + result = await compressor(messages) + + assert result == messages + assert provider.last_text_chat_kwargs is None + + @pytest.mark.asyncio + async def test_llm_compressor_summarizes_system_plus_single_completed_round(self): + from astrbot.core.agent.context.compressor import LLMSummaryCompressor + + provider = MockProvider() + compressor = LLMSummaryCompressor( + provider=provider, + keep_recent_ratio=0.15, + instruction_text="Summarize the completed round.", + ) # type: ignore[arg-type] + messages = [ + Message(role="system", content="System prompt"), + Message(role="user", content="Question"), + Message(role="assistant", content="x" * 1000), + ] + + result = await compressor(messages) + + summary_contexts = provider.last_text_chat_kwargs["contexts"] + assert summary_contexts[0]["role"] == "system" + assert summary_contexts[1]["role"] == "user" + assert summary_contexts[2]["role"] == "assistant" + assert len(result) == 3 + assert result[0] is messages[0] + assert result[1].role == "user" + assert result[2].role == "assistant" + + @pytest.mark.asyncio + async def test_llm_compressor_sanitizes_context_for_text_only_provider(self): + from astrbot.core.agent.context.compressor import LLMSummaryCompressor + + provider = MockProvider() + provider.provider_config["modalities"] = ["text"] + compressor = LLMSummaryCompressor( + provider=provider, + keep_recent_ratio=0, + instruction_text="Summarize multimodal and tool history.", + ) # type: ignore[arg-type] + messages = [ + Message( + role="user", + content=[ + TextPart(text="Please inspect this."), + ImageURLPart( + image_url=ImageURLPart.ImageURL(url="data:image/png;base64,abc") + ), + AudioURLPart( + audio_url=AudioURLPart.AudioURL(url="data:audio/wav;base64,abc") + ), + ], + ), + Message( + role="assistant", + content="Calling tool", + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "inspect", "arguments": "{}"}, + } + ], + ), + Message(role="tool", content="tool output", tool_call_id="call_1"), + Message(role="assistant", content="Done"), + ] + + await compressor(messages) + + summary_contexts = provider.last_text_chat_kwargs["contexts"] + assert summary_contexts[0]["content"][1] == {"type": "text", "text": "[Image]"} + assert summary_contexts[0]["content"][2] == {"type": "text", "text": "[Audio]"} + assert "tool_calls" not in summary_contexts[1] + assert summary_contexts[2] == { + "role": "user", + "content": "[Tool result]\ntool output", + } + + @pytest.mark.asyncio + async def test_llm_compressor_keeps_recent_by_token_ratio(self): + from astrbot.core.agent.context.compressor import LLMSummaryCompressor + + provider = MockProvider() + compressor = LLMSummaryCompressor( + provider=provider, + keep_recent_ratio=0.3, + instruction_text="Summarize.", + ) # type: ignore[arg-type] + messages = [ + Message(role="user", content="a" * 200), + Message(role="assistant", content="b" * 200), + Message(role="user", content="c" * 10), + Message(role="assistant", content="d" * 10), + Message(role="user", content="e" * 10), + Message(role="assistant", content="f" * 10), + ] + + result = await compressor(messages) + + summary_contexts = provider.last_text_chat_kwargs["contexts"] + assert summary_contexts[0] == {"role": "user", "content": "a" * 200} + assert summary_contexts[1] == {"role": "assistant", "content": "b" * 200} + assert not any(msg.get("content") == "c" * 10 for msg in summary_contexts) + assert result[-4:] == messages[2:] + # ==================== Empty and Edge Cases ==================== @pytest.mark.asyncio @@ -613,7 +876,7 @@ async def test_config_persistence(self): max_context_tokens=500, enforce_max_turns=5, truncate_turns=2, - llm_compress_keep_recent=3, + llm_compress_keep_recent_ratio=0.15, ) manager = ContextManager(config) @@ -621,7 +884,7 @@ async def test_config_persistence(self): assert manager.config.max_context_tokens == 500 assert manager.config.enforce_max_turns == 5 assert manager.config.truncate_turns == 2 - assert manager.config.llm_compress_keep_recent == 3 + assert manager.config.llm_compress_keep_recent_ratio == 0.15 # ==================== Run Compression Tests ==================== @@ -685,7 +948,7 @@ async def test_llm_compression_with_mock_provider(self): mock_provider = MockProvider() config = ContextConfig( llm_compress_provider=mock_provider, # type: ignore - llm_compress_keep_recent=3, + llm_compress_keep_recent_ratio=0.15, llm_compress_instruction="请总结对话内容", max_context_tokens=100, ) @@ -759,5 +1022,24 @@ def test_split_rounds_multi_tool(self): def test_split_rounds_empty(self): """Empty list returns no rounds.""" from astrbot.core.agent.context.round_utils import split_into_rounds + rounds = split_into_rounds([]) assert len(rounds) == 0 + + def test_split_rounds_accepts_message_objects(self): + """Message objects can be split without converting them to dictionaries.""" + from astrbot.core.agent.context.round_utils import split_into_rounds + + messages = [ + Message(role="system", content="System prompt"), + Message(role="user", content=[TextPart(text="hello")]), + Message(role="assistant", content="hi"), + Message(role="user", content="next"), + ] + + rounds = split_into_rounds(messages) + + assert len(rounds) == 3 + assert rounds[0][0] is messages[0] + assert rounds[1][0] is messages[1] + assert rounds[2][0] is messages[3] diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 304eda6374..db729a23ba 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -985,6 +985,39 @@ async def test_build_main_agent_basic( assert result is not None assert isinstance(result, module.MainAgentBuildResult) + @pytest.mark.asyncio + async def test_build_main_agent_passes_max_context_length_to_runner( + self, mock_event, mock_context, mock_provider + ): + module = ama + mock_context.get_provider_by_id.return_value = None + mock_context.get_using_provider.return_value = mock_provider + mock_context.get_config.return_value = {} + + conv_mgr = mock_context.conversation_manager + _setup_conversation_for_build(conv_mgr) + + with ( + patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, + patch("astrbot.core.astr_main_agent.AstrAgentContext"), + ): + mock_runner = MagicMock() + mock_runner.reset = AsyncMock() + mock_runner_cls.return_value = mock_runner + + result = await module.build_main_agent( + event=mock_event, + plugin_context=mock_context, + config=module.MainAgentBuildConfig( + tool_call_timeout=60, + max_context_length=7, + ), + ) + + assert result is not None + mock_runner.reset.assert_awaited_once() + assert mock_runner.reset.await_args.kwargs["enforce_max_turns"] == 7 + @pytest.mark.asyncio async def test_build_main_agent_no_provider(self, mock_event, mock_context): """Test building main agent when no provider is available."""