"""LLM provider registry, auth flows, and model metadata helpers.""" from __future__ import annotations import asyncio import base64 import copy import hashlib import json import secrets import threading import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Optional, Tuple from urllib.parse import urlencode import aiofiles import httpx import jwt from app.core.config import settings from app.db.systemconfig_oper import SystemConfigOper from app.log import logger from app.schemas.types import SystemConfigKey from app.utils.singleton import Singleton class LLMProviderError(RuntimeError): """通用 LLM provider 异常。""" class LLMProviderAuthError(LLMProviderError): """LLM provider 鉴权异常。""" @dataclass(frozen=True) class ProviderAuthMethod: """前端展示用的授权方式定义。""" id: str type: str label: str description: str = "" @dataclass(frozen=True) class ProviderUrlPreset: """前端展示用的 Base URL 预设。""" id: str label: str value: str runtime: Optional[str] = None model_list_strategy: Optional[str] = None model_list_base_url: Optional[str] = None models_dev_provider_id: Optional[str] = None @dataclass(frozen=True) class ProviderSpec: """描述一个可接入的 LLM provider。""" id: str name: str runtime: str models_dev_provider_id: Optional[str] = None default_base_url: Optional[str] = None base_url_presets: Tuple[ProviderUrlPreset, ...] = () base_url_editable: bool = False requires_base_url: bool = False supports_api_key: bool = True api_key_label: str = "API Key" api_key_hint: str = "" oauth_methods: Tuple[ProviderAuthMethod, ...] = () supports_model_refresh: bool = True model_list_strategy: str = "openai_compatible" sort_order: int = 100 description: str = "" @dataclass class PendingAuthSession: """保存临时鉴权会话,避免把 PKCE/device code 等状态写回配置。""" session_id: str provider_id: str method_id: str flow_type: str status: str = "pending" message: str = "" authorize_url: Optional[str] = None instructions: Optional[str] = None verification_url: Optional[str] = None user_code: Optional[str] = None interval_seconds: int = 5 expires_at: float = 0 created_at: float = field(default_factory=time.time) context: Dict[str, Any] = field(default_factory=dict) class LLMProviderManager(metaclass=Singleton): """统一维护 provider 目录、models.dev 缓存和 OAuth 状态。""" _MODELS_DEV_URL = "https://models.dev/api.json" _MODELS_DEV_BUNDLED_PATH = Path(__file__).with_name("models.json") _MODELS_DEV_CACHE_TTL = 7 * 24 * 60 * 60 _CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann" _CHATGPT_ISSUER = "https://auth.openai.com" _CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex" _COPILOT_CLIENT_ID = "Ov23li8tweQw6odWQebz" _DEFAULT_TIMEOUT = httpx.Timeout(15.0, connect=10.0) _MODELS_DEV_DYNAMIC_SKIP_IDS = { "aihubmix", "amazon-bedrock", "azure", "azure-cognitive-services", "cloudflare-ai-gateway", "cohere", "gitlab", "google-vertex", "google-vertex-anthropic", "kiro", "sap-ai-core", "v0", "vercel", } _PROVIDER_PATCHES = { "bailing": { "runtime": "openai_compatible", "default_base_url": "https://api.tbox.cn/api/llm/v1", "description": "Bailing OpenAI-compatible 端点。", }, "cerebras": { "runtime": "openai_compatible", "default_base_url": "https://api.cerebras.ai/v1", "description": "Cerebras 官方兼容端点。", }, "deepinfra": { "runtime": "openai_compatible", "default_base_url": "https://api.deepinfra.com/v1/openai", "description": "DeepInfra 官方兼容端点。", }, "mistral": { "runtime": "openai_compatible", "default_base_url": "https://api.mistral.ai/v1", "description": "Mistral 官方兼容端点。", }, "perplexity": { "runtime": "openai_compatible", "default_base_url": "https://api.perplexity.ai/v1", "description": "Perplexity 官方兼容端点。", }, "togetherai": { "runtime": "openai_compatible", "default_base_url": "https://api.together.xyz/v1", "description": "Together AI 官方兼容端点。", }, "venice": { "runtime": "openai_compatible", "default_base_url": "https://api.venice.ai/api/v1", "description": "Venice AI 官方兼容端点。", }, "cloudflare-workers-ai": { "api_key_hint": "填写 Cloudflare API Token,并将 Base URL 中的 ${CLOUDFLARE_ACCOUNT_ID} 替换为真实账户 ID。", "description": "Cloudflare Workers AI OpenAI-compatible 端点,需要替换账户 ID。", }, "privatemode-ai": { "api_key_hint": "如未启用鉴权,可填写任意占位值。", "description": "Privatemode AI 本地 OpenAI-compatible 端点。", }, } def __init__(self): """初始化管理器实例及各类锁和缓存变量。""" self._lock = threading.RLock() self._models_dev_lock = asyncio.Lock() self._pending_sessions: dict[str, PendingAuthSession] = {} self._oauth_state_index: dict[str, str] = {} self._models_dev_data: dict[str, Any] | None = None self._models_dev_loaded_at: float = 0 self._models_dev_cache_path = ( Path(settings.TEMP_PATH) / "llm_provider_models_dev_cache.json" ) @staticmethod def _builtin_provider_specs() -> tuple[ProviderSpec, ...]: """ 返回受支持的 provider 定义。 OpenAI 保留为用户自定义 OpenAI-compatible 兜底入口,因此仍要求填写 Base URL;ChatGPT 则单独承接官方 API Key / ChatGPT 订阅鉴权。 """ browser_auth = ProviderAuthMethod( id="browser_oauth", type="oauth", label="浏览器授权", description="使用 ChatGPT Plus/Pro 浏览器登录并回调授权。", ) device_auth = ProviderAuthMethod( id="device_code", type="device", label="设备码授权", description="适合无回调环境,复制设备码到浏览器完成登录。", ) url_preset = ProviderUrlPreset provider_patches = LLMProviderManager._PROVIDER_PATCHES def openai_provider( provider_id: str, name: str, default_base_url: str, sort_order: int, *, models_dev_provider_id: Optional[str] = None, base_url_presets: Tuple[ProviderUrlPreset, ...] = (), api_key_hint: Optional[str] = None, description: Optional[str] = None, model_list_strategy: str = "openai_compatible", api_key_label: str = "API Key", ) -> ProviderSpec: return ProviderSpec( id=provider_id, name=name, runtime="openai_compatible", models_dev_provider_id=models_dev_provider_id or provider_id, default_base_url=default_base_url, base_url_presets=base_url_presets, api_key_label=api_key_label, api_key_hint=api_key_hint or f"填写 {name} API Key。", model_list_strategy=model_list_strategy, description=description or f"{name} OpenAI-compatible 端点。", sort_order=sort_order, ) def catalog_openai_provider( provider_id: str, name: str, default_base_url: str, sort_order: int, *, models_dev_provider_id: Optional[str] = None, base_url_presets: Tuple[ProviderUrlPreset, ...] = (), api_key_hint: Optional[str] = None, description: Optional[str] = None, api_key_label: str = "API Key", ) -> ProviderSpec: return openai_provider( provider_id=provider_id, name=name, default_base_url=default_base_url, sort_order=sort_order, models_dev_provider_id=models_dev_provider_id, base_url_presets=base_url_presets, api_key_hint=api_key_hint, description=description, model_list_strategy="models_dev_only", api_key_label=api_key_label, ) def anthropic_provider( provider_id: str, name: str, default_base_url: str, sort_order: int, *, models_dev_provider_id: Optional[str] = None, base_url_presets: Tuple[ProviderUrlPreset, ...] = (), api_key_hint: Optional[str] = None, description: Optional[str] = None, ) -> ProviderSpec: return ProviderSpec( id=provider_id, name=name, runtime="anthropic_compatible", models_dev_provider_id=models_dev_provider_id or provider_id, default_base_url=default_base_url, base_url_presets=base_url_presets, api_key_hint=api_key_hint or f"填写 {name} API Key。", model_list_strategy="anthropic_compatible", description=description or f"{name} Anthropic-compatible 端点。", sort_order=sort_order, ) catalog_openai_providers = ( {"id": "huggingface", "name": "Hugging Face", "base_url": "https://router.huggingface.co/v1"}, {"id": "jiekou", "name": "接口 AI", "base_url": "https://api.jiekou.ai/openai"}, {"id": "kilo", "name": "Kilo Gateway", "base_url": "https://api.kilo.ai/api/gateway"}, {"id": "llama", "name": "Llama", "base_url": "https://api.llama.com/compat/v1/"}, {"id": "llmgateway", "name": "LLM Gateway", "base_url": "https://api.llmgateway.io/v1"}, {"id": "modelscope", "name": "ModelScope", "base_url": "https://api-inference.modelscope.cn/v1"}, {"id": "nova", "name": "Nova", "base_url": "https://api.nova.amazon.com/v1"}, {"id": "fireworks-ai", "name": "Fireworks AI", "base_url": "https://api.fireworks.ai/inference/v1/"}, {"id": "poe", "name": "Poe", "base_url": "https://api.poe.com/v1"}, {"id": "qihang-ai", "name": "启航 AI", "base_url": "https://api.qhaigc.net/v1"}, {"id": "qiniu-ai", "name": "七牛", "base_url": "https://api.qnaigc.com/v1"}, ) providers = [ ProviderSpec( id="openai", name="OpenAI 兼容", runtime="openai_compatible", default_base_url="", base_url_editable=True, requires_base_url=True, supports_api_key=True, api_key_hint="通用 OpenAI-compatible 兜底入口,需要手动填写 Base URL。", description="通用 OpenAI-compatible 模型服务。", sort_order=1, ), ProviderSpec( id="chatgpt", name="ChatGPT", runtime="chatgpt", models_dev_provider_id="openai", default_base_url="https://api.openai.com/v1", api_key_hint="可直接填写 OpenAI API Key,或使用 ChatGPT Plus/Pro 登录授权。", oauth_methods=(browser_auth, device_auth), model_list_strategy="chatgpt", description="支持 ChatGPT Plus/Pro 鉴权或 OpenAI 官方 API Key。", sort_order=10, ), ProviderSpec( id="google", name="Google", runtime="google", models_dev_provider_id="google", supports_api_key=True, api_key_hint="填写 Gemini / Google AI Studio API Key。", model_list_strategy="google", description="Gemini / Google AI Studio。", sort_order=20, ), anthropic_provider( provider_id="anthropic", name="Anthropic", default_base_url="https://api.anthropic.com/v1", sort_order=30, api_key_hint="填写 Anthropic API Key。", description="Anthropic Claude 官方端点。", ), ProviderSpec( id="deepseek", name="DeepSeek", runtime="deepseek", models_dev_provider_id="deepseek", default_base_url="https://api.deepseek.com", api_key_hint="填写 DeepSeek API Key。", description="DeepSeek 官方平台。", sort_order=40, ), ProviderSpec( id="openrouter", name="OpenRouter", runtime="openai_compatible", models_dev_provider_id="openrouter", default_base_url="https://openrouter.ai/api/v1", api_key_hint="填写 OpenRouter API Key。", description="OpenRouter 聚合模型平台。", sort_order=50, ), ProviderSpec( id="github-copilot", name="GitHub Copilot", runtime="github_copilot", models_dev_provider_id="github-copilot", supports_api_key=False, api_key_label="GitHub Token", oauth_methods=( ProviderAuthMethod( id="device_code", type="device", label="GitHub 设备码授权", description="使用 GitHub Copilot 订阅登录授权。", ), ), model_list_strategy="github_copilot", description="通过 GitHub Copilot 订阅接入。", sort_order=60, ), catalog_openai_provider( provider_id="github-models", name="GitHub Models", default_base_url="https://models.github.ai/inference", sort_order=70, api_key_label="GitHub Token", api_key_hint="填写具有 GitHub Models 访问权限的 GitHub Token。", description="GitHub Models 推理端点。", ), catalog_openai_provider( provider_id="moonshot", name="Moonshot / Kimi", default_base_url="https://api.moonshot.cn/v1", sort_order=80, models_dev_provider_id="moonshotai-cn", base_url_presets=( url_preset( id="moonshot-cn", label="中国站", value="https://api.moonshot.cn/v1", models_dev_provider_id="moonshotai-cn", ), url_preset( id="moonshot-global", label="国际站", value="https://api.moonshot.ai/v1", models_dev_provider_id="moonshotai", ), url_preset( id="moonshot-kimi-coding", label="Kimi for Coding", value="https://api.kimi.com/coding/v1", runtime="anthropic_compatible", model_list_strategy="anthropic_compatible", models_dev_provider_id="kimi-for-coding", ), ), api_key_hint="填写 Moonshot / Kimi API Key,可在中国站、国际站与 Kimi for Coding 端点间切换。", description="Moonshot / Kimi 官方端点,支持通用 API 与 Kimi for Coding 预设。", ), anthropic_provider( provider_id="minimax", name="MiniMax", default_base_url="https://api.minimaxi.com/anthropic/v1", sort_order=90, models_dev_provider_id="minimax-cn", base_url_presets=( url_preset( id="minimax-cn-general", label="中国内地 / 通用", value="https://api.minimaxi.com/anthropic/v1", models_dev_provider_id="minimax-cn", ), url_preset( id="minimax-global-general", label="国际站 / 通用", value="https://api.minimax.io/anthropic/v1", models_dev_provider_id="minimax", ), url_preset( id="minimax-cn-coding", label="中国内地 / Coding Plan", value="https://api.minimaxi.com/anthropic/v1", models_dev_provider_id="minimax-cn-coding-plan", ), url_preset( id="minimax-global-coding", label="国际站 / Coding Plan", value="https://api.minimax.io/anthropic/v1", models_dev_provider_id="minimax-coding-plan", ), ), api_key_hint="填写 MiniMax API Key,可在中国内地、国际站、通用与 Coding Plan 目录间切换。", description="MiniMax Anthropic-compatible 端点,支持通用与 Coding Plan 目录预设。", ), catalog_openai_provider( provider_id="xiaomi", name="Xiaomi", default_base_url="https://api.xiaomimimo.com/v1", sort_order=100, base_url_presets=( url_preset( id="xiaomi-standard", label="标准端点", value="https://api.xiaomimimo.com/v1", models_dev_provider_id="xiaomi", ), url_preset( id="xiaomi-token-plan-cn", label="Token Plan / 中国", value="https://token-plan-cn.xiaomimimo.com/v1", models_dev_provider_id="xiaomi-token-plan-cn", ), url_preset( id="xiaomi-token-plan-sgp", label="Token Plan / 新加坡", value="https://token-plan-sgp.xiaomimimo.com/v1", models_dev_provider_id="xiaomi-token-plan-sgp", ), url_preset( id="xiaomi-token-plan-ams", label="Token Plan / 欧洲", value="https://token-plan-ams.xiaomimimo.com/v1", models_dev_provider_id="xiaomi-token-plan-ams", ), ), api_key_hint="填写 Xiaomi API Key,可在标准端点与各区域 Token Plan 端点间切换。", description="小米 Mimo 兼容端点。", ), openai_provider( provider_id="zhipu", name="智谱 GLM", default_base_url="https://open.bigmodel.cn/api/paas/v4", sort_order=110, models_dev_provider_id="zhipuai", base_url_presets=( url_preset( id="zhipu-general", label="Token Plan / 通用 API", value="https://open.bigmodel.cn/api/paas/v4", models_dev_provider_id="zhipuai", ), url_preset( id="zhipu-coding", label="Coding Plan", value="https://open.bigmodel.cn/api/coding/paas/v4", model_list_base_url="https://open.bigmodel.cn/api/paas/v4", models_dev_provider_id="zhipuai-coding-plan", ), ), api_key_hint="填写智谱开放平台 API Key,可在 Token Plan / 通用 API 与 Coding Plan 端点间切换。", description="智谱开放平台国内站,支持通用 API 与 GLM Coding Plan 端点。", ), openai_provider( provider_id="siliconflow", name="硅基流动", default_base_url="https://api.siliconflow.cn/v1", sort_order=120, models_dev_provider_id="siliconflow-cn", base_url_presets=( url_preset( id="siliconflow-cn", label="中国大陆", value="https://api.siliconflow.cn/v1", models_dev_provider_id="siliconflow-cn", ), url_preset( id="siliconflow-global", label="Global", value="https://api.siliconflow.com/v1", models_dev_provider_id="siliconflow", ), ), api_key_hint="填写硅基流动 API Key,可在中国大陆与 Global 端点间切换。", description="SiliconFlow 官方兼容端点。", ), openai_provider( provider_id="alibaba", name="阿里云百炼", default_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", sort_order=130, models_dev_provider_id="alibaba-cn", base_url_presets=( url_preset( id="alibaba-cn-general", label="中国内地 / 通用", value="https://dashscope.aliyuncs.com/compatible-mode/v1", models_dev_provider_id="alibaba-cn", ), url_preset( id="alibaba-global-general", label="国际站 / 通用", value="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", models_dev_provider_id="alibaba", ), url_preset( id="alibaba-cn-coding", label="中国内地 / Coding Plan", value="https://coding.dashscope.aliyuncs.com/v1", model_list_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", models_dev_provider_id="alibaba-coding-plan-cn", ), url_preset( id="alibaba-global-coding", label="国际站 / Coding Plan", value="https://coding-intl.dashscope.aliyuncs.com/v1", model_list_base_url="https://dashscope-intl.aliyuncs.com/compatible-mode/v1", models_dev_provider_id="alibaba-coding-plan", ), ), api_key_hint="填写 DashScope / Alibaba API Key,可在中国内地、国际站与 Coding Plan 端点间切换。", description="阿里云百炼兼容端点。", ), ProviderSpec( id="baidu-qianfan-coding-plan", name="百度千帆", runtime="openai_compatible", default_base_url="https://qianfan.baidubce.com/v2", base_url_presets=( url_preset( id="baidu-qianfan-general", label="通用 API", value="https://qianfan.baidubce.com/v2", ), url_preset( id="baidu-qianfan-coding", label="Coding Plan", value="https://qianfan.baidubce.com/v2/coding", ), ), api_key_hint="填写百度千帆 API Key,可在通用 API 与 Coding Plan 端点间切换。通用 API 请使用 https://qianfan.baidubce.com/v2;Coding Plan 请切换到 https://qianfan.baidubce.com/v2/coding。", supports_model_refresh=False, description="百度千帆 OpenAI-compatible V2 端点,支持通用 API 与 Coding Plan 地址预设。", sort_order=140, ), ProviderSpec( id="jdcloud", name="京东云", runtime="openai_compatible", default_base_url="https://modelservice.jdcloud.com/v1", base_url_presets=( url_preset( id="jdcloud-general", label="通用 API", value="https://modelservice.jdcloud.com/v1", ), url_preset( id="jdcloud-coding", label="Coding Plan", value="https://modelservice.jdcloud.com/coding/openai/v1", ), ), api_key_hint="填写京东云 JoyBuilder API Key。通用 API 请使用 https://modelservice.jdcloud.com/v1;Coding Plan 请切换到 https://modelservice.jdcloud.com/coding/openai/v1,不要把 /v1 当成 Coding Plan 地址。", supports_model_refresh=False, description="京东云 JoyBuilder OpenAI-compatible 端点,支持通用 API 与 Coding Plan 地址预设。", sort_order=150, ), ProviderSpec( id="kuaishou-wanqing", name="快手万擎", runtime="openai_compatible", default_base_url="https://wanqing.streamlakeapi.com/api/gateway/v1/endpoints", base_url_presets=( url_preset( id="kuaishou-wanqing-usage", label="按量计费", value="https://wanqing.streamlakeapi.com/api/gateway/v1/endpoints", ), url_preset( id="kuaishou-wanqing-coding", label="Coding Plan", value="https://wanqing.streamlakeapi.com/api/gateway/coding/v1", ), ), api_key_hint="填写快手万擎 API Key;模型名称请填写万擎控制台或 OpenClaw 配置中的 model ID。", supports_model_refresh=False, model_list_strategy="manual", description="快手万擎 OpenAI-compatible 端点,支持按量计费与 Coding Plan 地址预设。", sort_order=155, ), ProviderSpec( id="volcengine", name="火山方舟", runtime="openai_compatible", default_base_url="https://ark.cn-beijing.volces.com/api/v3", api_key_hint="填写火山方舟 API Key。", description="字节跳动火山引擎兼容端点。", sort_order=160, ), ProviderSpec( id="tencent", name="腾讯云", runtime="openai_compatible", models_dev_provider_id="tencent-tokenhub", default_base_url="https://tokenhub.tencentmaas.com/v1", base_url_presets=( url_preset( id="tencent-tokenhub", label="TokenHub", value="https://tokenhub.tencentmaas.com/v1", models_dev_provider_id="tencent-tokenhub", ), url_preset( id="tencent-coding", label="Coding Plan", value="https://api.lkeap.cloud.tencent.com/coding/v3", models_dev_provider_id="tencent-coding-plan", ), ), api_key_hint="填写 Tencent API Key,可在 TokenHub 与 Coding Plan 端点间切换。", model_list_strategy="models_dev_only", description="腾讯兼容端点。", sort_order=170, ), ProviderSpec( id="ollama-cloud", name="Ollama Cloud", runtime="openai_compatible", models_dev_provider_id="ollama-cloud", default_base_url="https://ollama.com/v1", api_key_hint="填写 Ollama Cloud API Key。", description="Ollama Cloud 云端模型服务。", sort_order=180, ), ProviderSpec( id="nvidia", name="Nvidia", runtime="openai_compatible", models_dev_provider_id="nvidia", default_base_url="https://integrate.api.nvidia.com/v1", api_key_hint="填写 Nvidia API Key。", description="Nvidia 集成推理平台。", sort_order=190, ), catalog_openai_provider( provider_id="opencode", name="OpenCode", default_base_url="https://opencode.ai/zen/v1", sort_order=200, base_url_presets=( url_preset( id="opencode-zen", label="Zen", value="https://opencode.ai/zen/v1", models_dev_provider_id="opencode", ), url_preset( id="opencode-go", label="Go", value="https://opencode.ai/zen/go/v1", models_dev_provider_id="opencode-go", ), ), api_key_hint="填写 OpenCode API Key,可在 Zen 与 Go 端点间切换。", description="OpenCode Zen / Go 端点。", ), catalog_openai_provider( provider_id="groq", name="Groq", default_base_url="https://api.groq.com/openai/v1", sort_order=210, api_key_hint="填写 Groq API Key。", description="Groq 官方 OpenAI-compatible 端点。", ), catalog_openai_provider( provider_id="xai", name="xAI", default_base_url="https://api.x.ai/v1", sort_order=220, api_key_hint="填写 xAI API Key。", description="xAI 官方 OpenAI-compatible 端点。", ), catalog_openai_provider( provider_id="zai", name="Z.AI", default_base_url="https://api.z.ai/api/paas/v4", sort_order=230, base_url_presets=( url_preset( id="zai-general", label="Token Plan / 通用 API", value="https://api.z.ai/api/paas/v4", models_dev_provider_id="zai", ), url_preset( id="zai-coding", label="Coding Plan", value="https://api.z.ai/api/coding/paas/v4", models_dev_provider_id="zai-coding-plan", ), ), api_key_hint="填写 Z.AI API Key,可在通用 API 与 Coding Plan 端点间切换。", description="Z.AI 官方端点。", ), ] for sort_order, provider_entry in enumerate( catalog_openai_providers, start=1000, ): provider_id = provider_entry["id"] overrides = provider_patches.get(provider_id, {}) providers.append( catalog_openai_provider( provider_id=provider_id, name=provider_entry["name"], default_base_url=provider_entry["base_url"], sort_order=sort_order, api_key_hint=overrides.get("api_key_hint"), description=overrides.get("description"), ) ) return tuple(providers) def _cached_models_dev_payload(self) -> dict[str, Any]: """获取缓存的 models.dev payload。""" if isinstance(self._models_dev_data, dict): return self._models_dev_data try: if not self._models_dev_cache_path.exists(): payload = None else: payload = json.loads(self._models_dev_cache_path.read_text(encoding="utf-8")) except Exception as err: logger.warning(f"读取 models.dev provider 缓存失败: {err}") payload = None if not isinstance(payload, dict): payload = self._load_bundled_models_dev_payload() if not isinstance(payload, dict): return {} self._models_dev_data = payload return payload @staticmethod def _models_dev_env_names(payload: dict[str, Any]) -> tuple[str, ...]: """从 models.dev 数据中提取支持的环境变量名。""" raw_env_names = payload.get("env") if not isinstance(raw_env_names, list): return () env_names = [] for item in raw_env_names: value = str(item or "").strip() if value: env_names.append(value) return tuple(env_names) @classmethod def _models_dev_reserved_provider_ids( cls, specs: tuple[ProviderSpec, ...] ) -> set[str]: """获取所有已保留的 models_dev_provider_id 集合。""" reserved_ids: set[str] = set() for spec in specs: if spec.models_dev_provider_id: reserved_ids.add(spec.models_dev_provider_id) for preset in spec.base_url_presets: if preset.models_dev_provider_id: reserved_ids.add(preset.models_dev_provider_id) return reserved_ids @staticmethod def _dynamic_api_key_label(env_names: tuple[str, ...]) -> str: """根据环境变量名动态推断 API Key 标签名称。""" first_env = env_names[0].upper() if env_names else "" if "TOKEN" in first_env and "KEY" not in first_env: return "API Token" return "API Key" @classmethod def _normalize_models_dev_base_url( cls, runtime: str, base_url: Optional[str] ) -> Optional[str]: """规范化从 models.dev 获取的 Base URL。""" normalized = cls._sanitize_base_url(base_url) if not normalized: return None suffixes = { "openai_compatible": ( "/chat/completions", "/completions", "/responses", "/embeddings", "/audio/speech", "/audio/transcriptions", ), "anthropic_compatible": ( "/messages", ), } for suffix in suffixes.get(runtime, ()): if normalized.endswith(suffix): normalized = normalized[: -len(suffix)] break return cls._sanitize_base_url(normalized) @classmethod def _models_dev_dynamic_provider_spec( cls, provider_id: str, payload: dict[str, Any], sort_order: int, ) -> ProviderSpec | None: """根据 models.dev 数据动态生成 ProviderSpec 实例。""" normalized_id = str(provider_id or "").strip().lower() if not normalized_id or normalized_id in cls._MODELS_DEV_DYNAMIC_SKIP_IDS: return None override = cls._PROVIDER_PATCHES.get(normalized_id, {}) npm_package = str(payload.get("npm") or "").strip() runtime = override.get("runtime") if not runtime: if npm_package == "@ai-sdk/openai-compatible": runtime = "openai_compatible" elif npm_package == "@ai-sdk/anthropic": runtime = "anthropic_compatible" else: return None model_list_strategy = override.get("model_list_strategy") if not model_list_strategy: model_list_strategy = ( "anthropic_compatible" if runtime == "anthropic_compatible" else "models_dev_only" ) default_base_url = cls._normalize_models_dev_base_url( runtime=runtime, base_url=override.get("default_base_url") or payload.get("api"), ) requires_base_url = not bool(default_base_url) env_names = cls._models_dev_env_names(payload) api_key_label = override.get("api_key_label") or cls._dynamic_api_key_label( env_names ) name = str(payload.get("name") or override.get("name") or normalized_id).strip() description = override.get("description") if not description: transport_name = "Anthropic-compatible" if runtime == "anthropic_compatible" else "OpenAI-compatible" description = f"{name} {transport_name} 端点(来自 models.dev 目录)。" api_key_hint = override.get("api_key_hint") if not api_key_hint: api_key_hint = f"填写 {name} {api_key_label}。" if requires_base_url: api_key_hint = f"填写 {name} {api_key_label},并手动填写 Base URL。" return ProviderSpec( id=normalized_id, name=name, runtime=runtime, models_dev_provider_id=normalized_id, default_base_url=default_base_url, base_url_editable=True, requires_base_url=requires_base_url, api_key_label=api_key_label, api_key_hint=api_key_hint, model_list_strategy=model_list_strategy, description=description, sort_order=sort_order, ) def _dynamic_provider_specs( self, builtin_specs: tuple[ProviderSpec, ...] ) -> tuple[ProviderSpec, ...]: """获取从 models.dev 动态加载的所有 ProviderSpec 实例。""" payload = self._cached_models_dev_payload() if not payload: return () explicit_ids = {spec.id for spec in builtin_specs} reserved_ids = self._models_dev_reserved_provider_ids(builtin_specs) candidates: list[tuple[str, str, dict[str, Any]]] = [] for provider_id, provider_payload in payload.items(): normalized_id = str(provider_id or "").strip().lower() if not normalized_id or not isinstance(provider_payload, dict): continue if normalized_id in explicit_ids or normalized_id in reserved_ids: continue spec = self._models_dev_dynamic_provider_spec( provider_id=normalized_id, payload=provider_payload, sort_order=0, ) if not spec: continue candidates.append((spec.name.lower(), normalized_id, provider_payload)) dynamic_specs = [] for sort_order, (_, provider_id, provider_payload) in enumerate( sorted(candidates), start=700, ): spec = self._models_dev_dynamic_provider_spec( provider_id=provider_id, payload=provider_payload, sort_order=sort_order, ) if not spec: continue dynamic_specs.append(spec) return tuple(dynamic_specs) def _provider_specs(self) -> tuple[ProviderSpec, ...]: """获取所有支持的 ProviderSpec,包括内置和动态加载的。""" builtin_specs = self._builtin_provider_specs() return builtin_specs + self._dynamic_provider_specs(builtin_specs) async def _get_provider_async( self, provider_id: str, force_refresh: bool = False ) -> ProviderSpec: """异步获取指定 provider 的 ProviderSpec 实例。""" normalized_provider_id = self._normalize_provider_id(provider_id) try: return self.get_provider(normalized_provider_id) except LLMProviderError: await self.get_models_dev_data(force_refresh=force_refresh) return self.get_provider(normalized_provider_id) def _serialize_provider(self, spec: ProviderSpec) -> dict[str, Any]: """将 ProviderSpec 序列化为前端可用的字典。""" return { "id": spec.id, "name": spec.name, "runtime": spec.runtime, "default_base_url": self._default_base_url_for_provider(spec) or "", "base_url_presets": [ { "id": preset.id, "label": preset.label, "value": self._sanitize_base_url(preset.value) or "", "runtime": preset.runtime, "model_list_strategy": preset.model_list_strategy, } for preset in spec.base_url_presets ], "base_url_editable": spec.base_url_editable, "requires_base_url": spec.requires_base_url, "supports_api_key": spec.supports_api_key, "api_key_label": spec.api_key_label, "api_key_hint": spec.api_key_hint, "supports_model_refresh": spec.supports_model_refresh, "oauth_methods": [ { "id": method.id, "type": method.type, "label": method.label, "description": method.description, } for method in spec.oauth_methods ], "description": spec.description, "auth_status": self.get_auth_status(spec.id), } async def list_providers_async( self, force_refresh: bool = False ) -> list[dict[str, Any]]: """返回前端可渲染的 provider 目录,并优先补齐 models.dev 动态平台。""" try: await self.get_models_dev_data(force_refresh=force_refresh) except Exception as err: logger.debug(f"加载 models.dev provider 目录失败,回退内置列表: {err}") return self.list_providers() def list_providers(self) -> list[dict[str, Any]]: """返回前端可渲染的 provider 目录。""" return [ self._serialize_provider(spec) for spec in sorted(self._provider_specs(), key=lambda item: item.sort_order) ] def get_provider(self, provider_id: str) -> ProviderSpec: """按 provider id 获取定义。""" normalized = self._normalize_provider_id(provider_id) for spec in self._provider_specs(): if spec.id == normalized: return spec raise LLMProviderError(f"不支持的 LLM 提供商:{provider_id}") @staticmethod def _sanitize_base_url(base_url: Optional[str]) -> Optional[str]: """清理 Base URL 中多余的空格和结尾斜杠。""" if base_url is None: return None value = str(base_url).strip() if not value: return None return value.rstrip("/") @classmethod def _default_base_url_for_provider(cls, spec: ProviderSpec) -> Optional[str]: """获取 provider 的默认 Base URL。""" default_base_url = cls._sanitize_base_url(spec.default_base_url) if default_base_url: return default_base_url if not spec.base_url_presets: return None return cls._sanitize_base_url(spec.base_url_presets[0].value) @classmethod def _normalize_provider_id(cls, provider_id: str) -> str: """规范化 provider_id 以兼容旧版配置。""" normalized = (provider_id or "").strip().lower() if normalized == "minimax-coding": return "minimax" if normalized == "kimi-coding": return "moonshot" return normalized @classmethod def _normalize_base_url_preset_id( cls, provider_id: str, base_url_preset_id: Optional[str] ) -> Optional[str]: """规范化 Base URL 预设 ID。""" normalized_provider_id = cls._normalize_provider_id(provider_id) normalized_preset_id = str(base_url_preset_id or "").strip().lower() or None if not normalized_preset_id: return None if normalized_provider_id == "minimax" and normalized_preset_id == "minimax-coding": return "minimax-cn-coding" if normalized_provider_id == "moonshot" and normalized_preset_id == "kimi-coding": return "moonshot-kimi-coding" return normalized_preset_id @classmethod def _resolve_provider_preset( cls, spec: ProviderSpec, base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[ProviderUrlPreset]: """根据给定的参数解析出适用的 Base URL 预设。""" normalized_preset_id = cls._normalize_base_url_preset_id(spec.id, base_url_preset_id) if normalized_preset_id: for preset in spec.base_url_presets: if preset.id == normalized_preset_id: return preset normalized_base_url = cls._sanitize_base_url(base_url) if normalized_base_url: for preset in spec.base_url_presets: preset_value = cls._sanitize_base_url(preset.value) if normalized_base_url == preset_value: return preset return None default_base_url = cls._default_base_url_for_provider(spec) if default_base_url: for preset in spec.base_url_presets: preset_value = cls._sanitize_base_url(preset.value) if preset_value == default_base_url: return preset return None @classmethod def _resolve_provider_runtime( cls, spec: ProviderSpec, base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> str: """解析提供商最终适用的 runtime。""" preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) return preset.runtime or spec.runtime if preset else spec.runtime @classmethod def _resolve_provider_model_list_strategy( cls, spec: ProviderSpec, base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> str: """解析获取模型列表的策略。""" preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) return preset.model_list_strategy or spec.model_list_strategy if preset else spec.model_list_strategy @classmethod def _resolve_provider_model_list_base_url( cls, spec: ProviderSpec, base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[str]: """解析用于获取模型列表的 Base URL。""" preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) if preset: preset_value = cls._sanitize_base_url(preset.value) return cls._sanitize_base_url(preset.model_list_base_url) or preset_value normalized_base_url = cls._sanitize_base_url(base_url) if normalized_base_url: return normalized_base_url return cls._default_base_url_for_provider(spec) @classmethod def _resolve_provider_models_dev_provider_id( cls, spec: ProviderSpec, base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[str]: """解析对应的 models.dev provider id。""" preset = cls._resolve_provider_preset(spec, base_url, base_url_preset_id) if preset: return preset.models_dev_provider_id or spec.models_dev_provider_id normalized_base_url = cls._sanitize_base_url(base_url) if normalized_base_url: return spec.models_dev_provider_id return spec.models_dev_provider_id def resolve_model_list_base_url( self, provider_id: str, base_url: Optional[str], base_url_preset_id: Optional[str] = None, ) -> Optional[str]: """解析对外暴露的用于获取模型列表的 Base URL。""" spec = self.get_provider(provider_id) return self._resolve_provider_model_list_base_url( spec, base_url, base_url_preset_id=base_url_preset_id, ) @staticmethod def _httpx_proxy_key() -> str: """兼容不同 httpx 版本的 proxy 参数名。""" params = httpx.Client.__init__.__code__.co_varnames return "proxy" if "proxy" in params else "proxies" def _build_httpx_kwargs(self) -> dict[str, Any]: """构造用于 httpx 客户端的参数,如代理等。""" kwargs: dict[str, Any] = {"timeout": self._DEFAULT_TIMEOUT} if settings.PROXY_HOST: kwargs[self._httpx_proxy_key()] = settings.PROXY_HOST return kwargs @staticmethod def _read_agent_config() -> dict[str, Any]: """读取 AI Agent 配置信息。""" config = SystemConfigOper().get(SystemConfigKey.AIAgentConfig) if isinstance(config, dict): return config return {} @staticmethod async def _write_agent_config(value: dict[str, Any]) -> None: """ 使用异步持久化写回 provider 鉴权配置。 `SystemConfigOper().get()` 读取的是内存缓存,这里保留同步调用; 但写入需要落库,因此统一走 `async_set()`。 """ await SystemConfigOper().async_set( SystemConfigKey.AIAgentConfig, copy.deepcopy(value) or None, ) def _get_auth_store(self) -> dict[str, Any]: """获取所有鉴权数据。""" config = self._read_agent_config() auth_store = config.get("provider_auth") if isinstance(auth_store, dict): return auth_store return {} def get_saved_auth(self, provider_id: str) -> dict[str, Any] | None: """读取持久化 provider 鉴权信息。""" return copy.deepcopy(self._get_auth_store().get(provider_id)) async def save_auth(self, provider_id: str, auth_data: dict[str, Any]) -> None: """写入 provider 鉴权信息。""" config = self._read_agent_config() auth_store = config.get("provider_auth") if not isinstance(auth_store, dict): auth_store = {} auth_store[provider_id] = copy.deepcopy(auth_data) config["provider_auth"] = auth_store await self._write_agent_config(config) async def clear_auth(self, provider_id: str) -> None: """移除 provider 鉴权信息。""" config = self._read_agent_config() auth_store = config.get("provider_auth") if not isinstance(auth_store, dict): return auth_store.pop(provider_id, None) if auth_store: config["provider_auth"] = auth_store else: config.pop("provider_auth", None) await self._write_agent_config(config) def get_auth_status(self, provider_id: str) -> dict[str, Any]: """返回前端展示用的 provider 鉴权摘要。""" auth = self.get_saved_auth(provider_id) if not auth: return {"connected": False} return { "connected": True, "type": auth.get("type"), "label": auth.get("label") or auth.get("email") or auth.get("account_id") or "已授权", "expires_at": auth.get("expires_at"), "updated_at": auth.get("updated_at"), } async def _load_models_dev_from_disk(self) -> dict[str, Any] | None: """从磁盘缓存加载 models.dev 数据。""" try: if not self._models_dev_cache_path.exists(): return None async with aiofiles.open( self._models_dev_cache_path, mode="r", encoding="utf-8" ) as stream: return json.loads(await stream.read()) except Exception as err: logger.warning(f"读取 models.dev 缓存失败: {err}") return None def _load_bundled_models_dev_payload(self) -> dict[str, Any] | None: """从随代码附带的离线文件加载 models.dev 数据。""" try: if not self._MODELS_DEV_BUNDLED_PATH.exists(): return None payload = json.loads( self._MODELS_DEV_BUNDLED_PATH.read_text(encoding="utf-8") ) except Exception as err: logger.warning(f"读取本地 models.dev 离线文件失败: {err}") return None return payload if isinstance(payload, dict) else None async def _write_models_dev_to_disk(self, payload: dict[str, Any]) -> None: """将 models.dev 数据写入磁盘缓存。""" try: self._models_dev_cache_path.parent.mkdir(parents=True, exist_ok=True) async with aiofiles.open( self._models_dev_cache_path, mode="w", encoding="utf-8" ) as stream: await stream.write(json.dumps(payload, ensure_ascii=False)) except Exception as err: logger.warning(f"写入 models.dev 缓存失败: {err}") async def _fetch_models_dev(self) -> dict[str, Any]: """通过网络请求获取最新 models.dev 数据。""" headers = {"User-Agent": "MoviePilot/1.0"} async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.get(self._MODELS_DEV_URL, headers=headers) response.raise_for_status() return response.json() async def get_models_dev_data(self, force_refresh: bool = False) -> dict[str, Any]: """ 返回 models.dev 原始数据。 这里复用 opencode 的做法,把公共模型目录缓存到本地文件中,避免每次 刷新模型列表都直接打到远端。 """ async with self._models_dev_lock: now = time.time() if ( not force_refresh and self._models_dev_data is not None and now - self._models_dev_loaded_at < self._MODELS_DEV_CACHE_TTL ): return self._models_dev_data if not force_refresh and self._models_dev_cache_path.exists(): mtime = self._models_dev_cache_path.stat().st_mtime if now - mtime < self._MODELS_DEV_CACHE_TTL: cached = await self._load_models_dev_from_disk() if isinstance(cached, dict): self._models_dev_data = cached self._models_dev_loaded_at = now return cached try: payload = await self._fetch_models_dev() self._models_dev_data = payload self._models_dev_loaded_at = now await self._write_models_dev_to_disk(payload) return payload except Exception as err: logger.warning(f"刷新 models.dev 失败,尝试回退本地缓存: {err}") cached = await self._load_models_dev_from_disk() if isinstance(cached, dict): self._models_dev_data = cached self._models_dev_loaded_at = now return cached bundled = self._load_bundled_models_dev_payload() if isinstance(bundled, dict): self._models_dev_data = bundled self._models_dev_loaded_at = now return bundled raise LLMProviderError(f"获取 models.dev 数据失败: {err}") from err async def _models_dev_provider_payload( self, provider_id: str, base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> dict[str, Any]: """获取指定 provider 在 models.dev 中的完整负载。""" spec = await self._get_provider_async(provider_id) models_dev_provider_id = self._resolve_provider_models_dev_provider_id( spec, base_url, base_url_preset_id=base_url_preset_id, ) if not models_dev_provider_id: return {} return (await self.get_models_dev_data()).get(models_dev_provider_id, {}) or {} async def _models_dev_model( self, provider_id: str, model_id: str, base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> dict[str, Any] | None: """获取指定模型的 models.dev 元数据。""" payload = await self._models_dev_provider_payload( provider_id, base_url=base_url, base_url_preset_id=base_url_preset_id, ) models = payload.get("models") if isinstance(payload, dict) else None if not isinstance(models, dict): return None candidates = [model_id] if model_id.startswith("models/"): candidates.append(model_id.removeprefix("models/")) for candidate in candidates: if candidate in models: return models[candidate] return None @staticmethod def _normalize_model_record( model_id: str, display_name: Optional[str] = None, metadata: Optional[dict[str, Any]] = None, transport: str = "openai", live_context: Optional[int] = None, live_input: Optional[int] = None, live_output: Optional[int] = None, live_supports_tools: Optional[bool] = None, live_supports_reasoning: Optional[bool] = None, live_supports_image: Optional[bool] = None, live_supports_audio: Optional[bool] = None, source: str = "provider", ) -> dict[str, Any]: """ 统一输出模型记录格式,前端据此直接渲染和自动回填上下文等参数。 """ metadata = metadata or {} limit = metadata.get("limit") or {} modalities = metadata.get("modalities") or {} input_modalities = set(modalities.get("input") or []) context_tokens = live_context or limit.get("context") input_tokens = live_input or limit.get("input") output_tokens = live_output or limit.get("output") supports_image_input = ( live_supports_image if live_supports_image is not None else "image" in input_modalities ) supports_audio_input = ( live_supports_audio if live_supports_audio is not None else "audio" in input_modalities ) supports_tools = ( live_supports_tools if live_supports_tools is not None else bool(metadata.get("tool_call")) ) supports_reasoning = ( live_supports_reasoning if live_supports_reasoning is not None else bool(metadata.get("reasoning")) ) if context_tokens: try: ct_int = int(context_tokens) if ct_int % 1024 == 0 or ct_int == 1048576 or ct_int == 2097152: context_tokens_k = max(1, ct_int // 1024) else: context_tokens_k = max(1, (ct_int + 999) // 1000) except Exception: context_tokens_k = None else: context_tokens_k = None return { "id": model_id, "name": display_name or metadata.get("name") or model_id, "family": metadata.get("family"), "context_tokens": context_tokens, "input_tokens": input_tokens, "output_tokens": output_tokens, "context_tokens_k": context_tokens_k, "supports_reasoning": supports_reasoning, "supports_tools": supports_tools, "supports_image_input": supports_image_input, "supports_audio_input": supports_audio_input, "transport": transport, "source": source, "release_date": metadata.get("release_date"), "status": metadata.get("status"), } def _normalize_base_url_for_anthropic(self, base_url: str) -> str: """对 Anthropic 的 Base URL 进行适配处理。""" normalized = self._sanitize_base_url(base_url) or "" if normalized.endswith("/v1"): return normalized[:-3] return normalized async def _list_models_from_google(self, api_key: str) -> list[dict[str, Any]]: """从 Google AI Studio 获取模型列表。""" from google import genai from google.genai.types import HttpOptions http_options = None if settings.PROXY_HOST: proxy_key = self._httpx_proxy_key() proxy_args = {proxy_key: settings.PROXY_HOST} http_options = HttpOptions( client_args=proxy_args, async_client_args=proxy_args, ) client = genai.Client(api_key=api_key, http_options=http_options) response = await client.aio.models.list() results = [] for model in response.page: supported = set(model.supported_actions or []) if "generateContent" not in supported: continue model_id = model.name metadata = await self._models_dev_model("google", model_id) or {} results.append( self._normalize_model_record( model_id=model_id, display_name=model.display_name or metadata.get("name") or model_id, metadata=metadata, source="provider", ) ) return sorted(results, key=lambda item: item["name"].lower()) async def _list_models_from_openai_compatible( self, provider_id: str, api_key: str, base_url: str, default_headers: Optional[dict[str, str]] = None, ) -> list[dict[str, Any]]: """通过 OpenAI 兼容接口获取模型列表。""" from openai import AsyncOpenAI client = AsyncOpenAI( api_key=api_key, base_url=base_url, default_headers=default_headers, timeout=15.0, max_retries=2, ) results = [] response = await client.models.list() for model in response.data: metadata = await self._models_dev_model( provider_id, model.id, base_url=base_url, ) or {} results.append( self._normalize_model_record( model_id=model.id, display_name=metadata.get("name") or model.id, metadata=metadata, source="provider", ) ) return sorted(results, key=lambda item: item["name"].lower()) async def _list_models_from_models_dev_only( self, provider_id: str, transport: str = "openai", base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> list[dict[str, Any]]: """ 某些 provider 没有统一稳定的 models.list 行为, 因此优先读取 models.dev 目录;若未来 provider 暴露标准 models 接口, 再平滑补充实时刷新即可。 """ payload = await self._models_dev_provider_payload( provider_id, base_url=base_url, base_url_preset_id=base_url_preset_id, ) models = payload.get("models") if isinstance(payload, dict) else None if not isinstance(models, dict): raise LLMProviderError(f"{provider_id} 暂无可用模型目录") results = [] for model_id, metadata in models.items(): results.append( self._normalize_model_record( model_id=model_id, display_name=metadata.get("name") or model_id, metadata=metadata, transport=transport, source="models.dev", ) ) return sorted(results, key=lambda item: item["name"].lower()) @staticmethod def _copilot_headers( token: Optional[str] = None, include_auth: bool = True ) -> dict[str, str]: """ 构造 GitHub Copilot 请求头。 OpenAI-compatible 调用会由 SDK 自行补 Authorization,因此这里允许 仅补充 Copilot 必需的意图头,避免重复覆盖。 """ headers = { "User-Agent": "MoviePilot/1.0", "Openai-Intent": "conversation-edits", "x-initiator": "user", } if include_auth and token: headers["Authorization"] = f"Bearer {token}" return headers async def _list_models_from_copilot(self, token: str) -> list[dict[str, Any]]: """从 GitHub Copilot 端点获取模型列表。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.get( "https://api.githubcopilot.com/models", headers=self._copilot_headers(token), ) response.raise_for_status() payload = response.json() raw_models = payload.get("data") if isinstance(payload, dict) else None if not isinstance(raw_models, list): raise LLMProviderError("GitHub Copilot 模型列表响应格式不正确") results = [] for item in raw_models: if not isinstance(item, dict): continue if not item.get("model_picker_enabled", True): continue if (item.get("policy") or {}).get("state") == "disabled": continue model_id = str(item.get("id") or "").strip() if not model_id: continue endpoints = set(item.get("supported_endpoints") or []) # 优先兼容 OpenAI 风格端点;仅在缺失时再切到 Anthropic 风格消息接口。 transport = ( "anthropic" if "/v1/messages" in endpoints and "/v1/chat/completions" not in endpoints and "/v1/responses" not in endpoints else "openai" ) limits = ((item.get("capabilities") or {}).get("limits") or {}) supports = ((item.get("capabilities") or {}).get("supports") or {}) metadata = await self._models_dev_model("github-copilot", model_id) or {} results.append( self._normalize_model_record( model_id=model_id, display_name=item.get("name") or metadata.get("name") or model_id, metadata=metadata, transport=transport, live_context=limits.get("max_context_window_tokens"), live_input=limits.get("max_prompt_tokens"), live_output=limits.get("max_output_tokens"), live_supports_tools=supports.get("tool_calls"), live_supports_reasoning=bool( supports.get("adaptive_thinking") or supports.get("reasoning_effort") or supports.get("max_thinking_budget") is not None or supports.get("min_thinking_budget") is not None ), live_supports_image=bool( supports.get("vision") or ((limits.get("vision") or {}).get("supported_media_types")) ), source="provider", ) ) return sorted(results, key=lambda i: i["name"].lower()) async def _list_chatgpt_oauth_models( self, provider_id: str, base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> list[dict[str, Any]]: """获取开启 OAuth 的 ChatGPT 模型列表。""" # ChatGPT OAuth 仍然是 chatgpt provider 专属能力,但模型目录不再维护 # 一份内部名单,直接跟随当前 provider 对应的 models.dev 数据。 payload = await self._models_dev_provider_payload( provider_id, base_url=base_url, base_url_preset_id=base_url_preset_id, ) models = payload.get("models") if isinstance(payload, dict) else None if not isinstance(models, dict): return [] results = [] for model_id, metadata in models.items(): results.append( self._normalize_model_record( model_id=model_id, display_name=metadata.get("name") or model_id, metadata=metadata, source="models.dev", ) ) return sorted(results, key=lambda item: item["name"].lower()) async def list_models( self, provider_id: str, api_key: Optional[str] = None, base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, force_refresh: bool = False, ) -> list[dict[str, Any]]: """返回标准化后的模型目录。""" spec = await self._get_provider_async(provider_id, force_refresh=force_refresh) resolved_model_list_strategy = self._resolve_provider_model_list_strategy( spec, base_url, base_url_preset_id=base_url_preset_id, ) if self._resolve_provider_models_dev_provider_id( spec, base_url, base_url_preset_id=base_url_preset_id, ): # 对依赖 models.dev 的 provider 主动刷新一次缓存,保证“刷新模型列表” # 在使用目录型 provider 时也能拿到最新参数。 if force_refresh: await self.get_models_dev_data(force_refresh=True) if resolved_model_list_strategy == "manual": # 万擎等推理点型平台没有稳定的全局模型目录,模型 ID 需要用户从控制台复制。 return [] runtime = await self.resolve_runtime( provider_id, model=None, api_key=api_key, base_url=base_url, base_url_preset_id=base_url_preset_id, ) if resolved_model_list_strategy == "google": return await self._list_models_from_google(runtime["api_key"]) if resolved_model_list_strategy == "github_copilot": return await self._list_models_from_copilot(runtime["api_key"]) if resolved_model_list_strategy == "chatgpt": if runtime.get("auth_mode") == "oauth": return await self._list_chatgpt_oauth_models( provider_id=provider_id, base_url=base_url, base_url_preset_id=base_url_preset_id, ) return await self._list_models_from_openai_compatible( provider_id="chatgpt", api_key=runtime["api_key"], base_url=self._resolve_provider_model_list_base_url( spec, runtime["base_url"], base_url_preset_id=base_url_preset_id, ), default_headers=runtime.get("default_headers"), ) if resolved_model_list_strategy == "anthropic_compatible": return await self._list_models_from_models_dev_only( provider_id=provider_id, transport="anthropic", base_url=base_url, base_url_preset_id=base_url_preset_id, ) if resolved_model_list_strategy == "models_dev_only": return await self._list_models_from_models_dev_only( provider_id=provider_id, transport="openai", base_url=base_url, base_url_preset_id=base_url_preset_id, ) # openai-compatible / deepseek 默认走官方 models 端点。 return await self._list_models_from_openai_compatible( provider_id=provider_id, api_key=runtime["api_key"], base_url=self._resolve_provider_model_list_base_url( spec, runtime["base_url"], base_url_preset_id=base_url_preset_id, ), default_headers=runtime.get("default_headers"), ) async def resolve_model_metadata( self, provider_id: str, model_id: Optional[str], base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> dict[str, Any] | None: """解析并返回指定模型在 models.dev 中的元数据。""" if not model_id: return None metadata = await self._models_dev_model( provider_id, model_id, base_url=base_url, base_url_preset_id=base_url_preset_id, ) if metadata: return metadata if provider_id == "chatgpt": return await self._models_dev_model("openai", model_id) if provider_id == "openai": models_dev = await self.get_models_dev_data() return models_dev.get("openai", {}).get("models", {}).get(model_id) return None @staticmethod def _jwt_claims(token: str) -> dict[str, Any]: """解析 JWT token 内容(不验证签名)。""" try: return jwt.decode(token, options={"verify_signature": False}) except Exception as err: print(err) return {} @staticmethod def _extract_chatgpt_account_id(token_payload: dict[str, Any]) -> Optional[str]: """从 ChatGPT 的 Token payload 中提取 account id。""" if token_payload.get("chatgpt_account_id"): return token_payload["chatgpt_account_id"] auth_payload = token_payload.get("https://api.openai.com/auth") or {} if auth_payload.get("chatgpt_account_id"): return auth_payload["chatgpt_account_id"] organizations = token_payload.get("organizations") or [] if organizations and isinstance(organizations[0], dict): return organizations[0].get("id") return None def _chatgpt_authorize_url( self, redirect_uri: str, challenge: str, state: str ) -> str: """构建 ChatGPT OAuth 授权链接。""" query = urlencode( { "response_type": "code", "client_id": self._CHATGPT_CLIENT_ID, "redirect_uri": redirect_uri, "scope": "openid profile email offline_access", "code_challenge": challenge, "code_challenge_method": "S256", "id_token_add_organizations": "true", "codex_cli_simplified_flow": "true", "state": state, "originator": "moviepilot", } ) return f"{self._CHATGPT_ISSUER}/oauth/authorize?{query}" @staticmethod def _pkce_pair() -> tuple[str, str]: """生成 PKCE verifier 和 challenge。""" verifier = secrets.token_urlsafe(64).replace("=", "") digest = hashlib.sha256(verifier.encode("utf-8")).digest() challenge = base64.urlsafe_b64encode(digest).decode("utf-8").rstrip("=") return verifier, challenge async def start_auth( self, provider_id: str, method_id: str, callback_url: Optional[str] = None, ) -> dict[str, Any]: """ 启动 OAuth / device code 会话。 API Key 方式已经由普通设置表单覆盖,这里只处理需要交互式授权的 provider。 """ provider = await self._get_provider_async(provider_id) method = next( (item for item in provider.oauth_methods if item.id == method_id), None, ) if not method: raise LLMProviderAuthError(f"{provider.name} 不支持授权方式:{method_id}") session = PendingAuthSession( session_id=secrets.token_urlsafe(18), provider_id=provider_id, method_id=method_id, flow_type=method.type, expires_at=time.time() + 600, ) if provider_id == "chatgpt" and method_id == "browser_oauth": if not callback_url: raise LLMProviderAuthError("ChatGPT 浏览器授权缺少回调地址") verifier, challenge = self._pkce_pair() state = secrets.token_urlsafe(24) session.authorize_url = self._chatgpt_authorize_url( redirect_uri=callback_url, challenge=challenge, state=state, ) session.instructions = "请在浏览器中完成 ChatGPT Plus/Pro 登录授权。" session.context.update( { "code_verifier": verifier, "state": state, "redirect_uri": callback_url, } ) with self._lock: self._pending_sessions[session.session_id] = session self._oauth_state_index[state] = session.session_id return { "session_id": session.session_id, "flow_type": "oauth_browser", "authorize_url": session.authorize_url, "instructions": session.instructions, "expires_at": session.expires_at, } if provider_id == "chatgpt" and method_id == "device_code": async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/usercode", headers={ "Content-Type": "application/json", "User-Agent": "MoviePilot/1.0", }, json={"client_id": self._CHATGPT_CLIENT_ID}, ) response.raise_for_status() payload = response.json() session.verification_url = f"{self._CHATGPT_ISSUER}/codex/device" session.user_code = payload.get("user_code") session.interval_seconds = max(int(payload.get("interval") or 5), 1) session.instructions = f"请在浏览器输入设备码:{session.user_code}" session.context.update( { "device_auth_id": payload.get("device_auth_id"), "user_code": payload.get("user_code"), } ) with self._lock: self._pending_sessions[session.session_id] = session return { "session_id": session.session_id, "flow_type": "device_code", "verification_url": session.verification_url, "user_code": session.user_code, "interval_seconds": session.interval_seconds, "instructions": session.instructions, "expires_at": session.expires_at, } if provider_id == "github-copilot" and method_id == "device_code": async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( "https://github.com/login/device/code", headers={ "Accept": "application/json", "Content-Type": "application/json", "User-Agent": "MoviePilot/1.0", }, json={ "client_id": self._COPILOT_CLIENT_ID, "scope": "read:user", }, ) response.raise_for_status() payload = response.json() session.verification_url = payload.get("verification_uri") session.user_code = payload.get("user_code") session.interval_seconds = max(int(payload.get("interval") or 5), 1) session.instructions = f"请在 GitHub 页面输入设备码:{session.user_code}" session.context.update( { "device_code": payload.get("device_code"), } ) with self._lock: self._pending_sessions[session.session_id] = session return { "session_id": session.session_id, "flow_type": "device_code", "verification_url": session.verification_url, "user_code": session.user_code, "interval_seconds": session.interval_seconds, "instructions": session.instructions, "expires_at": session.expires_at, } raise LLMProviderAuthError(f"暂未实现 {provider.name} 的授权方式:{method.label}") def get_session_status(self, session_id: str) -> dict[str, Any]: """读取临时授权会话状态。""" with self._lock: session = self._pending_sessions.get(session_id) if not session: raise LLMProviderAuthError("授权会话不存在或已过期") return { "session_id": session.session_id, "provider_id": session.provider_id, "status": session.status, "message": session.message, "user_code": session.user_code, "verification_url": session.verification_url, "authorize_url": session.authorize_url, "instructions": session.instructions, "interval_seconds": session.interval_seconds, "expires_at": session.expires_at, } async def _mark_session_success( self, session: PendingAuthSession, auth_data: dict[str, Any] ) -> None: """标记授权会话为成功,并保存认证信息。""" auth_data["updated_at"] = int(time.time()) await self.save_auth(session.provider_id, auth_data) session.status = "authorized" session.message = "授权成功" @staticmethod def _mark_session_error(session: PendingAuthSession, message: str) -> None: """标记授权会话为失败,并记录错误信息。""" session.status = "failed" session.message = message async def handle_chatgpt_callback( self, provider_id: str, code: Optional[str], state: Optional[str], error: Optional[str], error_description: Optional[str], ) -> tuple[bool, str]: """处理 ChatGPT 浏览器 OAuth 回调。""" if provider_id != "chatgpt": return False, "当前 provider 不支持浏览器 OAuth 回调" if error: message = error_description or error with self._lock: session_id = self._oauth_state_index.pop(state or "", None) if session_id and session_id in self._pending_sessions: self._mark_session_error(self._pending_sessions[session_id], message) return False, message if not code or not state: return False, "缺少授权码或 state 参数" with self._lock: session_id = self._oauth_state_index.pop(state, None) session = self._pending_sessions.get(session_id or "") if not session: return False, "授权会话不存在或已失效" if state != session.context.get("state"): self._mark_session_error(session, "state 校验失败") return False, "state 校验失败" try: payload = await self._exchange_chatgpt_code_for_tokens( code=code, redirect_uri=session.context["redirect_uri"], code_verifier=session.context["code_verifier"], ) claims = self._jwt_claims(payload.get("id_token") or payload["access_token"]) account_id = self._extract_chatgpt_account_id(claims) auth_data = { "type": "oauth", "provider": "chatgpt", "access_token": payload["access_token"], "refresh_token": payload["refresh_token"], "expires_at": int(time.time() + int(payload.get("expires_in") or 3600)), "account_id": account_id, "email": claims.get("email"), "label": claims.get("email") or account_id or "ChatGPT Plus/Pro", } await self._mark_session_success(session, auth_data) return True, "ChatGPT 授权成功" except Exception as err: message = f"ChatGPT 授权失败: {err}" self._mark_session_error(session, message) return False, message async def poll_auth_session(self, session_id: str) -> dict[str, Any]: """ 执行一次 device code 轮询,并返回最新状态。 前端可按 interval_seconds 轮询,直到状态变为 authorized / failed。 """ with self._lock: session = self._pending_sessions.get(session_id) if not session: raise LLMProviderAuthError("授权会话不存在或已过期") if session.status != "pending": return self.get_session_status(session_id) try: if session.provider_id == "chatgpt" and session.method_id == "device_code": await self._poll_chatgpt_device_auth(session) elif session.provider_id == "github-copilot" and session.method_id == "device_code": await self._poll_copilot_device_auth(session) else: raise LLMProviderAuthError("当前授权会话不支持轮询") except Exception as err: self._mark_session_error(session, str(err)) return self.get_session_status(session_id) async def _exchange_chatgpt_code_for_tokens( self, code: str, redirect_uri: str, code_verifier: str ) -> dict[str, Any]: """使用 authorization code 交换 ChatGPT 令牌。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( f"{self._CHATGPT_ISSUER}/oauth/token", headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ "grant_type": "authorization_code", "code": code, "redirect_uri": redirect_uri, "client_id": self._CHATGPT_CLIENT_ID, "code_verifier": code_verifier, }, ) response.raise_for_status() return response.json() async def _refresh_chatgpt_tokens(self, refresh_token: str) -> dict[str, Any]: """刷新 ChatGPT 的 access_token。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( f"{self._CHATGPT_ISSUER}/oauth/token", headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ "grant_type": "refresh_token", "refresh_token": refresh_token, "client_id": self._CHATGPT_CLIENT_ID, }, ) response.raise_for_status() return response.json() async def _poll_chatgpt_device_auth(self, session: PendingAuthSession) -> None: """轮询 ChatGPT Device Auth 状态。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/token", headers={ "Content-Type": "application/json", "User-Agent": "MoviePilot/1.0", }, json={ "device_auth_id": session.context["device_auth_id"], "user_code": session.context["user_code"], }, ) if response.status_code in {403, 404}: session.message = "等待用户在浏览器完成授权" return response.raise_for_status() payload = response.json() token_payload = await self._exchange_chatgpt_code_for_tokens( code=payload["authorization_code"], redirect_uri=f"{self._CHATGPT_ISSUER}/deviceauth/callback", code_verifier=payload["code_verifier"], ) claims = self._jwt_claims( token_payload.get("id_token") or token_payload["access_token"] ) account_id = self._extract_chatgpt_account_id(claims) await self._mark_session_success( session, { "type": "oauth", "provider": "chatgpt", "access_token": token_payload["access_token"], "refresh_token": token_payload["refresh_token"], "expires_at": int(time.time() + int(token_payload.get("expires_in") or 3600)), "account_id": account_id, "email": claims.get("email"), "label": claims.get("email") or account_id or "ChatGPT Plus/Pro", }, ) async def _poll_copilot_device_auth(self, session: PendingAuthSession) -> None: """轮询 GitHub Copilot Device Auth 状态。""" async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client: response = await client.post( "https://github.com/login/oauth/access_token", headers={ "Accept": "application/json", "Content-Type": "application/json", "User-Agent": "MoviePilot/1.0", }, json={ "client_id": self._COPILOT_CLIENT_ID, "device_code": session.context["device_code"], "grant_type": "urn:ietf:params:oauth:grant-type:device_code", }, ) response.raise_for_status() payload = response.json() access_token = payload.get("access_token") if access_token: await self._mark_session_success( session, { "type": "oauth", "provider": "github-copilot", "access_token": access_token, # Copilot 设备码授权返回的是长期可复用 token,这里复用 access 字段即可。 "refresh_token": access_token, "expires_at": None, "label": "GitHub Copilot", }, ) return error = payload.get("error") if error == "authorization_pending": session.message = "等待用户在 GitHub 页面完成授权" return if error == "slow_down": session.interval_seconds = max(session.interval_seconds + 5, 10) session.message = "GitHub 要求降低轮询频率,稍后继续。" return if error: raise LLMProviderAuthError(f"GitHub Copilot 授权失败: {error}") async def _resolve_chatgpt_oauth(self) -> dict[str, Any]: """解析并返回 ChatGPT OAuth 鉴权,支持自动刷新 Token。""" auth = self.get_saved_auth("chatgpt") if not auth or auth.get("type") != "oauth": raise LLMProviderAuthError("尚未完成 ChatGPT Plus/Pro 授权") expires_at = auth.get("expires_at") refresh_token = auth.get("refresh_token") # 预留 60 秒刷新缓冲,避免刚发起请求就遇到过期。 if expires_at and refresh_token and int(expires_at) <= int(time.time()) + 60: payload = await self._refresh_chatgpt_tokens(refresh_token) claims = self._jwt_claims(payload.get("id_token") or payload["access_token"]) auth.update( { "access_token": payload["access_token"], "refresh_token": payload.get("refresh_token") or refresh_token, "expires_at": int(time.time() + int(payload.get("expires_in") or 3600)), "account_id": auth.get("account_id") or self._extract_chatgpt_account_id(claims), "email": auth.get("email") or claims.get("email"), "label": auth.get("label") or claims.get("email") or auth.get("account_id") or "ChatGPT Plus/Pro", } ) await self.save_auth("chatgpt", auth) return auth async def resolve_runtime( self, provider_id: str, model: Optional[str], api_key: Optional[str] = None, base_url: Optional[str] = None, base_url_preset_id: Optional[str] = None, ) -> dict[str, Any]: """ 解析 provider 运行时参数。 返回统一结构,供 `LLMHelper` 创建具体 LangChain 模型实例时使用。 """ normalized_provider_id = self._normalize_provider_id(provider_id) normalized_base_url_preset_id = self._normalize_base_url_preset_id( normalized_provider_id, base_url_preset_id, ) spec = await self._get_provider_async(normalized_provider_id) resolved_runtime = self._resolve_provider_runtime( spec, base_url, base_url_preset_id=normalized_base_url_preset_id, ) normalized_api_key = str(api_key or "").strip() or None normalized_base_url = self._sanitize_base_url(base_url) model_record = None if model: try: model_record = next( ( item for item in await self.list_models( normalized_provider_id, api_key=api_key, base_url=base_url, base_url_preset_id=normalized_base_url_preset_id, ) if item["id"] == model ), None, ) except Exception as err: print(err) model_record = None result: dict[str, Any] = { "provider_id": normalized_provider_id, "runtime": resolved_runtime, "model_id": model, "model_record": model_record, "model_metadata": await self.resolve_model_metadata( normalized_provider_id, model, base_url=base_url, base_url_preset_id=normalized_base_url_preset_id, ), "default_headers": None, "use_responses_api": None, "auth_mode": "api_key", } if normalized_provider_id == "chatgpt": auth = None try: auth = await self._resolve_chatgpt_oauth() except Exception as err: print(err) pass if auth: headers = {"originator": "moviepilot"} if auth.get("account_id"): headers["ChatGPT-Account-Id"] = auth["account_id"] result.update( { "runtime": "chatgpt", "api_key": auth["access_token"], "base_url": self._CHATGPT_CODEX_BASE_URL, "default_headers": headers, "use_responses_api": True, "auth_mode": "oauth", } ) return result if normalized_api_key: result.update( { "runtime": "openai_compatible", "api_key": normalized_api_key, "base_url": normalized_base_url or self._default_base_url_for_provider(spec), "auth_mode": "api_key", } ) return result raise LLMProviderAuthError("请提供 API Key 或完成 ChatGPT 授权") if normalized_provider_id == "github-copilot": auth = self.get_saved_auth("github-copilot") if auth and auth.get("type") == "oauth": token = auth.get("refresh_token") or auth.get("access_token") elif normalized_api_key: token = normalized_api_key else: raise LLMProviderAuthError("请先完成 GitHub Copilot 授权") transport = (model_record or {}).get("transport") or "openai" result.update( { "runtime": "copilot_anthropic" if transport == "anthropic" else "github_copilot", "api_key": token, "base_url": "https://api.githubcopilot.com", "default_headers": self._copilot_headers( token, include_auth=transport == "anthropic", ), "auth_mode": "oauth" if auth else "api_key", } ) return result if resolved_runtime == "google": if not normalized_api_key: raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key") result.update( { "api_key": normalized_api_key, "base_url": None, "auth_mode": "api_key", } ) return result if resolved_runtime == "anthropic_compatible": effective_base_url = normalized_base_url or self._default_base_url_for_provider( spec ) if not normalized_api_key: raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key") if not effective_base_url: raise LLMProviderAuthError(f"{spec.name} 缺少 Base URL") result.update( { "api_key": normalized_api_key, "base_url": self._normalize_base_url_for_anthropic( effective_base_url ), "auth_mode": "api_key", } ) return result effective_base_url = normalized_base_url or self._default_base_url_for_provider(spec) if spec.requires_base_url and not effective_base_url: raise LLMProviderAuthError(f"{spec.name} 需要填写 Base URL") if not normalized_api_key: raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key") result.update( { "api_key": normalized_api_key, "base_url": effective_base_url, "auth_mode": "api_key", } ) return result def render_auth_result_html(success: bool, message: str) -> str: """OAuth 回调落地页。""" title = "授权成功" if success else "授权失败" accent = "#3aa675" if success else "#e24b4b" return f"""
{message}