diff --git a/app/api/endpoints/subscribe.py b/app/api/endpoints/subscribe.py index 883c4a48..08ae9758 100644 --- a/app/api/endpoints/subscribe.py +++ b/app/api/endpoints/subscribe.py @@ -2,6 +2,7 @@ from typing import List, Any, Annotated, Optional import cn2an from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from app import schemas @@ -11,7 +12,7 @@ from app.core.context import MediaInfo from app.core.event import eventmanager from app.core.metainfo import MetaInfo from app.core.security import verify_token, verify_apitoken -from app.db import get_db +from app.db import get_async_db, get_db from app.db.models.subscribe import Subscribe from app.db.models.subscribehistory import SubscribeHistory from app.db.models.user import User @@ -34,21 +35,21 @@ def start_subscribe_add(title: str, year: str, @router.get("/", summary="查询所有订阅", response_model=List[schemas.Subscribe]) -def read_subscribes( - db: Session = Depends(get_db), +async def read_subscribes( + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询所有订阅 """ - return Subscribe.list(db) + return await Subscribe.async_list(db) @router.get("/list", summary="查询所有订阅(API_TOKEN)", response_model=List[schemas.Subscribe]) -def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any: +async def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any: """ 查询所有订阅 API_TOKEN认证(?token=xxx) """ - return read_subscribes() + return await read_subscribes() @router.post("/", summary="新增订阅", response_model=schemas.Response) @@ -87,16 +88,16 @@ def create_subscribe( @router.put("/", summary="更新订阅", response_model=schemas.Response) -def update_subscribe( +async def update_subscribe( *, subscribe_in: schemas.Subscribe, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token) ) -> Any: """ 更新订阅信息 """ - subscribe = Subscribe.get(db, subscribe_in.id) + subscribe = await Subscribe.async_get(db, subscribe_in.id) if not subscribe: return schemas.Response(success=False, message="订阅不存在") # 避免更新缺失集数 @@ -114,7 +115,7 @@ def update_subscribe( # 是否手动修改过总集数 if subscribe_in.total_episode != subscribe.total_episode: subscribe_dict["manual_total_episode"] = 1 - subscribe.update(db, subscribe_dict) + await subscribe.async_update(db, subscribe_dict) # 发送订阅调整事件 eventmanager.send_event(EventType.SubscribeModified, { "subscribe_id": subscribe.id, @@ -125,22 +126,22 @@ def update_subscribe( @router.put("/status/{subid}", summary="更新订阅状态", response_model=schemas.Response) -def update_subscribe_status( +async def update_subscribe_status( subid: int, state: str, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 更新订阅状态 """ - subscribe = Subscribe.get(db, subid) + subscribe = await Subscribe.async_get(db, subid) if not subscribe: return schemas.Response(success=False, message="订阅不存在") valid_states = ["R", "P", "S"] if state not in valid_states: return schemas.Response(success=False, message="无效的订阅状态") old_subscribe_dict = subscribe.to_dict() - subscribe.update(db, { + await subscribe.async_update(db, { "state": state }) # 发送订阅调整事件 @@ -153,11 +154,11 @@ def update_subscribe_status( @router.get("/media/{mediaid}", summary="查询订阅", response_model=schemas.Subscribe) -def subscribe_mediaid( +async def subscribe_mediaid( mediaid: str, season: Optional[int] = None, title: Optional[str] = None, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据 TMDBID/豆瓣ID/BangumiId 查询订阅 tmdb:/douban: @@ -167,23 +168,23 @@ def subscribe_mediaid( tmdbid = mediaid[5:] if not tmdbid or not str(tmdbid).isdigit(): return Subscribe() - result = Subscribe.exists(db, tmdbid=int(tmdbid), season=season) + result = await Subscribe.async_exists(db, tmdbid=int(tmdbid), season=season) elif mediaid.startswith("douban:"): doubanid = mediaid[7:] if not doubanid: return Subscribe() - result = Subscribe.get_by_doubanid(db, doubanid) + result = await Subscribe.async_get_by_doubanid(db, doubanid) if not result and title: title_check = True elif mediaid.startswith("bangumi:"): bangumiid = mediaid[8:] if not bangumiid or not str(bangumiid).isdigit(): return Subscribe() - result = Subscribe.get_by_bangumiid(db, int(bangumiid)) + result = await Subscribe.async_get_by_bangumiid(db, int(bangumiid)) if not result and title: title_check = True else: - result = Subscribe.get_by_mediaid(db, mediaid) + result = await Subscribe.async_get_by_mediaid(db, mediaid) if not result and title: title_check = True # 使用名称检查订阅 @@ -191,7 +192,7 @@ def subscribe_mediaid( meta = MetaInfo(title) if season: meta.begin_season = season - result = Subscribe.get_by_title(db, title=meta.name, season=meta.begin_season) + result = await Subscribe.async_get_by_title(db, title=meta.name, season=meta.begin_season) return result if result else Subscribe() @@ -207,17 +208,17 @@ def refresh_subscribes( @router.get("/reset/{subid}", summary="重置订阅", response_model=schemas.Response) -def reset_subscribes( +async def reset_subscribes( subid: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 重置订阅 """ - subscribe = Subscribe.get(db, subid) + subscribe = await Subscribe.async_get(db, subid) if subscribe: old_subscribe_dict = subscribe.to_dict() - subscribe.update(db, { + await subscribe.async_update(db, { "note": [], "lack_episode": subscribe.total_episode, "state": "R" @@ -243,7 +244,7 @@ def check_subscribes( @router.get("/search", summary="搜索所有订阅", response_model=schemas.Response) -def search_subscribes( +async def search_subscribes( background_tasks: BackgroundTasks, _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ @@ -262,7 +263,7 @@ def search_subscribes( @router.get("/search/{subscribe_id}", summary="搜索订阅", response_model=schemas.Response) -def search_subscribe( +async def search_subscribe( subscribe_id: int, background_tasks: BackgroundTasks, _: schemas.TokenPayload = Depends(verify_token)) -> Any: @@ -282,10 +283,10 @@ def search_subscribe( @router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response) -def delete_subscribe_by_mediaid( +async def delete_subscribe_by_mediaid( mediaid: str, season: Optional[int] = None, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token) ) -> Any: """ @@ -296,21 +297,21 @@ def delete_subscribe_by_mediaid( tmdbid = mediaid[5:] if not tmdbid or not str(tmdbid).isdigit(): return schemas.Response(success=False) - subscribes = Subscribe().get_by_tmdbid(db, int(tmdbid), season) + subscribes = await Subscribe.async_get_by_tmdbid(db, int(tmdbid), season) delete_subscribes.extend(subscribes) elif mediaid.startswith("douban:"): doubanid = mediaid[7:] if not doubanid: return schemas.Response(success=False) - subscribe = Subscribe().get_by_doubanid(db, doubanid) + subscribe = await Subscribe.async_get_by_doubanid(db, doubanid) if subscribe: delete_subscribes.append(subscribe) else: - subscribe = Subscribe().get_by_mediaid(db, mediaid) + subscribe = await Subscribe.async_get_by_mediaid(db, mediaid) if subscribe: delete_subscribes.append(subscribe) for subscribe in delete_subscribes: - Subscribe().delete(db, subscribe.id) + await Subscribe.async_delete(db, subscribe.id) # 发送事件 eventmanager.send_event(EventType.SubscribeDeleted, { "subscribe_id": subscribe.id, @@ -373,33 +374,33 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks, @router.get("/history/{mtype}", summary="查询订阅历史", response_model=List[schemas.Subscribe]) -def subscribe_history( +async def subscribe_history( mtype: str, page: Optional[int] = 1, count: Optional[int] = 30, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询电影/电视剧订阅历史 """ - return SubscribeHistory.list_by_type(db, mtype=mtype, page=page, count=count) + return await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count) @router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response) -def delete_subscribe( +async def delete_subscribe( history_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token) ) -> Any: """ 删除订阅历史 """ - SubscribeHistory.delete(db, history_id) + await SubscribeHistory.async_delete(db, history_id) return schemas.Response(success=True) @router.get("/popular", summary="热门订阅(基于用户共享数据)", response_model=List[schemas.MediaInfo]) -def popular_subscribes( +async def popular_subscribes( stype: str, page: Optional[int] = 1, count: Optional[int] = 30, @@ -408,7 +409,7 @@ def popular_subscribes( """ 查询热门订阅 """ - subscribes = SubscribeHelper().get_statistic(stype=stype, page=page, count=count) + subscribes = await SubscribeHelper().async_get_statistic(stype=stype, page=page, count=count) if subscribes: ret_medias = [] for sub in subscribes: @@ -444,14 +445,14 @@ def popular_subscribes( @router.get("/user/{username}", summary="用户订阅", response_model=List[schemas.Subscribe]) -def user_subscribes( +async def user_subscribes( username: str, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询用户订阅 """ - return Subscribe.list_by_username(db, username) + return await Subscribe.async_list_by_username(db, username) @router.get("/files/{subscribe_id}", summary="订阅相关文件信息", response_model=schemas.SubscrbieInfo) @@ -469,27 +470,27 @@ def subscribe_files( @router.post("/share", summary="分享订阅", response_model=schemas.Response) -def subscribe_share( +async def subscribe_share( sub: schemas.SubscribeShare, _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 分享订阅 """ - state, errmsg = SubscribeHelper().sub_share(subscribe_id=sub.subscribe_id, - share_title=sub.share_title, - share_comment=sub.share_comment, - share_user=sub.share_user) + state, errmsg = await SubscribeHelper().async_sub_share(subscribe_id=sub.subscribe_id, + share_title=sub.share_title, + share_comment=sub.share_comment, + share_user=sub.share_user) return schemas.Response(success=state, message=errmsg) @router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response) -def subscribe_share_delete( +async def subscribe_share_delete( share_id: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 删除分享 """ - state, errmsg = SubscribeHelper().share_delete(share_id=share_id) + state, errmsg = await SubscribeHelper().async_share_delete(share_id=share_id) return schemas.Response(success=state, message=errmsg) @@ -513,7 +514,7 @@ def subscribe_fork( @router.get("/follow", summary="查询已Follow的订阅分享人", response_model=List[str]) -def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询已Follow的订阅分享人 """ @@ -521,7 +522,7 @@ def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any @router.post("/follow", summary="Follow订阅分享人", response_model=schemas.Response) -def follow_subscriber( +async def follow_subscriber( share_uid: Optional[str] = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ @@ -535,7 +536,7 @@ def follow_subscriber( @router.delete("/follow", summary="取消Follow订阅分享人", response_model=schemas.Response) -def unfollow_subscriber( +async def unfollow_subscriber( share_uid: Optional[str] = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ @@ -549,7 +550,7 @@ def unfollow_subscriber( @router.get("/shares", summary="查询分享的订阅", response_model=List[schemas.SubscribeShare]) -def popular_subscribes( +async def popular_subscribes( name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30, @@ -557,43 +558,43 @@ def popular_subscribes( """ 查询分享的订阅 """ - return SubscribeHelper().get_shares(name=name, page=page, count=count) + return await SubscribeHelper().async_get_shares(name=name, page=page, count=count) @router.get("/share/statistics", summary="查询订阅分享统计", response_model=List[schemas.SubscribeShareStatistics]) -def subscribe_share_statistics(_: schemas.TokenPayload = Depends(verify_token)) -> Any: +async def subscribe_share_statistics(_: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 查询订阅分享统计 返回每个分享人分享的媒体数量以及总的复用人次 """ - return SubscribeHelper().get_share_statistics() + return await SubscribeHelper().async_get_share_statistics() @router.get("/{subscribe_id}", summary="订阅详情", response_model=schemas.Subscribe) -def read_subscribe( +async def read_subscribe( subscribe_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token)) -> Any: """ 根据订阅编号查询订阅信息 """ if not subscribe_id: return Subscribe() - return Subscribe.get(db, subscribe_id) + return await Subscribe.async_get(db, subscribe_id) @router.delete("/{subscribe_id}", summary="删除订阅", response_model=schemas.Response) -def delete_subscribe( +async def delete_subscribe( subscribe_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), _: schemas.TokenPayload = Depends(verify_token) ) -> Any: """ 删除订阅信息 """ - subscribe = Subscribe.get(db, subscribe_id) + subscribe = await Subscribe.async_get(db, subscribe_id) if subscribe: - subscribe.delete(db, subscribe_id) + await Subscribe.async_delete(db, subscribe_id) # 发送事件 eventmanager.send_event(EventType.SubscribeDeleted, { "subscribe_id": subscribe_id, diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 951d29ff..c014ae38 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -3,13 +3,14 @@ import re from typing import Annotated, Any, List, Union from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, File -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app import schemas from app.core.security import get_password_hash -from app.db import get_db +from app.db import get_async_db from app.db.models.user import User -from app.db.user_oper import get_current_active_superuser, get_current_active_user +from app.db.user_oper import get_current_active_superuser_async, \ + get_current_active_user_async, get_current_active_user from app.db.userconfig_oper import UserConfigOper from app.utils.otp import OtpUtils @@ -17,28 +18,27 @@ router = APIRouter() @router.get("/", summary="所有用户", response_model=List[schemas.User]) -def list_users( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_superuser), +async def list_users( + db: AsyncSession = Depends(get_async_db), + current_user: User = Depends(get_current_active_superuser_async), ) -> Any: """ 查询用户列表 """ - users = current_user.list(db) - return users + return await User.async_list(db) @router.post("/", summary="新增用户", response_model=schemas.Response) -def create_user( +async def create_user( *, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), user_in: schemas.UserCreate, - current_user: User = Depends(get_current_active_superuser), + current_user: User = Depends(get_current_active_superuser_async), ) -> Any: """ 新增用户 """ - user = current_user.get_by_name(db, name=user_in.name) + user = await User.async_get_by_name(db, name=user_in.name) if user: return schemas.Response(success=False, message="用户已存在") user_info = user_in.dict() @@ -46,16 +46,16 @@ def create_user( user_info["hashed_password"] = get_password_hash(user_info["password"]) user_info.pop("password") user = User(**user_info) - user.create(db) + await user.async_create(db) return schemas.Response(success=True) @router.put("/", summary="更新用户", response_model=schemas.Response) -def update_user( +async def update_user( *, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), user_in: schemas.UserUpdate, - _: User = Depends(get_current_active_superuser), + current_user: User = Depends(get_current_active_superuser_async), ) -> Any: """ 更新用户 @@ -69,24 +69,24 @@ def update_user( message="密码需要同时包含字母、数字、特殊字符中的至少两项,且长度大于6位") user_info["hashed_password"] = get_password_hash(user_info["password"]) user_info.pop("password") - user = User.get_by_id(db, user_id=user_info["id"]) + user = await User.async_get_by_id(db, user_id=user_info["id"]) user_name = user_info.get("name") if not user_name: return schemas.Response(success=False, message="用户名不能为空") # 新用户名去重 - users = User.list(db) + users = await User.async_list(db) for u in users: if u.name == user_name and u.id != user_info["id"]: return schemas.Response(success=False, message="用户名已被使用") if not user: return schemas.Response(success=False, message="用户不存在") - user.update(db, user_info) + await user.async_update(db, user_info) return schemas.Response(success=True) @router.get("/current", summary="当前登录用户信息", response_model=schemas.User) -def read_current_user( - current_user: User = Depends(get_current_active_user) +async def read_current_user( + current_user: User = Depends(get_current_active_user_async) ) -> Any: """ 当前登录用户信息 @@ -95,18 +95,18 @@ def read_current_user( @router.post("/avatar/{user_id}", summary="上传用户头像", response_model=schemas.Response) -def upload_avatar(user_id: int, db: Session = Depends(get_db), file: UploadFile = File(...), - _: User = Depends(get_current_active_user)): +async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db), file: UploadFile = File(...), + _: User = Depends(get_current_active_user_async)): """ 上传用户头像 """ # 将文件转换为Base64 file_base64 = base64.b64encode(file.file.read()) # 更新到用户表 - user = User.get(db, user_id) + user = await User.async_get(db, user_id) if not user: return schemas.Response(success=False, message="用户不存在") - user.update(db, { + await user.async_update(db, { "avatar": f"data:image/ico;base64,{file_base64}" }) return schemas.Response(success=True, message=file.filename) @@ -121,31 +121,31 @@ def otp_generate( @router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response) -def otp_judge( +async def otp_judge( data: dict, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user) + db: AsyncSession = Depends(get_async_db), + current_user: User = Depends(get_current_active_user_async) ) -> Any: uri = data.get("uri") otp_password = data.get("otpPassword") if not OtpUtils.is_legal(uri, otp_password): return schemas.Response(success=False, message="验证码错误") - current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) + await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) return schemas.Response(success=True) @router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response) -def otp_disable( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user) +async def otp_disable( + db: AsyncSession = Depends(get_async_db), + current_user: User = Depends(get_current_active_user_async) ) -> Any: - current_user.update_otp_by_name(db, current_user.name, False, "") + await current_user.async_update_otp_by_name(db, current_user.name, False, "") return schemas.Response(success=True) @router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response) -def otp_enable(userid: str, db: Session = Depends(get_db)) -> Any: - user: User = User.get_by_name(db, userid) +async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any: + user: User = await User.async_get_by_name(db, userid) if not user: return schemas.Response(success=False) return schemas.Response(success=user.is_otp) @@ -165,9 +165,9 @@ def get_config(key: str, @router.post("/config/{key}", summary="更新用户配置", response_model=schemas.Response) def set_config( - key: str, - value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None, - current_user: User = Depends(get_current_active_user), + key: str, + value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None, + current_user: User = Depends(get_current_active_user), ): """ 更新用户配置 @@ -177,49 +177,49 @@ def set_config( @router.delete("/id/{user_id}", summary="删除用户", response_model=schemas.Response) -def delete_user_by_id( +async def delete_user_by_id( *, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), user_id: int, - current_user: User = Depends(get_current_active_superuser), + current_user: User = Depends(get_current_active_superuser_async), ) -> Any: """ 通过唯一ID删除用户 """ - user = current_user.get_by_id(db, user_id=user_id) + user = await User.async_get_by_id(db, user_id=user_id) if not user: return schemas.Response(success=False, message="用户不存在") - user.delete_by_id(db, user_id) + await User.async_delete(db, user_id) return schemas.Response(success=True) @router.delete("/name/{user_name}", summary="删除用户", response_model=schemas.Response) -def delete_user_by_name( +async def delete_user_by_name( *, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), user_name: str, - current_user: User = Depends(get_current_active_superuser), + current_user: User = Depends(get_current_active_superuser_async), ) -> Any: """ 通过用户名删除用户 """ - user = current_user.get_by_name(db, name=user_name) + user = await User.async_get_by_name(db, name=user_name) if not user: return schemas.Response(success=False, message="用户不存在") - user.delete_by_name(db, user_name) + await User.async_delete(db, user.id) return schemas.Response(success=True) @router.get("/{username}", summary="用户详情", response_model=schemas.User) -def read_user_by_name( +async def read_user_by_name( username: str, - current_user: User = Depends(get_current_active_user), - db: Session = Depends(get_db), + current_user: User = Depends(get_current_active_user_async), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ 查询用户详情 """ - user = current_user.get_by_name(db, name=username) + user = await User.async_get_by_name(db, name=username) if not user: raise HTTPException( status_code=404, diff --git a/app/db/models/subscribe.py b/app/db/models/subscribe.py index f4ebd809..d77ab4aa 100644 --- a/app/db/models/subscribe.py +++ b/app/db/models/subscribe.py @@ -1,10 +1,11 @@ import time from typing import Optional -from sqlalchemy import Column, Integer, String, Sequence, Float, JSON +from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.db import db_query, db_update, Base +from app.db import db_query, db_update, Base, async_db_query, async_db_update class Subscribe(Base): @@ -99,6 +100,27 @@ class Subscribe(Base): return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first() return None + @classmethod + @async_db_query + async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, + season: Optional[int] = None): + if tmdbid: + if season: + result = await db.execute( + select(cls).filter(cls.tmdbid == tmdbid, cls.season == season) + ) + else: + result = await db.execute( + select(cls).filter(cls.tmdbid == tmdbid) + ) + elif doubanid: + result = await db.execute( + select(cls).filter(cls.doubanid == doubanid) + ) + else: + return None + return result.scalars().first() + @staticmethod @db_query def get_by_state(db: Session, state: str): @@ -109,6 +131,19 @@ class Subscribe(Base): # 如果传入的状态不为空,拆分成多个状态 return db.query(Subscribe).filter(Subscribe.state.in_(state.split(','))).all() + @classmethod + @async_db_query + async def async_get_by_state(cls, db: AsyncSession, state: str): + # 如果 state 为空或 None,返回所有订阅 + if not state: + result = await db.execute(select(cls)) + else: + # 如果传入的状态不为空,拆分成多个状态 + result = await db.execute( + select(cls).filter(cls.state.in_(state.split(','))) + ) + return result.scalars().all() + @staticmethod @db_query def get_by_title(db: Session, title: str, season: Optional[int] = None): @@ -117,6 +152,19 @@ class Subscribe(Base): Subscribe.season == season).first() return db.query(Subscribe).filter(Subscribe.name == title).first() + @classmethod + @async_db_query + async def async_get_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None): + if season: + result = await db.execute( + select(cls).filter(cls.name == title, cls.season == season) + ) + else: + result = await db.execute( + select(cls).filter(cls.name == title) + ) + return result.scalars().first() + @staticmethod @db_query def get_by_tmdbid(db: Session, tmdbid: int, season: Optional[int] = None): @@ -126,21 +174,58 @@ class Subscribe(Base): else: return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all() + @classmethod + @async_db_query + async def async_get_by_tmdbid(cls, db: AsyncSession, tmdbid: int, season: Optional[int] = None): + if season: + result = await db.execute( + select(cls).filter(cls.tmdbid == tmdbid, cls.season == season) + ) + else: + result = await db.execute( + select(cls).filter(cls.tmdbid == tmdbid) + ) + return result.scalars().all() + @staticmethod @db_query def get_by_doubanid(db: Session, doubanid: str): return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first() + @classmethod + @async_db_query + async def async_get_by_doubanid(cls, db: AsyncSession, doubanid: str): + result = await db.execute( + select(cls).filter(cls.doubanid == doubanid) + ) + return result.scalars().first() + @staticmethod @db_query def get_by_bangumiid(db: Session, bangumiid: int): return db.query(Subscribe).filter(Subscribe.bangumiid == bangumiid).first() + @classmethod + @async_db_query + async def async_get_by_bangumiid(cls, db: AsyncSession, bangumiid: int): + result = await db.execute( + select(cls).filter(cls.bangumiid == bangumiid) + ) + return result.scalars().first() + @staticmethod @db_query def get_by_mediaid(db: Session, mediaid: str): return db.query(Subscribe).filter(Subscribe.mediaid == mediaid).first() + @classmethod + @async_db_query + async def async_get_by_mediaid(cls, db: AsyncSession, mediaid: str): + result = await db.execute( + select(cls).filter(cls.mediaid == mediaid) + ) + return result.scalars().first() + @db_update def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int): subscrbies = self.get_by_tmdbid(db, tmdbid, season) @@ -148,6 +233,13 @@ class Subscribe(Base): subscrbie.delete(db, subscrbie.id) return True + @async_db_update + async def async_delete_by_tmdbid(self, db: AsyncSession, tmdbid: int, season: int): + subscrbies = await self.async_get_by_tmdbid(db, tmdbid, season) + for subscrbie in subscrbies: + await subscrbie.async_delete(db, subscrbie.id) + return True + @db_update def delete_by_doubanid(self, db: Session, doubanid: str): subscribe = self.get_by_doubanid(db, doubanid) @@ -155,6 +247,13 @@ class Subscribe(Base): subscribe.delete(db, subscribe.id) return True + @async_db_update + async def async_delete_by_doubanid(self, db: AsyncSession, doubanid: str): + subscribe = await self.async_get_by_doubanid(db, doubanid) + if subscribe: + await subscribe.async_delete(db, subscribe.id) + return True + @db_update def delete_by_mediaid(self, db: Session, mediaid: str): subscribe = self.get_by_mediaid(db, mediaid) @@ -162,6 +261,13 @@ class Subscribe(Base): subscribe.delete(db, subscribe.id) return True + @async_db_update + async def async_delete_by_mediaid(self, db: AsyncSession, mediaid: str): + subscribe = await self.async_get_by_mediaid(db, mediaid) + if subscribe: + await subscribe.async_delete(db, subscribe.id) + return True + @staticmethod @db_query def list_by_username(db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None): @@ -180,6 +286,30 @@ class Subscribe(Base): else: return db.query(Subscribe).filter(Subscribe.username == username).all() + @classmethod + @async_db_query + async def async_list_by_username(cls, db: AsyncSession, username: str, state: Optional[str] = None, + mtype: Optional[str] = None): + if mtype: + if state: + result = await db.execute( + select(cls).filter(cls.state == state, cls.username == username, cls.type == mtype) + ) + else: + result = await db.execute( + select(cls).filter(cls.username == username, cls.type == mtype) + ) + else: + if state: + result = await db.execute( + select(cls).filter(cls.state == state, cls.username == username) + ) + else: + result = await db.execute( + select(cls).filter(cls.username == username) + ) + return result.scalars().all() + @staticmethod @db_query def list_by_type(db: Session, mtype: str, days: int): @@ -188,3 +318,15 @@ class Subscribe(Base): Subscribe.date >= time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() - 86400 * int(days))) ).all() + + @classmethod + @async_db_query + async def async_list_by_type(cls, db: AsyncSession, mtype: str, days: int): + result = await db.execute( + select(cls).filter( + cls.type == mtype, + cls.date >= time.strftime("%Y-%m-%d %H:%M:%S", + time.localtime(time.time() - 86400 * int(days))) + ) + ) + return result.scalars().all() diff --git a/app/db/models/subscribehistory.py b/app/db/models/subscribehistory.py index 5294a541..b204ff42 100644 --- a/app/db/models/subscribehistory.py +++ b/app/db/models/subscribehistory.py @@ -1,9 +1,10 @@ from typing import Optional -from sqlalchemy import Column, Integer, String, Sequence, Float, JSON +from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.db import db_query, Base +from app.db import db_query, Base, async_db_query class SubscribeHistory(Base): @@ -81,6 +82,18 @@ class SubscribeHistory(Base): SubscribeHistory.date.desc() ).offset((page - 1) * count).limit(count).all() + @classmethod + @async_db_query + async def async_list_by_type(cls, db: AsyncSession, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30): + result = await db.execute( + select(cls).filter( + cls.type == mtype + ).order_by( + cls.date.desc() + ).offset((page - 1) * count).limit(count) + ) + return result.scalars().all() + @staticmethod @db_query def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None): @@ -92,3 +105,24 @@ class SubscribeHistory(Base): elif doubanid: return db.query(SubscribeHistory).filter(SubscribeHistory.doubanid == doubanid).first() return None + + @classmethod + @async_db_query + async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, + season: Optional[int] = None): + if tmdbid: + if season: + result = await db.execute( + select(cls).filter(cls.tmdbid == tmdbid, cls.season == season) + ) + else: + result = await db.execute( + select(cls).filter(cls.tmdbid == tmdbid) + ) + elif doubanid: + result = await db.execute( + select(cls).filter(cls.doubanid == doubanid) + ) + else: + return None + return result.scalars().first() diff --git a/app/db/models/user.py b/app/db/models/user.py index 646478e5..54eae5f1 100644 --- a/app/db/models/user.py +++ b/app/db/models/user.py @@ -1,7 +1,8 @@ -from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String +from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from app.db import Base, db_query, db_update +from app.db import Base, db_query, db_update, async_db_query, async_db_update class User(Base): @@ -36,11 +37,27 @@ class User(Base): def get_by_name(db: Session, name: str): return db.query(User).filter(User.name == name).first() + @classmethod + @async_db_query + async def async_get_by_name(cls, db: AsyncSession, name: str): + result = await db.execute( + select(cls).filter(cls.name == name) + ) + return result.scalars().first() + @staticmethod @db_query def get_by_id(db: Session, user_id: int): return db.query(User).filter(User.id == user_id).first() + @classmethod + @async_db_query + async def async_get_by_id(cls, db: AsyncSession, user_id: int): + result = await db.execute( + select(cls).filter(cls.id == user_id) + ) + return result.scalars().first() + @db_update def delete_by_name(self, db: Session, name: str): user = self.get_by_name(db, name) @@ -48,6 +65,13 @@ class User(Base): user.delete(db, user.id) return True + @async_db_update + async def async_delete_by_name(self, db: AsyncSession, name: str): + user = await self.async_get_by_name(db, name) + if user: + await user.async_delete(db, user.id) + return True + @db_update def delete_by_id(self, db: Session, user_id: int): user = self.get_by_id(db, user_id) @@ -55,6 +79,13 @@ class User(Base): user.delete(db, user.id) return True + @async_db_update + async def async_delete_by_id(self, db: AsyncSession, user_id: int): + user = await self.async_get_by_id(db, user_id) + if user: + await user.async_delete(db, user.id) + return True + @db_update def update_otp_by_name(self, db: Session, name: str, otp: bool, secret: str): user = self.get_by_name(db, name) @@ -65,3 +96,14 @@ class User(Base): }) return True return False + + @async_db_update + async def async_update_otp_by_name(self, db: AsyncSession, name: str, otp: bool, secret: str): + user = await self.async_get_by_name(db, name) + if user: + await user.async_update(db, { + 'is_otp': otp, + 'otp_secret': secret + }) + return True + return False diff --git a/app/db/subscribe_oper.py b/app/db/subscribe_oper.py index 1204b90d..69b64237 100644 --- a/app/db/subscribe_oper.py +++ b/app/db/subscribe_oper.py @@ -2,7 +2,7 @@ import time from typing import Tuple, List, Optional from app.core.context import MediaInfo -from app.db import DbOper +from app.db import DbOper, AsyncDbOper from app.db.models.subscribe import Subscribe from app.db.models.subscribehistory import SubscribeHistory @@ -48,7 +48,8 @@ class SubscribeOper(DbOper): else: return subscribe.id, "订阅已存在" - def exists(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None) -> bool: + def exists(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, + season: Optional[int] = None) -> bool: """ 判断是否存在 """ @@ -96,7 +97,8 @@ class SubscribeOper(DbOper): """ return Subscribe.get_by_tmdbid(self._db, tmdbid=tmdbid, season=season) - def list_by_username(self, username: str, state: Optional[str] = None, mtype: Optional[str] = None) -> List[Subscribe]: + def list_by_username(self, username: str, state: Optional[str] = None, + mtype: Optional[str] = None) -> List[Subscribe]: """ 获取指定用户的订阅 """ @@ -134,3 +136,15 @@ class SubscribeOper(DbOper): elif doubanid: return True if SubscribeHistory.exists(self._db, doubanid=doubanid) else False return False + + +class AsyncSubscribeOper(AsyncDbOper): + """ + 异步订阅管理 + """ + + async def get(self, sid: int) -> Subscribe: + """ + 获取订阅 + """ + return await Subscribe.async_get(self._db, id=sid) diff --git a/app/helper/subscribe.py b/app/helper/subscribe.py index 753d7ec1..369230e7 100644 --- a/app/helper/subscribe.py +++ b/app/helper/subscribe.py @@ -3,11 +3,11 @@ from typing import List, Tuple, Optional from app.core.cache import cached, cache_backend from app.core.config import settings -from app.db.subscribe_oper import SubscribeOper +from app.db.subscribe_oper import SubscribeOper, AsyncSubscribeOper from app.db.systemconfig_oper import SystemConfigOper from app.log import logger from app.schemas.types import SystemConfigKey -from app.utils.http import RequestUtils +from app.utils.http import RequestUtils, AsyncRequestUtils from app.utils.singleton import WeakSingleton from app.utils.system import SystemUtils @@ -60,7 +60,7 @@ class SubscribeHelper(metaclass=WeakSingleton): self.get_user_uuid() self.get_github_user() - @cached(maxsize=5, ttl=1800) + @cached(region=_shares_cache_region, maxsize=5, ttl=1800, skip_empty=True) def get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]: """ 获取订阅统计数据 @@ -76,6 +76,22 @@ class SubscribeHelper(metaclass=WeakSingleton): return res.json() return [] + @cached(region=_shares_cache_region, maxsize=5, ttl=1800, skip_empty=True) + async def async_get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]: + """ + 异步获取订阅统计数据 + """ + if not settings.SUBSCRIBE_STATISTIC_SHARE: + return [] + res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params={ + "stype": stype, + "page": page, + "count": count + }) + if res and res.status_code == 200: + return res.json() + return [] + def sub_reg(self, sub: dict) -> bool: """ 新增订阅统计 @@ -167,6 +183,37 @@ class SubscribeHelper(metaclass=WeakSingleton): else: return False, res.json().get("message") + async def async_sub_share(self, subscribe_id: int, + share_title: str, share_comment: str, share_user: str) -> Tuple[bool, str]: + """ + 异步分享订阅 + """ + if not settings.SUBSCRIBE_STATISTIC_SHARE: + return False, "当前没有开启订阅数据共享功能" + subscribe = await AsyncSubscribeOper().get(subscribe_id) + if not subscribe: + return False, "订阅不存在" + subscribe_dict = subscribe.to_dict() + subscribe_dict.pop("id") + cache_backend.clear(region=self._shares_cache_region) + res = await AsyncRequestUtils(proxies=settings.PROXY, content_type="application/json", + timeout=10).post(self._sub_share, + json={ + "share_title": share_title, + "share_comment": share_comment, + "share_user": share_user, + "share_uid": self._share_user_id, + **subscribe_dict + }) + if res is None: + return False, "连接MoviePilot服务器失败" + if res.status_code == 200: + # 清除 get_shares 的缓存,以便实时看到结果 + cache_backend.clear(region=self._shares_cache_region) + return True, "" + else: + return False, res.json().get("message") + def share_delete(self, share_id: int) -> Tuple[bool, str]: """ 删除分享 @@ -185,6 +232,24 @@ class SubscribeHelper(metaclass=WeakSingleton): else: return False, res.json().get("message") + async def async_share_delete(self, share_id: int) -> Tuple[bool, str]: + """ + 异步删除分享 + """ + if not settings.SUBSCRIBE_STATISTIC_SHARE: + return False, "当前没有开启订阅数据共享功能" + res = await AsyncRequestUtils(proxies=settings.PROXY, + timeout=5).delete_res(f"{self._sub_share}/{share_id}", + params={"share_uid": self._share_user_id}) + if res is None: + return False, "连接MoviePilot服务器失败" + if res.status_code == 200: + # 清除 get_shares 的缓存,以便实时看到结果 + cache_backend.clear(region=self._shares_cache_region) + return True, "" + else: + return False, res.json().get("message") + def sub_fork(self, share_id: int) -> Tuple[bool, str]: """ 复用分享的订阅 @@ -201,6 +266,22 @@ class SubscribeHelper(metaclass=WeakSingleton): else: return False, res.json().get("message") + async def async_sub_fork(self, share_id: int) -> Tuple[bool, str]: + """ + 异步复用分享的订阅 + """ + if not settings.SUBSCRIBE_STATISTIC_SHARE: + return False, "当前没有开启订阅数据共享功能" + res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=5, headers={ + "Content-Type": "application/json" + }).get_res(self._sub_fork % share_id) + if res is None: + return False, "连接MoviePilot服务器失败" + if res.status_code == 200: + return True, "" + else: + return False, res.json().get("message") + @cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True) def get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]: """ @@ -217,6 +298,23 @@ class SubscribeHelper(metaclass=WeakSingleton): return res.json() return [] + @cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True) + async def async_get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30) -> \ + List[dict]: + """ + 异步获取订阅分享数据 + """ + if not settings.SUBSCRIBE_STATISTIC_SHARE: + return [] + res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params={ + "name": name, + "page": page, + "count": count + }) + if res and res.status_code == 200: + return res.json() + return [] + @cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True) def get_share_statistics(self) -> List[dict]: """ @@ -229,6 +327,18 @@ class SubscribeHelper(metaclass=WeakSingleton): return res.json() return [] + @cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True) + async def async_get_share_statistics(self) -> List[dict]: + """ + 异步获取订阅分享统计数据 + """ + if not settings.SUBSCRIBE_STATISTIC_SHARE: + return [] + res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_share_statistic) + if res and res.status_code == 200: + return res.json() + return [] + def get_user_uuid(self) -> str: """ 获取用户uuid diff --git a/app/utils/http.py b/app/utils/http.py index fcc63608..c16663d6 100644 --- a/app/utils/http.py +++ b/app/utils/http.py @@ -569,18 +569,22 @@ class AsyncRequestUtils: if not content_type: content_type = "application/x-www-form-urlencoded; charset=UTF-8" if headers: - self._headers = headers + # 过滤掉None值的headers + self._headers = {k: v for k, v in headers.items() if v is not None} else: if ua and ua == settings.USER_AGENT: caller_name = get_caller() if caller_name: ua = f"{settings.USER_AGENT} Plugin/{caller_name}" - self._headers = { - "User-Agent": ua, - "Content-Type": content_type, - "Accept": accept_type, - "referer": referer - } + self._headers = {} + if ua: + self._headers["User-Agent"] = ua + if content_type: + self._headers["Content-Type"] = content_type + if accept_type: + self._headers["Accept"] = accept_type + if referer: + self._headers["referer"] = referer if cookies: if isinstance(cookies, str): self._cookies = cookie_parse(cookies)