refactor(agent): consolidate provider preset resolution

This commit is contained in:
jxxghp
2026-05-08 12:35:02 +08:00
parent c9931aa948
commit c6d95cd006
3 changed files with 254 additions and 135 deletions

View File

@@ -51,6 +51,8 @@ class ProviderUrlPreset:
id: str
label: str
value: str
runtime: Optional[str] = None
model_list_strategy: Optional[str] = None
model_list_base_url: Optional[str] = None
models_dev_provider_id: Optional[str] = None
@@ -107,17 +109,6 @@ class LLMProviderManager(metaclass=Singleton):
_CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
_COPILOT_CLIENT_ID = "Ov23li8tweQw6odWQebz"
_DEFAULT_TIMEOUT = httpx.Timeout(15.0, connect=10.0)
_CHATGPT_ALLOWED_OAUTH_MODELS = {
"gpt-5.1-codex",
"gpt-5.1-codex-max",
"gpt-5.1-codex-mini",
"gpt-5.2",
"gpt-5.2-codex",
"gpt-5.3-codex",
"gpt-5.4",
"gpt-5.4-mini",
"gpt-5.5",
}
_MODELS_DEV_DYNAMIC_SKIP_IDS = {
"aihubmix",
"amazon-bedrock",
@@ -133,7 +124,7 @@ class LLMProviderManager(metaclass=Singleton):
"v0",
"vercel",
}
_MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES = {
_PROVIDER_PATCHES = {
"bailing": {
"runtime": "openai_compatible",
"default_base_url": "https://api.tbox.cn/api/llm/v1",
@@ -169,6 +160,14 @@ class LLMProviderManager(metaclass=Singleton):
"default_base_url": "https://api.venice.ai/api/v1",
"description": "Venice AI 官方兼容端点。",
},
"cloudflare-workers-ai": {
"api_key_hint": "填写 Cloudflare API Token并将 Base URL 中的 ${CLOUDFLARE_ACCOUNT_ID} 替换为真实账户 ID。",
"description": "Cloudflare Workers AI OpenAI-compatible 端点,需要替换账户 ID。",
},
"privatemode-ai": {
"api_key_hint": "如未启用鉴权,可填写任意占位值。",
"description": "Privatemode AI 本地 OpenAI-compatible 端点。",
},
}
def __init__(self):
@@ -203,6 +202,7 @@ class LLMProviderManager(metaclass=Singleton):
description="适合无回调环境,复制设备码到浏览器完成登录。",
)
url_preset = ProviderUrlPreset
provider_patches = LLMProviderManager._PROVIDER_PATCHES
def openai_provider(
provider_id: str,
name: str,
@@ -280,28 +280,18 @@ class LLMProviderManager(metaclass=Singleton):
)
catalog_openai_providers = (
("huggingface", "Hugging Face", "https://router.huggingface.co/v1"),
("jiekou", "接口 AI", "https://api.jiekou.ai/openai"),
("kilo", "Kilo Gateway", "https://api.kilo.ai/api/gateway"),
("llama", "Llama", "https://api.llama.com/compat/v1/"),
("llmgateway", "LLM Gateway", "https://api.llmgateway.io/v1"),
("modelscope", "ModelScope", "https://api-inference.modelscope.cn/v1"),
("nova", "Nova", "https://api.nova.amazon.com/v1"),
("fireworks-ai", "Fireworks AI", "https://api.fireworks.ai/inference/v1/"),
("poe", "Poe", "https://api.poe.com/v1"),
("qihang-ai", "启航 AI", "https://api.qhaigc.net/v1"),
("qiniu-ai", "七牛", "https://api.qnaigc.com/v1"),
{"id": "huggingface", "name": "Hugging Face", "base_url": "https://router.huggingface.co/v1"},
{"id": "jiekou", "name": "接口 AI", "base_url": "https://api.jiekou.ai/openai"},
{"id": "kilo", "name": "Kilo Gateway", "base_url": "https://api.kilo.ai/api/gateway"},
{"id": "llama", "name": "Llama", "base_url": "https://api.llama.com/compat/v1/"},
{"id": "llmgateway", "name": "LLM Gateway", "base_url": "https://api.llmgateway.io/v1"},
{"id": "modelscope", "name": "ModelScope", "base_url": "https://api-inference.modelscope.cn/v1"},
{"id": "nova", "name": "Nova", "base_url": "https://api.nova.amazon.com/v1"},
{"id": "fireworks-ai", "name": "Fireworks AI", "base_url": "https://api.fireworks.ai/inference/v1/"},
{"id": "poe", "name": "Poe", "base_url": "https://api.poe.com/v1"},
{"id": "qihang-ai", "name": "启航 AI", "base_url": "https://api.qhaigc.net/v1"},
{"id": "qiniu-ai", "name": "七牛", "base_url": "https://api.qnaigc.com/v1"},
)
catalog_openai_overrides = {
"cloudflare-workers-ai": {
"api_key_hint": "填写 Cloudflare API Token并将 Base URL 中的 ${CLOUDFLARE_ACCOUNT_ID} 替换为真实账户 ID。",
"description": "Cloudflare Workers AI OpenAI-compatible 端点,需要替换账户 ID。",
},
"privatemode-ai": {
"api_key_hint": "如未启用鉴权,可填写任意占位值。",
"description": "Privatemode AI 本地 OpenAI-compatible 端点。",
},
}
providers = [
ProviderSpec(
@@ -436,7 +426,7 @@ class LLMProviderManager(metaclass=Singleton):
),
catalog_openai_provider(
provider_id="moonshot",
name="Moonshot AI",
name="Moonshot / Kimi",
default_base_url="https://api.moonshot.cn/v1",
sort_order=62,
models_dev_provider_id="moonshotai-cn",
@@ -453,18 +443,17 @@ class LLMProviderManager(metaclass=Singleton):
value="https://api.moonshot.ai/v1",
models_dev_provider_id="moonshotai",
),
url_preset(
id="moonshot-kimi-coding",
label="Kimi for Coding",
value="https://api.kimi.com/coding/v1",
runtime="anthropic_compatible",
model_list_strategy="anthropic_compatible",
models_dev_provider_id="kimi-for-coding",
),
),
api_key_hint="填写 Moonshot / Kimi API Key可在中国站国际站端点间切换。",
description="Moonshot / Kimi 官方兼容端点。",
),
anthropic_provider(
provider_id="kimi-coding",
name="Kimi for Coding",
default_base_url="https://api.kimi.com/coding/v1",
sort_order=63,
models_dev_provider_id="kimi-for-coding",
api_key_hint="填写 Moonshot / Kimi API Key。",
description="Moonshot Kimi Coding Anthropic-compatible 端点。",
api_key_hint="填写 Moonshot / Kimi API Key可在中国站国际站与 Kimi for Coding 端点间切换。",
description="Moonshot / Kimi 官方端点,支持通用 API 与 Kimi for Coding 预设",
),
openai_provider(
provider_id="zhipu",
@@ -740,16 +729,17 @@ class LLMProviderManager(metaclass=Singleton):
),
]
for sort_order, (provider_id, name, base_url) in enumerate(
for sort_order, provider_entry in enumerate(
catalog_openai_providers,
start=200,
):
overrides = catalog_openai_overrides.get(provider_id, {})
provider_id = provider_entry["id"]
overrides = provider_patches.get(provider_id, {})
providers.append(
catalog_openai_provider(
provider_id=provider_id,
name=name,
default_base_url=base_url,
name=provider_entry["name"],
default_base_url=provider_entry["base_url"],
sort_order=sort_order,
api_key_hint=overrides.get("api_key_hint"),
description=overrides.get("description"),
@@ -846,7 +836,7 @@ class LLMProviderManager(metaclass=Singleton):
if not normalized_id or normalized_id in cls._MODELS_DEV_DYNAMIC_SKIP_IDS:
return None
override = cls._MODELS_DEV_DYNAMIC_PROVIDER_OVERRIDES.get(normalized_id, {})
override = cls._PROVIDER_PATCHES.get(normalized_id, {})
npm_package = str(payload.get("npm") or "").strip()
runtime = override.get("runtime")
if not runtime:
@@ -968,6 +958,8 @@ class LLMProviderManager(metaclass=Singleton):
"id": preset.id,
"label": preset.label,
"value": self._sanitize_base_url(preset.value) or "",
"runtime": preset.runtime,
"model_list_strategy": preset.model_list_strategy,
}
for preset in spec.base_url_presets
],
@@ -1038,6 +1030,8 @@ class LLMProviderManager(metaclass=Singleton):
normalized = (provider_id or "").strip().lower()
if normalized == "minimax-coding":
return "minimax"
if normalized == "kimi-coding":
return "moonshot"
return normalized
@classmethod
@@ -1050,8 +1044,59 @@ class LLMProviderManager(metaclass=Singleton):
return None
if normalized_provider_id == "minimax" and normalized_preset_id == "minimax-coding":
return "minimax-cn-coding"
if normalized_provider_id == "moonshot" and normalized_preset_id == "kimi-coding":
return "moonshot-kimi-coding"
return normalized_preset_id
@classmethod
def _resolve_provider_preset(
cls,
spec: ProviderSpec,
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> Optional[ProviderUrlPreset]:
normalized_preset_id = cls._normalize_base_url_preset_id(spec.id, base_url_preset_id)
if normalized_preset_id:
for preset in spec.base_url_presets:
if preset.id == normalized_preset_id:
return preset
normalized_base_url = cls._sanitize_base_url(base_url)
if normalized_base_url:
for preset in spec.base_url_presets:
preset_value = cls._sanitize_base_url(preset.value)
if normalized_base_url == preset_value:
return preset
return None
default_base_url = cls._default_base_url_for_provider(spec)
if default_base_url:
for preset in spec.base_url_presets:
preset_value = cls._sanitize_base_url(preset.value)
if preset_value == default_base_url:
return preset
return None
@classmethod
def _resolve_provider_runtime(
cls,
spec: ProviderSpec,
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> str:
preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id)
return preset.runtime or spec.runtime if preset else spec.runtime
@classmethod
def _resolve_provider_model_list_strategy(
cls,
spec: ProviderSpec,
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> str:
preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id)
return preset.model_list_strategy or spec.model_list_strategy if preset else spec.model_list_strategy
@classmethod
def _resolve_provider_model_list_base_url(
cls,
@@ -1059,31 +1104,16 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> Optional[str]:
normalized_preset_id = cls._normalize_base_url_preset_id(spec.id, base_url_preset_id)
if normalized_preset_id:
for preset in spec.base_url_presets:
if preset.id != normalized_preset_id:
continue
preset_value = cls._sanitize_base_url(preset.value)
return cls._sanitize_base_url(preset.model_list_base_url) or preset_value
preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id)
if preset:
preset_value = cls._sanitize_base_url(preset.value)
return cls._sanitize_base_url(preset.model_list_base_url) or preset_value
normalized_base_url = cls._sanitize_base_url(base_url)
if normalized_base_url:
for preset in spec.base_url_presets:
preset_value = cls._sanitize_base_url(preset.value)
if normalized_base_url != preset_value:
continue
return cls._sanitize_base_url(preset.model_list_base_url) or preset_value
return normalized_base_url
default_base_url = cls._default_base_url_for_provider(spec)
if default_base_url:
for preset in spec.base_url_presets:
preset_value = cls._sanitize_base_url(preset.value)
if preset_value != default_base_url:
continue
return cls._sanitize_base_url(preset.model_list_base_url) or preset_value
return default_base_url
return cls._default_base_url_for_provider(spec)
@classmethod
def _resolve_provider_models_dev_provider_id(
@@ -1092,29 +1122,14 @@ class LLMProviderManager(metaclass=Singleton):
base_url: Optional[str],
base_url_preset_id: Optional[str] = None,
) -> Optional[str]:
normalized_preset_id = cls._normalize_base_url_preset_id(spec.id, base_url_preset_id)
if normalized_preset_id:
for preset in spec.base_url_presets:
if preset.id != normalized_preset_id:
continue
return preset.models_dev_provider_id or spec.models_dev_provider_id
preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id)
if preset:
return preset.models_dev_provider_id or spec.models_dev_provider_id
normalized_base_url = cls._sanitize_base_url(base_url)
if normalized_base_url:
for preset in spec.base_url_presets:
preset_value = cls._sanitize_base_url(preset.value)
if normalized_base_url != preset_value:
continue
return preset.models_dev_provider_id or spec.models_dev_provider_id
return spec.models_dev_provider_id
default_base_url = cls._default_base_url_for_provider(spec)
if default_base_url:
for preset in spec.base_url_presets:
preset_value = cls._sanitize_base_url(preset.value)
if preset_value != default_base_url:
continue
return preset.models_dev_provider_id or spec.models_dev_provider_id
return spec.models_dev_provider_id
def resolve_model_list_base_url(
@@ -1584,50 +1599,33 @@ class LLMProviderManager(metaclass=Singleton):
)
return sorted(results, key=lambda i: i["name"].lower())
async def _list_chatgpt_oauth_models(self) -> list[dict[str, Any]]:
payload = await self._models_dev_provider_payload("chatgpt")
async def _list_chatgpt_oauth_models(
self,
provider_id: str,
base_url: Optional[str] = None,
base_url_preset_id: Optional[str] = None,
) -> list[dict[str, Any]]:
# ChatGPT OAuth 仍然是 chatgpt provider 专属能力,但模型目录不再维护
# 一份内部名单,直接跟随当前 provider 对应的 models.dev 数据。
payload = await self._models_dev_provider_payload(
provider_id,
base_url=base_url,
base_url_preset_id=base_url_preset_id,
)
models = payload.get("models") if isinstance(payload, dict) else None
if not isinstance(models, dict):
return [
{
"id": model_id,
"name": model_id,
"context_tokens": None,
"input_tokens": None,
"output_tokens": None,
"context_tokens_k": settings.LLM_MAX_CONTEXT_TOKENS,
"supports_reasoning": True,
"supports_tools": True,
"supports_image_input": True,
"supports_audio_input": False,
"transport": "openai",
"source": "builtin",
"release_date": None,
"status": None,
}
for model_id in sorted(self._CHATGPT_ALLOWED_OAUTH_MODELS)
]
return []
results = []
for model_id, metadata in models.items():
if "codex" in model_id or model_id in self._CHATGPT_ALLOWED_OAUTH_MODELS:
match = None
if model_id.startswith("gpt-"):
try:
match = float(model_id.split("-")[1].replace(".mini", ""))
except Exception as err:
print(err)
match = None
if match is not None and match > 5.4 and "codex" not in model_id:
continue
results.append(
self._normalize_model_record(
model_id=model_id,
display_name=metadata.get("name") or model_id,
metadata=metadata,
source="models.dev",
)
results.append(
self._normalize_model_record(
model_id=model_id,
display_name=metadata.get("name") or model_id,
metadata=metadata,
source="models.dev",
)
)
return sorted(results, key=lambda item: item["name"].lower())
async def list_models(
@@ -1640,6 +1638,11 @@ class LLMProviderManager(metaclass=Singleton):
) -> list[dict[str, Any]]:
"""返回标准化后的模型目录。"""
spec = await self._get_provider_async(provider_id, force_refresh=force_refresh)
resolved_model_list_strategy = self._resolve_provider_model_list_strategy(
spec,
base_url,
base_url_preset_id=base_url_preset_id,
)
if self._resolve_provider_models_dev_provider_id(
spec,
base_url,
@@ -1657,15 +1660,19 @@ class LLMProviderManager(metaclass=Singleton):
base_url_preset_id=base_url_preset_id,
)
if spec.model_list_strategy == "google":
if resolved_model_list_strategy == "google":
return await self._list_models_from_google(runtime["api_key"])
if spec.model_list_strategy == "github_copilot":
if resolved_model_list_strategy == "github_copilot":
return await self._list_models_from_copilot(runtime["api_key"])
if spec.model_list_strategy == "chatgpt":
if resolved_model_list_strategy == "chatgpt":
if runtime.get("auth_mode") == "oauth":
return await self._list_chatgpt_oauth_models()
return await self._list_chatgpt_oauth_models(
provider_id=provider_id,
base_url=base_url,
base_url_preset_id=base_url_preset_id,
)
return await self._list_models_from_openai_compatible(
provider_id="chatgpt",
api_key=runtime["api_key"],
@@ -1677,7 +1684,7 @@ class LLMProviderManager(metaclass=Singleton):
default_headers=runtime.get("default_headers"),
)
if spec.model_list_strategy == "anthropic_compatible":
if resolved_model_list_strategy == "anthropic_compatible":
return await self._list_models_from_models_dev_only(
provider_id=provider_id,
transport="anthropic",
@@ -1685,7 +1692,7 @@ class LLMProviderManager(metaclass=Singleton):
base_url_preset_id=base_url_preset_id,
)
if spec.model_list_strategy == "models_dev_only":
if resolved_model_list_strategy == "models_dev_only":
return await self._list_models_from_models_dev_only(
provider_id=provider_id,
transport="openai",
@@ -2188,6 +2195,11 @@ class LLMProviderManager(metaclass=Singleton):
base_url_preset_id,
)
spec = await self._get_provider_async(normalized_provider_id)
resolved_runtime = self._resolve_provider_runtime(
spec,
base_url,
base_url_preset_id=normalized_base_url_preset_id,
)
normalized_api_key = str(api_key or "").strip() or None
normalized_base_url = self._sanitize_base_url(base_url)
model_record = None
@@ -2212,7 +2224,7 @@ class LLMProviderManager(metaclass=Singleton):
result: dict[str, Any] = {
"provider_id": normalized_provider_id,
"runtime": spec.runtime,
"runtime": resolved_runtime,
"model_id": model,
"model_record": model_record,
"model_metadata": await self.resolve_model_metadata(
@@ -2290,7 +2302,7 @@ class LLMProviderManager(metaclass=Singleton):
)
return result
if spec.runtime == "google":
if resolved_runtime == "google":
if not normalized_api_key:
raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key")
result.update(
@@ -2302,7 +2314,7 @@ class LLMProviderManager(metaclass=Singleton):
)
return result
if spec.runtime == "anthropic_compatible":
if resolved_runtime == "anthropic_compatible":
effective_base_url = normalized_base_url or self._default_base_url_for_provider(
spec
)

View File

@@ -1200,6 +1200,8 @@ def _llm_provider_defaults(
provider_definitions: list[dict[str, Any]],
) -> dict[str, str]:
normalized_provider = str(provider or "").strip().lower()
if normalized_provider == "kimi-coding":
normalized_provider = "moonshot"
defaults = dict(LLM_PROVIDER_DEFAULTS.get(normalized_provider) or {})
provider_meta = next(
(
@@ -1231,6 +1233,8 @@ def _llm_provider_meta(
provider_definitions: list[dict[str, Any]],
) -> dict[str, Any]:
normalized_provider = str(provider or "").strip().lower()
if normalized_provider == "kimi-coding":
normalized_provider = "moonshot"
provider_meta = next(
(
item
@@ -1784,6 +1788,8 @@ def _collect_agent_config(
provider_definitions = _load_llm_provider_definitions(runtime_python=runtime_python)
provider_choices = _llm_provider_choice_map(provider_definitions)
current_provider = _env_default("LLM_PROVIDER", "deepseek").lower()
if current_provider == "kimi-coding":
current_provider = "moonshot"
if current_provider not in provider_choices:
current_provider = "deepseek"

View File

@@ -294,6 +294,107 @@ class LlmProviderRegistryTest(unittest.TestCase):
"minimax-cn-coding-plan",
)
def test_builtin_moonshot_provider_includes_kimi_for_coding_preset(self):
manager = LLMProviderManager()
provider = manager.get_provider("moonshot")
serialized = manager.list_providers()
moonshot_payload = next(item for item in serialized if item["id"] == "moonshot")
self.assertEqual(provider.name, "Moonshot / Kimi")
self.assertEqual(provider.runtime, "openai_compatible")
self.assertEqual(
tuple((preset.id, preset.label, preset.value, preset.runtime) for preset in provider.base_url_presets),
(
("moonshot-cn", "中国站", "https://api.moonshot.cn/v1", None),
("moonshot-global", "国际站", "https://api.moonshot.ai/v1", None),
(
"moonshot-kimi-coding",
"Kimi for Coding",
"https://api.kimi.com/coding/v1",
"anthropic_compatible",
),
),
)
self.assertEqual(
tuple(item["id"] for item in moonshot_payload["base_url_presets"]),
("moonshot-cn", "moonshot-global", "moonshot-kimi-coding"),
)
def test_kimi_coding_alias_resolves_to_moonshot_provider(self):
manager = LLMProviderManager()
provider = manager.get_provider("kimi-coding")
self.assertEqual(provider.id, "moonshot")
def test_resolve_runtime_prefers_kimi_for_coding_preset_runtime(self):
manager = LLMProviderManager()
runtime = asyncio.run(
manager.resolve_runtime(
provider_id="moonshot",
model=None,
api_key="sk-test",
base_url="https://api.kimi.com/coding/v1",
base_url_preset_id="moonshot-kimi-coding",
)
)
self.assertEqual(runtime["provider_id"], "moonshot")
self.assertEqual(runtime["runtime"], "anthropic_compatible")
self.assertEqual(runtime["base_url"], "https://api.kimi.com/coding")
def test_resolve_model_list_strategy_prefers_kimi_for_coding_preset(self):
manager = LLMProviderManager()
provider = manager.get_provider("moonshot")
self.assertEqual(
manager._resolve_provider_model_list_strategy(
provider,
base_url="https://api.kimi.com/coding/v1",
base_url_preset_id="moonshot-kimi-coding",
),
"anthropic_compatible",
)
def test_chatgpt_oauth_models_follow_models_dev_catalog(self):
manager = LLMProviderManager()
payload = {
"openai": {
"id": "openai",
"name": "OpenAI",
"models": {
"gpt-5.5": {
"name": "GPT-5.5",
"limit": {"context": 400000},
},
"o4-mini": {
"name": "o4-mini",
"limit": {"context": 200000},
},
},
}
}
with patch.object(manager, "get_models_dev_data", AsyncMock(return_value=payload)):
models = asyncio.run(
manager._list_chatgpt_oauth_models(provider_id="chatgpt")
)
self.assertEqual([item["id"] for item in models], ["gpt-5.5", "o4-mini"])
self.assertTrue(all(item["source"] == "models.dev" for item in models))
def test_chatgpt_oauth_models_return_empty_when_catalog_missing(self):
manager = LLMProviderManager()
with patch.object(manager, "get_models_dev_data", AsyncMock(return_value={})):
models = asyncio.run(
manager._list_chatgpt_oauth_models(provider_id="chatgpt")
)
self.assertEqual(models, [])
if __name__ == "__main__":
unittest.main()