diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index ac5801994a..2bca7062a7 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -6,6 +6,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path VERSION = "4.25.2" +ASTRBOT_USER_AGENT = f"astrbot/{VERSION.removeprefix('v')}" DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") PERSONAL_WECHAT_CONFIG_METADATA = { "weixin_oc_base_url": { @@ -1199,7 +1200,7 @@ "api_base": "https://api.kimi.com/coding", "timeout": 120, "proxy": "", - "custom_headers": {"User-Agent": "claude-code/0.1.0"}, + "custom_headers": {"User-Agent": ASTRBOT_USER_AGENT}, "anth_thinking_config": {"type": "", "budget": 0, "effort": ""}, }, "Moonshot": { @@ -1236,7 +1237,7 @@ "api_base": "https://api.minimaxi.com/anthropic", "timeout": 120, "proxy": "", - "custom_headers": {"User-Agent": "claude-code/0.1.0"}, + "custom_headers": {"User-Agent": ASTRBOT_USER_AGENT}, "anth_thinking_config": {"type": "", "budget": 0, "effort": ""}, }, "Xiaomi": { @@ -1261,7 +1262,7 @@ "api_base": "https://token-plan-cn.xiaomimimo.com/anthropic", "timeout": 120, "proxy": "", - "custom_headers": {"User-Agent": "claude-code/0.1.0"}, + "custom_headers": {"User-Agent": ASTRBOT_USER_AGENT}, "anth_thinking_config": {"type": "", "budget": 0, "effort": ""}, }, "xAI": { diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 8b1efab1b6..fb78884002 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -13,9 +13,11 @@ from astrbot import logger from astrbot.api.provider import Provider from astrbot.core.agent.message import AudioURLPart, ContentPart, ImageURLPart, TextPart +from astrbot.core.config.default import ASTRBOT_USER_AGENT from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.entities import LLMResponse, TokenUsage from astrbot.core.provider.func_tool_manager import ToolSet +from astrbot.core.utils.http_headers import apply_default_headers, normalize_headers from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.network_utils import ( create_proxy_client, @@ -50,13 +52,12 @@ def _ensure_usable_response( @staticmethod def _normalize_custom_headers(provider_config: dict) -> dict[str, str] | None: - custom_headers = provider_config.get("custom_headers", {}) - if not isinstance(custom_headers, dict) or not custom_headers: + normalized_headers = normalize_headers( + provider_config.get("custom_headers", {}) + ) + if not normalized_headers: return None - normalized_headers: dict[str, str] = {} - for key, value in custom_headers.items(): - normalized_headers[str(key)] = str(value) - return normalized_headers or None + return normalized_headers @classmethod def _resolve_custom_headers( @@ -67,9 +68,7 @@ def _resolve_custom_headers( ) -> dict[str, str] | None: merged_headers = cls._normalize_custom_headers(provider_config) or {} if required_headers: - for header_name, header_value in required_headers.items(): - if not merged_headers.get(header_name, "").strip(): - merged_headers[header_name] = header_value + merged_headers = apply_default_headers(merged_headers, required_headers) return merged_headers or None def __init__( @@ -89,7 +88,10 @@ def __init__( if isinstance(self.timeout, str): self.timeout = int(self.timeout) self.thinking_config = provider_config.get("anth_thinking_config", {}) - self.custom_headers = self._resolve_custom_headers(provider_config) + self.custom_headers = self._resolve_custom_headers( + provider_config, + required_headers={"User-Agent": ASTRBOT_USER_AGENT}, + ) if use_api_key: self._init_api_key(provider_config) diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index f38fcfc359..99cf1399c3 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -18,11 +18,13 @@ from astrbot import logger from astrbot.api.provider import Provider from astrbot.core.agent.message import AudioURLPart, ContentPart, ImageURLPart, TextPart +from astrbot.core.config.default import ASTRBOT_USER_AGENT from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, TokenUsage from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.core.utils.http_headers import apply_default_headers, normalize_headers from astrbot.core.utils.io import download_file, download_image_by_url from astrbot.core.utils.media_utils import ensure_wav from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure @@ -76,17 +78,41 @@ def __init__( if self.api_base and self.api_base.endswith("/"): self.api_base = self.api_base[:-1] + self.custom_headers = self._resolve_custom_headers(provider_config) self._http_client: httpx.AsyncClient | None = None self._stale_http_clients: list[httpx.AsyncClient] = [] self._init_client() self.set_model(provider_config.get("model", "unknown")) self._init_safety_settings() + @staticmethod + def _resolve_custom_headers(provider_config: dict) -> dict[str, str]: + headers = apply_default_headers( + normalize_headers(provider_config.get("custom_headers", {})), + {"user-agent": ASTRBOT_USER_AGENT}, + ) + return { + "user-agent" if key.lower() == "user-agent" else key: value + for key, value in headers.items() + } + + @staticmethod + def _set_gemini_user_agent(client: object, user_agent: str) -> None: + api_client = getattr(client, "_api_client", None) + http_options = getattr(api_client, "_http_options", None) + if http_options is None or http_options.headers is None: + return + for key in list(http_options.headers): + if key.lower() == "user-agent": + http_options.headers.pop(key) + http_options.headers["user-agent"] = user_agent + def _init_client(self) -> None: """初始化Gemini客户端""" proxy = self.provider_config.get("proxy", "") http_options = types.HttpOptions( base_url=self.api_base, + headers=dict(self.custom_headers), timeout=self.timeout * 1000, # 毫秒 ) @@ -94,6 +120,7 @@ def _init_client(self) -> None: # httpx.AsyncClient 的 timeout 单位为秒(与 HttpOptions 的毫秒不同) async_client_kwargs: dict = { "base_url": self.api_base, + "headers": dict(self.custom_headers), "timeout": self.timeout, } if proxy: @@ -112,10 +139,15 @@ def _init_client(self) -> None: self._http_client = httpx.AsyncClient(**async_client_kwargs) http_options.httpx_async_client = self._http_client - self.client = genai.Client( + genai_client = genai.Client( api_key=self.chosen_api_key, http_options=http_options, - ).aio + ) + self._set_gemini_user_agent( + genai_client, + self.custom_headers["user-agent"], + ) + self.client = genai_client.aio def _init_safety_settings(self) -> None: """初始化安全设置""" diff --git a/astrbot/core/provider/sources/kimi_code_source.py b/astrbot/core/provider/sources/kimi_code_source.py index 02c200271f..cc52e09eef 100644 --- a/astrbot/core/provider/sources/kimi_code_source.py +++ b/astrbot/core/provider/sources/kimi_code_source.py @@ -1,9 +1,11 @@ +from astrbot.core.config.default import ASTRBOT_USER_AGENT + from ..register import register_provider_adapter from .anthropic_source import ProviderAnthropic KIMI_CODE_API_BASE = "https://api.kimi.com/coding" KIMI_CODE_DEFAULT_MODEL = "kimi-for-coding" -KIMI_CODE_USER_AGENT = "claude-code/0.1.0" +KIMI_CODE_USER_AGENT = ASTRBOT_USER_AGENT @register_provider_adapter( diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 8aa2778f1b..873f8b5ea5 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -34,10 +34,12 @@ TextPart, ) from astrbot.core.agent.tool import ToolSet +from astrbot.core.config.default import ASTRBOT_USER_AGENT from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult from astrbot.core.utils.astrbot_path import get_astrbot_temp_path +from astrbot.core.utils.http_headers import apply_default_headers, normalize_headers from astrbot.core.utils.io import download_file, download_image_by_url from astrbot.core.utils.media_utils import ensure_wav from astrbot.core.utils.network_utils import ( @@ -68,6 +70,13 @@ class ProviderOpenAIOfficial(Provider): "AVIF": "image/avif", } + @staticmethod + def _resolve_custom_headers(provider_config: dict) -> dict[str, str]: + return apply_default_headers( + normalize_headers(provider_config.get("custom_headers", {})), + {"User-Agent": ASTRBOT_USER_AGENT}, + ) + @classmethod def _truncate_error_text_candidate(cls, text: str) -> str: if len(text) <= cls._ERROR_TEXT_CANDIDATE_MAX_CHARS: @@ -498,16 +507,10 @@ def __init__(self, provider_config, provider_settings) -> None: self.api_keys: list = super().get_keys() self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None self.timeout = provider_config.get("timeout", 120) - self.custom_headers = provider_config.get("custom_headers", {}) + self.custom_headers = self._resolve_custom_headers(provider_config) if isinstance(self.timeout, str): self.timeout = int(self.timeout) - if not isinstance(self.custom_headers, dict) or not self.custom_headers: - self.custom_headers = None - else: - for key in self.custom_headers: - self.custom_headers[key] = str(self.custom_headers[key]) - if "api_version" in provider_config: # Using Azure OpenAI API self.client = AsyncAzureOpenAI( diff --git a/astrbot/core/utils/http_headers.py b/astrbot/core/utils/http_headers.py new file mode 100644 index 0000000000..f408880ad4 --- /dev/null +++ b/astrbot/core/utils/http_headers.py @@ -0,0 +1,31 @@ +from collections.abc import Mapping + + +def normalize_headers(headers: object) -> dict[str, str]: + if not isinstance(headers, dict): + return {} + return {str(key): str(value) for key, value in headers.items()} + + +def apply_default_headers( + headers: dict[str, str], + default_headers: Mapping[str, str], +) -> dict[str, str]: + merged_headers = dict(headers) + for default_name, default_value in default_headers.items(): + existing_name = next( + ( + header_name + for header_name in merged_headers + if header_name.lower() == default_name.lower() + ), + None, + ) + if existing_name is None: + merged_headers[default_name] = default_value + continue + if merged_headers[existing_name].strip(): + continue + merged_headers.pop(existing_name) + merged_headers[default_name] = default_value + return merged_headers diff --git a/tests/test_anthropic_kimi_code_provider.py b/tests/test_anthropic_kimi_code_provider.py index a8d60927e0..470b1aa440 100644 --- a/tests/test_anthropic_kimi_code_provider.py +++ b/tests/test_anthropic_kimi_code_provider.py @@ -4,6 +4,7 @@ import astrbot.core.provider.sources.anthropic_source as anthropic_source import astrbot.core.provider.sources.kimi_code_source as kimi_code_source +from astrbot.core.config.default import ASTRBOT_USER_AGENT from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.entities import LLMResponse @@ -16,6 +17,25 @@ async def close(self): return None +def test_anthropic_provider_uses_astrbot_default_user_agent(monkeypatch): + monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic) + + provider = anthropic_source.ProviderAnthropic( + provider_config={ + "id": "anthropic-test", + "type": "anthropic_chat_completion", + "model": "claude-test", + "key": ["test-key"], + }, + provider_settings={}, + ) + + assert provider.custom_headers == {"User-Agent": ASTRBOT_USER_AGENT} + assert provider.client.kwargs["default_headers"] == { + "User-Agent": ASTRBOT_USER_AGENT, + } + + def test_anthropic_provider_passes_custom_headers_via_default_headers(monkeypatch): monkeypatch.setattr(anthropic_source, "AsyncAnthropic", _FakeAsyncAnthropic) diff --git a/tests/test_gemini_source.py b/tests/test_gemini_source.py index 4db8e92bfe..928687804a 100644 --- a/tests/test_gemini_source.py +++ b/tests/test_gemini_source.py @@ -1,10 +1,48 @@ import pytest +from astrbot.core.config.default import ASTRBOT_USER_AGENT from astrbot.core.exceptions import EmptyModelOutputError +import astrbot.core.provider.sources.gemini_source as gemini_source from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.sources.gemini_source import ProviderGoogleGenAI +class _FakeGenAIClient: + def __init__(self, **kwargs): + self.kwargs = kwargs + self._api_client = type( + "FakeAPIClient", + (), + {"_http_options": kwargs["http_options"]}, + )() + self.aio = type("FakeAioClient", (), {"_api_client": self._api_client})() + + +@pytest.mark.asyncio +async def test_gemini_provider_uses_astrbot_default_user_agent(monkeypatch): + monkeypatch.setattr(gemini_source.genai, "Client", _FakeGenAIClient) + + provider = ProviderGoogleGenAI( + provider_config={ + "id": "gemini-test", + "type": "googlegenai_chat_completion", + "model": "gemini-test", + "key": ["test-key"], + "api_base": "https://generativelanguage.googleapis.com/", + }, + provider_settings={}, + ) + + try: + assert provider.custom_headers["user-agent"] == ASTRBOT_USER_AGENT + assert provider.client._api_client._http_options.headers["user-agent"] == ( + ASTRBOT_USER_AGENT + ) + assert provider._http_client.headers["user-agent"] == ASTRBOT_USER_AGENT + finally: + await provider.terminate() + + def test_gemini_empty_output_raises_empty_model_output_error(): llm_response = LLMResponse(role="assistant") diff --git a/tests/test_openai_source.py b/tests/test_openai_source.py index b5587ffb14..831f5d21a4 100644 --- a/tests/test_openai_source.py +++ b/tests/test_openai_source.py @@ -9,6 +9,7 @@ from PIL import Image as PILImage import astrbot.core.provider.sources.openai_source as openai_source_module +from astrbot.core.config.default import ASTRBOT_USER_AGENT from astrbot.core.exceptions import EmptyModelOutputError from astrbot.core.provider.sources.groq_source import ProviderGroq from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial @@ -26,6 +27,17 @@ def __init__(self, message: str, response_text: str): self.response = SimpleNamespace(text=response_text) +class _FakeChatCompletions: + def create(self): + return None + + +class _FakeAsyncOpenAI: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.chat = SimpleNamespace(completions=_FakeChatCompletions()) + + def _make_provider(overrides: dict | None = None) -> ProviderOpenAIOfficial: provider_config = { "id": "test-openai", @@ -56,6 +68,39 @@ def _make_groq_provider(overrides: dict | None = None) -> ProviderGroq: ) +def test_openai_provider_uses_astrbot_default_user_agent(monkeypatch): + monkeypatch.setattr(openai_source_module, "AsyncOpenAI", _FakeAsyncOpenAI) + + provider = _make_provider() + + assert provider.custom_headers == {"User-Agent": ASTRBOT_USER_AGENT} + assert provider.client.kwargs["default_headers"] == { + "User-Agent": ASTRBOT_USER_AGENT, + } + + +def test_openai_provider_preserves_custom_user_agent(monkeypatch): + monkeypatch.setattr(openai_source_module, "AsyncOpenAI", _FakeAsyncOpenAI) + + provider = _make_provider( + { + "custom_headers": { + "User-Agent": "custom-agent/1.0", + "X-Test-Header": 123, + }, + }, + ) + + assert provider.custom_headers == { + "User-Agent": "custom-agent/1.0", + "X-Test-Header": "123", + } + assert provider.client.kwargs["default_headers"] == { + "User-Agent": "custom-agent/1.0", + "X-Test-Header": "123", + } + + def test_create_http_client_uses_openai_httpx_module(monkeypatch): captured: dict[str, object] = {}