From c6d95cd0066536c7031c3196818e61113795a06a Mon Sep 17 00:00:00 2001 From: jxxghp Date: Fri, 8 May 2026 12:35:02 +0800 Subject: [PATCH] refactor(agent): consolidate provider preset resolution --- app/agent/llm/provider.py | 282 +++++++++++++++------------- scripts/local_setup.py | 6 + tests/test_llm_provider_registry.py | 101 ++++++++++ 3 files changed, 254 insertions(+), 135 deletions(-) diff --git a/app/agent/llm/provider.py b/app/agent/llm/provider.py index c92d4d73..3fc289a5 100644 --- a/app/agent/llm/provider.py +++ b/app/agent/llm/provider.py @@ -51,6 +51,8 @@ class ProviderUrlPreset: id: str label: str value: str + runtime: Optional[str] = None + model_list_strategy: Optional[str] = None model_list_base_url: Optional[str] = None models_dev_provider_id: Optional[str] = None @@ -107,17 +109,6 @@ class LLMProviderManager(metaclass=Singleton): _CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex" _COPILOT_CLIENT_ID = "Ov23li8tweQw6odWQebz" _DEFAULT_TIMEOUT = httpx.Timeout(15.0, connect=10.0) - _CHATGPT_ALLOWED_OAUTH_MODELS = { - "gpt-5.1-codex", - "gpt-5.1-codex-max", - "gpt-5.1-codex-mini", - "gpt-5.2", - "gpt-5.2-codex", - "gpt-5.3-codex", - "gpt-5.4", - "gpt-5.4-mini", - "gpt-5.5", - } _MODELS_DEV_DYNAMIC_SKIP_IDS = { "aihubmix", "amazon-bedrock", @@ -133,7 +124,7 @@ class LLMProviderManager(metaclass=Singleton): "v0", "vercel", } - _MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES = { + _PROVIDER_PATCHES = { "bailing": { "runtime": "openai_compatible", "default_base_url": "https://api.tbox.cn/api/llm/v1", @@ -169,6 +160,14 @@ class LLMProviderManager(metaclass=Singleton): "default_base_url": "https://api.venice.ai/api/v1", "description": "Venice AI 官方兼容端点。", }, + "cloudflare-workers-ai": { + "api_key_hint": "填写 Cloudflare API Token,并将 Base URL 中的 ${CLOUDFLARE_ACCOUNT_ID} 替换为真实账户 ID。", + "description": "Cloudflare Workers AI OpenAI-compatible 端点,需要替换账户 ID。", + }, + "privatemode-ai": { + "api_key_hint": "如未启用鉴权,可填写任意占位值。", + "description": "Privatemode AI 本地 OpenAI-compatible 端点。", + }, } def __init__(self): @@ -203,6 +202,7 @@ class LLMProviderManager(metaclass=Singleton): description="适合无回调环境,复制设备码到浏览器完成登录。", ) url_preset = ProviderUrlPreset + provider_patches = LLMProviderManager._PROVIDER_PATCHES def openai_provider( provider_id: str, name: str, @@ -280,28 +280,18 @@ class LLMProviderManager(metaclass=Singleton): ) catalog_openai_providers = ( - ("huggingface", "Hugging Face", "https://router.huggingface.co/v1"), - ("jiekou", "接口 AI", "https://api.jiekou.ai/openai"), - ("kilo", "Kilo Gateway", "https://api.kilo.ai/api/gateway"), - ("llama", "Llama", "https://api.llama.com/compat/v1/"), - ("llmgateway", "LLM Gateway", "https://api.llmgateway.io/v1"), - ("modelscope", "ModelScope", "https://api-inference.modelscope.cn/v1"), - ("nova", "Nova", "https://api.nova.amazon.com/v1"), - ("fireworks-ai", "Fireworks AI", "https://api.fireworks.ai/inference/v1/"), - ("poe", "Poe", "https://api.poe.com/v1"), - ("qihang-ai", "启航 AI", "https://api.qhaigc.net/v1"), - ("qiniu-ai", "七牛", "https://api.qnaigc.com/v1"), + {"id": "huggingface", "name": "Hugging Face", "base_url": "https://router.huggingface.co/v1"}, + {"id": "jiekou", "name": "接口 AI", "base_url": "https://api.jiekou.ai/openai"}, + {"id": "kilo", "name": "Kilo Gateway", "base_url": "https://api.kilo.ai/api/gateway"}, + {"id": "llama", "name": "Llama", "base_url": "https://api.llama.com/compat/v1/"}, + {"id": "llmgateway", "name": "LLM Gateway", "base_url": "https://api.llmgateway.io/v1"}, + {"id": "modelscope", "name": "ModelScope", "base_url": "https://api-inference.modelscope.cn/v1"}, + {"id": "nova", "name": "Nova", "base_url": "https://api.nova.amazon.com/v1"}, + {"id": "fireworks-ai", "name": "Fireworks AI", "base_url": "https://api.fireworks.ai/inference/v1/"}, + {"id": "poe", "name": "Poe", "base_url": "https://api.poe.com/v1"}, + {"id": "qihang-ai", "name": "启航 AI", "base_url": "https://api.qhaigc.net/v1"}, + {"id": "qiniu-ai", "name": "七牛", "base_url": "https://api.qnaigc.com/v1"}, ) - catalog_openai_overrides = { - "cloudflare-workers-ai": { - "api_key_hint": "填写 Cloudflare API Token,并将 Base URL 中的 ${CLOUDFLARE_ACCOUNT_ID} 替换为真实账户 ID。", - "description": "Cloudflare Workers AI OpenAI-compatible 端点,需要替换账户 ID。", - }, - "privatemode-ai": { - "api_key_hint": "如未启用鉴权,可填写任意占位值。", - "description": "Privatemode AI 本地 OpenAI-compatible 端点。", - }, - } providers = [ ProviderSpec( @@ -436,7 +426,7 @@ class LLMProviderManager(metaclass=Singleton): ), catalog_openai_provider( provider_id="moonshot", - name="Moonshot AI", + name="Moonshot / Kimi", default_base_url="https://api.moonshot.cn/v1", sort_order=62, models_dev_provider_id="moonshotai-cn", @@ -453,18 +443,17 @@ class LLMProviderManager(metaclass=Singleton): value="https://api.moonshot.ai/v1", models_dev_provider_id="moonshotai", ), + url_preset( + id="moonshot-kimi-coding", + label="Kimi for Coding", + value="https://api.kimi.com/coding/v1", + runtime="anthropic_compatible", + model_list_strategy="anthropic_compatible", + models_dev_provider_id="kimi-for-coding", + ), ), - api_key_hint="填写 Moonshot / Kimi API Key,可在中国站与国际站端点间切换。", - description="Moonshot / Kimi 官方兼容端点。", - ), - anthropic_provider( - provider_id="kimi-coding", - name="Kimi for Coding", - default_base_url="https://api.kimi.com/coding/v1", - sort_order=63, - models_dev_provider_id="kimi-for-coding", - api_key_hint="填写 Moonshot / Kimi API Key。", - description="Moonshot Kimi Coding Anthropic-compatible 端点。", + api_key_hint="填写 Moonshot / Kimi API Key,可在中国站、国际站与 Kimi for Coding 端点间切换。", + description="Moonshot / Kimi 官方端点,支持通用 API 与 Kimi for Coding 预设。", ), openai_provider( provider_id="zhipu", @@ -740,16 +729,17 @@ class LLMProviderManager(metaclass=Singleton): ), ] - for sort_order, (provider_id, name, base_url) in enumerate( + for sort_order, provider_entry in enumerate( catalog_openai_providers, start=200, ): - overrides = catalog_openai_overrides.get(provider_id, {}) + provider_id = provider_entry["id"] + overrides = provider_patches.get(provider_id, {}) providers.append( catalog_openai_provider( provider_id=provider_id, - name=name, - default_base_url=base_url, + name=provider_entry["name"], + default_base_url=provider_entry["base_url"], sort_order=sort_order, api_key_hint=overrides.get("api_key_hint"), description=overrides.get("description"), @@ -846,7 +836,7 @@ class LLMProviderManager(metaclass=Singleton): if not normalized_id or normalized_id in cls._MODELS_DEV_DYNAMIC_SKIP_IDS: return None - override = cls._MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES.get(normalized_id, {}) + override = cls._PROVIDER_PATCHES.get(normalized_id, {}) npm_package = str(payload.get("npm") or "").strip() runtime = override.get("runtime") if not runtime: @@ -968,6 +958,8 @@ class LLMProviderManager(metaclass=Singleton): "id": preset.id, "label": preset.label, "value": self._sanitize_base_url(preset.value) or "", + "runtime": preset.runtime, + "model_list_strategy": preset.model_list_strategy, } for preset in spec.base_url_presets ], @@ -1038,6 +1030,8 @@ class LLMProviderManager(metaclass=Singleton): normalized = (provider_id or "").strip().lower() if normalized == "minimax-coding": return "minimax" + if normalized == "kimi-coding": + return "moonshot" return normalized @classmethod @@ -1050,8 +1044,59 @@ class LLMProviderManager(metaclass=Singleton): return None if normalized_provider_id == "minimax" and normalized_preset_id == "minimax-coding": return "minimax-cn-coding" + if normalized_provider_id == "moonshot" and normalized_preset_id == "kimi-coding": + return "moonshot-kimi-coding" return normalized_preset_id + @classmethod + def _resolve_provider_preset( + cls, + spec: ProviderSpec, + base_url: Optional[str], + base_url_preset_id: Optional[str] = None, + ) -> Optional[ProviderUrlPreset]: + normalized_preset_id = cls._normalize_base_url_preset_id(spec.id, base_url_preset_id) + if normalized_preset_id: + for preset in spec.base_url_presets: + if preset.id == normalized_preset_id: + return preset + + normalized_base_url = cls._sanitize_base_url(base_url) + if normalized_base_url: + for preset in spec.base_url_presets: + preset_value = cls._sanitize_base_url(preset.value) + if normalized_base_url == preset_value: + return preset + return None + + default_base_url = cls._default_base_url_for_provider(spec) + if default_base_url: + for preset in spec.base_url_presets: + preset_value = cls._sanitize_base_url(preset.value) + if preset_value == default_base_url: + return preset + return None + + @classmethod + def _resolve_provider_runtime( + cls, + spec: ProviderSpec, + base_url: Optional[str], + base_url_preset_id: Optional[str] = None, + ) -> str: + preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) + return preset.runtime or spec.runtime if preset else spec.runtime + + @classmethod + def _resolve_provider_model_list_strategy( + cls, + spec: ProviderSpec, + base_url: Optional[str], + base_url_preset_id: Optional[str] = None, + ) -> str: + preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) + return preset.model_list_strategy or spec.model_list_strategy if preset else spec.model_list_strategy + @classmethod def _resolve_provider_model_list_base_url( cls, @@ -1059,31 +1104,16 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[str]: - normalized_preset_id = cls._normalize_base_url_preset_id(spec.id, base_url_preset_id) - if normalized_preset_id: - for preset in spec.base_url_presets: - if preset.id != normalized_preset_id: - continue - preset_value = cls._sanitize_base_url(preset.value) - return cls._sanitize_base_url(preset.model_list_base_url) or preset_value + preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) + if preset: + preset_value = cls._sanitize_base_url(preset.value) + return cls._sanitize_base_url(preset.model_list_base_url) or preset_value normalized_base_url = cls._sanitize_base_url(base_url) if normalized_base_url: - for preset in spec.base_url_presets: - preset_value = cls._sanitize_base_url(preset.value) - if normalized_base_url != preset_value: - continue - return cls._sanitize_base_url(preset.model_list_base_url) or preset_value return normalized_base_url - default_base_url = cls._default_base_url_for_provider(spec) - if default_base_url: - for preset in spec.base_url_presets: - preset_value = cls._sanitize_base_url(preset.value) - if preset_value != default_base_url: - continue - return cls._sanitize_base_url(preset.model_list_base_url) or preset_value - return default_base_url + return cls._default_base_url_for_provider(spec) @classmethod def _resolve_provider_models_dev_provider_id( @@ -1092,29 +1122,14 @@ class LLMProviderManager(metaclass=Singleton): base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[str]: - normalized_preset_id = cls._normalize_base_url_preset_id(spec.id, base_url_preset_id) - if normalized_preset_id: - for preset in spec.base_url_presets: - if preset.id != normalized_preset_id: - continue - return preset.models_dev_provider_id or spec.models_dev_provider_id + preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) + if preset: + return preset.models_dev_provider_id or spec.models_dev_provider_id normalized_base_url = cls._sanitize_base_url(base_url) if normalized_base_url: - for preset in spec.base_url_presets: - preset_value = cls._sanitize_base_url(preset.value) - if normalized_base_url != preset_value: - continue - return preset.models_dev_provider_id or spec.models_dev_provider_id return spec.models_dev_provider_id - default_base_url = cls._default_base_url_for_provider(spec) - if default_base_url: - for preset in spec.base_url_presets: - preset_value = cls._sanitize_base_url(preset.value) - if preset_value != default_base_url: - continue - return preset.models_dev_provider_id or spec.models_dev_provider_id return spec.models_dev_provider_id def resolve_model_list_base_url( @@ -1584,50 +1599,33 @@ class LLMProviderManager(metaclass=Singleton): ) return sorted(results, key=lambda i: i["name"].lower()) - async def _list_chatgpt_oauth_models(self) -> list[dict[str, Any]]: - payload = await self._models_dev_provider_payload("chatgpt") + async def _list_chatgpt_oauth_models( + self, + provider_id: str, + base_url: Optional[str] = None, + base_url_preset_id: Optional[str] = None, + ) -> list[dict[str, Any]]: + # ChatGPT OAuth 仍然是 chatgpt provider 专属能力,但模型目录不再维护 + # 一份内部名单,直接跟随当前 provider 对应的 models.dev 数据。 + payload = await self._models_dev_provider_payload( + provider_id, + base_url=base_url, + base_url_preset_id=base_url_preset_id, + ) models = payload.get("models") if isinstance(payload, dict) else None if not isinstance(models, dict): - return [ - { - "id": model_id, - "name": model_id, - "context_tokens": None, - "input_tokens": None, - "output_tokens": None, - "context_tokens_k": settings.LLM_MAX_CONTEXT_TOKENS, - "supports_reasoning": True, - "supports_tools": True, - "supports_image_input": True, - "supports_audio_input": False, - "transport": "openai", - "source": "builtin", - "release_date": None, - "status": None, - } - for model_id in sorted(self._CHATGPT_ALLOWED_OAUTH_MODELS) - ] + return [] results = [] for model_id, metadata in models.items(): - if "codex" in model_id or model_id in self._CHATGPT_ALLOWED_OAUTH_MODELS: - match = None - if model_id.startswith("gpt-"): - try: - match = float(model_id.split("-")[1].replace(".mini", "")) - except Exception as err: - print(err) - match = None - if match is not None and match > 5.4 and "codex" not in model_id: - continue - results.append( - self._normalize_model_record( - model_id=model_id, - display_name=metadata.get("name") or model_id, - metadata=metadata, - source="models.dev", - ) + results.append( + self._normalize_model_record( + model_id=model_id, + display_name=metadata.get("name") or model_id, + metadata=metadata, + source="models.dev", ) + ) return sorted(results, key=lambda item: item["name"].lower()) async def list_models( @@ -1640,6 +1638,11 @@ class LLMProviderManager(metaclass=Singleton): ) -> list[dict[str, Any]]: """返回标准化后的模型目录。""" spec = await self._get_provider_async(provider_id, force_refresh=force_refresh) + resolved_model_list_strategy = self._resolve_provider_model_list_strategy( + spec, + base_url, + base_url_preset_id=base_url_preset_id, + ) if self._resolve_provider_models_dev_provider_id( spec, base_url, @@ -1657,15 +1660,19 @@ class LLMProviderManager(metaclass=Singleton): base_url_preset_id=base_url_preset_id, ) - if spec.model_list_strategy == "google": + if resolved_model_list_strategy == "google": return await self._list_models_from_google(runtime["api_key"]) - if spec.model_list_strategy == "github_copilot": + if resolved_model_list_strategy == "github_copilot": return await self._list_models_from_copilot(runtime["api_key"]) - if spec.model_list_strategy == "chatgpt": + if resolved_model_list_strategy == "chatgpt": if runtime.get("auth_mode") == "oauth": - return await self._list_chatgpt_oauth_models() + return await self._list_chatgpt_oauth_models( + provider_id=provider_id, + base_url=base_url, + base_url_preset_id=base_url_preset_id, + ) return await self._list_models_from_openai_compatible( provider_id="chatgpt", api_key=runtime["api_key"], @@ -1677,7 +1684,7 @@ class LLMProviderManager(metaclass=Singleton): default_headers=runtime.get("default_headers"), ) - if spec.model_list_strategy == "anthropic_compatible": + if resolved_model_list_strategy == "anthropic_compatible": return await self._list_models_from_models_dev_only( provider_id=provider_id, transport="anthropic", @@ -1685,7 +1692,7 @@ class LLMProviderManager(metaclass=Singleton): base_url_preset_id=base_url_preset_id, ) - if spec.model_list_strategy == "models_dev_only": + if resolved_model_list_strategy == "models_dev_only": return await self._list_models_from_models_dev_only( provider_id=provider_id, transport="openai", @@ -2188,6 +2195,11 @@ class LLMProviderManager(metaclass=Singleton): base_url_preset_id, ) spec = await self._get_provider_async(normalized_provider_id) + resolved_runtime = self._resolve_provider_runtime( + spec, + base_url, + base_url_preset_id=normalized_base_url_preset_id, + ) normalized_api_key = str(api_key or "").strip() or None normalized_base_url = self._sanitize_base_url(base_url) model_record = None @@ -2212,7 +2224,7 @@ class LLMProviderManager(metaclass=Singleton): result: dict[str, Any] = { "provider_id": normalized_provider_id, - "runtime": spec.runtime, + "runtime": resolved_runtime, "model_id": model, "model_record": model_record, "model_metadata": await self.resolve_model_metadata( @@ -2290,7 +2302,7 @@ class LLMProviderManager(metaclass=Singleton): ) return result - if spec.runtime == "google": + if resolved_runtime == "google": if not normalized_api_key: raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key") result.update( @@ -2302,7 +2314,7 @@ class LLMProviderManager(metaclass=Singleton): ) return result - if spec.runtime == "anthropic_compatible": + if resolved_runtime == "anthropic_compatible": effective_base_url = normalized_base_url or self._default_base_url_for_provider( spec ) diff --git a/scripts/local_setup.py b/scripts/local_setup.py index 281a05b4..dc473ac4 100644 --- a/scripts/local_setup.py +++ b/scripts/local_setup.py @@ -1200,6 +1200,8 @@ def _llm_provider_defaults( provider_definitions: list[dict[str, Any]], ) -> dict[str, str]: normalized_provider = str(provider or "").strip().lower() + if normalized_provider == "kimi-coding": + normalized_provider = "moonshot" defaults = dict(LLM_PROVIDER_DEFAULTS.get(normalized_provider) or {}) provider_meta = next( ( @@ -1231,6 +1233,8 @@ def _llm_provider_meta( provider_definitions: list[dict[str, Any]], ) -> dict[str, Any]: normalized_provider = str(provider or "").strip().lower() + if normalized_provider == "kimi-coding": + normalized_provider = "moonshot" provider_meta = next( ( item @@ -1784,6 +1788,8 @@ def _collect_agent_config( provider_definitions = _load_llm_provider_definitions(runtime_python=runtime_python) provider_choices = _llm_provider_choice_map(provider_definitions) current_provider = _env_default("LLM_PROVIDER", "deepseek").lower() + if current_provider == "kimi-coding": + current_provider = "moonshot" if current_provider not in provider_choices: current_provider = "deepseek" diff --git a/tests/test_llm_provider_registry.py b/tests/test_llm_provider_registry.py index 69d90174..7bf7b904 100644 --- a/tests/test_llm_provider_registry.py +++ b/tests/test_llm_provider_registry.py @@ -294,6 +294,107 @@ class LlmProviderRegistryTest(unittest.TestCase): "minimax-cn-coding-plan", ) + def test_builtin_moonshot_provider_includes_kimi_for_coding_preset(self): + manager = LLMProviderManager() + + provider = manager.get_provider("moonshot") + serialized = manager.list_providers() + moonshot_payload = next(item for item in serialized if item["id"] == "moonshot") + + self.assertEqual(provider.name, "Moonshot / Kimi") + self.assertEqual(provider.runtime, "openai_compatible") + self.assertEqual( + tuple((preset.id, preset.label, preset.value, preset.runtime) for preset in provider.base_url_presets), + ( + ("moonshot-cn", "中国站", "https://api.moonshot.cn/v1", None), + ("moonshot-global", "国际站", "https://api.moonshot.ai/v1", None), + ( + "moonshot-kimi-coding", + "Kimi for Coding", + "https://api.kimi.com/coding/v1", + "anthropic_compatible", + ), + ), + ) + self.assertEqual( + tuple(item["id"] for item in moonshot_payload["base_url_presets"]), + ("moonshot-cn", "moonshot-global", "moonshot-kimi-coding"), + ) + + def test_kimi_coding_alias_resolves_to_moonshot_provider(self): + manager = LLMProviderManager() + + provider = manager.get_provider("kimi-coding") + + self.assertEqual(provider.id, "moonshot") + + def test_resolve_runtime_prefers_kimi_for_coding_preset_runtime(self): + manager = LLMProviderManager() + + runtime = asyncio.run( + manager.resolve_runtime( + provider_id="moonshot", + model=None, + api_key="sk-test", + base_url="https://api.kimi.com/coding/v1", + base_url_preset_id="moonshot-kimi-coding", + ) + ) + + self.assertEqual(runtime["provider_id"], "moonshot") + self.assertEqual(runtime["runtime"], "anthropic_compatible") + self.assertEqual(runtime["base_url"], "https://api.kimi.com/coding") + + def test_resolve_model_list_strategy_prefers_kimi_for_coding_preset(self): + manager = LLMProviderManager() + provider = manager.get_provider("moonshot") + + self.assertEqual( + manager._resolve_provider_model_list_strategy( + provider, + base_url="https://api.kimi.com/coding/v1", + base_url_preset_id="moonshot-kimi-coding", + ), + "anthropic_compatible", + ) + + def test_chatgpt_oauth_models_follow_models_dev_catalog(self): + manager = LLMProviderManager() + payload = { + "openai": { + "id": "openai", + "name": "OpenAI", + "models": { + "gpt-5.5": { + "name": "GPT-5.5", + "limit": {"context": 400000}, + }, + "o4-mini": { + "name": "o4-mini", + "limit": {"context": 200000}, + }, + }, + } + } + + with patch.object(manager, "get_models_dev_data", AsyncMock(return_value=payload)): + models = asyncio.run( + manager._list_chatgpt_oauth_models(provider_id="chatgpt") + ) + + self.assertEqual([item["id"] for item in models], ["gpt-5.5", "o4-mini"]) + self.assertTrue(all(item["source"] == "models.dev" for item in models)) + + def test_chatgpt_oauth_models_return_empty_when_catalog_missing(self): + manager = LLMProviderManager() + + with patch.object(manager, "get_models_dev_data", AsyncMock(return_value={})): + models = asyncio.run( + manager._list_chatgpt_oauth_models(provider_id="chatgpt") + ) + + self.assertEqual(models, []) + if __name__ == "__main__": unittest.main()