mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-06 20:42:43 +08:00
1537 lines
60 KiB
Python
1537 lines
60 KiB
Python
"""LLM provider registry, auth flows, and model metadata helpers."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import copy
|
||
import hashlib
|
||
import json
|
||
import secrets
|
||
import threading
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
from typing import Any, Dict, Optional, Tuple
|
||
from urllib.parse import urlencode
|
||
|
||
import aiofiles
|
||
import httpx
|
||
import jwt
|
||
|
||
from app.core.config import settings
|
||
from app.db.systemconfig_oper import SystemConfigOper
|
||
from app.log import logger
|
||
from app.schemas.types import SystemConfigKey
|
||
from app.utils.singleton import Singleton
|
||
|
||
|
||
class LLMProviderError(RuntimeError):
|
||
"""通用 LLM provider 异常。"""
|
||
|
||
|
||
class LLMProviderAuthError(LLMProviderError):
|
||
"""LLM provider 鉴权异常。"""
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class ProviderAuthMethod:
|
||
"""前端展示用的授权方式定义。"""
|
||
|
||
id: str
|
||
type: str
|
||
label: str
|
||
description: str = ""
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class ProviderSpec:
|
||
"""描述一个可接入的 LLM provider。"""
|
||
|
||
id: str
|
||
name: str
|
||
runtime: str
|
||
models_dev_provider_id: Optional[str] = None
|
||
default_base_url: Optional[str] = None
|
||
base_url_editable: bool = False
|
||
requires_base_url: bool = False
|
||
supports_api_key: bool = True
|
||
api_key_label: str = "API Key"
|
||
api_key_hint: str = ""
|
||
oauth_methods: Tuple[ProviderAuthMethod, ...] = ()
|
||
supports_model_refresh: bool = True
|
||
model_list_strategy: str = "openai_compatible"
|
||
sort_order: int = 100
|
||
description: str = ""
|
||
|
||
|
||
@dataclass
|
||
class PendingAuthSession:
|
||
"""保存临时鉴权会话,避免把 PKCE/device code 等状态写回配置。"""
|
||
|
||
session_id: str
|
||
provider_id: str
|
||
method_id: str
|
||
flow_type: str
|
||
status: str = "pending"
|
||
message: str = ""
|
||
authorize_url: Optional[str] = None
|
||
instructions: Optional[str] = None
|
||
verification_url: Optional[str] = None
|
||
user_code: Optional[str] = None
|
||
interval_seconds: int = 5
|
||
expires_at: float = 0
|
||
created_at: float = field(default_factory=time.time)
|
||
context: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
class LLMProviderManager(metaclass=Singleton):
|
||
"""统一维护 provider 目录、models.dev 缓存和 OAuth 状态。"""
|
||
|
||
_MODELS_DEV_URL = "https://models.dev/api.json"
|
||
_MODELS_DEV_CACHE_TTL = 12 * 60 * 60
|
||
_CHATGPT_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||
_CHATGPT_ISSUER = "https://auth.openai.com"
|
||
_CHATGPT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||
_COPILOT_CLIENT_ID = "Ov23li8tweQw6odWQebz"
|
||
_DEFAULT_TIMEOUT = httpx.Timeout(15.0, connect=10.0)
|
||
_CHATGPT_ALLOWED_OAUTH_MODELS = {
|
||
"gpt-5.1-codex",
|
||
"gpt-5.1-codex-max",
|
||
"gpt-5.1-codex-mini",
|
||
"gpt-5.2",
|
||
"gpt-5.2-codex",
|
||
"gpt-5.3-codex",
|
||
"gpt-5.4",
|
||
"gpt-5.4-mini",
|
||
"gpt-5.5",
|
||
}
|
||
|
||
def __init__(self):
|
||
self._lock = threading.RLock()
|
||
self._models_dev_lock = asyncio.Lock()
|
||
self._pending_sessions: dict[str, PendingAuthSession] = {}
|
||
self._oauth_state_index: dict[str, str] = {}
|
||
self._models_dev_data: dict[str, Any] | None = None
|
||
self._models_dev_loaded_at: float = 0
|
||
self._models_dev_cache_path = (
|
||
Path(settings.TEMP_PATH) / "llm_provider_models_dev_cache.json"
|
||
)
|
||
|
||
@staticmethod
|
||
def _provider_specs() -> tuple[ProviderSpec, ...]:
|
||
"""
|
||
返回受支持的 provider 定义。
|
||
|
||
OpenAI 保留为用户自定义 OpenAI-compatible 兜底入口,因此仍要求填写
|
||
Base URL;ChatGPT 则单独承接官方 API Key / ChatGPT 订阅鉴权。
|
||
"""
|
||
browser_auth = ProviderAuthMethod(
|
||
id="browser_oauth",
|
||
type="oauth",
|
||
label="浏览器授权",
|
||
description="使用 ChatGPT Plus/Pro 浏览器登录并回调授权。",
|
||
)
|
||
device_auth = ProviderAuthMethod(
|
||
id="device_code",
|
||
type="device",
|
||
label="设备码授权",
|
||
description="适合无回调环境,复制设备码到浏览器完成登录。",
|
||
)
|
||
return (
|
||
ProviderSpec(
|
||
id="chatgpt",
|
||
name="ChatGPT",
|
||
runtime="chatgpt",
|
||
models_dev_provider_id="openai",
|
||
default_base_url="https://api.openai.com/v1",
|
||
api_key_hint="可直接填写 OpenAI API Key,或使用 ChatGPT Plus/Pro 登录授权。",
|
||
oauth_methods=(browser_auth, device_auth),
|
||
model_list_strategy="chatgpt",
|
||
description="支持 ChatGPT Plus/Pro 鉴权或 OpenAI 官方 API Key。",
|
||
sort_order=10,
|
||
),
|
||
ProviderSpec(
|
||
id="google",
|
||
name="Google",
|
||
runtime="google",
|
||
models_dev_provider_id="google",
|
||
supports_api_key=True,
|
||
api_key_hint="填写 Gemini / Google AI Studio API Key。",
|
||
model_list_strategy="google",
|
||
description="Gemini / Google AI Studio。",
|
||
sort_order=20,
|
||
),
|
||
ProviderSpec(
|
||
id="deepseek",
|
||
name="DeepSeek",
|
||
runtime="deepseek",
|
||
models_dev_provider_id="deepseek",
|
||
default_base_url="https://api.deepseek.com",
|
||
api_key_hint="填写 DeepSeek API Key。",
|
||
description="DeepSeek 官方平台。",
|
||
sort_order=30,
|
||
),
|
||
ProviderSpec(
|
||
id="openrouter",
|
||
name="OpenRouter",
|
||
runtime="openai_compatible",
|
||
models_dev_provider_id="openrouter",
|
||
default_base_url="https://openrouter.ai/api/v1",
|
||
api_key_hint="填写 OpenRouter API Key。",
|
||
description="OpenRouter 聚合模型平台。",
|
||
sort_order=40,
|
||
),
|
||
ProviderSpec(
|
||
id="github-copilot",
|
||
name="GitHub Copilot",
|
||
runtime="github_copilot",
|
||
models_dev_provider_id="github-copilot",
|
||
supports_api_key=False,
|
||
api_key_label="GitHub Token",
|
||
oauth_methods=(
|
||
ProviderAuthMethod(
|
||
id="device_code",
|
||
type="device",
|
||
label="GitHub 设备码授权",
|
||
description="使用 GitHub Copilot 订阅登录授权。",
|
||
),
|
||
),
|
||
model_list_strategy="github_copilot",
|
||
description="通过 GitHub Copilot 订阅接入。",
|
||
sort_order=50,
|
||
),
|
||
ProviderSpec(
|
||
id="siliconflow",
|
||
name="硅基流动",
|
||
runtime="openai_compatible",
|
||
models_dev_provider_id="siliconflow",
|
||
default_base_url="https://api.siliconflow.cn/v1",
|
||
api_key_hint="填写硅基流动 API Key。",
|
||
description="SiliconFlow 官方兼容端点。",
|
||
sort_order=60,
|
||
),
|
||
ProviderSpec(
|
||
id="alibaba",
|
||
name="阿里云百炼",
|
||
runtime="openai_compatible",
|
||
models_dev_provider_id="alibaba",
|
||
default_base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
api_key_hint="填写 DashScope / Alibaba API Key。",
|
||
description="阿里云百炼兼容端点。",
|
||
sort_order=70,
|
||
),
|
||
ProviderSpec(
|
||
id="volcengine",
|
||
name="火山方舟",
|
||
runtime="openai_compatible",
|
||
default_base_url="https://ark.cn-beijing.volces.com/api/v3",
|
||
api_key_hint="填写火山方舟 API Key。",
|
||
description="字节跳动火山引擎兼容端点。",
|
||
sort_order=80,
|
||
),
|
||
ProviderSpec(
|
||
id="tencent",
|
||
name="腾讯云",
|
||
runtime="openai_compatible",
|
||
models_dev_provider_id="tencent-tokenhub",
|
||
default_base_url="https://tokenhub.tencentmaas.com/v1",
|
||
api_key_hint="填写 Tencent API Key。",
|
||
model_list_strategy="models_dev_only",
|
||
description="腾讯兼容端点。",
|
||
sort_order=90,
|
||
),
|
||
ProviderSpec(
|
||
id="ollama-cloud",
|
||
name="Ollama Cloud",
|
||
runtime="openai_compatible",
|
||
models_dev_provider_id="ollama-cloud",
|
||
default_base_url="https://ollama.com/v1",
|
||
api_key_hint="填写 Ollama Cloud API Key。",
|
||
description="Ollama Cloud 云端模型服务。",
|
||
sort_order=100,
|
||
),
|
||
ProviderSpec(
|
||
id="nvidia",
|
||
name="Nvidia",
|
||
runtime="openai_compatible",
|
||
models_dev_provider_id="nvidia",
|
||
default_base_url="https://integrate.api.nvidia.com/v1",
|
||
api_key_hint="填写 Nvidia API Key。",
|
||
description="Nvidia 集成推理平台。",
|
||
sort_order=110,
|
||
),
|
||
ProviderSpec(
|
||
id="minimax",
|
||
name="MiniMax",
|
||
runtime="anthropic_compatible",
|
||
models_dev_provider_id="minimax",
|
||
default_base_url="https://api.minimaxi.com/anthropic/v1",
|
||
api_key_hint="填写 MiniMax API Key。",
|
||
model_list_strategy="anthropic_compatible",
|
||
description="MiniMax Anthropic-compatible 端点。",
|
||
sort_order=120,
|
||
),
|
||
ProviderSpec(
|
||
id="xiaomi",
|
||
name="Xiaomi",
|
||
runtime="openai_compatible",
|
||
models_dev_provider_id="xiaomi",
|
||
default_base_url="https://api.xiaomimimo.com/v1",
|
||
api_key_hint="填写 Xiaomi API Key。",
|
||
description="小米 Mimo 兼容端点。",
|
||
sort_order=130,
|
||
),
|
||
ProviderSpec(
|
||
id="openai",
|
||
name="OpenAI Compatible",
|
||
runtime="openai_compatible",
|
||
default_base_url="",
|
||
base_url_editable=True,
|
||
requires_base_url=True,
|
||
supports_api_key=True,
|
||
api_key_hint="通用 OpenAI-compatible 兜底入口,需要手动填写 Base URL。",
|
||
description="通用 OpenAI-compatible 模型服务。",
|
||
sort_order=200,
|
||
),
|
||
)
|
||
|
||
def list_providers(self) -> list[dict[str, Any]]:
|
||
"""返回前端可渲染的 provider 目录。"""
|
||
providers = []
|
||
for spec in sorted(self._provider_specs(), key=lambda item: item.sort_order):
|
||
providers.append(
|
||
{
|
||
"id": spec.id,
|
||
"name": spec.name,
|
||
"runtime": spec.runtime,
|
||
"default_base_url": spec.default_base_url or "",
|
||
"base_url_editable": spec.base_url_editable,
|
||
"requires_base_url": spec.requires_base_url,
|
||
"supports_api_key": spec.supports_api_key,
|
||
"api_key_label": spec.api_key_label,
|
||
"api_key_hint": spec.api_key_hint,
|
||
"supports_model_refresh": spec.supports_model_refresh,
|
||
"oauth_methods": [
|
||
{
|
||
"id": method.id,
|
||
"type": method.type,
|
||
"label": method.label,
|
||
"description": method.description,
|
||
}
|
||
for method in spec.oauth_methods
|
||
],
|
||
"description": spec.description,
|
||
"auth_status": self.get_auth_status(spec.id),
|
||
}
|
||
)
|
||
return providers
|
||
|
||
def get_provider(self, provider_id: str) -> ProviderSpec:
|
||
"""按 provider id 获取定义。"""
|
||
normalized = (provider_id or "").strip().lower()
|
||
for spec in self._provider_specs():
|
||
if spec.id == normalized:
|
||
return spec
|
||
raise LLMProviderError(f"不支持的 LLM 提供商:{provider_id}")
|
||
|
||
@staticmethod
|
||
def _sanitize_base_url(base_url: Optional[str]) -> Optional[str]:
|
||
if base_url is None:
|
||
return None
|
||
value = str(base_url).strip()
|
||
if not value:
|
||
return None
|
||
return value.rstrip("/")
|
||
|
||
@staticmethod
|
||
def _httpx_proxy_key() -> str:
|
||
"""兼容不同 httpx 版本的 proxy 参数名。"""
|
||
params = httpx.Client.__init__.__code__.co_varnames
|
||
return "proxy" if "proxy" in params else "proxies"
|
||
|
||
def _build_httpx_kwargs(self) -> dict[str, Any]:
|
||
kwargs: dict[str, Any] = {"timeout": self._DEFAULT_TIMEOUT}
|
||
if settings.PROXY_HOST:
|
||
kwargs[self._httpx_proxy_key()] = settings.PROXY_HOST
|
||
return kwargs
|
||
|
||
@staticmethod
|
||
def _read_agent_config() -> dict[str, Any]:
|
||
config = SystemConfigOper().get(SystemConfigKey.AIAgentConfig)
|
||
if isinstance(config, dict):
|
||
return config
|
||
return {}
|
||
|
||
@staticmethod
|
||
async def _write_agent_config(value: dict[str, Any]) -> None:
|
||
"""
|
||
使用异步持久化写回 provider 鉴权配置。
|
||
|
||
`SystemConfigOper().get()` 读取的是内存缓存,这里保留同步调用;
|
||
但写入需要落库,因此统一走 `async_set()`。
|
||
"""
|
||
await SystemConfigOper().async_set(
|
||
SystemConfigKey.AIAgentConfig,
|
||
copy.deepcopy(value) or None,
|
||
)
|
||
|
||
def _get_auth_store(self) -> dict[str, Any]:
|
||
config = self._read_agent_config()
|
||
auth_store = config.get("provider_auth")
|
||
if isinstance(auth_store, dict):
|
||
return auth_store
|
||
return {}
|
||
|
||
def get_saved_auth(self, provider_id: str) -> dict[str, Any] | None:
|
||
"""读取持久化 provider 鉴权信息。"""
|
||
return copy.deepcopy(self._get_auth_store().get(provider_id))
|
||
|
||
async def save_auth(self, provider_id: str, auth_data: dict[str, Any]) -> None:
|
||
"""写入 provider 鉴权信息。"""
|
||
config = self._read_agent_config()
|
||
auth_store = config.get("provider_auth")
|
||
if not isinstance(auth_store, dict):
|
||
auth_store = {}
|
||
auth_store[provider_id] = copy.deepcopy(auth_data)
|
||
config["provider_auth"] = auth_store
|
||
await self._write_agent_config(config)
|
||
|
||
async def clear_auth(self, provider_id: str) -> None:
|
||
"""移除 provider 鉴权信息。"""
|
||
config = self._read_agent_config()
|
||
auth_store = config.get("provider_auth")
|
||
if not isinstance(auth_store, dict):
|
||
return
|
||
auth_store.pop(provider_id, None)
|
||
if auth_store:
|
||
config["provider_auth"] = auth_store
|
||
else:
|
||
config.pop("provider_auth", None)
|
||
await self._write_agent_config(config)
|
||
|
||
def get_auth_status(self, provider_id: str) -> dict[str, Any]:
|
||
"""返回前端展示用的 provider 鉴权摘要。"""
|
||
auth = self.get_saved_auth(provider_id)
|
||
if not auth:
|
||
return {"connected": False}
|
||
return {
|
||
"connected": True,
|
||
"type": auth.get("type"),
|
||
"label": auth.get("label") or auth.get("email") or auth.get("account_id") or "已授权",
|
||
"expires_at": auth.get("expires_at"),
|
||
"updated_at": auth.get("updated_at"),
|
||
}
|
||
|
||
async def _load_models_dev_from_disk(self) -> dict[str, Any] | None:
|
||
try:
|
||
if not self._models_dev_cache_path.exists():
|
||
return None
|
||
async with aiofiles.open(
|
||
self._models_dev_cache_path, mode="r", encoding="utf-8"
|
||
) as stream:
|
||
return json.loads(await stream.read())
|
||
except Exception as err:
|
||
logger.warning(f"读取 models.dev 缓存失败: {err}")
|
||
return None
|
||
|
||
async def _write_models_dev_to_disk(self, payload: dict[str, Any]) -> None:
|
||
try:
|
||
self._models_dev_cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||
async with aiofiles.open(
|
||
self._models_dev_cache_path, mode="w", encoding="utf-8"
|
||
) as stream:
|
||
await stream.write(json.dumps(payload, ensure_ascii=False))
|
||
except Exception as err:
|
||
logger.warning(f"写入 models.dev 缓存失败: {err}")
|
||
|
||
async def _fetch_models_dev(self) -> dict[str, Any]:
|
||
headers = {"User-Agent": "MoviePilot/1.0"}
|
||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||
response = await client.get(self._MODELS_DEV_URL, headers=headers)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
|
||
async def get_models_dev_data(self, force_refresh: bool = False) -> dict[str, Any]:
|
||
"""
|
||
返回 models.dev 原始数据。
|
||
|
||
这里复用 opencode 的做法,把公共模型目录缓存到本地文件中,避免每次
|
||
刷新模型列表都直接打到远端。
|
||
"""
|
||
async with self._models_dev_lock:
|
||
now = time.time()
|
||
if (
|
||
not force_refresh
|
||
and self._models_dev_data is not None
|
||
and now - self._models_dev_loaded_at < self._MODELS_DEV_CACHE_TTL
|
||
):
|
||
return self._models_dev_data
|
||
|
||
if not force_refresh and self._models_dev_cache_path.exists():
|
||
mtime = self._models_dev_cache_path.stat().st_mtime
|
||
if now - mtime < self._MODELS_DEV_CACHE_TTL:
|
||
cached = await self._load_models_dev_from_disk()
|
||
if isinstance(cached, dict):
|
||
self._models_dev_data = cached
|
||
self._models_dev_loaded_at = now
|
||
return cached
|
||
|
||
try:
|
||
payload = await self._fetch_models_dev()
|
||
self._models_dev_data = payload
|
||
self._models_dev_loaded_at = now
|
||
await self._write_models_dev_to_disk(payload)
|
||
return payload
|
||
except Exception as err:
|
||
logger.warning(f"刷新 models.dev 失败,尝试回退本地缓存: {err}")
|
||
cached = await self._load_models_dev_from_disk()
|
||
if isinstance(cached, dict):
|
||
self._models_dev_data = cached
|
||
self._models_dev_loaded_at = now
|
||
return cached
|
||
raise LLMProviderError(f"获取 models.dev 数据失败: {err}") from err
|
||
|
||
async def _models_dev_provider_payload(self, provider_id: str) -> dict[str, Any]:
|
||
spec = self.get_provider(provider_id)
|
||
if not spec.models_dev_provider_id:
|
||
return {}
|
||
return (await self.get_models_dev_data()).get(spec.models_dev_provider_id, {}) or {}
|
||
|
||
async def _models_dev_model(
|
||
self, provider_id: str, model_id: str
|
||
) -> dict[str, Any] | None:
|
||
payload = await self._models_dev_provider_payload(provider_id)
|
||
models = payload.get("models") if isinstance(payload, dict) else None
|
||
if not isinstance(models, dict):
|
||
return None
|
||
|
||
candidates = [model_id]
|
||
if model_id.startswith("models/"):
|
||
candidates.append(model_id.removeprefix("models/"))
|
||
|
||
for candidate in candidates:
|
||
if candidate in models:
|
||
return models[candidate]
|
||
return None
|
||
|
||
@staticmethod
|
||
def _normalize_model_record(
|
||
model_id: str,
|
||
display_name: Optional[str] = None,
|
||
metadata: Optional[dict[str, Any]] = None,
|
||
transport: str = "openai",
|
||
live_context: Optional[int] = None,
|
||
live_input: Optional[int] = None,
|
||
live_output: Optional[int] = None,
|
||
live_supports_tools: Optional[bool] = None,
|
||
live_supports_reasoning: Optional[bool] = None,
|
||
live_supports_image: Optional[bool] = None,
|
||
live_supports_audio: Optional[bool] = None,
|
||
source: str = "provider",
|
||
) -> dict[str, Any]:
|
||
"""
|
||
统一输出模型记录格式,前端据此直接渲染和自动回填上下文等参数。
|
||
"""
|
||
metadata = metadata or {}
|
||
limit = metadata.get("limit") or {}
|
||
modalities = metadata.get("modalities") or {}
|
||
input_modalities = set(modalities.get("input") or [])
|
||
|
||
context_tokens = live_context or limit.get("context")
|
||
input_tokens = live_input or limit.get("input")
|
||
output_tokens = live_output or limit.get("output")
|
||
supports_image_input = (
|
||
live_supports_image
|
||
if live_supports_image is not None
|
||
else "image" in input_modalities
|
||
)
|
||
supports_audio_input = (
|
||
live_supports_audio
|
||
if live_supports_audio is not None
|
||
else "audio" in input_modalities
|
||
)
|
||
supports_tools = (
|
||
live_supports_tools
|
||
if live_supports_tools is not None
|
||
else bool(metadata.get("tool_call"))
|
||
)
|
||
supports_reasoning = (
|
||
live_supports_reasoning
|
||
if live_supports_reasoning is not None
|
||
else bool(metadata.get("reasoning"))
|
||
)
|
||
|
||
if context_tokens:
|
||
try:
|
||
ct_int = int(context_tokens)
|
||
if ct_int % 1024 == 0 or ct_int == 1048576 or ct_int == 2097152:
|
||
context_tokens_k = max(1, ct_int // 1024)
|
||
else:
|
||
context_tokens_k = max(1, (ct_int + 999) // 1000)
|
||
except Exception:
|
||
context_tokens_k = None
|
||
else:
|
||
context_tokens_k = None
|
||
|
||
return {
|
||
"id": model_id,
|
||
"name": display_name or metadata.get("name") or model_id,
|
||
"family": metadata.get("family"),
|
||
"context_tokens": context_tokens,
|
||
"input_tokens": input_tokens,
|
||
"output_tokens": output_tokens,
|
||
"context_tokens_k": context_tokens_k,
|
||
"supports_reasoning": supports_reasoning,
|
||
"supports_tools": supports_tools,
|
||
"supports_image_input": supports_image_input,
|
||
"supports_audio_input": supports_audio_input,
|
||
"transport": transport,
|
||
"source": source,
|
||
"release_date": metadata.get("release_date"),
|
||
"status": metadata.get("status"),
|
||
}
|
||
|
||
def _normalize_base_url_for_anthropic(self, base_url: str) -> str:
|
||
normalized = self._sanitize_base_url(base_url) or ""
|
||
if normalized.endswith("/v1"):
|
||
return normalized[:-3]
|
||
return normalized
|
||
|
||
async def _list_models_from_google(self, api_key: str) -> list[dict[str, Any]]:
|
||
from google import genai
|
||
from google.genai.types import HttpOptions
|
||
|
||
http_options = None
|
||
if settings.PROXY_HOST:
|
||
proxy_key = self._httpx_proxy_key()
|
||
proxy_args = {proxy_key: settings.PROXY_HOST}
|
||
http_options = HttpOptions(
|
||
client_args=proxy_args,
|
||
async_client_args=proxy_args,
|
||
)
|
||
|
||
client = genai.Client(api_key=api_key, http_options=http_options)
|
||
response = await client.aio.models.list()
|
||
results = []
|
||
for model in response.page:
|
||
supported = set(model.supported_actions or [])
|
||
if "generateContent" not in supported:
|
||
continue
|
||
model_id = model.name
|
||
metadata = await self._models_dev_model("google", model_id) or {}
|
||
results.append(
|
||
self._normalize_model_record(
|
||
model_id=model_id,
|
||
display_name=model.display_name or metadata.get("name") or model_id,
|
||
metadata=metadata,
|
||
source="provider",
|
||
)
|
||
)
|
||
return sorted(results, key=lambda item: item["name"].lower())
|
||
|
||
async def _list_models_from_openai_compatible(
|
||
self,
|
||
provider_id: str,
|
||
api_key: str,
|
||
base_url: str,
|
||
default_headers: Optional[dict[str, str]] = None,
|
||
) -> list[dict[str, Any]]:
|
||
from openai import AsyncOpenAI
|
||
|
||
client = AsyncOpenAI(
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
default_headers=default_headers,
|
||
timeout=15.0,
|
||
max_retries=2,
|
||
)
|
||
results = []
|
||
response = await client.models.list()
|
||
for model in response.data:
|
||
metadata = await self._models_dev_model(provider_id, model.id) or {}
|
||
results.append(
|
||
self._normalize_model_record(
|
||
model_id=model.id,
|
||
display_name=metadata.get("name") or model.id,
|
||
metadata=metadata,
|
||
source="provider",
|
||
)
|
||
)
|
||
return sorted(results, key=lambda item: item["name"].lower())
|
||
|
||
async def _list_models_from_models_dev_only(
|
||
self,
|
||
provider_id: str,
|
||
transport: str = "openai",
|
||
) -> list[dict[str, Any]]:
|
||
"""
|
||
某些 provider 没有统一稳定的 models.list 行为,
|
||
因此优先读取 models.dev 目录;若未来 provider 暴露标准 models 接口,
|
||
再平滑补充实时刷新即可。
|
||
"""
|
||
payload = await self._models_dev_provider_payload(provider_id)
|
||
models = payload.get("models") if isinstance(payload, dict) else None
|
||
if not isinstance(models, dict):
|
||
raise LLMProviderError(f"{provider_id} 暂无可用模型目录")
|
||
results = []
|
||
for model_id, metadata in models.items():
|
||
results.append(
|
||
self._normalize_model_record(
|
||
model_id=model_id,
|
||
display_name=metadata.get("name") or model_id,
|
||
metadata=metadata,
|
||
transport=transport,
|
||
source="models.dev",
|
||
)
|
||
)
|
||
return sorted(results, key=lambda item: item["name"].lower())
|
||
|
||
@staticmethod
|
||
def _copilot_headers(
|
||
token: Optional[str] = None, include_auth: bool = True
|
||
) -> dict[str, str]:
|
||
"""
|
||
构造 GitHub Copilot 请求头。
|
||
|
||
OpenAI-compatible 调用会由 SDK 自行补 Authorization,因此这里允许
|
||
仅补充 Copilot 必需的意图头,避免重复覆盖。
|
||
"""
|
||
headers = {
|
||
"User-Agent": "MoviePilot/1.0",
|
||
"Openai-Intent": "conversation-edits",
|
||
"x-initiator": "user",
|
||
}
|
||
if include_auth and token:
|
||
headers["Authorization"] = f"Bearer {token}"
|
||
return headers
|
||
|
||
async def _list_models_from_copilot(self, token: str) -> list[dict[str, Any]]:
|
||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||
response = await client.get(
|
||
"https://api.githubcopilot.com/models",
|
||
headers=self._copilot_headers(token),
|
||
)
|
||
response.raise_for_status()
|
||
payload = response.json()
|
||
|
||
raw_models = payload.get("data") if isinstance(payload, dict) else None
|
||
if not isinstance(raw_models, list):
|
||
raise LLMProviderError("GitHub Copilot 模型列表响应格式不正确")
|
||
|
||
results = []
|
||
for item in raw_models:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
if not item.get("model_picker_enabled", True):
|
||
continue
|
||
if (item.get("policy") or {}).get("state") == "disabled":
|
||
continue
|
||
|
||
model_id = str(item.get("id") or "").strip()
|
||
if not model_id:
|
||
continue
|
||
|
||
endpoints = set(item.get("supported_endpoints") or [])
|
||
# 优先兼容 OpenAI 风格端点;仅在缺失时再切到 Anthropic 风格消息接口。
|
||
transport = (
|
||
"anthropic"
|
||
if "/v1/messages" in endpoints
|
||
and "/v1/chat/completions" not in endpoints
|
||
and "/v1/responses" not in endpoints
|
||
else "openai"
|
||
)
|
||
|
||
limits = ((item.get("capabilities") or {}).get("limits") or {})
|
||
supports = ((item.get("capabilities") or {}).get("supports") or {})
|
||
metadata = await self._models_dev_model("github-copilot", model_id) or {}
|
||
results.append(
|
||
self._normalize_model_record(
|
||
model_id=model_id,
|
||
display_name=item.get("name") or metadata.get("name") or model_id,
|
||
metadata=metadata,
|
||
transport=transport,
|
||
live_context=limits.get("max_context_window_tokens"),
|
||
live_input=limits.get("max_prompt_tokens"),
|
||
live_output=limits.get("max_output_tokens"),
|
||
live_supports_tools=supports.get("tool_calls"),
|
||
live_supports_reasoning=bool(
|
||
supports.get("adaptive_thinking")
|
||
or supports.get("reasoning_effort")
|
||
or supports.get("max_thinking_budget") is not None
|
||
or supports.get("min_thinking_budget") is not None
|
||
),
|
||
live_supports_image=bool(
|
||
supports.get("vision")
|
||
or ((limits.get("vision") or {}).get("supported_media_types"))
|
||
),
|
||
source="provider",
|
||
)
|
||
)
|
||
return sorted(results, key=lambda i: i["name"].lower())
|
||
|
||
async def _list_chatgpt_oauth_models(self) -> list[dict[str, Any]]:
|
||
payload = await self._models_dev_provider_payload("chatgpt")
|
||
models = payload.get("models") if isinstance(payload, dict) else None
|
||
if not isinstance(models, dict):
|
||
return [
|
||
{
|
||
"id": model_id,
|
||
"name": model_id,
|
||
"context_tokens": None,
|
||
"input_tokens": None,
|
||
"output_tokens": None,
|
||
"context_tokens_k": settings.LLM_MAX_CONTEXT_TOKENS,
|
||
"supports_reasoning": True,
|
||
"supports_tools": True,
|
||
"supports_image_input": True,
|
||
"supports_audio_input": False,
|
||
"transport": "openai",
|
||
"source": "builtin",
|
||
"release_date": None,
|
||
"status": None,
|
||
}
|
||
for model_id in sorted(self._CHATGPT_ALLOWED_OAUTH_MODELS)
|
||
]
|
||
|
||
results = []
|
||
for model_id, metadata in models.items():
|
||
if "codex" in model_id or model_id in self._CHATGPT_ALLOWED_OAUTH_MODELS:
|
||
match = None
|
||
if model_id.startswith("gpt-"):
|
||
try:
|
||
match = float(model_id.split("-")[1].replace(".mini", ""))
|
||
except Exception as err:
|
||
print(err)
|
||
match = None
|
||
if match is not None and match > 5.4 and "codex" not in model_id:
|
||
continue
|
||
results.append(
|
||
self._normalize_model_record(
|
||
model_id=model_id,
|
||
display_name=metadata.get("name") or model_id,
|
||
metadata=metadata,
|
||
source="models.dev",
|
||
)
|
||
)
|
||
return sorted(results, key=lambda item: item["name"].lower())
|
||
|
||
async def list_models(
|
||
self,
|
||
provider_id: str,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
force_refresh: bool = False,
|
||
) -> list[dict[str, Any]]:
|
||
"""返回标准化后的模型目录。"""
|
||
spec = self.get_provider(provider_id)
|
||
if force_refresh and spec.models_dev_provider_id:
|
||
# 对依赖 models.dev 的 provider 主动刷新一次缓存,保证“刷新模型列表”
|
||
# 在使用目录型 provider 时也能拿到最新参数。
|
||
await self.get_models_dev_data(force_refresh=True)
|
||
runtime = await self.resolve_runtime(
|
||
provider_id,
|
||
model=None,
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
|
||
if spec.model_list_strategy == "google":
|
||
return await self._list_models_from_google(runtime["api_key"])
|
||
|
||
if spec.model_list_strategy == "github_copilot":
|
||
return await self._list_models_from_copilot(runtime["api_key"])
|
||
|
||
if spec.model_list_strategy == "chatgpt":
|
||
if runtime.get("auth_mode") == "oauth":
|
||
return await self._list_chatgpt_oauth_models()
|
||
return await self._list_models_from_openai_compatible(
|
||
provider_id="chatgpt",
|
||
api_key=runtime["api_key"],
|
||
base_url=runtime["base_url"],
|
||
default_headers=runtime.get("default_headers"),
|
||
)
|
||
|
||
if spec.model_list_strategy == "anthropic_compatible":
|
||
return await self._list_models_from_models_dev_only(
|
||
provider_id=provider_id,
|
||
transport="anthropic",
|
||
)
|
||
|
||
if spec.model_list_strategy == "models_dev_only":
|
||
return await self._list_models_from_models_dev_only(
|
||
provider_id=provider_id,
|
||
transport="openai",
|
||
)
|
||
|
||
# openai-compatible / deepseek 默认走官方 models 端点。
|
||
return await self._list_models_from_openai_compatible(
|
||
provider_id=provider_id,
|
||
api_key=runtime["api_key"],
|
||
base_url=runtime["base_url"],
|
||
default_headers=runtime.get("default_headers"),
|
||
)
|
||
|
||
async def resolve_model_metadata(
|
||
self, provider_id: str, model_id: Optional[str]
|
||
) -> dict[str, Any] | None:
|
||
if not model_id:
|
||
return None
|
||
metadata = await self._models_dev_model(provider_id, model_id)
|
||
if metadata:
|
||
return metadata
|
||
if provider_id == "chatgpt":
|
||
return await self._models_dev_model("openai", model_id)
|
||
if provider_id == "openai":
|
||
models_dev = await self.get_models_dev_data()
|
||
return models_dev.get("openai", {}).get("models", {}).get(model_id)
|
||
return None
|
||
|
||
@staticmethod
|
||
def _jwt_claims(token: str) -> dict[str, Any]:
|
||
try:
|
||
return jwt.decode(token, options={"verify_signature": False})
|
||
except Exception as err:
|
||
print(err)
|
||
return {}
|
||
|
||
@staticmethod
|
||
def _extract_chatgpt_account_id(token_payload: dict[str, Any]) -> Optional[str]:
|
||
if token_payload.get("chatgpt_account_id"):
|
||
return token_payload["chatgpt_account_id"]
|
||
auth_payload = token_payload.get("https://api.openai.com/auth") or {}
|
||
if auth_payload.get("chatgpt_account_id"):
|
||
return auth_payload["chatgpt_account_id"]
|
||
organizations = token_payload.get("organizations") or []
|
||
if organizations and isinstance(organizations[0], dict):
|
||
return organizations[0].get("id")
|
||
return None
|
||
|
||
def _chatgpt_authorize_url(
|
||
self, redirect_uri: str, challenge: str, state: str
|
||
) -> str:
|
||
query = urlencode(
|
||
{
|
||
"response_type": "code",
|
||
"client_id": self._CHATGPT_CLIENT_ID,
|
||
"redirect_uri": redirect_uri,
|
||
"scope": "openid profile email offline_access",
|
||
"code_challenge": challenge,
|
||
"code_challenge_method": "S256",
|
||
"id_token_add_organizations": "true",
|
||
"codex_cli_simplified_flow": "true",
|
||
"state": state,
|
||
"originator": "moviepilot",
|
||
}
|
||
)
|
||
return f"{self._CHATGPT_ISSUER}/oauth/authorize?{query}"
|
||
|
||
@staticmethod
|
||
def _pkce_pair() -> tuple[str, str]:
|
||
verifier = secrets.token_urlsafe(64).replace("=", "")
|
||
digest = hashlib.sha256(verifier.encode("utf-8")).digest()
|
||
challenge = base64.urlsafe_b64encode(digest).decode("utf-8").rstrip("=")
|
||
return verifier, challenge
|
||
|
||
async def start_auth(
|
||
self,
|
||
provider_id: str,
|
||
method_id: str,
|
||
callback_url: Optional[str] = None,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
启动 OAuth / device code 会话。
|
||
|
||
API Key 方式已经由普通设置表单覆盖,这里只处理需要交互式授权的 provider。
|
||
"""
|
||
provider = self.get_provider(provider_id)
|
||
method = next(
|
||
(item for item in provider.oauth_methods if item.id == method_id),
|
||
None,
|
||
)
|
||
if not method:
|
||
raise LLMProviderAuthError(f"{provider.name} 不支持授权方式:{method_id}")
|
||
|
||
session = PendingAuthSession(
|
||
session_id=secrets.token_urlsafe(18),
|
||
provider_id=provider_id,
|
||
method_id=method_id,
|
||
flow_type=method.type,
|
||
expires_at=time.time() + 600,
|
||
)
|
||
|
||
if provider_id == "chatgpt" and method_id == "browser_oauth":
|
||
if not callback_url:
|
||
raise LLMProviderAuthError("ChatGPT 浏览器授权缺少回调地址")
|
||
verifier, challenge = self._pkce_pair()
|
||
state = secrets.token_urlsafe(24)
|
||
session.authorize_url = self._chatgpt_authorize_url(
|
||
redirect_uri=callback_url,
|
||
challenge=challenge,
|
||
state=state,
|
||
)
|
||
session.instructions = "请在浏览器中完成 ChatGPT Plus/Pro 登录授权。"
|
||
session.context.update(
|
||
{
|
||
"code_verifier": verifier,
|
||
"state": state,
|
||
"redirect_uri": callback_url,
|
||
}
|
||
)
|
||
with self._lock:
|
||
self._pending_sessions[session.session_id] = session
|
||
self._oauth_state_index[state] = session.session_id
|
||
return {
|
||
"session_id": session.session_id,
|
||
"flow_type": "oauth_browser",
|
||
"authorize_url": session.authorize_url,
|
||
"instructions": session.instructions,
|
||
"expires_at": session.expires_at,
|
||
}
|
||
|
||
if provider_id == "chatgpt" and method_id == "device_code":
|
||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||
response = await client.post(
|
||
f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/usercode",
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"User-Agent": "MoviePilot/1.0",
|
||
},
|
||
json={"client_id": self._CHATGPT_CLIENT_ID},
|
||
)
|
||
response.raise_for_status()
|
||
payload = response.json()
|
||
|
||
session.verification_url = f"{self._CHATGPT_ISSUER}/codex/device"
|
||
session.user_code = payload.get("user_code")
|
||
session.interval_seconds = max(int(payload.get("interval") or 5), 1)
|
||
session.instructions = f"请在浏览器输入设备码:{session.user_code}"
|
||
session.context.update(
|
||
{
|
||
"device_auth_id": payload.get("device_auth_id"),
|
||
"user_code": payload.get("user_code"),
|
||
}
|
||
)
|
||
with self._lock:
|
||
self._pending_sessions[session.session_id] = session
|
||
return {
|
||
"session_id": session.session_id,
|
||
"flow_type": "device_code",
|
||
"verification_url": session.verification_url,
|
||
"user_code": session.user_code,
|
||
"interval_seconds": session.interval_seconds,
|
||
"instructions": session.instructions,
|
||
"expires_at": session.expires_at,
|
||
}
|
||
|
||
if provider_id == "github-copilot" and method_id == "device_code":
|
||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||
response = await client.post(
|
||
"https://github.com/login/device/code",
|
||
headers={
|
||
"Accept": "application/json",
|
||
"Content-Type": "application/json",
|
||
"User-Agent": "MoviePilot/1.0",
|
||
},
|
||
json={
|
||
"client_id": self._COPILOT_CLIENT_ID,
|
||
"scope": "read:user",
|
||
},
|
||
)
|
||
response.raise_for_status()
|
||
payload = response.json()
|
||
|
||
session.verification_url = payload.get("verification_uri")
|
||
session.user_code = payload.get("user_code")
|
||
session.interval_seconds = max(int(payload.get("interval") or 5), 1)
|
||
session.instructions = f"请在 GitHub 页面输入设备码:{session.user_code}"
|
||
session.context.update(
|
||
{
|
||
"device_code": payload.get("device_code"),
|
||
}
|
||
)
|
||
with self._lock:
|
||
self._pending_sessions[session.session_id] = session
|
||
return {
|
||
"session_id": session.session_id,
|
||
"flow_type": "device_code",
|
||
"verification_url": session.verification_url,
|
||
"user_code": session.user_code,
|
||
"interval_seconds": session.interval_seconds,
|
||
"instructions": session.instructions,
|
||
"expires_at": session.expires_at,
|
||
}
|
||
|
||
raise LLMProviderAuthError(f"暂未实现 {provider.name} 的授权方式:{method.label}")
|
||
|
||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||
"""读取临时授权会话状态。"""
|
||
with self._lock:
|
||
session = self._pending_sessions.get(session_id)
|
||
if not session:
|
||
raise LLMProviderAuthError("授权会话不存在或已过期")
|
||
return {
|
||
"session_id": session.session_id,
|
||
"provider_id": session.provider_id,
|
||
"status": session.status,
|
||
"message": session.message,
|
||
"user_code": session.user_code,
|
||
"verification_url": session.verification_url,
|
||
"authorize_url": session.authorize_url,
|
||
"instructions": session.instructions,
|
||
"interval_seconds": session.interval_seconds,
|
||
"expires_at": session.expires_at,
|
||
}
|
||
|
||
async def _mark_session_success(
|
||
self, session: PendingAuthSession, auth_data: dict[str, Any]
|
||
) -> None:
|
||
auth_data["updated_at"] = int(time.time())
|
||
await self.save_auth(session.provider_id, auth_data)
|
||
session.status = "authorized"
|
||
session.message = "授权成功"
|
||
|
||
@staticmethod
|
||
def _mark_session_error(session: PendingAuthSession, message: str) -> None:
|
||
session.status = "failed"
|
||
session.message = message
|
||
|
||
async def handle_chatgpt_callback(
|
||
self,
|
||
provider_id: str,
|
||
code: Optional[str],
|
||
state: Optional[str],
|
||
error: Optional[str],
|
||
error_description: Optional[str],
|
||
) -> tuple[bool, str]:
|
||
"""处理 ChatGPT 浏览器 OAuth 回调。"""
|
||
if provider_id != "chatgpt":
|
||
return False, "当前 provider 不支持浏览器 OAuth 回调"
|
||
|
||
if error:
|
||
message = error_description or error
|
||
with self._lock:
|
||
session_id = self._oauth_state_index.pop(state or "", None)
|
||
if session_id and session_id in self._pending_sessions:
|
||
self._mark_session_error(self._pending_sessions[session_id], message)
|
||
return False, message
|
||
|
||
if not code or not state:
|
||
return False, "缺少授权码或 state 参数"
|
||
|
||
with self._lock:
|
||
session_id = self._oauth_state_index.pop(state, None)
|
||
session = self._pending_sessions.get(session_id or "")
|
||
|
||
if not session:
|
||
return False, "授权会话不存在或已失效"
|
||
|
||
if state != session.context.get("state"):
|
||
self._mark_session_error(session, "state 校验失败")
|
||
return False, "state 校验失败"
|
||
|
||
try:
|
||
payload = await self._exchange_chatgpt_code_for_tokens(
|
||
code=code,
|
||
redirect_uri=session.context["redirect_uri"],
|
||
code_verifier=session.context["code_verifier"],
|
||
)
|
||
claims = self._jwt_claims(payload.get("id_token") or payload["access_token"])
|
||
account_id = self._extract_chatgpt_account_id(claims)
|
||
auth_data = {
|
||
"type": "oauth",
|
||
"provider": "chatgpt",
|
||
"access_token": payload["access_token"],
|
||
"refresh_token": payload["refresh_token"],
|
||
"expires_at": int(time.time() + int(payload.get("expires_in") or 3600)),
|
||
"account_id": account_id,
|
||
"email": claims.get("email"),
|
||
"label": claims.get("email") or account_id or "ChatGPT Plus/Pro",
|
||
}
|
||
await self._mark_session_success(session, auth_data)
|
||
return True, "ChatGPT 授权成功"
|
||
except Exception as err:
|
||
message = f"ChatGPT 授权失败: {err}"
|
||
self._mark_session_error(session, message)
|
||
return False, message
|
||
|
||
async def poll_auth_session(self, session_id: str) -> dict[str, Any]:
|
||
"""
|
||
执行一次 device code 轮询,并返回最新状态。
|
||
|
||
前端可按 interval_seconds 轮询,直到状态变为 authorized / failed。
|
||
"""
|
||
with self._lock:
|
||
session = self._pending_sessions.get(session_id)
|
||
if not session:
|
||
raise LLMProviderAuthError("授权会话不存在或已过期")
|
||
if session.status != "pending":
|
||
return self.get_session_status(session_id)
|
||
|
||
try:
|
||
if session.provider_id == "chatgpt" and session.method_id == "device_code":
|
||
await self._poll_chatgpt_device_auth(session)
|
||
elif session.provider_id == "github-copilot" and session.method_id == "device_code":
|
||
await self._poll_copilot_device_auth(session)
|
||
else:
|
||
raise LLMProviderAuthError("当前授权会话不支持轮询")
|
||
except Exception as err:
|
||
self._mark_session_error(session, str(err))
|
||
return self.get_session_status(session_id)
|
||
|
||
async def _exchange_chatgpt_code_for_tokens(
|
||
self, code: str, redirect_uri: str, code_verifier: str
|
||
) -> dict[str, Any]:
|
||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||
response = await client.post(
|
||
f"{self._CHATGPT_ISSUER}/oauth/token",
|
||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||
data={
|
||
"grant_type": "authorization_code",
|
||
"code": code,
|
||
"redirect_uri": redirect_uri,
|
||
"client_id": self._CHATGPT_CLIENT_ID,
|
||
"code_verifier": code_verifier,
|
||
},
|
||
)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
|
||
async def _refresh_chatgpt_tokens(self, refresh_token: str) -> dict[str, Any]:
|
||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||
response = await client.post(
|
||
f"{self._CHATGPT_ISSUER}/oauth/token",
|
||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||
data={
|
||
"grant_type": "refresh_token",
|
||
"refresh_token": refresh_token,
|
||
"client_id": self._CHATGPT_CLIENT_ID,
|
||
},
|
||
)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
|
||
async def _poll_chatgpt_device_auth(self, session: PendingAuthSession) -> None:
|
||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||
response = await client.post(
|
||
f"{self._CHATGPT_ISSUER}/api/accounts/deviceauth/token",
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"User-Agent": "MoviePilot/1.0",
|
||
},
|
||
json={
|
||
"device_auth_id": session.context["device_auth_id"],
|
||
"user_code": session.context["user_code"],
|
||
},
|
||
)
|
||
|
||
if response.status_code in {403, 404}:
|
||
session.message = "等待用户在浏览器完成授权"
|
||
return
|
||
|
||
response.raise_for_status()
|
||
payload = response.json()
|
||
token_payload = await self._exchange_chatgpt_code_for_tokens(
|
||
code=payload["authorization_code"],
|
||
redirect_uri=f"{self._CHATGPT_ISSUER}/deviceauth/callback",
|
||
code_verifier=payload["code_verifier"],
|
||
)
|
||
claims = self._jwt_claims(
|
||
token_payload.get("id_token") or token_payload["access_token"]
|
||
)
|
||
account_id = self._extract_chatgpt_account_id(claims)
|
||
await self._mark_session_success(
|
||
session,
|
||
{
|
||
"type": "oauth",
|
||
"provider": "chatgpt",
|
||
"access_token": token_payload["access_token"],
|
||
"refresh_token": token_payload["refresh_token"],
|
||
"expires_at": int(time.time() + int(token_payload.get("expires_in") or 3600)),
|
||
"account_id": account_id,
|
||
"email": claims.get("email"),
|
||
"label": claims.get("email") or account_id or "ChatGPT Plus/Pro",
|
||
},
|
||
)
|
||
|
||
async def _poll_copilot_device_auth(self, session: PendingAuthSession) -> None:
|
||
async with httpx.AsyncClient(**self._build_httpx_kwargs()) as client:
|
||
response = await client.post(
|
||
"https://github.com/login/oauth/access_token",
|
||
headers={
|
||
"Accept": "application/json",
|
||
"Content-Type": "application/json",
|
||
"User-Agent": "MoviePilot/1.0",
|
||
},
|
||
json={
|
||
"client_id": self._COPILOT_CLIENT_ID,
|
||
"device_code": session.context["device_code"],
|
||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||
},
|
||
)
|
||
response.raise_for_status()
|
||
payload = response.json()
|
||
|
||
access_token = payload.get("access_token")
|
||
if access_token:
|
||
await self._mark_session_success(
|
||
session,
|
||
{
|
||
"type": "oauth",
|
||
"provider": "github-copilot",
|
||
"access_token": access_token,
|
||
# Copilot 设备码授权返回的是长期可复用 token,这里复用 access 字段即可。
|
||
"refresh_token": access_token,
|
||
"expires_at": None,
|
||
"label": "GitHub Copilot",
|
||
},
|
||
)
|
||
return
|
||
|
||
error = payload.get("error")
|
||
if error == "authorization_pending":
|
||
session.message = "等待用户在 GitHub 页面完成授权"
|
||
return
|
||
if error == "slow_down":
|
||
session.interval_seconds = max(session.interval_seconds + 5, 10)
|
||
session.message = "GitHub 要求降低轮询频率,稍后继续。"
|
||
return
|
||
if error:
|
||
raise LLMProviderAuthError(f"GitHub Copilot 授权失败: {error}")
|
||
|
||
async def _resolve_chatgpt_oauth(self) -> dict[str, Any]:
|
||
auth = self.get_saved_auth("chatgpt")
|
||
if not auth or auth.get("type") != "oauth":
|
||
raise LLMProviderAuthError("尚未完成 ChatGPT Plus/Pro 授权")
|
||
|
||
expires_at = auth.get("expires_at")
|
||
refresh_token = auth.get("refresh_token")
|
||
# 预留 60 秒刷新缓冲,避免刚发起请求就遇到过期。
|
||
if expires_at and refresh_token and int(expires_at) <= int(time.time()) + 60:
|
||
payload = await self._refresh_chatgpt_tokens(refresh_token)
|
||
claims = self._jwt_claims(payload.get("id_token") or payload["access_token"])
|
||
auth.update(
|
||
{
|
||
"access_token": payload["access_token"],
|
||
"refresh_token": payload.get("refresh_token") or refresh_token,
|
||
"expires_at": int(time.time() + int(payload.get("expires_in") or 3600)),
|
||
"account_id": auth.get("account_id")
|
||
or self._extract_chatgpt_account_id(claims),
|
||
"email": auth.get("email") or claims.get("email"),
|
||
"label": auth.get("label")
|
||
or claims.get("email")
|
||
or auth.get("account_id")
|
||
or "ChatGPT Plus/Pro",
|
||
}
|
||
)
|
||
await self.save_auth("chatgpt", auth)
|
||
return auth
|
||
|
||
async def resolve_runtime(
|
||
self,
|
||
provider_id: str,
|
||
model: Optional[str],
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
解析 provider 运行时参数。
|
||
|
||
返回统一结构,供 `LLMHelper` 创建具体 LangChain 模型实例时使用。
|
||
"""
|
||
spec = self.get_provider(provider_id)
|
||
normalized_api_key = str(api_key or "").strip() or None
|
||
normalized_base_url = self._sanitize_base_url(base_url)
|
||
model_record = None
|
||
if model:
|
||
try:
|
||
model_record = next(
|
||
(
|
||
item
|
||
for item in await self.list_models(
|
||
provider_id,
|
||
api_key=api_key,
|
||
base_url=base_url,
|
||
)
|
||
if item["id"] == model
|
||
),
|
||
None,
|
||
)
|
||
except Exception as err:
|
||
print(err)
|
||
model_record = None
|
||
|
||
result: dict[str, Any] = {
|
||
"provider_id": provider_id,
|
||
"runtime": spec.runtime,
|
||
"model_id": model,
|
||
"model_record": model_record,
|
||
"model_metadata": await self.resolve_model_metadata(provider_id, model),
|
||
"default_headers": None,
|
||
"use_responses_api": None,
|
||
"auth_mode": "api_key",
|
||
}
|
||
|
||
if provider_id == "chatgpt":
|
||
auth = None
|
||
try:
|
||
auth = await self._resolve_chatgpt_oauth()
|
||
except Exception as err:
|
||
print(err)
|
||
pass
|
||
|
||
if auth:
|
||
headers = {"originator": "moviepilot"}
|
||
if auth.get("account_id"):
|
||
headers["ChatGPT-Account-Id"] = auth["account_id"]
|
||
result.update(
|
||
{
|
||
"runtime": "chatgpt",
|
||
"api_key": auth["access_token"],
|
||
"base_url": self._CHATGPT_CODEX_BASE_URL,
|
||
"default_headers": headers,
|
||
"use_responses_api": True,
|
||
"auth_mode": "oauth",
|
||
}
|
||
)
|
||
return result
|
||
|
||
if normalized_api_key:
|
||
result.update(
|
||
{
|
||
"runtime": "openai_compatible",
|
||
"api_key": normalized_api_key,
|
||
"base_url": normalized_base_url or spec.default_base_url,
|
||
"auth_mode": "api_key",
|
||
}
|
||
)
|
||
return result
|
||
|
||
raise LLMProviderAuthError("请提供 API Key 或完成 ChatGPT 授权")
|
||
|
||
if provider_id == "github-copilot":
|
||
auth = self.get_saved_auth("github-copilot")
|
||
if auth and auth.get("type") == "oauth":
|
||
token = auth.get("refresh_token") or auth.get("access_token")
|
||
elif normalized_api_key:
|
||
token = normalized_api_key
|
||
else:
|
||
raise LLMProviderAuthError("请先完成 GitHub Copilot 授权")
|
||
|
||
transport = (model_record or {}).get("transport") or "openai"
|
||
result.update(
|
||
{
|
||
"runtime": "copilot_anthropic"
|
||
if transport == "anthropic"
|
||
else "github_copilot",
|
||
"api_key": token,
|
||
"base_url": "https://api.githubcopilot.com",
|
||
"default_headers": self._copilot_headers(
|
||
token,
|
||
include_auth=transport == "anthropic",
|
||
),
|
||
"auth_mode": "oauth" if auth else "api_key",
|
||
}
|
||
)
|
||
return result
|
||
|
||
if spec.runtime == "google":
|
||
if not normalized_api_key:
|
||
raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key")
|
||
result.update(
|
||
{
|
||
"api_key": normalized_api_key,
|
||
"base_url": None,
|
||
"auth_mode": "api_key",
|
||
}
|
||
)
|
||
return result
|
||
|
||
if spec.runtime == "anthropic_compatible":
|
||
effective_base_url = normalized_base_url or spec.default_base_url
|
||
if not normalized_api_key:
|
||
raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key")
|
||
if not effective_base_url:
|
||
raise LLMProviderAuthError(f"{spec.name} 缺少 Base URL")
|
||
result.update(
|
||
{
|
||
"api_key": normalized_api_key,
|
||
"base_url": self._normalize_base_url_for_anthropic(
|
||
effective_base_url
|
||
),
|
||
"auth_mode": "api_key",
|
||
}
|
||
)
|
||
return result
|
||
|
||
effective_base_url = normalized_base_url or spec.default_base_url
|
||
if spec.requires_base_url and not effective_base_url:
|
||
raise LLMProviderAuthError(f"{spec.name} 需要填写 Base URL")
|
||
if not normalized_api_key:
|
||
raise LLMProviderAuthError(f"{spec.name} 需要填写 API Key")
|
||
result.update(
|
||
{
|
||
"api_key": normalized_api_key,
|
||
"base_url": effective_base_url,
|
||
"auth_mode": "api_key",
|
||
}
|
||
)
|
||
return result
|
||
|
||
|
||
def render_auth_result_html(success: bool, message: str) -> str:
|
||
"""OAuth 回调落地页。"""
|
||
title = "授权成功" if success else "授权失败"
|
||
accent = "#3aa675" if success else "#e24b4b"
|
||
return f"""<!doctype html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||
<title>{title}</title>
|
||
<style>
|
||
body {{
|
||
margin: 0;
|
||
min-height: 100vh;
|
||
display: flex;
|
||
align-items: center;
|
||
justify-content: center;
|
||
background: #101418;
|
||
color: #f3f5f7;
|
||
font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif;
|
||
}}
|
||
.card {{
|
||
width: min(480px, calc(100vw - 32px));
|
||
padding: 28px 24px;
|
||
border-radius: 18px;
|
||
background: rgba(20, 28, 36, 0.92);
|
||
box-shadow: 0 18px 48px rgba(0, 0, 0, 0.28);
|
||
}}
|
||
h1 {{
|
||
margin: 0 0 12px;
|
||
font-size: 24px;
|
||
color: {accent};
|
||
}}
|
||
p {{
|
||
margin: 0;
|
||
line-height: 1.7;
|
||
color: #d4dbe3;
|
||
}}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<div class="card">
|
||
<h1>{title}</h1>
|
||
<p>{message}</p>
|
||
</div>
|
||
<script>
|
||
if (window.opener) {{
|
||
try {{
|
||
window.opener.postMessage({json.dumps({"type": "moviepilot-llm-auth", "success": success})}, "*");
|
||
}} catch (err) {{}}
|
||
}}
|
||
setTimeout(() => window.close(), 1800);
|
||
</script>
|
||
</body>
|
||
</html>"""
|