fix async apis

This commit is contained in:
jxxghp
2025-08-01 14:19:34 +08:00
parent 61f26d331b
commit 0c8fd5121a
28 changed files with 427 additions and 256 deletions

View File

@@ -111,7 +111,7 @@ def downloader2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
@router.get("/schedule", summary="后台服务", response_model=List[schemas.ScheduleInfo])
def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询后台服务信息
"""
@@ -119,7 +119,7 @@ def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/schedule2", summary="后台服务API_TOKEN", response_model=List[schemas.ScheduleInfo])
def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
async def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
查询下载器信息 API_TOKEN认证?token=xxx
"""
@@ -127,12 +127,13 @@ def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
@router.get("/transfer", summary="文件整理统计", response_model=List[int])
def transfer(days: Optional[int] = 7, db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def transfer(days: Optional[int] = 7,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询文件整理统计信息
"""
transfer_stat = TransferHistory.statistic(db, days)
transfer_stat = await TransferHistory.async_statistic(db, days)
return [stat[1] for stat in transfer_stat]

View File

@@ -116,7 +116,7 @@ def stop(hashString: str, name: Optional[str] = None,
@router.get("/clients", summary="查询可用下载器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询可用下载器
"""

View File

@@ -80,7 +80,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
deletesrc: Optional[bool] = False,
deletedest: Optional[bool] = False,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
删除整理记录
"""

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Dict, Optional
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app import schemas
from app.chain.download import DownloadChain
@@ -48,7 +48,7 @@ async def exists_local(title: Optional[str] = None,
mtype: Optional[str] = None,
tmdbid: Optional[int] = None,
season: Optional[int] = None,
db: Session = Depends(get_async_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
判断本地是否存在
@@ -148,7 +148,7 @@ def library(server: str, hidden: Optional[bool] = False,
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询可用媒体服务器
"""

View File

@@ -3,14 +3,14 @@ from typing import Union, Any, List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, Request
from pywebpush import WebPushException, webpush
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.responses import PlainTextResponse
from app import schemas
from app.chain.message import MessageChain
from app.core.config import settings, global_vars
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db import get_async_db
from app.db.models import User
from app.db.models.message import Message
from app.db.user_oper import get_current_active_superuser
@@ -58,15 +58,15 @@ def web_message(text: str, current_user: User = Depends(get_current_active_super
@router.get("/web", summary="获取WEB消息", response_model=List[dict])
def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
db: Session = Depends(get_db),
page: Optional[int] = 1,
count: Optional[int] = 20):
async def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
db: AsyncSession = Depends(get_async_db),
page: Optional[int] = 1,
count: Optional[int] = 20):
"""
获取WEB消息列表
"""
ret_messages = []
messages = Message.list_by_page(db, page=page, count=count)
messages = Message.async_list_by_page(db, page=page, count=count)
for message in messages:
try:
ret_messages.append(message.to_dict())
@@ -106,7 +106,7 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int],
return str(err)
def vocechat_verify() -> Any:
async def vocechat_verify() -> Any:
"""
VoceChat验证响应
"""
@@ -128,7 +128,7 @@ def incoming_verify(token: Optional[str] = None, echostr: Optional[str] = None,
@router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response)
def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)):
async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)):
"""
客户端webpush通知订阅
"""

View File

