From fd280a49b79090327f76147c97d6baa9a619a861 Mon Sep 17 00:00:00 2001 From: jxxghp Date: Thu, 4 Jun 2026 08:23:54 +0800 Subject: [PATCH] feat(auth): implement authentication provider endpoints and ticket exchange --- app/api/apiv1.py | 3 +- app/api/endpoints/auth.py | 70 ++++++++++++++++++ app/core/auth_bridge.py | 146 ++++++++++++++++++++++++++++++++++++++ app/core/plugin.py | 41 +++++++++++ app/plugins/__init__.py | 15 ++++ 5 files changed, 274 insertions(+), 1 deletion(-) create mode 100644 app/api/endpoints/auth.py create mode 100644 app/core/auth_bridge.py diff --git a/app/api/apiv1.py b/app/api/apiv1.py index f1bfeebe..15834d0c 100644 --- a/app/api/apiv1.py +++ b/app/api/apiv1.py @@ -1,10 +1,11 @@ from fastapi import APIRouter -from app.api.endpoints import login, user, webhook, message, site, subscribe, \ +from app.api.endpoints import auth, login, user, webhook, message, site, subscribe, \ media, douban, search, plugin, tmdb, history, system, download, dashboard, \ transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic, llm, notification api_router = APIRouter() +api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) api_router.include_router(login.router, prefix="/login", tags=["login"]) api_router.include_router(user.router, prefix="/user", tags=["user"]) api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"]) diff --git a/app/api/endpoints/auth.py b/app/api/endpoints/auth.py new file mode 100644 index 00000000..2883cb50 --- /dev/null +++ b/app/api/endpoints/auth.py @@ -0,0 +1,70 @@ +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from app import schemas +from app.core.auth_bridge import build_token_response, consume_plugin_auth_ticket +from app.core.plugin import PluginManager +from app.db.models.passkey import PassKey +from app.db.models.user import User + +router = APIRouter() + + +class AuthExchangeRequest(BaseModel): + """ + 插件认证票据兑换请求。 + """ + + ticket: str + + +def _system_auth_providers() -> list[dict[str, Any]]: + """ + 获取系统内建的匿名登录方式摘要。 + + :return: 系统认证提供方列表 + """ + has_passkey = bool(PassKey.list(db=None)) + return [ + { + "id": "system:passkey", + "type": "system", + "method": "passkey", + "name": "通行密钥", + "icon": "material-symbols:passkey", + "enabled": has_passkey, + } + ] + + +@router.get("/providers", summary="查询登录认证提供方", response_model=list[dict]) +def auth_providers() -> list[dict[str, Any]]: + """ + 查询系统和插件提供的登录认证入口。 + + :return: 认证提供方摘要列表 + """ + providers = _system_auth_providers() + providers.extend(PluginManager().get_plugin_auth_providers()) + return [provider for provider in providers if provider.get("enabled", True)] + + +@router.post("/exchange", summary="兑换插件认证登录票据", response_model=schemas.Token) +def auth_exchange(body: AuthExchangeRequest) -> schemas.Token: + """ + 将插件认证成功后生成的一次性票据兑换为系统 Token。 + + :param body: 票据兑换请求 + :return: 标准登录 Token + """ + ticket_data = consume_plugin_auth_ticket(body.ticket) + if not ticket_data: + raise HTTPException(status_code=401, detail="认证票据无效或已过期") + + user = User.get(db=None, rid=ticket_data.get("user_id")) + if not user or not user.is_active: + raise HTTPException(status_code=403, detail="用户不存在或已禁用") + + return build_token_response(user) diff --git a/app/core/auth_bridge.py b/app/core/auth_bridge.py new file mode 100644 index 00000000..038979d4 --- /dev/null +++ b/app/core/auth_bridge.py @@ -0,0 +1,146 @@ +import secrets +import threading +import time +from datetime import timedelta +from typing import Any, Optional + +from app import schemas +from app.core import security +from app.core.config import settings +from app.db.models.user import User +from app.db.systemconfig_oper import SystemConfigOper +from app.helper.sites import SitesHelper +from app.schemas.types import SystemConfigKey +from app.utils.singleton import Singleton + + +class AuthTicketStore(metaclass=Singleton): + """ + 插件认证一次性票据存储。 + """ + + _ttl_seconds = 120 + _max_items = 1024 + + def __init__(self): + """ + 初始化内存票据缓存。 + """ + self._tickets: dict[str, dict[str, Any]] = {} + self._lock = threading.RLock() + + def create(self, user_id: int, provider_id: str, metadata: Optional[dict[str, Any]] = None) -> str: + """ + 创建短时一次性登录票据。 + + :param user_id: 已通过插件认证的本地用户 ID + :param provider_id: 认证提供方 ID + :param metadata: 插件侧附加信息 + :return: 一次性票据字符串 + """ + ticket = secrets.token_urlsafe(32) + now = time.time() + with self._lock: + self._cleanup(now) + self._tickets[ticket] = { + "user_id": int(user_id), + "provider_id": provider_id, + "metadata": metadata or {}, + "created_at": now, + } + return ticket + + def consume(self, ticket: str) -> Optional[dict[str, Any]]: + """ + 消费并删除一次性登录票据。 + + :param ticket: 登录票据 + :return: 票据数据,票据不存在或过期时返回 None + """ + if not ticket: + return None + now = time.time() + with self._lock: + data = self._tickets.pop(ticket, None) + self._cleanup(now) + if not data: + return None + if now - float(data.get("created_at") or 0) > self._ttl_seconds: + return None + return data + + def _cleanup(self, now: Optional[float] = None) -> None: + """ + 清理过期或过量的票据缓存。 + + :param now: 当前时间戳,未传入时自动读取 + """ + current = now or time.time() + expired = [ + key + for key, value in self._tickets.items() + if current - float(value.get("created_at") or 0) > self._ttl_seconds + ] + for key in expired: + self._tickets.pop(key, None) + if len(self._tickets) <= self._max_items: + return + ordered = sorted( + self._tickets.items(), + key=lambda item: float(item[1].get("created_at") or 0), + ) + for key, _ in ordered[: len(self._tickets) - self._max_items]: + self._tickets.pop(key, None) + + +def create_plugin_auth_ticket(user_id: int, provider_id: str, metadata: Optional[dict[str, Any]] = None) -> str: + """ + 为插件认证成功的用户创建一次性登录票据。 + + :param user_id: 本地用户 ID + :param provider_id: 认证提供方 ID + :param metadata: 插件侧附加信息 + :return: 一次性票据字符串 + """ + return AuthTicketStore().create(user_id=user_id, provider_id=provider_id, metadata=metadata) + + +def consume_plugin_auth_ticket(ticket: str) -> Optional[dict[str, Any]]: + """ + 消费插件认证登录票据。 + + :param ticket: 登录票据 + :return: 票据数据,票据不存在或过期时返回 None + """ + return AuthTicketStore().consume(ticket) + + +def build_token_response(user: User) -> schemas.Token: + """ + 使用系统统一逻辑构造登录 Token 响应。 + + :param user: 已认证的本地用户 + :return: 标准 Token 响应 + """ + level = SitesHelper().auth_level + show_wizard = ( + not SystemConfigOper().get(SystemConfigKey.SetupWizardState) + and not settings.ADVANCED_MODE + ) + return schemas.Token( + access_token=security.create_access_token( + userid=user.id, + username=user.name, + super_user=user.is_superuser, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), + level=level, + ), + token_type="bearer", + super_user=user.is_superuser, + user_id=user.id, + user_name=user.name, + avatar=user.avatar, + level=level, + permissions=user.permissions or {}, + wizard=show_wizard, + ) diff --git a/app/core/plugin.py b/app/core/plugin.py index cacbfce1..a4b4aaf7 100644 --- a/app/core/plugin.py +++ b/app/core/plugin.py @@ -950,6 +950,47 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton): }) return remotes + def get_plugin_auth_providers(self) -> List[Dict[str, Any]]: + """ + 聚合插件声明的登录认证提供方。 + + :return: 插件认证入口列表 + """ + providers: List[Dict[str, Any]] = [] + running_plugins_snapshot = dict(self._running_plugins) + for plugin_id, plugin in running_plugins_snapshot.items(): + if not plugin.get_state(): + continue + if not hasattr(plugin, "get_auth_providers") or not ObjectUtils.check_method(plugin.get_auth_providers): + continue + try: + plugin_providers = plugin.get_auth_providers() or [] + except Exception as e: + logger.error(f"获取插件 {plugin_id} 登录认证提供方出错:{str(e)}") + continue + render_mode = None + dist_path = None + if hasattr(plugin, "get_render_mode"): + render_mode, dist_path = plugin.get_render_mode() + for raw_provider in plugin_providers: + if not raw_provider or not isinstance(raw_provider, dict): + continue + provider = raw_provider.copy() + provider["type"] = "plugin" + provider["plugin_id"] = plugin_id + provider.setdefault("id", f"plugin:{plugin_id}") + provider.setdefault("name", plugin.plugin_name) + provider.setdefault("enabled", True) + if render_mode == "vue" and dist_path: + provider.setdefault("component", "AuthPage") + provider["remote"] = { + "id": plugin_id, + "url": self.get_plugin_remote_entry(plugin_id, dist_path), + "name": plugin.plugin_name, + } + providers.append(provider) + return providers + def get_plugin_sidebar_nav(self) -> List[Dict[str, Any]]: """ 聚合所有已启用 Vue 插件的侧栏导航项(get_sidebar_nav)。 diff --git a/app/plugins/__init__.py b/app/plugins/__init__.py index 2550b95a..b4804d6d 100644 --- a/app/plugins/__init__.py +++ b/app/plugins/__init__.py @@ -174,6 +174,21 @@ class _PluginBase(metaclass=ABCMeta): """ pass + def get_auth_providers(self) -> List[Dict[str, Any]]: + """ + 声明插件提供的登录认证入口。 + + 返回示例: + [{ + "id": "oidc", + "name": "OIDC 登录", + "icon": "mdi-openid", + "component": "AuthPage", + "enabled": True + }] + """ + pass + def get_module(self) -> Dict[str, Any]: """ 获取插件模块声明,用于胁持系统模块实现(方法名:方法实现)