Compare commits

..

71 Commits
v2.3.0 ... v1

Author SHA1 Message Date
jxxghp
1336b2136d Merge pull request #4340 from jtcymc/main 2025-05-25 07:59:41 +08:00
shaw
b20e21e700 fix(SearchChain): with 关闭线程池
- 使用 with 语句管理 ThreadPoolExecutor,确保线程池正确关闭
2025-05-25 00:50:34 +08:00
jxxghp
c27ab4a4c7 v1.9.19
- 默认关闭自动升级
2025-05-12 16:41:57 +08:00
jxxghp
d9e6532325 更新 version.py 2025-04-10 14:55:24 +08:00
jxxghp
049f16ba01 Merge pull request #4130 from cddjr/fix_v1_mteam 2025-04-10 14:36:21 +08:00
景大侠
6541458326 backport: 适配馒头API变动 2025-04-10 14:18:04 +08:00
jxxghp
9f2912426b Merge pull request #2833 from wikrin/main 2024-10-10 22:49:12 +08:00
Attente
fde33d267a fix: 修正重复的特殊字符
将重复的特殊字符 `—`[U+2014](https://symbl.cc/cn/2014/) 修改为 `―`[U+2015](https://symbl.cc/cn/2015/)
2024-10-10 22:23:43 +08:00
jxxghp
ef7f0afa37 v1.9.17
- 修复115扫码登录问题
- 索引站点新增支持 `PTLGS`
2024-09-18 17:52:40 +08:00
jxxghp
bea77a8243 fix 115 2024-09-18 13:39:39 +08:00
jxxghp
b984b83870 v1.9.16
- 修复了有些情况下新增目录类型为全部时不生效的问题
2024-09-08 13:11:59 +08:00
jxxghp
2153ad48db v1.9.15
- 修复部分通知消息查看详细链接错误的问题
- 修复了麒麟无法索引综艺的问题
- 修复了插件无升级提示图标的问题
2024-08-28 15:01:46 +08:00
jxxghp
c9c43fde74 Merge pull request #2638 from Linvery/main 2024-08-13 12:41:51 +08:00
Linvery
e2c9742f64 fix: 解决推送消息错误的url路径 2024-08-13 12:40:33 +08:00
jxxghp
3d459a40f7 - 仅调整了插件页面的UI 2024-08-13 11:46:01 +08:00
jxxghp
5675cd5b11 v1.9.13
- 修复已知问题
2024-07-30 06:36:30 +08:00
jxxghp
74a4d0bd66 Merge pull request #2614 from InfinityPacer/main 2024-07-30 06:33:00 +08:00
InfinityPacer
2b8c313019 refactor(event): 事件处理调整为深复制,避免多线程环境下数据异常 2024-07-29 23:18:58 +08:00
jxxghp
62fb6b80a3 Merge pull request #2612 from audichuang/main 2024-07-28 18:48:06 +08:00
audichuang
eea86528d8 Modifying the Bilingual Subtitle Matching Specification 2024-07-28 18:07:28 +08:00
jxxghp
84e6abb659 Merge pull request #2600 from InfinityPacer/main 2024-07-23 19:45:18 +08:00
InfinityPacer
da2c755b6d fix(Plugin): 重置插件时初始化调整为reload,以保留默认配置 2024-07-23 19:41:29 +08:00
jxxghp
51f39be9bc Merge pull request #2590 from Akimio521/main 2024-07-20 20:15:04 +08:00
Akimio521
21b762e75c perfect(releasegroup):完善anime字幕组 2024-07-20 20:05:05 +08:00
jxxghp
54095074b6 v1.9.12
- 新增认证站点:`ROUSI`
- 修复Telegram部分消息格式异常问题
2024-07-16 08:07:24 +08:00
jxxghp
33525730b5 fix #2515 2024-07-16 08:04:48 +08:00
jxxghp
71260f04b5 Merge pull request #2560 from InfinityPacer/main 2024-07-12 17:03:55 +08:00
InfinityPacer
e2acec321d fix tips 2024-07-12 16:48:44 +08:00
InfinityPacer
74a462a09f fix SitesHelper import tips 2024-07-12 16:30:06 +08:00
jxxghp
ad9e1a5da6 Merge pull request #2552 from BrettDean/main 2024-07-11 21:26:04 +08:00
Dean
d90e3c29a5 优化微信文本消息发送:支持长文本分块发送 2024-07-11 20:30:09 +08:00
jxxghp
19165eff75 Merge pull request #2537 from InfinityPacer/main 2024-07-09 11:01:53 +08:00
jxxghp
52d0703812 v1.9.11
- 支持环境变量配置DOH域名和DNS服务器
- 问题修复
2024-07-09 08:09:58 +08:00
InfinityPacer
1431a5e82a fix #2518 移除不必要的debug日志 2024-07-09 01:40:37 +08:00
jxxghp
23fe643526 Merge pull request #2534 from InfinityPacer/main 2024-07-08 12:23:16 +08:00
jxxghp
545b3c0482 Merge pull request #2527 from s0urcelab/main 2024-07-08 12:22:50 +08:00
InfinityPacer
f102119eef fix #2526 Backdrop优先调整为取art 2024-07-07 23:03:34 +08:00
s0urce
9bb3d707c9 feat: history query add title field 2024-07-07 18:27:06 +08:00
jxxghp
b892ef50dc Merge pull request #2526 from InfinityPacer/main 2024-07-07 16:49:11 +08:00
InfinityPacer
41e2907168 fix jxxghp/MoviePilot-Plugins#38 2024-07-07 16:32:37 +08:00
jxxghp
14e28ed693 Merge pull request #2518 from InfinityPacer/main 2024-07-07 08:58:39 +08:00
InfinityPacer
79393c21ff feat: 支持插件进行私钥认证 2024-07-06 20:03:49 +08:00
InfinityPacer
cafa4d217c feat: 增加指定的仓库Github token 2024-07-06 16:07:46 +08:00
jxxghp
2b9e69b112 Merge pull request #2515 from BrettDean/main 2024-07-06 07:50:59 +08:00
Dean
3ffcea70a7 Fixed parsing of Telegram entities 2024-07-06 01:44:51 +08:00
jxxghp
ffc72ba6fe fix #2508 2024-07-05 17:03:00 +08:00
jxxghp
848becd946 Merge pull request #2506 from Akimio521/main 2024-07-05 14:44:50 +08:00
Akimio521
71fe96d7f9 feat: 添加 DOH 解析服务器列表至配置文件实现自定义 DOH 服务器 2024-07-05 13:55:48 +08:00
jxxghp
35c7238ede Merge pull request #2503 from Akimio521/main 2024-07-05 11:31:21 +08:00
Akimio521
3578204508 style 2024-07-05 10:52:39 +08:00
Akimio521
c11cf17f62 style:app.core.config.settings 2024-07-05 10:34:57 +08:00
Akimio521
5a59652684 feat:将使用 DOH 域名解析的域名添加至 app.core.config.settings 2024-07-05 09:31:41 +08:00
jxxghp
7f5f31f143 Merge pull request #2484 from InfinityPacer/main 2024-07-02 06:12:13 +08:00
InfinityPacer
dc1cee80b1 fix plugin install and reg 2024-07-02 01:11:42 +08:00
jxxghp
92cb066748 更新 version.py 2024-07-01 21:46:07 +08:00
jxxghp
6c8ef4122b fix e5ec02e043 2024-07-01 12:23:02 +08:00
jxxghp
971b02ac8c - 重新兼容了v1.9.1之前的版本直接升级
- 索引站点新增支持`HDVBits`
- 自定义重命名新增季年份`season_year`占位符
- 修复了普通用户搜索越权问题
2024-07-01 10:46:29 +08:00
jxxghp
d4a9643f47 Merge pull request #2463 from InfinityPacer/main
处理链run_module支持raise_exception
2024-07-01 10:33:55 +08:00
InfinityPacer
e56d31fedc fix exception 2024-06-30 11:50:26 +08:00
InfinityPacer
b9d91c5cd7 feat: DoubanModule触发限流时支持立即抛出限流异常 2024-06-30 11:48:29 +08:00
InfinityPacer
57cdb57331 feat: retry支持立即抛出异常 2024-06-30 11:47:30 +08:00
InfinityPacer
0f7a7ef44f feat: 添加ImmediateException 2024-06-30 11:47:00 +08:00
InfinityPacer
6267b3f670 feat: run_module支持raise_exception 2024-06-30 11:41:00 +08:00
jxxghp
82f77b4729 Merge pull request #2456 from AisukaYuki/main 2024-06-30 09:09:36 +08:00
jxxghp
58da0ebb4f Merge pull request #2460 from thsrite/main 2024-06-30 09:08:35 +08:00
thsrite
7a43e43478 fix 删除文件未删除thumb.jpg 2024-06-29 20:12:26 +08:00
AisukaYuki
e5ec02e043 add 自定义重命名新增季年份season_year 2024-06-29 13:50:40 +08:00
jxxghp
2944c343a8 Merge pull request #2432 from InfinityPacer/main 2024-06-26 18:21:00 +08:00
InfinityPacer
940cc566c8 fix douban rate_limit tips 2024-06-26 18:17:31 +08:00
jxxghp
db7b2cdcac fix error 2024-06-26 17:42:08 +08:00
jxxghp
8111cf5dc8 - 站点索引及用户认证新增支持海胆之家 2024-06-26 16:18:14 +08:00
295 changed files with 11408 additions and 25228 deletions

View File

@@ -1,45 +0,0 @@
name: 功能提案
description: Request for Comments
title: "[RFC]"
labels: ["RFC"]
body:
- type: markdown
attributes:
value: |
一份提案(RFC)定位为 **「在某功能/重构的具体开发前,用于开发者间 review 技术设计/方案的文档」**
目的是让协作的开发者间清晰的知道「要做什么」和「具体会怎么做」,以及所有的开发者都能公开透明的参与讨论;
以便评估和讨论产生的影响 (遗漏的考虑、向后兼容性、与现有功能的冲突)
因此提案侧重在对解决问题的 **方案、设计、步骤** 的描述上。
如果仅希望讨论是否添加或改进某功能本身,请使用 -> [Issue: 功能改进](https://github.com/jxxghp/MoviePilot/issues/new?assignees=&labels=feature+request&projects=&template=feature_request.yml&title=%5BFeature+Request%5D%3A+)
- type: textarea
id: background
attributes:
label: 背景 or 问题
description: 简单描述遇到的什么问题或需要改动什么。可以引用其他 issue、讨论、文档等。
validations:
required: true
- type: textarea
id: goal
attributes:
label: "目标 & 方案简述"
description: 简单描述提案此提案实现后,**预期的目标效果**,以及简单大致描述会采取的方案/步骤,可能会/不会产生什么影响。
validations:
required: true
- type: textarea
id: design
attributes:
label: "方案设计 & 实现步骤"
description: |
详细描述你设计的具体方案,可以考虑拆分列表或要点,一步步描述具体打算如何实现的步骤和相关细节。
这部份不需要一次性写完整,即使在创建完此提案 issue 后,依旧可以再次编辑修改。
validations:
required: false
- type: textarea
id: alternative
attributes:
label: "替代方案 & 对比"
description: |
[可选] 为来实现目标效果,还考虑过什么其他方案,有什么对比?
validations:
required: false

View File

@@ -1,11 +1,11 @@
name: MoviePilot Builder v2
name: MoviePilot Builder
on:
workflow_dispatch:
push:
branches:
- v2
- main
paths:
- 'version.py'
- version.py
jobs:
Docker-build:
@@ -25,7 +25,7 @@ jobs:
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKER_USERNAME }}/moviepilot-v2
images: ${{ secrets.DOCKER_USERNAME }}/moviepilot
tags: |
type=raw,value=${{ env.app_version }}
type=raw,value=latest
@@ -51,25 +51,181 @@ jobs:
linux/amd64
linux/arm64/v8
push: true
build-args: |
MOVIEPILOT_VERSION=${{ env.app_version }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha, scope=${{ github.workflow }}-docker
cache-to: type=gha, scope=${{ github.workflow }}-docker
- name: Delete Release
uses: dev-drprasad/delete-tag-and-release@v1.1
with:
tag_name: ${{ env.app_version }}
delete_release: true
github_token: ${{ secrets.GITHUB_TOKEN }}
Windows-build:
runs-on: windows-latest
name: Build Windows Binary
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Generate Release
uses: softprops/action-gh-release@v2
- name: Init Python 3.11.4
uses: actions/setup-python@v4
with:
tag_name: v${{ env.app_version }}
name: v${{ env.app_version }}
draft: false
prerelease: false
make_latest: false
python-version: '3.11.4'
cache: 'pip'
- name: Install Dependent Packages
run: |
python -m pip install --upgrade pip
pip install wheel pyinstaller
pip install -r requirements.txt
shell: pwsh
- name: Prepare Frontend
run: |
# 下载nginx
Invoke-WebRequest -Uri "http://nginx.org/download/nginx-1.25.2.zip" -OutFile "nginx.zip"
Expand-Archive -Path "nginx.zip" -DestinationPath "nginx-1.25.2"
Move-Item -Path "nginx-1.25.2/nginx-1.25.2" -Destination "nginx"
Remove-Item -Path "nginx.zip"
Remove-Item -Path "nginx-1.25.2" -Recurse -Force
# 下载前端
$FRONTEND_VERSION = (Invoke-WebRequest -Uri "https://api.github.com/repos/jxxghp/MoviePilot-Frontend/releases/latest" | ConvertFrom-Json).tag_name
Invoke-WebRequest -Uri "https://github.com/jxxghp/MoviePilot-Frontend/releases/download/$FRONTEND_VERSION/dist.zip" -OutFile "dist.zip"
Expand-Archive -Path "dist.zip" -DestinationPath "dist"
Move-Item -Path "dist/dist/*" -Destination "nginx/html" -Force
Remove-Item -Path "dist.zip"
Remove-Item -Path "dist" -Recurse -Force
Move-Item -Path "nginx/html/nginx.conf" -Destination "nginx/conf/nginx.conf" -Force
New-Item -Path "nginx/temp" -ItemType Directory -Force
New-Item -Path "nginx/temp/__keep__.txt" -ItemType File -Force
New-Item -Path "nginx/logs" -ItemType Directory -Force
New-Item -Path "nginx/logs/__keep__.txt" -ItemType File -Force
# 下载插件 jxxghp
Invoke-WebRequest -Uri "https://github.com/jxxghp/MoviePilot-Plugins/archive/refs/heads/main.zip" -OutFile "MoviePilot-Plugins-main.zip"
Expand-Archive -Path "MoviePilot-Plugins-main.zip" -DestinationPath "MoviePilot-Plugins-main"
Move-Item -Path "MoviePilot-Plugins-main/MoviePilot-Plugins-main/plugins/*" -Destination "app/plugins/" -Force -ErrorAction SilentlyContinue
Remove-Item -Path "MoviePilot-Plugins-main.zip"
Remove-Item -Path "MoviePilot-Plugins-main" -Recurse -Force
# 下载插件 thsrite
Invoke-WebRequest -Uri "https://github.com/thsrite/MoviePilot-Plugins/archive/refs/heads/main.zip" -OutFile "MoviePilot-Plugins-main.zip"
Expand-Archive -Path "MoviePilot-Plugins-main.zip" -DestinationPath "MoviePilot-Plugins-main"
Move-Item -Path "MoviePilot-Plugins-main/MoviePilot-Plugins-main/plugins/*" -Destination "app/plugins/" -Force -ErrorAction SilentlyContinue
Remove-Item -Path "MoviePilot-Plugins-main.zip"
Remove-Item -Path "MoviePilot-Plugins-main" -Recurse -Force
# 下载插件 honue
Invoke-WebRequest -Uri "https://github.com/honue/MoviePilot-Plugins/archive/refs/heads/main.zip" -OutFile "MoviePilot-Plugins-main.zip"
Expand-Archive -Path "MoviePilot-Plugins-main.zip" -DestinationPath "MoviePilot-Plugins-main"
Move-Item -Path "MoviePilot-Plugins-main/MoviePilot-Plugins-main/plugins/*" -Destination "app/plugins/" -Force -ErrorAction SilentlyContinue
Remove-Item -Path "MoviePilot-Plugins-main.zip"
Remove-Item -Path "MoviePilot-Plugins-main" -Recurse -Force
# 下载插件 InfinityPacer
Invoke-WebRequest -Uri "https://github.com/InfinityPacer/MoviePilot-Plugins/archive/refs/heads/main.zip" -OutFile "MoviePilot-Plugins-main.zip"
Expand-Archive -Path "MoviePilot-Plugins-main.zip" -DestinationPath "MoviePilot-Plugins-main"
Move-Item -Path "MoviePilot-Plugins-main/MoviePilot-Plugins-main/plugins/*" -Destination "app/plugins/" -Force -ErrorAction SilentlyContinue
Remove-Item -Path "MoviePilot-Plugins-main.zip"
Remove-Item -Path "MoviePilot-Plugins-main" -Recurse -Force
# 下载资源
Invoke-WebRequest -Uri "https://github.com/jxxghp/MoviePilot-Resources/archive/refs/heads/main.zip" -OutFile "MoviePilot-Resources-main.zip"
Expand-Archive -Path "MoviePilot-Resources-main.zip" -DestinationPath "MoviePilot-Resources-main"
Move-Item -Path "MoviePilot-Resources-main/MoviePilot-Resources-main/resources/*" -Destination "app/helper/" -Force
Remove-Item -Path "MoviePilot-Resources-main.zip"
Remove-Item -Path "MoviePilot-Resources-main" -Recurse -Force
shell: pwsh
- name: Pyinstaller
run: |
pyinstaller frozen.spec
shell: pwsh
- name: Upload Windows File
uses: actions/upload-artifact@v3
with:
name: windows
path: dist/MoviePilot.exe
Linux-build-amd64:
runs-on: ubuntu-latest
name: Build Linux Amd64
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Init Python 3.11.4
uses: actions/setup-python@v4
with:
python-version: '3.11.4'
cache: 'pip'
- name: Install Dependent Packages
run: |
python -m pip install --upgrade pip
pip install wheel pyinstaller
pip install -r requirements.txt
find app/plugins -name requirements.txt -exec pip install -r {} \;
- name: Prepare Frontend
run: |
wget https://github.com/jxxghp/MoviePilot-Plugins/archive/refs/heads/main.zip
unzip main.zip
mv MoviePilot-Plugins-main/plugins/* app/plugins/
rm main.zip
rm -rf MoviePilot-Plugins-main
wget https://github.com/jxxghp/MoviePilot-Resources/archive/refs/heads/main.zip
unzip main.zip
mv MoviePilot-Resources-main/resources/* app/helper/
rm main.zip
rm -rf MoviePilot-Resources-main
- name: Pyinstaller
run: |
pyinstaller frozen.spec
mv dist/MoviePilot dist/MoviePilot_Amd64
- name: Upload Linux File
uses: actions/upload-artifact@v3
with:
name: linux-amd64
path: dist/MoviePilot_Amd64
Create-release:
permissions: write-all
runs-on: ubuntu-latest
needs: [ Windows-build, Docker-build, Linux-build-amd64]
steps:
- uses: actions/checkout@v2
- name: Release Version
id: release_version
run: |
app_version=$(cat version.py |sed -ne "s/APP_VERSION\s=\s'v\(.*\)'/\1/gp")
echo "app_version=$app_version" >> $GITHUB_ENV
- name: Download Artifact
uses: actions/download-artifact@v3
- name: get release_informations
shell: bash
run: |
mkdir releases
mv ./windows/MoviePilot.exe ./releases/MoviePilot_Win_v${{ env.app_version }}.exe
mv ./linux-amd64/MoviePilot_Amd64 ./releases/MoviePilot_Amd64_v${{ env.app_version }}
- name: Create Release
id: create_release
uses: actions/create-release@latest
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: v${{ env.app_version }}
release_name: v${{ env.app_version }}
body: ${{ github.event.commits[0].message }}
draft: false
prerelease: false
- name: Upload Release Asset
uses: dwenegar/upload-release-assets@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
assets_path: |
./releases/

9
.gitignore vendored
View File

@@ -4,7 +4,6 @@ build/
dist/
nginx/
test.py
safety_report.txt
app/helper/sites.py
app/helper/*.so
app/helper/*.pyd
@@ -12,12 +11,10 @@ app/helper/*.bin
app/plugins/**
!app/plugins/__init__.py
config/cookies/**
config/user.db*
config/user.db
config/sites/**
config/logs/
config/temp/
config/cache/
*.pyc
*.log
.vscode
venv
venv
.DS_Store

View File

@@ -1,16 +1,19 @@
FROM python:3.11.4-slim-bookworm
ARG MOVIEPILOT_VERSION
ENV LANG="C.UTF-8" \
TZ="Asia/Shanghai" \
HOME="/moviepilot" \
CONFIG_DIR="/config" \
TERM="xterm" \
DISPLAY=:987 \
PUID=0 \
PGID=0 \
UMASK=000 \
PORT=3001 \
NGINX_PORT=3000 \
MOVIEPILOT_AUTO_UPDATE=release
PROXY_HOST="" \
MOVIEPILOT_AUTO_UPDATE=false \
AUTH_SITE="iyuu" \
IYUU_SIGN=""
WORKDIR "/app"
RUN apt-get update -y \
&& apt-get upgrade -y \
@@ -27,6 +30,7 @@ RUN apt-get update -y \
busybox \
dumb-init \
jq \
haproxy \
fuse3 \
rsync \
ffmpeg \
@@ -38,7 +42,6 @@ RUN apt-get update -y \
then ln -s /usr/lib/aarch64-linux-musl/libc.so /lib/libc.musl-aarch64.so.1; \
fi \
&& curl https://rclone.org/install.sh | bash \
&& curl --insecure -fsSL https://raw.githubusercontent.com/DDS-Derek/Aria2-Pro-Core/master/aria2-install.sh | bash \
&& apt-get autoremove -y \
&& apt-get clean -y \
&& rm -rf \
@@ -46,12 +49,11 @@ RUN apt-get update -y \
/moviepilot/.cache \
/var/lib/apt/lists/* \
/var/tmp/*
COPY requirements.in requirements.in
COPY requirements.txt requirements.txt
RUN apt-get update -y \
&& apt-get install -y build-essential \
&& pip install --upgrade pip \
&& pip install Cython pip-tools \
&& pip-compile requirements.in \
&& pip install Cython \
&& pip install -r requirements.txt \
&& playwright install-deps chromium \
&& apt-get remove -y build-essential \
@@ -66,26 +68,23 @@ COPY . .
RUN cp -f /app/nginx.conf /etc/nginx/nginx.template.conf \
&& cp -f /app/update /usr/local/bin/mp_update \
&& cp -f /app/entrypoint /entrypoint \
&& cp -f /app/docker_http_proxy.conf /etc/nginx/docker_http_proxy.conf \
&& chmod +x /entrypoint /usr/local/bin/mp_update \
&& mkdir -p ${HOME} \
&& groupadd -r moviepilot -g 918 \
&& useradd -r moviepilot -g moviepilot -d ${HOME} -s /bin/bash -u 918 \
&& mkdir -p ${HOME} /var/lib/haproxy/server-state \
&& groupadd -r moviepilot -g 911 \
&& useradd -r moviepilot -g moviepilot -d ${HOME} -s /bin/bash -u 911 \
&& python_ver=$(python3 -V | awk '{print $2}') \
&& echo "/app/" > /usr/local/lib/python${python_ver%.*}/site-packages/app.pth \
&& echo 'fs.inotify.max_user_watches=5242880' >> /etc/sysctl.conf \
&& echo 'fs.inotify.max_user_instances=5242880' >> /etc/sysctl.conf \
&& locale-gen zh_CN.UTF-8 \
&& FRONTEND_VERSION=$(sed -n "s/^FRONTEND_VERSION\s*=\s*'\([^']*\)'/\1/p" /app/version.py) \
&& FRONTEND_VERSION=$(curl -sL "https://api.github.com/repos/jxxghp/MoviePilot-Frontend/releases/latest" | jq -r .tag_name) \
&& curl -sL "https://github.com/jxxghp/MoviePilot-Frontend/releases/download/${FRONTEND_VERSION}/dist.zip" | busybox unzip -d / - \
&& mv /dist /public \
&& curl -sL "https://github.com/jxxghp/MoviePilot-Plugins/archive/refs/heads/main.zip" | busybox unzip -d /tmp - \
&& mv -f /tmp/MoviePilot-Plugins-main/plugins.v2/* /app/app/plugins/ \
&& cat /tmp/MoviePilot-Plugins-main/package.json | jq -r 'to_entries[] | select(.value.v2 == true) | .key' | awk '{print tolower($0)}' | \
while read -r i; do if [ ! -d "/app/app/plugins/$i" ]; then mv "/tmp/MoviePilot-Plugins-main/plugins/$i" "/app/app/plugins/"; else echo "跳过 $i"; fi; done \
&& mv -f /tmp/MoviePilot-Plugins-main/plugins/* /app/app/plugins/ \
&& curl -sL "https://github.com/jxxghp/MoviePilot-Resources/archive/refs/heads/main.zip" | busybox unzip -d /tmp - \
&& mv -f /tmp/MoviePilot-Resources-main/resources/* /app/app/helper/ \
&& rm -rf /tmp/*
EXPOSE 3000
VOLUME [ "/config" ]
ENTRYPOINT [ "/entrypoint" ]
ENTRYPOINT [ "/entrypoint" ]

View File

@@ -6,7 +6,6 @@
![GitHub repo size](https://img.shields.io/github/repo-size/jxxghp/MoviePilot?style=for-the-badge)
![GitHub issues](https://img.shields.io/github/issues/jxxghp/MoviePilot?style=for-the-badge)
![Docker Pulls](https://img.shields.io/docker/pulls/jxxghp/moviepilot?style=for-the-badge)
![Docker Pulls V2](https://img.shields.io/docker/pulls/jxxghp/moviepilot-v2?style=for-the-badge)
![Platform](https://img.shields.io/badge/platform-Windows%20%7C%20Linux%20%7C%20Synology-blue?style=for-the-badge)

View File

@@ -1,44 +0,0 @@
from abc import ABC, abstractmethod
from pydantic.main import BaseModel
from app.schemas import ActionContext, ActionParams
class BaseAction(BaseModel, ABC):
"""
工作流动作基类
"""
@property
@abstractmethod
def name(self) -> str:
pass
@property
@abstractmethod
def description(self) -> str:
pass
@abstractmethod
def execute(self, params: ActionParams, context: ActionContext) -> ActionContext:
"""
执行动作
"""
raise NotImplementedError
@property
@abstractmethod
def done(self) -> bool:
"""
判断动作是否完成
"""
pass
@property
@abstractmethod
def success(self) -> bool:
"""
判断动作是否成功
"""
pass

View File

@@ -1,42 +0,0 @@
from typing import Optional
from pydantic import Field
from app.actions import BaseAction
from app.schemas import ActionParams, ActionContext
class FetchRssParams(ActionParams):
"""
获取RSS资源列表参数
"""
url: str = Field(None, description="RSS地址")
proxy: Optional[bool] = Field(False, description="是否使用代理")
timeout: Optional[int] = Field(15, description="超时时间")
headers: Optional[dict] = Field(None, description="请求头")
recognize: Optional[bool] = Field(False, description="是否识别")
class FetchRssAction(BaseAction):
"""
获取RSS资源列表
"""
@property
def name(self) -> str:
return "获取RSS资源列表"
@property
def description(self) -> str:
return "请求RSS地址获取数据并解析为资源列表"
async def execute(self, params: FetchRssParams, context: ActionContext) -> ActionContext:
pass
@property
def done(self) -> bool:
return True
@property
def success(self) -> bool:
return True

View File

@@ -1,42 +0,0 @@
from typing import Optional
from pydantic import Field
from app.actions import BaseAction
from app.schemas import ActionParams, ActionContext
class SearchTorrentsParams(ActionParams):
"""
搜索站点资源参数
"""
name: str = Field(None, description="资源名称")
year: Optional[int] = Field(None, description="年份")
type: Optional[str] = Field(None, description="资源类型 (电影/电视剧)")
season: Optional[int] = Field(None, description="季度")
recognize: Optional[bool] = Field(False, description="是否识别")
class SearchTorrentsAction(BaseAction):
"""
搜索站点资源
"""
@property
def name(self) -> str:
return "搜索站点资源"
@property
def description(self) -> str:
return "根据关键字搜索站点种子资源"
@property
def done(self) -> bool:
return True
@property
def success(self) -> bool:
return True
async def execute(self, params: SearchTorrentsParams, context: ActionContext) -> ActionContext:
pass

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter
from app.api.endpoints import login, user, site, message, webhook, subscribe, \
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
transfer, mediaserver, bangumi, storage, discover, recommend, workflow
local, transfer, mediaserver, bangumi, aliyun, u115
api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"])
@@ -20,10 +20,9 @@ api_router.include_router(system.router, prefix="/system", tags=["system"])
api_router.include_router(plugin.router, prefix="/plugin", tags=["plugin"])
api_router.include_router(download.router, prefix="/download", tags=["download"])
api_router.include_router(dashboard.router, prefix="/dashboard", tags=["dashboard"])
api_router.include_router(storage.router, prefix="/storage", tags=["storage"])
api_router.include_router(local.router, prefix="/local", tags=["local"])
api_router.include_router(transfer.router, prefix="/transfer", tags=["transfer"])
api_router.include_router(mediaserver.router, prefix="/mediaserver", tags=["mediaserver"])
api_router.include_router(bangumi.router, prefix="/bangumi", tags=["bangumi"])
api_router.include_router(discover.router, prefix="/discover", tags=["discover"])
api_router.include_router(recommend.router, prefix="/recommend", tags=["recommend"])
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
api_router.include_router(aliyun.router, prefix="/aliyun", tags=["aliyun"])
api_router.include_router(u115.router, prefix="/u115", tags=["115"])

198
app/api/endpoints/aliyun.py Normal file
View File

@@ -0,0 +1,198 @@
from pathlib import Path
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from starlette.responses import Response
from app import schemas
from app.chain.transfer import TransferChain
from app.core.config import settings
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token, verify_uri_token
from app.helper.aliyun import AliyunHelper
from app.helper.progress import ProgressHelper
from app.schemas.types import ProgressKey
router = APIRouter()
@router.get("/qrcode", summary="生成二维码内容", response_model=schemas.Response)
def qrcode(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
生成二维码
"""
qrcode_data, errmsg = AliyunHelper().generate_qrcode()
if qrcode_data:
return schemas.Response(success=True, data=qrcode_data)
return schemas.Response(success=False, message=errmsg)
@router.get("/check", summary="二维码登录确认", response_model=schemas.Response)
def check(ck: str, t: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
二维码登录确认
"""
if not ck or not t:
return schemas.Response(success=False, message="参数错误")
data, errmsg = AliyunHelper().check_login(ck, t)
if data:
return schemas.Response(success=True, data=data)
return schemas.Response(success=False, message=errmsg)
@router.get("/userinfo", summary="查询用户信息", response_model=schemas.Response)
def userinfo(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询用户信息
"""
aliyunhelper = AliyunHelper()
# 查询用户信息返回
info = aliyunhelper.user_info()
if info:
return schemas.Response(success=True, data=info)
return schemas.Response(success=False)
@router.post("/list", summary="所有目录和文件(阿里云盘)", response_model=List[schemas.FileItem])
def list_aliyun(fileitem: schemas.FileItem,
sort: str = 'updated_at',
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询当前目录下所有目录和文件
:param fileitem: 文件夹信息
:param sort: 排序方式name:按名称排序time:按修改时间排序
:param _: token
:return: 所有目录和文件
"""
if not fileitem.fileid:
return []
if not fileitem.path:
path = "/"
else:
path = fileitem.path
if sort == "time":
sort = "updated_at"
if fileitem.type == "file":
fileitem = AliyunHelper().detail(drive_id=fileitem.drive_id, file_id=fileitem.fileid, path=path)
if fileitem:
return [fileitem]
return []
return AliyunHelper().list(drive_id=fileitem.drive_id,
parent_file_id=fileitem.fileid,
path=path,
order_by=sort)
@router.post("/mkdir", summary="创建目录(阿里云盘)", response_model=schemas.Response)
def mkdir_aliyun(fileitem: schemas.FileItem,
name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
创建目录
"""
if not fileitem.fileid or not name:
return schemas.Response(success=False)
result = AliyunHelper().create_folder(drive_id=fileitem.drive_id, parent_file_id=fileitem.fileid,
name=name, path=fileitem.path)
if result:
return schemas.Response(success=True)
return schemas.Response(success=False)
@router.post("/delete", summary="删除文件或目录(阿里云盘)", response_model=schemas.Response)
def delete_aliyun(fileitem: schemas.FileItem,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除文件或目录
"""
if not fileitem.fileid:
return schemas.Response(success=False)
result = AliyunHelper().delete(drive_id=fileitem.drive_id, file_id=fileitem.fileid)
if result:
return schemas.Response(success=True)
return schemas.Response(success=False)
@router.get("/download", summary="下载文件(阿里云盘)")
def download_aliyun(fileid: str,
drive_id: str = None,
_: schemas.TokenPayload = Depends(verify_uri_token)) -> Any:
"""
下载文件或目录
"""
if not fileid:
return schemas.Response(success=False)
url = AliyunHelper().download(drive_id=drive_id, file_id=fileid)
if url:
# 重定向
return Response(status_code=302, headers={"Location": url})
raise HTTPException(status_code=500, detail="下载文件出错")
@router.post("/rename", summary="重命名文件或目录(阿里云盘)", response_model=schemas.Response)
def rename_aliyun(fileitem: schemas.FileItem,
new_name: str,
recursive: bool = False,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
重命名文件或目录
"""
if not fileitem.fileid or not new_name:
return schemas.Response(success=False)
result = AliyunHelper().rename(drive_id=fileitem.drive_id, file_id=fileitem.fileid, name=new_name)
if result:
if recursive:
transferchain = TransferChain()
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIO_TRACK_EXT
# 递归修改目录内文件(智能识别命名)
sub_files: List[schemas.FileItem] = list_aliyun(fileitem=fileitem)
if sub_files:
# 开始进度
progress = ProgressHelper()
progress.start(ProgressKey.BatchRename)
total = len(sub_files)
handled = 0
for sub_file in sub_files:
handled += 1
progress.update(value=handled / total * 100,
text=f"正在处理 {sub_file.name} ...",
key=ProgressKey.BatchRename)
if sub_file.type == "dir":
continue
if not sub_file.extension:
continue
if f".{sub_file.extension.lower()}" not in media_exts:
continue
sub_path = Path(f"{fileitem.path}{sub_file.name}")
meta = MetaInfoPath(sub_path)
mediainfo = transferchain.recognize_media(meta)
if not mediainfo:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到媒体信息")
new_path = transferchain.recommend_name(meta=meta, mediainfo=mediainfo)
if not new_path:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到新名称")
ret: schemas.Response = rename_aliyun(fileitem=sub_file,
new_name=Path(new_path).name,
recursive=False)
if not ret.success:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 重命名失败!")
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=True)
return schemas.Response(success=False)
@router.get("/image", summary="读取图片(阿里云盘)", response_model=schemas.Response)
def image_aliyun(fileid: str, drive_id: str = None, _: schemas.TokenPayload = Depends(verify_uri_token)) -> Any:
"""
读取图片
"""
if not fileid:
return schemas.Response(success=False)
url = AliyunHelper().download(drive_id=drive_id, file_id=fileid)
if url:
# 重定向
return Response(status_code=302, headers={"Location": url})
raise HTTPException(status_code=500, detail="下载图片出错")

View File

@@ -10,6 +10,19 @@ from app.core.security import verify_token
router = APIRouter()
@router.get("/calendar", summary="Bangumi每日放送", response_model=List[schemas.MediaInfo])
def calendar(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览Bangumi每日放送
"""
medias = BangumiChain().calendar()
if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return []
@router.get("/credits/{bangumiid}", summary="查询Bangumi演职员表", response_model=List[schemas.MediaPerson])
def bangumi_credits(bangumiid: int,
page: int = 1,
@@ -50,14 +63,13 @@ def bangumi_person(person_id: int,
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def bangumi_person_credits(person_id: int,
page: int = 1,
count: int = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物参演作品
"""
medias = BangumiChain().person_credits(person_id=person_id)
if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return [media.to_dict() for media in medias[(page - 1) * 20: page * 20]]
return []

View File

@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
from app import schemas
from app.chain.dashboard import DashboardChain
from app.chain.storage import StorageChain
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db.models.transferhistory import TransferHistory
@@ -18,11 +17,11 @@ router = APIRouter()
@router.get("/statistic", summary="媒体数量统计", response_model=schemas.Statistic)
def statistic(name: str = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询媒体数量统计信息
"""
media_statistics: Optional[List[schemas.Statistic]] = DashboardChain().media_statistic(name)
media_statistics: Optional[List[schemas.Statistic]] = DashboardChain().media_statistic()
if media_statistics:
# 汇总各媒体库统计信息
ret_statistic = schemas.Statistic()
@@ -44,31 +43,23 @@ def statistic2(_: str = Depends(verify_apitoken)) -> Any:
return statistic()
@router.get("/storage", summary="本地存储空间", response_model=schemas.Storage)
@router.get("/storage", summary="存储空间", response_model=schemas.Storage)
def storage(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询本地存储空间信息
查询存储空间信息
"""
total, available = 0, 0
dirs = DirectoryHelper().get_dirs()
if not dirs:
return schemas.Storage(total_storage=total, used_storage=total - available)
storages = set([d.library_storage for d in dirs if d.library_storage])
for _storage in storages:
_usage = StorageChain().storage_usage(_storage)
if _usage:
total += _usage.total
available += _usage.available
library_dirs = DirectoryHelper().get_library_dirs()
total_storage, free_storage = SystemUtils.space_usage([Path(d.path) for d in library_dirs if d.path])
return schemas.Storage(
total_storage=total,
used_storage=total - available
total_storage=total_storage,
used_storage=total_storage - free_storage
)
@router.get("/storage2", summary="本地存储空间API_TOKEN", response_model=schemas.Storage)
@router.get("/storage2", summary="存储空间API_TOKEN", response_model=schemas.Storage)
def storage2(_: str = Depends(verify_apitoken)) -> Any:
"""
查询本地存储空间信息 API_TOKEN认证?token=xxx
查询存储空间信息 API_TOKEN认证?token=xxx
"""
return storage()
@@ -82,16 +73,16 @@ def processes(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/downloader", summary="下载器信息", response_model=schemas.DownloaderInfo)
def downloader(name: str = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
def downloader(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询下载器信息
"""
# 下载目录空间
download_dirs = DirectoryHelper().get_local_download_dirs()
_, free_space = SystemUtils.space_usage([Path(d.download_path) for d in download_dirs])
download_dirs = DirectoryHelper().get_download_dirs()
_, free_space = SystemUtils.space_usage([Path(d.path) for d in download_dirs if d.path])
# 下载器信息
downloader_info = schemas.DownloaderInfo()
transfer_infos = DashboardChain().downloader_info(name)
transfer_infos = DashboardChain().downloader_info()
if transfer_infos:
for transfer_info in transfer_infos:
downloader_info.download_speed += transfer_info.download_speed

View File

@@ -1,130 +0,0 @@
from typing import Any, List
from fastapi import APIRouter, Depends
from app import schemas
from app.core.event import eventmanager
from app.core.security import verify_token
from app.schemas import DiscoverSourceEventData
from app.schemas.types import ChainEventType, MediaType
from chain.bangumi import BangumiChain
from chain.douban import DoubanChain
from chain.tmdb import TmdbChain
router = APIRouter()
@router.get("/source", summary="获取探索数据源", response_model=List[schemas.DiscoverMediaSource])
def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取探索数据源
"""
# 广播事件,请示额外的探索数据源支持
event_data = DiscoverSourceEventData()
event = eventmanager.send_event(ChainEventType.DiscoverSource, event_data)
# 使用事件返回的上下文数据
if event and event.event_data:
event_data: DiscoverSourceEventData = event.event_data
if event_data.extra_sources:
return event_data.extra_sources
return []
@router.get("/bangumi", summary="探索Bangumi", response_model=List[schemas.MediaInfo])
def bangumi(type: int = 2,
cat: int = None,
sort: str = 'rank',
year: int = None,
page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
探索Bangumi
"""
medias = BangumiChain().discover(type=type, cat=cat, sort=sort, year=year,
limit=count, offset=(page - 1) * count)
if medias:
return [media.to_dict() for media in medias]
return []
@router.get("/douban_movies", summary="探索豆瓣电影", response_model=List[schemas.MediaInfo])
def douban_movies(sort: str = "R",
tags: str = "",
page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣电影信息
"""
movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@router.get("/douban_tvs", summary="探索豆瓣剧集", response_model=List[schemas.MediaInfo])
def douban_tvs(sort: str = "R",
tags: str = "",
page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
tvs = DoubanChain().douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@router.get("/tmdb_movies", summary="探索TMDB电影", response_model=List[schemas.MediaInfo])
def tmdb_movies(sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "",
with_keywords: str = "",
with_watch_providers: str = "",
vote_average: float = 0,
vote_count: int = 0,
release_date: str = "",
page: int = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB电影信息
"""
movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [movie.to_dict() for movie in movies] if movies else []
@router.get("/tmdb_tvs", summary="探索TMDB剧集", response_model=List[schemas.MediaInfo])
def tmdb_tvs(sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "",
with_keywords: str = "",
with_watch_providers: str = "",
vote_average: float = 0,
vote_count: int = 0,
release_date: str = "",
page: int = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [tv.to_dict() for tv in tvs] if tvs else []

View File

@@ -1,16 +1,33 @@
from typing import Any, List
from typing import List, Any
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Response
from app import schemas
from app.chain.douban import DoubanChain
from app.core.config import settings
from app.core.context import MediaInfo
from app.core.security import verify_token
from app.schemas import MediaType
from app.utils.http import RequestUtils
router = APIRouter()
@router.get("/img", summary="豆瓣图片代理")
def douban_img(imgurl: str) -> Any:
"""
豆瓣图片代理
"""
if not imgurl:
return None
response = RequestUtils(headers={
'Referer': "https://movie.douban.com/"
}, ua=settings.USER_AGENT).get_res(url=imgurl)
if response:
return Response(content=response.content, media_type="image/jpeg")
return None
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def douban_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@@ -33,9 +50,133 @@ def douban_person_credits(person_id: int,
return []
@router.get("/showing", summary="豆瓣正在热映", response_model=List[schemas.MediaInfo])
def movie_showing(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣正在热映
"""
movies = DoubanChain().movie_showing(page=page, count=count)
if movies:
return [media.to_dict() for media in movies]
return []
@router.get("/movies", summary="豆瓣电影", response_model=List[schemas.MediaInfo])
def douban_movies(sort: str = "R",
tags: str = "",
page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣电影信息
"""
movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
if movies:
return [media.to_dict() for media in movies]
return []
@router.get("/tvs", summary="豆瓣剧集", response_model=List[schemas.MediaInfo])
def douban_tvs(sort: str = "R",
tags: str = "",
page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
tvs = DoubanChain().douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
if tvs:
return [media.to_dict() for media in tvs]
return []
@router.get("/movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
def movie_top250(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
movies = DoubanChain().movie_top250(page=page, count=count)
if movies:
return [media.to_dict() for media in movies]
return []
@router.get("/tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
def tv_weekly_chinese(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
中国每周剧集口碑榜
"""
tvs = DoubanChain().tv_weekly_chinese(page=page, count=count)
if tvs:
return [media.to_dict() for media in tvs]
return []
@router.get("/tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
def tv_weekly_global(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
全球每周剧集口碑榜
"""
tvs = DoubanChain().tv_weekly_global(page=page, count=count)
if tvs:
return [media.to_dict() for media in tvs]
return []
@router.get("/tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
def tv_animation(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门动画剧集
"""
tvs = DoubanChain().tv_animation(page=page, count=count)
if tvs:
return [media.to_dict() for media in tvs]
return []
@router.get("/movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
def movie_hot(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门电影
"""
movies = DoubanChain().movie_hot(page=page, count=count)
if movies:
return [media.to_dict() for media in movies]
return []
@router.get("/tv_hot", summary="豆瓣热门电视剧", response_model=List[schemas.MediaInfo])
def tv_hot(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门电视剧
"""
tvs = DoubanChain().tv_hot(page=page, count=count)
if tvs:
return [media.to_dict() for media in tvs]
return []
@router.get("/credits/{doubanid}/{type_name}", summary="豆瓣演员阵容", response_model=List[schemas.MediaPerson])
def douban_credits(doubanid: str,
type_name: str,
page: int = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据豆瓣ID查询演员阵容type_name: 电影/电视剧

View File

@@ -1,6 +1,6 @@
from typing import Any, List
from fastapi import APIRouter, Depends, Body
from fastapi import APIRouter, Depends
from app import schemas
from app.chain.download import DownloadChain
@@ -9,29 +9,24 @@ from app.core.context import MediaInfo, Context, TorrentInfo
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.db.models.user import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user
from app.schemas.types import SystemConfigKey
from app.db.userauth import get_current_active_user
router = APIRouter()
@router.get("/", summary="正在下载", response_model=List[schemas.DownloadingTorrent])
def current(
name: str = None,
def read(
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询正在下载的任务
"""
return DownloadChain().downloading(name)
return DownloadChain().downloading()
@router.post("/", summary="添加下载(含媒体信息)", response_model=schemas.Response)
def download(
media_in: schemas.MediaInfo,
torrent_in: schemas.TorrentInfo,
downloader: str = Body(None),
save_path: str = Body(None),
current_user: User = Depends(get_current_active_user)) -> Any:
"""
添加下载任务(含媒体信息)
@@ -50,8 +45,7 @@ def download(
media_info=mediainfo,
torrent_info=torrentinfo
)
did = DownloadChain().download_single(context=context, username=current_user.name,
downloader=downloader, save_path=save_path, source="Manual")
did = DownloadChain().download_single(context=context, username=current_user.name)
if not did:
return schemas.Response(success=False, message="任务添加失败")
return schemas.Response(success=True, data={
@@ -62,8 +56,6 @@ def download(
@router.post("/add", summary="添加下载(不含媒体信息)", response_model=schemas.Response)
def add(
torrent_in: schemas.TorrentInfo,
downloader: str = Body(None),
save_path: str = Body(None),
current_user: User = Depends(get_current_active_user)) -> Any:
"""
添加下载任务(不含媒体信息)
@@ -83,8 +75,7 @@ def add(
media_info=mediainfo,
torrent_info=torrentinfo
)
did = DownloadChain().download_single(context=context, username=current_user.name,
downloader=downloader, save_path=save_path, source="Manual")
did = DownloadChain().download_single(context=context, username=current_user.name)
if not did:
return schemas.Response(success=False, message="任务添加失败")
return schemas.Response(success=True, data={
@@ -104,8 +95,9 @@ def start(
@router.get("/stop/{hashString}", summary="暂停任务", response_model=schemas.Response)
def stop(hashString: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def stop(
hashString: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
暂停下载任务
"""
@@ -113,20 +105,10 @@ def stop(hashString: str,
return schemas.Response(success=True if ret else False)
@router.get("/clients", summary="查询可用下载器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询可用下载器
"""
downloaders: List[dict] = SystemConfigOper().get(SystemConfigKey.Downloaders)
if downloaders:
return [{"name": d.get("name"), "type": d.get("type")} for d in downloaders if d.get("enabled")]
return []
@router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response)
def delete(hashString: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def info(
hashString: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除下载任务
"""

View File

@@ -1,20 +1,19 @@
from pathlib import Path
from typing import List, Any
import jieba
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app import schemas
from app.chain.storage import StorageChain
from app.core.config import settings
from app.chain.transfer import TransferChain
from app.core.event import eventmanager
from app.core.security import verify_token
from app.db import get_db
from app.db.models import User
from app.db.models.downloadhistory import DownloadHistory
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser
from app.schemas.types import EventType, MediaType
from app.db.userauth import get_current_active_superuser
from app.schemas.types import EventType
router = APIRouter()
@@ -41,7 +40,7 @@ def delete_download_history(history_in: schemas.DownloadHistory,
return schemas.Response(success=True)
@router.get("/transfer", summary="查询整理记录", response_model=schemas.Response)
@router.get("/transfer", summary="查询转移历史记录", response_model=schemas.Response)
def transfer_history(title: str = None,
page: int = 1,
count: int = 30,
@@ -49,7 +48,7 @@ def transfer_history(title: str = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询整理记录
查询转移历史记录
"""
if title == "失败":
title = None
@@ -59,9 +58,6 @@ def transfer_history(title: str = None,
status = True
if title:
if settings.TOKENIZED_SEARCH:
words = jieba.cut(title, HMM=False)
title = "%".join(words)
total = TransferHistory.count_by_title(db, title=title, status=status)
result = TransferHistory.list_by_title(db, title=title, page=page,
count=count, status=status)
@@ -76,29 +72,28 @@ def transfer_history(title: str = None,
})
@router.delete("/transfer", summary="删除整理记录", response_model=schemas.Response)
@router.delete("/transfer", summary="删除转移历史记录", response_model=schemas.Response)
def delete_transfer_history(history_in: schemas.TransferHistory,
deletesrc: bool = False,
deletedest: bool = False,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除整理记录
删除转移历史记录
"""
history: TransferHistory = TransferHistory.get(db, history_in.id)
history = TransferHistory.get(db, history_in.id)
if not history:
return schemas.Response(success=False, message="记录不存在")
return schemas.Response(success=False, msg="记录不存在")
# 册除媒体库文件
if deletedest and history.dest_fileitem:
dest_fileitem = schemas.FileItem(**history.dest_fileitem)
StorageChain().delete_media_file(fileitem=dest_fileitem, mtype=MediaType(history.type))
# 删除源文件
if deletesrc and history.src_fileitem:
src_fileitem = schemas.FileItem(**history.src_fileitem)
state = StorageChain().delete_media_file(src_fileitem)
if deletedest and history.dest:
state, msg = TransferChain().delete_files(Path(history.dest))
if not state:
return schemas.Response(success=False, message=f"{src_fileitem.path} 删除失败")
return schemas.Response(success=False, msg=msg)
# 删除源文件
if deletesrc and history.src:
state, msg = TransferChain().delete_files(Path(history.src))
if not state:
return schemas.Response(success=False, msg=msg)
# 发送事件
eventmanager.send_event(
EventType.DownloadFileDeleted,
@@ -112,11 +107,11 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
return schemas.Response(success=True)
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
@router.get("/empty/transfer", summary="清空转移历史记录", response_model=schemas.Response)
def delete_transfer_history(db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser)) -> Any:
"""
清空整理记录
清空转移历史记录
"""
TransferHistory.truncate(db)
return schemas.Response(success=True)

273
app/api/endpoints/local.py Normal file
View File

@@ -0,0 +1,273 @@
import shutil
from pathlib import Path
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from starlette.responses import FileResponse, Response
from app import schemas
from app.chain.transfer import TransferChain
from app.core.config import settings
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token, verify_uri_token
from app.helper.progress import ProgressHelper
from app.log import logger
from app.schemas.types import ProgressKey
from app.utils.system import SystemUtils
router = APIRouter()
IMAGE_TYPES = [".jpg", ".png", ".gif", ".bmp", ".jpeg", ".webp"]
@router.post("/list", summary="所有目录和文件(本地)", response_model=List[schemas.FileItem])
def list_local(fileitem: schemas.FileItem,
sort: str = 'time',
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询当前目录下所有目录和文件
:param fileitem: 文件项
:param sort: 排序方式name:按名称排序time:按修改时间排序
:param _: token
:return: 所有目录和文件
"""
# 返回结果
ret_items = []
path = fileitem.path
if not fileitem.path or fileitem.path == "/":
if SystemUtils.is_windows():
partitions = SystemUtils.get_windows_drives() or ["C:/"]
for partition in partitions:
ret_items.append(schemas.FileItem(
type="dir",
path=partition + "/",
name=partition,
basename=partition
))
return ret_items
else:
path = "/"
else:
if SystemUtils.is_windows():
path = path.lstrip("/")
elif not path.startswith("/"):
path = "/" + path
# 遍历目录
path_obj = Path(path)
if not path_obj.exists():
logger.warn(f"目录不存在:{path}")
return []
# 如果是文件
if path_obj.is_file():
ret_items.append(schemas.FileItem(
type="file",
path=str(path_obj).replace("\\", "/"),
name=path_obj.name,
basename=path_obj.stem,
extension=path_obj.suffix[1:],
size=path_obj.stat().st_size,
modify_time=path_obj.stat().st_mtime,
))
return ret_items
# 扁历所有目录
for item in SystemUtils.list_sub_directory(path_obj):
ret_items.append(schemas.FileItem(
type="dir",
path=str(item).replace("\\", "/") + "/",
name=item.name,
basename=item.stem,
modify_time=item.stat().st_mtime,
))
# 遍历所有文件,不含子目录
for item in SystemUtils.list_sub_files(path_obj,
settings.RMT_MEDIAEXT
+ settings.RMT_SUBEXT
+ IMAGE_TYPES
+ [".nfo"]):
ret_items.append(schemas.FileItem(
type="file",
path=str(item).replace("\\", "/"),
name=item.name,
basename=item.stem,
extension=item.suffix[1:],
size=item.stat().st_size,
modify_time=item.stat().st_mtime,
))
# 排序
if sort == 'time':
ret_items.sort(key=lambda x: x.modify_time, reverse=True)
else:
ret_items.sort(key=lambda x: x.name, reverse=False)
return ret_items
@router.get("/listdir", summary="所有目录(本地,不含文件)", response_model=List[schemas.FileItem])
def list_local_dir(path: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询当前目录下所有目录
"""
# 返回结果
ret_items = []
if not path or path == "/":
if SystemUtils.is_windows():
partitions = SystemUtils.get_windows_drives() or ["C:/"]
for partition in partitions:
ret_items.append(schemas.FileItem(
type="dir",
path=partition + "/",
name=partition,
children=[]
))
return ret_items
else:
path = "/"
else:
if not SystemUtils.is_windows() and not path.startswith("/"):
path = "/" + path
# 遍历目录
path_obj = Path(path)
if not path_obj.exists():
logger.warn(f"目录不存在:{path}")
return []
# 扁历所有目录
for item in SystemUtils.list_sub_directory(path_obj):
ret_items.append(schemas.FileItem(
type="dir",
path=str(item).replace("\\", "/") + "/",
name=item.name,
children=[]
))
return ret_items
@router.post("/mkdir", summary="创建目录(本地)", response_model=schemas.Response)
def mkdir_local(fileitem: schemas.FileItem,
name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
创建目录
"""
if not fileitem.path:
return schemas.Response(success=False)
path_obj = Path(fileitem.path) / name
if path_obj.exists():
return schemas.Response(success=False)
path_obj.mkdir(parents=True, exist_ok=True)
return schemas.Response(success=True)
@router.post("/delete", summary="删除文件或目录(本地)", response_model=schemas.Response)
def delete_local(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除文件或目录
"""
if not fileitem.path:
return schemas.Response(success=False)
path_obj = Path(fileitem.path)
if not path_obj.exists():
return schemas.Response(success=True)
if path_obj.is_file():
path_obj.unlink()
else:
shutil.rmtree(path_obj, ignore_errors=True)
return schemas.Response(success=True)
@router.get("/download", summary="下载文件(本地)")
def download_local(path: str, _: schemas.TokenPayload = Depends(verify_uri_token)) -> Any:
"""
下载文件或目录
"""
if not path:
return schemas.Response(success=False)
path_obj = Path(path)
if not path_obj.exists():
raise HTTPException(status_code=404, detail="文件不存在")
if path_obj.is_file():
# 做为文件流式下载
return FileResponse(path_obj)
else:
# 做为压缩包下载
shutil.make_archive(base_name=path_obj.stem, format="zip", root_dir=path_obj)
reponse = Response(content=path_obj.read_bytes(), media_type="application/zip")
# 删除压缩包
Path(f"{path_obj.stem}.zip").unlink()
return reponse
@router.post("/rename", summary="重命名文件或目录(本地)", response_model=schemas.Response)
def rename_local(fileitem: schemas.FileItem,
new_name: str,
recursive: bool = False,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
重命名文件或目录
"""
if not fileitem.path or not new_name:
return schemas.Response(success=False)
path_obj = Path(fileitem.path)
if not path_obj.exists():
return schemas.Response(success=False)
path_obj.rename(path_obj.parent / new_name)
if recursive:
transferchain = TransferChain()
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIO_TRACK_EXT
# 递归修改目录内文件(智能识别命名)
sub_files: List[schemas.FileItem] = list_local(fileitem=fileitem)
if sub_files:
# 开始进度
progress = ProgressHelper()
progress.start(ProgressKey.BatchRename)
total = len(sub_files)
handled = 0
for sub_file in sub_files:
handled += 1
progress.update(value=handled / total * 100,
text=f"正在处理 {sub_file.name} ...",
key=ProgressKey.BatchRename)
if sub_file.type == "dir":
continue
if not sub_file.extension:
continue
if f".{sub_file.extension.lower()}" not in media_exts:
continue
sub_path = Path(sub_file.path)
meta = MetaInfoPath(sub_path)
mediainfo = transferchain.recognize_media(meta)
if not mediainfo:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到媒体信息")
new_path = transferchain.recommend_name(meta=meta, mediainfo=mediainfo)
if not new_path:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到新名称")
ret: schemas.Response = rename_local(fileitem, new_name=Path(new_path).name, recursive=False)
if not ret.success:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 重命名失败!")
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=True)
@router.get("/image", summary="读取图片(本地)")
def image_local(path: str, _: schemas.TokenPayload = Depends(verify_uri_token)) -> Any:
"""
读取图片
"""
if not path:
return None
path_obj = Path(path)
if not path_obj.exists():
return None
if not path_obj.is_file():
return None
# 判断是否图片文件
if path_obj.suffix.lower() not in IMAGE_TYPES:
raise HTTPException(status_code=500, detail="图片读取出错")
return Response(content=path_obj.read_bytes(), media_type="image/jpeg")

View File

@@ -1,50 +1,77 @@
from datetime import timedelta
from typing import Any, List
from fastapi import APIRouter, Depends, Form, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Form
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session
from app import schemas
from app.chain.tmdb import TmdbChain
from app.chain.user import UserChain
from app.chain.mediaserver import MediaServerChain
from app.core import security
from app.core.config import settings
from app.core.security import get_password_hash
from app.db import get_db
from app.db.models.user import User
from app.helper.sites import SitesHelper
from app.log import logger
from app.utils.web import WebUtils
router = APIRouter()
@router.post("/access-token", summary="获取token", response_model=schemas.Token)
def login_access_token(
async def login_access_token(
db: Session = Depends(get_db),
form_data: OAuth2PasswordRequestForm = Depends(),
otp_password: str = Form(None)
) -> Any:
"""
获取认证Token
"""
success, user_or_message = UserChain().user_authenticate(username=form_data.username,
password=form_data.password,
mfa_code=otp_password)
# 检查数据库
success, user = User.authenticate(
db=db,
name=form_data.username,
password=form_data.password,
otp_password=otp_password
)
if not success:
raise HTTPException(status_code=401, detail=user_or_message)
# 认证不成功
if not user:
# 未找到用户,请求协助认证
logger.warn(f"登录用户 {form_data.username} 本地不存在,尝试辅助认证 ...")
token = UserChain().user_authenticate(form_data.username, form_data.password)
if not token:
logger.warn(f"用户 {form_data.username} 登录失败!")
raise HTTPException(status_code=401, detail="用户名、密码、二次校验码不正确")
else:
logger.info(f"用户 {form_data.username} 辅助认证成功,用户信息: {token},以普通用户登录...")
# 加入用户信息表
logger.info(f"创建用户: {form_data.username}")
user = User(name=form_data.username, is_active=True,
is_superuser=False, hashed_password=get_password_hash(token))
user.create(db)
else:
# 用户存在,但认证失败
logger.warn(f"用户 {user.name} 登录失败!")
raise HTTPException(status_code=401, detail="用户名、密码或二次校验码不正确")
elif user and not user.is_active:
raise HTTPException(status_code=403, detail="用户未启用")
logger.info(f"用户 {user.name} 登录成功!")
level = SitesHelper().auth_level
return schemas.Token(
access_token=security.create_access_token(
userid=user_or_message.id,
username=user_or_message.name,
super_user=user_or_message.is_superuser,
userid=user.id,
username=user.name,
super_user=user.is_superuser,
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
level=level
),
token_type="bearer",
super_user=user_or_message.is_superuser,
user_id=user_or_message.id,
user_name=user_or_message.name,
avatar=user_or_message.avatar,
super_user=user.is_superuser,
user_name=user.name,
avatar=user.avatar,
level=level
)
@@ -54,12 +81,10 @@ def wallpaper() -> Any:
"""
获取登录页面电影海报
"""
if settings.WALLPAPER == "bing":
url = WebUtils.get_bing_wallpaper()
elif settings.WALLPAPER == "mediaserver":
url = MediaServerChain().get_latest_wallpaper()
else:
if settings.WALLPAPER == "tmdb":
url = TmdbChain().get_random_wallpager()
else:
url = WebUtils.get_bing_wallpaper()
if url:
return schemas.Response(
success=True,
@@ -73,9 +98,7 @@ def wallpapers() -> Any:
"""
获取登录页面电影海报
"""
if settings.WALLPAPER == "bing":
return WebUtils.get_bing_wallpapers()
elif settings.WALLPAPER == "mediaserver":
return MediaServerChain().get_latest_wallpapers()
else:
if settings.WALLPAPER == "tmdb":
return TmdbChain().get_trending_wallpapers()
else:
return WebUtils.get_bing_wallpapers()

View File

@@ -5,14 +5,11 @@ from fastapi import APIRouter, Depends
from app import schemas
from app.chain.media import MediaChain
from app.chain.tmdb import TmdbChain
from app.core.config import settings
from app.core.context import Context
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo, MetaInfoPath
from app.core.security import verify_token, verify_apitoken
from app.schemas import MediaType, MediaRecognizeConvertEventData
from app.schemas.types import ChainEventType
from app.schemas import MediaType
router = APIRouter()
@@ -75,8 +72,7 @@ def search(title: str,
"""
模糊搜索媒体/人物信息列表 media媒体信息person人物信息
"""
def __get_source(obj: Union[schemas.MediaInfo, schemas.MediaPerson, dict]):
def __get_source(obj: Union[dict, schemas.MediaPerson]):
"""
获取对象属性
"""
@@ -89,8 +85,6 @@ def search(title: str,
_, medias = MediaChain().search(title=title)
if medias:
result = [media.to_dict() for media in medias]
elif type == "collection":
result = MediaChain().search_collections(name=title)
else:
result = MediaChain().search_persons(name=title)
if result:
@@ -117,13 +111,16 @@ def scrape(fileitem: schemas.FileItem,
scrape_path = Path(fileitem.path)
meta = MetaInfoPath(scrape_path)
mediainfo = chain.recognize_by_meta(meta)
if not mediainfo:
if not media_info:
return schemas.Response(success=False, message="刮削失败,无法识别媒体信息")
if storage == "local":
if not scrape_path.exists():
return schemas.Response(success=False, message="刮削路径不存在")
else:
if not fileitem.fileid:
return schemas.Response(success=False, message="刮削文件ID无效")
# 手动刮削
chain.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=True)
chain.manual_scrape(storage=storage, fileitem=fileitem, meta=meta, mediainfo=mediainfo)
return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成")
@@ -135,90 +132,25 @@ def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
return MediaChain().media_category() or {}
@router.get("/seasons", summary="查询媒体季信息", response_model=List[schemas.MediaSeason])
def seasons(mediaid: str = None,
title: str = None,
year: int = None,
season: int = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询媒体季信息
"""
if mediaid:
if mediaid.startswith("tmdb:"):
tmdbid = int(mediaid[5:])
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid)
if seasons_info:
if season:
return [sea for sea in seasons_info if sea.season_number == season]
return seasons_info
if title:
meta = MetaInfo(title)
if year:
meta.year = year
mediainfo = MediaChain().recognize_media(meta, mtype=MediaType.TV)
if mediainfo:
if settings.RECOGNIZE_SOURCE == "themoviedb":
seasons_info = TmdbChain().tmdb_seasons(tmdbid=mediainfo.tmdb_id)
if seasons_info:
if season:
return [sea for sea in seasons_info if sea.season_number == season]
return seasons_info
else:
sea = season or 1
return schemas.MediaSeason(
season_number=sea,
poster_path=mediainfo.poster_path,
name=f"{sea}",
air_date=mediainfo.release_date,
overview=mediainfo.overview,
vote_average=mediainfo.vote_average,
episode_count=mediainfo.number_of_episodes
)
return []
@router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo)
def detail(mediaid: str, type_name: str, title: str = None, year: int = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def media_info(mediaid: str, type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据媒体ID查询themoviedb或豆瓣媒体信息type_name: 电影/电视剧
"""
mtype = MediaType(type_name)
mediainfo = None
tmdbid, doubanid, bangumiid = None, None, None
if mediaid.startswith("tmdb:"):
mediainfo = MediaChain().recognize_media(tmdbid=int(mediaid[5:]), mtype=mtype)
tmdbid = int(mediaid[5:])
elif mediaid.startswith("douban:"):
mediainfo = MediaChain().recognize_media(doubanid=mediaid[7:], mtype=mtype)
doubanid = mediaid[7:]
elif mediaid.startswith("bangumi:"):
mediainfo = MediaChain().recognize_media(bangumiid=int(mediaid[8:]), mtype=mtype)
else:
# 广播事件解析媒体信息
event_data = MediaRecognizeConvertEventData(
mediaid=mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据
if event and event.event_data:
event_data: MediaRecognizeConvertEventData = event.event_data
if event_data.media_dict:
new_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
mediainfo = MediaChain().recognize_media(tmdbid=new_id, mtype=mtype)
elif event_data.convert_type == "douban":
mediainfo = MediaChain().recognize_media(doubanid=new_id, mtype=mtype)
elif title:
# 使用名称识别兜底
meta = MetaInfo(title)
if year:
meta.year = year
if mtype:
meta.type = mtype
mediainfo = MediaChain().recognize_media(meta=meta)
bangumiid = int(mediaid[8:])
if not tmdbid and not doubanid and not bangumiid:
return schemas.MediaInfo()
# 识别
mediainfo = MediaChain().recognize_media(tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid, mtype=mtype)
if mediainfo:
MediaChain().obtain_images(mediainfo)
return mediainfo.to_dict()
return schemas.MediaInfo()

View File

@@ -6,40 +6,38 @@ from sqlalchemy.orm import Session
from app import schemas
from app.chain.download import DownloadChain
from app.chain.mediaserver import MediaServerChain
from app.core.config import settings
from app.core.context import MediaInfo
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.db import get_db
from app.db.mediaserver_oper import MediaServerOper
from app.db.models import MediaServerItem
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.mediaserver import MediaServerHelper
from app.schemas import MediaType, NotExistMediaInfo
from app.schemas.types import SystemConfigKey
router = APIRouter()
@router.get("/play/{itemid:path}", summary="在线播放")
def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> schemas.Response:
@router.get("/play/{itemid}", summary="在线播放")
def play_item(itemid: str) -> schemas.Response:
"""
获取媒体服务器播放页面地址
"""
if not itemid:
return schemas.Response(success=False, message="参数错误")
configs = MediaServerHelper().get_configs()
if not configs:
return schemas.Response(success=False, message="未配置媒体服务器")
media_chain = MediaServerChain()
for name in configs.keys():
item = media_chain.iteminfo(server=name, item_id=itemid)
if item:
play_url = media_chain.get_play_url(server=name, item_id=itemid)
if play_url:
return schemas.Response(success=True, data={
"url": play_url
})
return schemas.Response(success=False, message="未找到播放地址")
return schemas.Response(success=False, msg="参数错误")
if not settings.MEDIASERVER:
return schemas.Response(success=False, msg="未配置媒体服务器")
# 查找一个不为空的值
mediaserver = next((server for server in settings.MEDIASERVER.split(",") if server), None)
if not mediaserver:
return schemas.Response(success=False, msg="未配置媒体服务器")
play_url = MediaServerChain().get_play_url(server=mediaserver, item_id=itemid)
# 重定向到play_url
if not play_url:
return schemas.Response(success=False, msg="未找到播放地址")
return schemas.Response(success=True, data={
"url": play_url
})
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)
@@ -121,38 +119,26 @@ def not_exists(media_in: schemas.MediaInfo,
@router.get("/latest", summary="最新入库条目", response_model=List[schemas.MediaServerPlayItem])
def latest(server: str, count: int = 18,
def latest(count: int = 18,
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取媒体服务器最新入库条目
"""
return MediaServerChain().latest(server=server, count=count, username=userinfo.username) or []
return MediaServerChain().latest(count=count, username=userinfo.username) or []
@router.get("/playing", summary="正在播放条目", response_model=List[schemas.MediaServerPlayItem])
def playing(server: str, count: int = 12,
def playing(count: int = 12,
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取媒体服务器正在播放条目
"""
return MediaServerChain().playing(server=server, count=count, username=userinfo.username) or []
return MediaServerChain().playing(count=count, username=userinfo.username) or []
@router.get("/library", summary="媒体库列表", response_model=List[schemas.MediaServerLibrary])
def library(server: str, hidden: bool = False,
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
def library(userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取媒体服务器媒体库列表
"""
return MediaServerChain().librarys(server=server, username=userinfo.username, hidden=hidden) or []
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询可用媒体服务器
"""
mediaservers: List[dict] = SystemConfigOper().get(SystemConfigKey.MediaServers)
if mediaservers:
return [{"name": d.get("name"), "type": d.get("type")} for d in mediaservers if d.get("enabled")]
return []
return MediaServerChain().librarys(username=userinfo.username) or []

View File

@@ -1,7 +1,8 @@
import json
from typing import Union, Any, List
from fastapi import APIRouter, BackgroundTasks, Depends, Request
from fastapi import APIRouter, BackgroundTasks, Depends
from fastapi import Request
from pywebpush import WebPushException, webpush
from sqlalchemy.orm import Session
from starlette.responses import PlainTextResponse
@@ -9,15 +10,16 @@ from starlette.responses import PlainTextResponse
from app import schemas
from app.chain.message import MessageChain
from app.core.config import settings, global_vars
from app.core.security import verify_token, verify_apitoken
from app.core.security import verify_token
from app.db import get_db
from app.db.models import User
from app.db.models.message import Message
from app.db.user_oper import get_current_active_superuser
from app.helper.service import ServiceConfigHelper
from app.db.systemconfig_oper import SystemConfigOper
from app.db.userauth import get_current_active_superuser
from app.log import logger
from app.modules.wechat.WXBizMsgCrypt3 import WXBizMsgCrypt
from app.schemas.types import MessageChannel
from app.schemas import NotificationSwitch
from app.schemas.types import SystemConfigKey, NotificationType, MessageChannel
router = APIRouter()
@@ -30,10 +32,9 @@ def start_message_chain(body: Any, form: Any, args: Any):
@router.post("/", summary="接收用户消息", response_model=schemas.Response)
async def user_message(background_tasks: BackgroundTasks, request: Request,
_: schemas.TokenPayload = Depends(verify_apitoken)):
async def user_message(background_tasks: BackgroundTasks, request: Request):
"""
用户消息响应配置请求中需要添加参数token=API_TOKEN&source=消息配置名
用户消息响应
"""
body = await request.body()
form = await request.form()
@@ -49,7 +50,6 @@ def web_message(text: str, current_user: User = Depends(get_current_active_super
"""
MessageChain().handle_message(
channel=MessageChannel.Web,
source=current_user.name,
userid=current_user.name,
username=current_user.name,
text=text
@@ -76,55 +76,87 @@ def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
return ret_messages
def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int], nonce: str,
source: str = None) -> Any:
def wechat_verify(echostr: str, msg_signature: str,
timestamp: Union[str, int], nonce: str) -> Any:
"""
微信验证响应
"""
# 获取服务配置
client_configs = ServiceConfigHelper.get_notification_configs()
if not client_configs:
return "未找到对应的消息配置"
client_config = next((config for config in client_configs if
config.type == "wechat" and config.enabled and (not source or config.name == source)), None)
if not client_config:
return "未找到对应的消息配置"
try:
wxcpt = WXBizMsgCrypt(sToken=client_config.config.get('WECHAT_TOKEN'),
sEncodingAESKey=client_config.config.get('WECHAT_ENCODING_AESKEY'),
sReceiveId=client_config.config.get('WECHAT_CORPID'))
ret, sEchoStr = wxcpt.VerifyURL(sMsgSignature=msg_signature,
sTimeStamp=timestamp,
sNonce=nonce,
sEchoStr=echostr)
if ret == 0:
# 验证URL成功将sEchoStr返回给企业号
return PlainTextResponse(sEchoStr)
return "微信验证失败"
wxcpt = WXBizMsgCrypt(sToken=settings.WECHAT_TOKEN,
sEncodingAESKey=settings.WECHAT_ENCODING_AESKEY,
sReceiveId=settings.WECHAT_CORPID)
except Exception as err:
logger.error(f"微信请求验证失败: {str(err)}")
return str(err)
ret, sEchoStr = wxcpt.VerifyURL(sMsgSignature=msg_signature,
sTimeStamp=timestamp,
sNonce=nonce,
sEchoStr=echostr)
if ret != 0:
logger.error("微信请求验证失败 VerifyURL ret: %s" % str(ret))
# 验证URL成功将sEchoStr返回给企业号
return PlainTextResponse(sEchoStr)
def vocechat_verify() -> Any:
def vocechat_verify(token: str) -> Any:
"""
VoceChat验证响应
"""
return {"status": "OK"}
if token == settings.API_TOKEN:
return {"status": "OK"}
return {"status": "ERROR"}
@router.get("/", summary="回调请求验证")
def incoming_verify(token: str = None, echostr: str = None, msg_signature: str = None,
timestamp: Union[str, int] = None, nonce: str = None, source: str = None,
_: schemas.TokenPayload = Depends(verify_apitoken)) -> Any:
timestamp: Union[str, int] = None, nonce: str = None) -> Any:
"""
微信/VoceChat等验证响应
"""
logger.info(f"收到验证请求: token={token}, echostr={echostr}, "
f"msg_signature={msg_signature}, timestamp={timestamp}, nonce={nonce}")
if echostr and msg_signature and timestamp and nonce:
return wechat_verify(echostr, msg_signature, timestamp, nonce, source)
return vocechat_verify()
return wechat_verify(echostr, msg_signature, timestamp, nonce)
return vocechat_verify(token)
@router.get("/switchs", summary="查询通知消息渠道开关", response_model=List[NotificationSwitch])
def read_switchs(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询通知消息渠道开关
"""
return_list = []
# 读取数据库
switchs = SystemConfigOper().get(SystemConfigKey.NotificationChannels)
if not switchs:
for noti in NotificationType:
return_list.append(NotificationSwitch(mtype=noti.value, wechat=True,
telegram=True, slack=True,
synologychat=True, vocechat=True))
else:
for switch in switchs:
return_list.append(NotificationSwitch(**switch))
for noti in NotificationType:
if not any([x.mtype == noti.value for x in return_list]):
return_list.append(NotificationSwitch(mtype=noti.value, wechat=True,
telegram=True, slack=True,
synologychat=True, vocechat=True))
return return_list
@router.post("/switchs", summary="设置通知消息渠道开关", response_model=schemas.Response)
def set_switchs(switchs: List[NotificationSwitch],
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
设置通知消息渠道开关
"""
switch_list = []
for switch in switchs:
switch_list.append(switch.dict())
# 存入数据库
SystemConfigOper().set(SystemConfigKey.NotificationChannels, switch_list)
return schemas.Response(success=True)
@router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response)

View File

@@ -1,124 +1,43 @@
from typing import Annotated, Any, List, Optional
from typing import Any, List, Annotated
from fastapi import APIRouter, Depends, Header
from app import schemas
from app.command import Command
from app.core.config import settings
from app.core.plugin import PluginManager
from app.core.security import verify_apikey, verify_token
from app.core.security import verify_token
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.factory import app
from app.helper.plugin import PluginHelper
from app.log import logger
from app.scheduler import Scheduler
from app.schemas.types import SystemConfigKey
PROTECTED_ROUTES = {"/api/v1/openapi.json", "/docs", "/docs/oauth2-redirect", "/redoc"}
PLUGIN_PREFIX = f"{settings.API_V1_STR}/plugin"
router = APIRouter()
def register_plugin_api(plugin_id: Optional[str] = None):
def register_plugin_api(plugin_id: str = None):
"""
动态注册插件 API
:param plugin_id: 插件 ID如果为 None则注册所有插件
注册插件API(先删除后新增)
"""
_update_plugin_api_routes(plugin_id, action="add")
for api in PluginManager().get_plugin_apis(plugin_id):
for r in router.routes:
if r.path == api.get("path"):
router.routes.remove(r)
break
router.add_api_route(**api)
def remove_plugin_api(plugin_id: str):
"""
动态移除单个插件的 API
:param plugin_id: 插件 ID
移除插件API
"""
_update_plugin_api_routes(plugin_id, action="remove")
def _update_plugin_api_routes(plugin_id: Optional[str], action: str):
"""
插件 API 路由注册和移除
:param plugin_id: 插件 ID如果 action 为 "add" 且 plugin_id 为 None则处理所有插件
如果 action 为 "remove"plugin_id 必须是有效的插件 ID
:param action: "add""remove",决定是添加还是移除路由
"""
if action not in {"add", "remove"}:
raise ValueError("Action must be 'add' or 'remove'")
is_modified = False
existing_paths = {route.path: route for route in app.routes}
plugin_ids = [plugin_id] if plugin_id else PluginManager().get_running_plugin_ids()
for plugin_id in plugin_ids:
routes_removed = _remove_routes(plugin_id)
if routes_removed:
is_modified = True
if action != "add":
continue
# 获取插件的 API 路由信息
plugin_apis = PluginManager().get_plugin_apis(plugin_id)
for api in plugin_apis:
api_path = f"{PLUGIN_PREFIX}{api.get('path', '')}"
try:
api["path"] = api_path
allow_anonymous = api.pop("allow_anonymous", False)
dependencies = api.setdefault("dependencies", [])
if not allow_anonymous and Depends(verify_apikey) not in dependencies:
dependencies.append(Depends(verify_apikey))
app.add_api_route(**api, tags=["plugin"])
is_modified = True
logger.debug(f"Added plugin route: {api_path}")
except Exception as e:
logger.error(f"Error adding plugin route {api_path}: {str(e)}")
if is_modified:
_clean_protected_routes(existing_paths)
app.openapi_schema = None
app.setup()
def _remove_routes(plugin_id: str) -> bool:
"""
移除与单个插件相关的路由
:param plugin_id: 插件 ID
:return: 是否有路由被移除
"""
if not plugin_id:
return False
prefix = f"{PLUGIN_PREFIX}/{plugin_id}/"
routes_to_remove = [route for route in app.routes if route.path.startswith(prefix)]
removed = False
for route in routes_to_remove:
try:
app.routes.remove(route)
removed = True
logger.debug(f"Removed plugin route: {route.path}")
except Exception as e:
logger.error(f"Error removing plugin route {route.path}: {str(e)}")
return removed
def _clean_protected_routes(existing_paths: dict):
"""
清理受保护的路由,防止在插件操作中被删除或重复添加
:param existing_paths: 当前应用的路由路径映射
"""
for protected_route in PROTECTED_ROUTES:
try:
existing_route = existing_paths.get(protected_route)
if existing_route:
app.routes.remove(existing_route)
except Exception as e:
logger.error(f"Error removing protected route {protected_route}: {str(e)}")
for api in PluginManager().get_plugin_apis(plugin_id):
for r in router.routes:
if r.path == api.get("path"):
router.routes.remove(r)
break
@router.get("/", summary="所有插件", response_model=List[schemas.Plugin])
def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
state: str = "all") -> List[schemas.Plugin]:
def all_plugins(_: schemas.TokenPayload = Depends(verify_token), state: str = "all") -> List[schemas.Plugin]:
"""
查询所有插件清单包括本地插件和在线插件插件状态installed, market, all
"""
@@ -164,7 +83,7 @@ def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
@router.get("/installed", summary="已安装插件", response_model=List[str])
def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
def installed(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询用户已安装插件清单
"""
@@ -183,7 +102,7 @@ def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def install(plugin_id: str,
repo_url: str = "",
force: bool = False,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
安装插件
"""
@@ -211,8 +130,6 @@ def install(plugin_id: str,
PluginManager().reload_plugin(plugin_id)
# 注册插件服务
Scheduler().update_plugin_job(plugin_id)
# 注册菜单命令
Command().init_commands(plugin_id)
# 注册插件API
register_plugin_api(plugin_id)
return schemas.Response(success=True)
@@ -220,7 +137,7 @@ def install(plugin_id: str,
@router.get("/form/{plugin_id}", summary="获取插件表单页面")
def plugin_form(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
_: schemas.TokenPayload = Depends(verify_token)) -> dict:
"""
根据插件ID获取插件配置表单
"""
@@ -232,7 +149,7 @@ def plugin_form(plugin_id: str,
@router.get("/page/{plugin_id}", summary="获取插件数据页面")
def plugin_page(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> List[dict]:
def plugin_page(plugin_id: str, _: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
"""
根据插件ID获取插件数据页面
"""
@@ -253,7 +170,7 @@ def plugin_dashboard(plugin_id: str, user_agent: Annotated[str | None, Header()]
"""
根据插件ID获取插件仪表板
"""
return PluginManager().get_plugin_dashboard(plugin_id, user_agent=user_agent)
return PluginManager().get_plugin_dashboard(plugin_id, key=None, user_agent=user_agent)
@router.get("/dashboard/{plugin_id}/{key}", summary="获取插件仪表板配置")
@@ -266,8 +183,7 @@ def plugin_dashboard(plugin_id: str, key: str, user_agent: Annotated[str | None,
@router.get("/reset/{plugin_id}", summary="重置插件配置及数据", response_model=schemas.Response)
def reset_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
def reset_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据插件ID重置插件配置及数据
"""
@@ -279,16 +195,13 @@ def reset_plugin(plugin_id: str,
PluginManager().reload_plugin(plugin_id)
# 注册插件服务
Scheduler().update_plugin_job(plugin_id)
# 注册菜单命令
Command().init_commands(plugin_id)
# 注册插件API
register_plugin_api(plugin_id)
return schemas.Response(success=True)
@router.get("/{plugin_id}", summary="获取插件配置")
def plugin_config(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
def plugin_config(plugin_id: str, _: schemas.TokenPayload = Depends(verify_token)) -> dict:
"""
根据插件ID获取插件配置信息
"""
@@ -297,7 +210,7 @@ def plugin_config(plugin_id: str,
@router.put("/{plugin_id}", summary="更新插件配置", response_model=schemas.Response)
def set_plugin_config(plugin_id: str, conf: dict,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
更新插件配置
"""
@@ -307,8 +220,6 @@ def set_plugin_config(plugin_id: str, conf: dict,
PluginManager().init_plugin(plugin_id, conf)
# 注册插件服务
Scheduler().update_plugin_job(plugin_id)
# 注册菜单命令
Command().init_commands(plugin_id)
# 注册插件API
register_plugin_api(plugin_id)
return schemas.Response(success=True)
@@ -316,7 +227,7 @@ def set_plugin_config(plugin_id: str, conf: dict,
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
def uninstall_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
卸载插件
"""
@@ -328,12 +239,12 @@ def uninstall_plugin(plugin_id: str,
break
# 保存
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
# 移除插件API
remove_plugin_api(plugin_id)
# 移除插件服务
Scheduler().remove_plugin_job(plugin_id)
# 移除插件
PluginManager().remove_plugin(plugin_id)
# 移除插件服务
Scheduler().remove_plugin_job(plugin_id)
# 移除插件API
remove_plugin_api(plugin_id)
return schemas.Response(success=True)

View File

@@ -1,191 +0,0 @@
from typing import Any, List
from fastapi import APIRouter, Depends
from app import schemas
from app.core.event import eventmanager
from app.core.security import verify_token
from app.schemas.types import ChainEventType
from chain.recommend import RecommendChain
from schemas import RecommendSourceEventData
router = APIRouter()
@router.get("/source", summary="获取推荐数据源", response_model=List[schemas.RecommendMediaSource])
def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取推荐数据源
"""
# 广播事件,请示额外的推荐数据源支持
event_data = RecommendSourceEventData()
event = eventmanager.send_event(ChainEventType.RecommendSource, event_data)
# 使用事件返回的上下文数据
if event and event.event_data:
event_data: RecommendSourceEventData = event.event_data
if event_data.extra_sources:
return event_data.extra_sources
return []
@router.get("/bangumi_calendar", summary="Bangumi每日放送", response_model=List[schemas.MediaInfo])
def bangumi_calendar(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览Bangumi每日放送
"""
return RecommendChain().bangumi_calendar(page=page, count=count)
@router.get("/douban_showing", summary="豆瓣正在热映", response_model=List[schemas.MediaInfo])
def douban_showing(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣正在热映
"""
return RecommendChain().douban_movie_showing(page=page, count=count)
@router.get("/douban_movies", summary="豆瓣电影", response_model=List[schemas.MediaInfo])
def douban_movies(sort: str = "R",
tags: str = "",
page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣电影信息
"""
return RecommendChain().douban_movies(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_tvs", summary="豆瓣剧集", response_model=List[schemas.MediaInfo])
def douban_tvs(sort: str = "R",
tags: str = "",
page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return RecommendChain().douban_tvs(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
def douban_movie_top250(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return RecommendChain().douban_movie_top250(page=page, count=count)
@router.get("/douban_tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
def douban_tv_weekly_chinese(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
中国每周剧集口碑榜
"""
return RecommendChain().douban_tv_weekly_chinese(page=page, count=count)
@router.get("/douban_tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
def douban_tv_weekly_global(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
全球每周剧集口碑榜
"""
return RecommendChain().douban_tv_weekly_global(page=page, count=count)
@router.get("/douban_tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
def douban_tv_animation(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门动画剧集
"""
return RecommendChain().douban_tv_animation(page=page, count=count)
@router.get("/douban_movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
def douban_movie_hot(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门电影
"""
return RecommendChain().douban_movie_hot(page=page, count=count)
@router.get("/douban_tv_hot", summary="豆瓣热门电视剧", response_model=List[schemas.MediaInfo])
def douban_tv_hot(page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门电视剧
"""
return RecommendChain().douban_tv_hot(page=page, count=count)
@router.get("/tmdb_movies", summary="TMDB电影", response_model=List[schemas.MediaInfo])
def tmdb_movies(sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "",
with_keywords: str = "",
with_watch_providers: str = "",
vote_average: float = 0,
vote_count: int = 0,
release_date: str = "",
page: int = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB电影信息
"""
return RecommendChain().tmdb_movies(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
@router.get("/tmdb_tvs", summary="TMDB剧集", response_model=List[schemas.MediaInfo])
def tmdb_tvs(sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "",
with_keywords: str = "",
with_watch_providers: str = "",
vote_average: float = 0,
vote_count: int = 0,
release_date: str = "",
page: int = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
return RecommendChain().tmdb_tvs(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
@router.get("/tmdb_trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo])
def tmdb_trending(page: int = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
TMDB流行趋势
"""
return RecommendChain().tmdb_trending(page=page)

View File

@@ -6,11 +6,8 @@ from app import schemas
from app.chain.media import MediaChain
from app.chain.search import SearchChain
from app.core.config import settings
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.schemas import MediaRecognizeConvertEventData
from app.schemas.types import MediaType, ChainEventType
from app.schemas.types import MediaType
router = APIRouter()
@@ -28,8 +25,6 @@ def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
def search_by_id(mediaid: str,
mtype: str = None,
area: str = "title",
title: str = None,
year: int = None,
season: str = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
@@ -39,8 +34,6 @@ def search_by_id(mediaid: str,
mtype = MediaType(mtype)
if season:
season = int(season)
torrents = None
# 根据前缀识别媒体ID
if mediaid.startswith("tmdb:"):
tmdbid = int(mediaid.replace("tmdb:", ""))
if settings.RECOGNIZE_SOURCE == "douban":
@@ -86,44 +79,8 @@ def search_by_id(mediaid: str,
else:
return schemas.Response(success=False, message="未识别到豆瓣媒体信息")
else:
# 未知前缀,广播事件解析媒体信息
event_data = MediaRecognizeConvertEventData(
mediaid=mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据
if event and event.event_data:
event_data: MediaRecognizeConvertEventData = event.event_data
if event_data.media_dict:
search_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
torrents = SearchChain().search_by_id(tmdbid=search_id,
mtype=mtype, area=area, season=season)
elif event_data.convert_type == "douban":
torrents = SearchChain().search_by_id(doubanid=search_id,
mtype=mtype, area=area, season=season)
else:
if not title:
return schemas.Response(success=False, message="未知的媒体ID")
# 使用名称识别兜底
meta = MetaInfo(title)
if year:
meta.year = year
if mtype:
meta.type = mtype
if season:
meta.type = MediaType.TV
meta.begin_season = season
mediainfo = MediaChain().recognize_media(meta=meta)
if mediainfo:
if settings.RECOGNIZE_SOURCE == "themoviedb":
torrents = SearchChain().search_by_id(tmdbid=mediainfo.tmdb_id,
mtype=mtype, area=area, season=season)
else:
torrents = SearchChain().search_by_id(doubanid=mediainfo.douban_id,
mtype=mtype, area=area, season=season)
# 返回搜索结果
return schemas.Response(success=False, message="未知的媒体ID")
if not torrents:
return schemas.Response(success=False, message="未搜索到任何资源")
else:

View File

@@ -1,4 +1,4 @@
from typing import List, Any, Dict
from typing import List, Any
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
@@ -8,16 +8,14 @@ from app import schemas
from app.chain.site import SiteChain
from app.chain.torrents import TorrentsChain
from app.core.event import EventManager
from app.core.plugin import PluginManager
from app.core.security import verify_token
from app.db import get_db
from app.db.models import User
from app.db.models.site import Site
from app.db.models.siteicon import SiteIcon
from app.db.models.sitestatistic import SiteStatistic
from app.db.models.siteuserdata import SiteUserData
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.db.userauth import get_current_active_superuser
from app.helper.sites import SitesHelper
from app.scheduler import Scheduler
from app.schemas.types import SystemConfigKey, EventType
@@ -28,7 +26,7 @@ router = APIRouter()
@router.get("/", summary="所有站点", response_model=List[schemas.Site])
def read_sites(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> List[dict]:
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
"""
获取站点列表
"""
@@ -40,7 +38,7 @@ def add_site(
*,
db: Session = Depends(get_db),
site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser)
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
新增站点
@@ -77,7 +75,7 @@ def update_site(
*,
db: Session = Depends(get_db),
site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser)
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
更新站点信息
@@ -98,7 +96,7 @@ def update_site(
@router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response)
def cookie_cloud_sync(background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
运行CookieCloud同步站点信息
"""
@@ -129,7 +127,7 @@ def reset(db: Session = Depends(get_db),
def update_sites_priority(
priorities: List[dict],
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
批量更新站点优先级
"""
@@ -147,7 +145,7 @@ def update_cookie(
password: str,
code: str = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
使用用户密码更新站点Cookie
"""
@@ -166,61 +164,6 @@ def update_cookie(
return schemas.Response(success=state, message=message)
@router.post("/userdata/{site_id}", summary="更新站点用户数据", response_model=schemas.Response)
def refresh_userdata(
site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
"""
刷新站点用户数据
"""
site = Site.get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
indexer = SitesHelper().get_indexer(site.domain)
if not indexer:
return schemas.Response(success=False, message="站点不支持索引或未通过用户认证!")
user_data = SiteChain().refresh_userdata(site=indexer) or {}
return schemas.Response(success=True, data=user_data)
@router.get("/userdata/latest", summary="查询所有站点最新用户数据", response_model=List[schemas.SiteUserData])
def read_userdata_latest(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
"""
查询所有站点最新用户数据
"""
user_datas = SiteUserData.get_latest(db)
if not user_datas:
return []
return [user_data.to_dict() for user_data in user_datas]
@router.get("/userdata/{site_id}", summary="查询某站点用户数据", response_model=schemas.Response)
def read_userdata(
site_id: int,
workdate: str = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
"""
查询站点用户数据
"""
site = Site.get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
user_data = SiteUserData.get_by_domain(db, domain=site.domain, workdate=workdate)
if not user_data:
return schemas.Response(success=False, data=[])
return schemas.Response(success=True, data=user_data)
@router.get("/test/{site_id}", summary="连接测试", response_model=schemas.Response)
def test_site(site_id: int,
db: Session = Depends(get_db),
@@ -259,43 +202,10 @@ def site_icon(site_id: int,
})
@router.get("/category/{site_id}", summary="站点分类", response_model=List[schemas.SiteCategory])
def site_category(site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取站点分类
"""
site = Site.get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
indexer = SitesHelper().get_indexer(site.domain)
if not indexer:
raise HTTPException(
status_code=404,
detail=f"站点 {site.domain} 不支持",
)
category: Dict[str, List[dict]] = indexer.get('category') or []
if not category:
return []
result = []
for cats in category.values():
for cat in cats:
if cat not in result:
result.append(cat)
return result
@router.get("/resource/{site_id}", summary="站点资源", response_model=List[schemas.TorrentInfo])
def site_resource(site_id: int,
keyword: str = None,
cat: str = None,
page: int = 0,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览站点资源
"""
@@ -305,7 +215,7 @@ def site_resource(site_id: int,
status_code=404,
detail=f"站点 {site_id} 不存在",
)
torrents = TorrentsChain().browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
torrents = TorrentsChain().browse(domain=site.domain)
if not torrents:
return []
return [torrent.to_dict() for torrent in torrents]
@@ -347,8 +257,7 @@ def read_site_by_domain(
@router.get("/rss", summary="所有订阅站点", response_model=List[schemas.Site])
def read_rss_sites(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
def read_rss_sites(db: Session = Depends(get_db)) -> List[dict]:
"""
获取站点列表
"""
@@ -365,36 +274,11 @@ def read_rss_sites(db: Session = Depends(get_db),
return rss_sites
@router.get("/auth", summary="查询认证站点", response_model=dict)
def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
"""
获取可认证站点列表
"""
return SitesHelper().get_authsites()
@router.post("/auth", summary="用户站点认证", response_model=schemas.Response)
def auth_site(
auth_info: schemas.SiteAuth,
_: User = Depends(get_current_active_superuser)
) -> Any:
"""
用户站点认证
"""
if not auth_info or not auth_info.site or not auth_info.params:
return schemas.Response(success=False, message="请输入认证站点和认证参数")
status, msg = SitesHelper().check_user(auth_info.site, auth_info.params)
SystemConfigOper().set(SystemConfigKey.UserSiteAuthParams, auth_info.dict())
PluginManager().init_config()
Scheduler().init_plugin_jobs()
return schemas.Response(success=status, message=msg)
@router.get("/{site_id}", summary="站点详情", response_model=schemas.Site)
def read_site(
site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
通过ID获取站点信息

View File

@@ -1,218 +0,0 @@
from datetime import datetime
from pathlib import Path
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from starlette.responses import FileResponse, Response
from app import schemas
from app.chain.storage import StorageChain
from app.chain.transfer import TransferChain
from app.core.config import settings
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token
from app.db.models import User
from app.db.user_oper import get_current_active_superuser
from app.helper.progress import ProgressHelper
from app.schemas.types import ProgressKey
router = APIRouter()
@router.get("/qrcode/{name}", summary="生成二维码内容", response_model=schemas.Response)
def qrcode(name: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
生成二维码
"""
qrcode_data, errmsg = StorageChain().generate_qrcode(name)
if qrcode_data:
return schemas.Response(success=True, data=qrcode_data, message=errmsg)
return schemas.Response(success=False)
@router.get("/check/{name}", summary="二维码登录确认", response_model=schemas.Response)
def check(name: str, ck: str = None, t: str = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
二维码登录确认
"""
if ck or t:
data, errmsg = StorageChain().check_login(name, ck=ck, t=t)
else:
data, errmsg = StorageChain().check_login(name)
if data:
return schemas.Response(success=True, data=data)
return schemas.Response(success=False, message=errmsg)
@router.post("/save/{name}", summary="保存存储配置", response_model=schemas.Response)
def save(name: str,
conf: dict,
_: User = Depends(get_current_active_superuser)) -> Any:
"""
保存存储配置
"""
StorageChain().save_config(name, conf)
return schemas.Response(success=True)
@router.post("/list", summary="所有目录和文件", response_model=List[schemas.FileItem])
def list_files(fileitem: schemas.FileItem,
sort: str = 'updated_at',
_: User = Depends(get_current_active_superuser)) -> Any:
"""
查询当前目录下所有目录和文件
:param fileitem: 文件项
:param sort: 排序方式name:按名称排序time:按修改时间排序
:param _: token
:return: 所有目录和文件
"""
file_list = StorageChain().list_files(fileitem)
if file_list:
if sort == "name":
file_list.sort(key=lambda x: x.name or "")
else:
file_list.sort(key=lambda x: x.modify_time or datetime.min, reverse=True)
return file_list
@router.post("/mkdir", summary="创建目录", response_model=schemas.Response)
def mkdir(fileitem: schemas.FileItem,
name: str,
_: User = Depends(get_current_active_superuser)) -> Any:
"""
创建目录
:param fileitem: 文件项
:param name: 目录名称
:param _: token
"""
if not name:
return schemas.Response(success=False)
result = StorageChain().create_folder(fileitem, name)
if result:
return schemas.Response(success=True)
return schemas.Response(success=False)
@router.post("/delete", summary="删除文件或目录", response_model=schemas.Response)
def delete(fileitem: schemas.FileItem,
_: User = Depends(get_current_active_superuser)) -> Any:
"""
删除文件或目录
:param fileitem: 文件项
:param _: token
"""
result = StorageChain().delete_file(fileitem)
if result:
return schemas.Response(success=True)
return schemas.Response(success=False)
@router.post("/download", summary="下载文件")
def download(fileitem: schemas.FileItem,
_: User = Depends(get_current_active_superuser)) -> Any:
"""
下载文件或目录
:param fileitem: 文件项
:param _: token
"""
# 临时目录
tmp_file = StorageChain().download_file(fileitem)
if tmp_file:
return FileResponse(path=tmp_file)
return schemas.Response(success=False)
@router.post("/image", summary="预览图片")
def image(fileitem: schemas.FileItem,
_: User = Depends(get_current_active_superuser)) -> Any:
"""
下载文件或目录
:param fileitem: 文件项
:param _: token
"""
# 临时目录
tmp_file = StorageChain().download_file(fileitem)
if not tmp_file:
raise HTTPException(status_code=500, detail="图片读取出错")
return Response(content=tmp_file.read_bytes(), media_type="image/jpeg")
@router.post("/rename", summary="重命名文件或目录", response_model=schemas.Response)
def rename(fileitem: schemas.FileItem,
new_name: str,
recursive: bool = False,
_: User = Depends(get_current_active_superuser)) -> Any:
"""
重命名文件或目录
:param fileitem: 文件项
:param new_name: 新名称
:param recursive: 是否递归修改
:param _: token
"""
if not new_name:
return schemas.Response(success=False, message="新名称为空")
result = StorageChain().rename_file(fileitem, new_name)
if result:
if recursive:
transferchain = TransferChain()
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIO_TRACK_EXT
# 递归修改目录内文件(智能识别命名)
sub_files: List[schemas.FileItem] = StorageChain().list_files(fileitem)
if sub_files:
# 开始进度
progress = ProgressHelper()
progress.start(ProgressKey.BatchRename)
total = len(sub_files)
handled = 0
for sub_file in sub_files:
handled += 1
progress.update(value=handled / total * 100,
text=f"正在处理 {sub_file.name} ...",
key=ProgressKey.BatchRename)
if sub_file.type == "dir":
continue
if not sub_file.extension:
continue
if f".{sub_file.extension.lower()}" not in media_exts:
continue
sub_path = Path(f"{fileitem.path}{sub_file.name}")
meta = MetaInfoPath(sub_path)
mediainfo = transferchain.recognize_media(meta)
if not mediainfo:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到媒体信息")
new_path = transferchain.recommend_name(meta=meta, mediainfo=mediainfo)
if not new_path:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到新名称")
ret: schemas.Response = rename(fileitem=sub_file,
new_name=Path(new_path).name,
recursive=False)
if not ret.success:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 重命名失败!")
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=True)
return schemas.Response(success=False)
@router.get("/usage/{name}", summary="存储空间信息", response_model=schemas.StorageUsage)
def usage(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
"""
查询存储空间
"""
ret = StorageChain().storage_usage(name)
if ret:
return ret
return schemas.StorageUsage()
@router.get("/transtype/{name}", summary="支持的整理方式获取", response_model=schemas.StorageTransType)
def transtype(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
"""
查询支持的整理方式
"""
ret = StorageChain().support_transtype(name)
if ret:
return schemas.StorageTransType(transtype=ret)
return schemas.StorageTransType()

View File

@@ -1,3 +1,4 @@
import json
from typing import List, Any
import cn2an
@@ -8,18 +9,16 @@ from app import schemas
from app.chain.subscribe import SubscribeChain
from app.core.config import settings
from app.core.context import MediaInfo
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db.models.subscribe import Subscribe
from app.db.models.subscribehistory import SubscribeHistory
from app.db.models.user import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user
from app.db.userauth import get_current_active_user
from app.helper.subscribe import SubscribeHelper
from app.scheduler import Scheduler
from app.schemas.types import MediaType, EventType, SystemConfigKey
from app.schemas.types import MediaType
router = APIRouter()
@@ -40,7 +39,16 @@ def read_subscribes(
"""
查询所有订阅
"""
return Subscribe.list(db)
subscribes = Subscribe.list(db)
for subscribe in subscribes:
if subscribe.sites:
try:
subscribe.sites = json.loads(str(subscribe.sites))
except json.JSONDecodeError:
subscribe.sites = []
else:
subscribe.sites = []
return subscribes
@router.get("/list", summary="查询所有订阅API_TOKEN", response_model=List[schemas.Subscribe])
@@ -56,7 +64,7 @@ def create_subscribe(
*,
subscribe_in: schemas.Subscribe,
current_user: User = Depends(get_current_active_user),
) -> schemas.Response:
) -> Any:
"""
新增订阅
"""
@@ -82,14 +90,10 @@ def create_subscribe(
season=subscribe_in.season,
doubanid=subscribe_in.doubanid,
bangumiid=subscribe_in.bangumiid,
mediaid=subscribe_in.mediaid,
username=current_user.name,
best_version=subscribe_in.best_version,
save_path=subscribe_in.save_path,
search_imdbid=subscribe_in.search_imdbid,
custom_words=subscribe_in.custom_words,
media_category=subscribe_in.media_category,
filter_groups=subscribe_in.filter_groups,
exist_ok=True)
return schemas.Response(
success=bool(sid), message=message, data={"id": sid}
@@ -109,8 +113,9 @@ def update_subscribe(
subscribe = Subscribe.get(db, subscribe_in.id)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
if subscribe_in.sites is not None:
subscribe_in.sites = json.dumps(subscribe_in.sites)
# 避免更新缺失集数
old_subscribe_dict = subscribe.to_dict()
subscribe_dict = subscribe_in.dict()
if not subscribe_in.lack_episode:
# 没有缺失集数时缺失集数清空避免更新为0
@@ -125,40 +130,6 @@ def update_subscribe(
if subscribe_in.total_episode != subscribe.total_episode:
subscribe_dict["manual_total_episode"] = 1
subscribe.update(db, subscribe_dict)
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
})
return schemas.Response(success=True)
@router.put("/status/{subid}", summary="更新订阅状态", response_model=schemas.Response)
def update_subscribe_status(
subid: int,
state: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
更新订阅状态
"""
subscribe = Subscribe.get(db, subid)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
valid_states = ["R", "P", "S"]
if state not in valid_states:
return schemas.Response(success=False, message="无效的订阅状态")
old_subscribe_dict = subscribe.to_dict()
subscribe.update(db, {
"state": state
})
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
})
return schemas.Response(success=True)
@@ -172,6 +143,7 @@ def subscribe_mediaid(
"""
根据 TMDBID/豆瓣ID/BangumiId 查询订阅 tmdb:/douban:
"""
result = None
title_check = False
if mediaid.startswith("tmdb:"):
tmdbid = mediaid[5:]
@@ -192,16 +164,17 @@ def subscribe_mediaid(
result = Subscribe.get_by_bangumiid(db, int(bangumiid))
if not result and title:
title_check = True
else:
result = Subscribe.get_by_mediaid(db, mediaid)
if not result and title:
title_check = True
# 使用名称检查订阅
if title_check and title:
meta = MetaInfo(title)
if season:
meta.begin_season = season
result = Subscribe.get_by_title(db, title=meta.name, season=meta.begin_season)
if result and result.sites:
try:
result.sites = json.loads(result.sites)
except json.JSONDecodeError:
result.sites = []
return result if result else Subscribe()
@@ -226,17 +199,9 @@ def reset_subscribes(
"""
subscribe = Subscribe.get(db, subid)
if subscribe:
old_subscribe_dict = subscribe.to_dict()
subscribe.update(db, {
"note": [],
"lack_episode": subscribe.total_episode,
"state": "R"
})
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
"note": "",
"lack_episode": subscribe.total_episode
})
return schemas.Response(success=True)
return schemas.Response(success=False, message="订阅不存在")
@@ -301,31 +266,17 @@ def delete_subscribe_by_mediaid(
"""
根据TMDBID或豆瓣ID删除订阅 tmdb:/douban:
"""
delete_subscribes = []
if mediaid.startswith("tmdb:"):
tmdbid = mediaid[5:]
if not tmdbid or not str(tmdbid).isdigit():
return schemas.Response(success=False)
subscribes = Subscribe().get_by_tmdbid(db, int(tmdbid), season)
delete_subscribes.extend(subscribes)
Subscribe().delete_by_tmdbid(db, int(tmdbid), season)
elif mediaid.startswith("douban:"):
doubanid = mediaid[7:]
if not doubanid:
return schemas.Response(success=False)
subscribe = Subscribe().get_by_doubanid(db, doubanid)
if subscribe:
delete_subscribes.append(subscribe)
else:
subscribe = Subscribe().get_by_mediaid(db, mediaid)
if subscribe:
delete_subscribes.append(subscribe)
for subscribe in delete_subscribes:
Subscribe().delete(db, subscribe.id)
# 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe.id,
"subscribe_info": subscribe.to_dict()
})
Subscribe().delete_by_doubanid(db, doubanid)
return schemas.Response(success=True)
@@ -383,7 +334,7 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
@router.get("/history/{mtype}", summary="查询订阅历史", response_model=List[schemas.Subscribe])
def subscribe_history(
def read_subscribe(
mtype: str,
page: int = 1,
count: int = 30,
@@ -392,7 +343,14 @@ def subscribe_history(
"""
查询电影/电视剧订阅历史
"""
return SubscribeHistory.list_by_type(db, mtype=mtype, page=page, count=count)
historys = SubscribeHistory.list_by_type(db, mtype=mtype, page=page, count=count)
for history in historys:
if history and history.sites:
try:
history.sites = json.loads(history.sites)
except json.JSONDecodeError:
history.sites = []
return historys
@router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response)
@@ -453,123 +411,6 @@ def popular_subscribes(
return []
@router.get("/user/{username}", summary="用户订阅", response_model=List[schemas.Subscribe])
def user_subscribes(
username: str,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询用户订阅
"""
return Subscribe.list_by_username(db, username)
@router.get("/files/{subscribe_id}", summary="订阅相关文件信息", response_model=schemas.SubscrbieInfo)
def subscribe_files(
subscribe_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
订阅相关文件信息
"""
subscribe = Subscribe.get(db, subscribe_id)
if subscribe:
return SubscribeChain().subscribe_files_info(subscribe)
return schemas.SubscrbieInfo()
@router.post("/share", summary="分享订阅", response_model=schemas.Response)
def subscribe_share(
sub: schemas.SubscribeShare,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
分享订阅
"""
state, errmsg = SubscribeHelper().sub_share(subscribe_id=sub.subscribe_id,
share_title=sub.share_title,
share_comment=sub.share_comment,
share_user=sub.share_user)
return schemas.Response(success=state, message=errmsg)
@router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response)
def subscribe_share_delete(
share_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除分享
"""
state, errmsg = SubscribeHelper().share_delete(share_id=share_id)
return schemas.Response(success=state, message=errmsg)
@router.post("/fork", summary="复用订阅", response_model=schemas.Response)
def subscribe_fork(
sub: schemas.SubscribeShare,
current_user: User = Depends(get_current_active_user)) -> Any:
"""
复用订阅
"""
sub_dict = sub.dict()
sub_dict.pop("id")
for key in list(sub_dict.keys()):
if not hasattr(schemas.Subscribe(), key):
sub_dict.pop(key)
result = create_subscribe(subscribe_in=schemas.Subscribe(**sub_dict),
current_user=current_user)
if result.success:
SubscribeHelper().sub_fork(share_id=sub.id)
return result
@router.get("/follow", summary="查询已Follow的订阅分享人", response_model=List[str])
def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询已Follow的订阅分享人
"""
return SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
@router.post("/follow", summary="Follow订阅分享人", response_model=schemas.Response)
def follow_subscriber(
share_uid: str = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
Follow订阅分享人
"""
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid not in subscribers:
subscribers.append(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True)
@router.delete("/follow", summary="取消Follow订阅分享人", response_model=schemas.Response)
def unfollow_subscriber(
share_uid: str = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
取消Follow订阅分享人
"""
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid in subscribers:
subscribers.remove(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True)
@router.get("/shares", summary="查询分享的订阅", response_model=List[schemas.SubscribeShare])
def popular_subscribes(
name: str = None,
page: int = 1,
count: int = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询分享的订阅
"""
return SubscribeHelper().get_shares(name=name, page=page, count=count)
@router.get("/{subscribe_id}", summary="订阅详情", response_model=schemas.Subscribe)
def read_subscribe(
subscribe_id: int,
@@ -580,7 +421,13 @@ def read_subscribe(
"""
if not subscribe_id:
return Subscribe()
return Subscribe.get(db, subscribe_id)
subscribe = Subscribe.get(db, subscribe_id)
if subscribe and subscribe.sites:
try:
subscribe.sites = json.loads(subscribe.sites)
except json.JSONDecodeError:
subscribe.sites = []
return subscribe
@router.delete("/{subscribe_id}", summary="删除订阅", response_model=schemas.Response)
@@ -595,14 +442,9 @@ def delete_subscribe(
subscribe = Subscribe.get(db, subscribe_id)
if subscribe:
subscribe.delete(db, subscribe_id)
# 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe_id,
"subscribe_info": subscribe.to_dict()
})
# 统计订阅
SubscribeHelper().sub_done_async({
"tmdbid": subscribe.tmdbid,
"doubanid": subscribe.doubanid
})
# 统计订阅
SubscribeHelper().sub_done_async({
"tmdbid": subscribe.tmdbid,
"doubanid": subscribe.doubanid
})
return schemas.Response(success=True)

View File

@@ -1,199 +1,57 @@
import asyncio
import io
import json
import tempfile
from collections import deque
import time
from datetime import datetime
from pathlib import Path
from typing import Optional, Union
from typing import Union, Any
import aiofiles
import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image
from fastapi import APIRouter, Depends, HTTPException, Header, Request, Response
import tailer
from dotenv import set_key
from fastapi import APIRouter, HTTPException, Depends, Response
from fastapi.responses import StreamingResponse
from app import schemas
from app.chain.search import SearchChain
from app.chain.system import SystemChain
from app.core.config import global_vars, settings
from app.core.metainfo import MetaInfo
from app.core.config import settings, global_vars
from app.core.module import ModuleManager
from app.core.security import verify_apitoken, verify_resource_token, verify_token
from app.core.security import verify_token
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.helper.mediaserver import MediaServerHelper
from app.db.userauth import get_current_active_superuser
from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper
from app.helper.rule import RuleHelper
from app.helper.sites import SitesHelper
from app.log import logger
from app.monitor import Monitor
from app.scheduler import Scheduler
from app.schemas.types import SystemConfigKey
from app.utils.crypto import HashUtils
from app.utils.http import RequestUtils
from app.utils.security import SecurityUtils
from app.utils.system import SystemUtils
from app.utils.url import UrlUtils
from version import APP_VERSION
router = APIRouter()
def fetch_image(
url: str,
proxy: bool = False,
use_disk_cache: bool = False,
if_none_match: Optional[str] = None,
allowed_domains: Optional[set[str]] = None) -> Response:
"""
处理图片缓存逻辑支持HTTP缓存和磁盘缓存
"""
if not url:
raise HTTPException(status_code=404, detail="URL not provided")
if allowed_domains is None:
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS)
# 验证URL安全性
if not SecurityUtils.is_safe_url(url, allowed_domains):
raise HTTPException(status_code=404, detail="Unsafe URL")
# 后续观察系统性能表现如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求可以考虑重新引入内存缓存
cache_path = None
if use_disk_cache:
# 生成缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = settings.CACHE_PATH / "images" / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
raise HTTPException(status_code=400, detail="Invalid cache path or file type")
# 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理
if cache_path.exists():
try:
content = cache_path.read_bytes()
etag = HashUtils.md5(content)
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
if if_none_match == etag:
return Response(status_code=304, headers=headers)
return Response(content=content, media_type="image/jpeg", headers=headers)
except Exception as e:
# 如果读取磁盘缓存发生异常,这里仅记录日志,尝试再次请求远端进行处理
logger.debug(f"Failed to read cache file {cache_path}: {e}")
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if proxy else None
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer,
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
if not response:
raise HTTPException(status_code=502, detail="Failed to fetch the image from the remote server")
# 验证下载的内容是否为有效图片
try:
Image.open(io.BytesIO(response.content)).verify()
except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}")
raise HTTPException(status_code=502, detail="Invalid image format")
content = response.content
response_headers = response.headers
cache_control_header = response_headers.get("Cache-Control", "")
cache_directive, max_age = RequestUtils.parse_cache_control(cache_control_header)
# 如果需要使用磁盘缓存,则保存到磁盘
if use_disk_cache and cache_path:
try:
if not cache_path.parent.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
tmp_file.write(content)
temp_path = Path(tmp_file.name)
temp_path.replace(cache_path)
except Exception as e:
logger.debug(f"Failed to write cache file {cache_path}: {e}")
# 检查 If-None-Match
etag = HashUtils.md5(content)
if if_none_match == etag:
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
return Response(status_code=304, headers=headers)
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
return Response(
content=content,
media_type=response_headers.get("Content-Type") or UrlUtils.get_mime_type(url, "image/jpeg"),
headers=headers
)
@router.get("/img/{proxy}", summary="图片代理")
def proxy_img(
imgurl: str,
proxy: bool = False,
if_none_match: Optional[str] = Header(None),
_: schemas.TokenPayload = Depends(verify_resource_token)
) -> Response:
def get_img(imgurl: str, proxy: bool = False) -> Any:
"""
图片代理,可选是否使用代理服务器,支持 HTTP 缓存
通过图片代理使用代理服务器
"""
# 媒体服务器添加图片代理支持
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
config and config.config and config.config.get("host")]
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
return fetch_image(url=imgurl, proxy=proxy, use_disk_cache=False,
if_none_match=if_none_match, allowed_domains=allowed_domains)
if not imgurl:
return None
if proxy:
response = RequestUtils(ua=settings.USER_AGENT, proxies=settings.PROXY).get_res(url=imgurl)
else:
response = RequestUtils(ua=settings.USER_AGENT).get_res(url=imgurl)
if response:
return Response(content=response.content, media_type="image/jpeg")
return None
@router.get("/cache/image", summary="图片缓存")
def cache_img(
url: str,
if_none_match: Optional[str] = Header(None),
_: schemas.TokenPayload = Depends(verify_resource_token)
) -> Response:
"""
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
"""
# 如果没有启用全局图片缓存,则不使用磁盘缓存
proxy = "doubanio.com" not in url
return fetch_image(url=url, proxy=proxy, use_disk_cache=settings.GLOBAL_IMAGE_CACHE, if_none_match=if_none_match)
@router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response)
def get_global_setting():
"""
查询非敏感系统设置(无需鉴权)
"""
# FIXME: 新增敏感配置项时要在此处添加排除项
info = settings.dict(
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN"}
)
# 追加用户唯一ID
info.update({
"USER_UNIQUE_ID": SystemUtils.generate_user_unique_id()
})
return schemas.Response(success=True,
data=info)
@router.get("/env", summary="查询系统配置", response_model=schemas.Response)
@router.get("/env", summary="查询系统环境变量", response_model=schemas.Response)
def get_env_setting(_: User = Depends(get_current_active_superuser)):
"""
查询系统环境变量,包括当前版本号(仅管理员)
查询系统环境变量,包括当前版本号
"""
info = settings.dict(
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}
exclude={"SECRET_KEY", "SUPERUSER_PASSWORD"}
)
info.update({
"VERSION": APP_VERSION,
@@ -205,53 +63,47 @@ def get_env_setting(_: User = Depends(get_current_active_superuser)):
data=info)
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
@router.post("/env", summary="更新系统环境变量", response_model=schemas.Response)
def set_env_setting(env: dict,
_: User = Depends(get_current_active_superuser)):
"""
更新系统环境变量(仅管理员)
更新系统环境变量
"""
result = settings.update_settings(env=env)
# 统计成功和失败的结果
success_updates = {k: v for k, v in result.items() if v[0]}
failed_updates = {k: v for k, v in result.items() if not v[0]}
if failed_updates:
return schemas.Response(
success=False,
message="部分配置项更新失败",
data={
"success_updates": success_updates,
"failed_updates": failed_updates
}
)
return schemas.Response(
success=True,
message="所有配置项更新成功",
data={
"success_updates": success_updates
}
)
for k, v in env.items():
if k == "undefined":
continue
if hasattr(settings, k):
if v == "None":
v = None
setattr(settings, k, v)
if v is None:
v = ''
else:
v = str(v)
set_key(settings.CONFIG_PATH / "app.env", k, v)
return schemas.Response(success=True)
@router.get("/progress/{process_type}", summary="实时进度")
async def get_progress(request: Request, process_type: str, _: schemas.TokenPayload = Depends(verify_resource_token)):
def get_progress(process_type: str, token: str):
"""
实时获取处理进度返回格式为SSE
"""
if not token or not verify_token(token):
raise HTTPException(
status_code=403,
detail="认证失败!",
)
progress = ProgressHelper()
async def event_generator():
try:
while not global_vars.is_system_stopped:
if await request.is_disconnected():
break
detail = progress.get(process_type)
yield f"data: {json.dumps(detail)}\n\n"
await asyncio.sleep(0.2)
except asyncio.CancelledError:
return
def event_generator():
while True:
if global_vars.is_system_stopped():
break
detail = progress.get(process_type)
yield 'data: %s\n\n' % json.dumps(detail)
time.sleep(0.2)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@@ -260,7 +112,7 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
def get_setting(key: str,
_: User = Depends(get_current_active_superuser)):
"""
查询系统设置(仅管理员)
查询系统设置
"""
if hasattr(settings, key):
value = getattr(settings, key)
@@ -275,89 +127,82 @@ def get_setting(key: str,
def set_setting(key: str, value: Union[list, dict, bool, int, str] = None,
_: User = Depends(get_current_active_superuser)):
"""
更新系统设置(仅管理员)
更新系统设置
"""
if hasattr(settings, key):
success, message = settings.update_setting(key=key, value=value)
return schemas.Response(success=success, message=message)
elif key in {item.value for item in SystemConfigKey}:
SystemConfigOper().set(key, value)
return schemas.Response(success=True)
if value == "None":
value = None
setattr(settings, key, value)
if value is None:
value = ''
else:
value = str(value)
set_key(settings.CONFIG_PATH / "app.env", key, value)
else:
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
SystemConfigOper().set(key, value)
return schemas.Response(success=True)
@router.get("/message", summary="实时消息")
async def get_message(request: Request, role: str = "system", _: schemas.TokenPayload = Depends(verify_resource_token)):
def get_message(token: str, role: str = "system"):
"""
实时获取系统消息返回格式为SSE
"""
if not token or not verify_token(token):
raise HTTPException(
status_code=403,
detail="认证失败!",
)
message = MessageHelper()
async def event_generator():
try:
while not global_vars.is_system_stopped:
if await request.is_disconnected():
break
detail = message.get(role)
yield f"data: {detail or ''}\n\n"
await asyncio.sleep(3)
except asyncio.CancelledError:
return
def event_generator():
while True:
if global_vars.is_system_stopped():
break
detail = message.get(role)
yield 'data: %s\n\n' % (detail or '')
time.sleep(3)
return StreamingResponse(event_generator(), media_type="text/event-stream")
@router.get("/logging", summary="实时日志")
async def get_logging(request: Request, length: int = 50, logfile: str = "moviepilot.log",
_: schemas.TokenPayload = Depends(verify_resource_token)):
def get_logging(token: str, length: int = 50, logfile: str = "moviepilot.log"):
"""
实时获取系统日志
length = -1 时, 返回text/plain
否则 返回格式SSE
"""
if not token or not verify_token(token):
raise HTTPException(
status_code=403,
detail="认证失败!",
)
log_path = settings.LOG_PATH / logfile
if not SecurityUtils.is_safe_path(settings.LOG_PATH, log_path, allowed_suffixes={".log"}):
raise HTTPException(status_code=404, detail="Not Found")
if not log_path.exists() or not log_path.is_file():
raise HTTPException(status_code=404, detail="Not Found")
async def log_generator():
try:
# 使用固定大小的双向队列来限制内存使用
lines_queue = deque(maxlen=max(length, 50))
# 使用 aiofiles 异步读取文件
async with aiofiles.open(log_path, mode="r", encoding="utf-8") as f:
# 逐行读取文件,将每一行存入队列
file_content = await f.read()
for line in file_content.splitlines():
lines_queue.append(line)
for line in lines_queue:
yield f"data: {line}\n\n"
# 移动文件指针到文件末尾,继续监听新增内容
await f.seek(0, 2)
while not global_vars.is_system_stopped:
if await request.is_disconnected():
break
line = await f.readline()
if not line:
await asyncio.sleep(0.5)
continue
yield f"data: {line}\n\n"
except asyncio.CancelledError:
return
def log_generator():
# 读取文件末尾50行不使用tailer模块
with open(log_path, 'r', encoding='utf-8') as f:
for line in f.readlines()[-max(length, 50):]:
yield 'data: %s\n\n' % line
while True:
if global_vars.is_system_stopped():
break
for t in tailer.follow(open(log_path, 'r', encoding='utf-8')):
yield 'data: %s\n\n' % (t or '')
time.sleep(1)
# 根据length参数返回不同的响应
if length == -1:
# 返回全部日志作为文本响应
if not log_path.exists():
return Response(content="日志文件不存在!", media_type="text/plain")
with open(log_path, "r", encoding='utf-8') as file:
with open(log_path, 'r', encoding='utf-8') as file:
text = file.read()
# 倒序输出
text = "\n".join(text.split("\n")[::-1])
text = '\n'.join(text.split('\n')[::-1])
return Response(content=text, media_type="text/plain")
else:
# 返回SSE流响应
@@ -378,10 +223,10 @@ def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
return schemas.Response(success=False)
@router.get("/ruletest", summary="过滤规则测试", response_model=schemas.Response)
@router.get("/ruletest", summary="优先级规则测试", response_model=schemas.Response)
def ruletest(title: str,
rulegroup_name: str,
subtitle: str = None,
ruletype: str = None,
_: schemas.TokenPayload = Depends(verify_token)):
"""
过滤规则测试,规则类型 1-订阅2-洗版3-搜索
@@ -390,21 +235,20 @@ def ruletest(title: str,
title=title,
description=subtitle,
)
# 查询规则组详情
rulegroup = RuleHelper().get_rule_group(rulegroup_name)
if not rulegroup:
return schemas.Response(success=False, message=f"过滤规则组 {rulegroup_name} 不存在!")
# 根据标题查询媒体信息
media_info = SearchChain().recognize_media(MetaInfo(title=title, subtitle=subtitle))
if not media_info:
return schemas.Response(success=False, message="未识别到媒体信息!")
if ruletype == "2":
rule_string = SystemConfigOper().get(SystemConfigKey.BestVersionFilterRules)
elif ruletype == "3":
rule_string = SystemConfigOper().get(SystemConfigKey.SearchFilterRules)
else:
rule_string = SystemConfigOper().get(SystemConfigKey.SubscribeFilterRules)
if not rule_string:
return schemas.Response(success=False, message="优先级规则未设置!")
# 过滤
result = SearchChain().filter_torrents(rule_groups=[rulegroup.name],
torrent_list=[torrent], mediainfo=media_info)
result = SearchChain().filter_torrents(rule_string=rule_string,
torrent_list=[torrent])
if not result:
return schemas.Response(success=False, message="不符合过滤规则!")
return schemas.Response(success=False, message="不符合优先级规则!")
return schemas.Response(success=True, data={
"priority": 100 - result[0].pri_order + 1
})
@@ -463,7 +307,7 @@ def moduletest(moduleid: str, _: schemas.TokenPayload = Depends(verify_token)):
@router.get("/restart", summary="重启系统", response_model=schemas.Response)
def restart_system(_: User = Depends(get_current_active_superuser)):
"""
重启系统(仅管理员)
重启系统
"""
if not SystemUtils.can_restart():
return schemas.Response(success=False, message="当前运行环境不支持重启操作!")
@@ -477,34 +321,20 @@ def restart_system(_: User = Depends(get_current_active_superuser)):
@router.get("/reload", summary="重新加载模块", response_model=schemas.Response)
def reload_module(_: User = Depends(get_current_active_superuser)):
"""
重新加载模块(仅管理员)
重新加载模块
"""
ModuleManager().reload()
Scheduler().init()
Monitor().init()
return schemas.Response(success=True)
@router.get("/runscheduler", summary="运行服务", response_model=schemas.Response)
def run_scheduler(jobid: str,
_: User = Depends(get_current_active_superuser)):
def execute_command(jobid: str,
_: User = Depends(get_current_active_superuser)):
"""
执行命令(仅管理员)
执行命令
"""
if not jobid:
return schemas.Response(success=False, message="命令不能为空!")
Scheduler().start(jobid)
return schemas.Response(success=True)
@router.get("/runscheduler2", summary="运行服务API_TOKEN", response_model=schemas.Response)
def run_scheduler2(jobid: str,
_: str = Depends(verify_apitoken)):
"""
执行命令API_TOKEN认证
"""
if not jobid:
return schemas.Response(success=False, message="命令不能为空!")
Scheduler().start(jobid)
return schemas.Response(success=True)

View File

@@ -59,20 +59,6 @@ def tmdb_recommend(tmdbid: int,
return []
@router.get("/collection/{collection_id}", summary="系列合集详情", response_model=List[schemas.MediaInfo])
def tmdb_collection(collection_id: int,
page: int = 1,
count: int = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据合集ID查询合集详情
"""
medias = TmdbChain().tmdb_collection(collection_id=collection_id)
if medias:
return [media.to_dict() for media in medias][(page - 1) * count:page * count]
return []
@router.get("/credits/{tmdbid}/{type_name}", summary="演员阵容", response_model=List[schemas.MediaPerson])
def tmdb_credits(tmdbid: int,
type_name: str,
@@ -113,6 +99,56 @@ def tmdb_person_credits(person_id: int,
return []
@router.get("/movies", summary="TMDB电影", response_model=List[schemas.MediaInfo])
def tmdb_movies(sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "",
page: int = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB电影信息
"""
movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
page=page)
if not movies:
return []
return [movie.to_dict() for movie in movies]
@router.get("/tvs", summary="TMDB剧集", response_model=List[schemas.MediaInfo])
def tmdb_tvs(sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "",
page: int = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
page=page)
if not tvs:
return []
return [tv.to_dict() for tv in tvs]
@router.get("/trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo])
def tmdb_trending(page: int = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
infos = TmdbChain().tmdb_trending(page=page)
if not infos:
return []
return [info.to_dict() for info in infos]
@router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode])
def tmdb_season_episodes(tmdbid: int, season: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:

View File

@@ -1,19 +1,17 @@
from pathlib import Path
from typing import Any, List
from typing import Any
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app import schemas
from app.chain.media import MediaChain
from app.chain.storage import StorageChain
from app.chain.transfer import TransferChain
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser
from app.schemas import MediaType, FileItem, ManualTransferItem
from app.schemas import MediaType
router = APIRouter()
@@ -47,113 +45,103 @@ def query_name(path: str, filetype: str,
})
@router.get("/queue", summary="查询整理队列", response_model=List[schemas.TransferJob])
def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询整理队列
:param _: Token校验
"""
return TransferChain().get_queue_tasks()
@router.delete("/queue", summary="从整理队列中删除任务", response_model=schemas.Response)
def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询整理队列
:param fileitem: 文件项
:param _: Token校验
"""
TransferChain().remove_from_queue(fileitem)
return schemas.Response(success=True)
@router.post("/manual", summary="手动转移", response_model=schemas.Response)
def manual_transfer(transer_item: ManualTransferItem,
background: bool = False,
def manual_transfer(storage: str = "local",
path: str = None,
drive_id: str = None,
fileid: str = None,
filetype: str = None,
logid: int = None,
target: str = None,
tmdbid: int = None,
doubanid: str = None,
type_name: str = None,
season: int = None,
transfer_type: str = None,
episode_format: str = None,
episode_detail: str = None,
episode_part: str = None,
episode_offset: int = 0,
min_filesize: int = 0,
scrape: bool = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
手动转移,文件或历史记录,支持自定义剧集识别格式
:param transer_item: 手工整理项
:param background: 后台运行
:param storage: 存储类型local/aliyun/u115
:param path: 转移路径或文件
:param drive_id: 云盘ID网盘等
:param fileid: 文件ID网盘等
:param filetype: 文件类型dir/file
:param logid: 转移历史记录ID
:param target: 目标路径
:param type_name: 媒体类型、电影/电视剧
:param tmdbid: tmdbid
:param doubanid: 豆瓣ID
:param season: 剧集季号
:param transfer_type: 转移类型move/copy 等
:param episode_format: 剧集识别格式
:param episode_detail: 剧集识别详细信息
:param episode_part: 剧集识别分集信息
:param episode_offset: 剧集识别偏移量
:param min_filesize: 最小文件大小(MB)
:param scrape: 是否刮削元数据
:param db: 数据库
:param _: Token校验
"""
force = False
target_path = Path(transer_item.target_path) if transer_item.target_path else None
if transer_item.logid:
target = Path(target) if target else None
transfer = TransferChain()
if logid:
# 查询历史记录
history: TransferHistory = TransferHistory.get(db, transer_item.logid)
history: TransferHistory = TransferHistory.get(db, logid)
if not history:
return schemas.Response(success=False, message=f"整理记录不存在ID{transer_item.logid}")
return schemas.Response(success=False, message=f"历史记录不存在ID{logid}")
# 强制转移
force = True
if history.status and ("move" in history.mode):
# 重新整理成功的转移,则使用成功的 dest 做 in_path
src_fileitem = FileItem(**history.dest_fileitem)
in_path = Path(history.dest)
else:
# 源路径
src_fileitem = FileItem(**history.src_fileitem)
in_path = Path(history.src)
# 目的路径
if history.dest_fileitem:
if history.dest and str(history.dest) != "None":
# 删除旧的已整理文件
dest_fileitem = FileItem(**history.dest_fileitem)
state = StorageChain().delete_media_file(dest_fileitem, mtype=MediaType(history.type))
if not state:
return schemas.Response(success=False, message=f"{dest_fileitem.path} 删除失败")
# 从历史数据获取信息
if transer_item.from_history:
transer_item.type_name = history.type if history.type else transer_item.type_name
transer_item.tmdbid = int(history.tmdbid) if history.tmdbid else transer_item.tmdbid
transer_item.doubanid = str(history.doubanid) if history.doubanid else transer_item.doubanid
transer_item.season = int(str(history.seasons).replace("S", "")) if history.seasons else transer_item.season
if history.episodes:
if "-" in str(history.episodes):
# E01-E03多集合并
episode_start, episode_end = str(history.episodes).split("-")
episode_list: list[int] = []
for i in range(int(episode_start.replace("E", "")), int(episode_end.replace("E", "")) + 1):
episode_list.append(i)
transer_item.episode_detail = ",".join(str(e) for e in episode_list)
else:
# E01单集
transer_item.episode_detail = str(history.episodes).replace("E", "")
elif transer_item.fileitem:
src_fileitem = transer_item.fileitem
transfer.delete_files(Path(history.dest))
elif path:
in_path = Path(path)
else:
return schemas.Response(success=False, message=f"缺少参数")
return schemas.Response(success=False, message=f"缺少参数path/logid")
# 类型
mtype = MediaType(transer_item.type_name) if transer_item.type_name else None
mtype = MediaType(type_name) if type_name else None
# 自定义格式
epformat = None
if transer_item.episode_offset or transer_item.episode_part \
or transer_item.episode_detail or transer_item.episode_format:
if episode_offset or episode_part or episode_detail or episode_format:
epformat = schemas.EpisodeFormat(
format=transer_item.episode_format,
detail=transer_item.episode_detail,
part=transer_item.episode_part,
offset=transer_item.episode_offset,
format=episode_format,
detail=episode_detail,
part=episode_part,
offset=episode_offset,
)
# 开始转移
state, errormsg = TransferChain().manual_transfer(
fileitem=src_fileitem,
target_storage=transer_item.target_storage,
target_path=target_path,
tmdbid=transer_item.tmdbid,
doubanid=transer_item.doubanid,
state, errormsg = transfer.manual_transfer(
storage=storage,
in_path=in_path,
drive_id=drive_id,
fileid=fileid,
filetype=filetype,
target=target,
tmdbid=tmdbid,
doubanid=doubanid,
mtype=mtype,
season=transer_item.season,
transfer_type=transer_item.transfer_type,
season=season,
transfer_type=transfer_type,
epformat=epformat,
min_filesize=transer_item.min_filesize,
scrape=transer_item.scrape,
library_type_folder=transer_item.library_type_folder,
library_category_folder=transer_item.library_category_folder,
force=force,
background=background
min_filesize=min_filesize,
scrape=scrape,
force=force
)
# 失败
if not state:

213
app/api/endpoints/u115.py Normal file
View File

@@ -0,0 +1,213 @@
from pathlib import Path
from typing import Any, List
from fastapi import APIRouter, Depends, HTTPException
from starlette.responses import Response
from app import schemas
from app.chain.transfer import TransferChain
from app.core.config import settings
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token, verify_uri_token
from app.helper.progress import ProgressHelper
from app.helper.u115 import U115Helper
from app.schemas.types import ProgressKey
from app.utils.http import RequestUtils
router = APIRouter()
@router.get("/qrcode", summary="生成二维码内容", response_model=schemas.Response)
def qrcode(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
生成二维码
"""
qrcode_data = U115Helper().generate_qrcode()
if qrcode_data:
return schemas.Response(success=True, data={
'codeContent': qrcode_data
})
return schemas.Response(success=False)
@router.get("/check", summary="二维码登录确认", response_model=schemas.Response)
def check(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
二维码登录确认
"""
data, errmsg = U115Helper().check_login()
if data:
return schemas.Response(success=True, data=data)
return schemas.Response(success=False, message=errmsg)
@router.get("/storage", summary="查询存储空间信息", response_model=schemas.Response)
def storage(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询存储空间信息
"""
storage_info = U115Helper().storage()
if storage_info:
return schemas.Response(success=True, data={
"total": storage_info[0],
"used": storage_info[1]
})
return schemas.Response(success=False)
@router.post("/list", summary="所有目录和文件115网盘", response_model=List[schemas.FileItem])
def list_115(fileitem: schemas.FileItem,
sort: str = 'updated_at',
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询当前目录下所有目录和文件
:param fileitem: 文件项
:param sort: 排序方式name:按名称排序time:按修改时间排序
:param _: token
:return: 所有目录和文件
"""
if not fileitem.fileid:
return []
if not fileitem.path:
path = "/"
else:
path = fileitem.path
if fileitem.fileid == "root":
fileid = "0"
else:
fileid = fileitem.fileid
if fileitem.type == "file":
name = Path(path).name
suffix = Path(name).suffix[1:]
return [schemas.FileItem(
fileid=fileid,
type="file",
path=path.rstrip('/'),
name=name,
extension=suffix,
pickcode=fileitem.pickcode
)]
file_list = U115Helper().list(parent_file_id=fileid, path=path)
if sort == "name":
file_list.sort(key=lambda x: x.name)
else:
file_list.sort(key=lambda x: x.modify_time, reverse=True)
return file_list
@router.post("/mkdir", summary="创建目录115网盘", response_model=schemas.Response)
def mkdir_115(fileitem: schemas.FileItem,
name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
创建目录
"""
if not fileitem.fileid or not name:
return schemas.Response(success=False)
result = U115Helper().create_folder(parent_file_id=fileitem.fileid, name=name, path=fileitem.path)
if result:
return schemas.Response(success=True)
return schemas.Response(success=False)
@router.post("/delete", summary="删除文件或目录115网盘", response_model=schemas.Response)
def delete_115(fileitem: schemas.FileItem,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除文件或目录
"""
if not fileitem.fileid:
return schemas.Response(success=False)
result = U115Helper().delete(fileitem.fileid)
if result:
return schemas.Response(success=True)
return schemas.Response(success=False)
@router.get("/download", summary="下载文件115网盘")
def download_115(pickcode: str,
_: schemas.TokenPayload = Depends(verify_uri_token)) -> Any:
"""
下载文件或目录
"""
if not pickcode:
return schemas.Response(success=False)
ticket = U115Helper().download(pickcode)
if ticket:
# 请求数据,并以文件流的方式返回
res = RequestUtils(headers=ticket.headers).get_res(ticket.url)
if res:
return Response(content=res.content, media_type="application/octet-stream")
return schemas.Response(success=False)
@router.post("/rename", summary="重命名文件或目录115网盘", response_model=schemas.Response)
def rename_115(fileitem: schemas.FileItem,
new_name: str,
recursive: bool = False,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
重命名文件或目录
"""
if not fileitem.fileid or not new_name:
return schemas.Response(success=False)
result = U115Helper().rename(fileitem.fileid, new_name)
if result:
if recursive:
transferchain = TransferChain()
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIO_TRACK_EXT
# 递归修改目录内文件(智能识别命名)
sub_files: List[schemas.FileItem] = list_115(fileitem)
if sub_files:
# 开始进度
progress = ProgressHelper()
progress.start(ProgressKey.BatchRename)
total = len(sub_files)
handled = 0
for sub_file in sub_files:
handled += 1
progress.update(value=handled / total * 100,
text=f"正在处理 {sub_file.name} ...",
key=ProgressKey.BatchRename)
if sub_file.type == "dir":
continue
if not sub_file.extension:
continue
if f".{sub_file.extension.lower()}" not in media_exts:
continue
sub_path = Path(f"{fileitem.path}{sub_file.name}")
meta = MetaInfoPath(sub_path)
mediainfo = transferchain.recognize_media(meta)
if not mediainfo:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到媒体信息")
new_path = transferchain.recommend_name(meta=meta, mediainfo=mediainfo)
if not new_path:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到新名称")
ret: schemas.Response = rename_115(fileitem=sub_file,
new_name=Path(new_path).name,
recursive=False)
if not ret.success:
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=False, message=f"{sub_path.name} 重命名失败!")
progress.end(ProgressKey.BatchRename)
return schemas.Response(success=True)
return schemas.Response(success=False)
@router.get("/image", summary="读取图片115网盘")
def image_115(pickcode: str, _: schemas.TokenPayload = Depends(verify_uri_token)) -> Any:
"""
读取图片
"""
if not pickcode:
return schemas.Response(success=False)
ticket = U115Helper().download(pickcode)
if ticket:
# 请求数据获取内容编码为图片base64返回
res = RequestUtils(headers=ticket.headers).get_res(ticket.url)
if res:
content_type = res.headers.get("Content-Type")
return Response(content=res.content, media_type=content_type)
raise HTTPException(status_code=500, detail="下载图片出错")

View File

@@ -9,7 +9,7 @@ from app import schemas
from app.core.security import get_password_hash
from app.db import get_db
from app.db.models.user import User
from app.db.user_oper import get_current_active_superuser, get_current_active_user
from app.db.userauth import get_current_active_superuser, get_current_active_user
from app.db.userconfig_oper import UserConfigOper
from app.utils.otp import OtpUtils
@@ -17,7 +17,7 @@ router = APIRouter()
@router.get("/", summary="所有用户", response_model=List[schemas.User])
def list_users(
def read_users(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_superuser),
) -> Any:
@@ -54,7 +54,7 @@ def create_user(
def update_user(
*,
db: Session = Depends(get_db),
user_in: schemas.UserUpdate,
user_in: schemas.UserCreate,
_: User = Depends(get_current_active_superuser),
) -> Any:
"""
@@ -69,15 +69,7 @@ def update_user(
message="密码需要同时包含字母、数字、特殊字符中的至少两项且长度大于6位")
user_info["hashed_password"] = get_password_hash(user_info["password"])
user_info.pop("password")
user = User.get_by_id(db, user_id=user_info["id"])
user_name = user_info.get("name")
if not user_name:
return schemas.Response(success=False, message="用户名不能为空")
# 新用户名去重
users = User.list(db)
for u in users:
if u.name == user_name and u.id != user_info["id"]:
return schemas.Response(success=False, message="用户名已被使用")
user = User.get_by_name(db, name=user_info["name"])
if not user:
return schemas.Response(success=False, message="用户不存在")
user.update(db, user_info)
@@ -147,7 +139,7 @@ def otp_disable(
def otp_enable(userid: str, db: Session = Depends(get_db)) -> Any:
user: User = User.get_by_name(db, userid)
if not user:
return schemas.Response(success=False)
return schemas.Response(success=False, message="用户不存在")
return schemas.Response(success=user.is_otp)
@@ -173,32 +165,15 @@ def set_config(key: str, value: Union[list, dict, bool, int, str] = None,
return schemas.Response(success=True)
@router.delete("/id/{user_id}", summary="删除用户", response_model=schemas.Response)
def delete_user_by_id(
*,
db: Session = Depends(get_db),
user_id: int,
current_user: User = Depends(get_current_active_superuser),
) -> Any:
"""
通过唯一ID删除用户
"""
user = current_user.get_by_id(db, user_id=user_id)
if not user:
return schemas.Response(success=False, message="用户不存在")
user.delete_by_id(db, user_id)
return schemas.Response(success=True)
@router.delete("/name/{user_name}", summary="删除用户", response_model=schemas.Response)
def delete_user_by_name(
@router.delete("/{user_name}", summary="删除用户", response_model=schemas.Response)
def delete_user(
*,
db: Session = Depends(get_db),
user_name: str,
current_user: User = Depends(get_current_active_superuser),
) -> Any:
"""
通过用户名删除用户
删除用户
"""
user = current_user.get_by_name(db, name=user_name)
if not user:
@@ -207,16 +182,16 @@ def delete_user_by_name(
return schemas.Response(success=True)
@router.get("/{username}", summary="用户详情", response_model=schemas.User)
def read_user_by_name(
username: str,
@router.get("/{user_id}", summary="用户详情", response_model=schemas.User)
def read_user_by_id(
user_id: int,
current_user: User = Depends(get_current_active_user),
db: Session = Depends(get_db),
) -> Any:
"""
查询用户详情
"""
user = current_user.get_by_name(db, name=username)
user = current_user.get(db, rid=user_id)
if not user:
raise HTTPException(
status_code=404,
@@ -224,7 +199,7 @@ def read_user_by_name(
)
if user == current_user:
return user
if not current_user.is_superuser:
if not user.is_superuser:
raise HTTPException(
status_code=400,
detail="用户权限不足"

View File

@@ -22,7 +22,7 @@ async def webhook_message(background_tasks: BackgroundTasks,
_: str = Depends(verify_apitoken)
) -> Any:
"""
Webhook响应配置请求中需要添加参数token=API_TOKEN&source=媒体服务器名
Webhook响应
"""
body = await request.body()
form = await request.form()
@@ -35,7 +35,7 @@ async def webhook_message(background_tasks: BackgroundTasks,
def webhook_message(background_tasks: BackgroundTasks,
request: Request, _: str = Depends(verify_apitoken)) -> Any:
"""
Webhook响应配置请求中需要添加参数token=API_TOKEN&source=媒体服务器名
Webhook响应
"""
args = request.query_params
background_tasks.add_task(start_webhook_chain, None, None, args)

View File

@@ -1,3 +0,0 @@
from fastapi import APIRouter
router = APIRouter()

View File

@@ -680,14 +680,6 @@ def arr_add_series(tv: schemas.SonarrSeries,
)
@arr_router.put("/series", summary="更新剧集订阅")
def arr_update_series(tv: schemas.SonarrSeries) -> Any:
"""
更新Sonarr剧集订阅
"""
return arr_add_series(tv)
@arr_router.delete("/series/{tid}", summary="删除剧集订阅")
def arr_remove_series(tid: int, db: Session = Depends(get_db), _: str = Depends(verify_apikey)) -> Any:
"""

View File

@@ -1,6 +1,8 @@
import gzip
import json
from typing import Annotated, Callable, Any, Dict, Optional
from hashlib import md5
from typing import Annotated, Callable
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
from fastapi.responses import PlainTextResponse
@@ -9,7 +11,7 @@ from fastapi.routing import APIRoute
from app import schemas
from app.core.config import settings
from app.log import logger
from app.utils.crypto import CryptoJsUtils, HashUtils
from app.utils.common import decrypt
class GzipRequest(Request):
@@ -19,7 +21,7 @@ class GzipRequest(Request):
body = await super().body()
if "gzip" in self.headers.getlist("Content-Encoding"):
body = gzip.decompress(body)
self._body = body # noqa
self._body = body
return self._body
@@ -45,7 +47,7 @@ async def verify_server_enabled():
cookie_router = APIRouter(route_class=GzipRoute,
tags=["servcookie"],
tags=['servcookie'],
dependencies=[Depends(verify_server_enabled)])
@@ -98,14 +100,15 @@ def get_decrypted_cookie_data(uuid: str, password: str,
"""
加载本地加密数据并解密为Cookie
"""
combined_string = f"{uuid}-{password}"
aes_key = HashUtils.md5(combined_string)[:16].encode("utf-8")
key_md5 = md5()
key_md5.update((uuid + '-' + password).encode('utf-8'))
aes_key = (key_md5.hexdigest()[:16]).encode('utf-8')
if encrypted:
try:
decrypted_data = CryptoJsUtils.decrypt(encrypted, aes_key).decode("utf-8")
decrypted_data = decrypt(encrypted, aes_key).decode('utf-8')
decrypted_data = json.loads(decrypted_data)
if "cookie_data" in decrypted_data:
if 'cookie_data' in decrypted_data:
return decrypted_data
except Exception as e:
logger.error(f"解密Cookie数据失败{str(e)}")

View File

@@ -1,4 +1,3 @@
import copy
import gc
import pickle
import traceback
@@ -7,6 +6,7 @@ from pathlib import Path
from typing import Optional, Any, Tuple, List, Set, Union, Dict
from qbittorrentapi import TorrentFilesList
from ruamel.yaml import CommentedMap
from transmission_rpc import File
from app.core.config import settings
@@ -15,12 +15,10 @@ from app.core.event import EventManager
from app.core.meta import MetaBase
from app.core.module import ModuleManager
from app.db.message_oper import MessageOper
from app.db.user_oper import UserOper
from app.helper.message import MessageHelper
from app.helper.service import ServiceConfigHelper
from app.log import logger
from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \
WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem, TransferDirectoryConf
WebhookEventInfo, TmdbEpisode, MediaPerson
from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType
from app.utils.object import ObjectUtils
@@ -38,7 +36,6 @@ class ChainBase(metaclass=ABCMeta):
self.eventmanager = EventManager()
self.messageoper = MessageOper()
self.messagehelper = MessageHelper()
self.useroper = UserOper()
@staticmethod
def load_cache(filename: str) -> Any:
@@ -61,7 +58,7 @@ class ChainBase(metaclass=ABCMeta):
"""
try:
with open(settings.TEMP_PATH / filename, 'wb') as f:
pickle.dump(cache, f) # noqa
pickle.dump(cache, f)
except Exception as err:
logger.error(f"保存缓存 {filename} 出错:{str(err)}")
finally:
@@ -76,7 +73,7 @@ class ChainBase(metaclass=ABCMeta):
"""
cache_path = settings.TEMP_PATH / filename
if cache_path.exists():
cache_path.unlink()
Path(cache_path).unlink()
def run_module(self, method: str, *args, **kwargs) -> Any:
"""
@@ -96,14 +93,12 @@ class ChainBase(metaclass=ABCMeta):
logger.debug(f"请求模块执行:{method} ...")
result = None
modules = self.modulemanager.get_running_modules(method)
# 按优先级排序
modules = sorted(modules, key=lambda x: x.get_priority())
for module in modules:
module_id = module.__class__.__name__
try:
module_name = module.get_name()
except Exception as err:
logger.debug(f"获取模块名称出错:{str(err)}")
logger.error(f"获取模块名称出错:{str(err)}")
module_name = module_id
try:
func = getattr(module, method)
@@ -223,8 +218,7 @@ class ChainBase(metaclass=ABCMeta):
image_prefix=image_prefix, image_type=image_type,
season=season, episode=episode)
def douban_info(self, doubanid: str, mtype: MediaType = None,
raise_exception: bool = False) -> Optional[dict]:
def douban_info(self, doubanid: str, mtype: MediaType = None, raise_exception: bool = False) -> Optional[dict]:
"""
获取豆瓣信息
:param doubanid: 豆瓣ID
@@ -260,20 +254,19 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("bangumi_info", bangumiid=bangumiid)
def message_parser(self, source: str, body: Any, form: Any,
def message_parser(self, body: Any, form: Any,
args: Any) -> Optional[CommingMessage]:
"""
解析消息内容,返回字典,注意以下约定值:
userid: 用户ID
username: 用户名
text: 内容
:param source: 消息来源(渠道配置名称)
:param body: 请求体
:param form: 表单
:param args: 参数
:return: 消息渠道、消息内容
"""
return self.run_module("message_parser", source=source, body=body, form=form, args=args)
return self.run_module("message_parser", body=body, form=form, args=args)
def webhook_parser(self, body: Any, form: Any, args: Any) -> Optional[WebhookEventInfo]:
"""
@@ -300,14 +293,7 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("search_persons", name=name)
def search_collections(self, name: str) -> Optional[List[MediaInfo]]:
"""
搜索集合信息
:param name: 集合名称
"""
return self.run_module("search_collections", name=name)
def search_torrents(self, site: dict,
def search_torrents(self, site: CommentedMap,
keywords: List[str],
mtype: MediaType = None,
page: int = 0) -> List[TorrentInfo]:
@@ -322,34 +308,34 @@ class ChainBase(metaclass=ABCMeta):
return self.run_module("search_torrents", site=site, keywords=keywords,
mtype=mtype, page=page)
def refresh_torrents(self, site: dict, keyword: str = None, cat: str = None, page: int = 0) -> List[TorrentInfo]:
def refresh_torrents(self, site: CommentedMap) -> List[TorrentInfo]:
"""
获取站点最新一页的种子,多个站点需要多线程处理
:param site: 站点
:param keyword: 标题
:param cat: 分类
:param page: 页码
:reutrn: 种子资源列表
"""
return self.run_module("refresh_torrents", site=site, keyword=keyword, cat=cat, page=page)
return self.run_module("refresh_torrents", site=site)
def filter_torrents(self, rule_groups: List[str],
def filter_torrents(self, rule_string: str,
torrent_list: List[TorrentInfo],
season_episodes: Dict[int, list] = None,
mediainfo: MediaInfo = None) -> List[TorrentInfo]:
"""
过滤种子资源
:param rule_groups: 过滤规则组名称列表
:param rule_string: 过滤规则
:param torrent_list: 资源列表
:param season_episodes: 季集数过滤 {season:[episodes]}
:param mediainfo: 识别的媒体信息
:return: 过滤后的资源列表,添加资源优先级
"""
return self.run_module("filter_torrents", rule_groups=rule_groups,
torrent_list=torrent_list, mediainfo=mediainfo)
return self.run_module("filter_torrents", rule_string=rule_string,
torrent_list=torrent_list, season_episodes=season_episodes,
mediainfo=mediainfo)
def download(self, content: Union[Path, str], download_dir: Path, cookie: str,
episodes: Set[int] = None, category: str = None,
downloader: str = None
) -> Optional[Tuple[Optional[str], Optional[str], Optional[str], str]]:
downloader: str = settings.DEFAULT_DOWNLOADER
) -> Optional[Tuple[Optional[str], str]]:
"""
根据种子文件,选择并添加下载任务
:param content: 种子文件地址或者磁力链接
@@ -358,7 +344,7 @@ class ChainBase(metaclass=ABCMeta):
:param episodes: 需要下载的集数
:param category: 种子分类
:param downloader: 下载器
:return: 下载器名称、种子Hash、种子文件布局、错误原因
:return: 种子Hash错误信息
"""
return self.run_module("download", content=content, download_dir=download_dir,
cookie=cookie, episodes=episodes, category=category,
@@ -377,7 +363,7 @@ class ChainBase(metaclass=ABCMeta):
def list_torrents(self, status: TorrentStatus = None,
hashs: Union[list, str] = None,
downloader: str = None
downloader: str = settings.DEFAULT_DOWNLOADER
) -> Optional[List[Union[TransferTorrent, DownloadingTorrent]]]:
"""
获取下载器种子列表
@@ -388,46 +374,37 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("list_torrents", status=status, hashs=hashs, downloader=downloader)
def transfer(self, fileitem: FileItem, meta: MetaBase, mediainfo: MediaInfo,
target_directory: TransferDirectoryConf = None,
target_storage: str = None, target_path: Path = None,
transfer_type: str = None, scrape: bool = None,
library_type_folder: bool = None, library_category_folder: bool = None,
episodes_info: List[TmdbEpisode] = None) -> Optional[TransferInfo]:
def transfer(self, path: Path, meta: MetaBase, mediainfo: MediaInfo,
transfer_type: str, target: Path = None,
episodes_info: List[TmdbEpisode] = None,
scrape: bool = None) -> Optional[TransferInfo]:
"""
文件转移
:param fileitem: 文件信息
:param path: 文件路径
:param meta: 预识别的元数据
:param mediainfo: 识别的媒体信息
:param target_directory: 目标目录配置
:param target_storage: 目标存储
:param target_path: 目标路径
:param transfer_type: 转移模式
:param scrape: 是否刮削元数据
:param library_type_folder: 是否按类型创建目录
:param library_category_folder: 是否按类别创建目录
:param target: 转移目标路径
:param episodes_info: 当前季的全部集信息
:param scrape: 是否刮削元数据
:return: {path, target_path, message}
"""
return self.run_module("transfer",
fileitem=fileitem, meta=meta, mediainfo=mediainfo,
target_directory=target_directory,
target_path=target_path, target_storage=target_storage,
transfer_type=transfer_type, scrape=scrape,
library_type_folder=library_type_folder,
library_category_folder=library_category_folder,
episodes_info=episodes_info)
return self.run_module("transfer", path=path, meta=meta, mediainfo=mediainfo,
transfer_type=transfer_type, target=target, episodes_info=episodes_info,
scrape=scrape)
def transfer_completed(self, hashs: str, downloader: str = None) -> None:
def transfer_completed(self, hashs: str, path: Path = None,
downloader: str = settings.DEFAULT_DOWNLOADER) -> None:
"""
下载器转移完成后的处理
转移完成后的处理
:param hashs: 种子Hash
:param path: 源目录
:param downloader: 下载器
"""
return self.run_module("transfer_completed", hashs=hashs, downloader=downloader)
return self.run_module("transfer_completed", hashs=hashs, path=path, downloader=downloader)
def remove_torrents(self, hashs: Union[str, list], delete_file: bool = True,
downloader: str = None) -> bool:
downloader: str = settings.DEFAULT_DOWNLOADER) -> bool:
"""
删除下载器种子
:param hashs: 种子Hash
@@ -437,7 +414,7 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("remove_torrents", hashs=hashs, delete_file=delete_file, downloader=downloader)
def start_torrents(self, hashs: Union[list, str], downloader: str = None) -> bool:
def start_torrents(self, hashs: Union[list, str], downloader: str = settings.DEFAULT_DOWNLOADER) -> bool:
"""
开始下载
:param hashs: 种子Hash
@@ -446,7 +423,7 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("start_torrents", hashs=hashs, downloader=downloader)
def stop_torrents(self, hashs: Union[list, str], downloader: str = None) -> bool:
def stop_torrents(self, hashs: Union[list, str], downloader: str = settings.DEFAULT_DOWNLOADER) -> bool:
"""
停止下载
:param hashs: 种子Hash
@@ -456,7 +433,7 @@ class ChainBase(metaclass=ABCMeta):
return self.run_module("stop_torrents", hashs=hashs, downloader=downloader)
def torrent_files(self, tid: str,
downloader: str = None) -> Optional[Union[TorrentFilesList, List[File]]]:
downloader: str = settings.DEFAULT_DOWNLOADER) -> Optional[Union[TorrentFilesList, List[File]]]:
"""
获取种子文件
:param tid: 种子Hash
@@ -465,24 +442,14 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("torrent_files", tid=tid, downloader=downloader)
def media_exists(self, mediainfo: MediaInfo, itemid: str = None,
server: str = None) -> Optional[ExistMediaInfo]:
def media_exists(self, mediainfo: MediaInfo, itemid: str = None) -> Optional[ExistMediaInfo]:
"""
判断媒体文件是否存在
:param mediainfo: 识别的媒体信息
:param itemid: 媒体服务器ItemID
:param server: 媒体服务器
:return: 如不存在返回None存在时返回信息包括每季已存在所有集{type: movie/tv, seasons: {season: [episodes]}}
"""
return self.run_module("media_exists", mediainfo=mediainfo, itemid=itemid, server=server)
def media_files(self, mediainfo: MediaInfo) -> Optional[List[FileItem]]:
"""
获取媒体文件清单
:param mediainfo: 识别的媒体信息
:return: 媒体文件列表
"""
return self.run_module("media_files", mediainfo=mediainfo)
return self.run_module("media_exists", mediainfo=mediainfo, itemid=itemid)
def post_message(self, message: Notification) -> None:
"""
@@ -491,68 +458,29 @@ class ChainBase(metaclass=ABCMeta):
:return: 成功或失败
"""
logger.info(f"发送消息channel={message.channel}"
f"source={message.source},"
f"title={message.title}, "
f"text={message.text}"
f"userid={message.userid}")
# 保存原消息
self.messagehelper.put(message, role="user", title=message.title)
self.messageoper.add(**message.dict())
# 发送消息按设置隔离
if not message.userid and message.mtype:
# 消息隔离设置
notify_action = ServiceConfigHelper.get_notification_switch(message.mtype)
if notify_action:
# 'admin' 'user,admin' 'user' 'all'
actions = notify_action.split(",")
# 是否已发送管理员标志
admin_sended = False
send_orignal = False
for action in actions:
send_message = copy.deepcopy(message)
if action == "admin" and not admin_sended:
# 发送管理员
logger.info(f"{send_message.mtype} 的消息已设置发送给管理员")
# 读取管理员消息IDS
send_message.targets = self.useroper.get_settings(settings.SUPERUSER)
admin_sended = True
elif action == "user" and send_message.username:
# 发送对应用户
logger.info(f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}")
# 读取用户消息IDS
send_message.targets = self.useroper.get_settings(send_message.username)
if send_message.targets is None:
# 没有找到用户
if not admin_sended:
# 回滚发送管理员
logger.info(f"用户 {send_message.username} 不存在,消息将发送给管理员")
# 读取管理员消息IDS
send_message.targets = self.useroper.get_settings(settings.SUPERUSER)
admin_sended = True
else:
# 管理员发过了,此消息不发了
logger.info(f"用户 {send_message.username} 不存在,消息无法发送到对应用户")
continue
elif send_message.username == settings.SUPERUSER:
# 管理员同名已发送
admin_sended = True
else:
# 按原消息发送全体
if not admin_sended:
send_orignal = True
break
# 按设定发送
self.eventmanager.send_event(etype=EventType.NoticeMessage,
data={**send_message.dict(), "type": send_message.mtype})
self.run_module("post_message", message=send_message)
if not send_orignal:
return
# 发送消息事件
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
# 按原消息发送
# 发送事件
self.eventmanager.send_event(etype=EventType.NoticeMessage,
data={
"channel": message.channel,
"type": message.mtype,
"title": message.title,
"text": message.text,
"image": message.image,
"userid": message.userid,
})
# 保存消息
self.messagehelper.put(message, role="user")
self.messageoper.add(channel=message.channel, mtype=message.mtype,
title=message.title, text=message.text,
image=message.image, link=message.link,
userid=message.userid, action=1)
# 发送
self.run_module("post_message", message=message)
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> Optional[bool]:
"""
发送媒体信息选择列表
:param message: 消息体
@@ -560,11 +488,15 @@ class ChainBase(metaclass=ABCMeta):
:return: 成功或失败
"""
note_list = [media.to_dict() for media in medias]
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
self.messageoper.add(**message.dict(), note=note_list)
self.messagehelper.put(message, role="user", note=note_list)
self.messageoper.add(channel=message.channel, mtype=message.mtype,
title=message.title, text=message.text,
image=message.image, link=message.link,
userid=message.userid, action=1,
note=note_list)
return self.run_module("post_medias_message", message=message, medias=medias)
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> Optional[bool]:
"""
发送种子信息选择列表
:param message: 消息体
@@ -572,18 +504,36 @@ class ChainBase(metaclass=ABCMeta):
:return: 成功或失败
"""
note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
self.messageoper.add(**message.dict(), note=note_list)
self.messagehelper.put(message, role="user", note=note_list)
self.messageoper.add(channel=message.channel, mtype=message.mtype,
title=message.title, text=message.text,
image=message.image, link=message.link,
userid=message.userid, action=1,
note=note_list)
return self.run_module("post_torrents_message", message=message, torrents=torrents)
def metadata_img(self, mediainfo: MediaInfo, season: int = None, episode: int = None) -> Optional[dict]:
def scrape_metadata(self, path: Path, mediainfo: MediaInfo, transfer_type: str,
metainfo: MetaBase = None, force_nfo: bool = False, force_img: bool = False) -> None:
"""
刮削元数据
:param path: 媒体文件路径
:param mediainfo: 识别的媒体信息
:param metainfo: 源文件的识别元数据
:param transfer_type: 转移模式
:param force_nfo: 强制刮削nfo
:param force_img: 强制刮削图片
:return: 成功或失败
"""
self.run_module("scrape_metadata", path=path, mediainfo=mediainfo, metainfo=metainfo,
transfer_type=transfer_type, force_nfo=force_nfo, force_img=force_img)
def metadata_img(self, mediainfo: MediaInfo, season: int = None) -> Optional[dict]:
"""
获取图片名称和url
:param mediainfo: 媒体信息
:param season: 季号
:param episode: 集号
"""
return self.run_module("metadata_img", mediainfo=mediainfo, season=season, episode=episode)
return self.run_module("metadata_img", mediainfo=mediainfo, season=season)
def media_category(self) -> Optional[Dict[str, list]]:
"""

View File

@@ -17,12 +17,6 @@ class BangumiChain(ChainBase, metaclass=Singleton):
"""
return self.run_module("bangumi_calendar")
def discover(self, **kwargs) -> Optional[List[MediaInfo]]:
"""
发现Bangumi番剧
"""
return self.run_module("bangumi_discover", **kwargs)
def bangumi_info(self, bangumiid: int) -> Optional[dict]:
"""
获取Bangumi信息

View File

@@ -9,14 +9,14 @@ class DashboardChain(ChainBase, metaclass=Singleton):
"""
各类仪表板统计处理链
"""
def media_statistic(self, server: str = None) -> Optional[List[schemas.Statistic]]:
def media_statistic(self) -> Optional[List[schemas.Statistic]]:
"""
媒体数量统计
"""
return self.run_module("media_statistic", server=server)
return self.run_module("media_statistic")
def downloader_info(self, downloader: str = None) -> Optional[List[schemas.DownloaderInfo]]:
def downloader_info(self) -> Optional[List[schemas.DownloaderInfo]]:
"""
下载器信息
"""
return self.run_module("downloader_info", downloader=downloader)
return self.run_module("downloader_info")

View File

@@ -8,7 +8,7 @@ from typing import List, Optional, Tuple, Set, Dict, Union
from app import schemas
from app.chain import ChainBase
from app.core.config import settings, global_vars
from app.core.config import settings
from app.core.context import MediaInfo, TorrentInfo, Context
from app.core.event import eventmanager, Event
from app.core.meta import MetaBase
@@ -19,8 +19,8 @@ from app.helper.directory import DirectoryHelper
from app.helper.message import MessageHelper
from app.helper.torrent import TorrentHelper
from app.log import logger
from app.schemas import ExistMediaInfo, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, ResourceDownloadEventData
from app.schemas.types import MediaType, TorrentStatus, EventType, MessageChannel, NotificationType, ChainEventType
from app.schemas import ExistMediaInfo, NotExistMediaInfo, DownloadingTorrent, Notification
from app.schemas.types import MediaType, TorrentStatus, EventType, MessageChannel, NotificationType
from app.utils.http import RequestUtils
from app.utils.string import StringUtils
@@ -39,18 +39,18 @@ class DownloadChain(ChainBase):
self.messagehelper = MessageHelper()
def post_download_message(self, meta: MetaBase, mediainfo: MediaInfo, torrent: TorrentInfo,
channel: MessageChannel = None, username: str = None,
channel: MessageChannel = None, userid: str = None, username: str = None,
download_episodes: str = None):
"""
发送添加下载的消息,根据消息场景开关决定发给谁
发送添加下载的消息
:param meta: 元数据
:param mediainfo: 媒体信息
:param torrent: 种子信息
:param channel: 通知渠道
:param userid: 用户ID指定时精确发送对应用户
:param username: 通知显示的下载用户信息
:param download_episodes: 下载的集数
"""
# 拼装消息内容
msg_text = ""
if username:
msg_text = f"用户:{username}"
@@ -84,20 +84,18 @@ class DownloadChain(ChainBase):
torrent.description = re.sub(r'<[^>]+>', '', description)
msg_text = f"{msg_text}\n描述:{torrent.description}"
# 下载成功按规则发送消息
self.post_message(Notification(
channel=channel,
mtype=NotificationType.Download,
userid=userid,
title=f"{mediainfo.title_year} "
f"{'%s %s' % (meta.season, download_episodes) if download_episodes else meta.season_episode} 开始下载",
text=msg_text,
image=mediainfo.get_message_image(),
link=settings.MP_DOMAIN('/#/downloading'),
username=username))
link=settings.MP_DOMAIN('/#/downloading')))
def download_torrent(self, torrent: TorrentInfo,
channel: MessageChannel = None,
source: str = None,
userid: Union[str, int] = None
) -> Tuple[Optional[Union[Path, str]], str, list]:
"""
@@ -180,7 +178,7 @@ class DownloadChain(ChainBase):
torrent_file, content, download_folder, files, error_msg = self.torrent.download_torrent(
url=torrent_url,
cookie=site_cookie,
ua=torrent.site_ua or settings.USER_AGENT,
ua=torrent.site_ua,
proxy=torrent.site_proxy)
if isinstance(content, str):
@@ -191,7 +189,6 @@ class DownloadChain(ChainBase):
logger.error(f"下载种子文件失败:{torrent.title} - {torrent_url}")
self.post_message(Notification(
channel=channel,
source=source if channel else None,
mtype=NotificationType.Manual,
title=f"{torrent.title} 种子下载失败!",
text=f"错误信息:{error_msg}\n站点:{torrent.site_name}",
@@ -204,54 +201,22 @@ class DownloadChain(ChainBase):
def download_single(self, context: Context, torrent_file: Path = None,
episodes: Set[int] = None,
channel: MessageChannel = None,
source: str = None,
downloader: str = None,
save_path: str = None,
userid: Union[str, int] = None,
username: str = None,
media_category: str = None) -> Optional[str]:
username: str = None) -> Optional[str]:
"""
下载及发送通知
:param context: 资源上下文
:param torrent_file: 种子文件路径
:param episodes: 需要下载的集数
:param channel: 通知渠道
:param source: 来源消息通知、Subscribe、Manual等
:param downloader: 下载器
:param save_path: 保存路径
:param userid: 用户ID
:param username: 调用下载的用户名/插件名
:param media_category: 自定义媒体类别
"""
# 发送资源下载事件,允许外部拦截下载
event_data = ResourceDownloadEventData(
context=context,
episodes=episodes or context.meta_info.episode_list,
channel=channel,
origin=source,
downloader=downloader,
options={
"save_path": save_path,
"userid": userid,
"username": username,
"media_category": media_category
}
)
# 触发资源下载事件
event = eventmanager.send_event(ChainEventType.ResourceDownload, event_data)
if event and event.event_data:
event_data: ResourceDownloadEventData = event.event_data
# 如果事件被取消,跳过资源下载
if event_data.cancel:
logger.debug(
f"Resource download canceled by event: {event_data.source},"
f"Reason: {event_data.reason}")
return None
_torrent = context.torrent_info
_media = context.media_info
_meta = context.meta_info
_site_downloader = _torrent.site_downloader
# 补充完整的media数据
if not _media.genre_ids:
@@ -267,7 +232,6 @@ class DownloadChain(ChainBase):
# 下载种子文件,得到的可能是文件也可能是磁力链
content, _folder_name, _file_list = self.download_torrent(_torrent,
channel=channel,
source=source,
userid=userid)
if not content:
return None
@@ -278,57 +242,52 @@ class DownloadChain(ChainBase):
# 下载目录
if save_path:
# 下载目录使用自定义的
download_dir = Path(save_path)
# 有自定义下载目录时,尝试匹配目录配置
dir_info = self.directoryhelper.get_download_dir(_media, to_path=Path(save_path))
else:
# 根据媒体信息查询下载目录配置
dir_info = self.directoryhelper.get_dir(_media, storage="local", include_unsorted=True)
# 拼装子目录
if dir_info:
# 一级目录
if not dir_info.media_type and dir_info.download_type_folder:
# 一级自动分类
download_dir = Path(dir_info.download_path) / _media.type.value
else:
# 一级不分类
download_dir = Path(dir_info.download_path)
# 二级目录
if not dir_info.media_category and dir_info.download_category_folder and _media and _media.category:
# 二级自动分类
download_dir = download_dir / _media.category
dir_info = self.directoryhelper.get_download_dir(_media)
# 拼装子目录
if dir_info:
# 一级目录
if not dir_info.media_type and dir_info.auto_category:
# 一级自动分类
download_dir = Path(dir_info.path) / _media.type.value
else:
# 未找到下载目录,且没有自定义下载目录
logger.error(f"未找到下载目录:{_media.type.value} {_media.title_year}")
self.messagehelper.put(f"{_media.type.value} {_media.title_year} 未找到下载目录!",
title="下载失败", role="system")
return None
# 一级不分类
download_dir = Path(dir_info.path)
# 二级目录
if not dir_info.category and dir_info.auto_category and _media and _media.category:
# 二级自动分类
download_dir = download_dir / _media.category
elif save_path:
# 自定义下载目录
download_dir = Path(save_path)
else:
# 未找到下载目录,且没有自定义下载目录
logger.error(f"未找到下载目录:{_media.type.value} {_media.title_year}")
self.messagehelper.put(f"{_media.type.value} {_media.title_year} 未找到下载目录!",
title="下载失败", role="system")
return None
# 添加下载
result: Optional[tuple] = self.download(content=content,
cookie=_torrent.site_cookie,
episodes=episodes,
download_dir=download_dir,
category=_media.category,
downloader=downloader or _site_downloader)
category=_media.category)
if result:
_downloader, _hash, _layout, error_msg = result
_hash, error_msg = result
else:
_downloader, _hash, _layout, error_msg = None, None, None, "找到下载器"
_hash, error_msg = None, "知错误"
if _hash:
# `不创建子文件夹` 或 `不存在子文件夹`
if _layout == "NoSubfolder" or not _folder_name:
# 下载路径记录至文件
download_path = download_dir / _file_list[0] if _file_list else download_dir
# 原始布局
elif _folder_name:
# 下载文件路径
if _folder_name:
download_path = download_dir / _folder_name
# 创建子文件夹
else:
download_path = download_dir / Path(_file_list[0]).stem if _file_list else download_dir
# 文件保存路径
_save_path = download_dir if _layout == "NoSubfolder" or not _folder_name else download_path
download_path = download_dir / _file_list[0] if _file_list else download_dir
# 登记下载记录
self.downloadhis.add(
@@ -343,7 +302,6 @@ class DownloadChain(ChainBase):
seasons=_meta.season,
episodes=download_episodes or _meta.episode,
image=_media.get_backdrop_image(),
downloader=_downloader,
download_hash=_hash,
torrent_name=_torrent.title,
torrent_description=_torrent.description,
@@ -351,9 +309,7 @@ class DownloadChain(ChainBase):
userid=userid,
username=username,
channel=channel.value if channel else None,
date=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
media_category=media_category,
note={"source": source}
date=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
)
# 登记下载文件
@@ -367,20 +323,20 @@ class DownloadChain(ChainBase):
continue
# 只处理视频格式
if not Path(file).suffix \
or Path(file).suffix.lower() not in settings.RMT_MEDIAEXT:
or Path(file).suffix not in settings.RMT_MEDIAEXT:
continue
files_to_add.append({
"download_hash": _hash,
"downloader": _downloader,
"fullpath": str(_save_path / file),
"savepath": str(_save_path),
"downloader": settings.DEFAULT_DOWNLOADER,
"fullpath": str(download_dir / _folder_name / file),
"savepath": str(download_dir / _folder_name),
"filepath": file,
"torrentname": _meta.org_string,
})
if files_to_add:
self.downloadhis.add_files(files_to_add)
# 下载成功发送消息
# 发送消息群发不带channel和userid
self.post_download_message(meta=_meta, mediainfo=_media, torrent=_torrent,
username=username, download_episodes=download_episodes)
# 下载成功后处理
@@ -389,10 +345,7 @@ class DownloadChain(ChainBase):
self.eventmanager.send_event(EventType.DownloadAdded, {
"hash": _hash,
"context": context,
"username": username,
"downloader": _downloader,
"episodes": episodes or _meta.episode_list,
"source": source
"username": username
})
else:
# 下载失败
@@ -401,7 +354,6 @@ class DownloadChain(ChainBase):
# 只发送给对应渠道和用户
self.post_message(Notification(
channel=channel,
source=source if channel else None,
mtype=NotificationType.Manual,
title="添加下载任务失败:%s %s"
% (_media.title_year, _meta.season_episode),
@@ -417,11 +369,8 @@ class DownloadChain(ChainBase):
no_exists: Dict[Union[int, str], Dict[int, NotExistMediaInfo]] = None,
save_path: str = None,
channel: MessageChannel = None,
source: str = None,
userid: str = None,
username: str = None,
media_category: str = None,
downloader: str = None
username: str = None
) -> Tuple[List[Context], Dict[Union[int, str], Dict[int, NotExistMediaInfo]]]:
"""
根据缺失数据,自动种子列表中组合择优下载
@@ -429,11 +378,8 @@ class DownloadChain(ChainBase):
:param no_exists: 缺失的剧集信息
:param save_path: 保存路径
:param channel: 通知渠道
:param source: 来源(消息通知、订阅、手工下载等)
:param userid: 用户ID
:param username: 调用下载的用户名/插件名
:param media_category: 自定义媒体类别
:param downloader: 下载器
:return: 已经下载的资源列表、剩余未下载到的剧集 no_exists[tmdb_id/douban_id] = {season: NotExistMediaInfo}
"""
# 已下载的项目
@@ -494,41 +440,22 @@ class DownloadChain(ChainBase):
return 9999
return no_exist[season].total_episode
# 发送资源选择事件,允许外部修改上下文数据
logger.debug(f"Initial contexts: {len(contexts)} items, Downloader: {downloader}")
event_data = ResourceSelectionEventData(
contexts=contexts,
downloader=downloader,
origin=source
)
event = eventmanager.send_event(ChainEventType.ResourceSelection, event_data)
# 如果事件修改了上下文数据,使用更新后的数据
if event and event.event_data:
event_data: ResourceSelectionEventData = event.event_data
if event_data.updated and event_data.updated_contexts is not None:
logger.debug(f"Contexts updated by event: "
f"{len(event_data.updated_contexts)} items (source: {event_data.source})")
contexts = event_data.updated_contexts
# 分组排序
contexts = TorrentHelper().sort_group_torrents(contexts)
# 如果是电影,直接下载
for context in contexts:
if global_vars.is_system_stopped:
break
if context.media_info.type == MediaType.MOVIE:
logger.info(f"开始下载电影 {context.torrent_info.title} ...")
if self.download_single(context, save_path=save_path, channel=channel,
source=source, userid=userid, username=username,
media_category=media_category, downloader=downloader):
userid=userid, username=username):
# 下载成功
logger.info(f"{context.torrent_info.title} 添加下载成功")
downloaded_list.append(context)
# 电视剧整季匹配
logger.info(f"开始匹配电视剧整季:{no_exists}")
if no_exists:
logger.info(f"开始匹配电视剧整季:{no_exists}")
# 先把整季缺失的拿出来,看是否刚好有所有季都满足的种子 {tmdbid: [seasons]}
need_seasons: Dict[int, list] = {}
for need_mid, need_tv in no_exists.items():
@@ -545,8 +472,6 @@ class DownloadChain(ChainBase):
for need_mid, need_season in need_seasons.items():
# 循环种子
for context in contexts:
if global_vars.is_system_stopped:
break
# 媒体信息
media = context.media_info
# 识别元数据
@@ -603,20 +528,15 @@ class DownloadChain(ChainBase):
torrent_file=content if isinstance(content, Path) else None,
save_path=save_path,
channel=channel,
source=source,
userid=userid,
username=username,
media_category=media_category,
downloader=downloader,
username=username
)
else:
# 下载
logger.info(f"开始下载 {torrent.title} ...")
download_id = self.download_single(context, save_path=save_path,
channel=channel, source=source,
userid=userid, username=username,
media_category=media_category,
downloader=downloader)
download_id = self.download_single(context,
save_path=save_path, channel=channel,
userid=userid, username=username)
if download_id:
# 下载成功
@@ -631,8 +551,8 @@ class DownloadChain(ChainBase):
# 全部下载完成
break
# 电视剧季内的集匹配
logger.info(f"开始电视剧完整集匹配:{no_exists}")
if no_exists:
logger.info(f"开始电视剧完整集匹配:{no_exists}")
# TMDBID列表
need_tv_list = list(no_exists)
for need_mid in need_tv_list:
@@ -656,8 +576,6 @@ class DownloadChain(ChainBase):
need_episodes = list(range(start_episode, total_episode + 1))
# 循环种子
for context in contexts:
if global_vars.is_system_stopped:
break
# 媒体信息
media = context.media_info
# 识别元数据
@@ -684,11 +602,9 @@ class DownloadChain(ChainBase):
if torrent_episodes.issubset(set(need_episodes)):
# 下载
logger.info(f"开始下载 {meta.title} ...")
download_id = self.download_single(context, save_path=save_path,
channel=channel, source=source,
userid=userid, username=username,
media_category=media_category,
downloader=downloader)
download_id = self.download_single(context,
save_path=save_path, channel=channel,
userid=userid, username=username)
if download_id:
# 下载成功
logger.info(f"{meta.title} 添加下载成功")
@@ -701,8 +617,8 @@ class DownloadChain(ChainBase):
logger.info(f"{need_season} 剩余需要集:{need_episodes}")
# 仍然缺失的剧集从整季中选择需要的集数文件下载仅支持QB和TR
logger.info(f"开始电视剧多集拆包匹配:{no_exists}")
if no_exists:
logger.info(f"开始电视剧多集拆包匹配:{no_exists}")
# TMDBID列表
no_exists_list = list(no_exists)
for need_mid in no_exists_list:
@@ -725,8 +641,6 @@ class DownloadChain(ChainBase):
continue
# 循环种子
for context in contexts:
if global_vars.is_system_stopped:
break
# 媒体信息
media = context.media_info
# 识别元数据
@@ -774,11 +688,8 @@ class DownloadChain(ChainBase):
episodes=selected_episodes,
save_path=save_path,
channel=channel,
source=source,
userid=userid,
username=username,
media_category=media_category,
downloader=downloader
username=username
)
if not download_id:
continue
@@ -930,7 +841,7 @@ class DownloadChain(ChainBase):
# 全部存在
return True, no_exists
def remote_downloading(self, channel: MessageChannel, userid: Union[str, int] = None, source: str = None):
def remote_downloading(self, channel: MessageChannel, userid: Union[str, int] = None):
"""
查询正在下载的任务,并发送消息
"""
@@ -938,7 +849,6 @@ class DownloadChain(ChainBase):
if not torrents:
self.post_message(Notification(
channel=channel,
source=source,
mtype=NotificationType.Download,
title="没有正在下载的任务!",
userid=userid,
@@ -956,7 +866,6 @@ class DownloadChain(ChainBase):
index += 1
self.post_message(Notification(
channel=channel,
source=source,
mtype=NotificationType.Download,
title=title,
text="\n".join(messages),
@@ -964,11 +873,11 @@ class DownloadChain(ChainBase):
link=settings.MP_DOMAIN('#/downloading')
))
def downloading(self, name: str = None) -> List[DownloadingTorrent]:
def downloading(self) -> List[DownloadingTorrent]:
"""
查询正在下载的任务
"""
torrents = self.list_torrents(downloader=name, status=TorrentStatus.DOWNLOADING)
torrents = self.list_torrents(status=TorrentStatus.DOWNLOADING)
if not torrents:
return []
ret_torrents = []

View File

@@ -1,35 +1,36 @@
import copy
import time
from pathlib import Path
from threading import Lock
from typing import Optional, List, Tuple, Union
from app import schemas
from app.chain import ChainBase
from app.chain.storage import StorageChain
from app.core.config import settings
from app.core.context import Context, MediaInfo
from app.core.event import eventmanager, Event
from app.core.meta import MetaBase
from app.core.metainfo import MetaInfo, MetaInfoPath
from app.helper.aliyun import AliyunHelper
from app.helper.u115 import U115Helper
from app.log import logger
from app.schemas import FileItem
from app.schemas.types import EventType, MediaType, ChainEventType
from app.schemas.types import EventType, MediaType
from app.utils.http import RequestUtils
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
from app.utils.system import SystemUtils
recognize_lock = Lock()
scraping_lock = Lock()
scraping_files = []
class MediaChain(ChainBase, metaclass=Singleton):
"""
媒体信息处理链,单例运行
"""
def __init__(self):
super().__init__()
self.storagechain = StorageChain()
# 临时识别标题
recognize_title: Optional[str] = None
# 临时识别结果 {title, name, year, season, episode}
recognize_temp: Optional[dict] = None
def metadata_nfo(self, meta: MetaBase, mediainfo: MediaInfo,
season: int = None, episode: int = None) -> Optional[str]:
@@ -51,7 +52,7 @@ class MediaChain(ChainBase, metaclass=Singleton):
mediainfo: MediaInfo = self.recognize_media(meta=metainfo)
if not mediainfo:
# 尝试使用辅助识别,如果有注册响应事件的话
if eventmanager.check(ChainEventType.NameRecognize):
if eventmanager.check(EventType.NameRecognize):
logger.info(f'请求辅助识别,标题:{title} ...')
mediainfo = self.recognize_help(title=title, org_meta=metainfo)
if not mediainfo:
@@ -70,47 +71,83 @@ class MediaChain(ChainBase, metaclass=Singleton):
:param title: 标题
:param org_meta: 原始元数据
"""
# 发送请求事件,等待结果
result: Event = eventmanager.send_event(
ChainEventType.NameRecognize,
with recognize_lock:
self.recognize_temp = None
self.recognize_title = title
# 发送请求事件
eventmanager.send_event(
EventType.NameRecognize,
{
'title': title,
}
)
if not result:
return None
# 获取返回事件数据
event_data = result.event_data or {}
logger.info(f'获取到辅助识别结果:{event_data}')
# 处理数据格式
title, year, season_number, episode_number = None, None, None, None
if event_data.get("name"):
title = str(event_data["name"]).split("/")[0].strip().replace(".", " ")
if event_data.get("year"):
year = str(event_data["year"]).split("/")[0].strip()
if event_data.get("season") and str(event_data["season"]).isdigit():
season_number = int(event_data["season"])
if event_data.get("episode") and str(event_data["episode"]).isdigit():
episode_number = int(event_data["episode"])
if not title:
return None
if title == 'Unknown':
return None
if not str(year).isdigit():
year = None
# 结果赋值
if title == org_meta.name and year == org_meta.year:
logger.info(f'辅助识别与原始识别结果一致,无需重新识别媒体信息')
return None
logger.info(f'辅助识别结果与原始识别结果不一致,重新匹配媒体信息 ...')
org_meta.name = title
org_meta.year = year
org_meta.begin_season = season_number
org_meta.begin_episode = episode_number
if org_meta.begin_season or org_meta.begin_episode:
org_meta.type = MediaType.TV
# 重新识别
return self.recognize_media(meta=org_meta)
# 每0.5秒循环一次等待结果直到10秒后超时
for i in range(20):
if self.recognize_temp is not None:
break
time.sleep(0.5)
# 加锁
with recognize_lock:
mediainfo = None
if not self.recognize_temp or self.recognize_title != title:
# 没有识别结果或者识别标题已改变
return None
# 有识别结果
meta_dict = copy.deepcopy(self.recognize_temp)
logger.info(f'获取到辅助识别结果:{meta_dict}')
if meta_dict.get("name") == org_meta.name and meta_dict.get("year") == org_meta.year:
logger.info(f'辅助识别结果与原始识别结果一致')
else:
logger.info(f'辅助识别结果与原始识别结果不一致,重新匹配媒体信息 ...')
org_meta.name = meta_dict.get("name")
org_meta.year = meta_dict.get("year")
org_meta.begin_season = meta_dict.get("season")
org_meta.begin_episode = meta_dict.get("episode")
if org_meta.begin_season or org_meta.begin_episode:
org_meta.type = MediaType.TV
# 重新识别
mediainfo = self.recognize_media(meta=org_meta)
return mediainfo
@eventmanager.register(EventType.NameRecognizeResult)
def recognize_result(self, event: Event):
"""
监控识别结果事件,获取辅助识别结果,结果格式:{title, name, year, season, episode}
"""
if not event:
return
event_data = event.event_data or {}
# 加锁
with recognize_lock:
# 不是原标题的结果不要
if event_data.get("title") != self.recognize_title:
return
# 标志收到返回
self.recognize_temp = {}
# 处理数据格式
file_title, file_year, season_number, episode_number = None, None, None, None
if event_data.get("name"):
file_title = str(event_data["name"]).split("/")[0].strip().replace(".", " ")
if event_data.get("year"):
file_year = str(event_data["year"]).split("/")[0].strip()
if event_data.get("season") and str(event_data["season"]).isdigit():
season_number = int(event_data["season"])
if event_data.get("episode") and str(event_data["episode"]).isdigit():
episode_number = int(event_data["episode"])
if not file_title:
return
if file_title == 'Unknown':
return
if not str(file_year).isdigit():
file_year = None
# 结果赋值
self.recognize_temp = {
"name": file_title,
"year": file_year,
"season": season_number,
"episode": episode_number
}
def recognize_by_path(self, path: str) -> Optional[Context]:
"""
@@ -124,7 +161,7 @@ class MediaChain(ChainBase, metaclass=Singleton):
mediainfo = self.recognize_media(meta=file_meta)
if not mediainfo:
# 尝试使用辅助识别,如果有注册响应事件的话
if eventmanager.check(ChainEventType.NameRecognize):
if eventmanager.check(EventType.NameRecognize):
logger.info(f'请求辅助识别,标题:{file_path.name} ...')
mediainfo = self.recognize_help(title=path, org_meta=file_meta)
if not mediainfo:
@@ -296,91 +333,54 @@ class MediaChain(ChainBase, metaclass=Singleton):
)
return None
@eventmanager.register(EventType.MetadataScrape)
def scrape_metadata_event(self, event: Event):
"""
监控手动刮削事件
"""
if not event:
return
event_data = event.event_data or {}
fileitem: FileItem = event_data.get("fileitem")
meta: MetaBase = event_data.get("meta")
mediainfo: MediaInfo = event_data.get("mediainfo")
overwrite = event_data.get("overwrite", False)
if not fileitem:
return
# 刮削锁
with scraping_lock:
if fileitem.path in scraping_files:
return
scraping_files.append(fileitem.path)
try:
# 执行刮削
self.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=overwrite)
finally:
# 释放锁
with scraping_lock:
scraping_files.remove(fileitem.path)
def scrape_metadata(self, fileitem: schemas.FileItem,
meta: MetaBase = None, mediainfo: MediaInfo = None,
init_folder: bool = True, parent: schemas.FileItem = None,
overwrite: bool = False):
def manual_scrape(self, storage: str, fileitem: schemas.FileItem,
meta: MetaBase = None, mediainfo: MediaInfo = None, init_folder: bool = True):
"""
手动刮削媒体信息
:param fileitem: 刮削目录或文件
:param meta: 元数据
:param mediainfo: 媒体信息
:param init_folder: 是否刮削根目录
:param parent: 上级目录
:param overwrite: 是否覆盖已有文件
"""
def is_bluray_folder(_fileitem: schemas.FileItem) -> bool:
"""
判断是否为原盘目录
"""
if not _fileitem or _fileitem.type != "dir":
return False
# 蓝光原盘目录必备的文件或文件夹
required_files = ['BDMV', 'CERTIFICATE']
# 检查目录下是否存在所需文件或文件夹
for item in self.storagechain.list_files(_fileitem):
if item.name in required_files:
return True
return False
def __list_files(_fileitem: schemas.FileItem):
def __list_files(_storage: str, _fileid: str, _path: str = None, _drive_id: str = None):
"""
列出下级文件
"""
return self.storagechain.list_files(fileitem=_fileitem)
if _storage == "aliyun":
return AliyunHelper().list(drive_id=_drive_id, parent_file_id=_fileid, path=_path)
elif _storage == "u115":
return U115Helper().list(parent_file_id=_fileid, path=_path)
else:
items = SystemUtils.list_sub_all(Path(_path))
return [schemas.FileItem(
type="file" if item.is_file() else "dir",
path=str(item),
name=item.name,
basename=item.stem,
extension=item.suffix[1:],
size=item.stat().st_size,
modify_time=item.stat().st_mtime
) for item in items]
def __save_file(_fileitem: schemas.FileItem, _path: Path, _content: Union[bytes, str]):
def __save_file(_storage: str, _drive_id: str, _fileid: str, _path: Path, _content: Union[bytes, str]):
"""
保存或上传文件
:param _fileitem: 关联的媒体文件项
:param _path: 元数据文件路径
:param _content: 文件内容
"""
if not _fileitem or not _content or not _path:
return
# 保存文件到临时目录,文件名随机
tmp_file = settings.TEMP_PATH / f"{_path.name}.{StringUtils.generate_random_str(10)}"
tmp_file.write_bytes(_content)
# 获取文件的父目录
try:
item = self.storagechain.upload_file(fileitem=_fileitem, path=tmp_file, new_name=_path.name)
if item:
logger.info(f"已保存文件:{item.path}")
else:
logger.warn(f"文件保存失败:{item.path}")
finally:
if tmp_file.exists():
tmp_file.unlink()
if _storage != "local":
# 写入到临时目录
temp_path = settings.TEMP_PATH / _path.name
temp_path.write_bytes(_content)
# 上传文件
logger.info(f"正在上传 {_path.name} ...")
if _storage == "aliyun":
AliyunHelper().upload(drive_id=_drive_id, parent_file_id=_fileid, file_path=temp_path)
elif _storage == "u115":
U115Helper().upload(parent_file_id=_fileid, file_path=temp_path)
logger.info(f"{_path.name} 上传完成")
else:
# 保存到本地
logger.info(f"正在保存 {_path.name} ...")
_path.write_bytes(_content)
logger.info(f"{_path} 已保存")
def __download_image(_url: str) -> Optional[bytes]:
def __save_image(_url: str) -> Optional[bytes]:
"""
下载图片并保存
"""
@@ -393,7 +393,6 @@ class MediaChain(ChainBase, metaclass=Singleton):
logger.info(f"{_url} 图片下载失败,请检查网络连通性!")
except Exception as err:
logger.error(f"{_url} 图片下载失败:{str(err)}")
return None
# 当前文件路径
filepath = Path(fileitem.path)
@@ -411,41 +410,23 @@ class MediaChain(ChainBase, metaclass=Singleton):
if mediainfo.type == MediaType.MOVIE:
# 电影
if fileitem.type == "file":
# 是否已存在
nfo_path = filepath.with_suffix(".nfo")
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
# 电影文件
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
if movie_nfo:
# 保存或上传nfo文件到上级目录
__save_file(_fileitem=parent, _path=nfo_path, _content=movie_nfo)
else:
logger.warn(f"{filepath.name} nfo文件生成失败")
else:
logger.info(f"已存在nfo文件{nfo_path}")
# 电影文件
logger.info(f"正在生成电影nfo{mediainfo.title_year} - {filepath.name}")
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
if not movie_nfo:
logger.warn(f"{filepath.name} nfo文件生成失败")
return
# 保存或上传nfo文件
__save_file(_storage=storage, _drive_id=fileitem.drive_id, _fileid=fileitem.parent_fileid,
_path=filepath.with_suffix(".nfo"), _content=movie_nfo)
else:
# 电影目录
if is_bluray_folder(fileitem):
# 原盘目录
nfo_path = filepath / (filepath.name + ".nfo")
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
# 生成原盘nfo
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
if movie_nfo:
# 保存或上传nfo文件到当前目录
__save_file(_fileitem=fileitem, _path=nfo_path, _content=movie_nfo)
else:
logger.warn(f"{filepath.name} nfo文件生成失败")
else:
logger.info(f"已存在nfo文件{nfo_path}")
else:
# 处理目录内的文件
files = __list_files(_fileitem=fileitem)
for file in files:
self.scrape_metadata(fileitem=file,
meta=meta, mediainfo=mediainfo,
init_folder=False, parent=fileitem,
overwrite=overwrite)
files = __list_files(_storage=storage, _fileid=fileitem.fileid,
_drive_id=fileitem.drive_id, _path=fileitem.path)
for file in files:
self.manual_scrape(storage=storage, fileitem=file,
meta=meta, mediainfo=mediainfo,
init_folder=False)
# 生成目录内图片文件
if init_folder:
# 图片
@@ -457,155 +438,83 @@ class MediaChain(ChainBase, metaclass=Singleton):
and attr_value.startswith("http"):
image_name = attr_name.replace("_path", "") + Path(attr_value).suffix
image_path = filepath / image_name
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
# 下载图片
content = __download_image(_url=attr_value)
# 写入图片到当前目录
if content:
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
else:
logger.info(f"已存在图片文件:{image_path}")
# 下载图片
content = __save_image(_url=attr_value)
# 写入nfo到根目录
__save_file(_storage=storage, _drive_id=fileitem.drive_id, _fileid=fileitem.fileid,
_path=image_path, _content=content)
else:
# 电视剧
if fileitem.type == "file":
# 重新识别季集
# 当前为集文件,重新识别季集
file_meta = MetaInfoPath(filepath)
if not file_meta.begin_episode:
logger.warn(f"{filepath.name} 无法识别文件集数!")
return
file_mediainfo = self.recognize_media(meta=file_meta, tmdbid=mediainfo.tmdb_id)
file_mediainfo = self.recognize_media(meta=file_meta)
if not file_mediainfo:
logger.warn(f"{filepath.name} 无法识别文件媒体信息!")
return
# 是否已存在
nfo_path = filepath.with_suffix(".nfo")
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
# 获取集的nfo文件
episode_nfo = self.metadata_nfo(meta=file_meta, mediainfo=file_mediainfo,
season=file_meta.begin_season, episode=file_meta.begin_episode)
if episode_nfo:
# 保存或上传nfo文件到上级目录
if not parent:
parent = self.storagechain.get_parent_item(fileitem)
__save_file(_fileitem=parent, _path=nfo_path, _content=episode_nfo)
else:
logger.warn(f"{filepath.name} nfo文件生成失败")
else:
logger.info(f"已存在nfo文件{nfo_path}")
# 获取集的图片
image_dict = self.metadata_img(mediainfo=file_mediainfo,
season=file_meta.begin_season, episode=file_meta.begin_episode)
if image_dict:
for episode, image_url in image_dict.items():
image_path = filepath.with_suffix(Path(image_url).suffix)
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage, path=image_path):
# 下载图片
content = __download_image(image_url)
# 保存图片文件到当前目录
if content:
if not parent:
parent = self.storagechain.get_parent_item(fileitem)
__save_file(_fileitem=parent, _path=image_path, _content=content)
else:
logger.info(f"已存在图片文件:{image_path}")
# 获取集的nfo文件
episode_nfo = self.metadata_nfo(meta=file_meta, mediainfo=file_mediainfo,
season=file_meta.begin_season, episode=file_meta.begin_episode)
if not episode_nfo:
logger.warn(f"{filepath.name} nfo生成失败")
return
# 保存或上传nfo文件
__save_file(_storage=storage, _drive_id=fileitem.drive_id, _fileid=fileitem.parent_fileid,
_path=filepath.with_suffix(".nfo"), _content=episode_nfo)
else:
# 当前为目录,处理目录内的文件
files = __list_files(_fileitem=fileitem)
files = __list_files(_storage=storage, _fileid=fileitem.fileid,
_drive_id=fileitem.drive_id, _path=fileitem.path)
for file in files:
self.scrape_metadata(fileitem=file,
meta=meta, mediainfo=mediainfo,
parent=fileitem if file.type == "file" else None,
init_folder=True if file.type == "dir" else False,
overwrite=overwrite)
self.manual_scrape(storage=storage, fileitem=file,
meta=meta, mediainfo=mediainfo,
init_folder=True if file.type == "dir" else False)
# 生成目录的nfo和图片
if init_folder:
# 识别文件夹名称
season_meta = MetaInfo(filepath.name)
# 当前文件夹为Specials或者SPs时设置为S0
if filepath.name in settings.RENAME_FORMAT_S0_NAMES:
season_meta.begin_season = 0
if season_meta.begin_season is not None:
# 是否已存在
if season_meta.begin_season:
# 当前目录有季号生成季nfo
season_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo, season=meta.begin_season)
if not season_nfo:
logger.warn(f"无法生成电视剧季nfo文件{meta.name}")
return
# 写入nfo到根目录
nfo_path = filepath / "season.nfo"
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
# 当前目录有季号生成季nfo
season_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo,
season=season_meta.begin_season)
if season_nfo:
# 写入nfo到根目录
__save_file(_fileitem=fileitem, _path=nfo_path, _content=season_nfo)
else:
logger.warn(f"无法生成电视剧季nfo文件{meta.name}")
else:
logger.info(f"已存在nfo文件{nfo_path}")
__save_file(_storage=storage, _drive_id=fileitem.drive_id, _fileid=fileitem.fileid,
_path=nfo_path, _content=season_nfo)
# TMDB季poster图片
image_dict = self.metadata_img(mediainfo=mediainfo, season=season_meta.begin_season)
if image_dict:
for image_name, image_url in image_dict.items():
image_path = filepath.with_name(image_name)
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
# 下载图片
content = __download_image(image_url)
# 保存图片文件到剧集目录
if content:
if not parent:
parent = self.storagechain.get_parent_item(fileitem)
__save_file(_fileitem=parent, _path=image_path, _content=content)
else:
logger.info(f"已存在图片文件:{image_path}")
# 额外fanart季图片poster thumb banner
image_dict = self.metadata_img(mediainfo=mediainfo)
if image_dict:
for image_name, image_url in image_dict.items():
if image_name.startswith("season"):
image_path = filepath.with_name(image_name)
# 只下载当前刮削季的图片
image_season = "00" if "specials" in image_name else image_name[6:8]
if image_season != str(season_meta.begin_season).rjust(2, '0'):
logger.info(f"当前刮削季为:{season_meta.begin_season},跳过文件:{image_path}")
continue
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
# 下载图片
content = __download_image(image_url)
# 保存图片文件到当前目录
if content:
if not parent:
parent = self.storagechain.get_parent_item(fileitem)
__save_file(_fileitem=parent, _path=image_path, _content=content)
else:
logger.info(f"已存在图片文件:{image_path}")
# 判断当前目录是不是剧集根目录
if not season_meta.season:
# 是否已存在
# 下载图片
content = __save_image(image_url)
# 保存图片文件到当前目录
__save_file(_storage=storage, _drive_id=fileitem.drive_id, _fileid=fileitem.fileid,
_path=image_path, _content=content)
if season_meta.name:
# 当前目录有名称生成tvshow nfo 和 tv图片
tv_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
if not tv_nfo:
logger.warn(f"无法生成电视剧nfo文件{meta.name}")
return
# 写入tvshow nfo到根目录
nfo_path = filepath / "tvshow.nfo"
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
# 当前目录有名称生成tvshow nfo 和 tv图片
tv_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
if tv_nfo:
# 写入tvshow nfo到根目录
__save_file(_fileitem=fileitem, _path=nfo_path, _content=tv_nfo)
else:
logger.warn(f"无法生成电视剧nfo文件{meta.name}")
else:
logger.info(f"已存在nfo文件{nfo_path}")
__save_file(_storage=storage, _drive_id=fileitem.drive_id, _fileid=fileitem.fileid,
_path=nfo_path, _content=tv_nfo)
# 生成目录图片
image_dict = self.metadata_img(mediainfo=mediainfo)
if image_dict:
for image_name, image_url in image_dict.items():
# 不下载季图片
if image_name.startswith("season"):
continue
image_path = filepath / image_name
if overwrite or not self.storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
# 下载图片
content = __download_image(image_url)
# 保存图片文件到当前目录
if content:
__save_file(_fileitem=fileitem, _path=image_path, _content=content)
else:
logger.info(f"已存在图片文件:{image_path}")
image_path = filepath.parent.with_name(image_name)
# 下载图片
content = __save_image(image_url)
# 保存图片文件到当前目录
__save_file(_storage=storage, _drive_id=fileitem.drive_id, _fileid=fileitem.fileid,
_path=image_path, _content=content)
logger.info(f"{filepath.name} 刮削完成")

View File

@@ -1,13 +1,12 @@
import json
import threading
from typing import List, Union, Optional, Generator
from typing import List, Union, Optional
from app import schemas
from app.chain import ChainBase
from app.core.cache import cached
from app.core.config import global_vars
from app.core.config import settings
from app.db.mediaserver_oper import MediaServerOper
from app.helper.service import ServiceConfigHelper
from app.log import logger
from app.schemas import MediaServerLibrary, MediaServerItem, MediaServerSeasonInfo, MediaServerPlayItem
lock = threading.Lock()
@@ -21,94 +20,42 @@ class MediaServerChain(ChainBase):
super().__init__()
self.dboper = MediaServerOper()
def librarys(self, server: str, username: str = None, hidden: bool = False) -> List[MediaServerLibrary]:
def librarys(self, server: str = None, username: str = None) -> List[schemas.MediaServerLibrary]:
"""
获取媒体服务器所有媒体库
"""
return self.run_module("mediaserver_librarys", server=server, username=username, hidden=hidden)
return self.run_module("mediaserver_librarys", server=server, username=username)
def items(self, server: str, library_id: Union[str, int], start_index: int = 0, limit: Optional[int] = -1) \
-> Optional[Generator]:
def items(self, server: str, library_id: Union[str, int]) -> List[schemas.MediaServerItem]:
"""
获取媒体服务器项目列表,支持分页和不分页逻辑,默认不分页获取所有数据
:param server: 媒体服务器名称
:param library_id: 媒体库ID用于标识要获取的媒体库
:param start_index: 起始索引,用于分页获取数据。默认为 0即从第一个项目开始获取
:param limit: 每次请求的最大项目数,用于分页。如果为 None 或 -1则表示一次性获取所有数据默认为 -1
:return: 返回一个生成器对象,用于逐步获取媒体服务器中的项目
说明:
- 特别注意的是这里使用yield from返回迭代器避免同时使用return与yield导致Python生成器解析异常
- 如果 `limit` 为 None 或 -1 时,表示一次性获取所有数据,分页处理将不再生效
- 在这种情况下,内存消耗可能会较大,特别是在数据量非常大的场景下
- 如果未来评估结果显示,不分页场景下的内存消耗远大于分页处理时的网络请求开销,可以考虑在此方法中实现自分页的处理
- 即通过 `while` 循环在上层进行分页控制,逐步获取所有数据,避免内存爆炸,当前该逻辑由具体实例来实现不分页的处理
- Plex 实际上已默认支持内部分页处理Jellyfin 与 Emby 获取数据时存在内部过滤场景,如排除合集等,分页数据可能是错误的
if limit is not None and limit != -1:
yield from self.run_module("mediaserver_items", server=server, library_id=library_id,
start_index=start_index, limit=limit)
else:
# 自分页逻辑,通过循环逐步获取所有数据
page_size = 10
while True:
data_generator = self.run_module("mediaserver_items", server=server, library_id=library_id,
start_index=start_index, limit=page_size)
if not data_generator:
break
count = 0
for item in data_generator:
if item:
count += 1
yield item
if count < page_size:
break
start_index += page_size
获取媒体服务器所有项目
"""
yield from self.run_module("mediaserver_items", server=server, library_id=library_id,
start_index=start_index, limit=limit)
return self.run_module("mediaserver_items", server=server, library_id=library_id)
def iteminfo(self, server: str, item_id: Union[str, int]) -> MediaServerItem:
def iteminfo(self, server: str, item_id: Union[str, int]) -> schemas.MediaServerItem:
"""
获取媒体服务器项目信息
"""
return self.run_module("mediaserver_iteminfo", server=server, item_id=item_id)
def episodes(self, server: str, item_id: Union[str, int]) -> List[MediaServerSeasonInfo]:
def episodes(self, server: str, item_id: Union[str, int]) -> List[schemas.MediaServerSeasonInfo]:
"""
获取媒体服务器剧集信息
"""
return self.run_module("mediaserver_tv_episodes", server=server, item_id=item_id)
def playing(self, server: str, count: int = 20, username: str = None) -> List[MediaServerPlayItem]:
def playing(self, count: int = 20, server: str = None, username: str = None) -> List[schemas.MediaServerPlayItem]:
"""
获取媒体服务器正在播放信息
"""
return self.run_module("mediaserver_playing", count=count, server=server, username=username)
def latest(self, server: str, count: int = 20, username: str = None) -> List[MediaServerPlayItem]:
def latest(self, count: int = 20, server: str = None, username: str = None) -> List[schemas.MediaServerPlayItem]:
"""
获取媒体服务器最新入库条目
"""
return self.run_module("mediaserver_latest", count=count, server=server, username=username)
@cached(maxsize=1, ttl=3600)
def get_latest_wallpapers(self, server: str = None, count: int = 10,
remote: bool = True, username: str = None) -> List[str]:
"""
获取最新最新入库条目海报作为壁纸缓存1小时
"""
return self.run_module("mediaserver_latest_images", server=server, count=count,
remote=remote, username=username)
def get_latest_wallpaper(self, server: str = None, remote: bool = True, username: str = None) -> Optional[str]:
"""
获取最新最新入库条目海报作为壁纸缓存1小时
"""
wallpapers = self.get_latest_wallpapers(server=server, count=1, remote=remote, username=username)
return wallpapers[0] if wallpapers else None
def get_play_url(self, server: str, item_id: Union[str, int]) -> Optional[str]:
"""
获取播放地址
@@ -120,9 +67,12 @@ class MediaServerChain(ChainBase):
同步媒体库所有数据到本地数据库
"""
# 设置的媒体服务器
mediaservers = ServiceConfigHelper.get_mediaserver_configs()
if not mediaservers:
if not settings.MEDIASERVER:
return
# 同步黑名单
sync_blacklist = settings.MEDIASERVER_SYNC_BLACKLIST.split(
",") if settings.MEDIASERVER_SYNC_BLACKLIST else []
mediaservers = settings.MEDIASERVER.split(",")
with lock:
# 汇总统计
total_count = 0
@@ -132,47 +82,35 @@ class MediaServerChain(ChainBase):
for mediaserver in mediaservers:
if not mediaserver:
continue
logger.info(f"正在准备同步媒体服务器 {mediaserver.name} 的数据")
if not mediaserver.enabled:
logger.info(f"媒体服务器 {mediaserver.name} 未启用,跳过")
continue
server_name = mediaserver.name
sync_libraries = mediaserver.sync_libraries or []
logger.info(f"开始同步媒体服务器 {server_name} 的数据 ...")
libraries = self.librarys(server_name)
if not libraries:
logger.info(f"没有获取到媒体服务器 {server_name} 的媒体库,跳过")
continue
for library in libraries:
if sync_libraries \
and "all" not in sync_libraries \
and str(library.id) not in sync_libraries:
logger.info(f"{library.name} 未在 {server_name} 同步媒体库列表中,跳过")
logger.info(f"开始同步媒体 {mediaserver} 的数据 ...")
for library in self.librarys(mediaserver):
# 同步黑名单 跳过
if library.name in sync_blacklist:
continue
logger.info(f"正在同步 {server_name} 媒体库 {library.name} ...")
logger.info(f"正在同步 {mediaserver} 媒体库 {library.name} ...")
library_count = 0
for item in self.items(server=server_name, library_id=library.id):
if global_vars.is_system_stopped:
return
if not item or not item.item_id:
for item in self.items(mediaserver, library.id):
if not item:
continue
if not item.item_id:
continue
logger.debug(f"正在同步 {item.title} ...")
# 计数
library_count += 1
seasoninfo = {}
# 类型
item_type = "电视剧" if item.item_type in ["Series", "show"] else "电影"
item_type = "电视剧" if item.item_type in ['Series', 'show'] else "电影"
if item_type == "电视剧":
# 查询剧集信息
espisodes_info = self.episodes(server_name, item.item_id) or []
espisodes_info = self.episodes(mediaserver, item.item_id) or []
for episode in espisodes_info:
seasoninfo[episode.season] = episode.episodes
# 插入数据
item_dict = item.dict()
item_dict["seasoninfo"] = seasoninfo
item_dict["item_type"] = item_type
item_dict['seasoninfo'] = json.dumps(seasoninfo)
item_dict['item_type'] = item_type
self.dboper.add(**item_dict)
logger.info(f"{server_name} 媒体库 {library.name} 同步完成,共同步数量:{library_count}")
logger.info(f"{mediaserver} 媒体库 {library.name} 同步完成,共同步数量:{library_count}")
# 总数累加
total_count += library_count
logger.info(f"媒体服务器 {server_name} 数据同步完成,同步数量:{total_count}")
logger.info("【MediaServer】媒体库数据同步完成,同步数量:%s" % total_count)

View File

@@ -1,4 +1,5 @@
import copy
import json
import re
from typing import Any, Optional, Dict, Union
@@ -105,14 +106,10 @@ class MessageChain(ChainBase):
"""
调用模块识别消息内容
"""
# 消息来源
source = args.get("source")
# 获取消息内容
info = self.message_parser(source=source, body=body, form=form, args=args)
info = self.message_parser(body=body, form=form, args=args)
if not info:
return
# 更新消息来源
source = info.source
# 渠道
channel = info.channel
# 用户ID
@@ -128,10 +125,9 @@ class MessageChain(ChainBase):
logger.debug(f'未识别到消息内容::{body}{form}{args}')
return
# 处理消息
self.handle_message(channel=channel, source=source, userid=userid, username=username, text=text)
self.handle_message(channel=channel, userid=userid, username=username, text=text)
def handle_message(self, channel: MessageChannel, source: str,
userid: Union[str, int], username: str, text: str) -> None:
def handle_message(self, channel: MessageChannel, userid: Union[str, int], username: str, text: str) -> None:
"""
识别消息内容,执行操作
"""
@@ -147,12 +143,10 @@ class MessageChain(ChainBase):
userid=userid,
username=username,
channel=channel,
source=source,
text=text
), role="user")
self.messageoper.add(
channel=channel,
source=source,
userid=username or userid,
text=text,
action=0
@@ -165,8 +159,7 @@ class MessageChain(ChainBase):
{
"cmd": text,
"user": userid,
"channel": channel,
"source": source
"channel": channel
}
)
@@ -179,7 +172,7 @@ class MessageChain(ChainBase):
or not cache_data.get('items') \
or len(cache_data.get('items')) < int(text):
# 发送消息
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
self.post_message(Notification(channel=channel, title="输入有误!", userid=userid))
return
# 选择的序号
_choice = int(text) + _current_page * self._page_size - 1
@@ -199,7 +192,6 @@ class MessageChain(ChainBase):
# 媒体库中已存在
self.post_message(
Notification(channel=channel,
source=source,
title=f"{_current_media.title_year}"
f"{_current_meta.sea} 媒体库中已存在,如需重新下载请发送:搜索 名称 或 下载 名称】",
userid=userid))
@@ -223,14 +215,12 @@ class MessageChain(ChainBase):
for sea, no_exist in no_exists.get(mediakey).items()]
if messages:
self.post_message(Notification(channel=channel,
source=source,
title=f"{mediainfo.title_year}\n" + "\n".join(messages),
userid=userid))
# 搜索种子,过滤掉不需要的剧集,以便选择
logger.info(f"开始搜索 {mediainfo.title_year} ...")
self.post_message(
Notification(channel=channel,
source=source,
title=f"开始搜索 {mediainfo.type.value} {mediainfo.title_year} ...",
userid=userid))
# 开始搜索
@@ -239,10 +229,8 @@ class MessageChain(ChainBase):
if not contexts:
# 没有数据
self.post_message(Notification(
channel=channel,
source=source,
title=f"{mediainfo.title}"
f"{_current_meta.sea} 未搜索到需要的资源!",
channel=channel, title=f"{mediainfo.title}"
f"{_current_meta.sea} 未搜索到需要的资源!",
userid=userid))
return
# 搜索结果排序
@@ -256,7 +244,6 @@ class MessageChain(ChainBase):
logger.info(f"用户 {userid} 在自动下载用户中,开始自动择优下载 ...")
# 自动选择下载
self.__auto_download(channel=channel,
source=source,
cache_list=contexts,
userid=userid,
username=username,
@@ -270,7 +257,6 @@ class MessageChain(ChainBase):
# 发送种子数据
logger.info(f"搜索到 {len(contexts)} 条数据,开始发送选择消息 ...")
self.__post_torrents_message(channel=channel,
source=source,
title=mediainfo.title,
items=contexts[:self._page_size],
userid=userid,
@@ -288,15 +274,12 @@ class MessageChain(ChainBase):
if exist_flag:
self.post_message(Notification(
channel=channel,
source=source,
title=f"{mediainfo.title_year}"
f"{_current_meta.sea} 媒体库中已存在,如需洗版请发送:洗版 XXX】",
userid=userid))
return
else:
best_version = True
# 转换用户名
mp_name = self.useroper.get_name(**{f"{channel.name.lower()}_userid": userid}) if channel else None
# 添加订阅状态为N
self.subscribechain.add(title=mediainfo.title,
year=mediainfo.year,
@@ -304,15 +287,13 @@ class MessageChain(ChainBase):
tmdbid=mediainfo.tmdb_id,
season=_current_meta.begin_season,
channel=channel,
source=source,
userid=userid,
username=mp_name or username,
username=username,
best_version=best_version)
elif cache_type == "Torrent":
if int(text) == 0:
# 自动选择下载,强制下载模式
self.__auto_download(channel=channel,
source=source,
cache_list=cache_list,
userid=userid,
username=username)
@@ -320,7 +301,7 @@ class MessageChain(ChainBase):
# 下载种子
context: Context = cache_list[_choice]
# 下载
self.downloadchain.download_single(context, channel=channel, source=source,
self.downloadchain.download_single(context, channel=channel,
userid=userid, username=username)
elif text.lower() == "p":
@@ -329,13 +310,13 @@ class MessageChain(ChainBase):
if not cache_data:
# 没有缓存
self.post_message(Notification(
channel=channel, source=source, title="输入有误!", userid=userid))
channel=channel, title="输入有误!", userid=userid))
return
if _current_page == 0:
# 第一页
self.post_message(Notification(
channel=channel, source=source, title="已经是第一页了!", userid=userid))
channel=channel, title="已经是第一页了!", userid=userid))
return
# 减一页
_current_page -= 1
@@ -351,7 +332,6 @@ class MessageChain(ChainBase):
if cache_type == "Torrent":
# 发送种子数据
self.__post_torrents_message(channel=channel,
source=source,
title=_current_media.title,
items=cache_list[start:end],
userid=userid,
@@ -359,7 +339,6 @@ class MessageChain(ChainBase):
else:
# 发送媒体数据
self.__post_medias_message(channel=channel,
source=source,
title=_current_meta.name,
items=cache_list[start:end],
userid=userid,
@@ -371,7 +350,7 @@ class MessageChain(ChainBase):
if not cache_data:
# 没有缓存
self.post_message(Notification(
channel=channel, source=source, title="输入有误!", userid=userid))
channel=channel, title="输入有误!", userid=userid))
return
cache_type: str = cache_data.get('type')
# 产生副本,避免修改原值
@@ -383,7 +362,7 @@ class MessageChain(ChainBase):
if not cache_list:
# 没有数据
self.post_message(Notification(
channel=channel, source=source, title="已经是最后一页了!", userid=userid))
channel=channel, title="已经是最后一页了!", userid=userid))
return
else:
# 加一页
@@ -391,13 +370,11 @@ class MessageChain(ChainBase):
if cache_type == "Torrent":
# 发送种子数据
self.__post_torrents_message(channel=channel,
source=source,
title=_current_media.title,
items=cache_list, userid=userid, total=total)
else:
# 发送媒体数据
self.__post_medias_message(channel=channel,
source=source,
title=_current_meta.name,
items=cache_list, userid=userid, total=total)
@@ -434,12 +411,12 @@ class MessageChain(ChainBase):
# 识别
if not meta.name:
self.post_message(Notification(
channel=channel, source=source, title="无法识别输入内容!", userid=userid))
channel=channel, title="无法识别输入内容!", userid=userid))
return
# 开始搜索
if not medias:
self.post_message(Notification(
channel=channel, source=source, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid))
channel=channel, title=f"{meta.name} 没有找到对应的媒体信息!", userid=userid))
return
logger.info(f"搜索到 {len(medias)} 条相关媒体信息")
# 记录当前状态
@@ -452,7 +429,6 @@ class MessageChain(ChainBase):
_current_media = None
# 发送媒体列表
self.__post_medias_message(channel=channel,
source=source,
title=meta.name,
items=medias[:self._page_size],
userid=userid, total=len(medias))
@@ -463,15 +439,14 @@ class MessageChain(ChainBase):
{
"text": content,
"userid": userid,
"channel": channel,
"source": source
"channel": channel
}
)
# 保存缓存
self.save_cache(user_cache, self._cache_file)
def __auto_download(self, channel: MessageChannel, source: str, cache_list: list[Context],
def __auto_download(self, channel: MessageChannel, cache_list: list[Context],
userid: Union[str, int], username: str,
no_exists: Optional[Dict[Union[int, str], Dict[int, NotExistMediaInfo]]] = None):
"""
@@ -491,7 +466,6 @@ class MessageChain(ChainBase):
downloads, lefts = self.downloadchain.batch_download(contexts=cache_list,
no_exists=no_exists,
channel=channel,
source=source,
userid=userid,
username=username)
if downloads and not lefts:
@@ -504,11 +478,9 @@ class MessageChain(ChainBase):
# 获取已下载剧集
downloaded = [download.meta_info.begin_episode for download in downloads
if download.meta_info.begin_episode]
note = downloaded
note = json.dumps(downloaded)
else:
note = None
# 转换用户名
mp_name = self.useroper.get_name(**{f"{channel.name.lower()}_userid": userid}) if channel else None
# 添加订阅状态为R
self.subscribechain.add(title=_current_media.title,
year=_current_media.year,
@@ -516,13 +488,12 @@ class MessageChain(ChainBase):
tmdbid=_current_media.tmdb_id,
season=_current_meta.begin_season,
channel=channel,
source=source,
userid=userid,
username=mp_name or username,
username=username,
state="R",
note=note)
def __post_medias_message(self, channel: MessageChannel, source: str,
def __post_medias_message(self, channel: MessageChannel,
title: str, items: list, userid: str, total: int):
"""
发送媒体列表消息
@@ -533,13 +504,11 @@ class MessageChain(ChainBase):
title = f"{title}】共找到{total}条相关信息,请回复对应数字选择"
self.post_medias_message(Notification(
channel=channel,
source=source,
title=title,
userid=userid
), medias=items)
def __post_torrents_message(self, channel: MessageChannel, source: str,
title: str, items: list,
def __post_torrents_message(self, channel: MessageChannel, title: str, items: list,
userid: str, total: int):
"""
发送种子列表消息
@@ -550,7 +519,6 @@ class MessageChain(ChainBase):
title = f"{title}】共找到{total}条相关资源请回复对应数字下载0: 自动选择)"
self.post_torrents_message(Notification(
channel=channel,
source=source,
title=title,
userid=userid,
link=settings.MP_DOMAIN('#/resource')

View File

@@ -1,314 +0,0 @@
import io
import tempfile
from pathlib import Path
from typing import List
import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image
from app.chain import ChainBase
from app.chain.bangumi import BangumiChain
from app.chain.douban import DoubanChain
from app.chain.tmdb import TmdbChain
from app.core.cache import cache_backend, cached
from app.core.config import settings, global_vars
from app.log import logger
from app.schemas import MediaType
from app.utils.common import log_execution_time
from app.utils.http import RequestUtils
from app.utils.security import SecurityUtils
from app.utils.singleton import Singleton
# 推荐相关的专用缓存
recommend_ttl = 24 * 3600
recommend_cache_region = "recommend"
class RecommendChain(ChainBase, metaclass=Singleton):
"""
推荐处理链,单例运行
"""
def __init__(self):
super().__init__()
self.tmdbchain = TmdbChain()
self.doubanchain = DoubanChain()
self.bangumichain = BangumiChain()
self.cache_max_pages = 5
def refresh_recommend(self):
"""
刷新推荐
"""
logger.debug("Starting to refresh Recommend data.")
cache_backend.clear(region=recommend_cache_region)
logger.debug("Recommend Cache has been cleared.")
# 推荐来源方法
recommend_methods = [
self.tmdb_movies,
self.tmdb_tvs,
self.tmdb_trending,
self.bangumi_calendar,
self.douban_movie_showing,
self.douban_movies,
self.douban_tvs,
self.douban_movie_top250,
self.douban_tv_weekly_chinese,
self.douban_tv_weekly_global,
self.douban_tv_animation,
self.douban_movie_hot,
self.douban_tv_hot,
]
# 缓存并刷新所有推荐数据
recommends = []
# 记录哪些方法已完成
methods_finished = set()
# 这里避免区间内连续调用相同来源,因此遍历方案为每页遍历所有推荐来源,再进行页数遍历
for page in range(1, self.cache_max_pages + 1):
for method in recommend_methods:
if global_vars.is_system_stopped:
return
if method in methods_finished:
continue
logger.debug(f"Fetch {method.__name__} data for page {page}.")
data = method(page=page)
if not data:
logger.debug("All recommendation methods have finished fetching data. Ending pagination early.")
methods_finished.add(method)
continue
recommends.extend(data)
# 如果所有方法都已经完成,提前结束循环
if len(methods_finished) == len(recommend_methods):
break
# 缓存收集到的海报
self.__cache_posters(recommends)
logger.debug("Recommend data refresh completed.")
def __cache_posters(self, datas: List[dict]):
"""
提取 poster_path 并缓存图片
:param datas: 数据列表
"""
if not settings.GLOBAL_IMAGE_CACHE:
return
for data in datas:
if global_vars.is_system_stopped:
return
poster_path = data.get("poster_path")
if poster_path:
poster_url = poster_path.replace("original", "w500")
logger.debug(f"Caching poster image: {poster_url}")
self.__fetch_and_save_image(poster_url)
@staticmethod
def __fetch_and_save_image(url: str):
"""
请求并保存图片
:param url: 图片路径
"""
if not settings.GLOBAL_IMAGE_CACHE or not url:
return
# 生成缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = settings.CACHE_PATH / "images" / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
logger.debug(f"Invalid cache path or file type for URL: {url}, sanitized path: {sanitized_path}")
return
# 本地存在缓存图片,则直接跳过
if cache_path.exists():
logger.debug(f"Cache hit: Image already exists at {cache_path}")
return
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if not referer else None
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
if not response:
logger.debug(f"Empty response for URL: {url}")
return
# 验证下载的内容是否为有效图片
try:
Image.open(io.BytesIO(response.content)).verify()
except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}")
return
if not cache_path:
return
try:
if not cache_path.parent.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
tmp_file.write(response.content)
temp_path = Path(tmp_file.name)
temp_path.replace(cache_path)
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
except Exception as e:
logger.debug(f"Failed to write cache file {cache_path} for URL {url}: {e}")
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def tmdb_movies(self, sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "",
with_keywords: str = "",
with_watch_providers: str = "",
vote_average: float = 0,
vote_count: int = 0,
release_date: str = "",
page: int = 1) -> List[dict]:
"""
TMDB热门电影
"""
movies = self.tmdbchain.tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [movie.to_dict() for movie in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def tmdb_tvs(self, sort_by: str = "popularity.desc",
with_genres: str = "",
with_original_language: str = "zh|en|ja|ko",
with_keywords: str = "",
with_watch_providers: str = "",
vote_average: float = 0,
vote_count: int = 0,
release_date: str = "",
page: int = 1) -> List[dict]:
"""
TMDB热门电视剧
"""
tvs = self.tmdbchain.tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [tv.to_dict() for tv in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def tmdb_trending(self, page: int = 1) -> List[dict]:
"""
TMDB流行趋势
"""
infos = self.tmdbchain.tmdb_trending(page=page)
return [info.to_dict() for info in infos] if infos else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def bangumi_calendar(self, page: int = 1, count: int = 30) -> List[dict]:
"""
Bangumi每日放送
"""
medias = self.bangumichain.calendar()
return [media.to_dict() for media in medias[(page - 1) * count: page * count]] if medias else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def douban_movie_showing(self, page: int = 1, count: int = 30) -> List[dict]:
"""
豆瓣正在热映
"""
movies = self.doubanchain.movie_showing(page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def douban_movies(self, sort: str = "R", tags: str = "", page: int = 1, count: int = 30) -> List[dict]:
"""
豆瓣最新电影
"""
movies = self.doubanchain.douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def douban_tvs(self, sort: str = "R", tags: str = "", page: int = 1, count: int = 30) -> List[dict]:
"""
豆瓣最新电视剧
"""
tvs = self.doubanchain.douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def douban_movie_top250(self, page: int = 1, count: int = 30) -> List[dict]:
"""
豆瓣电影TOP250
"""
movies = self.doubanchain.movie_top250(page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def douban_tv_weekly_chinese(self, page: int = 1, count: int = 30) -> List[dict]:
"""
豆瓣国产剧集榜
"""
tvs = self.doubanchain.tv_weekly_chinese(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def douban_tv_weekly_global(self, page: int = 1, count: int = 30) -> List[dict]:
"""
豆瓣全球剧集榜
"""
tvs = self.doubanchain.tv_weekly_global(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def douban_tv_animation(self, page: int = 1, count: int = 30) -> List[dict]:
"""
豆瓣热门动漫
"""
tvs = self.doubanchain.tv_animation(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def douban_movie_hot(self, page: int = 1, count: int = 30) -> List[dict]:
"""
豆瓣热门电影
"""
movies = self.doubanchain.movie_hot(page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
def douban_tv_hot(self, page: int = 1, count: int = 30) -> List[dict]:
"""
豆瓣热门电视剧
"""
tvs = self.doubanchain.tv_hot(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []

View File

@@ -6,7 +6,6 @@ from typing import Dict
from typing import List, Optional
from app.chain import ChainBase
from app.core.config import global_vars
from app.core.context import Context
from app.core.context import MediaInfo, TorrentInfo
from app.core.event import eventmanager, Event
@@ -25,8 +24,6 @@ class SearchChain(ChainBase):
站点资源搜索处理链
"""
__result_temp_file = "__search_result__"
def __init__(self):
super().__init__()
self.siteshelper = SitesHelper()
@@ -37,7 +34,7 @@ class SearchChain(ChainBase):
def search_by_id(self, tmdbid: int = None, doubanid: str = None,
mtype: MediaType = None, area: str = "title", season: int = None) -> List[Context]:
"""
根据TMDBID/豆瓣ID搜索资源精确匹配不过滤本地存在的资源
根据TMDBID/豆瓣ID搜索资源精确匹配但不不过滤本地存在的资源
:param tmdbid: TMDB ID
:param doubanid: 豆瓣 ID
:param mtype: 媒体,电影 or 电视剧
@@ -56,9 +53,9 @@ class SearchChain(ChainBase):
}
}
results = self.process(mediainfo=mediainfo, area=area, no_exists=no_exists)
# 保存到本地文件
# 保存结果
bytes_results = pickle.dumps(results)
self.save_cache(bytes_results, self.__result_temp_file)
self.systemconfig.set(SystemConfigKey.SearchResults, bytes_results)
return results
def search_by_title(self, title: str, page: int = 0, site: int = None) -> List[Context]:
@@ -80,21 +77,20 @@ class SearchChain(ChainBase):
# 组装上下文
contexts = [Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
torrent_info=torrent) for torrent in torrents]
# 保存到本地文件
# 保存结果
bytes_results = pickle.dumps(contexts)
self.save_cache(bytes_results, self.__result_temp_file)
self.systemconfig.set(SystemConfigKey.SearchResults, bytes_results)
return contexts
def last_search_results(self) -> List[Context]:
"""
获取上次搜索结果
"""
# 读取本地文件缓存
content = self.load_cache(self.__result_temp_file)
if not content:
results = self.systemconfig.get(SystemConfigKey.SearchResults)
if not results:
return []
try:
return pickle.loads(content)
return pickle.loads(results)
except Exception as e:
logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}')
return []
@@ -103,28 +99,27 @@ class SearchChain(ChainBase):
keyword: str = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: str = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
priority_rule: str = None,
filter_rule: Dict[str, str] = None,
area: str = "title") -> List[Context]:
"""
根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
:param rule_groups: 过滤规则组名称列表
:param priority_rule: 优先级规则,为空时使用搜索优先级规则
:param filter_rule: 过滤规则,为空是使用默认过滤规则
:param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表
:param filter_params: 过滤参数
"""
def __do_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
"""
执行优先级过滤
"""
return self.filter_torrents(rule_groups=rule_groups,
return self.filter_torrents(rule_string=priority_rule,
torrent_list=torrent_list,
season_episodes=season_episodes,
mediainfo=mediainfo) or []
# 豆瓣标题处理
@@ -163,8 +158,6 @@ class SearchChain(ChainBase):
keywords = list(dict.fromkeys([k for k in [mediainfo.title,
mediainfo.original_title,
mediainfo.en_title,
mediainfo.hk_title,
mediainfo.tw_title,
mediainfo.sg_title] if k]))
# 执行搜索
@@ -181,75 +174,40 @@ class SearchChain(ChainBase):
# 开始新进度
self.progress.start(ProgressKey.Search)
# 开始过滤
self.progress.update(value=0, text=f'开始过滤,总 {len(torrents)} 个资源,请稍候...',
key=ProgressKey.Search)
# 匹配订阅附加参数
if filter_params:
logger.info(f'开始附加参数过滤,附加参数:{filter_params} ...')
torrents = [torrent for torrent in torrents if self.torrenthelper.filter_torrent(torrent, filter_params)]
# 开始过滤规则过滤
if rule_groups is None:
# 取搜索过滤规则
rule_groups: List[str] = self.systemconfig.get(SystemConfigKey.SearchFilterRuleGroups)
if rule_groups:
logger.info(f'开始过滤规则/剧集过滤,使用规则组:{rule_groups} ...')
torrents = __do_filter(torrents)
if not torrents:
logger.warn(f'{keyword or mediainfo.title} 没有符合过滤规则的资源')
return []
logger.info(f"过滤规则/剧集过滤完成,剩余 {len(torrents)} 个资源")
# 过滤完成
self.progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源', key=ProgressKey.Search)
# 开始匹配
_match_torrents = []
# 总数
_total = len(torrents)
# 已处理数
_count = 0
if mediainfo:
# 英文标题应该在别名/原标题中,不需要再匹配
logger.info(f"开始匹配结果 标题:{mediainfo.title},原标题:{mediainfo.original_title},别名:{mediainfo.names}")
self.progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...', key=ProgressKey.Search)
self.progress.update(value=0, text=f'开始匹配,总 {_total} 个资源 ...', key=ProgressKey.Search)
for torrent in torrents:
if global_vars.is_system_stopped:
break
_count += 1
self.progress.update(value=(_count / _total) * 96,
text=f'正在匹配 {torrent.site_name},已完成 {_count} / {_total} ...',
key=ProgressKey.Search)
if not torrent.title:
continue
# 识别元数据
torrent_meta = MetaInfo(title=torrent.title, subtitle=torrent.description,
custom_words=custom_words)
if torrent.title != torrent_meta.org_string:
logger.info(f"种子名称应用识别词后发生改变:{torrent.title} => {torrent_meta.org_string}")
# 季集数过滤
if season_episodes \
and not self.torrenthelper.match_season_episodes(
torrent=torrent,
meta=torrent_meta,
season_episodes=season_episodes):
continue
# 比对IMDBID
if torrent.imdbid \
and mediainfo.imdb_id \
and torrent.imdbid == mediainfo.imdb_id:
logger.info(f'{mediainfo.title} 通过IMDBID匹配到资源{torrent.site_name} - {torrent.title}')
_match_torrents.append((torrent, torrent_meta))
_match_torrents.append(torrent)
continue
# 识别
torrent_meta = MetaInfo(title=torrent.title, subtitle=torrent.description)
if torrent.title != torrent_meta.org_string:
logger.info(f"种子名称应用识别词后发生改变:{torrent.title} => {torrent_meta.org_string}")
# 比对种子
if self.torrenthelper.match_torrent(mediainfo=mediainfo,
torrent_meta=torrent_meta,
torrent=torrent):
# 匹配成功
_match_torrents.append((torrent, torrent_meta))
_match_torrents.append(torrent)
continue
# 匹配完成
logger.info(f"匹配完成,共匹配到 {len(_match_torrents)} 个资源")
@@ -257,15 +215,44 @@ class SearchChain(ChainBase):
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源',
key=ProgressKey.Search)
else:
_match_torrents = [(t, MetaInfo(title=t.title, subtitle=t.description)) for t in torrents]
_match_torrents = torrents
# 开始过滤
self.progress.update(value=98, text=f'开始过滤,总 {len(_match_torrents)} 个资源,请稍候...',
key=ProgressKey.Search)
# 开始过滤规则过滤
if _match_torrents:
logger.info(f'开始过滤规则过滤,当前规则:{filter_rule} ...')
_match_torrents = self.filter_torrents_by_rule(torrents=_match_torrents,
mediainfo=mediainfo,
filter_rule=filter_rule)
if not _match_torrents:
logger.warn(f'{keyword or mediainfo.title} 没有符合过滤规则的资源')
return []
logger.info(f"过滤规则过滤完成,剩余 {len(_match_torrents)} 个资源")
# 开始优先级规则/剧集过滤
if priority_rule is None:
# 取搜索优先级规则
priority_rule = self.systemconfig.get(SystemConfigKey.SearchFilterRules)
if priority_rule:
logger.info(f'开始优先级规则/剧集过滤,当前规则:{priority_rule} ...')
_match_torrents = __do_filter(_match_torrents)
if not _match_torrents:
logger.warn(f'{keyword or mediainfo.title} 没有符合优先级规则的资源')
return []
logger.info(f"优先级规则/剧集过滤完成,剩余 {len(_match_torrents)} 个资源")
# 去掉mediainfo中多余的数据
mediainfo.clear()
# 组装上下文
contexts = [Context(torrent_info=t[0],
contexts = [Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
media_info=mediainfo,
meta_info=t[1]) for t in _match_torrents]
torrent_info=torrent) for torrent in _match_torrents]
self.progress.update(value=99, text=f'过滤完成,剩余 {len(contexts)} 个资源', key=ProgressKey.Search)
# 排序
self.progress.update(value=99,
@@ -274,10 +261,10 @@ class SearchChain(ChainBase):
contexts = self.torrenthelper.sort_torrents(contexts)
# 结束进度
logger.info(f'搜索完成,共 {len(contexts)} 个资源')
self.progress.update(value=100,
text=f'搜索完成,共 {len(contexts)} 个资源',
key=ProgressKey.Search)
logger.info(f'搜索完成,共 {len(contexts)} 个资源')
self.progress.end(ProgressKey.Search)
# 返回
@@ -329,36 +316,34 @@ class SearchChain(ChainBase):
self.progress.update(value=0,
text=f"开始搜索,共 {total_num} 个站点 ...",
key=ProgressKey.Search)
# 多线程
executor = ThreadPoolExecutor(max_workers=len(indexer_sites))
all_task = []
for site in indexer_sites:
if area == "imdbid":
# 搜索IMDBID
task = executor.submit(self.search_torrents, site=site,
keywords=[mediainfo.imdb_id] if mediainfo else None,
mtype=mediainfo.type if mediainfo else None,
page=page)
else:
# 搜索标题
task = executor.submit(self.search_torrents, site=site,
keywords=keywords,
mtype=mediainfo.type if mediainfo else None,
page=page)
all_task.append(task)
# 结果集
results = []
for future in as_completed(all_task):
if global_vars.is_system_stopped:
break
finish_count += 1
result = future.result()
if result:
results.extend(result)
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
self.progress.update(value=finish_count / total_num * 100,
text=f"正在搜索{keywords or ''},已完成 {finish_count} / {total_num} 个站点 ...",
key=ProgressKey.Search)
# 多线程
with ThreadPoolExecutor(max_workers=len(indexer_sites)) as executor:
all_task = []
for site in indexer_sites:
if area == "imdbid":
# 搜索IMDBID
task = executor.submit(self.search_torrents, site=site,
keywords=[mediainfo.imdb_id] if mediainfo else None,
mtype=mediainfo.type if mediainfo else None,
page=page)
else:
# 搜索标题
task = executor.submit(self.search_torrents, site=site,
keywords=keywords,
mtype=mediainfo.type if mediainfo else None,
page=page)
all_task.append(task)
for future in as_completed(all_task):
finish_count += 1
result = future.result()
if result:
results.extend(result)
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
self.progress.update(value=finish_count / total_num * 100,
text=f"正在搜索{keywords or ''},已完成 {finish_count} / {total_num} 个站点 ...",
key=ProgressKey.Search)
# 计算耗时
end_time = datetime.now()
# 更新进度
@@ -371,6 +356,34 @@ class SearchChain(ChainBase):
# 返回
return results
def filter_torrents_by_rule(self,
torrents: List[TorrentInfo],
mediainfo: MediaInfo,
filter_rule: Dict[str, str] = None,
) -> List[TorrentInfo]:
"""
使用过滤规则过滤种子
:param torrents: 种子列表
:param filter_rule: 过滤规则
:param mediainfo: 媒体信息
"""
if not filter_rule:
# 没有则取搜索默认过滤规则
filter_rule = self.systemconfig.get(SystemConfigKey.DefaultSearchFilterRules)
if not filter_rule:
return torrents
# 使用默认过滤规则再次过滤
return list(filter(
lambda t: self.torrenthelper.filter_torrent(
torrent_info=t,
filter_rule=filter_rule,
mediainfo=mediainfo
),
torrents
))
@eventmanager.register(EventType.SiteDeleted)
def remove_site(self, event: Event):
"""

View File

@@ -1,18 +1,20 @@
import base64
import re
from datetime import datetime
from time import time
from typing import Optional, Tuple, Union, Dict
from typing import Tuple, Optional
from typing import Union
from urllib.parse import urljoin
from lxml import etree
from app.chain import ChainBase
from app.core.config import global_vars, settings
from app.core.event import Event, EventManager, eventmanager
from app.core.config import settings
from app.core.event import eventmanager, Event, EventManager
from app.db.models.site import Site
from app.db.site_oper import SiteOper
from app.db.siteicon_oper import SiteIconOper
from app.db.systemconfig_oper import SystemConfigOper
from app.db.sitestatistic_oper import SiteStatisticOper
from app.helper.browser import PlaywrightHelper
from app.helper.cloudflare import under_challenge
from app.helper.cookie import CookieHelper
@@ -21,8 +23,8 @@ from app.helper.message import MessageHelper
from app.helper.rss import RssHelper
from app.helper.sites import SitesHelper
from app.log import logger
from app.schemas import MessageChannel, Notification, SiteUserData
from app.schemas.types import EventType, NotificationType
from app.schemas import MessageChannel, Notification
from app.schemas.types import EventType
from app.utils.http import RequestUtils
from app.utils.site import SiteUtils
from app.utils.string import StringUtils
@@ -36,12 +38,14 @@ class SiteChain(ChainBase):
def __init__(self):
super().__init__()
self.siteoper = SiteOper()
self.siteiconoper = SiteIconOper()
self.siteshelper = SitesHelper()
self.rsshelper = RssHelper()
self.cookiehelper = CookieHelper()
self.message = MessageHelper()
self.cookiecloud = CookieCloudHelper()
self.systemconfig = SystemConfigOper()
self.sitestatistic = SiteStatisticOper()
# 特殊站点登录验证
self.special_site_test = {
@@ -54,69 +58,6 @@ class SiteChain(ChainBase):
"yemapt.org": self.__yema_test,
}
def refresh_userdata(self, site: dict = None) -> Optional[SiteUserData]:
"""
刷新站点的用户数据
:param site: 站点
:return: 用户数据
"""
userdata: SiteUserData = self.run_module("refresh_userdata", site=site)
if userdata:
self.siteoper.update_userdata(domain=StringUtils.get_url_domain(site.get("domain")),
name=site.get("name"),
payload=userdata.dict())
# 发送事件
EventManager().send_event(EventType.SiteRefreshed, {
"site_id": site.get("id")
})
# 发送站点消息
if userdata.message_unread:
if userdata.message_unread_contents and len(userdata.message_unread_contents) > 0:
for head, date, content in userdata.message_unread_contents:
msg_title = f"【站点 {site.get('name')} 消息】"
msg_text = f"时间:{date}\n标题:{head}\n内容:\n{content}"
self.post_message(Notification(
mtype=NotificationType.SiteMessage,
title=msg_title, text=msg_text, link=site.get("url")
))
else:
self.post_message(Notification(
mtype=NotificationType.SiteMessage,
title=f"站点 {site.get('name')} 收到 "
f"{userdata.message_unread} 条新消息,请登陆查看",
link=site.get("url")
))
# 低分享率警告
if userdata.ratio and float(userdata.ratio) < 1 and not bool(
re.search(r"(贵宾|VIP?)", userdata.user_level or "", re.IGNORECASE)):
self.post_message(Notification(
mtype=NotificationType.SiteMessage,
title=f"【站点分享率低预警】",
text=f"站点 {site.get('name')} 分享率 {userdata.ratio},请注意!"
))
return userdata
def refresh_userdatas(self) -> Optional[Dict[str, SiteUserData]]:
"""
刷新所有站点的用户数据
"""
sites = self.siteshelper.get_indexers()
any_site_updated = False
result = {}
for site in sites:
if global_vars.is_system_stopped:
return None
if site.get("is_active"):
userdata = self.refresh_userdata(site)
if userdata:
any_site_updated = True
result[site.get("name")] = userdata
if any_site_updated:
EventManager().send_event(EventType.SiteRefreshed, {
"site_id": "*"
})
return result
def is_special_site(self, domain: str) -> bool:
"""
判断是否特殊站点
@@ -137,14 +78,10 @@ class SiteChain(ChainBase):
proxies=settings.PROXY if site.proxy else None,
timeout=site.timeout or 15
).get_res(url=site.url)
if res is None:
return False, "无法打开网站!"
if res.status_code == 200:
if res and res.status_code == 200:
csrf_token = re.search(r'<meta name="x-csrf-token" content="(.+?)">', res.text)
if csrf_token:
token = csrf_token.group(1)
else:
return False, f"错误:{res.status_code} {res.reason}"
if not token:
return False, "无法获取Token"
# 调用查询用户信息接口
@@ -158,15 +95,11 @@ class SiteChain(ChainBase):
proxies=settings.PROXY if site.proxy else None,
timeout=site.timeout or 15
).get_res(url=f"{site.url}api/user/getInfo")
if user_res is None:
return False, "无法打开网站!"
if user_res.status_code == 200:
if user_res and user_res.status_code == 200:
user_info = user_res.json()
if user_info and user_info.get("data"):
return True, "连接成功"
return False, "Cookie已失效"
else:
return False, f"错误:{user_res.status_code} {user_res.reason}"
return False, "Cookie已失效"
@staticmethod
def __mteam_test(site: Site) -> Tuple[bool, str]:
@@ -177,12 +110,9 @@ class SiteChain(ChainBase):
domain = StringUtils.get_url_domain(site.url)
url = f"https://api.{domain}/api/member/profile"
headers = {
"Content-Type": "application/json",
"User-Agent": user_agent,
"Accept": "application/json, text/plain, */*",
"Authorization": site.token,
"x-api-key": site.apikey,
"ts": str(int(time()))
}
res = RequestUtils(
headers=headers,
@@ -192,27 +122,10 @@ class SiteChain(ChainBase):
if res is None:
return False, "无法打开网站!"
if res.status_code == 200:
state = False
message = "鉴权已过期或无效"
user_info = res.json() or {}
if user_info.get("data"):
# 更新最后访问时间
del headers["x-api-key"]
res = RequestUtils(headers=headers,
timeout=site.timeout or 15,
proxies=settings.PROXY if site.proxy else None,
referer=f"{site.url}index"
).post_res(url=f"https://api.{domain}/api/member/updateLastBrowse")
state = True
message = "连接成功,但更新状态失败"
if res and res.status_code == 200:
update_info = res.json() or {}
if "code" in update_info and int(update_info["code"]) == 0:
message = "连接成功"
elif user_info.get("message"):
# 使用馒头的错误提示
message = user_info.get("message")
return state, message
return True, "连接成功"
return False, user_info.get("message", "鉴权已过期或无效")
else:
return False, f"错误:{res.status_code} {res.reason}"
@@ -234,15 +147,11 @@ class SiteChain(ChainBase):
proxies=settings.PROXY if site.proxy else None,
timeout=site.timeout or 15
).get_res(url=url)
if res is None:
return False, "无法打开网站!"
if res.status_code == 200:
if res and res.status_code == 200:
user_info = res.json()
if user_info and user_info.get("success"):
return True, "连接成功"
return False, "Cookie已过期"
else:
return False, f"错误:{res.status_code} {res.reason}"
return False, "Cookie已过期"
def __indexphp_test(self, site: Site) -> Tuple[bool, str]:
"""
@@ -268,7 +177,7 @@ class SiteChain(ChainBase):
logger.error(f"获取站点页面失败:{url}")
return favicon_url, None
html = etree.HTML(html_text)
if StringUtils.is_valid_html_element(html):
if html:
fav_link = html.xpath('//head/link[contains(@rel, "icon")]/@href')
if fav_link:
favicon_url = urljoin(url, fav_link[0])
@@ -345,7 +254,6 @@ class SiteChain(ChainBase):
continue
# 新增站点
domain_url = __indexer_domain(inx=indexer, sub_domain=domain)
proxy = False
res = RequestUtils(cookies=cookie,
ua=settings.USER_AGENT
).get_res(url=domain_url)
@@ -363,37 +271,16 @@ class SiteChain(ChainBase):
logger.warn(f"站点 {indexer.get('name')} 连接状态码:{res.status_code},无法添加站点")
continue
else:
if not settings.PROXY_HOST:
_fail_count += 1
logger.warn(f"站点 {indexer.get('name')} 连接失败,无法添加站点")
continue
else:
# 如果配置了代理,尝试通过代理重试
logger.info(f"站点 {indexer.get('name')} 初次连接失败,尝试通过代理重试...")
proxy = True
res = RequestUtils(cookies=cookie,
ua=settings.USER_AGENT,
proxies=settings.PROXY
).get_res(url=domain_url)
if res and res.status_code in [200, 500, 403]:
if not indexer.get("public") and not SiteUtils.is_logged_in(res.text):
logger.warn(f"站点 {indexer.get('name')} 登录失败,即使通过代理,无法添加站点")
_fail_count += 1
continue
logger.info(f"站点 {indexer.get('name')} 通过代理连接成功")
else:
logger.warn(f"站点 {indexer.get('name')} 通过代理连接失败,无法添加站点")
_fail_count += 1
continue
_fail_count += 1
logger.warn(f"站点 {indexer.get('name')} 连接失败,无法添加站点")
continue
# 获取rss地址
rss_url = None
if not indexer.get("public") and domain_url:
# 自动生成rss地址
rss_url, errmsg = self.rsshelper.get_rss_link(url=domain_url,
cookie=cookie,
ua=settings.USER_AGENT,
proxy=proxy)
ua=settings.USER_AGENT)
if errmsg:
logger.warn(errmsg)
# 插入数据库
@@ -403,7 +290,6 @@ class SiteChain(ChainBase):
domain=domain,
cookie=cookie,
rss=rss_url,
proxy=1 if proxy else 0,
public=1 if indexer.get("public") else 0)
_add_count += 1
@@ -448,17 +334,17 @@ class SiteChain(ChainBase):
logger.warn(f"站点 {domain} 索引器不存在!")
return
# 查询站点图标
site_icon = self.siteoper.get_icon_by_domain(domain)
site_icon = self.siteiconoper.get_by_domain(domain)
if not site_icon or not site_icon.base64:
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
cookie=cookie,
ua=settings.USER_AGENT)
if icon_url:
self.siteoper.update_icon(name=indexer.get("name"),
domain=domain,
icon_url=icon_url,
icon_base64=icon_base64)
self.siteiconoper.update_icon(name=indexer.get("name"),
domain=domain,
icon_url=icon_url,
icon_base64=icon_base64)
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
else:
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
@@ -484,26 +370,6 @@ class SiteChain(ChainBase):
logger.info(f"清理站点配置:{key}")
self.systemconfig.delete(key)
@eventmanager.register(EventType.SiteUpdated)
def cache_site_userdata(self, event: Event):
"""
缓存站点用户数据
"""
if not event:
return
event_data = event.event_data or {}
# 主域名
domain = event_data.get("domain")
if not domain:
return
if str(domain).startswith("http"):
domain = StringUtils.get_url_domain(domain)
indexer = self.siteshelper.get_indexer(domain)
if not indexer:
return
# 刷新站点用户数据
self.refresh_userdata(site=indexer) or {}
def test(self, url: str) -> Tuple[bool, str]:
"""
测试站点是否可用
@@ -529,9 +395,9 @@ class SiteChain(ChainBase):
# 统计
seconds = (datetime.now() - start_time).seconds
if state:
self.siteoper.success(domain=domain, seconds=seconds)
self.sitestatistic.success(domain=domain, seconds=seconds)
else:
self.siteoper.fail(domain)
self.sitestatistic.fail(domain)
return state, message
except Exception as e:
return False, f"{str(e)}"
@@ -572,18 +438,17 @@ class SiteChain(ChainBase):
elif res.status_code == 200:
msg = "Cookie已失效"
else:
msg = f"错误{res.status_code} {res.reason}"
msg = f"状态码{res.status_code}"
return False, f"{msg}"
elif public and res.status_code != 200:
return False, f"错误{res.status_code} {res.reason}"
return False, f"状态码{res.status_code}"
elif res is not None:
return False, f"错误{res.status_code} {res.reason}"
return False, f"状态码{res.status_code}"
else:
return False, f"无法打开网站!"
return True, "连接成功"
def remote_list(self, channel: MessageChannel,
userid: Union[str, int] = None, source: str = None):
def remote_list(self, channel: MessageChannel, userid: Union[str, int] = None):
"""
查询所有站点,发送消息
"""
@@ -611,13 +476,10 @@ class SiteChain(ChainBase):
# 发送列表
self.post_message(Notification(
channel=channel,
source=source,
title=title, text="\n".join(messages), userid=userid,
link=settings.MP_DOMAIN('#/site'))
)
link=settings.MP_DOMAIN('#/site')))
def remote_disable(self, arg_str: str, channel: MessageChannel,
userid: Union[str, int] = None, source: str = None):
def remote_disable(self, arg_str, channel: MessageChannel, userid: Union[str, int] = None):
"""
禁用站点
"""
@@ -639,10 +501,9 @@ class SiteChain(ChainBase):
"is_active": False
})
# 重新发送消息
self.remote_list(channel=channel, userid=userid, source=source)
self.remote_list(channel, userid)
def remote_enable(self, arg_str: str, channel: MessageChannel,
userid: Union[str, int] = None, source: str = None):
def remote_enable(self, arg_str, channel: MessageChannel, userid: Union[str, int] = None):
"""
启用站点
"""
@@ -665,7 +526,7 @@ class SiteChain(ChainBase):
"is_active": True
})
# 重新发送消息
self.remote_list(channel=channel, userid=userid, source=source)
self.remote_list(channel, userid)
def update_cookie(self, site_info: Site,
username: str, password: str, two_step_code: str = None) -> Tuple[bool, str]:
@@ -696,8 +557,7 @@ class SiteChain(ChainBase):
return True, msg
return False, "未知错误"
def remote_cookie(self, arg_str: str, channel: MessageChannel,
userid: Union[str, int] = None, source: str = None):
def remote_cookie(self, arg_str: str, channel: MessageChannel, userid: Union[str, int] = None):
"""
使用用户名密码更新站点Cookie
"""
@@ -706,7 +566,6 @@ class SiteChain(ChainBase):
if not arg_str:
self.post_message(Notification(
channel=channel,
source=source,
title=err_title, userid=userid))
return
arg_str = str(arg_str).strip()
@@ -718,14 +577,12 @@ class SiteChain(ChainBase):
elif len(args) != 3:
self.post_message(Notification(
channel=channel,
source=source,
title=err_title, userid=userid))
return
site_id = args[0]
if not site_id.isdigit():
self.post_message(Notification(
channel=channel,
source=source,
title=err_title, userid=userid))
return
# 站点ID
@@ -735,12 +592,10 @@ class SiteChain(ChainBase):
if not site_info:
self.post_message(Notification(
channel=channel,
source=source,
title=f"站点编号 {site_id} 不存在!", userid=userid))
return
self.post_message(Notification(
channel=channel,
source=source,
title=f"开始更新【{site_info.name}】Cookie&UA ...", userid=userid))
# 用户名
username = args[1]
@@ -755,76 +610,11 @@ class SiteChain(ChainBase):
logger.error(msg)
self.post_message(Notification(
channel=channel,
source=source,
title=f"{site_info.name}】 Cookie&UA更新失败",
text=f"错误原因:{msg}",
userid=userid))
else:
self.post_message(Notification(
channel=channel,
source=source,
title=f"{site_info.name}】 Cookie&UA更新成功",
userid=userid))
def remote_refresh_userdatas(self, channel: MessageChannel,
userid: Union[str, int] = None, source: str = None):
"""
刷新所有站点用户数据
"""
logger.info("收到命令,开始刷新站点数据 ...")
self.post_message(Notification(
channel=channel,
source=source,
title="开始刷新站点数据 ...",
userid=userid
))
# 刷新站点数据
site_datas = self.refresh_userdatas()
if site_datas:
# 发送消息
messages = {}
# 总上传
incUploads = 0
# 总下载
incDownloads = 0
# 今天日期
today_date = datetime.now().strftime("%Y-%m-%d")
for rand, site in enumerate(site_datas.keys()):
upload = int(site_datas[site].upload or 0)
download = int(site_datas[site].download or 0)
updated_date = site_datas[site].updated_day
if updated_date and updated_date != today_date:
updated_date = f"{updated_date}"
else:
updated_date = ""
if upload > 0 or download > 0:
incUploads += upload
incDownloads += download
messages[upload + (rand / 1000)] = (
f"{site}{updated_date}\n"
+ f"上传量:{StringUtils.str_filesize(upload)}\n"
+ f"下载量:{StringUtils.str_filesize(download)}\n"
+ "————————————"
)
if incDownloads or incUploads:
sorted_messages = [messages[key] for key in sorted(messages.keys(), reverse=True)]
sorted_messages.insert(0, f"【汇总】\n"
f"总上传:{StringUtils.str_filesize(incUploads)}\n"
f"总下载:{StringUtils.str_filesize(incDownloads)}\n"
f"————————————")
self.post_message(Notification(
channel=channel,
source=source,
title="【站点数据统计】",
text="\n".join(sorted_messages),
userid=userid
))
else:
self.post_message(Notification(
channel=channel,
source=source,
title="没有刷新到任何站点数据!",
userid=userid
))

View File

@@ -1,177 +0,0 @@
from pathlib import Path
from typing import Optional, Tuple, List, Dict
from app import schemas
from app.chain import ChainBase
from app.core.config import settings
from app.helper.directory import DirectoryHelper
from app.log import logger
from app.schemas import MediaType
class StorageChain(ChainBase):
"""
存储处理链
"""
def __init__(self):
super().__init__()
self.directoryhelper = DirectoryHelper()
def save_config(self, storage: str, conf: dict) -> None:
"""
保存存储配置
"""
self.run_module("save_config", storage=storage, conf=conf)
def generate_qrcode(self, storage: str) -> Optional[Tuple[dict, str]]:
"""
生成二维码
"""
return self.run_module("generate_qrcode", storage=storage)
def check_login(self, storage: str, **kwargs) -> Optional[Tuple[dict, str]]:
"""
登录确认
"""
return self.run_module("check_login", storage=storage, **kwargs)
def list_files(self, fileitem: schemas.FileItem, recursion: bool = False) -> Optional[List[schemas.FileItem]]:
"""
查询当前目录下所有目录和文件
"""
return self.run_module("list_files", fileitem=fileitem, recursion=recursion)
def any_files(self, fileitem: schemas.FileItem, extensions: list = None) -> Optional[bool]:
"""
查询当前目录下是否存在指定扩展名任意文件
"""
return self.run_module("any_files", fileitem=fileitem, extensions=extensions)
def create_folder(self, fileitem: schemas.FileItem, name: str) -> Optional[schemas.FileItem]:
"""
创建目录
"""
return self.run_module("create_folder", fileitem=fileitem, name=name)
def download_file(self, fileitem: schemas.FileItem, path: Path = None) -> Optional[Path]:
"""
下载文件
:param fileitem: 文件项
:param path: 本地保存路径
"""
return self.run_module("download_file", fileitem=fileitem, path=path)
def upload_file(self, fileitem: schemas.FileItem, path: Path,
new_name: str = None) -> Optional[schemas.FileItem]:
"""
上传文件
:param fileitem: 保存目录项
:param path: 本地文件路径
:param new_name: 新文件名
"""
return self.run_module("upload_file", fileitem=fileitem, path=path, new_name=new_name)
def delete_file(self, fileitem: schemas.FileItem) -> Optional[bool]:
"""
删除文件或目录
"""
return self.run_module("delete_file", fileitem=fileitem)
def rename_file(self, fileitem: schemas.FileItem, name: str) -> Optional[bool]:
"""
重命名文件或目录
"""
return self.run_module("rename_file", fileitem=fileitem, name=name)
def get_item(self, fileitem: schemas.FileItem) -> Optional[schemas.FileItem]:
"""
查询目录或文件
"""
return self.get_file_item(storage=fileitem.storage, path=Path(fileitem.path))
def get_file_item(self, storage: str, path: Path) -> Optional[schemas.FileItem]:
"""
根据路径获取文件项
"""
return self.run_module("get_file_item", storage=storage, path=path)
def get_parent_item(self, fileitem: schemas.FileItem) -> Optional[schemas.FileItem]:
"""
获取上级目录项
"""
return self.run_module("get_parent_item", fileitem=fileitem)
def snapshot_storage(self, storage: str, path: Path) -> Optional[Dict[str, float]]:
"""
快照存储
"""
return self.run_module("snapshot_storage", storage=storage, path=path)
def storage_usage(self, storage: str) -> Optional[schemas.StorageUsage]:
"""
存储使用情况
"""
return self.run_module("storage_usage", storage=storage)
def support_transtype(self, storage: str) -> Optional[dict]:
"""
获取支持的整理方式
"""
return self.run_module("support_transtype", storage=storage)
def delete_media_file(self, fileitem: schemas.FileItem,
mtype: MediaType = None, delete_self: bool = True) -> bool:
"""
删除媒体文件,以及不含媒体文件的目录
"""
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
if fileitem.path == "/" or len(Path(fileitem.path).parts) <= 2:
logger.warn(f"{fileitem.storage}{fileitem.path} 根目录或一级目录不允许删除")
return False
if fileitem.type == "dir":
# 本身是目录
if _blue_dir := self.list_files(fileitem=fileitem, recursion=False):
# 删除蓝光目录
for _f in _blue_dir:
if _f.type == "dir" and _f.name in ["BDMV", "CERTIFICATE"]:
logger.warn(f"{fileitem.storage}{_f.path} 删除蓝光目录")
self.delete_file(_f)
if self.any_files(fileitem, extensions=media_exts) is False:
logger.warn(f"{fileitem.storage}{fileitem.path} 不存在其它媒体文件,删除空目录")
return self.delete_file(fileitem)
return False
elif delete_self:
# 本身是文件
logger.warn(f"正在删除【{fileitem.storage}{fileitem.path}")
if not self.delete_file(fileitem):
logger.warn(f"{fileitem.storage}{fileitem.path} 删除失败")
return False
if mtype:
# 重命名格式
rename_format = settings.TV_RENAME_FORMAT \
if mtype == MediaType.TV else settings.MOVIE_RENAME_FORMAT
# 计算重命名中的文件夹层数
rename_format_level = len(rename_format.split("/")) - 1
if rename_format_level < 1:
return True
# 处理上级目录
dir_item = self.get_file_item(storage=fileitem.storage,
path=Path(fileitem.path).parents[rename_format_level - 1])
else:
dir_item = self.get_parent_item(fileitem)
if dir_item and len(Path(dir_item.path).parts) > 2:
# 如何目录是所有下载目录、媒体库目录的上级,则不处理
for d in self.directoryhelper.get_dirs():
if d.download_path and Path(d.download_path).is_relative_to(Path(dir_item.path)):
logger.debug(f"{dir_item.storage}{dir_item.path} 是下载目录本级或上级目录,不删除")
return True
if d.library_path and Path(d.library_path).is_relative_to(Path(dir_item.path)):
logger.debug(f"{dir_item.storage}{dir_item.path} 是媒体库目录本级或上级目录,不删除")
return True
# 不存在其他媒体文件,删除空目录
if self.any_files(dir_item, extensions=media_exts) is False:
logger.warn(f"{dir_item.storage}{dir_item.path} 不存在其它媒体文件,删除空目录")
return self.delete_file(dir_item)
return True

File diff suppressed because it is too large Load Diff

View File

@@ -10,7 +10,6 @@ from app.schemas import Notification, MessageChannel
from app.utils.http import RequestUtils
from app.utils.singleton import Singleton
from app.utils.system import SystemUtils
from version import FRONTEND_VERSION, APP_VERSION
class SystemChain(ChainBase, metaclass=Singleton):
@@ -20,25 +19,20 @@ class SystemChain(ChainBase, metaclass=Singleton):
_restart_file = "__system_restart__"
def __init__(self):
super().__init__()
# 重启完成检测
self.restart_finish()
def remote_clear_cache(self, channel: MessageChannel, userid: Union[int, str], source: str = None):
def remote_clear_cache(self, channel: MessageChannel, userid: Union[int, str]):
"""
清理系统缓存
"""
self.clear_cache()
self.post_message(Notification(channel=channel, source=source,
self.post_message(Notification(channel=channel,
title=f"缓存清理完成!", userid=userid))
def restart(self, channel: MessageChannel, userid: Union[int, str], source: str = None):
def restart(self, channel: MessageChannel, userid: Union[int, str]):
"""
重启系统
"""
if channel and userid:
self.post_message(Notification(channel=channel, source=source,
self.post_message(Notification(channel=channel,
title="系统正在重启,请耐心等候!", userid=userid))
# 保存重启信息
self.save_cache({
@@ -65,11 +59,11 @@ class SystemChain(ChainBase, metaclass=Singleton):
title += f"当前前端版本:{front_local_version},远程版本:{front_release_version}"
return title
def version(self, channel: MessageChannel, userid: Union[int, str], source: str = None):
def version(self, channel: MessageChannel, userid: Union[int, str]):
"""
查看当前版本、远程版本
"""
self.post_message(Notification(channel=channel, source=source,
self.post_message(Notification(channel=channel,
title=self.__get_version_message(),
userid=userid))
@@ -99,63 +93,60 @@ class SystemChain(ChainBase, metaclass=Singleton):
@staticmethod
def __get_server_release_version():
"""
获取后端V2最新版本
获取后端最新版本
"""
try:
# 获取所有发布的版本列表
response = RequestUtils(
proxies=settings.PROXY,
headers=settings.GITHUB_HEADERS
).get_res("https://api.github.com/repos/jxxghp/MoviePilot/releases")
if response:
releases = [release['tag_name'] for release in response.json()]
v2_releases = [tag for tag in releases if re.match(r"^v2\.", tag)]
if not v2_releases:
logger.warn("获取v2后端最新版本版本出错")
else:
# 找到最新的v2版本
latest_v2 = sorted(v2_releases, key=lambda s: list(map(int, re.findall(r'\d+', s))))[-1]
logger.info(f"获取到后端最新版本:{latest_v2}")
return latest_v2
version_res = RequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
"https://api.github.com/repos/jxxghp/MoviePilot/releases/latest")
if version_res:
ver_json = version_res.json()
version = f"{ver_json['tag_name']}"
return version
else:
logger.error("无法获取后端版本信息请检查网络连接或GitHub API请求。")
return None
except Exception as err:
logger.error(f"获取后端最新版本失败:{str(err)}")
return None
return None
@staticmethod
def __get_front_release_version():
"""
获取前端V2最新版本
获取前端最新版本
"""
try:
# 获取所有发布的版本列表
response = RequestUtils(
proxies=settings.PROXY,
headers=settings.GITHUB_HEADERS
).get_res("https://api.github.com/repos/jxxghp/MoviePilot-Frontend/releases")
if response:
releases = [release['tag_name'] for release in response.json()]
v2_releases = [tag for tag in releases if re.match(r"^v2\.", tag)]
if not v2_releases:
logger.warn("获取v2前端最新版本版本出错")
else:
# 找到最新的v2版本
latest_v2 = sorted(v2_releases, key=lambda s: list(map(int, re.findall(r'\d+', s))))[-1]
logger.info(f"获取到前端最新版本:{latest_v2}")
return latest_v2
version_res = RequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
"https://api.github.com/repos/jxxghp/MoviePilot-Frontend/releases/latest")
if version_res:
ver_json = version_res.json()
version = f"{ver_json['tag_name']}"
return version
else:
logger.error("无法获取前端版本信息请检查网络连接或GitHub API请求。")
return None
except Exception as err:
logger.error(f"获取前端最新版本失败:{str(err)}")
return None
return None
@staticmethod
def get_server_local_version():
"""
查看当前版本
"""
return APP_VERSION
version_file = settings.ROOT_PATH / "version.py"
if version_file.exists():
try:
with open(version_file, 'rb') as f:
version = f.read()
pattern = r"'([^']*)'"
match = re.search(pattern, str(version))
if match:
version = match.group(1)
return version
else:
logger.warn("未找到版本号")
return None
except Exception as err:
logger.error(f"加载版本文件 {version_file} 出错:{str(err)}")
@staticmethod
def get_frontend_version():
@@ -172,5 +163,7 @@ class SystemChain(ChainBase, metaclass=Singleton):
version = str(f.read()).strip()
return version
except Exception as err:
logger.debug(f"加载版本文件 {version_file} 出错:{str(err)}")
return FRONTEND_VERSION
logger.error(f"加载版本文件 {version_file} 出错:{str(err)}")
else:
logger.warn("未找到前端版本文件,请正确设置 FRONTEND_PATH")
return None

View File

@@ -1,9 +1,11 @@
import random
from typing import Optional, List
from cachetools import cached, TTLCache
from app import schemas
from app.chain import ChainBase
from app.core.cache import cached
from app.core.config import settings
from app.core.context import MediaInfo
from app.schemas import MediaType
from app.utils.singleton import Singleton
@@ -14,38 +16,19 @@ class TmdbChain(ChainBase, metaclass=Singleton):
TheMovieDB处理链单例运行
"""
def tmdb_discover(self, mtype: MediaType,
sort_by: str,
with_genres: str,
with_original_language: str,
with_keywords: str,
with_watch_providers: str,
vote_average: float,
vote_count: int,
release_date: str,
page: int = 1) -> Optional[List[MediaInfo]]:
def tmdb_discover(self, mtype: MediaType, sort_by: str, with_genres: str,
with_original_language: str, page: int = 1) -> Optional[List[MediaInfo]]:
"""
:param mtype: 媒体类型
:param sort_by: 排序方式
:param with_genres: 类型
:param with_original_language: 语言
:param with_keywords: 关键字
:param with_watch_providers: 提供商
:param vote_average: 评分
:param vote_count: 评分人数
:param release_date: 上映日期
:param page: 页码
:return: 媒体信息列表
"""
return self.run_module("tmdb_discover", mtype=mtype,
sort_by=sort_by,
with_genres=with_genres,
sort_by=sort_by, with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
def tmdb_trending(self, page: int = 1) -> Optional[List[MediaInfo]]:
@@ -56,13 +39,6 @@ class TmdbChain(ChainBase, metaclass=Singleton):
"""
return self.run_module("tmdb_trending", page=page)
def tmdb_collection(self, collection_id: int) -> Optional[List[MediaInfo]]:
"""
根据合集ID查询集合
:param collection_id: 合集ID
"""
return self.run_module("tmdb_collection", collection_id=collection_id)
def tmdb_seasons(self, tmdbid: int) -> List[schemas.TmdbSeason]:
"""
根据TMDBID查询themoviedb所有季信息
@@ -137,7 +113,7 @@ class TmdbChain(ChainBase, metaclass=Singleton):
"""
return self.run_module("tmdb_person_credits", person_id=person_id, page=page)
@cached(maxsize=1, ttl=3600)
@cached(cache=TTLCache(maxsize=1, ttl=3600))
def get_random_wallpager(self) -> Optional[str]:
"""
获取随机壁纸缓存1个小时
@@ -151,12 +127,12 @@ class TmdbChain(ChainBase, metaclass=Singleton):
return info.backdrop_path
return None
@cached(maxsize=1, ttl=3600)
def get_trending_wallpapers(self, num: int = 10) -> List[str]:
@cached(cache=TTLCache(maxsize=1, ttl=3600))
def get_trending_wallpapers(self, num: int = 10) -> Optional[List[str]]:
"""
获取所有流行壁纸
"""
infos = self.tmdb_trending()
if infos:
return [info.backdrop_path for info in infos if info and info.backdrop_path][:num]
return []
return None

View File

@@ -6,7 +6,7 @@ from cachetools import cached, TTLCache
from app.chain import ChainBase
from app.chain.media import MediaChain
from app.core.config import settings, global_vars
from app.core.config import settings
from app.core.context import TorrentInfo, Context, MediaInfo
from app.core.metainfo import MetaInfo
from app.db.site_oper import SiteOper
@@ -73,20 +73,17 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
logger.info(f'种子缓存数据清理完成')
@cached(cache=TTLCache(maxsize=128, ttl=595))
def browse(self, domain: str, keyword: str = None, cat: str = None, page: int = 0) -> List[TorrentInfo]:
def browse(self, domain: str) -> List[TorrentInfo]:
"""
浏览站点首页内容返回种子清单TTL缓存10分钟
:param domain: 站点域名
:param keyword: 搜索标题
:param cat: 搜索分类
:param page: 页码
"""
logger.info(f'开始获取站点 {domain} 最新种子 ...')
site = self.siteshelper.get_indexer(domain)
if not site:
logger.error(f'站点 {domain} 不存在!')
return []
return self.refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
return self.refresh_torrents(site=site)
@cached(cache=TTLCache(maxsize=128, ttl=295))
def rss(self, domain: str) -> List[TorrentInfo]:
@@ -123,7 +120,6 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
site_ua=site.get("ua") or settings.USER_AGENT,
site_proxy=site.get("proxy"),
site_order=site.get("pri"),
site_downloader=site.get("downloader"),
title=item.get("title"),
enclosure=item.get("enclosure"),
page_url=item.get("link"),
@@ -162,8 +158,6 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
domains = []
# 遍历站点缓存资源
for indexer in indexers:
if global_vars.is_system_stopped:
break
# 未开启的站点不刷新
if sites and indexer.get("id") not in sites:
continue
@@ -178,7 +172,7 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
# 按pubdate降序排列
torrents.sort(key=lambda x: x.pubdate or '', reverse=True)
# 取前N条
torrents = torrents[:settings.CACHE_CONF["refresh"]]
torrents = torrents[:settings.CACHE_CONF.get('refresh')]
if torrents:
# 过滤出没有处理过的种子
torrents = [torrent for torrent in torrents
@@ -191,8 +185,6 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
logger.info(f'{indexer.get("name")} 没有新种子')
continue
for torrent in torrents:
if global_vars.is_system_stopped:
break
logger.info(f'处理资源:{torrent.title} ...')
# 识别
meta = MetaInfo(title=torrent.title, subtitle=torrent.description)
@@ -218,8 +210,8 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
else:
torrents_cache[domain].append(context)
# 如果超过了限制条数则移除掉前面的
if len(torrents_cache[domain]) > settings.CACHE_CONF["torrents"]:
torrents_cache[domain] = torrents_cache[domain][-settings.CACHE_CONF["torrents"]:]
if len(torrents_cache[domain]) > settings.CACHE_CONF.get('torrents'):
torrents_cache[domain] = torrents_cache[domain][-settings.CACHE_CONF.get('torrents'):]
# 回收资源
del torrents
else:

File diff suppressed because it is too large Load Diff

View File

@@ -1,237 +1,15 @@
import secrets
from typing import Optional, Tuple, Union
from typing import Optional
from app.chain import ChainBase
from app.core.config import settings
from app.core.security import get_password_hash, verify_password
from app.db.models.user import User
from app.db.user_oper import UserOper
from app.log import logger
from app.schemas import AuthCredentials, AuthInterceptCredentials
from app.schemas.types import ChainEventType
from app.utils.otp import OtpUtils
from app.utils.singleton import Singleton
PASSWORD_INVALID_CREDENTIALS_MESSAGE = "用户名或密码或二次校验码不正确"
class UserChain(ChainBase, metaclass=Singleton):
"""
用户链,处理多种认证协议
"""
class UserChain(ChainBase):
def __init__(self):
super().__init__()
self.user_oper = UserOper()
def user_authenticate(
self,
username: Optional[str] = None,
password: Optional[str] = None,
mfa_code: Optional[str] = None,
code: Optional[str] = None,
grant_type: str = "password"
) -> Union[Tuple[bool, Optional[str]], Tuple[bool, Optional[User]]]:
def user_authenticate(self, name, password) -> Optional[str]:
"""
认证用户,根据不同的 grant_type 处理不同的认证流程
:param username: 用户名,适用于 "password" grant_type
:param password: 用户密码,适用于 "password" grant_type
:param mfa_code: 一次性密码,适用于 "password" grant_type
:param code: 授权码,适用于 "authorization_code" grant_type
:param grant_type: 认证类型,如 "password", "authorization_code", "client_credentials"
:return:
- 对于成功的认证,返回 (True, User)
- 对于失败的认证,返回 (False, "错误信息")
辅助完成用户认证
:param name: 用户名
:param password: 密码
:return: token
"""
credentials = AuthCredentials(
username=username,
password=password,
mfa_code=mfa_code,
code=code,
grant_type=grant_type
)
logger.debug(f"认证类型:{grant_type},开始准备对用户 {username} 进行身份校验")
if credentials.grant_type == "password":
# Password 认证
success, user_or_message = self.password_authenticate(credentials=credentials)
if success:
# 如果用户启用了二次验证码,则进一步验证
if not self._verify_mfa(user_or_message, credentials.mfa_code):
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
logger.info(f"用户 {username} 通过密码认证成功")
return True, user_or_message
else:
# 用户不存在或密码错误,考虑辅助认证
if settings.AUXILIARY_AUTH_ENABLE:
logger.warning("密码认证失败,尝试通过外部服务进行辅助认证 ...")
aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials)
if aux_success:
# 辅助认证成功后再验证二次验证码
if not self._verify_mfa(aux_user_or_message, credentials.mfa_code):
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
return True, aux_user_or_message
else:
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
else:
logger.debug(f"辅助认证未启用,用户 {username} 认证失败")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
elif credentials.grant_type == "authorization_code":
# 处理其他认证类型的分支
if settings.AUXILIARY_AUTH_ENABLE:
aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials)
if aux_success:
return True, aux_user_or_message
else:
return False, "认证失败"
else:
return False, "认证失败"
else:
logger.debug(f"辅助认证未启用,认证类型 {grant_type} 未实现")
return False, "不支持的认证类型"
def password_authenticate(self, credentials: AuthCredentials) -> Tuple[bool, Union[User, str]]:
"""
密码认证
:param credentials: 认证凭证,包含用户名、密码以及可选的 MFA 认证码
:return:
- 成功时返回 (True, User),其中 User 是认证通过的用户对象
- 失败时返回 (False, "错误信息")
"""
if not credentials or credentials.grant_type != "password":
logger.info("密码认证失败,认证类型不匹配")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
user = self.user_oper.get_by_name(name=credentials.username)
if not user:
logger.info(f"密码认证失败,用户 {credentials.username} 不存在")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
if not user.is_active:
logger.info(f"密码认证失败,用户 {credentials.username} 已被禁用")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
if not verify_password(credentials.password, str(user.hashed_password)):
logger.info(f"密码认证失败,用户 {credentials.username} 的密码验证不通过")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
return True, user
def auxiliary_authenticate(self, credentials: AuthCredentials) -> Tuple[bool, Union[User, str]]:
"""
辅助用户认证
:param credentials: 认证凭证,包含必要的认证信息
:return:
- 成功时返回 (True, User),其中 User 是认证通过的用户对象
- 失败时返回 (False, "错误信息")
"""
if not credentials:
return False, "认证凭证无效"
# 检查是否因为用户被禁用
if credentials.username:
user = self.user_oper.get_by_name(name=credentials.username)
if user and not user.is_active:
logger.info(f"用户 {user.name} 已被禁用,跳过后续身份校验")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
logger.debug(f"认证类型:{credentials.grant_type},尝试通过系统模块进行辅助认证,用户: {credentials.username}")
result = self.run_module("user_authenticate", credentials=credentials)
if not result:
logger.debug(f"通过系统模块辅助认证失败,尝试触发 {ChainEventType.AuthVerification} 事件")
event = self.eventmanager.send_event(etype=ChainEventType.AuthVerification, data=credentials)
if not event or not event.event_data:
logger.error(f"认证类型:{credentials.grant_type},辅助认证失败,未返回有效数据")
return False, f"认证类型:{credentials.grant_type},辅助认证事件失败或无效"
credentials = event.event_data # 使用事件返回的认证数据
else:
logger.info(f"通过系统模块辅助认证成功,用户: {credentials.username}")
credentials = result # 使用模块认证返回的认证数据
# 处理认证成功的逻辑
success = self._process_auth_success(username=credentials.username, credentials=credentials)
if success:
logger.info(f"用户 {credentials.username} 辅助认证通过")
return True, self.user_oper.get_by_name(credentials.username)
else:
logger.warning(f"用户 {credentials.username} 辅助认证未通过")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
@staticmethod
def _verify_mfa(user: User, mfa_code: Optional[str]) -> bool:
"""
验证 MFA二次验证码
:param user: 用户对象
:param mfa_code: 二次验证码
:return: 如果验证成功返回 True否则返回 False
"""
if not user.is_otp:
return True
if not mfa_code:
logger.info(f"用户 {user.name} 缺少 MFA 认证码")
return False
if not OtpUtils.check(str(user.otp_secret), mfa_code):
logger.info(f"用户 {user.name} 的 MFA 认证失败")
return False
return True
def _process_auth_success(self, username: str, credentials: AuthCredentials) -> bool:
"""
处理辅助认证成功的逻辑,返回用户对象或创建新用户
:param username: 用户名
:param credentials: 认证凭证,包含 token、channel、service 等信息
:return:
- 如果认证成功并且用户存在或已创建,返回 User 对象
- 如果认证被拦截或失败,返回 None
"""
if not username:
logger.info(f"未能获取到对应的用户信息,{credentials.grant_type} 认证不通过")
return False
token, channel, service = credentials.token, credentials.channel, credentials.service
if not all([token, channel, service]):
logger.info(f"用户 {username} 未通过 {credentials.grant_type} 认证,必要信息不足")
return False
# 触发认证通过的拦截事件
intercept_event = self.eventmanager.send_event(
etype=ChainEventType.AuthIntercept,
data=AuthInterceptCredentials(username=username, channel=channel, service=service,
token=token, status="completed")
)
if intercept_event and intercept_event.event_data:
intercept_data: AuthInterceptCredentials = intercept_event.event_data
if intercept_data.cancel:
logger.warning(
f"认证被拦截,用户:{username},渠道:{channel},服务:{service},拦截源:{intercept_data.source}")
return False
# 检查用户是否存在,如果不存在且当前为密码认证时则创建新用户
user = self.user_oper.get_by_name(name=username)
if user:
# 如果用户存在,但是已经被禁用,则直接响应
if not user.is_active:
logger.info(f"辅助认证失败,用户 {username} 已被禁用")
return False
anonymized_token = f"{token[:len(token) // 2]}********"
logger.info(
f"认证类型:{credentials.grant_type},用户:{username},渠道:{channel}"
f"服务:{service} 认证成功token{anonymized_token}")
return True
else:
if credentials.grant_type == "password":
self.user_oper.add(name=username, is_active=True, is_superuser=False,
hashed_password=get_password_hash(secrets.token_urlsafe(16)))
logger.info(f"用户 {username} 不存在,已通过 {credentials.grant_type} 认证并已创建普通用户")
return True
else:
logger.warning(
f"认证类型:{credentials.grant_type},用户:{username},渠道:{channel}"
f"服务:{service} 认证不通过,未能在本地找到对应的用户信息")
return False
return self.run_module("user_authenticate", name=name, password=password)

View File

@@ -1,51 +0,0 @@
from typing import List
from app.chain import ChainBase
from app.core.workflow import WorkFlowManager
from app.db.workflow_oper import WorkflowOper
from app.log import logger
from app.schemas import Workflow, ActionContext, Action
class WorkflowChain(ChainBase):
"""
工作流链
"""
def __init__(self):
super().__init__()
self.workflowoper = WorkflowOper()
self.workflowmanager = WorkFlowManager()
def process(self, workflow_id: int) -> bool:
"""
处理工作流
"""
workflow = self.workflowoper.get(workflow_id)
if not workflow:
logger.warn(f"工作流 {workflow_id} 不存在")
return False
if not workflow.actions:
logger.warn(f"工作流 {workflow.name} 无动作")
return False
logger.info(f"开始处理 {workflow.name},共 {len(workflow.actions)} 个动作 ...")
# 启用上下文
context = ActionContext()
self.workflowoper.start(workflow_id)
for act in workflow.actions:
action = Action(**act)
state, context = self.workflowmanager.excute(action, context)
self.workflowoper.step(workflow_id, action=action.name, context=context.dict())
if not state:
logger.error(f"动作 {action.name} 执行失败,工作流失败")
self.workflowoper.fail(workflow_id, result=f"动作 {action.name} 执行失败")
return False
logger.info(f"工作流 {workflow.name} 执行完成")
self.workflowoper.success(workflow_id)
return True
def get_workflows(self) -> List[Workflow]:
"""
获取工作流列表
"""
return self.workflowoper.list_enabled()

View File

@@ -1,7 +1,9 @@
import copy
import importlib
import threading
import traceback
from typing import Any, Union, Dict, Optional
from threading import Thread
from typing import Any, Union, Dict
from app.chain import ChainBase
from app.chain.download import DownloadChain
@@ -10,37 +12,52 @@ from app.chain.subscribe import SubscribeChain
from app.chain.system import SystemChain
from app.chain.transfer import TransferChain
from app.core.config import settings
from app.core.event import Event as ManagerEvent, eventmanager, Event
from app.core.event import Event as ManagerEvent, eventmanager, EventManager
from app.core.plugin import PluginManager
from app.helper.message import MessageHelper
from app.helper.thread import ThreadHelper
from app.log import logger
from app.scheduler import Scheduler
from app.schemas import Notification, CommandRegisterEventData
from app.schemas.types import EventType, MessageChannel, ChainEventType
from app.schemas import Notification
from app.schemas.types import EventType, MessageChannel
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton
from app.utils.structures import DictUtils
class CommandChain(ChainBase):
pass
class CommandChian(ChainBase):
"""
插件处理链
"""
def process(self, *args, **kwargs):
pass
class Command(metaclass=Singleton):
"""
全局命令管理,消费事件
"""
# 内建命令
_commands = {}
# 退出事件
_event = threading.Event()
def __init__(self):
# 事件管理器
self.eventmanager = EventManager()
# 插件管理器
super().__init__()
# 注册的命令集合
self._registered_commands = {}
# 所有命令集合
self._commands = {}
# 内建命令集合
self._preset_commands = {
self.pluginmanager = PluginManager()
# 处理链
self.chain = CommandChian()
# 定时服务管理
self.scheduler = Scheduler()
# 消息管理器
self.messagehelper = MessageHelper()
# 线程管理器
self.threader = ThreadHelper()
# 内置命令
self._commands = {
"/cookiecloud": {
"id": "cookiecloud",
"type": "scheduler",
@@ -58,11 +75,6 @@ class Command(metaclass=Singleton):
"description": "更新站点Cookie",
"data": {}
},
"/site_statistic": {
"func": SiteChain().remote_refresh_userdatas,
"description": "站点数据统计",
"data": {}
},
"/site_enable": {
"func": SiteChain().remote_enable,
"description": "启用站点",
@@ -143,148 +155,98 @@ class Command(metaclass=Singleton):
"data": {}
}
}
# 插件命令集合
self._plugin_commands = {}
# 其他命令集合
self._other_commands = {}
# 初始化锁
self._rlock = threading.RLock()
# 插件管理
self.pluginmanager = PluginManager()
# 定时服务管理
self.scheduler = Scheduler()
# 消息管理器
self.messagehelper = MessageHelper()
# 初始化命令
self.init_commands()
def init_commands(self, pid: Optional[str] = None) -> None:
"""
初始化菜单命令
"""
if settings.DEV:
logger.debug("Development mode active. Skipping command initialization.")
return
# 使用线程池提交后台任务,避免引起阻塞
ThreadHelper().submit(self.__init_commands_background, pid)
def __init_commands_background(self, pid: Optional[str] = None) -> None:
"""
后台初始化菜单命令
"""
try:
with self._rlock:
logger.debug("Acquired lock for initializing commands in background.")
self._plugin_commands = self.__build_plugin_commands(pid)
self._commands = {
**self._preset_commands,
**self._plugin_commands,
**self._other_commands
# 汇总插件命令
plugin_commands = self.pluginmanager.get_plugin_commands()
for command in plugin_commands:
self.register(
cmd=command.get('cmd'),
func=Command.send_plugin_event,
desc=command.get('desc'),
category=command.get('category'),
data={
'etype': command.get('event'),
'data': command.get('data')
}
)
# 广播注册命令菜单
if not settings.DEV:
self.chain.register_commands(commands=self.get_commands())
# 消息处理线程
self._thread = Thread(target=self.__run)
# 启动事件处理线程
self._thread.start()
# 重启msg
SystemChain().restart_finish()
# 强制触发注册
force_register = False
# 触发事件允许可以拦截和调整命令
event, initial_commands = self.__trigger_register_commands_event()
if event and event.event_data:
# 如果事件返回有效的 event_data使用事件中调整后的命令
event_data: CommandRegisterEventData = event.event_data
# 如果事件被取消,跳过命令注册
if event_data.cancel:
logger.debug(f"Command initialization canceled by event: {event_data.source}")
return
# 如果拦截源与插件标识一致时,这里认为需要强制触发注册
if pid is not None and pid == event_data.source:
force_register = True
initial_commands = event_data.commands or {}
logger.debug(f"Registering command count from event: {len(initial_commands)}")
else:
logger.debug(f"Registering initial command count: {len(initial_commands)}")
# initial_commands 必须是 self._commands 的子集
filtered_initial_commands = DictUtils.filter_keys_to_subset(initial_commands, self._commands)
# 如果 filtered_initial_commands 为空,则跳过注册
if not filtered_initial_commands and not force_register:
logger.debug("Filtered commands are empty, skipping registration.")
return
# 对比调整后的命令与当前命令
if filtered_initial_commands != self._registered_commands or force_register:
logger.debug("Command set has changed or force registration is enabled.")
self._registered_commands = filtered_initial_commands
CommandChain().register_commands(commands=filtered_initial_commands)
else:
logger.debug("Command set unchanged, skipping broadcast registration.")
except Exception as e:
logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True)
def __trigger_register_commands_event(self) -> (Optional[Event], dict):
def __run(self):
"""
触发事件,允许调整命令数据
事件处理线程
"""
while not self._event.is_set():
event, handlers = self.eventmanager.get_event()
if event:
logger.info(f"处理事件:{event.event_type} - {handlers}")
for handler in handlers:
names = handler.__qualname__.split(".")
[class_name, method_name] = names
try:
if class_name in self.pluginmanager.get_plugin_ids():
# 插件事件
self.threader.submit(
self.pluginmanager.run_plugin_method,
class_name, method_name, copy.deepcopy(event)
)
def add_commands(source, command_type):
"""
添加命令集合
"""
for cmd, command in source.items():
command_data = {
"type": command_type,
"description": command.get("description"),
"category": command.get("category")
}
# 如果有 pid则添加到命令数据中
plugin_id = command.get("pid")
if plugin_id:
command_data["pid"] = plugin_id
commands[cmd] = command_data
else:
# 检查全局变量中是否存在
if class_name not in globals():
# 导入模块除了插件和Command本身只有chain能响应事件
try:
module = importlib.import_module(
f"app.chain.{class_name[:-5].lower()}"
)
class_obj = getattr(module, class_name)()
except Exception as e:
logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
continue
# 初始化命令字典
commands: Dict[str, dict] = {}
add_commands(self._preset_commands, "preset")
add_commands(self._plugin_commands, "plugin")
add_commands(self._other_commands, "other")
else:
# 通过类名创建类实例
class_obj = globals()[class_name]()
# 检查类是否存在并调用方法
if hasattr(class_obj, method_name):
self.threader.submit(
getattr(class_obj, method_name),
copy.deepcopy(event)
)
except Exception as e:
logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
self.messagehelper.put(title=f"{event.event_type} 事件处理出错",
message=f"{class_name}.{method_name}{str(e)}",
role="system")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "event",
"event_type": event.event_type,
"event_handle": f"{class_name}.{method_name}",
"error": str(e),
"traceback": traceback.format_exc()
}
)
# 触发事件允许可以拦截和调整命令
event_data = CommandRegisterEventData(commands=commands, origin="CommandChain", service=None)
event = eventmanager.send_event(ChainEventType.CommandRegister, event_data)
return event, commands
def __build_plugin_commands(self, _: Optional[str] = None) -> Dict[str, dict]:
"""
构建插件命令
"""
# 为了保证命令顺序的一致性,目前这里没有直接使用 pid 获取单一插件命令,后续如果存在性能问题,可以考虑优化这里的逻辑
plugin_commands = {}
for command in self.pluginmanager.get_plugin_commands():
cmd = command.get("cmd")
if cmd:
plugin_commands[cmd] = {
"pid": command.get("pid"),
"func": self.send_plugin_event,
"description": command.get("desc"),
"category": command.get("category"),
"data": {
"etype": command.get("event"),
"data": command.get("data")
}
}
return plugin_commands
def __run_command(self, command: Dict[str, any], data_str: str = "",
channel: MessageChannel = None, source: str = None, userid: Union[str, int] = None):
def __run_command(self, command: Dict[str, any],
data_str: str = "",
channel: MessageChannel = None, userid: Union[str, int] = None):
"""
运行定时服务
"""
if command.get("type") == "scheduler":
# 定时服务
if userid:
CommandChain().post_message(
self.chain.post_message(
Notification(
channel=channel,
source=source,
title=f"开始执行 {command.get('description')} ...",
userid=userid
)
@@ -294,67 +256,75 @@ class Command(metaclass=Singleton):
self.scheduler.start(job_id=command.get("id"))
if userid:
CommandChain().post_message(
self.chain.post_message(
Notification(
channel=channel,
source=source,
title=f"{command.get('description')} 执行完成",
userid=userid
)
)
else:
# 命令
cmd_data = copy.deepcopy(command['data']) if command.get('data') else {}
cmd_data = command['data'] if command.get('data') else {}
args_num = ObjectUtils.arguments(command['func'])
if args_num > 0:
if cmd_data:
# 有内置参数直接使用内置参数
data = cmd_data.get("data") or {}
data['channel'] = channel
data['source'] = source
data['user'] = userid
if data_str:
data['arg_str'] = data_str
data['args'] = data_str
cmd_data['data'] = data
command['func'](**cmd_data)
elif args_num == 3:
# 没有输入参数,只输入渠道来源、用户ID和消息来源
command['func'](channel, userid, source)
elif args_num > 3:
elif args_num == 2:
# 没有输入参数,只输入渠道用户ID
command['func'](channel, userid)
elif args_num > 2:
# 多个输入参数用户输入、用户ID
command['func'](data_str, channel, userid, source)
command['func'](data_str, channel, userid)
else:
# 没有参数
command['func']()
def stop(self):
"""
停止事件处理线程
"""
logger.info("正在停止事件处理...")
self._event.set()
try:
self._thread.join()
logger.info("事件处理停止完成")
except Exception as e:
logger.error(f"停止事件处理线程出错:{str(e)} - {traceback.format_exc()}")
def get_commands(self):
"""
获取命令列表
"""
return self._commands
def get(self, cmd: str) -> Any:
"""
获取命令
"""
return self._commands.get(cmd, {})
def register(self, cmd: str, func: Any, data: dict = None,
desc: str = None, category: str = None) -> None:
"""
注册单个命令
注册命令
"""
# 单独调用的,统一注册到其他
self._other_commands[cmd] = {
self._commands[cmd] = {
"func": func,
"description": desc,
"category": category,
"data": data or {}
}
def get(self, cmd: str) -> Any:
"""
获取命令
"""
return self._commands.get(cmd, {})
def execute(self, cmd: str, data_str: str = "",
channel: MessageChannel = None, source: str = None,
userid: Union[str, int] = None) -> None:
channel: MessageChannel = None, userid: Union[str, int] = None) -> None:
"""
执行命令
"""
@@ -368,7 +338,7 @@ class Command(metaclass=Singleton):
# 执行命令
self.__run_command(command, data_str=data_str,
channel=channel, source=source, userid=userid)
channel=channel, userid=userid)
if userid:
logger.info(f"用户 {userid} {command.get('description')} 执行完成")
@@ -385,7 +355,7 @@ class Command(metaclass=Singleton):
"""
发送插件命令
"""
eventmanager.send_event(etype, data)
EventManager().send_event(etype, data)
@eventmanager.register(EventType.CommandExcute)
def command_event(self, event: ManagerEvent) -> None:
@@ -399,21 +369,10 @@ class Command(metaclass=Singleton):
event_str = event.event_data.get('cmd')
# 消息渠道
event_channel = event.event_data.get('channel')
# 消息来源
event_source = event.event_data.get('source')
# 消息用户
event_user = event.event_data.get('user')
if event_str:
cmd = event_str.split()[0]
args = " ".join(event_str.split()[1:])
if self.get(cmd):
self.execute(cmd=cmd, data_str=args,
channel=event_channel, source=event_source, userid=event_user)
@eventmanager.register(EventType.ModuleReload)
def module_reload_event(self, _: ManagerEvent) -> None:
"""
注册模块重载事件
"""
# 发生模块重载时,重新注册命令
self.init_commands()
self.execute(cmd, args, event_channel, event_user)

View File

@@ -1,567 +0,0 @@
import inspect
import json
import pickle
from abc import ABC, abstractmethod
from functools import wraps
from typing import Any, Dict, Optional
from urllib.parse import quote
import redis
from cachetools import TTLCache
from cachetools.keys import hashkey
from app.core.config import settings
from app.log import logger
# 默认缓存区
DEFAULT_CACHE_REGION = "DEFAULT"
class CacheBackend(ABC):
"""
缓存后端基类,定义通用的缓存接口
"""
@abstractmethod
def set(self, key: str, value: Any, ttl: int, region: str = DEFAULT_CACHE_REGION, **kwargs) -> None:
"""
设置缓存
:param key: 缓存的键
:param value: 缓存的值
:param ttl: 缓存的存活时间,单位秒
:param region: 缓存的区
:param kwargs: 其他参数
"""
pass
@abstractmethod
def exists(self, key: str, region: str = DEFAULT_CACHE_REGION) -> bool:
"""
判断缓存键是否存在
:param key: 缓存的键
:param region: 缓存的区
:return: 存在返回 True否则返回 False
"""
pass
@abstractmethod
def get(self, key: str, region: str = DEFAULT_CACHE_REGION) -> Any:
"""
获取缓存
:param key: 缓存的键
:param region: 缓存的区
:return: 返回缓存的值,如果缓存不存在返回 None
"""
pass
@abstractmethod
def delete(self, key: str, region: str = DEFAULT_CACHE_REGION) -> None:
"""
删除缓存
:param key: 缓存的键
:param region: 缓存的区
"""
pass
@abstractmethod
def clear(self, region: Optional[str] = None) -> None:
"""
清除指定区域的缓存或全部缓存
:param region: 缓存的区
"""
pass
@abstractmethod
def close(self) -> None:
"""
关闭缓存连接
"""
pass
@staticmethod
def get_region(region: str = DEFAULT_CACHE_REGION):
"""
获取缓存的区
"""
return f"region:{region}" if region else "region:default"
@staticmethod
def get_cache_key(func, args, kwargs):
"""
获取缓存的键,通过哈希函数对函数的参数进行处理
:param func: 被装饰的函数
:param args: 位置参数
:param kwargs: 关键字参数
:return: 缓存键
"""
signature = inspect.signature(func)
# 绑定传入的参数并应用默认值
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
# 忽略第一个参数,如果它是实例(self)或类(cls)
parameters = list(signature.parameters.keys())
if parameters and parameters[0] in ("self", "cls"):
bound.arguments.pop(parameters[0], None)
# 按照函数签名顺序提取参数值列表
keys = [
bound.arguments[param] for param in signature.parameters if param in bound.arguments
]
# 使用有序参数生成缓存键
return f"{func.__name__}_{hashkey(*keys)}"
class CacheToolsBackend(CacheBackend):
"""
基于 `cachetools.TTLCache` 实现的缓存后端
特性:
- 支持动态设置缓存的 TTLTime To Live存活时间和最大条目数Maxsize
- 缓存实例按区域region划分不同 region 拥有独立的缓存实例
- 同一 region 共享相同的 TTL 和 Maxsize设置时只能作用于整个 region
限制:
- 不支持按 `key` 独立隔离 TTL 和 Maxsize仅支持作用于 region 级别
"""
def __init__(self, maxsize: int = 1000, ttl: int = 1800):
"""
初始化缓存实例
:param maxsize: 缓存的最大条目数
:param ttl: 默认缓存存活时间,单位秒
"""
self.maxsize = maxsize
self.ttl = ttl
# 存储各个 region 的缓存实例region -> TTLCache
self._region_caches: Dict[str, TTLCache] = {}
def __get_region_cache(self, region: str) -> Optional[TTLCache]:
"""
获取指定区域的缓存实例,如果不存在则返回 None
"""
region = self.get_region(region)
return self._region_caches.get(region)
def set(self, key: str, value: Any, ttl: int = None, region: str = DEFAULT_CACHE_REGION, **kwargs) -> None:
"""
设置缓存值支持每个 key 独立配置 TTL 和 Maxsize
:param key: 缓存的键
:param value: 缓存的值
:param ttl: 缓存的存活时间,单位秒如果未传入则使用默认值
:param region: 缓存的区
:param kwargs: maxsize: 缓存的最大条目数如果未传入则使用默认值
"""
ttl = ttl or self.ttl
maxsize = kwargs.get("maxsize", self.maxsize)
region = self.get_region(region)
# 如果该 key 尚未有缓存实例,则创建一个新的 TTLCache 实例
region_cache = self._region_caches.setdefault(region, TTLCache(maxsize=maxsize, ttl=ttl))
# 设置缓存值
region_cache[key] = value
def exists(self, key: str, region: str = DEFAULT_CACHE_REGION) -> bool:
"""
判断缓存键是否存在
:param key: 缓存的键
:param region: 缓存的区
:return: 存在返回 True否则返回 False
"""
region_cache = self.__get_region_cache(region)
if region_cache is None:
return False
return key in region_cache
def get(self, key: str, region: str = DEFAULT_CACHE_REGION) -> Any:
"""
获取缓存的值
:param key: 缓存的键
:param region: 缓存的区
:return: 返回缓存的值,如果缓存不存在返回 None
"""
region_cache = self.__get_region_cache(region)
if region_cache is None:
return None
return region_cache.get(key)
def delete(self, key: str, region: str = DEFAULT_CACHE_REGION) -> None:
"""
删除缓存
:param key: 缓存的键
:param region: 缓存的区
"""
region_cache = self.__get_region_cache(region)
if region_cache is None:
return None
del region_cache[key]
def clear(self, region: Optional[str] = None) -> None:
"""
清除指定区域的缓存或全部缓存
:param region: 缓存的区
"""
if region:
# 清理指定缓存区
region_cache = self.__get_region_cache(region)
if region_cache:
region_cache.clear()
logger.info(f"Cleared cache for region: {region}")
else:
# 清除所有区域的缓存
for region_cache in self._region_caches.values():
region_cache.clear()
logger.info("Cleared all cache")
def close(self) -> None:
"""
内存缓存不需要关闭资源
"""
pass
class RedisBackend(CacheBackend):
"""
基于 Redis 实现的缓存后端,支持通过 Redis 存储缓存
特性:
- 支持动态设置缓存的 TTLTime To Live存活时间
- 支持分区域region管理缓存不同的 region 采用独立的命名空间
- 支持自定义最大内存限制maxmemory和内存淘汰策略如 allkeys-lru
限制:
- 由于 Redis 的分布式特性,写入和读取可能受到网络延迟的影响
- Pickle 反序列化可能存在安全风险,需进一步重构调用来源,避免复杂对象缓存
"""
# 类型缓存集合,针对非容器简单类型
_complex_serializable_types = set()
_simple_serializable_types = set()
def __init__(self, redis_url: str = "redis://localhost", ttl: int = 1800):
"""
初始化 Redis 缓存实例
:param redis_url: Redis 服务的 URL
:param ttl: 缓存的存活时间,单位秒
"""
self.redis_url = redis_url
self.ttl = ttl
try:
self.client = redis.Redis.from_url(
redis_url,
decode_responses=False,
socket_timeout=30,
socket_connect_timeout=5,
health_check_interval=60,
)
# 测试连接,确保 Redis 可用
self.client.ping()
logger.debug(f"Successfully connected to Redis")
self.set_memory_limit()
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
raise RuntimeError("Redis connection failed") from e
def set_memory_limit(self, policy: str = "allkeys-lru"):
"""
动态设置 Redis 最大内存和内存淘汰策略
:param policy: 淘汰策略(如 'allkeys-lru'
"""
try:
# 如果有显式值,则直接使用,为 0 时说明不限制,如果未配置,开启 BIG_MEMORY_MODE 时为 "1024mb",未开启时为 "256mb"
maxmemory = settings.CACHE_REDIS_MAXMEMORY or ("1024mb" if settings.BIG_MEMORY_MODE else "256mb")
self.client.config_set("maxmemory", maxmemory)
self.client.config_set("maxmemory-policy", policy)
logger.debug(f"Redis maxmemory set to {maxmemory}, policy: {policy}")
except Exception as e:
logger.error(f"Failed to set Redis maxmemory or policy: {e}")
@staticmethod
def is_container_type(t):
return t in (list, dict, tuple, set)
@classmethod
def serialize(cls, value: Any) -> bytes:
"""
将值序列化为二进制数据,根据序列化方式标识格式
"""
vt = type(value)
# 针对非容器类型使用缓存策略
if not cls.is_container_type(vt):
# 如果已知需要复杂序列化
if vt in cls._complex_serializable_types:
return b"PICKLE" + b"\x00" + pickle.dumps(value)
# 如果已知可以简单序列化
if vt in cls._simple_serializable_types:
json_data = json.dumps(value).encode("utf-8")
return b"JSON" + b"\x00" + json_data
# 对于未知的非容器类型,尝试简单序列化,如抛出异常,再使用复杂序列化
try:
json_data = json.dumps(value).encode("utf-8")
cls._simple_serializable_types.add(vt)
return b"JSON" + b"\x00" + json_data
except TypeError:
cls._complex_serializable_types.add(vt)
return b"PICKLE" + b"\x00" + pickle.dumps(value)
# 针对容器类型,每次尝试简单序列化,不使用缓存
else:
try:
json_data = json.dumps(value).encode("utf-8")
return b"JSON" + b"\x00" + json_data
except TypeError:
return b"PICKLE" + b"\x00" + pickle.dumps(value)
@classmethod
def deserialize(cls, value: bytes) -> Any:
"""
将二进制数据反序列化为原始值,根据格式标识区分序列化方式
"""
format_marker, data = value.split(b"\x00", 1)
if format_marker == b"JSON":
return json.loads(data.decode("utf-8"))
elif format_marker == b"PICKLE":
return pickle.loads(data)
else:
raise ValueError("Unknown serialization format")
# @staticmethod
# def serialize(value: Any) -> bytes:
# return msgpack.packb(value, use_bin_type=True)
#
# @staticmethod
# def deserialize(value: bytes) -> Any:
# return msgpack.unpackb(value, raw=False)
def get_redis_key(self, region: str, key: str) -> str:
"""
获取缓存 Key
"""
# 使用 region 作为缓存键的一部分
region = self.get_region(quote(region))
return f"{region}:key:{quote(key)}"
def set(self, key: str, value: Any, ttl: int = None, region: str = DEFAULT_CACHE_REGION, **kwargs) -> None:
"""
设置缓存
:param key: 缓存的键
:param value: 缓存的值
:param ttl: 缓存的存活时间,单位秒如果未传入则使用默认值
:param region: 缓存的区
:param kwargs: kwargs
"""
try:
ttl = ttl or self.ttl
redis_key = self.get_redis_key(region, key)
# 对值进行序列化
serialized_value = self.serialize(value)
kwargs.pop("maxsize", None)
self.client.set(redis_key, serialized_value, ex=ttl, **kwargs)
except Exception as e:
logger.error(f"Failed to set key: {key} in region: {region}, error: {e}")
def exists(self, key: str, region: str = DEFAULT_CACHE_REGION) -> bool:
"""
判断缓存键是否存在
:param key: 缓存的键
:param region: 缓存的区
:return: 存在返回 True否则返回 False
"""
try:
redis_key = self.get_redis_key(region, key)
return self.client.exists(redis_key) == 1
except Exception as e:
logger.error(f"Failed to exists key: {key} region: {region}, error: {e}")
return False
def get(self, key: str, region: str = DEFAULT_CACHE_REGION) -> Optional[Any]:
"""
获取缓存的值
:param key: 缓存的键
:param region: 缓存的区
:return: 返回缓存的值,如果缓存不存在返回 None
"""
try:
redis_key = self.get_redis_key(region, key)
value = self.client.get(redis_key)
if value is not None:
return self.deserialize(value) # noqa
return None
except Exception as e:
logger.error(f"Failed to get key: {key} in region: {region}, error: {e}")
return None
def delete(self, key: str, region: str = DEFAULT_CACHE_REGION) -> None:
"""
删除缓存
:param key: 缓存的键
:param region: 缓存的区
"""
try:
redis_key = self.get_redis_key(region, key)
self.client.delete(redis_key)
except Exception as e:
logger.error(f"Failed to delete key: {key} in region: {region}, error: {e}")
def clear(self, region: Optional[str] = None) -> None:
"""
清除指定区域的缓存或全部缓存
:param region: 缓存的区
"""
try:
if region:
cache_region = self.get_region(quote(region))
redis_key = f"{cache_region}:key:*"
# self.client.delete(*self.client.keys(redis_key))
with self.client.pipeline() as pipe:
for key in self.client.scan_iter(redis_key):
pipe.delete(key)
pipe.execute()
logger.info(f"Cleared Redis cache for region: {region}")
else:
self.client.flushdb()
logger.info("Cleared all Redis cache")
except Exception as e:
logger.error(f"Failed to clear cache, region: {region}, error: {e}")
def close(self) -> None:
"""
关闭 Redis 客户端的连接池
"""
if self.client:
self.client.close()
def get_cache_backend(maxsize: int = 1000, ttl: int = 1800) -> CacheBackend:
"""
根据配置获取缓存后端实例
:param maxsize: 缓存的最大条目数
:param ttl: 缓存的默认存活时间,单位秒
:return: 返回缓存后端实例
"""
cache_type = settings.CACHE_BACKEND_TYPE
logger.debug(f"Cache backend type from settings: {cache_type}")
if cache_type == "redis":
redis_url = settings.CACHE_BACKEND_URL
if redis_url:
try:
logger.debug(f"Attempting to use RedisBackend with URL: {redis_url}, TTL: {ttl}")
return RedisBackend(redis_url=redis_url, ttl=ttl)
except RuntimeError:
logger.warning("Falling back to CacheToolsBackend due to Redis connection failure.")
else:
logger.debug("Cache backend type is redis, but no valid REDIS_URL found. "
"Falling back to CacheToolsBackend.")
# 如果不是 Redis回退到内存缓存
logger.debug(f"Using CacheToolsBackend with default maxsize: {maxsize}, TTL: {ttl}")
return CacheToolsBackend(maxsize=maxsize, ttl=ttl)
def cached(region: Optional[str] = None, maxsize: int = 1000, ttl: int = 1800,
skip_none: bool = True, skip_empty: bool = False):
"""
自定义缓存装饰器,支持为每个 key 动态传递 maxsize 和 ttl
:param region: 缓存的区
:param maxsize: 缓存的最大条目数,默认值为 1000
:param ttl: 缓存的存活时间,单位秒,默认值为 1800
:param skip_none: 跳过 None 缓存,默认为 True
:param skip_empty: 跳过空值缓存(如 None, [], {}, "", set()),默认为 False
:return: 装饰器函数
"""
def should_cache(value: Any) -> bool:
"""
判断是否应该缓存结果,如果返回值是 None 或空值则不缓存
:param value: 要判断的缓存值
:return: 是否缓存结果
"""
if skip_none and value is None:
return False
# if skip_empty and value in [None, [], {}, "", set()]:
if skip_empty and not value:
return False
return True
def is_valid_cache_value(cache_key: str, cached_value: Any, cache_region: str) -> bool:
"""
判断指定的值是否为一个有效的缓存值
:param cache_key: 缓存的键
:param cached_value: 缓存的值
:param cache_region: 缓存的区
:return: 若值是有效的缓存值返回 True否则返回 False
"""
# 如果 skip_none 为 False且 value 为 None需要判断缓存实际是否存在
if not skip_none and cached_value is None:
if not cache_backend.exists(key=cache_key, region=cache_region):
return False
return True
def decorator(func):
# 获取缓存区
cache_region = region if region is not None else f"{func.__module__}.{func.__name__}"
@wraps(func)
def wrapper(*args, **kwargs):
# 获取缓存键
cache_key = cache_backend.get_cache_key(func, args, kwargs)
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
# 执行函数并缓存结果
result = func(*args, **kwargs)
# 判断是否需要缓存
if not should_cache(result):
return result
# 设置缓存(如果有传入的 maxsize 和 ttl则覆盖默认值
cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region)
return result
def cache_clear():
"""
清理缓存区
"""
# 清理缓存区
cache_backend.clear(region=cache_region)
wrapper.cache_region = cache_region
wrapper.cache_clear = cache_clear
return wrapper
return decorator
# 缓存后端实例
cache_backend = get_cache_backend()
def close_cache() -> None:
"""
关闭缓存后端连接并清理资源
"""
try:
if cache_backend:
cache_backend.close()
logger.info("Cache backend closed successfully.")
except Exception as e:
logger.info(f"Error while closing cache backend: {e}")

View File

@@ -1,28 +1,18 @@
import copy
import os
import re
import secrets
import sys
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Optional, List
from dotenv import set_key
from pydantic import BaseModel, BaseSettings, validator, Field
from pydantic import BaseSettings, validator
from app.log import logger, log_settings, LogConfigModel
from app.utils.system import SystemUtils
from app.utils.url import UrlUtils
class ConfigModel(BaseModel):
class Settings(BaseSettings):
"""
Pydantic 配置模型,描述所有配置项及其类型和默认值
系统配置类
"""
class Config:
extra = "ignore" # 忽略未定义的配置项
# 项目名称
PROJECT_NAME = "MoviePilot"
# 域名 格式https://movie-pilot.org
@@ -33,14 +23,10 @@ class ConfigModel(BaseModel):
FRONTEND_PATH: str = "/public"
# 密钥
SECRET_KEY: str = secrets.token_urlsafe(32)
# RESOURCE密钥
RESOURCE_SECRET_KEY: str = secrets.token_urlsafe(32)
# 允许的域名
ALLOWED_HOSTS: list = Field(default_factory=lambda: ["*"])
ALLOWED_HOSTS: list = ["*"]
# TOKEN过期时间
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
# RESOURCE_TOKEN过期时间
RESOURCE_ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 30
# 时区
TZ: str = "Asia/Shanghai"
# API监听地址
@@ -53,42 +39,18 @@ class ConfigModel(BaseModel):
DEBUG: bool = False
# 是否开发模式
DEV: bool = False
# 是否在控制台输出 SQL 语句,默认关闭
DB_ECHO: bool = False
# 数据库连接池类型QueuePool, NullPool
DB_POOL_TYPE: str = "QueuePool"
# 是否在获取连接时进行预先 ping 操作,默认关闭
DB_POOL_PRE_PING: bool = False
# 数据库连接池的大小,默认 100
DB_POOL_SIZE: int = 100
# 数据库连接的回收时间(秒),默认 1800 秒
DB_POOL_RECYCLE: int = 1800
# 数据库连接池获取连接的超时时间(秒),默认 60 秒
DB_POOL_TIMEOUT: int = 60
# 数据库连接池最大溢出连接数,默认 500
DB_MAX_OVERFLOW: int = 500
# SQLite 的 busy_timeout 参数,默认为 60 秒
DB_TIMEOUT: int = 60
# SQLite 是否启用 WAL 模式,默认关闭
DB_WAL_ENABLE: bool = False
# 缓存类型,支持 cachetools 和 redis默认使用 cachetools
CACHE_BACKEND_TYPE: str = "cachetools"
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached需要
CACHE_BACKEND_URL: Optional[str] = None
# Redis 缓存最大内存限制,未配置时,如开启大内存模式时为 "1024mb",未开启时为 "256mb"
CACHE_REDIS_MAXMEMORY: Optional[str] = None
# 是否开启插件热加载
PLUGIN_AUTO_RELOAD: bool = False
# 配置文件目录
CONFIG_DIR: Optional[str] = None
# 超级管理员
SUPERUSER: str = "admin"
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
AUXILIARY_AUTH_ENABLE: bool = False
# API密钥需要更换
API_TOKEN: Optional[str] = None
API_TOKEN: str = "moviepilot"
# 登录页面电影海报,tmdb/bing
WALLPAPER: str = "tmdb"
# 网络代理 IP:PORT
PROXY_HOST: Optional[str] = None
# 登录页面电影海报,tmdb/bing/mediaserver
WALLPAPER: str = "tmdb"
# 媒体搜索来源 themoviedb/douban/bangumi多个用,分隔
SEARCH_SOURCE: str = "themoviedb,douban,bangumi"
# 媒体识别来源 themoviedb/douban
@@ -109,74 +71,124 @@ class ConfigModel(BaseModel):
FANART_ENABLE: bool = True
# Fanart API Key
FANART_API_KEY: str = "d2d31f9ecabea050fc7d68aa3146015f"
# 元数据识别缓存过期时间(小时)
META_CACHE_EXPIRE: int = 0
# 电视剧动漫的分类genre_ids
ANIME_GENREIDS = [16]
# 用户认证站点
AUTH_SITE: str = ""
# 自动检查和更新站点资源包(站点索引、认证等)
AUTO_UPDATE_RESOURCE: bool = True
# 是否启用DOH解析域名
DOH_ENABLE: bool = False
# 使用 DOH 解析的域名列表
DOH_DOMAINS: str = ("api.themoviedb.org,"
"api.tmdb.org,"
"webservice.fanart.tv,"
"api.github.com,"
"github.com,"
"raw.githubusercontent.com,"
"api.telegram.org")
# DOH 解析服务器列表
DOH_RESOLVERS: str = "1.0.0.1,1.1.1.1,9.9.9.9,149.112.112.112"
# 支持的后缀格式
RMT_MEDIAEXT: list = Field(
default_factory=lambda: ['.mp4', '.mkv', '.ts', '.iso',
'.rmvb', '.avi', '.mov', '.mpeg',
'.mpg', '.wmv', '.3gp', '.asf',
'.m4v', '.flv', '.m2ts', '.strm',
'.tp', '.f4v']
)
RMT_MEDIAEXT: list = ['.mp4', '.mkv', '.ts', '.iso',
'.rmvb', '.avi', '.mov', '.mpeg',
'.mpg', '.wmv', '.3gp', '.asf',
'.m4v', '.flv', '.m2ts', '.strm',
'.tp', '.f4v']
# 支持的字幕文件后缀格式
RMT_SUBEXT: list = Field(default_factory=lambda: ['.srt', '.ass', '.ssa', '.sup'])
# 支持的音轨文件后缀格式
RMT_AUDIO_TRACK_EXT: list = Field(default_factory=lambda: ['.mka'])
# 音轨文件后缀格式
RMT_AUDIOEXT: list = Field(
default_factory=lambda: ['.aac', '.ac3', '.amr', '.caf', '.cda', '.dsf',
'.dff', '.kar', '.m4a', '.mp1', '.mp2', '.mp3',
'.mid', '.mod', '.mka', '.mpc', '.nsf', '.ogg',
'.pcm', '.rmi', '.s3m', '.snd', '.spx', '.tak',
'.tta', '.vqf', '.wav', '.wma',
'.aifc', '.aiff', '.alac', '.adif', '.adts',
'.flac', '.midi', '.opus', '.sfalc']
)
RMT_SUBEXT: list = ['.srt', '.ass', '.ssa', '.sup']
# 下载器临时文件后缀
DOWNLOAD_TMPEXT: list = Field(default_factory=lambda: ['.!qb', '.part'])
# 媒体服务器同步间隔(小时)
MEDIASERVER_SYNC_INTERVAL: int = 6
DOWNLOAD_TMPEXT: list = ['.!qB', '.part']
# 支持的音轨文件后缀格式
RMT_AUDIO_TRACK_EXT: list = ['.mka']
# 索引器
INDEXER: str = "builtin"
# 订阅模式
SUBSCRIBE_MODE: str = "spider"
# RSS订阅模式刷新时间间隔分钟
SUBSCRIBE_RSS_INTERVAL: int = 30
# 订阅数据共享
SUBSCRIBE_STATISTIC_SHARE: bool = True
# 订阅搜索开关
SUBSCRIBE_SEARCH: bool = False
# 检查本地媒体库是否存在资源开关
LOCAL_EXISTS_SEARCH: bool = False
# 搜索多个名称
SEARCH_MULTIPLE_NAME: bool = False
# 站点数据刷新间隔(小时)
SITEDATA_REFRESH_INTERVAL: int = 6
# 读取和发送站点消息
SITE_MESSAGE: bool = True
# 用户认证站点
AUTH_SITE: str = ""
# 交互搜索自动下载用户ID使用,分割
AUTO_DOWNLOAD_USER: Optional[str] = None
# 消息通知渠道 telegram/wechat/slack/synologychat/vocechat/webpush多个通知渠道用,分隔
MESSAGER: str = "webpush"
# WeChat企业ID
WECHAT_CORPID: Optional[str] = None
# WeChat应用Secret
WECHAT_APP_SECRET: Optional[str] = None
# WeChat应用ID
WECHAT_APP_ID: Optional[str] = None
# WeChat代理服务器
WECHAT_PROXY: str = "https://qyapi.weixin.qq.com"
# WeChat Token
WECHAT_TOKEN: Optional[str] = None
# WeChat EncodingAESKey
WECHAT_ENCODING_AESKEY: Optional[str] = None
# WeChat 管理员
WECHAT_ADMINS: Optional[str] = None
# Telegram Bot Token
TELEGRAM_TOKEN: Optional[str] = None
# Telegram Chat ID
TELEGRAM_CHAT_ID: Optional[str] = None
# Telegram 用户ID使用,分隔
TELEGRAM_USERS: str = ""
# Telegram 管理员ID使用,分隔
TELEGRAM_ADMINS: str = ""
# Slack Bot User OAuth Token
SLACK_OAUTH_TOKEN: str = ""
# Slack App-Level Token
SLACK_APP_TOKEN: str = ""
# Slack 频道名称
SLACK_CHANNEL: str = ""
# SynologyChat Webhook
SYNOLOGYCHAT_WEBHOOK: str = ""
# SynologyChat Token
SYNOLOGYCHAT_TOKEN: str = ""
# VoceChat地址
VOCECHAT_HOST: str = ""
# VoceChat ApiKey
VOCECHAT_API_KEY: str = ""
# VoceChat 频道ID
VOCECHAT_CHANNEL_ID: str = ""
# 下载器 qbittorrent/transmission启用多个下载器时使用,分隔,只有第一个会被默认使用
DOWNLOADER: str = "qbittorrent"
# 下载器监控开关
DOWNLOADER_MONITOR: bool = True
# Qbittorrent地址IP:PORT
QB_HOST: Optional[str] = None
# Qbittorrent用户名
QB_USER: Optional[str] = None
# Qbittorrent密码
QB_PASSWORD: Optional[str] = None
# Qbittorrent分类自动管理
QB_CATEGORY: bool = False
# Qbittorrent按顺序下载
QB_SEQUENTIAL: bool = True
# Qbittorrent忽略队列限制强制继续
QB_FORCE_RESUME: bool = False
# Transmission地址IP:PORT
TR_HOST: Optional[str] = None
# Transmission用户名
TR_USER: Optional[str] = None
# Transmission密码
TR_PASSWORD: Optional[str] = None
# 种子标签
TORRENT_TAG: str = "MOVIEPILOT"
# 下载站点字幕
DOWNLOAD_SUBTITLE: bool = True
# 交互搜索自动下载用户ID使用,分割
AUTO_DOWNLOAD_USER: Optional[str] = None
# 媒体服务器 emby/jellyfin/plex多个媒体服务器,分割
MEDIASERVER: str = "emby"
# 媒体服务器同步间隔(小时)
MEDIASERVER_SYNC_INTERVAL: Optional[int] = 6
# 媒体服务器同步黑名单,多个媒体库名称,分割
MEDIASERVER_SYNC_BLACKLIST: Optional[str] = None
# EMBY服务器地址IP:PORT
EMBY_HOST: Optional[str] = None
# EMBY外网地址http(s)://DOMAIN:PORT未设置时使用EMBY_HOST
EMBY_PLAY_HOST: Optional[str] = None
# EMBY Api Key
EMBY_API_KEY: Optional[str] = None
# Jellyfin服务器地址IP:PORT
JELLYFIN_HOST: Optional[str] = None
# Jellyfin外网地址http(s)://DOMAIN:PORT未设置时使用JELLYFIN_HOST
JELLYFIN_PLAY_HOST: Optional[str] = None
# Jellyfin Api Key
JELLYFIN_API_KEY: Optional[str] = None
# Plex服务器地址IP:PORT
PLEX_HOST: Optional[str] = None
# Plex外网地址http(s)://DOMAIN:PORT未设置时使用PLEX_HOST
PLEX_PLAY_HOST: Optional[str] = None
# Plex Token
PLEX_TOKEN: Optional[str] = None
# 转移方式 link/copy/move/softlink
TRANSFER_TYPE: str = "copy"
# 是否同盘优先
TRANSFER_SAME_DISK: bool = True
# CookieCloud是否启动本地服务
COOKIECLOUD_ENABLE_LOCAL: Optional[bool] = False
# CookieCloud服务器地址
@@ -189,8 +201,12 @@ class ConfigModel(BaseModel):
COOKIECLOUD_INTERVAL: Optional[int] = 60 * 24
# CookieCloud同步黑名单多个域名,分割
COOKIECLOUD_BLACKLIST: Optional[str] = None
# OCR服务器地址
OCR_HOST: str = "https://movie-pilot.org"
# CookieCloud对应的浏览器UA
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.57"
# 电视剧动漫的分类genre_ids
ANIME_GENREIDS = [16]
# 电影重命名格式
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
@@ -200,262 +216,72 @@ class ConfigModel(BaseModel):
"/Season {{season}}" \
"/{{title}} - {{season_episode}}{% if part %}-{{part}}{% endif %}{% if episode %} - 第 {{episode}} 集{% endif %}" \
"{{fileExt}}"
# OCR服务器地址
OCR_HOST: str = "https://movie-pilot.org"
# 服务器地址,对应 https://github.com/jxxghp/MoviePilot-Server 项目
MP_SERVER_HOST: str = "https://movie-pilot.org"
# 插件市场仓库地址,多个地址使用,分隔,地址以/结尾
PLUGIN_MARKET: str = ("https://github.com/jxxghp/MoviePilot-Plugins,"
"https://github.com/thsrite/MoviePilot-Plugins,"
"https://github.com/honue/MoviePilot-Plugins,"
"https://github.com/InfinityPacer/MoviePilot-Plugins")
# 插件安装数据共享
PLUGIN_STATISTIC_SHARE: bool = True
# 是否开启插件热加载
PLUGIN_AUTO_RELOAD: bool = False
# Github token提高请求api限流阈值 ghp_****
GITHUB_TOKEN: Optional[str] = None
# Github代理服务器格式https://mirror.ghproxy.com/
GITHUB_PROXY: Optional[str] = ''
# pip镜像站点格式https://pypi.tuna.tsinghua.edu.cn/simple
PIP_PROXY: Optional[str] = ''
# 指定的仓库Github token多个仓库使用,分隔,格式:{user1}/{repo1}:ghp_****,{user2}/{repo2}:github_pat_****
REPO_GITHUB_TOKEN: Optional[str] = None
# 转移时覆盖模式
OVERWRITE_MODE: str = "size"
# 大内存模式
BIG_MEMORY_MODE: bool = False
# 全局图片缓存,将媒体图片缓存到本地
GLOBAL_IMAGE_CACHE: bool = False
# 是否启用编码探测的性能模式
ENCODING_DETECTION_PERFORMANCE_MODE: bool = True
# 编码探测的最低置信度阈值
ENCODING_DETECTION_MIN_CONFIDENCE: float = 0.8
# 允许的图片缓存域名
SECURITY_IMAGE_DOMAINS: List[str] = Field(
default_factory=lambda: ["image.tmdb.org",
"static-mdb.v.geilijiasu.com",
"doubanio.com",
"lain.bgm.tv",
"raw.githubusercontent.com",
"github.com",
"thetvdb.com",
"cctvpic.com",
"iqiyipic.com",
"hdslb.com",
"cmvideo.cn",
"ykimg.com",
"qpic.cn"]
)
# 允许的图片文件后缀格式
SECURITY_IMAGE_SUFFIXES: List[str] = Field(
default_factory=lambda: [".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"]
)
# 重命名时支持的S0别名
RENAME_FORMAT_S0_NAMES: List[str] = Field(
default_factory=lambda: ["Specials", "SPs"]
)
# 启用分词搜索
TOKENIZED_SEARCH: bool = False
# 为指定默认字幕添加.default后缀
DEFAULT_SUB: Optional[str] = "zh-cn"
# 插件市场仓库地址,多个地址使用,分隔,地址以/结尾
PLUGIN_MARKET: str = "https://github.com/jxxghp/MoviePilot-Plugins,https://github.com/thsrite/MoviePilot-Plugins,https://github.com/honue/MoviePilot-Plugins,https://github.com/InfinityPacer/MoviePilot-Plugins"
# Github token提高请求api限流阈值 ghp_****
GITHUB_TOKEN: Optional[str] = None
# 指定的仓库Github token多个仓库使用,分隔,格式:{user1}/{repo1}:ghp_****,{user2}/{repo2}:github_pat_****
REPO_GITHUB_TOKEN: Optional[str] = None
# Github代理服务器格式https://mirror.ghproxy.com/
GITHUB_PROXY: Optional[str] = ''
# 自动检查和更新站点资源包(站点索引、认证等)
AUTO_UPDATE_RESOURCE: bool = False
# 元数据识别缓存过期时间(小时)
META_CACHE_EXPIRE: int = 0
# 是否启用DOH解析域名
DOH_ENABLE: bool = True
# 使用 DOH 解析的域名列表
DOH_DOMAINS: str = "api.themoviedb.org,api.tmdb.org,webservice.fanart.tv,api.github.com,github.com,raw.githubusercontent.com,api.telegram.org"
# DOH 解析服务器列表
DOH_RESOLVERS: str = "1.0.0.1,1.1.1.1,9.9.9.9,149.112.112.112"
# 搜索多个名称
SEARCH_MULTIPLE_NAME: bool = False
# 订阅数据共享
SUBSCRIBE_STATISTIC_SHARE: bool = True
# 插件安装数据共享
PLUGIN_STATISTIC_SHARE: bool = True
# 服务器地址,对应 https://github.com/jxxghp/MoviePilot-Server 项目
MP_SERVER_HOST: str = "https://movie-pilot.org"
# 【已弃用】刮削入库的媒体文件
SCRAP_METADATA: bool = True
# 【已弃用】下载保存目录,容器内映射路径需要一致
DOWNLOAD_PATH: Optional[str] = None
# 【已弃用】电影下载保存目录,容器内映射路径需要一致
DOWNLOAD_MOVIE_PATH: Optional[str] = None
# 【已弃用】电视剧下载保存目录,容器内映射路径需要一致
DOWNLOAD_TV_PATH: Optional[str] = None
# 【已弃用】动漫下载保存目录,容器内映射路径需要一致
DOWNLOAD_ANIME_PATH: Optional[str] = None
# 【已弃用】下载目录二级分类
DOWNLOAD_CATEGORY: bool = False
# 【已弃用】媒体库目录,多个目录使用,分隔
LIBRARY_PATH: Optional[str] = None
# 【已弃用】电影媒体库目录名
LIBRARY_MOVIE_NAME: str = "电影"
# 【已弃用】电视剧媒体库目录名
LIBRARY_TV_NAME: str = "电视剧"
# 【已弃用】动漫媒体库目录名,不设置时使用电视剧目录
LIBRARY_ANIME_NAME: Optional[str] = None
# 【已弃用】二级分类
LIBRARY_CATEGORY: bool = True
class Settings(BaseSettings, ConfigModel, LogConfigModel):
"""
系统配置类
"""
class Config:
case_sensitive = True
env_file = SystemUtils.get_env_path()
env_file_encoding = "utf-8"
def __init__(self, **kwargs):
super().__init__(**kwargs)
# 初始化配置目录及子目录
for path in [self.CONFIG_PATH, self.TEMP_PATH, self.LOG_PATH, self.COOKIE_PATH]:
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
# 如果是二进制程序,确保配置文件存在
if SystemUtils.is_frozen():
app_env_path = self.CONFIG_PATH / "app.env"
if not app_env_path.exists():
SystemUtils.copy(self.INNER_CONFIG_PATH / "app.env", app_env_path)
@staticmethod
def validate_api_token(value: Any, original_value: Any) -> Tuple[Any, bool]:
"""
校验 API_TOKEN
"""
if isinstance(value, (list, dict, set)):
value = copy.deepcopy(value)
value = value.strip() if isinstance(value, str) else None
if not value or len(value) < 16:
new_token = secrets.token_urlsafe(16)
if not value:
logger.info(f"'API_TOKEN' 未设置已随机生成新的【API_TOKEN】{new_token}")
else:
logger.warning(f"'API_TOKEN' 长度不足 16 个字符存在安全隐患已随机生成新的【API_TOKEN】{new_token}")
return new_token, True
return value, str(value) != str(original_value)
@staticmethod
def generic_type_converter(value: Any, original_value: Any, expected_type: Type, default: Any, field_name: str,
raise_exception: bool = False) -> Tuple[Any, bool]:
"""
通用类型转换函数,根据预期类型转换值。如果转换失败,返回默认值
"""
if isinstance(value, (list, dict, set)):
value = copy.deepcopy(value)
# 如果 value 是 None仍需要检查与 original_value 是否不一致
if value is None:
return default, str(value) != str(original_value)
if isinstance(value, str):
value = value.strip()
@validator("SUBSCRIBE_RSS_INTERVAL",
"COOKIECLOUD_INTERVAL",
"MEDIASERVER_SYNC_INTERVAL",
"META_CACHE_EXPIRE",
pre=True, always=True)
def convert_int(cls, value):
if not value:
return 0
try:
if expected_type is bool:
if isinstance(value, bool):
return value, str(value).lower() != str(original_value).lower()
if isinstance(value, str):
value_clean = value.lower()
bool_map = {
"false": False, "no": False, "0": False, "off": False,
"true": True, "yes": True, "1": True, "on": True
}
if value_clean in bool_map:
converted = bool_map[value_clean]
return converted, str(converted).lower() != str(original_value).lower()
elif isinstance(value, (int, float)):
converted = bool(value)
return converted, str(converted).lower() != str(original_value).lower()
return default, True
elif expected_type is int:
if isinstance(value, int):
return value, str(value) != str(original_value)
if isinstance(value, str):
converted = int(value)
return converted, str(converted) != str(original_value)
elif expected_type is float:
if isinstance(value, float):
return value, str(value) != str(original_value)
if isinstance(value, str):
converted = float(value)
return converted, str(converted) != str(original_value)
elif expected_type is str:
# 清理 value 中所有空白字符的字段
fields_not_keep_spaces = {"AUTO_DOWNLOAD_USER", "REPO_GITHUB_TOKEN", "PLUGIN_MARKET"}
if field_name in fields_not_keep_spaces:
value = re.sub(r"\s+", "", value)
return value, str(value) != str(original_value)
# # 后续考虑支持 list 类型的处理
# elif expected_type is list:
# if isinstance(value, list):
# return value, False
# if isinstance(value, str):
# items = [item.strip() for item in value.split(",") if item.strip()]
# return items, items != original_value.split(",")
# 可根据需要添加更多类型处理
else:
return value, str(value) != str(original_value)
except (ValueError, TypeError) as e:
if raise_exception:
raise ValueError(f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型") from e
logger.error(
f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型,使用默认值 '{default}',错误信息: {e}")
return default, True
@validator('*', pre=True, always=True)
def generic_type_validator(cls, value: Any, field): # noqa
"""
通用校验器,尝试将配置值转换为期望的类型
"""
if field.name == "API_TOKEN":
converted_value, needs_update = cls.validate_api_token(value, value)
else:
converted_value, needs_update = cls.generic_type_converter(value, value, field.type_, field.default,
field.name)
if needs_update:
cls.update_env_config(field, value, converted_value)
return converted_value
@staticmethod
def update_env_config(field: Any, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
"""
更新 env 配置
"""
message = None
is_converted = original_value is not None and str(original_value) != str(converted_value)
if is_converted:
message = f"配置项 '{field.name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
logger.warning(message)
if field.name in os.environ:
message = f"配置项 '{field.name}' 已在环境变量中设置,请手动更新以保持一致性"
logger.warning(message)
return False, message
else:
set_key(SystemUtils.get_env_path(), field.name, str(converted_value) if converted_value is not None else "")
if is_converted:
logger.info(f"配置项 '{field.name}' 已自动修正并写入到 'app.env' 文件")
return True, message
def update_setting(self, key: str, value: Any) -> Tuple[bool, str]:
"""
更新单个配置项
"""
if not hasattr(self, key):
return False, f"配置项 '{key}' 不存在"
try:
field = self.__fields__[key]
original_value = getattr(self, key)
if field.name == "API_TOKEN":
converted_value, needs_update = self.validate_api_token(value, original_value)
else:
converted_value, needs_update = self.generic_type_converter(value, original_value, field.type_,
field.default, key)
# 如果没有抛出异常,则统一使用 converted_value 进行更新
if needs_update or str(value) != str(converted_value):
success, message = self.update_env_config(field, value, converted_value)
# 仅成功更新配置时,才更新内存
if success:
setattr(self, key, converted_value)
if hasattr(log_settings, key):
setattr(log_settings, key, converted_value)
return success, message
return True, ""
except Exception as e:
return False, str(e)
def update_settings(self, env: Dict[str, Any]) -> Dict[str, Tuple[bool, str]]:
"""
更新多个配置项
"""
results = {}
log_updated, plugin_monitor_updated = False, False
for k, v in env.items():
results[k] = self.update_setting(k, v)
if hasattr(log_settings, k):
log_updated = True
if k in ["PLUGIN_AUTO_RELOAD", "DEV"]:
plugin_monitor_updated = True
# 本次更新存在日志配置项更新,需要重新加载日志配置
if log_updated:
logger.update_loggers()
# 本次更新存在插件监控配置项更新,需要重新加载插件监控
if plugin_monitor_updated:
# 解决顶层循环导入问题
from app.core.plugin import PluginManager
PluginManager().reload_monitor()
return results
@property
def VERSION_FLAG(self) -> str:
"""
版本标识用来区分重大版本为空则为v1不允许外部修改
"""
return "v2"
return int(value)
except (ValueError, TypeError):
raise ValueError(f"{value} 格式错误,不是有效数字!")
@property
def INNER_CONFIG_PATH(self):
@@ -475,10 +301,6 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
def TEMP_PATH(self):
return self.CONFIG_PATH / "temp"
@property
def CACHE_PATH(self):
return self.CONFIG_PATH / "cache"
@property
def ROOT_PATH(self):
return Path(__file__).parents[2]
@@ -497,34 +319,22 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
@property
def CACHE_CONF(self):
"""
{
"torrents": "缓存种子数量",
"refresh": "订阅刷新处理数量",
"tmdb": "TMDB请求缓存数量",
"douban": "豆瓣请求缓存数量",
"fanart": "Fanart请求缓存数量",
"meta": "元数据缓存过期时间(秒)"
}
"""
if self.BIG_MEMORY_MODE:
return {
"torrents": 200,
"refresh": 100,
"tmdb": 1024,
"refresh": 50,
"torrents": 100,
"douban": 512,
"bangumi": 512,
"fanart": 512,
"meta": (self.META_CACHE_EXPIRE or 24) * 3600
"meta": (self.META_CACHE_EXPIRE or 168) * 3600
}
return {
"torrents": 100,
"refresh": 50,
"tmdb": 256,
"refresh": 30,
"torrents": 50,
"douban": 256,
"bangumi": 256,
"fanart": 128,
"meta": (self.META_CACHE_EXPIRE or 2) * 3600
"meta": (self.META_CACHE_EXPIRE or 72) * 3600
}
@property
@@ -585,6 +395,24 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
# 如果传入了指定的仓库名称,则返回该仓库的请求头信息,否则返回默认请求头
return headers.get(repo, self.GITHUB_HEADERS)
@property
def DEFAULT_DOWNLOADER(self):
"""
默认下载器
"""
if not self.DOWNLOADER:
return None
return next((d for d in settings.DOWNLOADER.split(",") if d), None)
@property
def DOWNLOADERS(self):
"""
下载器列表
"""
if not self.DOWNLOADER:
return []
return [d for d in settings.DOWNLOADER.split(",") if d]
@property
def VAPID(self):
return {
@@ -596,7 +424,33 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
def MP_DOMAIN(self, url: str = None):
if not self.APP_DOMAIN:
return None
return UrlUtils.combine_url(host=self.APP_DOMAIN, path=url)
domain = self.APP_DOMAIN.rstrip("/")
if not domain.startswith("http"):
domain = "http://" + domain
if not url:
return domain
return domain + "/" + url.lstrip("/")
def __init__(self, **kwargs):
super().__init__(**kwargs)
with self.CONFIG_PATH as p:
if not p.exists():
p.mkdir(parents=True, exist_ok=True)
if SystemUtils.is_frozen():
if not (p / "app.env").exists():
SystemUtils.copy(self.INNER_CONFIG_PATH / "app.env", p / "app.env")
with self.TEMP_PATH as p:
if not p.exists():
p.mkdir(parents=True, exist_ok=True)
with self.LOG_PATH as p:
if not p.exists():
p.mkdir(parents=True, exist_ok=True)
with self.COOKIE_PATH as p:
if not p.exists():
p.mkdir(parents=True, exist_ok=True)
class Config:
case_sensitive = True
class GlobalVar(object):
@@ -614,7 +468,6 @@ class GlobalVar(object):
"""
self.STOP_EVENT.set()
@property
def is_system_stopped(self):
"""
是否停止
@@ -635,7 +488,10 @@ class GlobalVar(object):
# 实例化配置
settings = Settings()
settings = Settings(
_env_file=Settings().CONFIG_PATH / "app.env",
_env_file_encoding="utf-8"
)
# 全局标识
global_vars = GlobalVar()

View File

@@ -1,6 +1,5 @@
import re
from dataclasses import dataclass, field
from datetime import datetime
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Any, Tuple
from app.core.config import settings
@@ -24,8 +23,6 @@ class TorrentInfo:
site_proxy: bool = False
# 站点优先级
site_order: int = 0
# 站点下载器
site_downloader: str = None
# 种子名称
title: str = None
# 种子副标题
@@ -124,25 +121,11 @@ class TorrentInfo:
return ""
return StringUtils.diff_time_str(self.freedate)
def pub_minutes(self) -> float:
"""
返回发布时间距离当前时间的分钟数
"""
if not self.pubdate:
return 0
try:
pub_date = datetime.strptime(self.pubdate, "%Y-%m-%d %H:%M:%S")
now_datetime = datetime.now()
return (now_datetime - pub_date).total_seconds() // 60
except Exception as e:
print(f"种子发布时间获取失败: {e}")
return 0
def to_dict(self):
"""
返回字典
"""
dicts = vars(self).copy()
dicts = asdict(self)
dicts["volume_factor"] = self.volume_factor
dicts["freedate_diff"] = self.freedate_diff
return dicts
@@ -158,10 +141,6 @@ class MediaInfo:
title: str = None
# 英文标题
en_title: str = None
# 香港标题
hk_title: str = None
# 台湾标题
tw_title: str = None
# 新加坡标题
sg_title: str = None
# 年份
@@ -178,8 +157,6 @@ class MediaInfo:
douban_id: str = None
# Bangumi ID
bangumi_id: int = None
# 合集ID
collection_id: int = None
# 媒体原语种
original_language: str = None
# 媒体原发行标题
@@ -262,8 +239,6 @@ class MediaInfo:
runtime: int = None
# 下一集
next_episode_to_air: dict = field(default_factory=dict)
# 内容分级
content_rating: str = None
def __post_init__(self):
# 设置媒体信息
@@ -401,8 +376,6 @@ class MediaInfo:
if info.get("external_ids"):
self.tvdb_id = info.get("external_ids", {}).get("tvdb_id")
self.imdb_id = info.get("external_ids", {}).get("imdb_id")
# 合集ID
self.collection_id = info.get('collection_id')
# 评分
self.vote_average = round(float(info.get('vote_average')), 1) if info.get('vote_average') else 0
# 描述
@@ -413,10 +386,6 @@ class MediaInfo:
self.original_language = info.get('original_language')
# 英文标题
self.en_title = info.get('en_title')
# 香港标题
self.hk_title = info.get('hk_title')
# 台湾标题
self.tw_title = info.get('tw_title')
# 新加坡标题
self.sg_title = info.get('sg_title')
if self.type == MediaType.MOVIE:
@@ -746,7 +715,7 @@ class MediaInfo:
"""
返回字典
"""
dicts = vars(self).copy()
dicts = asdict(self)
dicts["type"] = self.type.value if self.type else None
dicts["detail_link"] = self.detail_link
dicts["title_year"] = self.title_year

View File

@@ -1,545 +1,123 @@
import copy
import importlib
import inspect
import random
import threading
import time
import traceback
import uuid
from functools import lru_cache
from queue import Empty, PriorityQueue
from typing import Callable, Dict, List, Optional, Union
from queue import Queue, Empty
from typing import Dict, Any
from app.helper.message import MessageHelper
from app.helper.thread import ThreadHelper
from app.log import logger
from app.schemas import ChainEventData
from app.schemas.types import ChainEventType, EventType
from app.utils.limit import ExponentialBackoffRateLimiter
from app.utils.singleton import Singleton
DEFAULT_EVENT_PRIORITY = 10 # 事件的默认优先级
MIN_EVENT_CONSUMER_THREADS = 1 # 最小事件消费者线程数
INITIAL_EVENT_QUEUE_IDLE_TIMEOUT_SECONDS = 1 # 事件队列空闲时的初始超时时间(秒)
MAX_EVENT_QUEUE_IDLE_TIMEOUT_SECONDS = 5 # 事件队列空闲时的最大超时时间(秒)
class Event:
"""
事件类,封装事件的基本信息
"""
def __init__(self, event_type: Union[EventType, ChainEventType],
event_data: Optional[Union[Dict, ChainEventData]] = None,
priority: int = DEFAULT_EVENT_PRIORITY):
"""
:param event_type: 事件的类型,支持 EventType 或 ChainEventType
:param event_data: 可选,事件携带的数据,默认为空字典
:param priority: 可选,事件的优先级,默认为 10
"""
self.event_id = str(uuid.uuid4()) # 事件ID
self.event_type = event_type # 事件类型
self.event_data = event_data or {} # 事件数据
self.priority = priority # 事件优先级
def __repr__(self) -> str:
"""
重写 __repr__ 方法用于返回事件的详细信息包括事件类型、事件ID和优先级
"""
event_kind = Event.get_event_kind(self.event_type)
return f"<{event_kind}: {self.event_type.value}, ID: {self.event_id}, Priority: {self.priority}>"
def __lt__(self, other):
"""
定义事件对象的比较规则,基于优先级比较
优先级小的事件会被认为“更小”,优先级高的事件将被认为“更大”
"""
return self.priority < other.priority
@staticmethod
def get_event_kind(event_type: Union[EventType, ChainEventType]) -> str:
"""
根据事件类型判断事件是广播事件还是链式事件
:param event_type: 事件类型,支持 EventType 或 ChainEventType
:return: 返回 Broadcast Event 或 Chain Event
"""
return "Broadcast Event" if isinstance(event_type, EventType) else "Chain Event"
from app.schemas.types import EventType
class EventManager(metaclass=Singleton):
"""
EventManager 负责管理和调度广播事件和链式事件,包括订阅、发送和处理事件
事件管理器
"""
# 退出事件
__event = threading.Event()
def __init__(self):
self.__messagehelper = MessageHelper()
self.__executor = ThreadHelper() # 动态线程池,用于消费事件
self.__consumer_threads = [] # 用于保存启动的事件消费者线程
self.__event_queue = PriorityQueue() # 优先级队列
self.__broadcast_subscribers: Dict[EventType, Dict[str, Callable]] = {} # 广播事件的订阅者
self.__chain_subscribers: Dict[ChainEventType, Dict[str, tuple[int, Callable]]] = {} # 链式事件的订阅者
self.__disabled_handlers = set() # 禁用的事件处理器集合
self.__disabled_classes = set() # 禁用的事件处理器类集合
self.__lock = threading.Lock() # 线程锁
# 事件队列
self._eventQueue = Queue()
# 事件响应函数字典
self._handlers: Dict[str, Dict[str, Any]] = {}
# 已禁用的事件响应
self._disabled_handlers = []
def start(self):
def get_event(self):
"""
开始广播事件处理线程
获取事件
"""
# 启动消费者线程用于处理广播事件
self.__event.set()
for _ in range(MIN_EVENT_CONSUMER_THREADS):
thread = threading.Thread(target=self.__broadcast_consumer_loop, daemon=True)
thread.start()
self.__consumer_threads.append(thread) # 将线程对象保存到列表中
def stop(self):
"""
停止广播事件处理线程
"""
logger.info("正在停止事件处理...")
self.__event.clear() # 停止广播事件处理
try:
# 通过遍历保存的线程来等待它们完成
for consumer_thread in self.__consumer_threads:
consumer_thread.join()
logger.info("事件处理停止完成")
except Exception as e:
logger.error(f"停止事件处理线程出错:{str(e)} - {traceback.format_exc()}")
event = self._eventQueue.get(block=True, timeout=1)
handlers = self._handlers.get(event.event_type) or {}
if handlers:
# 去除掉被禁用的事件响应
handlerList = [handler for handler in handlers.values()
if handler.__qualname__.split(".")[0] not in self._disabled_handlers]
return event, handlerList
return event, []
except Empty:
return None, []
def check(self, etype: Union[EventType, ChainEventType]) -> bool:
def check(self, etype: EventType):
"""
检查是否有启用的事件处理器可以响应某个事件类型
:param etype: 事件类型 (EventType 或 ChainEventType)
:return: 返回是否存在可用的处理器
检查事件是否存在响应,去除掉被禁用的事件响应
"""
if isinstance(etype, ChainEventType):
handlers = self.__chain_subscribers.get(etype, {})
return any(
self.__is_handler_enabled(handler)
for _, handler in handlers.values()
)
else:
handlers = self.__broadcast_subscribers.get(etype, {})
return any(
self.__is_handler_enabled(handler)
for handler in handlers.values()
)
def send_event(self, etype: Union[EventType, ChainEventType], data: Optional[Union[Dict, ChainEventData]] = None,
priority: int = DEFAULT_EVENT_PRIORITY) -> Optional[Event]:
"""
发送事件,根据事件类型决定是广播事件还是链式事件
:param etype: 事件类型 (EventType 或 ChainEventType)
:param data: 可选,事件数据
:param priority: 广播事件的优先级,默认为 10
:return: 如果是链式事件,返回处理后的事件数据;否则返回 None
"""
event = Event(etype, data, priority)
if isinstance(etype, EventType):
self.__trigger_broadcast_event(event)
elif isinstance(etype, ChainEventType):
return self.__trigger_chain_event(event)
else:
logger.error(f"Unknown event type: {etype}")
def add_event_listener(self, event_type: Union[EventType, ChainEventType], handler: Callable,
priority: int = DEFAULT_EVENT_PRIORITY):
"""
注册事件处理器,将处理器添加到对应的事件订阅列表中
:param event_type: 事件类型 (EventType 或 ChainEventType)
:param handler: 处理器
:param priority: 可选,链式事件的优先级,默认为 10广播事件不需要优先级
"""
with self.__lock:
handler_identifier = self.__get_handler_identifier(handler)
if isinstance(event_type, ChainEventType):
# 链式事件,按优先级排序
if event_type not in self.__chain_subscribers:
self.__chain_subscribers[event_type] = {}
handlers = self.__chain_subscribers[event_type]
if handler_identifier in handlers:
handlers.pop(handler_identifier)
else:
logger.debug(
f"Subscribed to chain event: {event_type.value}, "
f"Priority: {priority} - {handler_identifier}")
handlers[handler_identifier] = (priority, handler)
# 根据优先级排序
self.__chain_subscribers[event_type] = dict(
sorted(self.__chain_subscribers[event_type].items(), key=lambda x: x[1][0])
)
else:
# 广播事件
if event_type not in self.__broadcast_subscribers:
self.__broadcast_subscribers[event_type] = {}
handlers = self.__broadcast_subscribers[event_type]
if handler_identifier in handlers:
handlers.pop(handler_identifier)
else:
logger.debug(f"Subscribed to broadcast event: {event_type.value} - {handler_identifier}")
handlers[handler_identifier] = handler
def remove_event_listener(self, event_type: Union[EventType, ChainEventType], handler: Callable):
"""
移除事件处理器,将处理器从对应事件的订阅列表中删除
:param event_type: 事件类型 (EventType 或 ChainEventType)
:param handler: 要移除的处理器
"""
with self.__lock:
handler_identifier = self.__get_handler_identifier(handler)
if isinstance(event_type, ChainEventType) and event_type in self.__chain_subscribers:
self.__chain_subscribers[event_type].pop(handler_identifier, None)
logger.debug(f"Unsubscribed from chain event: {event_type.value} - {handler_identifier}")
elif event_type in self.__broadcast_subscribers:
self.__broadcast_subscribers[event_type].pop(handler_identifier, None)
logger.debug(f"Unsubscribed from broadcast event: {event_type.value} - {handler_identifier}")
def disable_event_handler(self, target: Union[Callable, type]):
"""
禁用指定的事件处理器或事件处理器类
:param target: 处理器函数或类
"""
identifier = self.__get_handler_identifier(target)
if identifier in self.__disabled_handlers or identifier in self.__disabled_classes:
return
if isinstance(target, type):
self.__disabled_classes.add(identifier)
logger.debug(f"Disabled event handler class - {identifier}")
else:
self.__disabled_handlers.add(identifier)
logger.debug(f"Disabled event handler - {identifier}")
def enable_event_handler(self, target: Union[Callable, type]):
"""
启用指定的事件处理器或事件处理器类
:param target: 处理器函数或类
"""
identifier = self.__get_handler_identifier(target)
if isinstance(target, type):
self.__disabled_classes.discard(identifier)
logger.debug(f"Enabled event handler class - {identifier}")
else:
self.__disabled_handlers.discard(identifier)
logger.debug(f"Enabled event handler - {identifier}")
def visualize_handlers(self) -> List[Dict]:
"""
可视化所有事件处理器,包括是否被禁用的状态
:return: 处理器列表,包含事件类型、处理器标识符、优先级(如果有)和状态
"""
def parse_handler_data(data):
"""
解析处理器数据,判断是否包含优先级
:param data: 订阅者数据,可能是元组或单一值
:return: (priority, handler),若没有优先级则返回 (None, handler)
"""
if isinstance(data, tuple) and len(data) == 2:
return data
return None, data
handler_info = []
# 统一处理广播事件和链式事件
for event_type, subscribers in {**self.__broadcast_subscribers, **self.__chain_subscribers}.items():
for handler_identifier, handler_data in subscribers.items():
# 解析优先级和处理器
priority, handler = parse_handler_data(handler_data)
# 检查处理器的启用状态
status = "enabled" if self.__is_handler_enabled(handler) else "disabled"
# 构建处理器信息字典
handler_dict = {
"event_type": event_type.value,
"handler_identifier": handler_identifier,
"status": status
}
if priority is not None:
handler_dict["priority"] = priority
handler_info.append(handler_dict)
return handler_info
@classmethod
@lru_cache(maxsize=1000)
def __get_handler_identifier(cls, target: Union[Callable, type]) -> Optional[str]:
"""
获取处理器或处理器类的唯一标识符,包括模块名和类名/方法名
:param target: 处理器函数或类
:return: 唯一标识符
"""
# 统一使用 inspect.getmodule 来获取模块名
module = inspect.getmodule(target)
module_name = module.__name__ if module else "unknown_module"
# 使用 __qualname__ 获取目标的限定名
qualname = target.__qualname__
return f"{module_name}.{qualname}"
@classmethod
@lru_cache(maxsize=1000)
def __get_class_from_callable(cls, handler: Callable) -> Optional[str]:
"""
获取可调用对象所属类的唯一标识符
:param handler: 可调用对象(函数、方法等)
:return: 类的唯一标识符
"""
# 对于绑定方法,通过 __self__.__class__ 获取类
if inspect.ismethod(handler) and hasattr(handler, "__self__"):
return cls.__get_handler_identifier(handler.__self__.__class__)
# 对于类实例(实现了 __call__ 方法)
if not inspect.isfunction(handler) and hasattr(handler, "__call__"):
handler_cls = handler.__class__ # noqa
return cls.__get_handler_identifier(handler_cls)
# 对于未绑定方法、静态方法、类方法,使用 __qualname__ 提取类信息
qualname_parts = handler.__qualname__.split(".")
if len(qualname_parts) > 1:
class_name = ".".join(qualname_parts[:-1])
module = inspect.getmodule(handler)
module_name = module.__name__ if module else "unknown_module"
return f"{module_name}.{class_name}"
def __is_handler_enabled(self, handler: Callable) -> bool:
"""
检查处理器是否已启用(没有被禁用)
:param handler: 处理器函数
:return: 如果处理器启用则返回 True否则返回 False
"""
# 获取处理器的唯一标识符
handler_id = self.__get_handler_identifier(handler)
# 获取处理器所属类的唯一标识符
class_id = self.__get_class_from_callable(handler)
# 检查处理器或类是否被禁用,只要其中之一被禁用则返回 False
if handler_id in self.__disabled_handlers or (class_id is not None and class_id in self.__disabled_classes):
if etype.value not in self._handlers:
return False
handlers = self._handlers.get(etype.value)
return any([handler for handler in handlers.values()
if handler.__qualname__.split(".")[0] not in self._disabled_handlers])
return True
def __trigger_chain_event(self, event: Event) -> Optional[Event]:
def add_event_listener(self, etype: EventType, handler: type):
"""
触发链式事件,按顺序调用订阅的处理器,并记录处理耗时
注册事件处理
"""
logger.debug(f"Triggering synchronous chain event: {event}")
dispatch = self.__dispatch_chain_event(event)
return event if dispatch else None
def __trigger_broadcast_event(self, event: Event):
"""
触发广播事件,将事件插入到优先级队列中
:param event: 要处理的事件对象
"""
logger.debug(f"Triggering broadcast event: {event}")
self.__event_queue.put((event.priority, event))
def __dispatch_chain_event(self, event: Event) -> bool:
"""
同步方式调度链式事件,按优先级顺序逐个调用事件处理器,并记录每个处理器的处理时间
:param event: 要调度的事件对象
"""
handlers = self.__chain_subscribers.get(event.event_type, {})
if not handlers:
logger.debug(f"No handlers found for chain event: {event}")
return False
# 过滤出启用的处理器
enabled_handlers = {handler_id: (priority, handler) for handler_id, (priority, handler) in handlers.items()
if self.__is_handler_enabled(handler)}
if not enabled_handlers:
logger.debug(f"No enabled handlers found for chain event: {event}. Skipping execution.")
return False
self.__log_event_lifecycle(event, "Started")
for handler_id, (priority, handler) in enabled_handlers.items():
start_time = time.time()
self.__safe_invoke_handler(handler, event)
logger.debug(
f"{self.__get_handler_identifier(handler)} (Priority: {priority}), "
f"completed in {time.time() - start_time:.3f}s for event: {event}"
)
self.__log_event_lifecycle(event, "Completed")
return True
def __dispatch_broadcast_event(self, event: Event):
"""
异步方式调度广播事件,通过线程池逐个调用事件处理器
:param event: 要调度的事件对象
"""
handlers = self.__broadcast_subscribers.get(event.event_type, {})
if not handlers:
logger.debug(f"No handlers found for broadcast event: {event}")
return
for handler_id, handler in handlers.items():
self.__executor.submit(self.__safe_invoke_handler, handler, event)
def __safe_invoke_handler(self, handler: Callable, event: Event):
"""
调用处理器,处理链式或广播事件
:param handler: 处理器
:param event: 事件对象
"""
if not self.__is_handler_enabled(handler):
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
return
# 根据事件类型判断是否需要深复制
is_broadcast_event = isinstance(event.event_type, EventType)
event_to_process = copy.deepcopy(event) if is_broadcast_event else event
names = handler.__qualname__.split(".")
class_name, method_name = names[0], names[1]
try:
from app.core.plugin import PluginManager
handlers = self._handlers[etype.value]
except KeyError:
handlers = {}
self._handlers[etype.value] = handlers
if handler.__qualname__ in handlers:
handlers.pop(handler.__qualname__)
else:
logger.debug(f"Event Registed{etype.value} - {handler.__qualname__}")
handlers[handler.__qualname__] = handler
if class_name in PluginManager().get_plugin_ids():
# 定义一个插件调用函数
def plugin_callable():
PluginManager().run_plugin_method(class_name, method_name, event_to_process)
if is_broadcast_event:
self.__executor.submit(plugin_callable)
else:
plugin_callable()
else:
# 获取全局对象或模块类的实例
class_obj = self.__get_class_instance(class_name)
if class_obj and hasattr(class_obj, method_name):
method = getattr(class_obj, method_name)
if is_broadcast_event:
self.__executor.submit(method, event_to_process)
else:
method(event_to_process)
except Exception as e:
self.__handle_event_error(event, handler, e)
@staticmethod
def __get_class_instance(class_name: str):
def disable_events_hander(self, class_name: str):
"""
根据类名获取类实例,首先检查全局变量中是否存在该类,如果不存在则尝试动态导入模块。
:param class_name: 类的名称
:return: 类的实例
标记对应类事件处理为不可用
"""
# 检查类是否在全局变量中
if class_name in globals():
try:
class_obj = globals()[class_name]()
return class_obj
except Exception as e:
logger.error(f"事件处理出错:创建全局类实例出错:{str(e)} - {traceback.format_exc()}")
return None
if class_name not in self._disabled_handlers:
self._disabled_handlers.append(class_name)
logger.debug(f"Event Disabled{class_name}")
# 如果类不在全局变量中,尝试动态导入模块并创建实例
try:
if class_name == "Command":
module_name = "app.command"
module = importlib.import_module(module_name)
elif class_name.endswith("Chain"):
module_name = f"app.chain.{class_name[:-5].lower()}"
module = importlib.import_module(module_name)
else:
logger.debug(f"事件处理出错:无效的 Chain 类名: {class_name},类名必须以 'Chain' 结尾")
return None
if hasattr(module, class_name):
class_obj = getattr(module, class_name)()
return class_obj
else:
logger.debug(f"事件处理出错:模块 {module_name} 中没有找到类 {class_name}")
except Exception as e:
logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
return None
def enable_events_hander(self, class_name: str):
"""
标记对应类事件处理为可用
"""
if class_name in self._disabled_handlers:
self._disabled_handlers.remove(class_name)
logger.debug(f"Event Enabled{class_name}")
def __broadcast_consumer_loop(self):
def send_event(self, etype: EventType, data: dict = None):
"""
持续从队列中提取事件的后台广播消费者线程
发送事件
"""
jitter_factor = 0.1
rate_limiter = ExponentialBackoffRateLimiter(base_wait=INITIAL_EVENT_QUEUE_IDLE_TIMEOUT_SECONDS,
max_wait=MAX_EVENT_QUEUE_IDLE_TIMEOUT_SECONDS,
backoff_factor=2.0,
source="BroadcastConsumer",
enable_logging=False)
while self.__event.is_set():
try:
priority, event = self.__event_queue.get(timeout=rate_limiter.current_wait)
rate_limiter.reset()
self.__dispatch_broadcast_event(event)
except Empty:
rate_limiter.current_wait = rate_limiter.current_wait * random.uniform(1, 1 + jitter_factor)
rate_limiter.trigger_limit()
if etype not in EventType:
return
event = Event(etype.value)
event.event_data = data or {}
logger.debug(f"发送事件:{etype.value} - {event.event_data}")
self._eventQueue.put(event)
@staticmethod
def __log_event_lifecycle(event: Event, stage: str):
def register(self, etype: [EventType, list]):
"""
记录事件的生命周期日志
"""
logger.debug(f"{stage} - {event}")
def __handle_event_error(self, event: Event, handler: Callable, e: Exception):
"""
全局错误处理器,用于处理事件处理中的异常
"""
logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
names = handler.__qualname__.split(".")
class_name, method_name = names[0], names[1]
self.__messagehelper.put(title=f"{event.event_type} 事件处理出错",
message=f"{class_name}.{method_name}{str(e)}",
role="system")
self.send_event(
EventType.SystemError,
{
"type": "event",
"event_type": event.event_type,
"event_handle": f"{class_name}.{method_name}",
"error": str(e),
"traceback": traceback.format_exc()
}
)
def register(self, etype: Union[EventType, ChainEventType, List[Union[EventType, ChainEventType]], type],
priority: int = DEFAULT_EVENT_PRIORITY):
"""
事件注册装饰器,用于将函数注册为事件的处理器
:param etype:
- 单个事件类型成员 (如 EventType.MetadataScrape, ChainEventType.PluginAction)
- 事件类型类 (EventType, ChainEventType)
- 或事件类型成员的列表
:param priority: 可选,链式事件的优先级,默认为 DEFAULT_EVENT_PRIORITY
事件注册
:param etype: 事件类型
"""
def decorator(f: Callable):
# 将输入的事件类型统一转换为列表格式
def decorator(f):
if isinstance(etype, list):
# 传入的已经是列表,直接使用
event_list = etype
for et in etype:
self.add_event_listener(et, f)
elif type(etype) == type(EventType):
for et in etype.__members__.values():
self.add_event_listener(et, f)
else:
# 不是列表则包裹成单一元素的列表
event_list = [etype]
# 遍历列表,处理每个事件类型
for event in event_list:
if isinstance(event, (EventType, ChainEventType)):
self.add_event_listener(event, f, priority)
elif isinstance(event, type) and issubclass(event, (EventType, ChainEventType)):
# 如果是 EventType 或 ChainEventType 类,提取该类中的所有成员
for et in event.__members__.values():
self.add_event_listener(et, f, priority)
else:
raise ValueError(f"无效的事件类型: {event}")
self.add_event_listener(etype, f)
return f
return decorator
# 全局实例定义
class Event(object):
"""
事件对象
"""
def __init__(self, event_type=None):
# 事件类型
self.event_type = event_type
# 字典用于保存具体的事件数据
self.event_data = {}
# 实例引用,用于注册事件
eventmanager = EventManager()

View File

@@ -81,6 +81,7 @@ class MetaAnime(MetaBase):
_, self.cn_name, _, _, _, _ = StringUtils.get_keyword(self.cn_name)
if self.cn_name:
self.cn_name = re.sub(r'%s' % self._name_nostring_re, '', self.cn_name, flags=re.IGNORECASE).strip()
self.cn_name = zhconv.convert(self.cn_name, "zh-hans")
if self.en_name:
self.en_name = re.sub(r'%s' % self._name_nostring_re, '', self.en_name, flags=re.IGNORECASE).strip().title()
self._name = StringUtils.str_title(self.en_name)

View File

@@ -1,13 +1,13 @@
import traceback
from dataclasses import dataclass
from dataclasses import dataclass, asdict
from typing import Union, Optional, List, Self
import cn2an
import regex as re
from app.log import logger
from app.schemas.types import MediaType
from app.utils.string import StringUtils
from app.schemas.types import MediaType
@dataclass
@@ -69,7 +69,7 @@ class MetaBase(object):
_subtitle_flag = False
_title_episodel_re = r"Episode\s+(\d{1,4})"
_subtitle_season_re = r"(?<![全共]\s*)[第\s]+([0-9一二三四五六七八九十S\-]+)\s*季(?!\s*[全共])"
_subtitle_season_all_re = r"[全共]\s*([0-9一二三四五六七八九十]+)\s*季"
_subtitle_season_all_re = r"[全共]\s*([0-9一二三四五六七八九十]+)\s*季|([0-9一二三四五六七八九十]+)\s*季\s*全"
_subtitle_episode_re = r"(?<![全共]\s*)[第\s]+([0-9一二三四五六七八九十百零EP]+)\s*[集话話期幕](?!\s*[全共])"
_subtitle_episode_between_re = r"[第]*\s*([0-9一二三四五六七八九十百零]+)\s*[集话話期幕]?\s*-\s*第*\s*([0-9一二三四五六七八九十百零]+)\s*[集话話期幕]"
_subtitle_episode_all_re = r"([0-9一二三四五六七八九十百零]+)\s*集\s*全|[全共]\s*([0-9一二三四五六七八九十百零]+)\s*[集话話期幕]"
@@ -247,7 +247,7 @@ class MetaBase(object):
self.type = MediaType.TV
self._subtitle_flag = True
return
# x集全/全x集
# x集全
episode_all_str = re.search(r'%s' % self._subtitle_episode_all_re, title_text, re.IGNORECASE)
if episode_all_str:
episode_all = episode_all_str.group(1)
@@ -259,6 +259,8 @@ class MetaBase(object):
except Exception as err:
logger.debug(f'识别集失败:{str(err)} - {traceback.format_exc()}')
return
self.begin_episode = None
self.end_episode = None
self.type = MediaType.TV
self._subtitle_flag = True
return
@@ -587,10 +589,9 @@ class MetaBase(object):
"""
转为字典
"""
dicts = vars(self).copy()
dicts = asdict(self)
dicts["type"] = self.type.value if self.type else None
dicts["season_episode"] = self.season_episode
dicts["edition"] = self.edition
dicts["name"] = self.name
dicts["episode_list"] = self.episode_list
return dicts

View File

@@ -30,8 +30,8 @@ class MetaVideo(MetaBase):
_episode_re = r"EP?(\d{2,4})$|^EP?(\d{1,4})$|^S\d{1,2}EP?(\d{1,4})$|S\d{2}EP?(\d{2,4})"
_part_re = r"(^PART[0-9ABI]{0,2}$|^CD[0-9]{0,2}$|^DVD[0-9]{0,2}$|^DISK[0-9]{0,2}$|^DISC[0-9]{0,2}$)"
_roman_numerals = r"^(?=[MDCLXVI])M*(C[MD]|D?C{0,3})(X[CL]|L?X{0,3})(I[XV]|V?I{0,3})$"
_source_re = r"^BLURAY$|^HDTV$|^UHDTV$|^HDDVD$|^WEBRIP$|^DVDRIP$|^BDRIP$|^BLU$|^WEB$|^BD$|^HDRip$|^REMUX$|^UHD$"
_effect_re = r"^SDR$|^HDR\d*$|^DOLBY$|^DOVI$|^DV$|^3D$|^REPACK$"
_source_re = r"^BLURAY$|^HDTV$|^UHDTV$|^HDDVD$|^WEBRIP$|^DVDRIP$|^BDRIP$|^BLU$|^WEB$|^BD$|^HDRip$"
_effect_re = r"^REMUX$|^UHD$|^SDR$|^HDR\d*$|^DOLBY$|^DOVI$|^DV$|^3D$|^REPACK$"
_resources_type_re = r"%s|%s" % (_source_re, _effect_re)
_name_no_begin_re = r"^[\[【].+?[\]】]"
_name_no_chinese_re = r".*版|.*字幕"
@@ -524,37 +524,6 @@ class MetaVideo(MetaBase):
"""
if not self.name:
return
if token.upper() == "DL" \
and self._last_token_type == "source" \
and self._last_token == "WEB":
self._source = "WEB-DL"
self._continue_flag = False
return
elif token.upper() == "RAY" \
and self._last_token_type == "source" \
and self._last_token == "BLU":
# UHD BluRay组合
if self._source == "UHD":
self._source = "UHD BluRay"
else:
self._source = "BluRay"
self._continue_flag = False
return
elif token.upper() == "WEBDL":
self._source = "WEB-DL"
self._continue_flag = False
return
# UHD REMUX组合
if token.upper() == "REMUX" \
and self._source == "BluRay":
self._source = "BluRay REMUX"
self._continue_flag = False
return
elif token.upper() == "BLURAY" \
and self._source == "UHD":
self._source = "UHD BluRay"
self._continue_flag = False
return
source_res = re.search(r"(%s)" % self._source_re, token, re.IGNORECASE)
if source_res:
self._last_token_type = "source"
@@ -564,6 +533,22 @@ class MetaVideo(MetaBase):
self._source = source_res.group(1)
self._last_token = self._source.upper()
return
elif token.upper() == "DL" \
and self._last_token_type == "source" \
and self._last_token == "WEB":
self._source = "WEB-DL"
self._continue_flag = False
return
elif token.upper() == "RAY" \
and self._last_token_type == "source" \
and self._last_token == "BLU":
self._source = "BluRay"
self._continue_flag = False
return
elif token.upper() == "WEBDL":
self._source = "WEB-DL"
self._continue_flag = False
return
effect_res = re.search(r"(%s)" % self._effect_re, token, re.IGNORECASE)
if effect_res:
self._last_token_type = "effect"

View File

@@ -70,12 +70,11 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
"U2": [],
"ultrahd": [],
"others": ['B(?:MDru|eyondHD|TN)', 'C(?:fandora|trlhd|MRG)', 'DON', 'EVO', 'FLUX', 'HONE(?:|yG)',
'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )', 'UBWEB'],
"anime": ['ANi', 'HYSUB', 'KTXP', 'LoliHouse', 'MCE', 'Nekomoe kissaten', 'SweetSub', 'MingY',
'(?:Lilith|NC)-Raws', '织梦字幕组', '枫叶字幕组', '猎户手抄部', '喵萌奶茶屋', '漫猫字幕社',
'霜庭云花Sub', '北宇治字幕组', '氢气烤肉架', '云歌字幕组', '萌樱字幕组', '极影字幕社',
'悠哈璃羽字幕社',
'❀拨雪寻春❀', '沸羊羊(?:制作|字幕组)', '(?:桜|樱)都字幕组']
'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )'],
"anime": ['ANi', 'HYSUB', 'KTXP', 'LoliHouse', 'MCE', 'Nekomoe kissaten', 'SweetSub', 'MingY',
'(?:Lilith|NC)-Raws', '织梦字幕组', '枫叶字幕组', '猎户手抄部', '喵萌奶茶屋', '漫猫字幕社',
'霜庭云花Sub', '北宇治字幕组', '氢气烤肉架', '云歌字幕组', '萌樱字幕组','极影字幕社','悠哈璃羽字幕社',
'❀拨雪寻春❀', '沸羊羊(?:制作|字幕组)', '(?:桜|樱)都字幕组',]
}
def __init__(self):

View File

@@ -14,7 +14,7 @@ class WordsMatcher(metaclass=Singleton):
def __init__(self):
self.systemconfig = SystemConfigOper()
def prepare(self, title: str, custom_words: List[str] = None) -> Tuple[str, List[str]]:
def prepare(self, title: str) -> Tuple[str, List[str]]:
"""
预处理标题,支持三种格式
1屏蔽词
@@ -23,7 +23,7 @@ class WordsMatcher(metaclass=Singleton):
"""
appley_words = []
# 读取自定义识别词
words: List[str] = custom_words or self.systemconfig.get(SystemConfigKey.CustomIdentifiers) or []
words: List[str] = self.systemconfig.get(SystemConfigKey.CustomIdentifiers) or []
for word in words:
if not word or word.startswith("#"):
continue

View File

@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Tuple, List
from typing import Tuple
import regex as re
@@ -10,18 +10,17 @@ from app.log import logger
from app.schemas.types import MediaType
def MetaInfo(title: str, subtitle: str = None, custom_words: List[str] = None) -> MetaBase:
def MetaInfo(title: str, subtitle: str = None) -> MetaBase:
"""
根据标题和副标题识别元数据
:param title: 标题、种子名、文件名
:param subtitle: 副标题、描述
:param custom_words: 自定义识别词列表
:return: MetaAnime、MetaVideo
"""
# 原标题
org_title = title
# 预处理标题
title, apply_words = WordsMatcher().prepare(title, custom_words=custom_words)
title, apply_words = WordsMatcher().prepare(title)
# 获取标题中媒体信息
title, metainfo = find_metainfo(title)
# 判断是否处理文件

View File

@@ -1,12 +1,9 @@
import traceback
from typing import Generator, Optional, Tuple, Any, Union
from typing import Generator, Optional, Tuple, Any
from app.core.config import settings
from app.core.event import eventmanager
from app.helper.module import ModuleHelper
from app.log import logger
from app.schemas.types import EventType, ModuleType, DownloaderType, MediaServerType, MessageChannel, StorageSchema, \
OtherModulesType
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton
@@ -20,8 +17,6 @@ class ModuleManager(metaclass=Singleton):
_modules: dict = {}
# 运行态模块列表
_running_modules: dict = {}
# 子模块类型集合
SubType = Union[DownloaderType, MediaServerType, MessageChannel, StorageSchema, OtherModulesType]
def __init__(self):
self.load_modules()
@@ -64,7 +59,7 @@ class ModuleManager(metaclass=Singleton):
logger.info(f"Moudle Stoped{module_id}")
except Exception as err:
logger.error(f"Stop Moudle Error{module_id}{str(err)} - {traceback.format_exc()}", exc_info=True)
logger.info("所有模块停止完成")
logger.info("模块停止完成")
def reload(self):
"""
@@ -72,21 +67,17 @@ class ModuleManager(metaclass=Singleton):
"""
self.stop()
self.load_modules()
eventmanager.send_event(etype=EventType.ModuleReload, data={})
def test(self, modleid: str) -> Tuple[bool, str]:
"""
测试模块
"""
if modleid not in self._running_modules:
return False, ""
return False, "模块未加载,请检查参数设置"
module = self._running_modules[modleid]
if hasattr(module, "test") \
and ObjectUtils.check_method(getattr(module, "test")):
result = module.test()
if not result:
return False, ""
return result
return module.test()
return True, "模块不支持测试"
@staticmethod
@@ -127,28 +118,6 @@ class ModuleManager(metaclass=Singleton):
and ObjectUtils.check_method(getattr(module, method)):
yield module
def get_running_type_modules(self, module_type: ModuleType) -> Generator:
"""
获取指定类型的模块列表
"""
if not self._running_modules:
return []
for _, module in self._running_modules.items():
if hasattr(module, 'get_type') \
and module.get_type() == module_type:
yield module
def get_running_subtype_module(self, module_subtype: SubType) -> Generator:
"""
获取指定子类型的模块
"""
if not self._running_modules:
return []
for _, module in self._running_modules.items():
if hasattr(module, 'get_subtype') \
and module.get_subtype() == module_subtype:
yield module
def get_module(self, module_id: str) -> Any:
"""
根据模块id获取模块

View File

@@ -3,9 +3,9 @@ import concurrent.futures
import importlib.util
import inspect
import os
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
@@ -21,9 +21,8 @@ from app.helper.module import ModuleHelper
from app.helper.plugin import PluginHelper
from app.helper.sites import SitesHelper
from app.log import logger
from app.schemas.types import EventType, SystemConfigKey
from app.schemas.types import SystemConfigKey
from app.utils.crypto import RSAUtils
from app.utils.limit import rate_limit_window
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
@@ -31,6 +30,14 @@ from app.utils.system import SystemUtils
class PluginMonitorHandler(FileSystemEventHandler):
# 计时器
__reload_timer = None
# 防抖时间间隔
__debounce_interval = 0.5
# 最近一次修改时间
__last_modified = 0
# 修改间隔
__timeout = 2
def on_modified(self, event):
"""
@@ -43,6 +50,10 @@ class PluginMonitorHandler(FileSystemEventHandler):
if not event_path.name.endswith(".py") or "pycache" in event_path.parts:
return
current_time = time.time()
if current_time - self.__last_modified < self.__timeout:
return
self.__last_modified = current_time
# 读取插件根目录下的__init__.py文件读取class XXXX(_PluginBase)的类名
try:
plugins_root = settings.ROOT_PATH / "app" / "plugins"
@@ -64,12 +75,15 @@ class PluginMonitorHandler(FileSystemEventHandler):
if line.startswith("class") and "(_PluginBase)" in line:
pid = line.split("class ")[1].split("(_PluginBase)")[0].strip()
if pid:
self.__reload_plugin(pid)
# 防抖处理,通过计时器延迟加载
if self.__reload_timer:
self.__reload_timer.cancel()
self.__reload_timer = threading.Timer(self.__debounce_interval, self.__reload_plugin, [pid])
self.__reload_timer.start()
except Exception as e:
logger.error(f"插件文件修改后重载出错:{str(e)}")
@staticmethod
@rate_limit_window(max_calls=1, window_seconds=2, source="PluginMonitor", enable_logging=False)
def __reload_plugin(pid):
"""
重新加载插件
@@ -158,7 +172,7 @@ class PluginManager(metaclass=Singleton):
# 未安装的不加载
if plugin_id not in installed_plugins:
# 设置事件状态为不可用
eventmanager.disable_event_handler(plugin)
eventmanager.disable_events_hander(plugin_id)
continue
# 生成实例
plugin_obj = plugin()
@@ -169,9 +183,9 @@ class PluginManager(metaclass=Singleton):
logger.info(f"加载插件:{plugin_id} 版本:{plugin_obj.plugin_version}")
# 启用的插件才设置事件注册状态可用
if plugin_obj.get_state():
eventmanager.enable_event_handler(plugin)
eventmanager.enable_events_hander(plugin_id)
else:
eventmanager.disable_event_handler(plugin)
eventmanager.disable_events_hander(plugin_id)
except Exception as err:
logger.error(f"加载插件 {plugin_id} 出错:{str(err)} - {traceback.format_exc()}")
@@ -181,18 +195,15 @@ class PluginManager(metaclass=Singleton):
:param plugin_id: 插件ID
:param conf: 插件配置
"""
plugin = self._running_plugins.get(plugin_id)
if not plugin:
if not self._running_plugins.get(plugin_id):
return
# 初始化插件
plugin.init_plugin(conf)
# 检查插件状态并启用/禁用事件处理器
if plugin.get_state():
# 启用插件类的事件处理器
eventmanager.enable_event_handler(type(plugin))
self._running_plugins[plugin_id].init_plugin(conf)
if self._running_plugins[plugin_id].get_state():
# 设置启用的插件事件注册状态可用
eventmanager.enable_events_hander(plugin_id)
else:
# 禁用插件类的事件处理器
eventmanager.disable_event_handler(type(plugin))
# 设置事件状态为不可用
eventmanager.disable_events_hander(plugin_id)
def stop(self, pid: str = None):
"""
@@ -207,7 +218,6 @@ class PluginManager(metaclass=Singleton):
for plugin_id, plugin in self._running_plugins.items():
if pid and plugin_id != pid:
continue
eventmanager.disable_event_handler(type(plugin))
self.__stop_plugin(plugin)
# 清空对像
if pid:
@@ -220,23 +230,11 @@ class PluginManager(metaclass=Singleton):
self._running_plugins = {}
logger.info("插件停止完成")
def reload_monitor(self):
"""
重新加载插件文件修改监测
"""
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
if self._observer and self._observer.is_alive():
logger.info("插件文件修改监测已经在运行中...")
else:
self.__start_monitor()
else:
self.stop_monitor()
def __start_monitor(self):
"""
启用监测插件文件修改监测
开发者模式下监测插件文件修改
"""
logger.info("开始监测插件文件修改...")
logger.info("发者模式下开始监测插件文件修改...")
monitor_handler = PluginMonitorHandler()
self._observer = Observer()
self._observer.schedule(monitor_handler, str(settings.ROOT_PATH / "app" / "plugins"), recursive=True)
@@ -244,16 +242,14 @@ class PluginManager(metaclass=Singleton):
def stop_monitor(self):
"""
停止监测插件文件修改监测
停止监测插件修改
"""
# 停止监测
if self._observer and self._observer.is_alive():
if self._observer:
logger.info("正在停止插件文件修改监测...")
self._observer.stop()
self._observer.join()
logger.info("插件文件修改监测停止完成")
else:
logger.info("未启用插件文件修改监测,无需停止")
@staticmethod
def __stop_plugin(plugin: Any):
@@ -284,86 +280,35 @@ class PluginManager(metaclass=Singleton):
self.stop(plugin_id)
# 重新加载
self.start(plugin_id)
# 广播事件
eventmanager.send_event(EventType.PluginReload, data={"plugin_id": plugin_id})
def sync(self) -> List[str]:
def install_online_plugin(self):
"""
安装本地不存在的在线插件
"""
def install_plugin(plugin):
start_time = time.time()
state, msg = self.pluginhelper.install(pid=plugin.id, repo_url=plugin.repo_url, force_install=True)
elapsed_time = time.time() - start_time
if state:
logger.info(
f"插件 {plugin.plugin_name} 安装成功,版本:{plugin.plugin_version},耗时:{elapsed_time:.2f}")
sync_plugins.append(plugin.id)
else:
logger.error(
f"插件 {plugin.plugin_name} v{plugin.plugin_version} 安装失败:{msg},耗时:{elapsed_time:.2f}")
failed_plugins.append(plugin.id)
if SystemUtils.is_frozen():
return []
# 获取已安装插件列表
install_plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件列表
online_plugins = self.get_online_plugins()
# 确定需要安装的插件
plugins_to_install = [
plugin for plugin in online_plugins
if plugin.id in install_plugins and not self.is_plugin_exists(plugin.id)
]
if not plugins_to_install:
return []
return
logger.info("开始安装第三方插件...")
sync_plugins = []
failed_plugins = []
# 使用 ThreadPoolExecutor 进行并发安装
total_start_time = time.time()
with ThreadPoolExecutor(max_workers=5) as executor:
futures = {
executor.submit(install_plugin, plugin): plugin
for plugin in plugins_to_install
}
for future in as_completed(futures):
plugin = futures[future]
try:
future.result()
except Exception as exc:
logger.error(f"插件 {plugin.plugin_name} 安装过程中出现异常: {exc}")
total_elapsed_time = time.time() - total_start_time
logger.info(
f"第三方插件安装完成,成功:{len(sync_plugins)} 个,"
f"失败:{len(failed_plugins)} 个,总耗时:{total_elapsed_time:.2f}"
)
return sync_plugins
def install_plugin_missing_dependencies(self) -> List[str]:
"""
安装插件中缺失或不兼容的依赖项
"""
# 第一步:获取需要安装的依赖项列表
missing_dependencies = self.pluginhelper.find_missing_dependencies()
if not missing_dependencies:
return missing_dependencies
logger.debug(f"检测到缺失的依赖项: {missing_dependencies}")
logger.info(f"开始安装缺失的依赖项,共 {len(missing_dependencies)} 个...")
# 第二步:安装依赖项并返回结果
total_start_time = time.time()
success, message = self.pluginhelper.install_dependencies(missing_dependencies)
total_elapsed_time = time.time() - total_start_time
if success:
logger.info(f"已完成 {len(missing_dependencies)} 个依赖项安装,总耗时:{total_elapsed_time:.2f}")
else:
logger.warning(f"存在缺失依赖项安装失败,请尝试手动安装,总耗时:{total_elapsed_time:.2f}")
return missing_dependencies
# 已安装插件
install_plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
# 在线插件
online_plugins = self.get_online_plugins()
if not online_plugins:
logger.error("未获取到第三方插件")
return
# 支持更新的插件自动更新
for plugin in online_plugins:
# 只处理已安装的插件
if plugin.id in install_plugins and not self.is_plugin_exists(plugin.id):
# 下载安装
state, msg = self.pluginhelper.install(pid=plugin.id,
repo_url=plugin.repo_url)
# 安装失败
if not state:
logger.error(
f"插件 {plugin.plugin_name} v{plugin.plugin_version} 安装失败:{msg}")
continue
logger.info(f"插件 {plugin.plugin_name} 安装成功,版本:{plugin.plugin_version}")
logger.info("第三方插件安装完成")
def get_plugin_config(self, pid: str) -> dict:
"""
@@ -431,7 +376,7 @@ class PluginManager(metaclass=Singleton):
return plugin.get_page() or []
return []
def get_plugin_dashboard(self, pid: str, key: str = None, **kwargs) -> Optional[schemas.PluginDashboard]:
def get_plugin_dashboard(self, pid: str, key: str, **kwargs) -> Optional[schemas.PluginDashboard]:
"""
获取插件仪表盘
:param pid: 插件ID
@@ -469,42 +414,27 @@ class PluginManager(metaclass=Singleton):
)
return None
def get_plugin_state(self, pid: str) -> bool:
"""
获取插件状态
:param pid: 插件ID
"""
plugin = self._running_plugins.get(pid)
return plugin.get_state() if plugin else False
def get_plugin_commands(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
def get_plugin_commands(self) -> List[Dict[str, Any]]:
"""
获取插件命令
[{
"cmd": "/xx",
"event": EventType.xx,
"desc": "xxxx",
"data": {},
"pid": "",
"data": {}
}]
"""
ret_commands = []
for plugin_id, plugin in self._running_plugins.items():
if pid and pid != plugin_id:
continue
if hasattr(plugin, "get_command") and ObjectUtils.check_method(plugin.get_command):
for _, plugin in self._running_plugins.items():
if hasattr(plugin, "get_command") \
and ObjectUtils.check_method(plugin.get_command):
try:
if not plugin.get_state():
continue
commands = plugin.get_command() or []
for command in commands:
command["pid"] = plugin_id
ret_commands.extend(commands)
ret_commands += plugin.get_command() or []
except Exception as e:
logger.error(f"获取插件命令出错:{str(e)}")
return ret_commands
def get_plugin_apis(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
def get_plugin_apis(self, plugin_id: str = None) -> List[Dict[str, Any]]:
"""
获取插件API
[{
@@ -512,27 +442,25 @@ class PluginManager(metaclass=Singleton):
"endpoint": self.xxx,
"methods": ["GET", "POST"],
"summary": "API名称",
"description": "API说明",
"allow_anonymous": false
"description": "API说明"
}]
"""
ret_apis = []
for plugin_id, plugin in self._running_plugins.items():
if pid and pid != plugin_id:
for pid, plugin in self._running_plugins.items():
if plugin_id and pid != plugin_id:
continue
if hasattr(plugin, "get_api") and ObjectUtils.check_method(plugin.get_api):
if hasattr(plugin, "get_api") \
and ObjectUtils.check_method(plugin.get_api):
try:
if not plugin.get_state():
continue
apis = plugin.get_api() or []
for api in apis:
api["path"] = f"/{plugin_id}{api['path']}"
api["path"] = f"/{pid}{api['path']}"
ret_apis.extend(apis)
except Exception as e:
logger.error(f"获取插件 {plugin_id} API出错{str(e)}")
logger.error(f"获取插件 {pid} API出错{str(e)}")
return ret_apis
def get_plugin_services(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
def get_plugin_services(self) -> List[Dict[str, Any]]:
"""
获取插件服务
[{
@@ -540,22 +468,19 @@ class PluginManager(metaclass=Singleton):
"name": "服务名称",
"trigger": "触发器cron、interval、date、CronTrigger.from_crontab()",
"func": self.xxx,
"kwargs": {} # 定时器参数,
"func_kwargs": {} # 方法参数
"kwagrs": {} # 定时器参数
}]
"""
ret_services = []
for plugin_id, plugin in self._running_plugins.items():
if pid and pid != plugin_id:
continue
if hasattr(plugin, "get_service") and ObjectUtils.check_method(plugin.get_service):
for pid, plugin in self._running_plugins.items():
if hasattr(plugin, "get_service") \
and ObjectUtils.check_method(plugin.get_service):
try:
if not plugin.get_state():
continue
services = plugin.get_service() or []
ret_services.extend(services)
services = plugin.get_service()
if services:
ret_services.extend(services)
except Exception as e:
logger.error(f"获取插件 {plugin_id} 服务出错:{str(e)}")
logger.error(f"获取插件 {pid} 服务出错:{str(e)}")
return ret_services
def get_plugin_dashboard_meta(self):
@@ -632,59 +557,122 @@ class PluginManager(metaclass=Singleton):
"""
获取所有在线插件信息
"""
def __get_plugin_info(market: str) -> Optional[List[schemas.Plugin]]:
"""
获取插件信息
"""
online_plugins = self.pluginhelper.get_plugins(market) or {}
if not online_plugins:
logger.warn(f"获取插件库失败:{market}")
return
ret_plugins = []
add_time = len(online_plugins)
for pid, plugin_info in online_plugins.items():
# 运行状插件
plugin_obj = self._running_plugins.get(pid)
# 非运行态插件
plugin_static = self._plugins.get(pid)
# 基本属性
plugin = schemas.Plugin()
# ID
plugin.id = pid
# 安装状态
if pid in installed_apps and plugin_static:
plugin.installed = True
else:
plugin.installed = False
# 是否有新版本
plugin.has_update = False
if plugin_static:
installed_version = getattr(plugin_static, "plugin_version")
if StringUtils.compare_version(installed_version, plugin_info.get("version")) < 0:
# 需要更新
plugin.has_update = True
# 运行状态
if plugin_obj and hasattr(plugin_obj, "get_state"):
try:
state = plugin_obj.get_state()
except Exception as e:
logger.error(f"获取插件 {pid} 状态出错:{str(e)}")
state = False
plugin.state = state
else:
plugin.state = False
# 是否有详情页面
plugin.has_page = False
if plugin_obj and hasattr(plugin_obj, "get_page"):
if ObjectUtils.check_method(plugin_obj.get_page):
plugin.has_page = True
# 公钥
if plugin_info.get("key"):
plugin.plugin_public_key = plugin_info.get("key")
# 权限
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info):
continue
# 名称
if plugin_info.get("name"):
plugin.plugin_name = plugin_info.get("name")
# 描述
if plugin_info.get("description"):
plugin.plugin_desc = plugin_info.get("description")
# 版本
if plugin_info.get("version"):
plugin.plugin_version = plugin_info.get("version")
# 图标
if plugin_info.get("icon"):
plugin.plugin_icon = plugin_info.get("icon")
# 标签
if plugin_info.get("labels"):
plugin.plugin_label = plugin_info.get("labels")
# 作者
if plugin_info.get("author"):
plugin.plugin_author = plugin_info.get("author")
# 更新历史
if plugin_info.get("history"):
plugin.history = plugin_info.get("history")
# 仓库链接
plugin.repo_url = market
# 本地标志
plugin.is_local = False
# 添加顺序
plugin.add_time = add_time
# 汇总
ret_plugins.append(plugin)
add_time -= 1
return ret_plugins
if not settings.PLUGIN_MARKET:
return []
# 返回值
all_plugins = []
# 用于存储高于 v1 版本的插件(如 v2, v3 等)
higher_version_plugins = []
# 用于存储 v1 版本插件
base_version_plugins = []
# 已安装插件
installed_apps = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
# 使用多线程获取线上插件
with concurrent.futures.ThreadPoolExecutor() as executor:
futures_to_version = {}
futures = []
for m in settings.PLUGIN_MARKET.split(","):
if not m:
continue
# 提交任务获取 v1 版本插件,存储 future 到 version 的映射
base_future = executor.submit(self.get_plugins_from_market, m, None)
futures_to_version[base_future] = "base_version"
# 提交任务获取高版本插件(如 v2、v3存储 future 到 version 的映射
if settings.VERSION_FLAG:
higher_version_future = executor.submit(self.get_plugins_from_market, m, settings.VERSION_FLAG)
futures_to_version[higher_version_future] = "higher_version"
# 按照完成顺序处理结果
for future in concurrent.futures.as_completed(futures_to_version):
futures.append(executor.submit(__get_plugin_info, m))
for future in concurrent.futures.as_completed(futures):
plugins = future.result()
version = futures_to_version[future]
if plugins:
if version == "higher_version":
higher_version_plugins.extend(plugins) # 收集高版本插件
else:
base_version_plugins.extend(plugins) # 收集 v1 版本插件
# 优先处理高版本插件
all_plugins.extend(higher_version_plugins)
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
all_plugins.extend(plugins)
# 去重
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
# 所有插件按 repo 在设置中的顺序排序
# 所有插件按repo在设置中的顺序排序
all_plugins.sort(
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
)
# 相同 ID 的插件保留版本号最大版本
# 相同ID的插件保留版本号最大版本
max_versions = {}
for p in all_plugins:
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, max_versions[p.id]) > 0:
max_versions[p.id] = p.plugin_version
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
result = [p for p in all_plugins if
p.plugin_version == max_versions[p.id]]
logger.info(f"共获取到 {len(result)} 个线上插件")
return result
@@ -773,113 +761,13 @@ class PluginManager(metaclass=Singleton):
# 构建包名
package_name = f"app.plugins.{pid.lower()}"
# 检查包是否存在
spec = importlib.util.find_spec(package_name)
package_exists = spec is not None and spec.origin is not None
package_exists = importlib.util.find_spec(package_name) is not None
logger.debug(f"{pid} exists: {package_exists}")
return package_exists
except Exception as e:
logger.debug(f"获取插件是否在本地包中存在失败,{e}")
return False
def get_plugins_from_market(self, market: str, package_version: str = None) -> Optional[List[schemas.Plugin]]:
"""
从指定的市场获取插件信息
:param market: 市场的 URL 或标识
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
:return: 返回插件的列表,若获取失败返回 []
"""
if not market:
return []
# 已安装插件
installed_apps = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件
online_plugins = self.pluginhelper.get_plugins(market, package_version) or {}
if not online_plugins:
if not package_version:
logger.warning(f"获取插件库失败:{market},请检查 GitHub 网络连接")
return []
ret_plugins = []
add_time = len(online_plugins)
for pid, plugin_info in online_plugins.items():
# 如 package_version 为空,则需要判断插件是否兼容当前版本
if not package_version:
if plugin_info.get(settings.VERSION_FLAG) is not True:
# 插件当前版本不兼容
continue
# 运行状插件
plugin_obj = self._running_plugins.get(pid)
# 非运行态插件
plugin_static = self._plugins.get(pid)
# 基本属性
plugin = schemas.Plugin()
# ID
plugin.id = pid
# 安装状态
if pid in installed_apps and plugin_static:
plugin.installed = True
else:
plugin.installed = False
# 是否有新版本
plugin.has_update = False
if plugin_static:
installed_version = getattr(plugin_static, "plugin_version")
if StringUtils.compare_version(installed_version, "<", plugin_info.get("version")):
# 需要更新
plugin.has_update = True
# 运行状态
if plugin_obj and hasattr(plugin_obj, "get_state"):
try:
state = plugin_obj.get_state()
except Exception as e:
logger.error(f"获取插件 {pid} 状态出错:{str(e)}")
state = False
plugin.state = state
else:
plugin.state = False
# 是否有详情页面
plugin.has_page = False
if plugin_obj and hasattr(plugin_obj, "get_page"):
if ObjectUtils.check_method(plugin_obj.get_page):
plugin.has_page = True
# 公钥
if plugin_info.get("key"):
plugin.plugin_public_key = plugin_info.get("key")
# 权限
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info):
continue
# 名称
if plugin_info.get("name"):
plugin.plugin_name = plugin_info.get("name")
# 描述
if plugin_info.get("description"):
plugin.plugin_desc = plugin_info.get("description")
# 版本
if plugin_info.get("version"):
plugin.plugin_version = plugin_info.get("version")
# 图标
if plugin_info.get("icon"):
plugin.plugin_icon = plugin_info.get("icon")
# 标签
if plugin_info.get("labels"):
plugin.plugin_label = plugin_info.get("labels")
# 作者
if plugin_info.get("author"):
plugin.plugin_author = plugin_info.get("author")
# 更新历史
if plugin_info.get("history"):
plugin.history = plugin_info.get("history")
# 仓库链接
plugin.repo_url = market
# 本地标志
plugin.is_local = False
# 添加顺序
plugin.add_time = add_time
# 汇总
ret_plugins.append(plugin)
add_time -= 1
return ret_plugins
def __set_and_check_auth_level(self, plugin: Union[schemas.Plugin, Type[Any]],
source: Optional[Union[dict, Type[Any]]] = None) -> bool:
"""

View File

@@ -5,165 +5,55 @@ import json
import os
import traceback
from datetime import datetime, timedelta
from typing import Any, Union, Annotated, Optional
from typing import Any, Union, Optional, Annotated
import jwt
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from cryptography.fernet import Fernet
from fastapi import HTTPException, status, Security, Request, Response
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie
from fastapi import HTTPException, status, Depends, Header
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from app import schemas
from app.core.config import settings
from cryptography.fernet import Fernet
from app.log import logger
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
ALGORITHM = "HS256"
# OAuth2PasswordBearer 用于 JWT Token 认证
oauth2_scheme = OAuth2PasswordBearer(
# Token认证
reusable_oauth2 = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
)
# RESOURCE TOKEN 通过 Cookie 认证
resource_token_cookie = APIKeyCookie(name=settings.PROJECT_NAME, auto_error=False, scheme_name="resource_token_cookie")
# API TOKEN 通过 QUERY 认证
api_token_query = APIKeyQuery(name="token", auto_error=False, scheme_name="api_token_query")
# API KEY 通过 Header 认证
api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="api_key_header")
# API KEY 通过 QUERY 认证
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
def create_access_token(
userid: Union[str, Any],
username: str,
super_user: bool = False,
expires_delta: Optional[timedelta] = None,
level: int = 1,
purpose: Optional[str] = "authentication"
userid: Union[str, Any], username: str, super_user: bool = False,
expires_delta: timedelta = None, level: int = 1
) -> str:
"""
创建 JWT 访问令牌,包含用户 ID、用户名、是否为超级用户以及权限等级
:param userid: 用户的唯一标识符,通常是字符串或整数
:param username: 用户名,用于标识用户的账户名
:param super_user: 是否为超级用户,默认值为 False
:param expires_delta: 令牌的有效期时长,如果不提供则根据用途使用默认过期时间
:param level: 用户的权限级别,默认为 1
:param purpose: 令牌的用途,"authentication""resource"
:return: 编码后的 JWT 令牌字符串
:raises ValueError: 如果 expires_delta 为负数
"""
if purpose == "resource":
default_expire = timedelta(seconds=settings.RESOURCE_ACCESS_TOKEN_EXPIRE_SECONDS)
secret_key = settings.RESOURCE_SECRET_KEY
else:
default_expire = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
secret_key = settings.SECRET_KEY
if expires_delta is not None:
if expires_delta.total_seconds() <= 0:
raise ValueError("过期时间必须为正数")
if expires_delta:
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + default_expire
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {
"exp": expire,
"iat": datetime.utcnow(),
"sub": str(userid),
"username": username,
"super_user": super_user,
"level": level,
"purpose": purpose
"level": level
}
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def __set_or_refresh_resource_token_cookie(request: Request, response: Response, payload: schemas.TokenPayload):
"""
设置资源令牌 Cookie
:param request: 包含请求相关的上下文数据
:param response: 用于在服务器响应时设置 Cookie
:param payload: 已通过身份验证的 TokenPayload 对象
"""
resource_token = request.cookies.get(settings.PROJECT_NAME)
if resource_token:
# 检查令牌剩余时间
try:
decoded_token = jwt.decode(resource_token, settings.RESOURCE_SECRET_KEY, algorithms=[ALGORITHM])
exp = decoded_token.get("exp")
if exp:
remaining_time = datetime.utcfromtimestamp(exp) - datetime.utcnow()
# 根据剩余时长提前刷新令牌
if remaining_time < timedelta(seconds=(settings.RESOURCE_ACCESS_TOKEN_EXPIRE_SECONDS / 3)):
raise jwt.ExpiredSignatureError
except jwt.PyJWTError:
logger.debug(f"Token error occurred. refreshing token")
except Exception as e:
logger.debug(f"Unexpected error occurred while decoding token: {e}")
else:
# 如果令牌有效且没有即将过期,则不需要刷新
return
# 创建新的资源访问令牌
resource_token_expires = timedelta(seconds=settings.RESOURCE_ACCESS_TOKEN_EXPIRE_SECONDS)
resource_token = create_access_token(
userid=payload.sub,
username=payload.username,
super_user=payload.super_user,
expires_delta=resource_token_expires,
level=payload.level,
purpose="resource"
)
# 设置会话级别的 HttpOnly Cookie
response.set_cookie(
key=settings.PROJECT_NAME,
value=resource_token,
httponly=True,
secure=request.url.scheme == "https", # 根据当前请求的协议设置 secure 属性
samesite="lax" # 不同浏览器对 "Strict" 的处理可能不同,设置 SameSite 为 "Lax",以平衡安全性和兼容性
)
def __verify_token(token: str, purpose: str = "authentication") -> schemas.TokenPayload:
"""
使用 JWT Token 进行身份认证并解析 Token 的内容
:param token: JWT 令牌
:param purpose: 期望的令牌用途,默认为 "authentication"
:return: 包含用户身份信息的 Token 负载数据
:raises HTTPException: 如果令牌无效或用途不匹配
"""
def verify_token(token: str = Depends(reusable_oauth2)) -> schemas.TokenPayload:
try:
if purpose == "resource":
secret_key = settings.RESOURCE_SECRET_KEY
else:
secret_key = settings.SECRET_KEY
if not token:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"{purpose} token not found"
)
payload = jwt.decode(
token, secret_key, algorithms=[ALGORITHM]
token, settings.SECRET_KEY, algorithms=[ALGORITHM]
)
token_payload = schemas.TokenPayload(**payload)
if token_payload.purpose != purpose:
raise jwt.InvalidTokenError("令牌用途不匹配")
return schemas.TokenPayload(**payload)
except (jwt.DecodeError, jwt.InvalidTokenError, jwt.ImmatureSignatureError):
raise HTTPException(
@@ -172,98 +62,54 @@ def __verify_token(token: str, purpose: str = "authentication") -> schemas.Token
)
def verify_token(
request: Request,
response: Response,
token: str = Security(oauth2_scheme)
) -> schemas.TokenPayload:
def __get_token(token: str = None) -> str:
"""
验证 JWT 令牌并自动处理 resource_token 写入
:param request: 请求对象,用于访问 Cookie 和请求信息
:param response: 响应对象,用于设置 Cookie
:param token: 从 Authorization 头部获取的 JWT 令牌
:return: 解析后的 TokenPayload
:raises HTTPException: 如果令牌无效或用途不匹配
从请求URL中获取token
"""
# 验证并解析 JWT 认证令牌
payload = __verify_token(token=token, purpose="authentication")
# 如果没有 resource_token生成并写入到 Cookie
__set_or_refresh_resource_token_cookie(request, response, payload)
return payload
return token
def verify_resource_token(
resource_token: str = Security(resource_token_cookie)
) -> schemas.TokenPayload:
def __get_apikey(apikey: str = None, x_api_key: Annotated[str | None, Header()] = None) -> str:
"""
验证资源访问令牌(从 Cookie 中获取)
:param resource_token: 从 Cookie 中获取的资源访问令牌
:return: 解析后的 TokenPayload
:raises HTTPException: 如果资源访问令牌无效
从请求URL中获取apikey
"""
# 验证并解析资源访问令牌
return __verify_token(token=resource_token, purpose="resource")
return apikey or x_api_key
def __get_api_token(
token_query: Annotated[str | None, Security(api_token_query)] = None
) -> str:
def verify_apitoken(token: str = Depends(__get_token)) -> str:
"""
从 URL 查询参数中获取 API Token
:param token_query: 从 URL 中的 `token` 查询参数获取 API Token
:return: 返回获取到的 API Token若无则返回 None
通过依赖项使用token进行身份认证
"""
return token_query
def __get_api_key(
key_query: Annotated[str | None, Security(api_key_query)] = None,
key_header: Annotated[str | None, Security(api_key_header)] = None
) -> str:
"""
从 URL 查询参数或请求头部获取 API Key优先使用 URL 参数
:param key_query: URL 中的 `apikey` 查询参数
:param key_header: 请求头中的 `X-API-KEY` 参数
:return: 返回从 URL 或请求头中获取的 API Key若无则返回 None
"""
return key_query or key_header
def __verify_key(key: str, expected_key: str, key_type: str) -> str:
"""
通用的 API Key 或 Token 验证函数
:param key: 从请求中获取的 API Key 或 Token
:param expected_key: 系统配置中的期望值,用于验证的 API Key 或 Token
:param key_type: 键的类型(例如 "API_KEY""API_TOKEN"),用于错误消息
:return: 返回校验通过的 API Key 或 Token
:raises HTTPException: 如果校验不通过,抛出 401 错误
"""
if key != expected_key:
if token != settings.API_TOKEN:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"{key_type} 校验不通过"
detail="token校验不通过"
)
return key
return token
def verify_apitoken(token: str = Security(__get_api_token)) -> str:
def verify_apikey(apikey: str = Depends(__get_apikey)) -> str:
"""
使用 API Token 进行身份认证
:param token: API Token从 URL 查询参数中获取
:return: 返回校验通过的 API Token
通过依赖项使用apikey进行身份认证
"""
return __verify_key(token, settings.API_TOKEN, "API_TOKEN")
if apikey != settings.API_TOKEN:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="apikey校验不通过"
)
return apikey
def verify_apikey(apikey: str = Security(__get_api_key)) -> str:
def verify_uri_token(token: str = Depends(__get_token)) -> str:
"""
使用 API Key 进行身份认证
:param apikey: API Key从 URL 查询参数或请求头中获取
:return: 返回校验通过的 API Key
通过依赖项使用token进行身份认证
"""
return __verify_key(apikey, settings.API_TOKEN, "API_KEY")
if not verify_token(token):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="token校验不通过"
)
return token
def verify_password(plain_password: str, hashed_password: str) -> bool:
@@ -286,7 +132,7 @@ def decrypt(data: bytes, key: bytes) -> Optional[bytes]:
return None
def encrypt_message(message: str, key: bytes) -> str:
def encrypt_message(message: str, key: bytes):
"""
使用给定的key对消息进行加密并返回加密后的字符串
"""
@@ -295,14 +141,14 @@ def encrypt_message(message: str, key: bytes) -> str:
return encrypted_message.decode()
def hash_sha256(message: str) -> str:
def hash_sha256(message):
"""
对字符串做hash运算
"""
return hashlib.sha256(message.encode()).hexdigest()
def aes_decrypt(data: str, key: str) -> str:
def aes_decrypt(data, key):
"""
AES解密
"""
@@ -322,7 +168,7 @@ def aes_decrypt(data: str, key: str) -> str:
return result.decode('utf-8')
def aes_encrypt(data: str, key: str) -> str:
def aes_encrypt(data, key):
"""
AES加密
"""
@@ -338,7 +184,7 @@ def aes_encrypt(data: str, key: str) -> str:
return base64.b64encode(cipher.iv + result).decode('utf-8')
def nexusphp_encrypt(data_str: str, key: bytes) -> str:
def nexusphp_encrypt(data_str: str, key):
"""
NexusPHP加密
"""

View File

@@ -1,77 +0,0 @@
from time import sleep
from typing import Dict, Any, Tuple
from app.actions import BaseAction
from app.helper.module import ModuleHelper
from app.log import logger
from app.schemas import Action, ActionContext
from app.utils.singleton import Singleton
class WorkFlowManager(metaclass=Singleton):
"""
工作流管理器
"""
# 所有动作定义
_actions: Dict[str, BaseAction] = {}
def __init__(self):
self.init()
def init(self):
"""
初始化
"""
def filter_func(obj: Any):
"""
过滤函数,确保只加载新定义的类
"""
if not isinstance(obj, type):
return False
if not hasattr(obj, 'execute') or not hasattr(obj, "name"):
return False
if obj.__name__ == "BaseAction":
return False
return obj.__module__.startswith("app.actions")
# 加载所有动作
self._actions = {}
actions = ModuleHelper.load(
"app.actions",
filter_func=lambda _, obj: filter_func(obj)
)
for action in actions:
logger.debug(f"加载动作: {action.__name__}")
self._actions[action.__name__] = action
def stop(self):
"""
停止
"""
pass
def excute(self, action: Action, context: ActionContext = None) -> Tuple[bool, ActionContext]:
"""
执行工作流动作
"""
if not context:
context = ActionContext()
if action.id in self._actions:
action_obj = self._actions[action.id]
logger.info(f"执行动作: {action.id} - {action.name}")
result_context = action_obj.execute(action.params, context)
logger.info(f"{action.name} 执行结果: {action_obj.success}")
if action.loop and action.loop_interval:
while not action_obj.done:
logger.info(f"{action.name} 等待 {action.loop_interval} 秒后继续执行")
sleep(action.loop_interval)
logger.info(f"继续执行动作: {action.id} - {action.name}")
result_context = action_obj.execute(action.params, result_context)
logger.info(f"{action.name} 执行结果: {action_obj.success}")
logger.info(f"{action.name} 执行完成")
return action_obj.success, result_context
else:
logger.error(f"未找到动作: {action.id} - {action.name}")
return False, context

View File

@@ -1,41 +1,23 @@
from typing import Any, Generator, List, Optional, Self, Tuple
from typing import Any, Self, List
from typing import Tuple, Optional, Generator
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text
from sqlalchemy.orm import Session, as_declarative, declared_attr, scoped_session, sessionmaker
from sqlalchemy import create_engine, QueuePool
from sqlalchemy import inspect
from sqlalchemy.orm import declared_attr
from sqlalchemy.orm import sessionmaker, Session, scoped_session, as_declarative
from app.core.config import settings
# 根据池类型设置 poolclass 和相关参
pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
connect_args = {
"timeout": settings.DB_TIMEOUT
}
# 启用 WAL 模式时的额外配置
if settings.DB_WAL_ENABLE:
connect_args["check_same_thread"] = False
db_kwargs = {
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": pool_class,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": connect_args
}
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
if pool_class == QueuePool:
db_kwargs.update({
"pool_size": settings.DB_POOL_SIZE,
"pool_timeout": settings.DB_POOL_TIMEOUT,
"max_overflow": settings.DB_MAX_OVERFLOW
})
# 创建数据库引擎
Engine = create_engine(**db_kwargs)
# 根据配置设置日志模式
journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
with Engine.connect() as connection:
current_mode = connection.execute(text(f"PRAGMA journal_mode={journal_mode};")).scalar()
print(f"Database journal mode set to: {current_mode}")
# 数据库引擎
Engine = create_engine(f"sqlite:///{settings.CONFIG_PATH}/user.db",
pool_pre_ping=True,
echo=False,
poolclass=QueuePool,
pool_size=1024,
pool_recycle=3600,
pool_timeout=180,
max_overflow=10,
connect_args={"timeout": 60})
# 会话工厂
SessionFactory = sessionmaker(bind=Engine)
@@ -57,36 +39,6 @@ def get_db() -> Generator:
db.close()
def perform_checkpoint(mode: str = "PASSIVE"):
"""
执行 SQLite 的 checkpoint 操作,将 WAL 文件内容写回主数据库
:param mode: checkpoint 模式,可选值包括 "PASSIVE""FULL""RESTART""TRUNCATE"
默认为 "PASSIVE",即不锁定 WAL 文件的轻量级同步
"""
if not settings.DB_WAL_ENABLE:
return
valid_modes = {"PASSIVE", "FULL", "RESTART", "TRUNCATE"}
if mode.upper() not in valid_modes:
raise ValueError(f"Invalid checkpoint mode '{mode}'. Must be one of {valid_modes}")
try:
# 使用指定的 checkpoint 模式,确保 WAL 文件数据被正确写回主数据库
with Engine.connect() as conn:
conn.execute(text(f"PRAGMA wal_checkpoint({mode.upper()});"))
except Exception as e:
print(f"Error during WAL checkpoint: {e}")
def close_database():
"""
关闭所有数据库连接并清理资源
"""
try:
# 释放连接池SQLite 会自动清空 WAL 文件,这里不单独再调用 checkpoint
Engine.dispose()
except Exception as e:
print(f"Error while disposing database connections: {e}")
def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
"""
从参数中获取数据库Session对象
@@ -198,7 +150,7 @@ class Base:
@classmethod
@db_query
def get(cls, db: Session, rid: int) -> Self:
return db.query(cls).filter(and_(cls.id == rid)).first()
return db.query(cls).filter(cls.id == rid).first()
@db_update
def update(self, db: Session, payload: dict):
@@ -211,7 +163,7 @@ class Base:
@classmethod
@db_update
def delete(cls, db: Session, rid):
db.query(cls).filter(and_(cls.id == rid)).delete()
db.query(cls).filter(cls.id == rid).delete()
@classmethod
@db_update
@@ -225,7 +177,7 @@ class Base:
return list(result)
def to_dict(self):
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns}
@declared_attr
def __tablename__(self) -> str:

View File

@@ -23,14 +23,6 @@ class DownloadHistoryOper(DbOper):
"""
return DownloadHistory.get_by_hash(self._db, download_hash)
def get_by_mediaid(self, tmdbid: int, doubanid: str) -> List[DownloadHistory]:
"""
按媒体ID查询下载记录
:param tmdbid: tmdbid
:param doubanid: doubanid
"""
return DownloadHistory.get_by_mediaid(self._db, tmdbid=tmdbid, doubanid=doubanid)
def add(self, **kwargs):
"""
新增下载历史

View File

@@ -1,8 +1,13 @@
import random
import string
from alembic.command import upgrade
from alembic.config import Config
from app.core.config import settings
from app.db import Engine, Base
from app.core.security import get_password_hash
from app.db import Engine, SessionFactory, Base
from app.db.models import *
from app.log import logger
@@ -11,7 +16,28 @@ def init_db():
初始化数据库
"""
# 全量建表
Base.metadata.create_all(bind=Engine) # noqa
Base.metadata.create_all(bind=Engine)
def init_super_user():
"""
初始化超级管理员
"""
# 初始化超级管理员
with SessionFactory() as db:
_user = User.get_by_name(db=db, name=settings.SUPERUSER)
if not _user:
# 定义包含数字、大小写字母的字符集合
characters = string.ascii_letters + string.digits
# 生成随机密码
random_password = ''.join(random.choice(characters) for _ in range(16))
logger.info(f"【超级管理员初始密码】{random_password} 请登录系统后在设定中修改。 注:该密码只会显示一次,请注意保存。")
_user = User(
name=settings.SUPERUSER,
hashed_password=get_password_hash(random_password),
is_superuser=True,
)
_user.create(db)
def update_db():

View File

@@ -1,3 +1,4 @@
import json
from typing import Optional
from sqlalchemy.orm import Session
@@ -18,8 +19,6 @@ class MediaServerOper(DbOper):
"""
新增媒体服务器数据
"""
# MediaServerItem中没有的属性剔除
kwargs = {k: v for k, v in kwargs.items() if hasattr(MediaServerItem, k)}
item = MediaServerItem(**kwargs)
if not item.get_by_itemid(self._db, kwargs.get("item_id")):
item.create(self._db)
@@ -53,7 +52,7 @@ class MediaServerOper(DbOper):
# 判断季是否存在
if not item.seasoninfo:
return None
seasoninfo = item.seasoninfo or {}
seasoninfo = json.loads(item.seasoninfo) or {}
if kwargs.get("season") not in seasoninfo.keys():
return None
return item

View File

@@ -1,3 +1,4 @@
import json
import time
from typing import Optional, Union
@@ -18,7 +19,6 @@ class MessageOper(DbOper):
def add(self,
channel: MessageChannel = None,
source: str = None,
mtype: NotificationType = None,
title: str = None,
text: str = None,
@@ -31,7 +31,6 @@ class MessageOper(DbOper):
"""
新增媒体服务器数据
:param channel: 消息渠道
:param source: 来源
:param mtype: 消息类型
:param title: 标题
:param text: 文本内容
@@ -43,7 +42,6 @@ class MessageOper(DbOper):
"""
kwargs.update({
"channel": channel.value if channel else '',
"source": source,
"mtype": mtype.value if mtype else '',
"title": title,
"text": text,
@@ -52,14 +50,8 @@ class MessageOper(DbOper):
"userid": userid,
"action": action,
"reg_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note or {}
"note": json.dumps(note) if note else ''
})
# 从kwargs中去掉Message中没有的字段
for k in list(kwargs.keys()):
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
Message(**kwargs).create(self._db)
def list_by_page(self, page: int = 1, count: int = 30) -> Optional[str]:

View File

@@ -1,6 +1,6 @@
import time
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
@@ -29,8 +29,6 @@ class DownloadHistory(Base):
episodes = Column(String)
# 海报
image = Column(String)
# 下载器
downloader = Column(String)
# 下载任务Hash
download_hash = Column(String, index=True)
# 种子名称
@@ -48,22 +46,12 @@ class DownloadHistory(Base):
# 创建时间
date = Column(String)
# 附加信息
note = Column(JSON)
# 自定义媒体类别
media_category = Column(String)
note = Column(String)
@staticmethod
@db_query
def get_by_hash(db: Session, download_hash: str):
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by(
DownloadHistory.date.desc()
).first()
@staticmethod
@db_query
def get_by_mediaid(db: Session, tmdbid: int, doubanid: str):
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.doubanid == doubanid).all()
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).first()
@staticmethod
@db_query
@@ -170,10 +158,10 @@ class DownloadFiles(Base):
下载文件记录
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 下载器
downloader = Column(String)
# 下载任务Hash
download_hash = Column(String, index=True)
# 下载器
downloader = Column(String)
# 完整路径
fullpath = Column(String, index=True)
# 保存路径

View File

@@ -1,7 +1,7 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
@@ -35,9 +35,9 @@ class MediaServerItem(Base):
# 路径
path = Column(String)
# 季集
seasoninfo = Column(JSON, default=dict)
seasoninfo = Column(String)
# 备注
note = Column(JSON)
note = Column(String)
# 同步时间
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session
from app.db import db_query, Base
@@ -11,8 +11,6 @@ class Message(Base):
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 消息渠道
channel = Column(String)
# 消息来源
source = Column(String)
# 消息类型
mtype = Column(String)
# 标题
@@ -30,7 +28,7 @@ class Message(Base):
# 消息方向0-接收息1-发送消息
action = Column(Integer)
# 附件json
note = Column(JSON)
note = Column(String)
@staticmethod
@db_query

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
@@ -11,7 +11,7 @@ class PluginData(Base):
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
plugin_id = Column(String, nullable=False, index=True)
key = Column(String, index=True, nullable=False)
value = Column(JSON)
value = Column(String)
@staticmethod
@db_query

View File

@@ -1,6 +1,6 @@
from datetime import datetime
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON
from sqlalchemy import Boolean, Column, Integer, String, Sequence
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
@@ -38,7 +38,7 @@ class Site(Base):
# 是否公开站点
public = Column(Integer)
# 附加信息
note = Column(JSON)
note = Column(String)
# 流控单位周期
limit_interval = Column(Integer, default=0)
# 流控次数
@@ -46,13 +46,11 @@ class Site(Base):
# 流控间隔
limit_seconds = Column(Integer, default=0)
# 超时时间
timeout = Column(Integer, default=15)
timeout = Column(Integer, default=0)
# 是否启用
is_active = Column(Boolean(), default=True)
# 创建时间
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
# 下载器
downloader = Column(String)
@staticmethod
@db_query

View File

@@ -1,6 +1,6 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
@@ -24,7 +24,7 @@ class SiteStatistic(Base):
# 最后访问时间
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
# 耗时记录 Json
note = Column(JSON)
note = Column(String)
@staticmethod
@db_query

View File

@@ -1,93 +0,0 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_
from sqlalchemy.orm import Session
from app.db import db_query, Base
class SiteUserData(Base):
"""
站点数据表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 站点域名
domain = Column(String, index=True)
# 站点名称
name = Column(String)
# 用户名
username = Column(String)
# 用户ID
userid = Column(Integer)
# 用户等级
user_level = Column(String)
# 加入时间
join_at = Column(String)
# 积分
bonus = Column(Float, default=0)
# 上传量
upload = Column(Float, default=0)
# 下载量
download = Column(Float, default=0)
# 分享率
ratio = Column(Float, default=0)
# 做种数
seeding = Column(Float, default=0)
# 下载数
leeching = Column(Float, default=0)
# 做种体积
seeding_size = Column(Float, default=0)
# 下载体积
leeching_size = Column(Float, default=0)
# 做种人数, 种子大小 JSON
seeding_info = Column(JSON, default=dict)
# 未读消息
message_unread = Column(Integer, default=0)
# 未读消息内容 JSON
message_unread_contents = Column(JSON, default=list)
# 错误信息
err_msg = Column(String)
# 更新日期
updated_day = Column(String, index=True, default=datetime.now().strftime('%Y-%m-%d'))
# 更新时间
updated_time = Column(String, default=datetime.now().strftime('%H:%M:%S'))
@staticmethod
@db_query
def get_by_domain(db: Session, domain: str, workdate: str = None, worktime: str = None):
if workdate and worktime:
return db.query(SiteUserData).filter(SiteUserData.domain == domain,
SiteUserData.updated_day == workdate,
SiteUserData.updated_time == worktime).all()
elif workdate:
return db.query(SiteUserData).filter(SiteUserData.domain == domain,
SiteUserData.updated_day == workdate).all()
return db.query(SiteUserData).filter(SiteUserData.domain == domain).all()
@staticmethod
@db_query
def get_by_date(db: Session, date: str):
return db.query(SiteUserData).filter(SiteUserData.updated_day == date).all()
@staticmethod
@db_query
def get_latest(db: Session):
"""
获取各站点最新一天的数据
"""
subquery = (
db.query(
SiteUserData.domain,
func.max(SiteUserData.updated_day).label('latest_update_day')
)
.group_by(SiteUserData.domain)
.filter(or_(SiteUserData.err_msg.is_(None), SiteUserData.err_msg == ""))
.subquery()
)
# 主查询:按 domain 和 updated_day 获取最新的记录
return db.query(SiteUserData).join(
subquery,
(SiteUserData.domain == subquery.c.domain) &
(SiteUserData.updated_day == subquery.c.latest_update_day)
).order_by(SiteUserData.updated_time.desc()).all()

View File

@@ -1,6 +1,6 @@
import time
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON
from sqlalchemy import Column, Integer, String, Sequence, Float
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
@@ -24,7 +24,6 @@ class Subscribe(Base):
tvdbid = Column(Integer)
doubanid = Column(String, index=True)
bangumiid = Column(Integer, index=True)
mediaid = Column(String, index=True)
# 季号
season = Column(Integer)
# 海报
@@ -54,8 +53,8 @@ class Subscribe(Base):
# 缺失集数
lack_episode = Column(Integer)
# 附加信息
note = Column(JSON)
# 状态N-新建 R-订阅中 P-待定 S-暂停
note = Column(String)
# 状态N-新建 R-订阅中
state = Column(String, nullable=False, index=True, default='N')
# 最后更新时间
last_update = Column(String)
@@ -64,9 +63,7 @@ class Subscribe(Base):
# 订阅用户
username = Column(String)
# 订阅站点
sites = Column(JSON, default=list)
# 下载器
downloader = Column(String)
sites = Column(String)
# 是否洗版
best_version = Column(Integer, default=0)
# 当前优先级
@@ -77,12 +74,6 @@ class Subscribe(Base):
search_imdbid = Column(Integer, default=0)
# 是否手动修改过总集数 0否 1是
manual_total_episode = Column(Integer, default=0)
# 自定义识别词
custom_words = Column(String)
# 自定义媒体类别
media_category = Column(String)
# 过滤规则组
filter_groups = Column(JSON, default=list)
@staticmethod
@db_query
@@ -99,23 +90,9 @@ class Subscribe(Base):
@staticmethod
@db_query
def get_by_state(db: Session, state: str):
# 如果 state 为空或 None返回所有订阅
if not state:
result = db.query(Subscribe).all()
else:
# 如果传入的状态不为空,拆分成多个状态
states = state.split(',')
result = db.query(Subscribe).filter(Subscribe.state.in_(states)).all()
result = db.query(Subscribe).filter(Subscribe.state == state).all()
return list(result)
@staticmethod
@db_query
def get_by_title(db: Session, title: str, season: int = None):
if season:
return db.query(Subscribe).filter(Subscribe.name == title,
Subscribe.season == season).first()
return db.query(Subscribe).filter(Subscribe.name == title).first()
@staticmethod
@db_query
def get_by_tmdbid(db: Session, tmdbid: int, season: int = None):
@@ -126,6 +103,14 @@ class Subscribe(Base):
result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all()
return list(result)
@staticmethod
@db_query
def get_by_title(db: Session, title: str, season: int = None):
if season:
return db.query(Subscribe).filter(Subscribe.name == title,
Subscribe.season == season).first()
return db.query(Subscribe).filter(Subscribe.name == title).first()
@staticmethod
@db_query
def get_by_doubanid(db: Session, doubanid: str):
@@ -136,11 +121,6 @@ class Subscribe(Base):
def get_by_bangumiid(db: Session, bangumiid: int):
return db.query(Subscribe).filter(Subscribe.bangumiid == bangumiid).first()
@staticmethod
@db_query
def get_by_mediaid(db: Session, mediaid: str):
return db.query(Subscribe).filter(Subscribe.mediaid == mediaid).first()
@db_update
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
subscrbies = self.get_by_tmdbid(db, tmdbid, season)
@@ -155,13 +135,6 @@ class Subscribe(Base):
subscribe.delete(db, subscribe.id)
return True
@db_update
def delete_by_mediaid(self, db: Session, mediaid: str):
subscribe = self.get_by_mediaid(db, mediaid)
if subscribe:
subscribe.delete(db, subscribe.id)
return True
@staticmethod
@db_query
def list_by_username(db: Session, username: str, state: str = None, mtype: str = None):

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON
from sqlalchemy import Column, Integer, String, Sequence, Float
from sqlalchemy.orm import Session
from app.db import db_query, Base
@@ -22,7 +22,6 @@ class SubscribeHistory(Base):
tvdbid = Column(Integer)
doubanid = Column(String, index=True)
bangumiid = Column(Integer, index=True)
mediaid = Column(String, index=True)
# 季号
season = Column(Integer)
# 海报
@@ -54,19 +53,13 @@ class SubscribeHistory(Base):
# 订阅用户
username = Column(String)
# 订阅站点
sites = Column(JSON)
sites = Column(String)
# 是否洗版
best_version = Column(Integer, default=0)
# 保存路径
save_path = Column(String)
# 是否使用 imdbid 搜索
search_imdbid = Column(Integer, default=0)
# 自定义识别词
custom_words = Column(String)
# 自定义媒体类别
media_category = Column(String)
# 过滤规则组
filter_groups = Column(JSON, default=list)
@staticmethod
@db_query
@@ -74,18 +67,6 @@ class SubscribeHistory(Base):
result = db.query(SubscribeHistory).filter(
SubscribeHistory.type == mtype
).order_by(
SubscribeHistory.date.desc()
SubscribeHistory.date.desc()
).offset((page - 1) * count).limit(count).all()
return list(result)
@staticmethod
@db_query
def exists(db: Session, tmdbid: int = None, doubanid: str = None, season: int = None):
if tmdbid:
if season:
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid,
SubscribeHistory.season == season).first()
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid).first()
elif doubanid:
return db.query(SubscribeHistory).filter(SubscribeHistory.doubanid == doubanid).first()
return None

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
@@ -12,7 +12,7 @@ class SystemConfig(Base):
# 主键
key = Column(String, index=True)
# 值
value = Column(JSON)
value = Column(String, nullable=True)
@staticmethod
@db_query

View File

@@ -1,6 +1,6 @@
import time
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_, JSON
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
@@ -8,21 +8,13 @@ from app.db import db_query, db_update, Base
class TransferHistory(Base):
"""
整理记录
转移历史记录
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 源路径
# 源目录
src = Column(String, index=True)
# 源存储
src_storage = Column(String)
# 源文件项
src_fileitem = Column(JSON, default=dict)
# 目标路径
# 目标目录
dest = Column(String)
# 目标存储
dest_storage = Column(String)
# 目标文件项
dest_fileitem = Column(JSON, default=dict)
# 转移模式 move/copy/link...
mode = Column(String)
# 类型 电影/电视剧
@@ -43,8 +35,6 @@ class TransferHistory(Base):
episodes = Column(String)
# 海报
image = Column(String)
# 下载器
downloader = Column(String)
# 下载器hash
download_hash = Column(String, index=True)
# 转移成功状态
@@ -54,7 +44,7 @@ class TransferHistory(Base):
# 时间
date = Column(String, index=True)
# 文件清单以JSON存储
files = Column(JSON, default=list)
files = Column(String)
@staticmethod
@db_query
@@ -97,12 +87,8 @@ class TransferHistory(Base):
@staticmethod
@db_query
def get_by_src(db: Session, src: str, storage: str = None):
if storage:
return db.query(TransferHistory).filter(TransferHistory.src == src,
TransferHistory.src_storage == storage).first()
else:
return db.query(TransferHistory).filter(TransferHistory.src == src).first()
def get_by_src(db: Session, src: str):
return db.query(TransferHistory).filter(TransferHistory.src == src).first()
@staticmethod
@db_query

View File

@@ -1,7 +1,12 @@
from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String
from typing import Tuple, Optional
from sqlalchemy import Boolean, Column, Integer, String, Sequence
from sqlalchemy.orm import Session
from app.db import Base, db_query, db_update
from app.core.security import verify_password
from app.db import db_query, db_update, Base
from app.schemas import User
from app.utils.otp import OtpUtils
class User(Base):
@@ -10,9 +15,9 @@ class User(Base):
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 用户名,唯一值
# 用户名
name = Column(String, index=True, nullable=False)
# 邮箱
# 邮箱,未启用
email = Column(String)
# 加密后密码
hashed_password = Column(String)
@@ -26,21 +31,25 @@ class User(Base):
is_otp = Column(Boolean(), default=False)
# otp秘钥
otp_secret = Column(String, default=None)
# 用户权限 json
permissions = Column(JSON, default=dict)
# 用户个性化设置 json
settings = Column(JSON, default=dict)
@staticmethod
@db_query
def authenticate(db: Session, name: str, password: str, otp_password: str) -> Tuple[bool, Optional[User]]:
user = db.query(User).filter(User.name == name).first()
if not user:
return False, None
if not verify_password(password, str(user.hashed_password)):
return False, user
if user.is_otp:
if not otp_password or not OtpUtils.check(user.otp_secret, otp_password):
return False, user
return True, user
@staticmethod
@db_query
def get_by_name(db: Session, name: str):
return db.query(User).filter(User.name == name).first()
@staticmethod
@db_query
def get_by_id(db: Session, user_id: int):
return db.query(User).filter(User.id == user_id).first()
@db_update
def delete_by_name(self, db: Session, name: str):
user = self.get_by_name(db, name)
@@ -48,13 +57,6 @@ class User(Base):
user.delete(db, user.id)
return True
@db_update
def delete_by_id(self, db: Session, user_id: int):
user = self.get_by_id(db, user_id)
if user:
user.delete(db, user.id)
return True
@db_update
def update_otp_by_name(self, db: Session, name: str, otp: bool, secret: str):
user = self.get_by_name(db, name)

View File

@@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, Sequence, UniqueConstraint, Index, JSON
from sqlalchemy import Column, Integer, String, Sequence, UniqueConstraint, Index
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
@@ -14,7 +14,7 @@ class UserConfig(Base):
# 配置键
key = Column(String)
# 值
value = Column(JSON)
value = Column(String, nullable=True)
__table_args__ = (
# 用户名和配置键联合唯一

View File

@@ -1,69 +0,0 @@
from sqlalchemy import Column, Integer, String, Sequence, Float
from sqlalchemy.orm import Session
from app.db import db_query, Base
class UserRequest(Base):
"""
用户请求表
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 申请用户
req_user = Column(String, index=True, nullable=False)
# 申请时间
req_time = Column(String)
# 申请备注
req_remark = Column(String)
# 审批用户
app_user = Column(String, index=True, nullable=False)
# 审批时间
app_time = Column(String)
# 审批状态 0-待审批 1-通过 2-拒绝
app_status = Column(Integer, default=0)
# 类型
type = Column(String)
# 标题
title = Column(String)
# 年份
year = Column(String)
# 媒体ID
tmdbid = Column(Integer)
imdbid = Column(String)
tvdbid = Column(Integer)
doubanid = Column(String)
bangumiid = Column(Integer)
# 季号
season = Column(Integer)
# 海报
poster = Column(String)
# 背景图
backdrop = Column(String)
# 评分float
vote = Column(Float)
# 简介
description = Column(String)
@staticmethod
@db_query
def get_by_req_user(db: Session, req_user: str, status: int = None):
if status:
return db.query(UserRequest).filter(UserRequest.req_user == req_user,
UserRequest.app_status == status).all()
else:
return db.query(UserRequest).filter(UserRequest.req_user == req_user).all()
@staticmethod
@db_query
def get_by_app_user(db: Session, app_user: str, status: int = None):
if status:
return db.query(UserRequest).filter(UserRequest.app_user == app_user,
UserRequest.app_status == status).all()
else:
return db.query(UserRequest).filter(UserRequest.app_user == app_user).all()
@staticmethod
@db_query
def get_by_status(db: Session, status: int):
return db.query(UserRequest).filter(UserRequest.app_status == status).all()

View File

@@ -1,87 +0,0 @@
from datetime import datetime
from sqlalchemy import Column, Integer, JSON, Sequence, String
from app.db import Base, db_query, db_update
class Workflow(Base):
"""
工作流表
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 名称
name = Column(String, index=True, nullable=False)
# 描述
description = Column(String)
# 定时器
timer = Column(String)
# 状态W-等待 R-运行中 P-暂停 S-成功 F-失败
state = Column(String, nullable=False, index=True, default='W')
# 当前执行动作
current_action = Column(String)
# 任务执行结果
result = Column(String)
# 已执行次数
run_count = Column(Integer, default=0)
# 任务列表
actions = Column(JSON, default=list)
# 执行上下文
context = Column(JSON, default=dict)
# 创建时间
add_time = Column(String, default=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
# 最后执行时间
last_time = Column(String)
@staticmethod
@db_query
def get_enabled_workflows(db):
return db.query(Workflow).filter(Workflow.state != 'P').all()
@staticmethod
@db_query
def get_by_name(db, name: str):
return db.query(Workflow).filter(Workflow.name == name).first()
@staticmethod
@db_update
def update_state(db, wid: int, state: str):
db.query(Workflow).filter(Workflow.id == wid).update({"state": state})
return True
@staticmethod
@db_update
def start(db, wid: int):
db.query(Workflow).filter(Workflow.id == wid).update({
"state": 'R'
})
return True
@staticmethod
@db_update
def fail(db, wid: int, result: str):
db.query(Workflow).filter(Workflow.id == wid).update({
"state": 'F',
"result": result,
"run_count": Workflow.run_count + 1,
"last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
})
return True
@staticmethod
@db_update
def success(db, wid: int, result: str = None):
db.query(Workflow).filter(Workflow.id == wid).update({
"state": 'S',
"result": result,
"run_count": Workflow.run_count + 1,
"last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
})
return True
@staticmethod
@db_update
def update_current_action(db, wid: int, action: str, context: dict):
db.query(Workflow).filter(Workflow.id == wid).update({"current_action": action, "context": context})
return True

View File

@@ -1,7 +1,9 @@
import json
from typing import Any
from app.db import DbOper
from app.db.models.plugindata import PluginData
from app.utils.object import ObjectUtils
class PluginDataOper(DbOper):
@@ -16,6 +18,8 @@ class PluginDataOper(DbOper):
:param key: 数据key
:param value: 数据值
"""
if ObjectUtils.is_obj(value):
value = json.dumps(value)
plugin = PluginData.get_plugin_data_by_key(self._db, plugin_id, key)
if plugin:
plugin.update(self._db, {
@@ -34,6 +38,8 @@ class PluginDataOper(DbOper):
data = PluginData.get_plugin_data_by_key(self._db, plugin_id, key)
if not data:
return None
if ObjectUtils.is_obj(data.value):
return json.loads(data.value)
return data.value
else:
return PluginData.get_plugin_data(self._db, plugin_id)

View File

@@ -1,11 +1,7 @@
from datetime import datetime
from typing import List, Tuple
from typing import Tuple, List
from app.db import DbOper
from app.db.models import SiteIcon
from app.db.models.site import Site
from app.db.models.sitestatistic import SiteStatistic
from app.db.models.siteuserdata import SiteUserData
class SiteOper(DbOper):
@@ -102,131 +98,3 @@ class SiteOper(DbOper):
"rss": rss
})
return True, "更新站点RSS地址成功"
def update_userdata(self, domain: str, name: str, payload: dict) -> Tuple[bool, str]:
"""
更新站点用户数据
"""
# 当前系统日期
current_day = datetime.now().strftime('%Y-%m-%d')
current_time = datetime.now().strftime('%H:%M:%S')
payload.update({
"domain": domain,
"name": name,
"updated_day": current_day,
"updated_time": current_time,
"err_msg": payload.get("err_msg") or ""
})
# 按站点+天判断是否存在数据
siteuserdatas = SiteUserData.get_by_domain(self._db, domain=domain, workdate=current_day)
if siteuserdatas:
# 存在则更新
siteuserdatas[0].update(self._db, payload)
else:
# 不存在则插入
SiteUserData(**payload).create(self._db)
return True, "更新站点用户数据成功"
def get_userdata(self) -> List[SiteUserData]:
"""
获取站点用户数据
"""
return SiteUserData.list(self._db)
def get_userdata_by_domain(self, domain: str, workdate: str = None) -> List[SiteUserData]:
"""
获取站点用户数据
"""
return SiteUserData.get_by_domain(self._db, domain=domain, workdate=workdate)
def get_userdata_by_date(self, date: str) -> List[SiteUserData]:
"""
获取站点用户数据
"""
return SiteUserData.get_by_date(self._db, date)
def get_userdata_latest(self) -> List[SiteUserData]:
"""
获取站点最新数据
"""
return SiteUserData.get_latest(self._db)
def get_icon_by_domain(self, domain: str) -> SiteIcon:
"""
按域名获取站点图标
"""
return SiteIcon.get_by_domain(self._db, domain)
def update_icon(self, name: str, domain: str, icon_url: str, icon_base64: str) -> bool:
"""
更新站点图标
"""
icon_base64 = f"data:image/ico;base64,{icon_base64}" if icon_base64 else ""
siteicon = self.get_icon_by_domain(domain)
if not siteicon:
SiteIcon(name=name, domain=domain, url=icon_url, base64=icon_base64).create(self._db)
elif icon_base64:
siteicon.update(self._db, {
"url": icon_url,
"base64": icon_base64
})
return True
def success(self, domain: str, seconds: int = None):
"""
站点访问成功
"""
lst_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sta = SiteStatistic.get_by_domain(self._db, domain)
if sta:
avg_seconds, note = None, {}
if seconds is not None:
note: dict = sta.note or {}
note[lst_date] = seconds or 1
avg_times = len(note.keys())
if avg_times > 10:
note = dict(sorted(note.items(), key=lambda x: x[0], reverse=True)[:10])
avg_seconds = sum([v for v in note.values()]) // avg_times
sta.update(self._db, {
"success": sta.success + 1,
"seconds": avg_seconds or sta.seconds,
"lst_state": 0,
"lst_mod_date": lst_date,
"note": note or sta.note
})
else:
note = {}
if seconds is not None:
note = {
lst_date: seconds or 1
}
SiteStatistic(
domain=domain,
success=1,
fail=0,
seconds=seconds or 1,
lst_state=0,
lst_mod_date=lst_date,
note=note
).create(self._db)
def fail(self, domain: str):
"""
站点访问失败
"""
lst_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sta = SiteStatistic.get_by_domain(self._db, domain)
if sta:
sta.update(self._db, {
"fail": sta.fail + 1,
"lst_state": 1,
"lst_mod_date": lst_date
})
else:
SiteStatistic(
domain=domain,
success=0,
fail=1,
lst_state=1,
lst_mod_date=lst_date
).create(self._db)

Some files were not shown because too many files have changed in this diff Show More