Skip to content
157 changes: 108 additions & 49 deletions astrbot/core/agent/context/compressor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from typing import TYPE_CHECKING, Protocol, runtime_checkable

from ...provider.modalities import (
log_context_sanitize_stats,
sanitize_contexts_by_modalities,
)
from ..message import Message
from .token_counter import EstimateTokenCounter, TokenCounter

if TYPE_CHECKING:
from astrbot import logger
Expand Down Expand Up @@ -96,23 +101,6 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
return truncated_messages


def _message_to_dict(msg: Message) -> dict:
"""Convert a Message to a plain dict suitable for round splitting."""
d = {"role": msg.role}
if msg.content is not None:
d["content"] = msg.content
if getattr(msg, "tool_calls", None):
d["tool_calls"] = msg.tool_calls
if getattr(msg, "tool_call_id", None):
d["tool_call_id"] = msg.tool_call_id
return d


def _dict_to_message(d: dict) -> Message:
"""Convert a plain dict back to a Message."""
return Message(**d)


def _extract_system_messages(messages: list[Message]) -> list[Message]:
"""Return the leading system messages from a message list."""
result = []
Expand All @@ -126,28 +114,36 @@ def _extract_system_messages(messages: list[Message]) -> list[Message]:

class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.
Uses LLM to summarize old conversation history while keeping a recent token
budget as exact context.
"""

TASK_CONTINUATION_INSTRUCTION = (
"If a task appears to be in progress, end the summary with the latest "
"known result and the concrete next step to continue the task."
)

def __init__(
self,
provider: "Provider",
keep_recent: int = 4,
keep_recent_ratio: float = 0.15,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
token_counter: TokenCounter | None = None,
) -> None:
"""Initialize the LLM summary compressor.
Args:
provider: The LLM provider instance.
keep_recent: The number of latest messages to keep (default: 4).
keep_recent_ratio: Ratio of current context tokens to keep as recent
exact context. Clamped to 0-0.3.
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.provider = provider
self.keep_recent = keep_recent
self.keep_recent_ratio = min(max(float(keep_recent_ratio), 0.0), 0.3)
self.compression_threshold = compression_threshold
self.existing_summary: str = ""
self.token_counter = token_counter or EstimateTokenCounter()