@@ -13,6 +13,7 @@ from app.command import Command
from app.core.config import settings
from app.core.plugin import PluginManager
from app.core.security import verify_apikey, verify_token
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.factory import app
@@ -138,7 +139,7 @@ def register_plugin(plugin_id: str):
@router.get("/", summary="所有插件", response_model=List[schemas.Plugin])
async def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser_async),
async def all_plugins(_: User = Depends(get_current_active_superuser_async),
state: Optional[str] = "all", force: bool = False) -> List[schemas.Plugin]:
"""
查询所有插件清单包括本地插件和在线插件插件状态installed, market, all
@@ -187,7 +188,7 @@ async def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_super
@router.get("/installed", summary="已安装插件", response_model=List[str])
def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def installed(_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询用户已安装插件清单
"""
@@ -203,7 +204,7 @@ async def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/reload/{plugin_id}", summary="重新加载插件", response_model=schemas.Response)
def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
def reload_plugin(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> Any:
"""
重新加载插件
"""
@@ -218,7 +219,7 @@ def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_
def install(plugin_id: str,
repo_url: Optional[str] = "",
force: Optional[bool] = False,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
安装插件
"""
@@ -260,7 +261,7 @@ def remotes(token: str) -> Any:
@router.get("/form/{plugin_id}", summary="获取插件表单页面")
def plugin_form(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
_: User = Depends(get_current_active_superuser)) -> dict:
"""
根据插件ID获取插件配置表单或Vue组件URL
"""
@@ -284,7 +285,7 @@ def plugin_form(plugin_id: str,
@router.get("/page/{plugin_id}", summary="获取插件数据页面")
def plugin_page(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
def plugin_page(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> dict:
"""
根据插件ID获取插件数据页面
"""
@@ -333,7 +334,7 @@ def plugin_dashboard(plugin_id: str, user_agent: Annotated[str | None, Header()]
@router.get("/reset/{plugin_id}", summary="重置插件配置及数据", response_model=schemas.Response)
def reset_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
根据插件ID重置插件配置及数据
"""
@@ -394,7 +395,7 @@ async def plugin_static_file(plugin_id: str, filepath: str):
@router.get("/folders", summary="获取插件文件夹配置", response_model=dict)
def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
async def get_plugin_folders(_: User = Depends(get_current_active_superuser_async)) -> dict:
"""
获取插件文件夹分组配置
"""
@@ -407,7 +408,7 @@ def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_supe
@router.post("/folders", summary="保存插件文件夹配置", response_model=schemas.Response)
def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def save_plugin_folders(folders: dict, _: User = Depends(get_current_active_superuser_async)) -> Any:
"""
保存插件文件夹分组配置
"""
@@ -420,7 +421,8 @@ def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_cur
@router.post("/folders/{folder_name}", summary="创建插件文件夹", response_model=schemas.Response)
def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def create_plugin_folder(folder_name: str,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
创建新的插件文件夹
"""
@@ -434,34 +436,35 @@ def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get
@router.delete("/folders/{folder_name}", summary="删除插件文件夹", response_model=schemas.Response)
def delete_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def delete_plugin_folder(folder_name: str,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
删除插件文件夹
"""
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
if folder_name in folders:
del folders[folder_name]
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders)
await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 删除成功")
else:
return schemas.Response(success=False, message=f"文件夹 '{folder_name}' 不存在")
@router.put("/folders/{folder_name}/plugins", summary="更新文件夹中的插件", response_model=schemas.Response)
def update_folder_plugins(folder_name: str, plugin_ids: List[str],
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def update_folder_plugins(folder_name: str, plugin_ids: List[str],
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
更新指定文件夹中的插件列表
"""
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
folders[folder_name] = plugin_ids
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders)
await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 中的插件已更新")
@router.get("/{plugin_id}", summary="获取插件配置")
def plugin_config(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
async def plugin_config(plugin_id: str,
_: User = Depends(get_current_active_superuser_async)) -> dict:
"""
根据插件ID获取插件配置信息
"""
@@ -470,7 +473,7 @@ def plugin_config(plugin_id: str,
@router.put("/{plugin_id}", summary="更新插件配置", response_model=schemas.Response)
def set_plugin_config(plugin_id: str, conf: dict,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
更新插件配置
"""
@@ -486,7 +489,7 @@ def set_plugin_config(plugin_id: str, conf: dict,
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
def uninstall_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
卸载插件
"""
@@ -527,7 +530,7 @@ def uninstall_plugin(plugin_id: str,
@router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response)
def clone_plugin(plugin_id: str,
clone_data: dict,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
创建插件分身
"""

View File

@@ -10,7 +10,7 @@ from app.api.endpoints.plugin import register_plugin_api
from app.chain.site import SiteChain
from app.chain.torrents import TorrentsChain
from app.command import Command
from app.core.event import EventManager
from app.core.event import eventmanager
from app.core.plugin import PluginManager
from app.core.security import verify_token
from app.db import get_db, get_async_db
@@ -21,7 +21,7 @@ from app.db.models.sitestatistic import SiteStatistic
from app.db.models.siteuserdata import SiteUserData
from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.sites import SitesHelper # noqa
from app.scheduler import Scheduler
from app.schemas.types import SystemConfigKey, EventType
@@ -31,20 +31,20 @@ router = APIRouter()
@router.get("/", summary="所有站点", response_model=List[schemas.Site])
def read_sites(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> List[dict]:
async def read_sites(db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser)) -> List[dict]:
"""
获取站点列表
"""
return Site.list_order_by_pri(db)
return Site.async_list_order_by_pri(db)
@router.post("/", summary="新增站点", response_model=schemas.Response)
def add_site(
async def add_site(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser)
_: User = Depends(get_current_active_superuser)
) -> Any:
"""
新增站点
@@ -54,10 +54,10 @@ def add_site(
if SitesHelper().auth_level < 2:
return schemas.Response(success=False, message="用户未通过认证,无法使用站点功能!")
domain = StringUtils.get_url_domain(site_in.url)
site_info = SitesHelper().get_indexer(domain)
site_info = await SitesHelper().async_get_indexer(domain)
if not site_info:
return schemas.Response(success=False, message="该站点不支持,请检查站点域名是否正确")
if Site.get_by_domain(db, domain):
if await Site.async_get_by_domain(db, domain):
return schemas.Response(success=False, message=f"{domain} 站点己存在")
# 保存站点信息
site_in.domain = domain
@@ -70,39 +70,39 @@ def add_site(
site = Site(**site_in.dict())
site.create(db)
# 通知站点更新
EventManager().send_event(EventType.SiteUpdated, {
await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": domain
})
return schemas.Response(success=True)
@router.put("/", summary="更新站点", response_model=schemas.Response)
def update_site(
async def update_site(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser)
_: User = Depends(get_current_active_superuser)
) -> Any:
"""
更新站点信息
"""
site = Site.get(db, site_in.id)
site = await Site.async_get(db, site_in.id)
if not site:
return schemas.Response(success=False, message="站点不存在")
# 校正地址格式
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
site_in.url = f"{_scheme}://{_netloc}/"
site.update(db, site_in.dict())
await site.async_update(db, site_in.dict())
# 通知站点更新
EventManager().send_event(EventType.SiteUpdated, {
await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": site_in.domain
})
return schemas.Response(success=True)
@router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response)
def cookie_cloud_sync(background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def cookie_cloud_sync(background_tasks: BackgroundTasks,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
运行CookieCloud同步站点信息
"""
@@ -111,7 +111,7 @@ def cookie_cloud_sync(background_tasks: BackgroundTasks,
@router.get("/reset", summary="重置站点", response_model=schemas.Response)
def reset(db: Session = Depends(get_db),
def reset(db: AsyncSession = Depends(get_db),
_: User = Depends(get_current_active_superuser)) -> Any:
"""
清空所有站点数据并重新同步CookieCloud站点信息
@@ -122,25 +122,25 @@ def reset(db: Session = Depends(get_db),
# 启动定时服务
Scheduler().start("cookiecloud", manual=True)
# 插件站点删除
EventManager().send_event(EventType.SiteDeleted,
{
"site_id": "*"
})
eventmanager.send_event(EventType.SiteDeleted,
{
"site_id": "*"
})
return schemas.Response(success=True, message="站点已重置!")
@router.post("/priorities", summary="批量更新站点优先级", response_model=schemas.Response)
def update_sites_priority(
async def update_sites_priority(
priorities: List[dict],
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
批量更新站点优先级
"""
for priority in priorities:
site = Site.get(db, priority.get("id"))
site = await Site.async_get(db, priority.get("id"))
if site:
site.update(db, {"pri": priority.get("pri")})
await site.async_update(db, {"pri": priority.get("pri")})
return schemas.Response(success=True)
@@ -151,7 +151,7 @@ def update_cookie(
password: str,
code: Optional[str] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
使用用户密码更新站点Cookie
"""
@@ -174,7 +174,7 @@ def update_cookie(
def refresh_userdata(
site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
刷新站点用户数据
"""
@@ -192,34 +192,34 @@ def refresh_userdata(
@router.get("/userdata/latest", summary="查询所有站点最新用户数据", response_model=List[schemas.SiteUserData])
def read_userdata_latest(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def read_userdata_latest(
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询所有站点最新用户数据
"""
user_datas = SiteUserData.get_latest(db)
user_datas = await SiteUserData.async_get_latest(db)
if not user_datas:
return []
return [user_data.to_dict() for user_data in user_datas]
@router.get("/userdata/{site_id}", summary="查询某站点用户数据", response_model=schemas.Response)
def read_userdata(
async def read_userdata(
site_id: int,
workdate: Optional[str] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询站点用户数据
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
user_data = SiteUserData.get_by_domain(db, domain=site.domain, workdate=workdate)
user_data = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate)
if not user_data:
return schemas.Response(success=False, data=[])
return schemas.Response(success=True, data=user_data)
@@ -264,19 +264,19 @@ async def site_icon(site_id: int,
@router.get("/category/{site_id}", summary="站点分类", response_model=List[schemas.SiteCategory])
def site_category(site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def site_category(site_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取站点分类
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
indexer = SitesHelper().get_indexer(site.domain)
indexer = await SitesHelper().async_get_indexer(site.domain)
if not indexer:
raise HTTPException(
status_code=404,
@@ -294,38 +294,38 @@ def site_category(site_id: int,
@router.get("/resource/{site_id}", summary="站点资源", response_model=List[schemas.TorrentInfo])
def site_resource(site_id: int,
keyword: Optional[str] = None,
cat: Optional[str] = None,
page: Optional[int] = 0,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def site_resource(site_id: int,
keyword: Optional[str] = None,
cat: Optional[str] = None,
page: Optional[int] = 0,
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
浏览站点资源
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
torrents = TorrentsChain().browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
torrents = await TorrentsChain().async_browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
if not torrents:
return []
return [torrent.to_dict() for torrent in torrents]
@router.get("/domain/{site_url}", summary="站点详情", response_model=schemas.Site)
def read_site_by_domain(
async def read_site_by_domain(
site_url: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
通过域名获取站点信息
"""
domain = StringUtils.get_url_domain(site_url)
site = Site.get_by_domain(db, domain)
site = await Site.async_get_by_domain(db, domain)
if not site:
raise HTTPException(
status_code=404,
@@ -335,35 +335,35 @@ def read_site_by_domain(
@router.get("/statistic/{site_url}", summary="特定站点统计信息", response_model=schemas.SiteStatistic)
def read_statistic_by_domain(
async def read_statistic_by_domain(
site_url: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
通过域名获取站点统计信息
"""
domain = StringUtils.get_url_domain(site_url)
sitestatistic = SiteStatistic.get_by_domain(db, domain)
sitestatistic = await SiteStatistic.async_get_by_domain(db, domain)
if sitestatistic:
return sitestatistic
return schemas.SiteStatistic(domain=domain)
@router.get("/statistic", summary="所有站点统计信息", response_model=List[schemas.SiteStatistic])
def read_statistics(
db: Session = Depends(get_db),
async def read_statistics(
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
获取所有站点统计信息
"""
return SiteStatistic.list(db)
return await SiteStatistic.async_list(db)
@router.get("/rss", summary="所有订阅站点", response_model=List[schemas.Site])
def read_rss_sites(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
async def read_rss_sites(db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
"""
获取站点列表
"""
@@ -371,7 +371,7 @@ def read_rss_sites(db: Session = Depends(get_db),
selected_sites = SystemConfigOper().get(SystemConfigKey.RssSites) or []
# 所有站点
all_site = Site.list_order_by_pri(db)
all_site = await Site.async_list_order_by_pri(db)
if not selected_sites:
return all_site
@@ -381,7 +381,7 @@ def read_rss_sites(db: Session = Depends(get_db),
@router.get("/auth", summary="查询认证站点", response_model=dict)
def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
async def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
"""
获取可认证站点列表
"""
@@ -409,12 +409,12 @@ def auth_site(
@router.get("/mapping", summary="获取站点域名到名称的映射", response_model=schemas.Response)
def site_mapping(_: User = Depends(get_current_active_superuser)):
async def site_mapping(_: User = Depends(get_current_active_superuser_async)):
"""
获取站点域名到名称的映射关系
"""
try:
sites = SiteOper().list()
sites = await SiteOper().async_list()
mapping = {}
for site in sites:
mapping[site.domain] = site.name
@@ -424,7 +424,7 @@ def site_mapping(_: User = Depends(get_current_active_superuser)):
@router.get("/supporting", summary="获取支持的站点列表", response_model=dict)
def support_sites(_: User = Depends(get_current_active_superuser)):
async def support_sites(_: User = Depends(get_current_active_superuser_async)):
"""
获取支持的站点列表
"""
@@ -432,15 +432,15 @@ def support_sites(_: User = Depends(get_current_active_superuser)):
@router.get("/{site_id}", summary="站点详情", response_model=schemas.Site)
def read_site(
async def read_site(
site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)
) -> Any:
"""
通过ID获取站点信息
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
@@ -450,18 +450,18 @@ def read_site(
@router.delete("/{site_id}", summary="删除站点", response_model=schemas.Response)
def delete_site(
async def delete_site(
site_id: int,
db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser)
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)
) -> Any:
"""
删除站点
"""
Site.delete(db, site_id)
await Site.async_delete(db, site_id)
# 插件站点删除
EventManager().send_event(EventType.SiteDeleted,
{
"site_id": site_id
})
await eventmanager.async_send_event(EventType.SiteDeleted,
{
"site_id": site_id
})
return schemas.Response(success=True)

View File

@@ -12,7 +12,7 @@ from app.core.config import settings
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token
from app.db.models import User
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.progress import ProgressHelper
from app.schemas.types import ProgressKey
@@ -222,7 +222,7 @@ def usage(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
@router.get("/transtype/{name}", summary="支持的整理方式获取", response_model=schemas.StorageTransType)
def transtype(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
async def transtype(name: str, _: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询支持的整理方式
"""

View File

@@ -117,7 +117,7 @@ async def update_subscribe(
subscribe_dict["manual_total_episode"] = 1
# 发送订阅调整事件
subscribe = await subscribe.async_get(db, subscribe_in.id)
eventmanager.send_event(EventType.SubscribeModified, {
await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe_in.id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
@@ -145,7 +145,7 @@ async def update_subscribe_status(
"state": state
})
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
@@ -224,7 +224,7 @@ async def reset_subscribes(
"state": "R"
})
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
@@ -313,7 +313,7 @@ async def delete_subscribe_by_mediaid(
for subscribe in delete_subscribes:
await Subscribe.async_delete(db, subscribe.id)
# 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, {
eventmanager.async_send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe.id,
"subscribe_info": subscribe.to_dict()
})
@@ -531,7 +531,7 @@ async def follow_subscriber(
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid not in subscribers:
subscribers.append(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True)
@@ -545,7 +545,7 @@ async def unfollow_subscriber(
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid in subscribers:
subscribers.remove(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True)
@@ -596,7 +596,7 @@ async def delete_subscribe(
if subscribe:
await Subscribe.async_delete(db, subscribe_id)
# 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, {
eventmanager.async_send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe_id,
"subscribe_info": subscribe.to_dict()
})

View File

@@ -23,7 +23,7 @@ from app.core.module import ModuleManager
from app.core.security import verify_apitoken, verify_resource_token, verify_token
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.mediaserver import MediaServerHelper
from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper
@@ -202,7 +202,7 @@ def get_global_setting(token: str):
@router.get("/env", summary="查询系统配置", response_model=schemas.Response)
def get_env_setting(_: User = Depends(get_current_active_superuser)):
async def get_env_setting(_: User = Depends(get_current_active_superuser_async)):
"""
查询系统环境变量,包括当前版本号(仅管理员)
"""
@@ -220,8 +220,8 @@ def get_env_setting(_: User = Depends(get_current_active_superuser)):
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
def set_env_setting(env: dict,
_: User = Depends(get_current_active_superuser)):
async def set_env_setting(env: dict,
_: User = Depends(get_current_active_superuser_async)):
"""
更新系统环境变量(仅管理员)
"""
@@ -243,7 +243,7 @@ def set_env_setting(env: dict,
if success_updates:
for key in success_updates.keys():
# 发送配置变更事件
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=getattr(settings, key, None),
change_type="update"
@@ -280,8 +280,8 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
def get_setting(key: str,
_: User = Depends(get_current_active_superuser)):
async def get_setting(key: str,
_: User = Depends(get_current_active_superuser_async)):
"""
查询系统设置(仅管理员)
"""
@@ -295,10 +295,10 @@ def get_setting(key: str,
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
def set_setting(
async def set_setting(
key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
_: User = Depends(get_current_active_superuser),
_: User = Depends(get_current_active_superuser_async),
):
"""
更新系统设置(仅管理员)
@@ -307,7 +307,7 @@ def set_setting(
success, message = settings.update_setting(key=key, value=value)
if success:
# 发送配置变更事件
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=value,
change_type="update"
@@ -319,10 +319,10 @@ def set_setting(
if isinstance(value, list):
value = list(filter(None, value))
value = value if value else None
success = SystemConfigOper().set(key, value)
success = await SystemConfigOper().async_set(key, value)
if success:
# 发送配置变更事件
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=value,
change_type="update"

View File

@@ -12,6 +12,7 @@ from app.core.config import settings
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db.models import User
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser
from app.helper.directory import DirectoryHelper
@@ -81,7 +82,7 @@ def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(v
def manual_transfer(transer_item: ManualTransferItem,
background: Optional[bool] = False,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
手动转移,文件或历史记录,支持自定义剧集识别格式
:param transer_item: 手工整理项

View File

@@ -25,7 +25,7 @@ async def list_users(
"""
查询用户列表
"""
return await User.async_list(db)
return await current_user.async_list(db)
@router.post("/", summary="新增用户", response_model=schemas.Response)
@@ -38,7 +38,7 @@ async def create_user(
"""
新增用户
"""
user = await User.async_get_by_name(db, name=user_in.name)
user = await current_user.async_get_by_name(db, name=user_in.name)
if user:
return schemas.Response(success=False, message="用户已存在")
user_info = user_in.dict()
@@ -68,12 +68,12 @@ async def update_user(
message="密码需要同时包含字母、数字、特殊字符中的至少两项且长度大于6位")
user_info["hashed_password"] = get_password_hash(user_info["password"])
user_info.pop("password")
user = await User.async_get_by_id(db, user_id=user_info["id"])
user = await current_user.async_get_by_id(db, user_id=user_info["id"])
user_name = user_info.get("name")
if not user_name:
return schemas.Response(success=False, message="用户名不能为空")
# 新用户名去重
users = await User.async_list(db)
users = await current_user.async_list(db)
for u in users:
if u.name == user_name and u.id != user_info["id"]:
return schemas.Response(success=False, message="用户名已被使用")
@@ -185,10 +185,10 @@ async def delete_user_by_id(
"""
通过唯一ID删除用户
"""
user = await User.async_get_by_id(db, user_id=user_id)
user = await current_user.async_get_by_id(db, user_id=user_id)
if not user:
return schemas.Response(success=False, message="用户不存在")
await User.async_delete(db, user_id)
await current_user.async_delete(db, user_id)
return schemas.Response(success=True)
@@ -202,10 +202,10 @@ async def delete_user_by_name(
"""
通过用户名删除用户
"""
user = await User.async_get_by_name(db, name=user_name)
user = await current_user.async_get_by_name(db, name=user_name)
if not user:
return schemas.Response(success=False, message="用户不存在")
await User.async_delete(db, user.id)
await current_user.async_delete(db, user.id)
return schemas.Response(success=True)
@@ -218,7 +218,7 @@ async def read_user_by_name(
"""
查询用户详情
"""
user = await User.async_get_by_name(db, name=username)
user = await current_user.async_get_by_name(db, name=username)
if not user:
raise HTTPException(
status_code=404,

View File

@@ -657,6 +657,19 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("refresh_torrents", site=site, keyword=keyword, cat=cat, page=page)
async def async_refresh_torrents(self, site: dict, keyword: Optional[str] = None,
cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步获取站点最新一页的种子,多个站点需要多线程处理
:param site: 站点
:param keyword: 标题
:param cat: 分类
:param page: 页码
:reutrn: 种子资源列表
"""
return await self.async_run_module("async_refresh_torrents",
site=site, keyword=keyword, cat=cat, page=page)
def filter_torrents(self, rule_groups: List[str],
torrent_list: List[TorrentInfo],
mediainfo: MediaInfo = None) -> List[TorrentInfo]:

View File

@@ -4,12 +4,11 @@ from datetime import datetime
from typing import Optional, Tuple, Union, Dict
from urllib.parse import urljoin
from app.helper.sites import SitesHelper # noqa
from lxml import etree
from app.chain import ChainBase
from app.core.config import global_vars, settings
from app.core.event import Event, EventManager, eventmanager
from app.core.event import Event, eventmanager
from app.db.models.site import Site
from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper
@@ -18,6 +17,7 @@ from app.helper.cloudflare import under_challenge
from app.helper.cookie import CookieHelper
from app.helper.cookiecloud import CookieCloudHelper
from app.helper.rss import RssHelper
from app.helper.sites import SitesHelper # noqa
from app.log import logger
from app.schemas import MessageChannel, Notification, SiteUserData
from app.schemas.types import EventType, NotificationType
@@ -58,7 +58,7 @@ class SiteChain(ChainBase):
name=site.get("name"),
payload=userdata.dict())
# 发送事件
EventManager().send_event(EventType.SiteRefreshed, {
eventmanager.send_event(EventType.SiteRefreshed, {
"site_id": site.get("id")
})
# 发送站点消息
@@ -103,7 +103,7 @@ class SiteChain(ChainBase):
any_site_updated = True
result[site.get("name")] = userdata
if any_site_updated:
EventManager().send_event(EventType.SiteRefreshed, {
eventmanager.send_event(EventType.SiteRefreshed, {
"site_id": "*"
})
@@ -415,7 +415,7 @@ class SiteChain(ChainBase):
# 通知站点更新
if indexer:
EventManager().send_event(EventType.SiteUpdated, {
eventmanager.send_event(EventType.SiteUpdated, {
"domain": domain,
})
# 处理完成

View File

@@ -15,7 +15,7 @@ from app.chain.tmdb import TmdbChain
from app.chain.torrents import TorrentsChain
from app.core.config import settings, global_vars
from app.core.context import TorrentInfo, Context, MediaInfo
from app.core.event import eventmanager, Event, EventManager
from app.core.event import eventmanager, Event
from app.core.meta import MetaBase
from app.core.meta.words import WordsMatcher
from app.core.metainfo import MetaInfo
@@ -237,7 +237,7 @@ class SubscribeChain(ChainBase):
username=username
)
# 发送事件
EventManager().send_event(EventType.SubscribeAdded, {
eventmanager.send_event(EventType.SubscribeAdded, {
"subscribe_id": sid,
"username": username,
"mediainfo": mediainfo.to_dict(),
@@ -1090,7 +1090,7 @@ class SubscribeChain(ChainBase):
username=subscribe.username
)
# 发送事件
EventManager().send_event(EventType.SubscribeComplete, {
eventmanager.send_event(EventType.SubscribeComplete, {
"subscribe_id": subscribe.id,
"subscribe_info": subscribe.to_dict(),
"mediainfo": mediainfo.to_dict(),

View File

@@ -7,12 +7,11 @@ from typing import Union, Optional
from app.chain import ChainBase
from app.core.config import settings
from app.core.plugin import PluginManager
from app.helper.system import SystemHelper
from app.log import logger
from app.schemas import Notification, MessageChannel
from app.utils.http import RequestUtils
from app.utils.system import SystemUtils
from app.helper.system import SystemHelper
from app.helper.plugin import PluginHelper
from version import FRONTEND_VERSION, APP_VERSION

View File

@@ -85,6 +85,22 @@ class TorrentsChain(ChainBase):
return []
return self.refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
async def async_browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步浏览站点首页内容返回种子清单TTL缓存5分钟
:param domain: 站点域名
:param keyword: 搜索标题
:param cat: 搜索分类
:param page: 页码
"""
logger.info(f'开始获取站点 {domain} 最新种子 ...')
site = await SitesHelper().async_get_indexer(domain)
if not site:
logger.error(f'站点 {domain} 不存在!')
return []
return await self.async_refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
def rss(self, domain: str) -> List[TorrentInfo]:
"""
获取站点RSS内容返回种子清单TTL缓存3分钟

View File

@@ -1,9 +1,10 @@
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, async_db_query
class Message(Base):
@@ -38,3 +39,11 @@ class Message(Base):
@db_query
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all()
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count)
)
return result.scalars().all()

View File

@@ -1,10 +1,10 @@
from datetime import datetime
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON, select
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON, select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base, async_db_query
from app.db import db_query, db_update, Base, async_db_query, async_db_update
class Site(Base):
@@ -82,6 +82,12 @@ class Site(Base):
def list_order_by_pri(cls, db: Session):
return db.query(cls).order_by(cls.pri).all()
@classmethod
@async_db_query
async def async_list_order_by_pri(cls, db: AsyncSession):
result = await db.execute(select(cls).order_by(cls.pri))
return result.scalars().all()
@classmethod
@db_query
def get_domains_by_ids(cls, db: Session, ids: list):
@@ -91,3 +97,8 @@ class Site(Base):
@db_update
def reset(cls, db: Session):
db.query(cls).delete()
@classmethod
@async_db_update
async def async_reset(cls, db: AsyncSession):
await db.execute(delete(cls))

View File

@@ -1,10 +1,11 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, async_db_query
class SiteUserData(Base):
@@ -65,6 +66,17 @@ class SiteUserData(Base):
cls.updated_day == workdate).all()
return db.query(cls).filter(cls.domain == domain).all()
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
query = select(cls).filter(cls.domain == domain)
if workdate and worktime:
query = query.filter(cls.updated_day == workdate, cls.updated_time == worktime)
elif workdate:
query = query.filter(cls.updated_day == workdate)
result = await db.execute(query)
return result.scalars().all()
@classmethod
@db_query
def get_by_date(cls, db: Session, date: str):
@@ -92,3 +104,28 @@ class SiteUserData(Base):
(cls.domain == subquery.c.domain) &
(cls.updated_day == subquery.c.latest_update_day)
).order_by(cls.updated_time.desc()).all()
@classmethod
@async_db_query
async def async_get_latest(cls, db: AsyncSession):
"""
异步获取各站点最新一天的数据
"""
subquery = (
select(
cls.domain,
func.max(cls.updated_day).label('latest_update_day')
)
.group_by(cls.domain)
.filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
.subquery()
)
# 主查询:按 domain 和 updated_day 获取最新的记录
result = await db.execute(
select(cls).join(
subquery,
(cls.domain == subquery.c.domain) &
(cls.updated_day == subquery.c.latest_update_day)
).order_by(cls.updated_time.desc()))
return result.scalars().all()

View File

@@ -1,7 +1,8 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, Base, async_db_query
class SystemConfig(Base):
@@ -19,6 +20,12 @@ class SystemConfig(Base):
def get_by_key(cls, db: Session, key: str):
return db.query(cls).filter(cls.key == key).first()
@classmethod
@async_db_query
async def async_get_by_key(cls, db: AsyncSession, key: str):
result = await db.execute(select(cls).where(cls.key == key))
return result.scalar_one_or_none()
@db_update
def delete_by_key(self, db: Session, key: str):
systemconfig = self.get_by_key(db, key)

View File

@@ -173,6 +173,21 @@ class TransferHistory(Base):
time.localtime(time.time() - 86400 * days))).subquery()
return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all()
@classmethod
@async_db_query
async def async_statistic(cls, db: AsyncSession, days: Optional[int] = 7):
"""
统计最近days天的下载历史数量按日期分组返回每日数量
"""
sub_query = select(func.substr(cls.date, 1, 10).label('date'),
cls.id.label('id')).filter(
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery()
result = await db.execute(
select(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date)
)
return result.scalars().all()
@classmethod
@db_query
def count(cls, db: Session, status: bool = None):

View File

@@ -35,6 +35,12 @@ class SiteOper(DbOper):
"""
return Site.list(self._db)
async def async_list(self) -> List[Site]:
"""
异步获取站点列表
"""
return await Site.async_list(self._db)
def list_order_by_pri(self) -> List[Site]:
"""
获取站点列表

View File

@@ -47,6 +47,33 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
conf.create(self._db)
return True
async def async_set(self, key: Union[str, SystemConfigKey], value: Any) -> Optional[bool]:
"""
异步设置系统设置
:param key: 配置键
:param value: 配置值
:return: 是否设置成功True 成功/False 失败/None 无需更新)
"""
if isinstance(key, SystemConfigKey):
key = key.value
# 旧值
old_value = self.__SYSTEMCONF.get(key)
# 更新内存(deepcopy避免内存共享)
self.__SYSTEMCONF[key] = copy.deepcopy(value)
conf = await SystemConfig.async_get_by_key(self._db, key)
if conf:
if old_value != value:
if value:
conf.update(self._db, {"value": value})
else:
conf.delete(self._db, conf.id)
return True
return None
else:
conf = SystemConfig(key=key, value=value)
await conf.async_create(self._db)
return True
def get(self, key: Union[str, SystemConfigKey] = None) -> Any:
"""
获取系统设置

View File

@@ -60,7 +60,6 @@ class PlaywrightHelper:
except Exception as e:
logger.error(f"网页操作失败: {str(e)}")
finally:
# 确保资源被正确清理
if page:
page.close()
if context:

View File

@@ -144,53 +144,50 @@ class CookieHelper:
break
if not submit_xpath:
return None, None, "未找到登录按钮"
finally:
if html is not None:
del html
# 点击登录按钮
try:
# 等待登录按钮准备好
page.wait_for_selector(submit_xpath)
# 输入用户名
page.fill(username_xpath, username)
# 输入密码
page.fill(password_xpath, password)
# 输入二步验证码
if twostep_xpath:
page.fill(twostep_xpath, otp_code)
# 识别验证码
if captcha_xpath and captcha_img_url:
captcha_element = page.query_selector(captcha_xpath)
if captcha_element.is_visible():
# 验证码图片地址
code_url = self.__get_captcha_url(url, captcha_img_url)
# 获取当前的cookie和ua
cookie = self.parse_cookies(page.context.cookies())
ua = page.evaluate("() => window.navigator.userAgent")
# 自动OCR识别验证码
captcha = self.__get_captcha_text(cookie=cookie, ua=ua, code_url=code_url)
if captcha:
logger.info("验证码地址为:%s,识别结果:%s" % (code_url, captcha))
else:
return None, None, "验证码识别失败"
# 输入验证码
captcha_element.fill(captcha)
else:
# 不可见元素不处理
pass
# 点击登录按钮
page.click(submit_xpath)
page.wait_for_load_state("networkidle", timeout=30 * 1000)
except Exception as e:
logger.error(f"仿真登录失败:{str(e)}")
return None, None, f"仿真登录失败:{str(e)}"
# 对于某二次验证码为单页面的站点,输入二次验证码
if "verify" in page.url:
if not otp_code:
return None, None, "需要二次验证码"
html = etree.HTML(page.content())
try:
# 等待登录按钮准备好
page.wait_for_selector(submit_xpath)
# 输入用户名
page.fill(username_xpath, username)
# 输入密码
page.fill(password_xpath, password)
# 输入二步验证码
if twostep_xpath:
page.fill(twostep_xpath, otp_code)
# 识别验证码
if captcha_xpath and captcha_img_url:
captcha_element = page.query_selector(captcha_xpath)
if captcha_element.is_visible():
# 验证码图片地址
code_url = self.__get_captcha_url(url, captcha_img_url)
# 获取当前的cookie和ua
cookie = self.parse_cookies(page.context.cookies())
ua = page.evaluate("() => window.navigator.userAgent")
# 自动OCR识别验证码
captcha = self.__get_captcha_text(cookie=cookie, ua=ua, code_url=code_url)
if captcha:
logger.info("验证码地址为:%s,识别结果:%s" % (code_url, captcha))
else:
return None, None, "验证码识别失败"
# 输入验证码
captcha_element.fill(captcha)
else:
# 不可见元素不处理
pass
# 点击登录按钮
page.click(submit_xpath)
page.wait_for_load_state("networkidle", timeout=30 * 1000)
except Exception as e:
logger.error(f"仿真登录失败:{str(e)}")
return None, None, f"仿真登录失败:{str(e)}"
# 对于某二次验证码为单页面的站点,输入二次验证码
if "verify" in page.url:
if not otp_code:
return None, None, "需要二次验证码"
html = etree.HTML(page.content())
for xpath in self._SITE_LOGIN_XPATH.get("twostep"):
if html.xpath(xpath):
try:
@@ -204,28 +201,29 @@ class CookieHelper:
logger.error(f"二次验证码输入失败:{str(e)}")
return None, None, f"二次验证码输入失败:{str(e)}"
break
finally:
if html is not None:
del html
# 登录后的源码
html_text = page.content()
if not html_text:
return None, None, "获取网页源码失败"
if SiteUtils.is_logged_in(html_text):
return self.parse_cookies(page.context.cookies()), \
page.evaluate("() => window.navigator.userAgent"), ""
else:
# 读取错误信息
error_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("error"):
if html.xpath(xpath):
error_xpath = xpath
break
if not error_xpath:
return None, None, "登录失败"
# 登录后的源码
html_text = page.content()
if not html_text:
return None, None, "获取网页源码失败"
if SiteUtils.is_logged_in(html_text):
return self.parse_cookies(page.context.cookies()), \
page.evaluate("() => window.navigator.userAgent"), ""
else:
error_msg = html.xpath(error_xpath)[0]
return None, None, error_msg
# 读取错误信息
error_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("error"):
if html.xpath(xpath):
error_xpath = xpath
break
if not error_xpath:
return None, None, "登录失败"
else:
error_msg = html.xpath(error_xpath)[0]
return None, None, error_msg
finally:
if html:
del html
if not url or not username or not password:
return None, None, "参数错误"

View File

@@ -457,6 +457,20 @@ class IndexerModule(_ModuleBase):
"""
return self.search_torrents(site=site, keywords=[keyword], cat=cat, page=page)
async def async_refresh_torrents(self, site: dict,
keyword: Optional[str] = None,
cat: Optional[str] = None,
page: Optional[int] = 0) -> Optional[List[TorrentInfo]]:
"""
异步获取站点最新一页的种子,多个站点需要多线程处理
:param site: 站点
:param keyword: 关键字
:param cat: 分类
:param page: 页码
:reutrn: 种子资源列表
"""
return await self.async_search_torrents(site=site, keywords=[keyword], cat=cat, page=page)
def refresh_userdata(self, site: dict) -> Optional[SiteUserData]:
"""
刷新站点的用户数据

View File

@@ -21,6 +21,7 @@ from app.core.config import settings
from app.core.event import EventManager, eventmanager, Event
from app.core.plugin import PluginManager
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.message import MessageHelper
from app.helper.sites import SitesHelper # noqa
from app.helper.wallpaper import WallpaperHelper
from app.log import logger
@@ -380,46 +381,60 @@ class Scheduler(metaclass=Singleton):
# 启动定时服务
self._scheduler.start()
def __prepare_job(self, job_id: str) -> Optional[dict]:
"""
准备定时任务
"""
with self._lock:
job = self._jobs.get(job_id)
if not job:
return None
if job.get("running"):
logger.warning(f"定时任务 {job_id} - {job.get("name")} 正在运行 ...")
return None
self._jobs[job_id]["running"] = True
return job
def __finish_job(self, job_id: str):
"""
完成定时任务
"""
with self._lock:
try:
self._jobs[job_id]["running"] = False
except KeyError:
pass
def start(self, job_id: str, *args, **kwargs):
"""
启动定时服务
"""
# 处理job_id格式
with self._lock:
job = self._jobs.get(job_id)
if not job:
return
job_name = job.get("name")
if job.get("running"):
logger.warning(f"定时任务 {job_id} - {job_name} 正在运行 ...")
return
self._jobs[job_id]["running"] = True
# 获取定时任务
job = self.__prepare_job(job_id)
if not job:
return
# 开始运行
try:
if not kwargs:
kwargs = job.get("kwargs") or {}
job["func"](*args, **kwargs)
except Exception as e:
logger.error(f"定时任务 {job_name} 执行失败:{str(e)} - {traceback.format_exc()}")
SchedulerChain().messagehelper.put(title=f"{job_name} 执行失败",
message=str(e),
role="system")
EventManager().send_event(
logger.error(f"定时任务 {job.get('name')} 执行失败:{str(e)} - {traceback.format_exc()}")
MessageHelper().put(title=f"{job.get('name')} 执行失败",
message=str(e),
role="system")
eventmanager.send_event(
EventType.SystemError,
{
"type": "scheduler",
"scheduler_id": job_id,
"scheduler_name": job_name,
"scheduler_name": job.get('name'),
"error": str(e),
"traceback": traceback.format_exc()
}
)
# 运行结束
with self._lock:
try:
self._jobs[job_id]["running"] = False
except KeyError:
pass
self.__finish_job(job_id)
def init_plugin_jobs(self):
"""