diff --git a/app/api/endpoints/login.py b/app/api/endpoints/login.py index 8df058a4..81bba212 100644 --- a/app/api/endpoints/login.py +++ b/app/api/endpoints/login.py @@ -1,21 +1,15 @@ -import secrets from datetime import timedelta from typing import Any, List from fastapi import APIRouter, Depends, Form, HTTPException from fastapi.security import OAuth2PasswordRequestForm -from sqlalchemy.orm import Session from app import schemas from app.chain.tmdb import TmdbChain from app.chain.user import UserChain from app.core import security from app.core.config import settings -from app.core.security import get_password_hash -from app.db import get_db -from app.db.models.user import User from app.helper.sites import SitesHelper -from app.log import logger from app.utils.web import WebUtils router = APIRouter() @@ -23,60 +17,32 @@ router = APIRouter() @router.post("/access-token", summary="获取token", response_model=schemas.Token) async def login_access_token( - db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends(), otp_password: str = Form(None) ) -> Any: """ 获取认证Token """ - # 检查数据库 - success, user = User.authenticate( - db=db, - name=form_data.username, - password=form_data.password, - otp_password=otp_password - ) + success, user_or_message = UserChain().user_authenticate(username=form_data.username, + password=form_data.password, + mfa_code=otp_password) + if not success: - # 认证不成功 - if not user: - if not settings.AUXILIARY_AUTH_ENABLE: - logger.warn(f"用户 {form_data.username} 登录失败!") - raise HTTPException(status_code=401, detail="用户名、密码或二次校验码不正确") - else: - # 如果找不到用户并开启了辅助认证 - logger.warn(f"登录用户 {form_data.username} 本地不存在,尝试辅助认证 ...") - success = UserChain().user_authenticate(form_data.username, form_data.password) - if not success: - logger.warn(f"用户 {form_data.username} 登录失败!") - raise HTTPException(status_code=401, detail="用户名、密码、二次校验码不正确") - else: - logger.info(f"用户 {form_data.username} 辅助认证成功,以普通用户登录...") - # 加入用户信息表 - logger.info(f"创建用户: {form_data.username}") - user = User(name=form_data.username, is_active=True, - is_superuser=False, hashed_password=get_password_hash(secrets.token_urlsafe(16))) - user.create(db) - else: - # 用户存在,但认证失败 - logger.warn(f"用户 {user.name} 登录失败!") - raise HTTPException(status_code=401, detail="用户名、密码或二次校验码不正确") - elif user and not user.is_active: - raise HTTPException(status_code=403, detail="用户未启用") - logger.info(f"用户 {user.name} 登录成功!") + raise HTTPException(status_code=401, detail=user_or_message) + level = SitesHelper().auth_level return schemas.Token( access_token=security.create_access_token( - userid=user.id, - username=user.name, - super_user=user.is_superuser, + userid=user_or_message.id, + username=user_or_message.name, + super_user=user_or_message.is_superuser, expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), level=level ), token_type="bearer", - super_user=user.is_superuser, - user_name=user.name, - avatar=user.avatar, + super_user=user_or_message.is_superuser, + user_name=user_or_message.name, + avatar=user_or_message.avatar, level=level ) diff --git a/app/chain/user.py b/app/chain/user.py index 8c1c193b..67c9288a 100644 --- a/app/chain/user.py +++ b/app/chain/user.py @@ -1,80 +1,224 @@ +import secrets +from typing import Optional, Tuple, Union + from app.chain import ChainBase +from app.core.config import settings +from app.core.security import get_password_hash, verify_password +from app.db.models.user import User +from app.db.user_oper import UserOper from app.log import logger -from app.schemas.event import AuthPassedInterceptData, AuthVerificationData +from app.schemas.event import AuthCredentials, AuthInterceptCredentials from app.schemas.types import ChainEventType +from app.utils.otp import OtpUtils +from app.utils.singleton import Singleton + +PASSWORD_INVALID_CREDENTIALS_MESSAGE = "用户名或密码或二次校验码不正确" -class UserChain(ChainBase): +class UserChain(ChainBase, metaclass=Singleton): """ - 用户链 + 用户链,处理多种认证协议 """ - def user_authenticate(self, name: str, password: str) -> bool: - """ - 辅助完成用户认证。 + def __init__(self): + super().__init__() + self.user_oper = UserOper() - :param name: 用户名 - :param password: 密码 - :return: 认证成功时返回 True,否则返回 False + def user_authenticate( + self, + username: Optional[str] = None, + password: Optional[str] = None, + mfa_code: Optional[str] = None, + code: Optional[str] = None, + grant_type: str = "password" + ) -> Union[Tuple[bool, Optional[str]], Tuple[bool, Optional[User]]]: """ - logger.debug(f"开始对用户 {name} 通过系统预置渠道进行辅助认证") - auth_data = AuthVerificationData(name=name, password=password) - # 尝试通过默认的认证模块认证 - try: - result = self.run_module("user_authenticate", auth_data=auth_data) - if result: - return self._process_auth_success(name, result) - except Exception as e: - logger.error(f"认证模块运行出错:{e}") - return False + 认证用户,根据不同的 grant_type 处理不同的认证流程 - # 如果预置的认证未通过,则触发 AuthVerification 事件 - logger.debug(f"用户 {name} 未通过系统预置渠道认证,触发认证事件") - event = self.eventmanager.send_event( - etype=ChainEventType.AuthVerification, - data=auth_data + :param username: 用户名,适用于 "password" grant_type + :param password: 用户密码,适用于 "password" grant_type + :param mfa_code: 一次性密码,适用于 "password" grant_type + :param code: 授权码,适用于 "authorization_code" grant_type + :param grant_type: 认证类型,如 "password", "authorization_code", "client_credentials" + :return: + - 对于成功的认证,返回 (True, User) + - 对于失败的认证,返回 (False, "错误信息") + """ + credentials = AuthCredentials( + username=username, + password=password, + mfa_code=mfa_code, + code=code, + grant_type=grant_type ) - if not event: + logger.debug(f"开始使用 {grant_type} 认证,对用户 {username} 进行身份校验") + if credentials.grant_type == "password": + # Password 认证 + success, user_or_message = self.password_authenticate(credentials=credentials) + if success: + # 如果用户启用了二次验证码,则进一步验证 + if not self._verify_mfa(user_or_message, credentials.mfa_code): + return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE + logger.info(f"用户 {username} 通过密码认证成功") + return True, user_or_message + else: + # 用户不存在或密码错误,考虑辅助认证 + if settings.AUXILIARY_AUTH_ENABLE: + # 检查是否因为用户被禁用 + user = self.user_oper.get_by_name(name=username) + if user and not user.is_active: + logger.info(f"用户 {username} 已被禁用,跳过后续辅助认证") + return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE + + logger.warning("密码认证失败,尝试通过外部服务进行辅助认证 ...") + aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials) + if aux_success: + # 辅助认证成功后再验证二次验证码 + if not self._verify_mfa(aux_user_or_message, credentials.mfa_code): + return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE + return True, aux_user_or_message + else: + return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE + else: + logger.debug(f"辅助认证未启用,用户 {username} 认证失败") + return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE + else: + # 处理其他认证类型的分支 + if settings.AUXILIARY_AUTH_ENABLE: + aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials) + if aux_success: + logger.info(f"用户 {username} 辅助认证成功") + return True, aux_user_or_message + else: + logger.warning(f"用户 {username} 辅助认证失败") + return False, "认证失败" + else: + logger.debug(f"辅助认证未启用,认证类型 {grant_type} 未实现") + return False, "未实现的认证类型" + + def password_authenticate(self, credentials: AuthCredentials) -> Tuple[bool, Union[User, str]]: + """ + 密码认证 + + :param credentials: 认证凭证,包含用户名、密码以及可选的 MFA 认证码 + :return: + - 成功时返回 (True, User),其中 User 是认证通过的用户对象 + - 失败时返回 (False, "错误信息") + """ + if not credentials or credentials.grant_type != "password": + logger.debug("密码认证失败,认证类型不匹配") + return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE + + user = self.user_oper.get_by_name(name=credentials.username) + if not user: + logger.debug(f"密码认证失败,用户 {credentials.username} 不存在") + return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE + + if not user.is_active: + logger.debug(f"密码认证失败,用户 {credentials.username} 已被禁用") + return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE + + if not verify_password(credentials.password, str(user.hashed_password)): + logger.debug(f"密码认证失败,用户 {credentials.username} 的密码验证不通过") + return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE + + return True, user + + def auxiliary_authenticate(self, credentials: AuthCredentials) -> Tuple[bool, Union[User, str]]: + """ + 辅助用户认证 + + :param credentials: 认证凭证,包含必要的认证信息 + :return: + - 成功时返回 (True, User),其中 User 是认证通过的用户对象 + - 失败时返回 (False, "错误信息") + """ + if not credentials: + return False, "认证凭证无效" + + logger.debug(f"尝试通过系统模块进行辅助认证,用户: {credentials.username}") + result = self.run_module("user_authenticate", credentials=credentials) + + if not result: + logger.debug(f"通过系统模块辅助认证失败,尝试触发 {ChainEventType.AuthVerification} 事件") + event = self.eventmanager.send_event(etype=ChainEventType.AuthVerification, data=credentials) + if not event or not event.event_data: + logger.error(f"{credentials.grant_type} 辅助认证失败,未返回有效数据") + return False, f"{credentials.grant_type} 辅助认证事件失败或无效" + + credentials = event.event_data # 使用事件返回的认证数据 + else: + logger.info(f"通过系统模块辅助认证成功,用户: {credentials.username}") + credentials = result # 使用模块认证返回的认证数据 + + # 处理认证成功的逻辑 + success = self._process_auth_success(username=credentials.username, credentials=credentials) + if success: + logger.info(f"用户 {credentials.username} 辅助认证通过") + return True, self.user_oper.get_by_name(credentials.username) + else: + logger.warning(f"用户 {credentials.username} 辅助认证未通过") + return False, "用户名或密码或二次校验码不正确" + + @staticmethod + def _verify_mfa(user: User, mfa_code: Optional[str]) -> bool: + """ + 验证 MFA(二次验证码) + + :param user: 用户对象 + :param mfa_code: 二次验证码 + :return: 如果验证成功返回 True,否则返回 False + """ + if not user.is_otp: + return True + if not mfa_code: + logger.debug(f"用户 {user.name} 缺少 MFA 认证码") return False - if event and event.event_data: - try: - return self._process_auth_success(name, event.event_data) - except Exception as e: - logger.error(f"AuthVerificationData 数据验证失败:{e}") + if not OtpUtils.check(str(user.otp_secret), mfa_code): + logger.debug(f"用户 {user.name} 的 MFA 认证失败") + return False + return True + + def _process_auth_success(self, username: str, credentials: AuthCredentials) -> bool: + """ + 处理辅助认证成功的逻辑,返回用户对象或创建新用户 + + :param username: 用户名 + :param credentials: 认证凭证,包含 token、channel、service 等信息 + :return: + - 如果认证成功并且用户存在或已创建,返回 User 对象 + - 如果认证被拦截或失败,返回 None + """ + token, channel, service = credentials.token, credentials.channel, credentials.service + + if not all([token, channel, service]): + logger.debug(f"用户 {username} 未通过 {credentials.grant_type} 认证,必要信息不足") + return False + + anonymized_token = f"{token[:len(token) // 2]}********" + logger.info( + f"认证类型:{credentials.grant_type},用户:{username},渠道:{channel}," + f"服务:{service} 认证成功,token:{anonymized_token}") + + # 触发认证通过的拦截事件 + intercept_event = self.eventmanager.send_event( + etype=ChainEventType.AuthPassedIntercept, + data=AuthInterceptCredentials(username=username, channel=channel, service=service, token=token) + ) + + if intercept_event and intercept_event.event_data: + intercept_data: AuthInterceptCredentials = intercept_event.event_data + if intercept_data.cancel: + logger.warning( + f"认证被拦截,用户:{username},渠道:{channel},服务:{service},拦截源:{intercept_data.source}") return False - # 认证失败 - logger.warning(f"用户 {name} 辅助认证失败") - return False - - def _process_auth_success(self, name: str, data: AuthVerificationData) -> bool: - """ - 处理认证成功后的逻辑,记录日志并处理拦截事件。 - - :param name: 用户名 - :param data: 认证返回的数据,包含 token、channel 和 service - :return: 成功返回 True,若被拦截返回 False - """ - token, channel, service = data.token, data.channel, data.service - if token and channel and service: - # 匿名化 token - anonymized_token = f"{token[:len(token) // 2]}****" - logger.info(f"用户 {name} 通过渠道 {channel},服务: {service} 认证成功,token: {anonymized_token}") - - # 触发认证通过的拦截事件 - intercept_event = self.eventmanager.send_event( - etype=ChainEventType.AuthPassedIntercept, - data=AuthPassedInterceptData(name=name, channel=channel, service=service, token=token) - ) - - if intercept_event and intercept_event.event_data: - intercept_data: AuthPassedInterceptData = intercept_event.event_data - if intercept_data.cancel: - logger.info( - f"认证被拦截,用户: {name},渠道: {channel},服务: {service},拦截源: {intercept_data.source}") - return False - + # 检查用户是否存在,如果不存在则创建新用户 + user = self.user_oper.get_by_name(name=username) + if user: return True - logger.warning(f"用户 {name} 未通过辅助认证") - return False + logger.info(f"用户 {username} 不存在,已通过 {credentials.grant_type} 认证并已创建普通用户") + self.user_oper.add(name=username, is_active=True, is_superuser=False, + hashed_password=get_password_hash(secrets.token_urlsafe(16))) + return True diff --git a/app/core/config.py b/app/core/config.py index dd014abe..7df1dfe8 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -70,7 +70,7 @@ class ConfigModel(BaseModel): CONFIG_DIR: Optional[str] = None # 超级管理员 SUPERUSER: str = "admin" - # 辅助认证,允许通过外部服务(如媒体服务器/插件等)认证并创建用户 + # 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户 AUXILIARY_AUTH_ENABLE: bool = False # API密钥,需要更换 API_TOKEN: Optional[str] = None diff --git a/app/db/models/user.py b/app/db/models/user.py index 2a6fb748..646478e5 100644 --- a/app/db/models/user.py +++ b/app/db/models/user.py @@ -1,11 +1,7 @@ -from typing import Tuple, Any - -from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON +from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String from sqlalchemy.orm import Session -from app.core.security import verify_password -from app.db import db_query, db_update, Base -from app.utils.otp import OtpUtils +from app.db import Base, db_query, db_update class User(Base): @@ -35,20 +31,6 @@ class User(Base): # 用户个性化设置 json settings = Column(JSON, default=dict) - @staticmethod - @db_query - def authenticate(db: Session, name: str, password: str, - otp_password: str) -> Tuple[bool, Any]: - user = db.query(User).filter(User.name == name).first() - if not user: - return False, None - if not verify_password(password, str(user.hashed_password)): - return False, user - if user.is_otp: - if not otp_password or not OtpUtils.check(str(user.otp_secret), otp_password): - return False, user - return True, user - @staticmethod @db_query def get_by_name(db: Session, name: str): diff --git a/app/db/user_oper.py b/app/db/user_oper.py index 8037aa9a..1740b7c6 100644 --- a/app/db/user_oper.py +++ b/app/db/user_oper.py @@ -5,8 +5,7 @@ from sqlalchemy.orm import Session from app import schemas from app.core.security import verify_token -from app.db import DbOper -from app.db import get_db +from app.db import DbOper, get_db from app.db.models.user import User @@ -52,6 +51,19 @@ class UserOper(DbOper): 用户管理 """ + def add(self, **kwargs): + """ + 新增用户 + """ + user = User(**kwargs) + user.create(self._db) + + def get_by_name(self, name: str) -> User: + """ + 根据用户名获取用户 + """ + return User.get_by_name(self._db, name) + def get_permissions(self, name: str) -> dict: """ 获取用户权限 diff --git a/app/modules/emby/__init__.py b/app/modules/emby/__init__.py index a23d8289..5953e8e8 100644 --- a/app/modules/emby/__init__.py +++ b/app/modules/emby/__init__.py @@ -5,7 +5,7 @@ from app.core.context import MediaInfo from app.log import logger from app.modules import _MediaServerBase, _ModuleBase from app.modules.emby.emby import Emby -from app.schemas.event import AuthVerificationData +from app.schemas.event import AuthCredentials from app.schemas.types import MediaType, ModuleType @@ -58,22 +58,22 @@ class EmbyModule(_ModuleBase, _MediaServerBase[Emby]): logger.info(f"Emby服务器 {name} 连接断开,尝试重连 ...") server.reconnect() - def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]: + def user_authenticate(self, credentials: AuthCredentials) -> Optional[AuthCredentials]: """ 使用Emby用户辅助完成用户认证 - :param auth_data: 认证数据 + :param credentials: 认证数据 :return: 认证数据 """ # Emby认证 - if not auth_data: + if not credentials or credentials.grant_type != "password": return None for name, server in self.get_instances().items(): - token = server.authenticate(auth_data.name, auth_data.password) + token = server.authenticate(credentials.username, credentials.password) if token: - auth_data.channel = self.get_name() - auth_data.service = name - auth_data.token = token - return auth_data + credentials.channel = self.get_name() + credentials.service = name + credentials.token = token + return credentials return None def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]: diff --git a/app/modules/jellyfin/__init__.py b/app/modules/jellyfin/__init__.py index a48a0c71..cdcb45a0 100644 --- a/app/modules/jellyfin/__init__.py +++ b/app/modules/jellyfin/__init__.py @@ -5,7 +5,7 @@ from app.core.context import MediaInfo from app.log import logger from app.modules import _MediaServerBase, _ModuleBase from app.modules.jellyfin.jellyfin import Jellyfin -from app.schemas.event import AuthVerificationData +from app.schemas.event import AuthCredentials from app.schemas.types import MediaType, ModuleType @@ -58,22 +58,22 @@ class JellyfinModule(_ModuleBase, _MediaServerBase[Jellyfin]): return False, f"无法连接Jellyfin服务器:{name}" return True, "" - def user_authenticate(self, auth_data: AuthVerificationData) -> Optional[AuthVerificationData]: + def user_authenticate(self, credentials: AuthCredentials) -> Optional[AuthCredentials]: """ 使用Jellyfin用户辅助完成用户认证 - :param auth_data: 认证数据 + :param credentials: 认证数据 :return: 认证数据 """ # Jellyfin认证 - if not auth_data: + if not credentials or credentials.grant_type != "password": return None for name, server in self.get_instances().items(): - token = server.authenticate(auth_data.name, auth_data.password) + token = server.authenticate(credentials.username, credentials.password) if token: - auth_data.channel = self.get_name() - auth_data.service = name - auth_data.token = token - return auth_data + credentials.channel = self.get_name() + credentials.service = name + credentials.token = token + return credentials return None def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[schemas.WebhookEventInfo]: diff --git a/app/schemas/event.py b/app/schemas/event.py index 48368905..1a5251c5 100644 --- a/app/schemas/event.py +++ b/app/schemas/event.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, root_validator class BaseEventData(BaseModel): @@ -17,37 +17,60 @@ class ChainEventData(BaseEventData): pass -class AuthVerificationData(ChainEventData): +class AuthCredentials(ChainEventData): """ AuthVerification 事件的数据模型 Attributes: - # 输入参数 - name (str): 用户名 - password (str): 用户密码 - - # 输出参数 - token (str): 认证令牌 - channel (str): 认证渠道 - service (str): 服务名称 + username (Optional[str]): 用户名,适用于 "password" grant_type + password (Optional[str]): 用户密码,适用于 "password" grant_type + mfa_code (Optional[str]): 一次性密码,目前仅适用于 "password" 认证类型 + code (Optional[str]): 授权码,适用于 "authorization_code" grant_type + grant_type (str): 认证类型,如 "password", "authorization_code", "client_credentials" + # scope (List[str]): 权限范围,如 ["read", "write"] + token (Optional[str]): 认证令牌 + channel (Optional[str]): 认证渠道 + service (Optional[str]): 服务名称 """ # 输入参数 - name: str = Field(..., description="用户名") - password: str = Field(..., description="用户密码") + username: Optional[str] = Field(None, description="用户名,适用于 'password' 认证类型") + password: Optional[str] = Field(None, description="用户密码,适用于 'password' 认证类型") + mfa_code: Optional[str] = Field(None, description="一次性密码,目前仅适用于 'password' 认证类型") + code: Optional[str] = Field(None, description="授权码,适用于 'authorization_code' 认证类型") + grant_type: str = Field(..., description="认证类型,如 'password', 'authorization_code', 'client_credentials'") + # scope: List[str] = Field(default_factory=list, description="权限范围,如 ['read', 'write']") # 输出参数 + # grant_type 为 authorization_code 时,输出参数包括 username、token、channel、service token: Optional[str] = Field(None, description="认证令牌") channel: Optional[str] = Field(None, description="认证渠道") service: Optional[str] = Field(None, description="服务名称") + @root_validator(pre=True) + def check_fields_based_on_grant_type(cls, values): + grant_type = values.get("grant_type") + if not grant_type: + values["grant_type"] = "password" + grant_type = "password" -class AuthPassedInterceptData(ChainEventData): + if grant_type == "password": + if not values.get("username") or not values.get("password"): + raise ValueError("username and password are required for grant_type 'password'") + + elif grant_type == "authorization_code": + if not values.get("code"): + raise ValueError("code is required for grant_type 'authorization_code'") + + return values + + +class AuthInterceptCredentials(ChainEventData): """ - AuthPassedIntercept 事件的数据模型。 + AuthPassedIntercept 事件的数据模型 Attributes: # 输入参数 - name (str): 用户名 + username (str): 用户名 channel (str): 认证渠道 service (str): 服务名称 token (str): 认证令牌 @@ -57,7 +80,7 @@ class AuthPassedInterceptData(ChainEventData): cancel (bool): 是否取消认证,默认值为 False """ # 输入参数 - name: str = Field(..., description="用户名") + username: str = Field(..., description="用户名") channel: str = Field(..., description="认证渠道") service: str = Field(..., description="服务名称") token: Optional[str] = Field(None, description="认证令牌") diff --git a/config/app.env b/config/app.env index 31797080..1352e306 100644 --- a/config/app.env +++ b/config/app.env @@ -17,7 +17,7 @@ DB_MAX_OVERFLOW=5 DB_TIMEOUT=60 # 【*】超级管理员,设置后一但重启将固化到数据库中,修改将无效(初始化超级管理员密码仅会生成一次,请在日志中查看并自行登录系统修改) SUPERUSER=admin -# 辅助认证,允许通过媒体服务器认证并创建用户 +# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户 AUXILIARY_AUTH_ENABLE=false # 大内存模式,开启后会增加缓存数量,但会占用更多内存 BIG_MEMORY_MODE=false