self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
Expand Down Expand Up @@ -177,47 +173,109 @@ def should_compress(
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold

def _split_recent_rounds_by_token_ratio(
self,
rounds: list[list[Message]],
total_tokens: int,
) -> tuple[list[list[Message]], list[list[Message]]]:
"""Split rounds into summarised history and exact recent context.
The token budget is computed from the current context token count and
`keep_recent_ratio`, then floored by `int(...)`. Mapping that budget to
rounds is round-granular: a positive ratio always preserves the latest
whole round, even if that round itself exceeds the budget. Earlier
rounds are added only while the accumulated recent rounds stay within
the budget. No round is split.
"""
if not rounds or self.keep_recent_ratio <= 0 or total_tokens <= 0:
return rounds, []

budget = max(1, int(total_tokens * self.keep_recent_ratio))
used = 0
recent_start = len(rounds)

for idx in range(len(rounds) - 1, -1, -1):
round_tokens = self.token_counter.count_tokens(rounds[idx])
if used > 0 and used + round_tokens > budget:
break
used += round_tokens
recent_start = idx

return rounds[:recent_start], rounds[recent_start:]

async def __call__(self, messages: list[Message]) -> list[Message]:
"""Use LLM to generate a summary of the conversation history.
Uses round-based splitting to preserve user-assistant turn boundaries.
On LLM failure, returns the original messages unchanged (caller should
fall back to truncation).
"""
from .round_utils import rounds_to_text, split_into_rounds

# Convert messages to dict list for round splitting
msg_dicts = [_message_to_dict(m) for m in messages]
rounds = split_into_rounds(msg_dicts)

if len(rounds) <= self.keep_recent:
return messages
from .round_utils import split_into_rounds

rounds = split_into_rounds(messages)
message_rounds = [
[seg for seg in rnd if isinstance(seg, Message)] for rnd in rounds
]
total_tokens = self.token_counter.count_tokens(messages)
old_rounds, recent_rounds = self._split_recent_rounds_by_token_ratio(
message_rounds,
total_tokens,
)

old_rounds = rounds[: -self.keep_recent]
recent_rounds = rounds[-self.keep_recent :]
# The latest user message is the active request. Keep its whole round
# exact even when the ratio is 0 or the ratio budget would otherwise
# summarize every round.
if messages and messages[-1].role == "user" and old_rounds:
latest_old_round = old_rounds[-1]
if latest_old_round and latest_old_round[-1] is messages[-1]:
old_rounds = old_rounds[:-1]
recent_rounds = [latest_old_round, *recent_rounds]

if not old_rounds:
return messages

# Build LLM payload
old_text = rounds_to_text(old_rounds)
existing_note = ""
if self.existing_summary:
existing_note = (
"\nExisting memory summary (merge with old rounds above):\n"
f"{self.existing_summary}\n"
if recent_rounds and messages and messages[-1].role == "user":
return messages
old_rounds = message_rounds
recent_rounds = []

summary_contexts = [msg for rnd in old_rounds for msg in rnd]
if not any(msg.role != "system" for msg in summary_contexts):
if recent_rounds and messages and messages[-1].role == "user":
return messages
old_rounds = message_rounds
recent_rounds = []
summary_contexts = [msg for rnd in old_rounds for msg in rnd]
if not any(msg.role != "system" for msg in summary_contexts):
return messages

if summary_contexts[-1].role != "assistant":
summary_contexts.append(
Message(
role="assistant",
content="Acknowledged.",
)
)
summary_contexts.append(
Message(
role="user",
content=(
"Generate a summary of our previous conversation history.\n"
f"<extra_instruction>\n{self.instruction_text}\n\n"
f"{self.TASK_CONTINUATION_INSTRUCTION}</extra_instruction>\n"
"Respond ONLY with the summary content, without any additional text or formatting."
),
)
prompt = (
f"{self.instruction_text}\n\n"
"--- BEGIN CONVERSATION ROUNDS TO SUMMARIZE ---\n"
f"{old_text}\n"
"--- END CONVERSATION ROUNDS ---"
f"{existing_note}"
)
sanitized_summary_contexts, sanitize_stats = sanitize_contexts_by_modalities(
summary_contexts,
self.provider.provider_config.get("modalities", None),
)
log_context_sanitize_stats(sanitize_stats)

# Generate summary
try:
response = await self.provider.text_chat(prompt=prompt)
response = await self.provider.text_chat(
contexts=sanitized_summary_contexts,
)
summary_content = (response.completion_text or "").strip()
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
Expand Down Expand Up @@ -246,6 +304,7 @@ async def __call__(self, messages: list[Message]) -> list[Message]:
# Flatten recent rounds back to message list
for rnd in recent_rounds:
for seg in rnd:
result.append(_dict_to_message(seg))
if isinstance(seg, Message):
result.append(seg)

return result
4 changes: 2 additions & 2 deletions astrbot/core/agent/context/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class ContextConfig:
"""
llm_compress_instruction: str | None = None
"""Instruction prompt for LLM-based compression."""
llm_compress_keep_recent: int = 0
"""Number of recent messages to keep during LLM-based compression."""
llm_compress_keep_recent_ratio: float = 0.15
"""Percent of current context tokens to keep as exact recent context during LLM-based compression."""
llm_compress_provider: "Provider | None" = None
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
custom_token_counter: TokenCounter | None = None
Expand Down
3 changes: 2 additions & 1 deletion astrbot/core/agent/context/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def __init__(
elif config.llm_compress_provider:
self.compressor = LLMSummaryCompressor(
provider=config.llm_compress_provider,
keep_recent=config.llm_compress_keep_recent,
keep_recent_ratio=config.llm_compress_keep_recent_ratio,
instruction_text=config.llm_compress_instruction,
token_counter=self.token_counter,
)
else:
self.compressor = TruncateByTurnsCompressor(
Expand Down
54 changes: 44 additions & 10 deletions astrbot/core/agent/context/round_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
"""Round-based utilities shared by LTM compaction and LLMSummaryCompressor."""

import json
from collections.abc import Sequence
from typing import Any

from ..message import ContentPart, Message, ToolCall

RoundSegment = dict[str, Any] | Message


def _segment_role(seg: RoundSegment) -> str:
if isinstance(seg, Message):
return seg.role
return str(seg.get("role", "?"))


def split_into_rounds(
contexts: list[dict[str, Any]],
) -> list[list[dict[str, Any]]]:
contexts: Sequence[RoundSegment],
) -> list[list[RoundSegment]]:
"""Split a flat contexts list into logical rounds.

A round begins at a ``user`` segment and includes all subsequent
``assistant`` / ``tool`` segments until the next ``user`` segment.
"""
rounds: list[list[dict[str, Any]]] = []
current: list[dict[str, Any]] = []
rounds: list[list[RoundSegment]] = []
current: list[RoundSegment] = []
for seg in contexts:
if seg.get("role") == "user" and current:
if _segment_role(seg) == "user" and current:
rounds.append(current)
current = []
current.append(seg)
Expand All @@ -24,15 +35,38 @@ def split_into_rounds(
return rounds


def rounds_to_text(rounds: list[list[dict[str, Any]]]) -> str:
def _content_to_text(content: Any) -> str:
if isinstance(content, list):
normalized = [
part.model_dump_for_context() if isinstance(part, ContentPart) else part
for part in content
]
return json.dumps(normalized, ensure_ascii=False)
if isinstance(content, ContentPart):
return json.dumps(content.model_dump_for_context(), ensure_ascii=False)
return str(content or "")


def _segment_content(seg: RoundSegment) -> Any:
if isinstance(seg, Message):
if seg.content is not None:
return seg.content
if seg.tool_calls:
return [
tc.model_dump() if isinstance(tc, ToolCall) else tc
for tc in seg.tool_calls
]
return ""
return seg.get("content") or seg.get("tool_calls") or ""


def rounds_to_text(rounds: list[list[RoundSegment]]) -> str:
"""Render rounds into a plain-text string for LLM summarisation."""
lines: list[str] = []
for i, rnd in enumerate(rounds, 1):
lines.append(f"--- Round {i} ---")
for seg in rnd:
role = seg.get("role", "?")
content = seg.get("content") or seg.get("tool_calls") or ""
if isinstance(content, list):
content = json.dumps(content, ensure_ascii=False)
role = _segment_role(seg)
content = _content_to_text(_segment_content(seg))
lines.append(f"[{role}] {content}")
return "\n".join(lines)
Loading
Loading